diff --git a/AGENTS.md b/AGENTS.md index 73b4d54b10ff..1ac4aafaf11d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -72,6 +72,34 @@ _Important:_ For OpenAI Codex agents (most likely you!), your environment does n - `weave/` - Core implementation - `weave/` - Python package implementation - `weave/trace_server` - Backend server implementation + - `weave/trace_server_bindings` - Client-side remote server binding (backed by `weave_server_sdk`) + +### Client/Server Layering (lint-enforced) + +Client code must not use the trace server's interface/model modules. This is +enforced incrementally by ruff `TID251` (`flake8-tidy-imports.banned-api` in +`pyproject.toml`); the server package itself, `tests/`, `scripts/`, and +`trace_server_mock/` are exempt. + +- API request/response types come from `weave_server_sdk.models` — generated + from the trace server's OpenAPI spec (the source of truth; the generator + lives in wandb/core `services/weave-trace/tools/codegen`). The package is + TEMPORARILY resolved from test PyPI via `[tool.uv.sources]`. +- Surface the published SDK cannot express yet (endpoints excluded from the + OpenAPI spec, 0.0.1 codegen gaps) lives in + `weave/trace_server_bindings/models.py` as documented gap models — prefer + deleting these when a regenerated SDK covers them. Never fall back to the + server's tsi models in client code. +- Client code outside the bindings package must not import + `weave.trace_server_bindings.models` either (also TID251-banned): API types + come from `weave_server_sdk` directly. Where a gap model is genuinely + unavoidable today, the import carries a spot-level + `# noqa: TID251` with a reason, so the exemption list burns down as the + SDK regenerates. +- In tests, construct `weave_server_sdk.models` (or plain JSON dicts), not + tsi models. The in-process test servers parse incoming requests with their + own types via `tests/trace_server/conftest_lib/request_coercion.py` — the + in-process equivalent of the HTTP seam. ## Generated Files — Do Not Hand-Edit diff --git a/pyproject.toml b/pyproject.toml index 631721bde608..c7a209fe2ada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -330,13 +330,20 @@ exclude = ["rules"] # list grows as the migration proceeds. The server package itself, tests, and # scripts are exempt via per-file-ignores below. "weave.trace_server.http_service_interface".msg = "Client code must not use the server's HTTP body models. Use weave_server_sdk.models (generated from the OpenAPI spec) instead." +"weave.trace_server.trace_server_interface".msg = "Client code must not use the server's request/response models. Use weave_server_sdk.models (generated from the OpenAPI spec) or the gap models in weave.trace_server_bindings.models instead." +"weave.trace_server.interface.query".msg = "Client code must not use the server's query AST models. Use weave_server_sdk.models instead." +"weave.trace_server.common_interface".msg = "Client code must not use the server's common interface models. Use weave_server_sdk.models instead." +"weave.trace_server.service_interface".msg = "Client code must not use the server's service interface. Use weave_server_sdk.models and weave.trace_server_bindings instead." +"weave.trace_server_bindings.models".msg = "Client code must import API types from weave_server_sdk directly. The bindings' gap models exist only because weave-server-sdk 0.0.1 cannot express that surface; where one is unavoidable, use a spot-level noqa with a reason so the exemption burns down when the SDK regenerates." [tool.ruff.lint.per-file-ignores] "!/weave/trace/**/*.py" = ["T201"] "!/tests/**/*.py" = ["RUF059"] -# The trace server may import its own modules; tests exercise the server -# directly; scripts and the mock server are operational server tooling. +# The trace server may import its own modules (and the bindings package its +# own models); tests exercise the server directly; scripts and the mock +# server are operational server tooling. "weave/trace_server/**/*.py" = ["TID251"] +"weave/trace_server_bindings/**/*.py" = ["TID251"] "scripts/**/*.py" = ["TID251"] "trace_server_mock/**/*.py" = ["TID251"] # Tests intentionally use async functions without await, compare known float diff --git a/tests/conftest.py b/tests/conftest.py index 2e1d0115cfbf..aa97b0a633e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ import weave from tests.trace.util import DummyTestException from tests.trace_server.conftest import TEST_ENTITY, get_trace_server_flag +from tests.trace_server.conftest_lib.request_coercion import RequestCoercingTraceServer from weave.trace import weave_client, weave_init from weave.trace.context import weave_client_context from weave.trace.context.call_context import set_call_stack @@ -460,7 +461,9 @@ def create_client( elif trace_server_flag == "http": server = RemoteHTTPTraceServer(trace_server_flag) else: - server = trace_server + # The client sends weave_server_sdk models; the in-process servers + # parse them with their own request types, like the HTTP seam would. + server = RequestCoercingTraceServer(trace_server) # Removing this as it lead to passing tests that were not passing in prod! # Keeping off for now until it is the default behavior. diff --git a/tests/flow/test_evaluation_imperative.py b/tests/flow/test_evaluation_imperative.py index 105b3505261c..30eba820f932 100644 --- a/tests/flow/test_evaluation_imperative.py +++ b/tests/flow/test_evaluation_imperative.py @@ -6,13 +6,13 @@ from typing import TypedDict import pytest +from weave_server_sdk.models import ObjectVersionFilter import weave from weave.evaluation.eval_imperative import EvaluationLogger, Model, Scorer from weave.integrations.integration_utilities import op_name_from_call from weave.trace.context import call_context from weave.trace.serialization.serialize import to_json -from weave.trace_server.trace_server_interface import ObjectVersionFilter class ExampleRow(TypedDict): diff --git a/tests/flow/test_monitor.py b/tests/flow/test_monitor.py index 0ca1a1429036..9ca8c93525d4 100644 --- a/tests/flow/test_monitor.py +++ b/tests/flow/test_monitor.py @@ -1,10 +1,10 @@ import pytest from pydantic import ValidationError +from weave_server_sdk.models import Query from weave.flow.monitor import Monitor from weave.scorers import ValidJSONScorer from weave.trace.api import publish -from weave.trace_server.interface.query import Query def test_init_pass(): diff --git a/tests/trace/data_serialization/test_cases/config_cases.py b/tests/trace/data_serialization/test_cases/config_cases.py index 81717c5445a9..eaee5eede32c 100644 --- a/tests/trace/data_serialization/test_cases/config_cases.py +++ b/tests/trace/data_serialization/test_cases/config_cases.py @@ -120,23 +120,15 @@ "sampling_rate": 0.5, "scorers": [], "op_names": ["weave.genai.turn_ended"], + # The generated query models type operand lists as list[Any], so + # leaf operands serialize as plain dicts (no _type metadata). "query": { "_type": "Query", "$expr": { "_type": "GtOperation", "$gt": [ - { - "_type": "GetFieldOperator", - "$getField": "started_at", - "_class_name": "GetFieldOperator", - "_bases": ["BaseModel"], - }, - { - "_type": "LiteralOperation", - "$literal": 1742540400, - "_class_name": "LiteralOperation", - "_bases": ["BaseModel"], - }, + {"$getField": "started_at"}, + {"$literal": 1742540400}, ], "_class_name": "GtOperation", "_bases": ["BaseModel"], diff --git a/tests/trace/server_utils.py b/tests/trace/server_utils.py index 9c8efdd6cf12..b341e5f443ae 100644 --- a/tests/trace/server_utils.py +++ b/tests/trace/server_utils.py @@ -11,7 +11,15 @@ TEST_ENTITY = "shawn" # Attribute names used by each middleware layer to reference the next server. -_NEXT_SERVER_ATTRS = ("server", "_next_trace_server", "_internal_trace_server") +# _server must come first: the flushing proxy and the request-coercion +# wrapper forward unknown attributes to their inner server, so probing the +# other names on them would skip a layer. +_NEXT_SERVER_ATTRS = ( + "_server", + "server", + "_next_trace_server", + "_internal_trace_server", +) def find_server_layer(server: tsi.TraceServerInterface, layer_type: type[T]) -> T: @@ -32,7 +40,10 @@ def find_server_layer(server: tsi.TraceServerInterface, layer_type: type[T]) -> visited.add(obj_id) next_layer = None for attr in _NEXT_SERVER_ATTRS: - next_layer = getattr(current, attr, None) + # Probe the instance dict directly: several wrappers forward + # unknown attributes to their inner server, so getattr would + # tunnel through and skip layers. + next_layer = vars(current).get(attr) if next_layer is not None: break current = next_layer diff --git a/tests/trace/test_client_annotation_queue_sdk.py b/tests/trace/test_client_annotation_queue_sdk.py index cc3c1fb42635..6eb4a4a9510e 100644 --- a/tests/trace/test_client_annotation_queue_sdk.py +++ b/tests/trace/test_client_annotation_queue_sdk.py @@ -3,11 +3,11 @@ import datetime import pytest +from weave_server_sdk.models import AnnotationQueueItemsFilter, SortBy from tests.trace.util import client_is_sqlite from weave.trace.call import Call from weave.trace.weave_client import WeaveClient -from weave.trace_server.common_interface import AnnotationQueueItemsFilter, SortBy from weave.trace_server.errors import NotFoundError from weave.trace_server.trace_server_interface import ( AnnotationQueueAddCallsRes, diff --git a/tests/trace/test_saved_view.py b/tests/trace/test_saved_view.py index 80a4fd98f18d..e85aaaa925e1 100644 --- a/tests/trace/test_saved_view.py +++ b/tests/trace/test_saved_view.py @@ -1,6 +1,7 @@ import datetime import pytest +from weave_server_sdk import models as tsi import weave from weave.flow.saved_view import ( @@ -11,7 +12,6 @@ to_seconds, ) from weave.trace.api import ObjectRef -from weave.trace_server import trace_server_interface as tsi def test_to_seconds(): diff --git a/tests/trace/test_wal_client_writes.py b/tests/trace/test_wal_client_writes.py index cab4679a74d4..8603859d397f 100644 --- a/tests/trace/test_wal_client_writes.py +++ b/tests/trace/test_wal_client_writes.py @@ -32,7 +32,7 @@ from weave.durability.wal_writer import JSONLWALWriter from weave.trace import weave_client from weave.trace.settings import UserSettings, override_settings -from weave.trace_server import trace_server_interface as tsi +from weave.trace_server_bindings import models as tsi def _read_all_wal_records(client: weave.WeaveClient) -> list[dict]: @@ -119,6 +119,7 @@ def test_obj_create(self, wal_client): "val": {"model": "gpt-4", "temp": 0.7}, "builtin_object_class": None, "expected_digest": None, + "set_base_object_class": None, "wb_user_id": None, } }, @@ -398,6 +399,7 @@ def test_write_only_manager(self, tmp_path): "val": {"hello": "world"}, "builtin_object_class": None, "expected_digest": None, + "set_base_object_class": None, "wb_user_id": None, } }, diff --git a/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py index 9eb3ee8ee0c6..9aecd6bad381 100644 --- a/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -12,6 +12,10 @@ import pytest from pydantic import ValidationError +# `tsi` aliases the server interface in this file (coercion-seam tests); +# the binding-facing tests use the SDK models module unaliased. +from weave_server_sdk import models + import weave import weave.trace.call import weave.trace_server.trace_server_interface as tsi @@ -1601,8 +1605,8 @@ def test_table_partitioning(network_proxy_client, use_parallel_table_upload): 100 * 1024 ) # very large buffer to ensure a single request res = remote_client.table_create( - tsi.TableCreateReq( - table=tsi.TableSchemaForInsert( + models.TableCreateReq( + table=models.TableSchemaForInsert( project_id=client.project_id, rows=rows, ) @@ -4380,8 +4384,8 @@ def test_table_create_from_digests(network_proxy_client): # Create a table with these rows to get the row digests table_res = client.server.table_create( - tsi.TableCreateReq( - table=tsi.TableSchemaForInsert( + models.TableCreateReq( + table=models.TableSchemaForInsert( project_id=client.project_id, rows=rows, ) @@ -4393,7 +4397,7 @@ def test_table_create_from_digests(network_proxy_client): # Now create a new table using the same row digests from_digests_res = client.server.table_create_from_digests( - tsi.TableCreateFromDigestsReq( + models.TableCreateFromDigestsReq( project_id=client.project_id, row_digests=row_digests, ) @@ -4411,8 +4415,8 @@ def test_table_create_from_digests(network_proxy_client): ] more_table_res = client.server.table_create( - tsi.TableCreateReq( - table=tsi.TableSchemaForInsert( + models.TableCreateReq( + table=models.TableSchemaForInsert( project_id=client.project_id, rows=more_rows, ) @@ -4423,7 +4427,7 @@ def test_table_create_from_digests(network_proxy_client): # Test with a different order of row digests - should produce different digest combined_res = client.server.table_create_from_digests( - tsi.TableCreateFromDigestsReq( + models.TableCreateFromDigestsReq( project_id=client.project_id, row_digests=combined_digests, ) @@ -4431,7 +4435,7 @@ def test_table_create_from_digests(network_proxy_client): # now get the new table new_table_res = basic_client.server.table_query( - tsi.TableQueryReq( + models.TableQueryReq( project_id=client.project_id, digest=combined_res.digest, ) @@ -4450,7 +4454,7 @@ def test_table_create_from_digests(network_proxy_client): # Test with a different order of row digests - should produce different digest shuffled_digests = [row_digests[2], row_digests[0], row_digests[1]] # [3, 1, 2] shuffled_res = client.server.table_create_from_digests( - tsi.TableCreateFromDigestsReq( + models.TableCreateFromDigestsReq( project_id=client.project_id, row_digests=shuffled_digests, ) diff --git a/tests/trace_server/conftest_lib/request_coercion.py b/tests/trace_server/conftest_lib/request_coercion.py new file mode 100644 index 000000000000..17f2c6936f00 --- /dev/null +++ b/tests/trace_server/conftest_lib/request_coercion.py @@ -0,0 +1,87 @@ +"""Generic request-family coercion for the in-process test servers. + +The Weave client speaks ``weave_server_sdk`` models (generated from the +trace server's OpenAPI spec). The in-process test servers declare the +server's own request types and rely on ``isinstance`` checks internally +(e.g. the query AST in the ClickHouse query builder), so a foreign-but- +field-compatible pydantic model cannot be passed through structurally. + +In production the seam is HTTP: the client serializes to JSON and the server +parses JSON into its own types. This wrapper is the in-process equivalent of +that seam — it serializes whatever request arrives (client SDK model, legacy +model, or plain dict) to JSON-shaped data and lets the wrapped server's own +parameter annotation parse it. ``exclude_none`` matches the SDK's wire +encoding (None means unset). + +This is server-side test infrastructure: it introspects the wrapped server's +own signatures and never imports client or SDK types. Responses are returned +unchanged — the server's response models are JSON-shape-identical to the +SDK's, and callers access fields, not types. +""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Callable +from typing import Any, get_type_hints + +from pydantic import BaseModel + + +def _first_param_model(method: Callable[..., Any]) -> type[BaseModel] | None: + """Return the method's first parameter type if it is a pydantic model.""" + try: + sig = inspect.signature(method) + hints = get_type_hints(method) + except (TypeError, ValueError, NameError): + return None + for name in sig.parameters: + if name in {"self", "cls"}: + continue + annotation = hints.get(name) + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + return annotation + return None + return None + + +class RequestCoercingTraceServer: + """Parses incoming requests with the wrapped server's own request types.""" + + def __init__(self, server: Any) -> None: + self._server = server + self._wrapped_methods: dict[str, Callable[..., Any]] = {} + + def __getattr__(self, name: str) -> Any: + attr = getattr(self._server, name) + if name.startswith("_") or not callable(attr): + return attr + if name in self._wrapped_methods: + return self._wrapped_methods[name] + + expected = _first_param_model(attr) + if expected is None: + return attr + + # functools.wraps matters beyond hygiene: the caching middleware + # namespaces cache keys by func.__name__, and invalidation prefixes + # are built from the real method names. + @functools.wraps(attr) + def wrapper( + req: Any = None, + *args: Any, + _m: Any = attr, + _t: Any = expected, + **kwargs: Any, + ) -> Any: + if isinstance(req, BaseModel) and not isinstance(req, _t): + req = _t.model_validate( + req.model_dump(by_alias=True, exclude_none=True) + ) + elif isinstance(req, dict): + req = _t.model_validate(req) + return _m(req, *args, **kwargs) + + self._wrapped_methods[name] = wrapper + return wrapper diff --git a/tests/trace_server_bindings/conftest.py b/tests/trace_server_bindings/conftest.py index 1548f25acb6e..2be5a878e0a2 100644 --- a/tests/trace_server_bindings/conftest.py +++ b/tests/trace_server_bindings/conftest.py @@ -5,8 +5,8 @@ import httpx import pytest import tenacity +from weave_server_sdk import models as tsi -from weave.trace_server import trace_server_interface as tsi from weave.trace_server.ids import generate_id from weave.trace_server_bindings.remote_http_trace_server import ( RemoteHTTPTraceServer, diff --git a/tests/trace_server_bindings/test_call_batch_processor.py b/tests/trace_server_bindings/test_call_batch_processor.py index bbb4d23e2552..cb731b3639d1 100644 --- a/tests/trace_server_bindings/test_call_batch_processor.py +++ b/tests/trace_server_bindings/test_call_batch_processor.py @@ -7,7 +7,7 @@ import httpx import pytest -from weave.trace_server import trace_server_interface as tsi +from weave.trace_server_bindings import models as tsi from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor from weave.trace_server_bindings.models import ( CompleteBatchItem, diff --git a/tests/trace_server_bindings/test_http_behavior.py b/tests/trace_server_bindings/test_http_behavior.py index b78abd66e9b6..8f6f1a47941c 100644 --- a/tests/trace_server_bindings/test_http_behavior.py +++ b/tests/trace_server_bindings/test_http_behavior.py @@ -27,7 +27,7 @@ generate_start, ) from weave.trace.display.term import configure_logger -from weave.trace_server import trace_server_interface as tsi +from weave.trace_server_bindings import models as tsi from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor from weave.trace_server_bindings.http_utils import ( @@ -39,6 +39,10 @@ EndBatchItem, StartBatchItem, ) + +# Client-side batch envelopes and gap models; not part of the generated SDK +# (the calls_complete v2 and multipart file endpoints are excluded from the +# OpenAPI spec). from weave.trace_server_bindings.remote_http_trace_server import ( RemoteHTTPTraceServer, ) @@ -280,12 +284,14 @@ def test_eager_calls_use_v2_start_end_endpoints(): start = generate_start(id="call-id", project_id="entity/project") ended_at = datetime.datetime.now(tz=datetime.timezone.utc) started_at = ended_at - datetime.timedelta(seconds=1) - end = tsi.EndedCallSchemaForInsertWithStartedAt( + # started_at rides as an extra field (the published SDK model doesn't + # declare it yet). + end = tsi.EndedCallSchemaForInsert( project_id="entity/project", id="call-id", ended_at=ended_at, started_at=started_at, - summary={"result": "Test summary"}, + summary={}, ) try: @@ -305,7 +311,7 @@ def test_eager_calls_use_v2_start_end_endpoints(): payload_started_at = datetime.datetime.fromisoformat( end_payload["end"]["started_at"].replace("Z", "+00:00") ) - assert payload_started_at == end.started_at + assert payload_started_at == started_at assert end_payload["end"]["id"] == "call-id" finally: shutdown(server) @@ -433,13 +439,13 @@ def test_post_timeout(monkeypatch, log_collector): call_start_ok_response(call_id), ) new_server = make_server(transport2, should_batch=False) - fast_retried_via_sdk = fast_retry( + fast_retried_call_sdk = fast_retry( MethodType( - new_server._via_sdk.__wrapped__, # type: ignore[attr-defined] + new_server._call_sdk.__wrapped__, # type: ignore[attr-defined] new_server, ) ) - new_server._via_sdk = fast_retried_via_sdk + new_server._call_sdk = fast_retried_call_sdk response = new_server.call_start(tsi.CallStartReq(start=generate_start(call_id))) assert response.id == call_id diff --git a/tests/trace_server_bindings/test_tags_aliases_routes.py b/tests/trace_server_bindings/test_tags_aliases_routes.py index 817b9426c795..eeaa0b65400b 100644 --- a/tests/trace_server_bindings/test_tags_aliases_routes.py +++ b/tests/trace_server_bindings/test_tags_aliases_routes.py @@ -13,7 +13,10 @@ import pytest from tests.trace_server_bindings.conftest import SpyTransport -from weave.trace_server import trace_server_interface as tsi +from weave.trace_server_bindings import models as tsi + +# Request envelopes for routes whose ids travel in the URL path; these are +# client-side gap models (the generated SDK has body models only). from weave.trace_server_bindings.remote_http_trace_server import ( RemoteHTTPTraceServer, ) diff --git a/tests/trace_server_bindings/test_trace_server_bindings.py b/tests/trace_server_bindings/test_trace_server_bindings.py index b427a388d2ca..d40d2e5f7fcb 100644 --- a/tests/trace_server_bindings/test_trace_server_bindings.py +++ b/tests/trace_server_bindings/test_trace_server_bindings.py @@ -13,13 +13,13 @@ import httpx import pytest +from weave_server_sdk import models as tsi from tests.trace_server_bindings.conftest import ( generate_call_start_end_pair, generate_end, generate_start, ) -from weave.trace_server import trace_server_interface as tsi from weave.trace_server_bindings.models import ( Batch, EndBatchItem, diff --git a/weave/durability/wal_sender.py b/weave/durability/wal_sender.py index 93a00071ea26..f4f53245b930 100644 --- a/weave/durability/wal_sender.py +++ b/weave/durability/wal_sender.py @@ -43,6 +43,7 @@ from pydantic import BaseModel from typing_extensions import Self +from weave_server_sdk import models as tsi from weave.durability.wal import ( WALConsumer, @@ -56,8 +57,11 @@ from weave.durability.wal_directory_manager import FileWALDirectoryManager from weave.durability.wal_lock import is_writer_alive from weave.telemetry.trace_sentry import log_error -from weave.trace_server import trace_server_interface as tsi from weave.trace_server_bindings.client_interface import TraceServerClientInterface + +# FileCreateReq is a binding gap model (multipart upload is not expressible in +# weave-server-sdk 0.0.1); remove when a regenerated SDK covers it. +from weave.trace_server_bindings.models import FileCreateReq # noqa: TID251 from weave.trace_server_bindings.remote_http_trace_server import ( RemoteHTTPTraceServer, ) @@ -298,7 +302,7 @@ def _run(self) -> None: "call_end": tsi.CallEndReq, "obj_create": tsi.ObjCreateReq, "table_create": tsi.TableCreateReq, - "file_create": tsi.FileCreateReq, + "file_create": FileCreateReq, } diff --git a/weave/evaluation/eval.py b/weave/evaluation/eval.py index b67cacdceea8..c28e5db28753 100644 --- a/weave/evaluation/eval.py +++ b/weave/evaluation/eval.py @@ -11,6 +11,8 @@ from pydantic import PrivateAttr from typing_extensions import Self +from weave_server_sdk import models as tsi +from weave_server_sdk.models import CallsFilter from weave.dataset.dataset import Dataset from weave.flow import util @@ -41,8 +43,6 @@ from weave.trace.vals import WeaveObject from weave.trace.weave_client import get_ref from weave.trace_server import constants -from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.trace_server_interface import CallsFilter from weave.utils.project_id import from_project_id logger = logging.getLogger(__name__) diff --git a/weave/evaluation/otel_eval_linker.py b/weave/evaluation/otel_eval_linker.py index 5400158fca84..99b9798e6537 100644 --- a/weave/evaluation/otel_eval_linker.py +++ b/weave/evaluation/otel_eval_linker.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor +from weave_server_sdk import models as tsi from weave.evaluation.eval import ( _attach_genai_span_ref_to_call_summary, @@ -26,7 +27,6 @@ _find_current_predict_and_score_call, ) from weave.trace_server import constants -from weave.trace_server import trace_server_interface as tsi if TYPE_CHECKING: from opentelemetry.context import Context diff --git a/weave/flow/leaderboard.py b/weave/flow/leaderboard.py index 24fe35ac3aa5..ca051d5ce32b 100644 --- a/weave/flow/leaderboard.py +++ b/weave/flow/leaderboard.py @@ -3,10 +3,11 @@ from dataclasses import dataclass from typing import Any +from weave_server_sdk.models import CallsFilter + from weave.trace.refs import OpRef from weave.trace.weave_client import WeaveClient, get_ref from weave.trace_server.interface.builtin_object_classes import leaderboard -from weave.trace_server.trace_server_interface import CallsFilter from weave.utils.project_id import from_project_id diff --git a/weave/flow/monitor.py b/weave/flow/monitor.py index 52779d153809..84fa3a499fd4 100644 --- a/weave/flow/monitor.py +++ b/weave/flow/monitor.py @@ -2,6 +2,7 @@ from pydantic import Field, field_validator from typing_extensions import NotRequired, Self, TypedDict +from weave_server_sdk.models import Query from weave.flow.casting import Scorer from weave.object.obj import Object @@ -13,7 +14,6 @@ from weave.trace.objectify import register_object from weave.trace.refs import OpRef, Ref from weave.trace.vals import WeaveObject -from weave.trace_server.interface.query import Query DebounceAggregationField: TypeAlias = Literal["trace_id", "thread_id"] DebounceAggregationMethod: TypeAlias = Literal["last_message", "all_messages"] diff --git a/weave/flow/saved_view.py b/weave/flow/saved_view.py index c7f47710f87e..d2782eb9e79c 100644 --- a/weave/flow/saved_view.py +++ b/weave/flow/saved_view.py @@ -1,15 +1,18 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Literal, TypedDict +from typing import Any, Literal, TypeAlias, TypedDict -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from weave_server_sdk import models as tsi +from weave_server_sdk.models import SortBy from weave.trace import urls from weave.trace.api import publish as weave_publish from weave.trace.api import ref as weave_ref from weave.trace.call import CallsIter +from weave.trace.casting import cast_to_query from weave.trace.context import weave_client_context from weave.trace.display import display from weave.trace.display.grid import Grid @@ -17,14 +20,42 @@ from weave.trace.refs import ObjectRef, OpRef from weave.trace.traverse import ObjectPath, get_paths from weave.trace.vals import WeaveObject -from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.common_interface import SortBy -from weave.trace_server.interface import query as tsi_query from weave.trace_server.interface.builtin_object_classes.saved_view import Column, Pin from weave.trace_server.interface.builtin_object_classes.saved_view import ( SavedView as SavedViewBase, ) +TRACE_STATUS_SUCCESS: tsi.TraceStatus = "success" +TRACE_STATUS_ERROR: tsi.TraceStatus = "error" +TRACE_STATUS_RUNNING: tsi.TraceStatus = "running" + +# The tsi query module exported an `Operand` union alias; the generated SDK +# inlines the union instead, so spell it out here. +OperandType: TypeAlias = ( + tsi.LiteralOperation + | tsi.GetFieldOperator + | tsi.ConvertOperation + | tsi.AndOperation + | tsi.OrOperation + | tsi.NotOperation + | tsi.EqOperation + | tsi.GtOperation + | tsi.GteOperation + | tsi.InOperation + | tsi.ContainsOperation +) + +# The generated query models type operand lists as list[Any], leaving leaf +# operands as raw dicts after validation; parse them on access. +_OPERAND_ADAPTER: TypeAdapter[Any] = TypeAdapter(OperandType) + + +def _parse_operand(item: Any) -> Any: + if isinstance(item, dict): + return _OPERAND_ADAPTER.validate_python(item) + return item + + KNOWN_COLUMNS = [ "id", "display_name", @@ -291,18 +322,18 @@ def filters_to_query(filters: Filters | None) -> tsi.Query | None: return tsi.Query(**{"$expr": expr}) -def operand_to_filter_eq(operand: tsi_query.EqOperation) -> Filter: - first = operand.eq_[0] - second = operand.eq_[1] - if isinstance(first, tsi_query.ConvertOperation) and first.convert_.to in { +def operand_to_filter_eq(operand: tsi.EqOperation) -> Filter: + first = _parse_operand(operand.eq[0]) + second = _parse_operand(operand.eq[1]) + if isinstance(first, tsi.ConvertOperation) and first.convert.to in { "double", "int", }: - first = first.convert_.input - if isinstance(first, tsi_query.GetFieldOperator) and isinstance( - second, tsi_query.LiteralOperation + first = _parse_operand(first.convert.input) + if isinstance(first, tsi.GetFieldOperator) and isinstance( + second, tsi.LiteralOperation ): - value = second.literal_ + value = second.literal if isinstance(value, str): if value == "": operator = "(any): isEmpty" @@ -313,46 +344,46 @@ def operand_to_filter_eq(operand: tsi_query.EqOperation) -> Filter: operator = "(number): =" else: raise QueryTranslationException(f"Could not parse {operand}") - field = first.get_field_ + field = first.get_field return Filter(field=field, operator=operator, value=value) raise QueryTranslationException(f"Could not parse {operand}") -def operand_to_filter_contains(operand: tsi_query.ContainsOperation) -> Filter: - input = operand.contains_.input - substr = operand.contains_.substr - case_insensitive = operand.contains_.case_insensitive +def operand_to_filter_contains(operand: tsi.ContainsOperation) -> Filter: + input = _parse_operand(operand.contains.input) + substr = _parse_operand(operand.contains.substr) + case_insensitive = operand.contains.case_insensitive # TODO: Handle case_insensitive correctly - if isinstance(input, tsi_query.GetFieldOperator) and isinstance( - substr, tsi_query.LiteralOperation + if isinstance(input, tsi.GetFieldOperator) and isinstance( + substr, tsi.LiteralOperation ): - value = substr.literal_ + value = substr.literal if isinstance(value, str): operator = "(string): contains" else: raise QueryTranslationException(f"Could not parse {operand}") - field = input.get_field_ + field = input.get_field return Filter(field=field, operator=operator, value=value) raise QueryTranslationException(f"Could not parse {operand}") -def operand_to_filter_gt(operand: tsi_query.GtOperation) -> Filter: - first = operand.gt_[0] - second = operand.gt_[1] - if isinstance(first, tsi_query.ConvertOperation) and first.convert_.to in { +def operand_to_filter_gt(operand: tsi.GtOperation) -> Filter: + first = _parse_operand(operand.gt[0]) + second = _parse_operand(operand.gt[1]) + if isinstance(first, tsi.ConvertOperation) and first.convert.to in { "double", "int", }: - first = first.convert_.input - if isinstance(first, tsi_query.GetFieldOperator) and isinstance( - second, tsi_query.LiteralOperation + first = _parse_operand(first.convert.input) + if isinstance(first, tsi.GetFieldOperator) and isinstance( + second, tsi.LiteralOperation ): - value = second.literal_ + value = second.literal if isinstance(value, (int, float)): operator = "(number): >" else: raise QueryTranslationException(f"Could not parse {operand}") - field = first.get_field_ + field = first.get_field if field == "started_at": operator = "(date): after" value = datetime.fromtimestamp(value).isoformat() @@ -360,38 +391,39 @@ def operand_to_filter_gt(operand: tsi_query.GtOperation) -> Filter: raise QueryTranslationException(f"Could not parse {operand}") -def operand_to_filter_gte(operand: tsi_query.GteOperation) -> Filter: - first = operand.gte_[0] - second = operand.gte_[1] - if isinstance(first, tsi_query.ConvertOperation) and first.convert_.to in { +def operand_to_filter_gte(operand: tsi.GteOperation) -> Filter: + first = _parse_operand(operand.gte[0]) + second = _parse_operand(operand.gte[1]) + if isinstance(first, tsi.ConvertOperation) and first.convert.to in { "double", "int", }: - first = first.convert_.input - if isinstance(first, tsi_query.GetFieldOperator) and isinstance( - second, tsi_query.LiteralOperation + first = _parse_operand(first.convert.input) + if isinstance(first, tsi.GetFieldOperator) and isinstance( + second, tsi.LiteralOperation ): - value = second.literal_ + value = second.literal if isinstance(value, (int, float)): operator = "(number): >=" else: raise QueryTranslationException(f"Could not parse {operand}") - field = first.get_field_ + field = first.get_field return Filter(field=field, operator=operator, value=value) raise QueryTranslationException(f"Could not parse {operand}") -def operand_to_filter(operand: tsi_query.Operand) -> Filter: - if isinstance(operand, tsi_query.EqOperation): +def operand_to_filter(operand: OperandType) -> Filter: + operand = _parse_operand(operand) + if isinstance(operand, tsi.EqOperation): return operand_to_filter_eq(operand) - if isinstance(operand, tsi_query.ContainsOperation): + if isinstance(operand, tsi.ContainsOperation): return operand_to_filter_contains(operand) - if isinstance(operand, tsi_query.GtOperation): + if isinstance(operand, tsi.GtOperation): return operand_to_filter_gt(operand) - if isinstance(operand, tsi_query.GteOperation): + if isinstance(operand, tsi.GteOperation): return operand_to_filter_gte(operand) - if isinstance(operand, tsi_query.NotOperation): - filter = operand_to_filter(operand.not_[0]) + if isinstance(operand, tsi.NotOperation): + filter = operand_to_filter(_parse_operand(operand.not_[0])) if filter.operator == "(number): >=": filter.operator = "(number): <" elif filter.operator == "(number): >": @@ -417,7 +449,7 @@ def operand_to_filter(operand: tsi_query.Operand) -> Filter: else: raise QueryTranslationException(f"Could not parse {filter}") return filter - if isinstance(operand, tsi_query.OrOperation) and len(operand.or_) > 0: + if isinstance(operand, tsi.OrOperation) and len(operand.or_) > 0: operands = [operand_to_filter(o) for o in operand.or_] if all(o.field == operands[0].field for o in operands) and all( o.operator == "(string): equals" for o in operands @@ -432,24 +464,26 @@ def query_to_filters(query: tsi.Query | None) -> Filters | None: """Convert Saved View Query to Filters representation.""" if query is None: return None + query = cast_to_query(query) - if isinstance(query.expr_, tsi_query.AndOperation): - operands = query.expr_.and_ + expr = _parse_operand(query.expr) + if isinstance(expr, tsi.AndOperation): + operands = expr.and_ if not operands: return None return [operand_to_filter(o) for o in operands] if isinstance( - query.expr_, + expr, ( - tsi_query.EqOperation, - tsi_query.GtOperation, - tsi_query.GteOperation, - tsi_query.NotOperation, - tsi_query.ContainsOperation, + tsi.EqOperation, + tsi.GtOperation, + tsi.GteOperation, + tsi.NotOperation, + tsi.ContainsOperation, ), ): - return [operand_to_filter(query.expr_)] + return [operand_to_filter(expr)] raise QueryTranslationException(f"Could not parse {query}") @@ -475,11 +509,12 @@ def get_object_path(obj: WeaveObject, path: str | ObjectPath) -> Any: def render_status(value: Any) -> str: - if value == tsi.TraceStatus.SUCCESS: + # tsi.TraceStatus is a Literal type, so compare against its string values. + if value == TRACE_STATUS_SUCCESS: return "✅" - elif value == tsi.TraceStatus.ERROR: + elif value == TRACE_STATUS_ERROR: return "❌" - elif value == tsi.TraceStatus.RUNNING: + elif value == TRACE_STATUS_RUNNING: return "⏳" return value diff --git a/weave/trace/call.py b/weave/trace/call.py index 2bd642c4e972..1158a8b4b101 100644 --- a/weave/trace/call.py +++ b/weave/trace/call.py @@ -7,6 +7,15 @@ from concurrent.futures import Future from typing import TYPE_CHECKING, Any, TypedDict +from weave_server_sdk.models import ( + CallSchema, + CallsFilter, + CallsQueryReq, + CallsQueryStatsReq, + Query, + SortBy, +) + from weave.trace import urls from weave.trace.context import weave_client_context from weave.trace.feedback import RefFeedbackQuery @@ -17,16 +26,8 @@ from weave.trace.serialization.serialize import from_json from weave.trace.util import log_once from weave.trace.vals import WeaveObject -from weave.trace_server.common_interface import SortBy from weave.trace_server.constants import MAX_DISPLAY_NAME_LENGTH -from weave.trace_server.interface.query import Query -from weave.trace_server.trace_server_interface import ( - CallSchema, - CallsFilter, - CallsQueryReq, - CallsQueryStatsReq, - TraceServerInterface, -) +from weave.trace_server_bindings.client_interface import TraceServerClientInterface from weave.utils.attributes_dict import AttributesDict from weave.utils.paginated_iterator import PaginatedIterator from weave.utils.project_id import from_project_id @@ -345,7 +346,7 @@ def elide_display_name(name: str) -> str: def _make_calls_iterator( - server: TraceServerInterface, + server: TraceServerClientInterface, project_id: str, filter: CallsFilter, limit_override: int | None = None, @@ -424,7 +425,10 @@ def size_func() -> int: def make_client_call( - entity: str, project: str, server_call: CallSchema, server: TraceServerInterface + entity: str, + project: str, + server_call: CallSchema, + server: TraceServerClientInterface, ) -> WeaveObject: if (call_id := server_call.id) is None: raise ValueError("Call ID is None") diff --git a/weave/trace/casting.py b/weave/trace/casting.py index b93bc3d5ccfa..f97d57ee5d32 100644 --- a/weave/trace/casting.py +++ b/weave/trace/casting.py @@ -1,17 +1,37 @@ -from typing import Annotated, Any +from typing import Annotated, Any, TypeAlias import pydantic +from weave_server_sdk.models import CallsFilter, Query, SortBy -from weave.trace_server.common_interface import SortBy -from weave.trace_server.interface.query import Query -from weave.trace_server.trace_server_interface import CallsFilter + +def _reject_unknown_keys(obj: dict, model: type[pydantic.BaseModel]) -> None: + """Raise for dict keys the model does not declare. + + The generated SDK models allow extra fields (forward compatibility on + responses), but user-supplied dicts at this boundary should fail fast on + typos like the strict legacy models did. + """ + allowed = set(model.model_fields) + for field in model.model_fields.values(): + if field.alias: + allowed.add(field.alias) + if extras := set(obj) - allowed: + raise ValueError( + f"Extra inputs are not permitted: {sorted(extras)} for {model.__name__}" + ) def cast_to_calls_filter(obj: Any) -> CallsFilter: if isinstance(obj, CallsFilter): return obj + if isinstance(obj, pydantic.BaseModel): + # Foreign model families (e.g. legacy tsi.CallsFilter) are + # field-compatible; re-validate. + return CallsFilter.model_validate(obj.model_dump(by_alias=True)) + if isinstance(obj, dict): + _reject_unknown_keys(obj, CallsFilter) return CallsFilter(**obj) raise TypeError(f"Unable to cast to CallsFilter: {obj}") @@ -21,7 +41,11 @@ def cast_to_sort_by(obj: Any) -> SortBy: if isinstance(obj, SortBy): return obj + if isinstance(obj, pydantic.BaseModel): + return SortBy.model_validate(obj.model_dump(by_alias=True)) + if isinstance(obj, dict): + _reject_unknown_keys(obj, SortBy) return SortBy(**obj) raise TypeError(f"Unable to cast to SortBy: {obj}") @@ -31,12 +55,18 @@ def cast_to_query(obj: Any) -> Query: if isinstance(obj, Query): return obj + if isinstance(obj, pydantic.BaseModel): + return Query.model_validate(obj.model_dump(by_alias=True)) + if isinstance(obj, dict): + _reject_unknown_keys(obj, Query) return Query(**obj) raise TypeError(f"Unable to cast to Query: {obj}") -CallsFilterLike = Annotated[CallsFilter, pydantic.BeforeValidator(cast_to_calls_filter)] -SortByLike = Annotated[SortBy, pydantic.BeforeValidator(cast_to_sort_by)] -QueryLike = Annotated[Query, pydantic.BeforeValidator(cast_to_query)] +CallsFilterLike: TypeAlias = Annotated[ + CallsFilter, pydantic.BeforeValidator(cast_to_calls_filter) +] +SortByLike: TypeAlias = Annotated[SortBy, pydantic.BeforeValidator(cast_to_sort_by)] +QueryLike: TypeAlias = Annotated[Query, pydantic.BeforeValidator(cast_to_query)] diff --git a/weave/trace/feedback.py b/weave/trace/feedback.py index 63586deaf1d8..567a4727c082 100644 --- a/weave/trace/feedback.py +++ b/weave/trace/feedback.py @@ -2,10 +2,15 @@ from __future__ import annotations +import datetime import json from collections.abc import Iterable, Iterator from typing import Any +from pydantic import Field +from weave_server_sdk import models as tsi +from weave_server_sdk.models import Query + from weave.trace import util from weave.trace.context import weave_client_context from weave.trace.display import display @@ -13,18 +18,30 @@ from weave.trace.display.rich.container import AbstractRichContainer from weave.trace.display.rich.refs import Refs from weave.trace.refs import ObjectRef, Ref -from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.interface.query import Query from weave.utils.project_id import to_project_id -class Feedbacks(AbstractRichContainer[tsi.Feedback]): +class Feedback(tsi.FeedbackCreateReq): + """A feedback row, as returned by feedback queries. + + The generated SDK models feedback query results as plain dicts + (FeedbackQueryRes.result); this gives the client a typed row. + """ + + id: str + created_at: datetime.datetime + wb_user_id: str = Field( + description="The user who created the feedback.", + ) + + +class Feedbacks(AbstractRichContainer[Feedback]): """A collection of Feedback objects with utilities.""" show_refs: bool def __init__( - self, show_refs: bool, feedbacks: Iterable[tsi.Feedback] | None = None + self, show_refs: bool, feedbacks: Iterable[Feedback] | None = None ) -> None: super().__init__("Feedback", feedbacks) self.show_refs = show_refs @@ -43,7 +60,7 @@ def _add_table_columns(self, table: display.Table) -> None: table.add_column("ID", overflow="fold") table.add_column("Creator") - def _item_to_row(self, item: tsi.Feedback) -> list: + def _item_to_row(self, item: Feedback) -> list: feedback = item type_ = feedback.feedback_type @@ -114,10 +131,10 @@ def __init__( self.feedbacks = None - def __iter__(self) -> Iterator[tsi.Feedback]: + def __iter__(self) -> Iterator[Feedback]: yield from self.execute() - def __getitem__(self, index: int) -> tsi.Feedback: + def __getitem__(self, index: int) -> Feedback: return self.execute()[index] def __len__(self) -> int: @@ -140,7 +157,7 @@ def refresh(self) -> Feedbacks: response = self.client.server.feedback_query(req) # Response is dicts because API allows user to specify fields, but we don't # expose that in this Python API. - return Feedbacks(self.show_refs, (tsi.Feedback(**r) for r in response.result)) + return Feedbacks(self.show_refs, (Feedback(**r) for r in response.result)) def execute(self) -> Feedbacks: if self.feedbacks is not None: diff --git a/weave/trace/interface_query_builder.py b/weave/trace/interface_query_builder.py index 2795876e04b2..9769437e01a0 100644 --- a/weave/trace/interface_query_builder.py +++ b/weave/trace/interface_query_builder.py @@ -1,10 +1,33 @@ -from typing import Any +from typing import Any, TypeAlias -from weave.trace_server.interface.query import ( +from weave_server_sdk.models import ( + AndOperation, + ContainsOperation, + ConvertOperation, + EqOperation, GetFieldOperator, + GteOperation, + GtOperation, + InOperation, LiteralOperation, NotOperation, - Operand, + OrOperation, +) + +# The tsi query module exported an `Operand` union alias; the generated SDK +# inlines the union instead, so spell it out here. +Operand: TypeAlias = ( + LiteralOperation + | GetFieldOperator + | ConvertOperation + | AndOperation + | OrOperation + | NotOperation + | EqOperation + | GtOperation + | GteOperation + | InOperation + | ContainsOperation ) diff --git a/weave/trace/project_id_resolver.py b/weave/trace/project_id_resolver.py index 25d3c97440ea..a5b27257c515 100644 --- a/weave/trace/project_id_resolver.py +++ b/weave/trace/project_id_resolver.py @@ -14,10 +14,10 @@ from typing import TYPE_CHECKING, Any from httpx import HTTPStatusError as HTTPError +from weave_server_sdk.models import ProjectsInfoReq from weave.trace.settings import should_enable_client_side_digests from weave.trace_server.errors import DigestMismatchError -from weave.trace_server.trace_server_interface import ProjectsInfoReq if TYPE_CHECKING: from weave.trace_server_bindings.client_interface import TraceServerClientInterface diff --git a/weave/trace/serialization/custom_objs.py b/weave/trace/serialization/custom_objs.py index bc5cf432e94b..e8b77eeae126 100644 --- a/weave/trace/serialization/custom_objs.py +++ b/weave/trace/serialization/custom_objs.py @@ -4,6 +4,8 @@ from collections.abc import Mapping from typing import Any, Literal, TypedDict +from weave_server_sdk.models import FileContentReadReq + from weave.trace.context.weave_client_context import ( get_weave_client, require_weave_client, @@ -23,10 +25,7 @@ is_probably_legacy_inline_load, ) from weave.trace.settings import should_allow_unsafe_custom_obj_decode -from weave.trace_server.trace_server_interface import ( - FileContentReadReq, - TraceServerInterface, -) +from weave.trace_server_bindings.client_interface import TraceServerClientInterface logger = logging.getLogger(__name__) @@ -180,7 +179,7 @@ def _ensure_bytes(value: str | bytes) -> bytes: def _load_custom_obj_files( - project_id: str, server: TraceServerInterface, file_digests: dict + project_id: str, server: TraceServerClientInterface, file_digests: dict ) -> dict[str, bytes]: loaded_files: dict[str, bytes] = {} for name, digest in file_digests.items(): diff --git a/weave/trace/serialization/serialize.py b/weave/trace/serialization/serialize.py index 7b487d4d4207..1c0f49919a73 100644 --- a/weave/trace/serialization/serialize.py +++ b/weave/trace/serialization/serialize.py @@ -12,10 +12,11 @@ from weave.trace.refs import ObjectRef, Ref, TableRef from weave.trace.serialization import custom_objs from weave.trace.serialization.dictifiable import try_to_dict -from weave.trace_server.trace_server_interface import ( - FileCreateReq, - TraceServerInterface, -) +from weave.trace_server_bindings.client_interface import TraceServerClientInterface + +# FileCreateReq is a binding gap model (multipart upload is not expressible +# in weave-server-sdk 0.0.1); remove when a regenerated SDK covers it. +from weave.trace_server_bindings.models import FileCreateReq # noqa: TID251 from weave.utils.sanitize import REDACTED_VALUE, should_redact if TYPE_CHECKING: @@ -400,7 +401,7 @@ def isinstance_namedtuple(obj: Any) -> bool: ) -def from_json(obj: Any, project_id: str, server: TraceServerInterface) -> Any: +def from_json(obj: Any, project_id: str, server: TraceServerClientInterface) -> Any: if isinstance(obj, list): return [from_json(v, project_id, server) for v in obj] elif isinstance(obj, dict): diff --git a/weave/trace/vals.py b/weave/trace/vals.py index d172dafd74c7..4d7f9b4bf5c7 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -7,6 +7,12 @@ from typing import Any, Literal, SupportsIndex, cast from pydantic import BaseModel +from weave_server_sdk.models import ( + ObjReadReq, + TableQueryReq, + TableQueryStatsReq, + TableRowFilter, +) from weave.trace import box from weave.trace.context.tests_context import get_raise_on_captured_errors @@ -26,13 +32,7 @@ from weave.trace.serialization.serialize import from_json from weave.trace.table import Table from weave.trace_server.errors import ObjectDeletedError -from weave.trace_server.trace_server_interface import ( - ObjReadReq, - TableQueryReq, - TableQueryStatsReq, - TableRowFilter, - TraceServerInterface, -) +from weave.trace_server_bindings.client_interface import TraceServerClientInterface from weave.trace_server_bindings.http_utils import retry_on_not_found from weave.utils.iterators import ThreadSafeLazyList from weave.utils.project_id import to_project_id @@ -116,7 +116,7 @@ class Traceable: mutations: list[Mutation] | None = None root: "Traceable" parent: "Traceable | None" = None - server: TraceServerInterface + server: TraceServerClientInterface _is_dirty: bool = False def _mark_dirty(self) -> None: @@ -186,7 +186,7 @@ def attribute_access_result( val_attr_val: Any, attr_name: str, *, - server: TraceServerInterface | None, + server: TraceServerClientInterface | None, ) -> Any: # Not ideal, what about properties? if callable(val_attr_val): @@ -222,7 +222,7 @@ def __init__( self, val: Any, ref: RefWithExtra | None, - server: TraceServerInterface, + server: TraceServerClientInterface, root: Traceable | None, parent: Traceable | None = None, ) -> None: @@ -301,7 +301,7 @@ class WeaveTable(Traceable): # noqa: PLW1641 def __init__( self, - server: TraceServerInterface, + server: TraceServerClientInterface, table_ref: TableRef | None = None, ref: RefWithExtra | None = None, filter: TableRowFilter | None = None, @@ -616,7 +616,7 @@ class WeaveList(Traceable, list): # noqa: PLW1641 def __init__( self, *args: Any, - server: TraceServerInterface, + server: TraceServerClientInterface, ref: RefWithExtra | None = None, root: Traceable | None = None, parent: Traceable | None = None, @@ -695,7 +695,7 @@ class WeaveDict(Traceable, dict): # noqa: PLW1641 def __init__( self, *args: Any, - server: TraceServerInterface, + server: TraceServerClientInterface, ref: RefWithExtra | None = None, root: Traceable | None = None, parent: Traceable | None = None, @@ -779,7 +779,7 @@ def unwrap(self) -> Any: def make_trace_obj( val: Any, new_ref: RefWithExtra | None, # Can this actually be None? - server: TraceServerInterface, + server: TraceServerClientInterface, root: Traceable | None, parent: Any = None, ) -> Any: diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index a30cb159ed0d..7c4540a14476 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -15,6 +15,51 @@ import pydantic from httpx import HTTPStatusError as HTTPError +from weave_server_sdk.models import ( + AnnotationQueueAddCallsRes, + AnnotationQueueCreateReq, + AnnotationQueueItemSchema, + AnnotationQueueItemsFilter, + AnnotationQueueSchema, + AnnotationQueuesQueryReq, + AnnotationQueuesStatsReq, + AnnotationQueueStatsSchema, + CallEndReq, + CallsDeleteReq, + CallsFilter, + CallsQueryReq, + CallStartReq, + CallUpdateReq, + CostCreateInput, + CostCreateReq, + CostCreateRes, + CostPurgeReq, + CostQueryOutput, + CostQueryReq, + EndedCallSchemaForInsert, + FeedbackCreateReq, + FileCreateRes, + ObjCreateReq, + ObjCreateRes, + ObjDeleteReq, + ObjectVersionFilter, + ObjQueryReq, + ObjReadReq, + ObjSchema, + ObjSchemaForInsert, + Query, + RefsReadBatchReq, + SortBy, + StartedCallSchemaForInsert, + TableAppendSpec, + TableAppendSpecPayload, + TableCreateFromDigestsReq, + TableCreateReq, + TableCreateRes, + TableSchemaForInsert, + TableUpdateReq, + TraceStatus, +) from weave.chat.chat import Chat from weave.chat.inference_models import InferenceModels @@ -35,7 +80,13 @@ elide_display_name, make_client_call, ) -from weave.trace.casting import CallsFilterLike, QueryLike, SortByLike +from weave.trace.casting import ( + CallsFilterLike, + QueryLike, + SortByLike, + cast_to_query, + cast_to_sort_by, +) from weave.trace.concurrent.futures import FutureExecutor from weave.trace.constants import TRACE_CALL_EMOJI from weave.trace.context import call_context @@ -106,7 +157,6 @@ get_global_wb_run_context, ) from weave.trace.weave_client_send_file_cache import WeaveClientSendFileCache -from weave.trace_server.common_interface import AnnotationQueueItemsFilter, SortBy from weave.trace_server.constants import MAX_OBJECT_NAME_LENGTH from weave.trace_server.errors import DigestMismatchError, InvalidExternalRef from weave.trace_server.ids import generate_id @@ -116,61 +166,6 @@ runnable_feedback_runnable_ref_selector, ) from weave.trace_server.trace_server_converter import universal_ext_to_int_ref_converter -from weave.trace_server.trace_server_interface import ( - AliasesListReq, - AnnotationQueueAddCallsReq, - AnnotationQueueAddCallsRes, - AnnotationQueueCreateReq, - AnnotationQueueDeleteReq, - AnnotationQueueItemSchema, - AnnotationQueueItemsQueryReq, - AnnotationQueueReadReq, - AnnotationQueueSchema, - AnnotationQueuesQueryReq, - AnnotationQueuesStatsReq, - AnnotationQueueStatsSchema, - AnnotationQueueUpdateReq, - CallEndReq, - CallsDeleteReq, - CallsFilter, - CallsQueryReq, - CallStartReq, - CallUpdateReq, - CostCreateInput, - CostCreateReq, - CostCreateRes, - CostPurgeReq, - CostQueryOutput, - CostQueryReq, - EndedCallSchemaForInsertWithStartedAt, - FeedbackCreateReq, - FileCreateReq, - FileCreateRes, - ObjAddTagsReq, - ObjCreateReq, - ObjCreateRes, - ObjDeleteReq, - ObjectVersionFilter, - ObjQueryReq, - ObjReadReq, - ObjRemoveAliasesReq, - ObjRemoveTagsReq, - ObjSchema, - ObjSchemaForInsert, - ObjSetAliasesReq, - Query, - RefsReadBatchReq, - StartedCallSchemaForInsert, - TableAppendSpec, - TableAppendSpecPayload, - TableCreateFromDigestsReq, - TableCreateReq, - TableCreateRes, - TableSchemaForInsert, - TableUpdateReq, - TagsListReq, - TraceStatus, -) from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor from weave.trace_server_bindings.client_interface import TraceServerClientInterface @@ -186,7 +181,26 @@ LinkAssetToRegistryTarget, link_asset_to_registry, ) -from weave.trace_server_bindings.models import StartBatchItem + +# Binding gap models: request envelopes for routes whose ids travel in the +# URL path, the multipart file upload, and the batch envelope for the eager +# start path — none expressible in weave-server-sdk 0.0.1. Remove when a +# regenerated SDK covers them. +from weave.trace_server_bindings.models import ( # noqa: TID251 + AliasesListReq, + AnnotationQueueAddCallsReq, + AnnotationQueueDeleteReq, + AnnotationQueueItemsQueryReq, + AnnotationQueueReadReq, + AnnotationQueueUpdateReq, + FileCreateReq, + ObjAddTagsReq, + ObjRemoveAliasesReq, + ObjRemoveTagsReq, + ObjSetAliasesReq, + StartBatchItem, + TagsListReq, +) from weave.utils.attributes_dict import AttributesDict from weave.utils.capture_info import get_capture_info from weave.utils.dict_utils import sum_dict_leaves, zip_dicts @@ -204,6 +218,11 @@ logger = logging.getLogger(__name__) +# tsi.TraceStatus is a Literal type in weave-server-sdk; runtime constants for +# building summary status counts. +TRACE_STATUS_SUCCESS: TraceStatus = "success" +TRACE_STATUS_ERROR: TraceStatus = "error" + class NoInternalProjectIDError(Exception): """Raised when client-side digest computation cannot proceed because @@ -805,7 +824,9 @@ def list_annotation_queues( AnnotationQueuesQueryReq( project_id=self.project_id, name=name, - sort_by=sort_by, + sort_by=[cast_to_sort_by(s) for s in sort_by] + if sort_by is not None + else None, limit=limit, offset=offset, ) @@ -884,8 +905,14 @@ def list_annotation_queue_items( AnnotationQueueItemsQueryReq( project_id=self.project_id, queue_id=queue_id, - filter=filter, - sort_by=sort_by, + filter=AnnotationQueueItemsFilter.model_validate( + filter.model_dump(by_alias=True) + ) + if filter is not None + else None, + sort_by=[cast_to_sort_by(s) for s in sort_by] + if sort_by is not None + else None, limit=limit, offset=offset, include_position=include_position, @@ -1204,12 +1231,12 @@ def finish_call( # Create client-side rollup of status_counts_by_op status_counts_dict = computed_summary.setdefault( RESERVED_SUMMARY_STATUS_COUNTS_KEY, - {TraceStatus.SUCCESS: 0, TraceStatus.ERROR: 0}, + {TRACE_STATUS_SUCCESS: 0, TRACE_STATUS_ERROR: 0}, ) if exception: - status_counts_dict[TraceStatus.ERROR] += 1 + status_counts_dict[TRACE_STATUS_ERROR] += 1 else: - status_counts_dict[TraceStatus.SUCCESS] += 1 + status_counts_dict[TRACE_STATUS_SUCCESS] += 1 # Merge any user-provided summary values with computed values merged_summary = copy.deepcopy(call.summary or {}) @@ -1255,7 +1282,10 @@ def send_end_call() -> None: ) call_end_req = CallEndReq( - end=EndedCallSchemaForInsertWithStartedAt( + # started_at is an extra field: the published SDK's + # EndedCallSchemaForInsert does not declare it yet, and the + # models allow extras. + end=EndedCallSchemaForInsert( project_id=project_id, id=call.id, started_at=call.started_at, @@ -1711,7 +1741,7 @@ def get_feedback( # Find all feedback objects with a specific feedback type with # mongo-style query. - from weave.trace_server.interface.query import Query + from weave_server_sdk.models import Query query = Query( **{ @@ -1748,8 +1778,10 @@ def get_feedback( {"$literal": query}, ], } - elif isinstance(query, Query): - expr = query.expr_ + elif query is not None: + # Accept Query, foreign model families (e.g. legacy tsi.Query), + # and raw dicts. + expr = cast_to_query(query).expr if reaction: expr = { @@ -1892,8 +1924,10 @@ def query_costs( {"$literal": query}, ], } - elif isinstance(query, Query): - expr = query.expr_ + elif query is not None: + # Accept Query, foreign model families (e.g. legacy tsi.Query), + # and raw dicts. + expr = cast_to_query(query).expr if llm_ids: expr = { diff --git a/weave/trace/weave_client_send_file_cache.py b/weave/trace/weave_client_send_file_cache.py index 82ea5b78e329..12efc4096519 100644 --- a/weave/trace/weave_client_send_file_cache.py +++ b/weave/trace/weave_client_send_file_cache.py @@ -8,7 +8,11 @@ from concurrent.futures import Future from typing import Generic, TypeVar -from weave.trace_server.trace_server_interface import FileCreateReq, FileCreateRes +from weave_server_sdk.models import FileCreateRes + +# FileCreateReq is a binding gap model (multipart upload is not expressible +# in weave-server-sdk 0.0.1); remove when a regenerated SDK covers it. +from weave.trace_server_bindings.models import FileCreateReq # noqa: TID251 # Define generic type variables K = TypeVar("K") # Key type diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index bcb61a9ba567..385f2c761ad8 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -31,8 +31,9 @@ from weave.wandb_interface.context import get_wandb_api_context if TYPE_CHECKING: + from weave_server_sdk.models import ServerInfoRes + from weave.trace.op import PostprocessInputsFunc, PostprocessOutputFunc - from weave.trace_server.service_interface import ServerInfoRes logger = logging.getLogger(__name__) diff --git a/weave/trace_server/interface/builtin_object_classes/llm_structured_model.py b/weave/trace_server/interface/builtin_object_classes/llm_structured_model.py index 0024ee4c3174..0a6db5db6187 100644 --- a/weave/trace_server/interface/builtin_object_classes/llm_structured_model.py +++ b/weave/trace_server/interface/builtin_object_classes/llm_structured_model.py @@ -8,7 +8,7 @@ from weave.trace import vals from weave.trace.context.weave_client_context import WeaveInitError, get_weave_client from weave.trace_server.interface.builtin_object_classes import base_object_def -from weave.trace_server.trace_server_interface import ( +from weave.trace_server_bindings.models import ( CompletionsCreateReq, CompletionsCreateRequestInputs, ) diff --git a/weave/trace_server/interface/builtin_object_classes/saved_view.py b/weave/trace_server/interface/builtin_object_classes/saved_view.py index d7c51b6c952c..6471debe281e 100644 --- a/weave/trace_server/interface/builtin_object_classes/saved_view.py +++ b/weave/trace_server/interface/builtin_object_classes/saved_view.py @@ -1,9 +1,9 @@ from typing import Literal from pydantic import BaseModel, Field +from weave_server_sdk import models as tsi +from weave_server_sdk.models import SortBy -from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.common_interface import SortBy from weave.trace_server.interface.builtin_object_classes import base_object_def PathElement = str | int diff --git a/weave/trace_server_bindings/caching_middleware_trace_server.py b/weave/trace_server_bindings/caching_middleware_trace_server.py index c88de40ecef1..979ae0454159 100644 --- a/weave/trace_server_bindings/caching_middleware_trace_server.py +++ b/weave/trace_server_bindings/caching_middleware_trace_server.py @@ -18,7 +18,7 @@ server_cache_size_limit, use_server_cache, ) -from weave.trace_server import trace_server_interface as tsi +from weave.trace_server_bindings import models as tsi from weave.trace_server_bindings.caches import DiskCache, LRUCache, StackedCache from weave.trace_server_bindings.client_interface import TraceServerClientInterface from weave.trace_server_bindings.delegating_trace_server import ( @@ -243,6 +243,7 @@ def _with_cache_pydantic( func: Callable[[TReq], TRes], req: TReq, res_type: type[TRes], + make_cache_key: Callable[[TReq], str] | None = None, ) -> TRes: """Cache the result of a function that takes and returns Pydantic models. @@ -253,15 +254,28 @@ def _with_cache_pydantic( func: The function to cache results for req: The request object (must be a Pydantic model) res_type: The response type (must be a Pydantic model) + make_cache_key: Optional cache-key builder. Methods whose entries + are prefix-invalidated must use a key whose leading fields + match the invalidation prefix. Returns: The function result, either from cache or from calling func """ + + def call_and_normalize(r: TReq) -> TRes: + res = func(r) + # Wrapped servers may return a field-compatible foreign model + # family (e.g. the in-process test servers); normalize so hits + # and misses return the same declared type. + if not isinstance(res, res_type): + res = res_type.model_validate(res.model_dump(by_alias=True)) + return res + return self._with_cache( - func, + call_and_normalize, req, func.__name__, - pydantic_bytes_safe_dump, + make_cache_key or pydantic_bytes_safe_dump, lambda res: res.model_dump_json(), res_type.model_validate_json, ) @@ -291,14 +305,20 @@ def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: if not digest_is_cacheable(req.digest): return self._next_trace_server.obj_read(req) return self._with_cache_pydantic( - self._next_trace_server.obj_read, req, tsi.ObjReadRes + self._next_trace_server.obj_read, + req, + tsi.ObjReadRes, + make_cache_key=_obj_read_cache_key, ) # Obj API def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: # All obj_create requests are cacheable! return self._with_cache_pydantic( - self._next_trace_server.obj_create, req, tsi.ObjCreateRes + self._next_trace_server.obj_create, + req, + tsi.ObjCreateRes, + make_cache_key=_obj_create_cache_key, ) def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: @@ -393,7 +413,9 @@ def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsR if not digest_is_cacheable(req.digest): return self._next_trace_server.table_query_stats(req) return self._with_cache_pydantic( - self._next_trace_server.table_query_stats, req, tsi.TableQueryStatsRes + self._next_trace_server.table_query_stats, + req, + tsi.TableQueryStatsRes, ) def table_query_stats_batch( @@ -484,20 +506,25 @@ def files_stats(self, req: tsi.FilesStatsReq) -> tsi.FilesStatsRes: tsi.FilesStatsRes, ) - # Object APIs - def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: - if not digest_is_cacheable(req.digest): - return self._next_trace_server.op_read(req) - return self._with_cache_pydantic( - self._next_trace_server.op_read, req, tsi.OpReadRes - ) - def dataset_read(self, req: tsi.DatasetReadReq) -> tsi.DatasetReadRes: - if not digest_is_cacheable(req.digest): - return self._next_trace_server.dataset_read(req) - return self._with_cache_pydantic( - self._next_trace_server.dataset_read, req, tsi.DatasetReadRes - ) +def _reorder_leading(d: dict[str, Any], leading: tuple[str, ...]) -> dict[str, Any]: + """Return a copy of d with `leading` keys first (in that order). + + The generated SDK models declare fields in sorted order, but the cache's + prefix-invalidation scheme requires identifying fields to lead the + serialized key. + """ + return {key: d.pop(key) for key in leading} | d + + +def _obj_read_cache_key(req: tsi.ObjReadReq) -> str: + raw = _reorder_leading(req.model_dump(), ("project_id", "object_id", "digest")) + return json.dumps(_bytes_to_base64(raw), ensure_ascii=False) + + +def _obj_create_cache_key(req: tsi.ObjCreateReq) -> str: + obj = _reorder_leading(req.model_dump()["obj"], ("project_id", "object_id")) + return json.dumps({"obj": _bytes_to_base64(obj)}, ensure_ascii=False) def _build_invalidation_prefix(namespace: str, match_fields: dict[str, Any]) -> str: @@ -516,20 +543,19 @@ def _build_invalidation_prefix(namespace: str, match_fields: dict[str, Any]) -> return f"{namespace}_{serialized.rstrip('}')}" +def _bytes_to_base64(obj: Any) -> Any: + """Convert bytes to base64 strings for JSON serialization, recursively.""" + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("utf-8") + elif isinstance(obj, dict): + return {k: _bytes_to_base64(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_bytes_to_base64(v) for v in obj] + return obj + + def pydantic_bytes_safe_dump(obj: BaseModel) -> str: - raw_dict = obj.model_dump() - - # Convert bytes to base64 string for JSON serialization - def _bytes_to_base64(obj: Any) -> Any: - if isinstance(obj, bytes): - return base64.b64encode(obj).decode("utf-8") - elif isinstance(obj, dict): - return {k: _bytes_to_base64(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [_bytes_to_base64(v) for v in obj] - return obj - - processed_dict = _bytes_to_base64(raw_dict) + processed_dict = _bytes_to_base64(obj.model_dump()) return json.dumps(processed_dict, ensure_ascii=False) diff --git a/weave/trace_server_bindings/call_batch_processor.py b/weave/trace_server_bindings/call_batch_processor.py index d7a83df99134..d25721722bd7 100644 --- a/weave/trace_server_bindings/call_batch_processor.py +++ b/weave/trace_server_bindings/call_batch_processor.py @@ -16,7 +16,7 @@ from cachetools import TTLCache -from weave.trace_server import trace_server_interface as tsi +from weave.trace_server_bindings import models as tsi from weave.trace_server_bindings.async_batch_processor import ( AsyncBatchProcessor, SkipIndividualProcessingError, diff --git a/weave/trace_server_bindings/client_interface.py b/weave/trace_server_bindings/client_interface.py index 6da25711d8cc..f77c8c1813dc 100644 --- a/weave/trace_server_bindings/client_interface.py +++ b/weave/trace_server_bindings/client_interface.py @@ -1,16 +1,140 @@ +"""The client's view of a trace server. + +This protocol is owned by the client and typed entirely with the models in +``weave.trace_server_bindings.models`` — the OpenAPI-generated +``weave_server_sdk.models`` (the source of truth, re-exported) plus a handful +of gap models for surface the published SDK does not yet express. The ``tsi`` +alias is a temporary migration aid; the follow-up PR imports the models +directly. + +It deliberately contains only the methods the client (WeaveClient, flow +modules, caching middleware, WAL) actually uses. Server implementations may — +and do — expose more; extra methods reach callers through the delegation +mixin's passthrough rather than this contract. +""" + from __future__ import annotations +from collections.abc import Iterator from typing import Protocol -from weave.trace_server.service_interface import ServiceInterface -from weave.trace_server.trace_server_interface import FullTraceServerInterface +from weave.trace_server_bindings import models as tsi +from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor +from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor + + +class TraceServerClientInterface(Protocol): + """What the Weave client requires of a trace server.""" + + # ---- service ---------------------------------------------------------- + + def ensure_project_exists( + self, entity: str, project: str + ) -> tsi.EnsureProjectExistsRes: ... + def set_auth(self, auth: tuple[str, str]) -> None: ... + def get_call_processor( + self, + ) -> AsyncBatchProcessor | CallBatchProcessor | None: ... + def get_feedback_processor(self) -> AsyncBatchProcessor | None: ... + def server_info(self) -> tsi.ServerInfoRes: ... + def projects_info(self, req: tsi.ProjectsInfoReq) -> list[tsi.ProjectsInfoRes]: ... + + # ---- calls ------------------------------------------------------------- + + def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: ... + def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: ... + def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: ... + def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: ... + def calls_query_stream( + self, req: tsi.CallsQueryReq + ) -> Iterator[tsi.CallSchema]: ... + def calls_query_stats( + self, req: tsi.CallsQueryStatsReq + ) -> tsi.CallsQueryStatsRes: ... + + # ---- objects ----------------------------------------------------------- + + def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: ... + def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: ... + def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: ... + def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: ... + + # Request envelopes for these come from bindings models (the OpenAPI spec + # carries the ids in the URL path, so the SDK has body models only). + def obj_add_tags(self, req: tsi.ObjAddTagsReq) -> tsi.ObjAddTagsRes: ... + def obj_remove_tags(self, req: tsi.ObjRemoveTagsReq) -> tsi.ObjRemoveTagsRes: ... + def obj_set_aliases(self, req: tsi.ObjSetAliasesReq) -> tsi.ObjSetAliasesRes: ... + def obj_remove_aliases( + self, req: tsi.ObjRemoveAliasesReq + ) -> tsi.ObjRemoveAliasesRes: ... + def tags_list(self, req: tsi.TagsListReq) -> tsi.TagsListRes: ... + def aliases_list(self, req: tsi.AliasesListReq) -> tsi.AliasesListRes: ... + + # ---- tables ------------------------------------------------------------ + + def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: ... + def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: ... + def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: ... + def table_query_stats( + self, req: tsi.TableQueryStatsReq + ) -> tsi.TableQueryStatsRes: ... + def table_query_stats_batch( + self, req: tsi.TableQueryStatsBatchReq + ) -> tsi.TableQueryStatsBatchRes: ... + def table_create_from_digests( + self, req: tsi.TableCreateFromDigestsReq + ) -> tsi.TableCreateFromDigestsRes: ... + + # ---- refs / files ------------------------------------------------------ + + def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: ... + def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: ... + def file_content_read( + self, req: tsi.FileContentReadReq + ) -> tsi.FileContentReadRes: ... + def files_stats(self, req: tsi.FilesStatsReq) -> tsi.FilesStatsRes: ... + + # ---- feedback ---------------------------------------------------------- + + def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: ... + def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: ... + def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: ... + + # ---- costs -------------------------------------------------------------- + + def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: ... + def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: ... + def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: ... + # ---- annotation queues --------------------------------------------------- -class TraceServerClientInterface(FullTraceServerInterface, ServiceInterface, Protocol): - """Combined interface for trace server client implementations. + def annotation_queue_create( + self, req: tsi.AnnotationQueueCreateReq + ) -> tsi.AnnotationQueueCreateRes: ... + def annotation_queues_query_stream( + self, req: tsi.AnnotationQueuesQueryReq + ) -> Iterator[tsi.AnnotationQueueSchema]: ... + def annotation_queue_read( + self, req: tsi.AnnotationQueueReadReq + ) -> tsi.AnnotationQueueReadRes: ... + def annotation_queue_delete( + self, req: tsi.AnnotationQueueDeleteReq + ) -> tsi.AnnotationQueueDeleteRes: ... + def annotation_queue_update( + self, req: tsi.AnnotationQueueUpdateReq + ) -> tsi.AnnotationQueueUpdateRes: ... + def annotation_queue_add_calls( + self, req: tsi.AnnotationQueueAddCallsReq + ) -> tsi.AnnotationQueueAddCallsRes: ... + def annotation_queue_items_query( + self, req: tsi.AnnotationQueueItemsQueryReq + ) -> tsi.AnnotationQueueItemsQueryRes: ... + def annotation_queues_stats( + self, req: tsi.AnnotationQueuesStatsReq + ) -> tsi.AnnotationQueuesStatsRes: ... - Union of the storage interface (FullTraceServerInterface) and the service - interface (ServiceInterface). - """ + # ---- completions ---------------------------------------------------------- - pass + def completions_create( + self, req: tsi.CompletionsCreateReq + ) -> tsi.CompletionsCreateRes: ... diff --git a/weave/trace_server_bindings/delegating_trace_server.py b/weave/trace_server_bindings/delegating_trace_server.py index 514475852441..a7c7877071dd 100644 --- a/weave/trace_server_bindings/delegating_trace_server.py +++ b/weave/trace_server_bindings/delegating_trace_server.py @@ -2,18 +2,12 @@ from typing import Any -from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.service_interface import ServiceInterface +from weave.trace_server_bindings.client_interface import TraceServerClientInterface _TRACE_SERVER_METHOD_NAMES = frozenset( { name - for interface in ( - tsi.TraceServerInterface, - tsi.ObjectInterface, - ServiceInterface, - ) - for name, value in vars(interface).items() + for name, value in vars(TraceServerClientInterface).items() if callable(value) and not name.startswith("_") } ) diff --git a/weave/trace_server_bindings/http_utils.py b/weave/trace_server_bindings/http_utils.py index 8568d3604ef7..da4152808e9e 100644 --- a/weave/trace_server_bindings/http_utils.py +++ b/weave/trace_server_bindings/http_utils.py @@ -6,8 +6,8 @@ import httpx import tenacity from typing_extensions import ParamSpec +from weave_server_sdk import models as tsi -from weave.trace_server import trace_server_interface as tsi from weave.trace_server.errors import NotFoundError, ObjectDeletedError from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor from weave.utils.retry import _is_retryable_exception, with_retry diff --git a/weave/trace_server_bindings/link_asset_to_registry.py b/weave/trace_server_bindings/link_asset_to_registry.py index b58ca976f16f..02babe87bc69 100644 --- a/weave/trace_server_bindings/link_asset_to_registry.py +++ b/weave/trace_server_bindings/link_asset_to_registry.py @@ -3,7 +3,6 @@ from pydantic import BaseModel, ConfigDict, Field from weave.trace.env import weave_trace_server_url -from weave.trace_server.common_interface import BaseModelStrict from weave.trace_server_bindings.http_utils import handle_response_error from weave.utils import http_requests from weave.wandb_interface.context import get_wandb_api_context @@ -11,13 +10,19 @@ LINK_TO_REGISTRY_PATH = "/link_to_registry" -class LinkAssetToRegistryTarget(BaseModelStrict): +class LinkAssetToRegistryTarget(BaseModel): + # Strict: reject unknown fields so client typos fail fast. + model_config = ConfigDict(extra="forbid") + portfolio_name: str entity_name: str project_name: str -class LinkAssetToRegistryReq(BaseModelStrict): +class LinkAssetToRegistryReq(BaseModel): + # Strict: reject unknown fields so client typos fail fast. + model_config = ConfigDict(extra="forbid") + ref: str target: LinkAssetToRegistryTarget aliases: list[str] = Field(default_factory=list) diff --git a/weave/trace_server_bindings/models.py b/weave/trace_server_bindings/models.py index 7f368fd45910..81c8070c7f16 100644 --- a/weave/trace_server_bindings/models.py +++ b/weave/trace_server_bindings/models.py @@ -1,8 +1,483 @@ -from typing import Literal +"""Client-side models for the trace server bindings. -from pydantic import BaseModel +Types here fall into two groups: -from weave.trace_server import trace_server_interface as tsi +1. Batch-item envelopes used by the client's batching machinery (these never + cross the wire as-is; their ``req`` payloads do). +2. Gap models for API surface the published ``weave-server-sdk`` does not yet + express: multipart file upload, binary file content, the calls_complete + payload, and the wandb-side ensure-project response. Each should be + deleted when a regenerated SDK covers it. +""" + +from __future__ import annotations + +import datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field +from weave_server_sdk import models as tsi + +# TEMPORARY migration aid: re-export the generated SDK models so binding-layer +# files can keep referencing every model (SDK and gap alike) through one +# `tsi`-aliased namespace, keeping this PR's diffs minimal. The follow-up PR +# removes the alias and this re-export in favor of direct imports. Explicit +# (not a star import) so the SDK module's own imports cannot shadow names here. +from weave_server_sdk.models import ( # noqa: F401 + AggregationType, + AliasesListRes, + AndOperation, + AnnotationQueueAddCallsBody, + AnnotationQueueAddCallsRes, + AnnotationQueueCreateReq, + AnnotationQueueCreateRes, + AnnotationQueueDeleteRes, + AnnotationQueueItemProgressUpdateBody, + AnnotationQueueItemSchema, + AnnotationQueueItemsFilter, + AnnotationQueueItemsQueryBody, + AnnotationQueueItemsQueryRes, + AnnotationQueueReadRes, + AnnotationQueueSchema, + AnnotationQueuesQueryReq, + AnnotationQueuesStatsReq, + AnnotationQueuesStatsRes, + AnnotationQueueStatsSchema, + AnnotationQueueUpdateBody, + AnnotationQueueUpdateRes, + AnnotatorQueueItemsProgressUpdateRes, + Body_file_create_file_create_post, + CallBatchEndMode, + CallBatchStartMode, + CallCreateBatchReq, + CallCreateBatchRes, + CallEndReq, + CallEndRes, + CallMetricSpec, + CallReadReq, + CallReadRes, + CallSchema, + CallsDeleteReq, + CallsDeleteRes, + CallsFilter, + CallsQueryReq, + CallsQueryStatsReq, + CallsQueryStatsRes, + CallsScoreReq, + CallsScoreRes, + CallStartReq, + CallStartRes, + CallStatsReq, + CallStatsRes, + CallsUsageReq, + CallsUsageRes, + CallUpdateReq, + CallUpdateRes, + ContainsOperation, + ContainsSpec, + ConvertOperation, + ConvertSpec, + CostCreateInput, + CostCreateReq, + CostCreateRes, + CostPurgeReq, + CostPurgeRes, + CostQueryOutput, + CostQueryReq, + CostQueryRes, + CreateAndLinkPayload, + CreateAndLinkTarget, + CreateAndLinkWeaveAssetRes, + Datacenter, + DatasetCreateBody, + DatasetCreateRes, + DatasetDeleteRes, + DatasetReadRes, + EndedCallSchemaForInsert, + EqOperation, + EvalResultsEvaluationSummary, + EvalResultsFilter, + EvalResultsQueryBody, + EvalResultsQueryRes, + EvalResultsRow, + EvalResultsRowEvaluation, + EvalResultsScorerStats, + EvalResultsSortBy, + EvalResultsSummaryRes, + EvalResultsTrial, + EvaluateModelReq, + EvaluateModelRes, + EvaluationCreateBody, + EvaluationCreateRes, + EvaluationDeleteRes, + EvaluationReadRes, + EvaluationRunCreateBody, + EvaluationRunCreateRes, + EvaluationRunDeleteRes, + EvaluationRunFinishBody, + EvaluationRunFinishRes, + EvaluationRunReadRes, + EvaluationStatusComplete, + EvaluationStatusFailed, + EvaluationStatusNotFound, + EvaluationStatusReq, + EvaluationStatusRes, + EvaluationStatusRunning, + FeedbackCreateBatchReq, + FeedbackCreateBatchRes, + FeedbackCreateReq, + FeedbackCreateRes, + FeedbackMetricSpec, + FeedbackPayloadPath, + FeedbackPayloadSchemaReq, + FeedbackPayloadSchemaRes, + FeedbackPurgeReq, + FeedbackPurgeRes, + FeedbackQueryReq, + FeedbackQueryRes, + FeedbackReplaceReq, + FeedbackReplaceRes, + FeedbackStatsReq, + FeedbackStatsRes, + FileContentReadReq, + FileCreateRes, + FilesStatsReq, + FilesStatsRes, + GenAISpanRef, + Geolocation, + GeolocationRes, + GetFieldOperator, + GteOperation, + GtOperation, + HTTPValidationError, + ImageGenerationCreateReq, + ImageGenerationCreateRes, + ImageGenerationRequestInputs, + InOperation, + LiteralOperation, + LLMAggregatedUsage, + LLMUsageSchema, + LteOperation, + LtOperation, + ModelCreateBody, + ModelCreateRes, + ModelDeleteRes, + ModelReadRes, + NotOperation, + NvidiaHardwareOption, + NvidiaHardwareRes, + NvidiaServerlessPricing, + ObjAddTagsRes, + ObjCreateReq, + ObjCreateRes, + ObjDeleteReq, + ObjDeleteRes, + ObjectVersionFilter, + ObjQueryReq, + ObjQueryRes, + ObjReadReq, + ObjReadRes, + ObjRemoveAliasesBody, + ObjRemoveAliasesRes, + ObjRemoveTagsRes, + ObjSchema, + ObjSchemaForInsert, + ObjSetAliasesBody, + ObjSetAliasesRes, + ObjTagsBody, + OpCreateBody, + OpCreateRes, + OpDeleteRes, + OpReadRes, + OrOperation, + PredictionCreateBody, + PredictionCreateRes, + PredictionDeleteRes, + PredictionFinishRes, + PredictionReadRes, + Pricing, + ProjectsInfoReq, + ProjectsInfoRes, + Query, + RefsReadBatchReq, + RefsReadBatchRes, + RouterOpenRouterModel, + RouterOpenRouterModelsRes, + ScoreCreateBody, + ScoreCreateRes, + ScoreDeleteRes, + ScorerCreateBody, + ScorerCreateRes, + ScorerDeleteRes, + ScoreReadRes, + ScorerReadRes, + ServerInfoRes, + SortBy, + StartedCallSchemaForInsert, + SummaryInsertMap, + TableAppendSpec, + TableAppendSpecPayload, + TableCreateFromDigestsReq, + TableCreateFromDigestsRes, + TableCreateReq, + TableCreateRes, + TableInsertSpec, + TableInsertSpecPayload, + TablePopSpec, + TablePopSpecPayload, + TableQueryReq, + TableQueryRes, + TableQueryStatsBatchReq, + TableQueryStatsBatchRes, + TableQueryStatsReq, + TableQueryStatsRes, + TableRowFilter, + TableRowSchema, + TableSchemaForInsert, + TableStatsRow, + TableUpdateReq, + TableUpdateRes, + TagsListRes, + ThreadsQueryFilter, + ThreadsQueryReq, + TraceStatus, + TraceUsageReq, + TraceUsageRes, + UsageMetricSpec, + ValidationError, +) + + +class CompletedCallSchemaForInsert(BaseModel): + """Schema for inserting a completed call directly. + + This represents a call that is already finished at insertion time, with + both start and end information provided together. Used by the + calls_complete endpoint, which is excluded from the OpenAPI spec (and so + absent from weave-server-sdk). + """ + + project_id: str + id: str + trace_id: str + op_name: str + started_at: datetime.datetime + ended_at: datetime.datetime + attributes: dict[str, Any] + inputs: dict[str, Any] + + display_name: str | None = None + parent_id: str | None = None + thread_id: str | None = None + turn_id: str | None = None + otel_dump: dict[str, Any] | None = None + exception: str | None = None + output: Any | None = None + summary: tsi.SummaryInsertMap | None = None + wb_user_id: str | None = None + wb_run_id: str | None = None + wb_run_step: int | None = None + wb_run_step_end: int | None = None + + +class CallsUpsertCompleteReq(BaseModel): + """Request body for the v2 calls_complete endpoint (absent from the SDK).""" + + batch: list[CompletedCallSchemaForInsert] + + +class CallsUpsertCompleteRes(BaseModel): + """Response body for the v2 calls_complete endpoint (absent from the SDK).""" + + +class CallStartV2Req(BaseModel): + """Request body for the eager v2 call-start endpoint (absent from the SDK).""" + + start: tsi.StartedCallSchemaForInsert + + +class CallStartV2Res(BaseModel): + id: str + trace_id: str + + +class CallEndV2Req(BaseModel): + """Request body for the eager v2 call-end endpoint (absent from the SDK). + + Note: ``started_at`` rides along as an extra field on the SDK's + EndedCallSchemaForInsert (the published model does not declare it yet). + """ + + end: tsi.EndedCallSchemaForInsert + + +class CallEndV2Res(BaseModel): + pass + + +class CallsQueryRes(BaseModel): + """Aggregate calls-query response (the SDK only exposes the stream form).""" + + calls: list[tsi.CallSchema] + + +class CompletionsCreateRequestInputs(BaseModel): + """LLM completion parameters for the /completions/create endpoint. + + That endpoint is excluded from the OpenAPI spec (include_in_schema=False + on the server), so the SDK has no model for it. + """ + + model: str + messages: list = [] + timeout: float | str | None = None + temperature: float | None = None + top_p: float | None = None + n: int | None = None + stop: str | list | None = None + max_completion_tokens: int | None = None + max_tokens: int | None = None + modalities: list | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + stream: bool | None = None + logit_bias: dict | None = None + user: str | None = None + # openai v1.0+ new params + response_format: dict | type[BaseModel] | None = None + seed: int | None = None + tools: list | None = None + tool_choice: str | dict | None = None + logprobs: bool | None = None + top_logprobs: int | None = None + parallel_tool_calls: bool | None = None + extra_headers: dict | None = None + # soon to be deprecated params by OpenAI + functions: list | None = None + function_call: str | None = None + api_version: str | None = None + # Weave-specific params + prompt: str | None = Field( + None, + description="Reference to a Weave Prompt object (e.g., 'weave:///entity/project/object/prompt_name:version'). " + "If provided, the messages from this prompt will be prepended to the messages in this request. " + "Template variables in the prompt messages can be substituted using the template_vars parameter.", + ) + template_vars: dict[str, Any] | None = Field( + None, + description="Dictionary of template variables to substitute in prompt messages. " + "Variables in messages like '{variable_name}' will be replaced with the corresponding values. " + "Applied to both prompt messages (if prompt is provided) and regular messages.", + ) + vertex_credentials: str | None = Field( + None, + description="JSON string of Vertex AI service account credentials. " + "When provided for vertex_ai models (e.g. vertex_ai/gemini-2.5-pro), used for authentication " + "instead of api_key. Not persisted in trace storage.", + ) + + +class CompletionsCreateReq(BaseModel): + """Request for /completions/create (excluded from the OpenAPI spec).""" + + project_id: str + inputs: CompletionsCreateRequestInputs + wb_user_id: str | None = None + track_llm_call: bool | None = True + trace_id: str | None = None + parent_id: str | None = None + + +class CompletionsCreateRes(BaseModel): + """Response of /completions/create (excluded from the OpenAPI spec).""" + + response: dict[str, Any] + weave_call_id: str | None = None # Deprecated: use span_id instead + span_id: str | None = None + trace_id: str | None = None + conversation_id: str | None = None + + +class FileCreateReq(BaseModel): + """Multipart file-upload request. + + The SDK's generated files.create lost its multipart body, so the binding + posts the form directly and the client expresses the request with this + model. + """ + + project_id: str + name: str + content: bytes + expected_digest: str | None = None + + +class FileContentReadRes(BaseModel): + """Binary file-content response (the SDK returns raw bytes untyped).""" + + content: bytes + + +class EnsureProjectExistsRes(BaseModel): + """Response of the wandb-side ensure-project call (not a trace-server API).""" + + project_name: str + + +# --------------------------------------------------------------------------- +# Request envelopes for routes whose OpenAPI definition carries ids in the URL +# path. The SDK (correctly) has no request models for these — only body +# models. The client surface is method(req), so these compose the SDK body +# with the path fields. +# --------------------------------------------------------------------------- + + +class ObjAddTagsReq(tsi.ObjTagsBody): + object_id: str + digest: str + + +class ObjRemoveTagsReq(tsi.ObjTagsBody): + object_id: str + digest: str + + +class ObjSetAliasesReq(tsi.ObjSetAliasesBody): + object_id: str + + +class ObjRemoveAliasesReq(tsi.ObjRemoveAliasesBody): + object_id: str + + +class TagsListReq(BaseModel): + project_id: str + + +class AliasesListReq(BaseModel): + project_id: str + + +class AnnotationQueueReadReq(BaseModel): + queue_id: str + project_id: str + + +class AnnotationQueueDeleteReq(BaseModel): + queue_id: str + project_id: str + + +class AnnotationQueueUpdateReq(tsi.AnnotationQueueUpdateBody): + queue_id: str + + +class AnnotationQueueAddCallsReq(tsi.AnnotationQueueAddCallsBody): + queue_id: str + + +class AnnotationQueueItemsQueryReq(tsi.AnnotationQueueItemsQueryBody): + queue_id: str class StartBatchItem(BaseModel): @@ -19,7 +494,7 @@ class CompleteBatchItem(BaseModel): """A complete call ready to be sent to calls_complete endpoint.""" mode: Literal["complete"] = "complete" - req: tsi.CompletedCallSchemaForInsert + req: CompletedCallSchemaForInsert class Batch(BaseModel): @@ -32,3 +507,271 @@ class EntityProjectInfo(BaseModel): entity: str project: str project_id: str + + +# --------------------------------------------------------------------------- +# Spec-excluded read APIs (include_in_schema=False on the server, so the SDK +# has no models for them). +# --------------------------------------------------------------------------- + + +class ProjectStatsReq(BaseModel): + project_id: str + include_trace_storage_size: bool | None = True + include_object_storage_size: bool | None = True + include_table_storage_size: bool | None = True + include_file_storage_size: bool | None = True + + +class ProjectStatsRes(BaseModel): + trace_storage_size_bytes: int + objects_storage_size_bytes: int + tables_storage_size_bytes: int + files_storage_size_bytes: int + + +class ProjectTTLSettingsReadReq(BaseModel): + project_id: str + + +class ProjectTTLSettingsReadRes(BaseModel): + retention_days: int | None = None + + +class ProjectTTLSettingsUpdateReq(BaseModel): + project_id: str + retention_days: int | None = None + wb_user_id: str | None = None + + +class ProjectTTLSettingsUpdateRes(BaseModel): + retention_days: int | None = None + + +class ThreadSchema(BaseModel): + """A thread row from /threads/stream_query (the SDK leaves rows untyped).""" + + thread_id: str + turn_count: int + start_time: datetime.datetime + last_updated: datetime.datetime + first_turn_id: str | None = None + last_turn_id: str | None = None + p50_turn_duration_ms: float | None = None + p99_turn_duration_ms: float | None = None + + +class FeedbackAggregateBucket(BaseModel): + time_bucket_start_ms: int | None = None + group: dict[str, str] = Field(default_factory=dict) + total_count: int = 0 + scored_count: int = 0 + tag_counts: dict[str, int] = Field(default_factory=dict) + rating_counts: dict[str, int] = Field(default_factory=dict) + rating_sums: dict[str, float] = Field(default_factory=dict) + + +class FeedbackAggregateReq(BaseModel): + project_id: str + after_ms: int + before_ms: int + time_bucket_seconds: int | None = None + feedback_types: list[str] = Field(default_factory=list) + tags: list[str] = Field(default_factory=list) + rating_min: float | None = None + rating_max: float | None = None + monitor_ids: list[str] = Field(default_factory=list) + scorer_ids: list[str] = Field(default_factory=list) + span_agent_names: list[str] = Field(default_factory=list) + span_types: list[str] = Field(default_factory=list) + group_by: list[str] = Field(default_factory=list) + + +class FeedbackAggregateRes(BaseModel): + time_bucket_seconds: int | None = None + after_ms: int + before_ms: int + buckets: list[FeedbackAggregateBucket] + + +class RescoreReq(BaseModel): + """Server-side rescore request; the remote binding does not support it.""" + + source_evaluation_run_id: str + scorer_refs: list[str] + project_id: str + wb_user_id: str | None = None + + +# --------------------------------------------------------------------------- +# v2 object API request envelopes. The OpenAPI spec carries entity/project in +# the URL path, so the SDK only has body models; the client surface stays +# method(req), so these compose the SDK bodies with project_id (and the +# server-populated wb_user_id passthrough the legacy models carried). +# --------------------------------------------------------------------------- + + +class _V2ReqMixin(BaseModel): + project_id: str + wb_user_id: str | None = None + + +class _V2ReadReq(_V2ReqMixin): + object_id: str + digest: str + + +class _V2ListReq(_V2ReqMixin): + limit: int | None = None + offset: int | None = None + + +class _V2DeleteReq(_V2ReqMixin): + object_id: str + digests: list[str] | None = None + + +class OpCreateReq(tsi.OpCreateBody, _V2ReqMixin): + pass + + +class OpReadReq(_V2ReadReq): + pass + + +class OpListReq(_V2ListReq): + pass + + +class OpDeleteReq(_V2DeleteReq): + pass + + +class DatasetCreateReq(tsi.DatasetCreateBody, _V2ReqMixin): + pass + + +class DatasetReadReq(_V2ReadReq): + pass + + +class DatasetListReq(_V2ListReq): + pass + + +class DatasetDeleteReq(_V2DeleteReq): + pass + + +class ScorerCreateReq(tsi.ScorerCreateBody, _V2ReqMixin): + pass + + +class ScorerReadReq(_V2ReadReq): + pass + + +class ScorerListReq(_V2ListReq): + pass + + +class ScorerDeleteReq(_V2DeleteReq): + pass + + +class EvaluationCreateReq(tsi.EvaluationCreateBody, _V2ReqMixin): + pass + + +class EvaluationReadReq(_V2ReadReq): + pass + + +class EvaluationListReq(_V2ListReq): + pass + + +class EvaluationDeleteReq(_V2DeleteReq): + pass + + +class ModelCreateReq(tsi.ModelCreateBody, _V2ReqMixin): + pass + + +class ModelReadReq(_V2ReadReq): + pass + + +class ModelListReq(_V2ListReq): + pass + + +class ModelDeleteReq(_V2DeleteReq): + pass + + +class EvaluationRunFilter(BaseModel): + evaluations: list[str] | None = None + models: list[str] | None = None + evaluation_run_ids: list[str] | None = None + + +class EvaluationRunCreateReq(tsi.EvaluationRunCreateBody, _V2ReqMixin): + pass + + +class EvaluationRunReadReq(BaseModel): + project_id: str + evaluation_run_id: str + + +class EvaluationRunListReq(BaseModel): + project_id: str + filter: EvaluationRunFilter | None = None + limit: int | None = None + offset: int | None = None + + +class EvaluationRunDeleteReq(_V2ReqMixin): + evaluation_run_ids: list[str] + + +class EvaluationRunFinishReq(tsi.EvaluationRunFinishBody, _V2ReqMixin): + evaluation_run_id: str + + +class PredictionCreateReq(tsi.PredictionCreateBody, _V2ReqMixin): + pass + + +class PredictionReadReq(_V2ReqMixin): + prediction_id: str + + +class PredictionListReq(_V2ListReq): + evaluation_run_id: str | None = None + + +class PredictionDeleteReq(_V2ReqMixin): + prediction_ids: list[str] + + +class PredictionFinishReq(_V2ReqMixin): + prediction_id: str + + +class ScoreCreateReq(tsi.ScoreCreateBody, _V2ReqMixin): + pass + + +class ScoreReadReq(_V2ReqMixin): + score_id: str + + +class ScoreListReq(_V2ListReq): + evaluation_run_id: str | None = None + + +class ScoreDeleteReq(_V2ReqMixin): + score_ids: list[str] diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index 5049b6f21479..94cba782e53a 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -1,9 +1,9 @@ """Remote trace server binding backed by the generated ``weave-server-sdk``. -``RemoteHTTPTraceServer`` keeps its tsi-typed ``TraceServerClientInterface`` -surface, but delegates HTTP transport and request/response typing to the -``weave_server_sdk`` package (generated from the trace server's OpenAPI spec — -the source of truth for the API shape). +The binding speaks ``weave_server_sdk.models`` end to end — requests and +responses are the SDK's generated types (the OpenAPI spec is the source of +truth), plus a few gap models from ``weave.trace_server_bindings.models`` for +surface the published SDK does not yet express. Design notes: @@ -12,24 +12,22 @@ ``http_timeout()`` honored) is injected into the SDK so every request — SDK-routed or raw — shares one connection pool, auth, and event hooks. - A response event hook routes every non-2xx response through - ``handle_response_error`` *before* the SDK sees it, so callers observe the - exact same ``httpx.HTTPStatusError`` / ``CallsCompleteModeRequired`` - semantics as ``RemoteHTTPTraceServer`` (retry predicates, 413 batch - splitting, and calls_complete auto-upgrade all key off these). + ``handle_response_error`` *before* the SDK sees it, so callers observe + ``httpx.HTTPStatusError`` / ``CallsCompleteModeRequired`` (retry predicates, + 413 batch splitting, and calls_complete auto-upgrade all key off these). - A request event hook injects the dynamic ``X-Weave-Retry-Id`` header at send time so every retry attempt carries the current retry id. - Endpoints the SDK cannot reach go through ``_raw_request``/``_raw_stream`` with an explicit reason. Two categories: 1. Endpoints excluded from the OpenAPI spec (``include_in_schema=False`` on - the server): calls_complete v2, eager v2 call start/end, completions, - project stats, TTL settings. + the server): calls_complete v2, eager v2 call start/end, completions. 2. weave-server-sdk 0.0.1 codegen bugs (duplicate method names where the last definition wins, lost multipart body): single feedback create, obj tag add/remove, /trace/usage, file create. Remove these hatches when a fixed SDK ships. -- Streaming endpoints (``*_stream`` and the v2 jsonl list endpoints) use - ``_raw_stream`` because the published SDK buffers jsonl responses into - memory; the raw path preserves line-by-line streaming. +- Streaming endpoints (``*_stream``) use ``_raw_stream`` because the published + SDK buffers jsonl responses into memory; the raw path preserves line-by-line + streaming. """ from __future__ import annotations @@ -37,16 +35,14 @@ import datetime import io import logging -from collections.abc import Callable, Iterator +from collections.abc import Iterator from typing import Any, TypeVar, cast from zoneinfo import ZoneInfo import httpx -from pydantic import BaseModel, Field, validate_call -from pydantic.json_schema import SkipJsonSchema +from pydantic import BaseModel, validate_call from typing_extensions import Self from weave_server_sdk import WeaveTrace -from weave_server_sdk import models as sdk_models from weave.trace.env import ssl_verify, weave_trace_server_url from weave.trace.settings import ( @@ -55,9 +51,8 @@ should_enable_disk_fallback, should_use_calls_complete, ) -from weave.trace_server import trace_server_interface as tsi from weave.trace_server.ids import generate_id -from weave.trace_server.service_interface import ServerInfoRes +from weave.trace_server_bindings import models as tsi from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor from weave.trace_server_bindings.client_interface import TraceServerClientInterface @@ -91,14 +86,16 @@ CALL_START_V2_PATH = "/v2/{entity}/{project}/call/start" CALL_END_V2_PATH = "/v2/{entity}/{project}/call/end" COMPLETIONS_CREATE_PATH = "/completions/create" + +# Endpoints reached via _raw_request because weave-server-sdk 0.0.1 cannot call +# them (duplicate generated method names where the last definition wins, or a +# lost multipart body). Remove once a fixed SDK is published. PROJECT_STATS_PATH = "/project/stats" PROJECT_TTL_SETTINGS_READ_PATH = "/project/ttl_settings/read" PROJECT_TTL_SETTINGS_UPDATE_PATH = "/project/ttl_settings/update" FEEDBACK_AGGREGATE_PATH = "/feedback/aggregate" +THREADS_STREAM_QUERY_PATH = "/threads/stream_query" -# Endpoints reached via _raw_request because weave-server-sdk 0.0.1 cannot call -# them (duplicate generated method names where the last definition wins, or a -# lost multipart body). Remove once a fixed SDK is published. FEEDBACK_CREATE_PATH = "/feedback/create" FEEDBACK_BATCH_CREATE_PATH = "/feedback/batch/create" TRACE_USAGE_PATH = "/trace/usage" @@ -110,7 +107,6 @@ # Streaming endpoints; the published SDK buffers jsonl bodies, so these are # reached via _raw_stream to preserve line-by-line streaming. CALLS_STREAM_QUERY_PATH = "/calls/stream_query" -THREADS_STREAM_QUERY_PATH = "/threads/stream_query" ANNOTATION_QUEUES_QUERY_PATH = "/annotation_queues/query" CALL_UPSERT_BATCH_PATH = "/call/upsert_batch" @@ -194,13 +190,11 @@ def _inject_dynamic_headers(self, request: httpx.Request) -> None: request.headers["X-Weave-Retry-Id"] = retry_id def _raise_for_status(self, response: httpx.Response) -> None: - """Surface error responses with RemoteHTTPTraceServer's exception - semantics (httpx response hook). + """Surface error responses as httpx errors (httpx response hook). Raising here means the SDK's own exception types never surface: retry predicates, 413 batch splitting, and client code keep seeing - ``httpx.HTTPStatusError`` / ``CallsCompleteModeRequired`` exactly as - before. + ``httpx.HTTPStatusError`` / ``CallsCompleteModeRequired``. """ if response.status_code >= 400: # Event hooks fire before the body is read; load it so @@ -227,39 +221,6 @@ def set_auth(self, auth: tuple[str, str]) -> None: # ---- request helpers ---------------------------------------------------- - @with_retry - def _via_sdk( - self, - req: BaseModel, - sdk_req_type: type[BaseModel], - sdk_method: Callable[..., Any], - res_type: type[TRes], - **path_args: Any, - ) -> TRes: - """Round-trip a tsi request through the typed SDK binding. - - ``by_alias`` is required since query models have Mongo-style properties - aliased to start with ``$``. - """ - body = sdk_req_type.model_validate(req.model_dump(by_alias=True)) - sdk_res = sdk_method(body, **path_args) - if sdk_res is None: - return res_type() - return res_type.model_validate(sdk_res.model_dump(by_alias=True)) - - @with_retry - def _via_sdk_no_body( - self, - sdk_method: Callable[..., Any], - res_type: type[TRes], - **call_args: Any, - ) -> TRes: - """Call an SDK binding that takes only path/query arguments.""" - sdk_res = sdk_method(**call_args) - if sdk_res is None: - return res_type() - return res_type.model_validate(sdk_res.model_dump(by_alias=True)) - @with_retry def _raw_request( self, @@ -318,8 +279,8 @@ def _open_stream( ) -> httpx.Response: """Open a streaming response; retries cover connection/headers only. - Mid-stream failures are not retried, matching RemoteHTTPTraceServer. - The caller owns the returned response and must close() it. + Mid-stream failures are not retried. The caller owns the returned + response and must close() it. """ request = self._http.build_request( method, @@ -332,6 +293,11 @@ def _open_stream( ) return self._http.send(request, stream=True) + @with_retry + def _call_sdk(self, sdk_method: Any, *args: Any, **kwargs: Any) -> Any: + """Invoke a typed SDK binding with the standard retry policy.""" + return sdk_method(*args, **kwargs) + # ---- batching ----------------------------------------------------------- @with_retry @@ -474,7 +440,7 @@ def _send_call_start_v2(self, start: tsi.StartedCallSchemaForInsert) -> None: ) @with_retry - def _send_call_end_v2(self, end: tsi.EndedCallSchemaForInsertWithStartedAt) -> None: + def _send_call_end_v2(self, end: tsi.EndedCallSchemaForInsert) -> None: """Send a single call end to the v2 endpoint.""" entity, project = from_project_id(end.project_id) req = tsi.CallEndV2Req(end=end) @@ -553,7 +519,7 @@ def encode_batch(batch: list[CompleteBatchItem]) -> bytes: ) def get_call_processor(self) -> AsyncBatchProcessor | CallBatchProcessor | None: - """Custom method not defined on the formal TraceServerInterface to expose + """Custom method not defined on the formal client interface to expose the underlying call processor. Should be formalized in a client-side interface. """ return self.call_processor @@ -606,16 +572,12 @@ def send_feedback_batch(encoded_data: bytes) -> None: e, ) - # Feedback endpoint doesn't support id, created_at, so we need to strip them - class FeedbackCreateReqStripped(tsi.FeedbackCreateReq): - id: SkipJsonSchema[str] = Field(exclude=True) - created_at: SkipJsonSchema[datetime.datetime | None] = Field( - exclude=True, default=None - ) - - # Fall back to individual feedback creation calls + # Fall back to individual feedback creation calls. The + # single-create endpoint doesn't accept an id, so strip it. for item in batch: - item_copy = FeedbackCreateReqStripped(**item.model_dump()) + item_copy = tsi.FeedbackCreateReq.model_validate( + item.model_dump(exclude={"id"}, exclude_none=True) + ) try: self._raw_request( "POST", @@ -645,7 +607,7 @@ class FeedbackCreateReqStripped(tsi.FeedbackCreateReq): ) def get_feedback_processor(self) -> AsyncBatchProcessor | None: - """Custom method not defined on the formal TraceServerInterface to expose + """Custom method not defined on the formal client interface to expose the underlying feedback processor. Should be formalized in a client-side interface. """ return self.feedback_processor @@ -653,18 +615,15 @@ def get_feedback_processor(self) -> AsyncBatchProcessor | None: # ---- service ------------------------------------------------------------ @with_retry - def server_info(self) -> ServerInfoRes: - res = self._sdk.services.server_info() - return ServerInfoRes.model_validate(res.model_dump()) + def server_info(self) -> tsi.ServerInfoRes: + return self._sdk.services.server_info() @validate_call @with_retry def projects_info(self, req: tsi.ProjectsInfoReq) -> list[tsi.ProjectsInfoRes]: - body = sdk_models.ProjectsInfoReq.model_validate(req.model_dump()) - res = self._sdk.service.create_projects_info(body) - return [tsi.ProjectsInfoRes.model_validate(item.model_dump()) for item in res] + return self._sdk.service.create_projects_info(req) - def otel_export(self, req: tsi.OTelExportReq) -> tsi.OTelExportRes: + def otel_export(self, req: Any) -> Any: # TODO: Add docs link (DOCS-1390) raise NotImplementedError("Sending otel traces directly is not yet supported.") @@ -681,17 +640,10 @@ def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: ) self.call_processor.enqueue_start(StartBatchItem(req=req)) return tsi.CallStartRes(id=req.start.id, trace_id=req.start.trace_id) - return self._via_sdk( - req, sdk_models.CallStartReq, self._sdk.calls.start, tsi.CallStartRes - ) + return self._call_sdk(self._sdk.calls.start, req) def call_start_batch(self, req: tsi.CallCreateBatchReq) -> tsi.CallCreateBatchRes: - return self._via_sdk( - req, - sdk_models.CallCreateBatchReq, - self._sdk.calls.upsert_batch, - tsi.CallCreateBatchRes, - ) + return self._call_sdk(self._sdk.calls.upsert_batch, req) @validate_call def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: @@ -700,15 +652,11 @@ def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: self.call_processor.enqueue([EndBatchItem(req=req)]) return tsi.CallEndRes() - return self._via_sdk( - req, sdk_models.CallEndReq, self._sdk.calls.end, tsi.CallEndRes - ) + return self._call_sdk(self._sdk.calls.end, req) @validate_call def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes: - return self._via_sdk( - req, sdk_models.CallReadReq, self._sdk.calls.read, tsi.CallReadRes - ) + return self._call_sdk(self._sdk.calls.read, req) @validate_call def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: @@ -723,12 +671,7 @@ def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema] @validate_call def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: - return self._via_sdk( - req, - sdk_models.CallsQueryStatsReq, - self._sdk.calls.query_stats, - tsi.CallsQueryStatsRes, - ) + return self._call_sdk(self._sdk.calls.query_stats, req) @validate_call def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes: @@ -740,54 +683,37 @@ def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes: @validate_call def calls_usage(self, req: tsi.CallsUsageReq) -> tsi.CallsUsageRes: - return self._via_sdk( - req, - sdk_models.CallsUsageReq, - self._sdk.calls.create_usage, - tsi.CallsUsageRes, - ) + return self._call_sdk(self._sdk.calls.create_usage, req) @validate_call def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: - return self._via_sdk( - req, sdk_models.CallsDeleteReq, self._sdk.calls.delete, tsi.CallsDeleteRes - ) + return self._call_sdk(self._sdk.calls.delete, req) @validate_call def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: - return self._via_sdk( - req, sdk_models.CallUpdateReq, self._sdk.calls.update, tsi.CallUpdateRes - ) + return self._call_sdk(self._sdk.calls.update, req) # ---- Obj API -------------------------------------------------------------- @validate_call def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: - return self._via_sdk( - req, sdk_models.ObjCreateReq, self._sdk.objects.create, tsi.ObjCreateRes - ) + return self._call_sdk(self._sdk.objects.create, req) @validate_call def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: - return self._via_sdk( - req, sdk_models.ObjReadReq, self._sdk.objects.read, tsi.ObjReadRes - ) + return self._call_sdk(self._sdk.objects.read, req) @validate_call def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: - return self._via_sdk( - req, sdk_models.ObjQueryReq, self._sdk.objects.query, tsi.ObjQueryRes - ) + return self._call_sdk(self._sdk.objects.query, req) def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: - return self._via_sdk( - req, sdk_models.ObjDeleteReq, self._sdk.objects.delete, tsi.ObjDeleteRes - ) + return self._call_sdk(self._sdk.objects.delete, req) def obj_add_tags(self, req: tsi.ObjAddTagsReq) -> tsi.ObjAddTagsRes: # SDK 0.0.1: objects.tags for add-tags is shadowed by the /tags list # overload of the same generated name. - body = sdk_models.ObjTagsBody(project_id=req.project_id, tags=req.tags) + body = tsi.ObjTagsBody(project_id=req.project_id, tags=req.tags) return self._raw_request( "PUT", OBJ_ADD_TAGS_PATH.format(object_id=req.object_id, digest=req.digest), @@ -798,7 +724,7 @@ def obj_add_tags(self, req: tsi.ObjAddTagsReq) -> tsi.ObjAddTagsRes: def obj_remove_tags(self, req: tsi.ObjRemoveTagsReq) -> tsi.ObjRemoveTagsRes: # SDK 0.0.1: objects.create_remove for remove-tags is shadowed by the # remove-aliases overload of the same generated name. - body = sdk_models.ObjTagsBody(project_id=req.project_id, tags=req.tags) + body = tsi.ObjTagsBody(project_id=req.project_id, tags=req.tags) return self._raw_request( "POST", OBJ_REMOVE_TAGS_PATH.format(object_id=req.object_id, digest=req.digest), @@ -807,49 +733,32 @@ def obj_remove_tags(self, req: tsi.ObjRemoveTagsReq) -> tsi.ObjRemoveTagsRes: ) def obj_set_aliases(self, req: tsi.ObjSetAliasesReq) -> tsi.ObjSetAliasesRes: - body = sdk_models.ObjSetAliasesBody( + body = tsi.ObjSetAliasesBody( project_id=req.project_id, digest=req.digest, aliases=req.aliases ) - res = self._via_sdk_no_body( - self._sdk.objects.update_aliases, - tsi.ObjSetAliasesRes, - body=body, - object_id=req.object_id, + return self._call_sdk( + self._sdk.objects.update_aliases, body, object_id=req.object_id ) - return res def obj_remove_aliases( self, req: tsi.ObjRemoveAliasesReq ) -> tsi.ObjRemoveAliasesRes: - body = sdk_models.ObjRemoveAliasesBody( - project_id=req.project_id, aliases=req.aliases - ) - return self._via_sdk_no_body( - self._sdk.objects.create_remove, - tsi.ObjRemoveAliasesRes, - body=body, - object_id=req.object_id, + body = tsi.ObjRemoveAliasesBody(project_id=req.project_id, aliases=req.aliases) + return self._call_sdk( + self._sdk.objects.create_remove, body, object_id=req.object_id ) def tags_list(self, req: tsi.TagsListReq) -> tsi.TagsListRes: - return self._via_sdk_no_body( - self._sdk.objects.tags, tsi.TagsListRes, project_id=req.project_id - ) + return self._call_sdk(self._sdk.objects.tags, project_id=req.project_id) def aliases_list(self, req: tsi.AliasesListReq) -> tsi.AliasesListRes: - return self._via_sdk_no_body( - self._sdk.objects.list_aliases, - tsi.AliasesListRes, - project_id=req.project_id, - ) + return self._call_sdk(self._sdk.objects.list_aliases, project_id=req.project_id) # ---- Table API ------------------------------------------------------------ @validate_call def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: - return self._via_sdk( - req, sdk_models.TableCreateReq, self._sdk.tables.create, tsi.TableCreateRes - ) + return self._call_sdk(self._sdk.tables.create, req) @validate_call def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: @@ -879,18 +788,11 @@ def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: digest=second_half_res.digest, updated_row_digests=all_digests ) else: - return self._via_sdk( - req, - sdk_models.TableUpdateReq, - self._sdk.tables.update, - tsi.TableUpdateRes, - ) + return self._call_sdk(self._sdk.tables.update, req) @validate_call def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: - return self._via_sdk( - req, sdk_models.TableQueryReq, self._sdk.tables.query, tsi.TableQueryRes - ) + return self._call_sdk(self._sdk.tables.query, req) @validate_call def table_query_stream( @@ -902,24 +804,14 @@ def table_query_stream( @validate_call def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes: - return self._via_sdk( - req, - sdk_models.TableQueryStatsReq, - self._sdk.tables.query_stats, - tsi.TableQueryStatsRes, - ) + return self._call_sdk(self._sdk.tables.query_stats, req) @validate_call def table_create_from_digests( self, req: tsi.TableCreateFromDigestsReq ) -> tsi.TableCreateFromDigestsRes: """Create a table by specifying row digests instead of actual rows.""" - return self._via_sdk( - req, - sdk_models.TableCreateFromDigestsReq, - self._sdk.tables.create_create_from_digests, - tsi.TableCreateFromDigestsRes, - ) + return self._call_sdk(self._sdk.tables.create_create_from_digests, req) def unretried_table_create_from_digests( self, req: tsi.TableCreateFromDigestsReq @@ -928,33 +820,17 @@ def unretried_table_create_from_digests( probe: a missing endpoint (404) must fail fast, and a flaky probe must not stall table saves behind retries. """ - return self._via_sdk.__wrapped__( # type: ignore[attr-defined] - self, - req, - sdk_models.TableCreateFromDigestsReq, - self._sdk.tables.create_create_from_digests, - tsi.TableCreateFromDigestsRes, - ) + return self._sdk.tables.create_create_from_digests(req) @validate_call def table_query_stats_batch( - self, req: tsi.TableQueryStatsReq - ) -> tsi.TableQueryStatsRes: - return self._via_sdk( - req, - sdk_models.TableQueryStatsBatchReq, - self._sdk.tables.create_query_stats_batch, - tsi.TableQueryStatsBatchRes, - ) + self, req: tsi.TableQueryStatsBatchReq + ) -> tsi.TableQueryStatsBatchRes: + return self._call_sdk(self._sdk.tables.create_query_stats_batch, req) @validate_call def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: - return self._via_sdk( - req, - sdk_models.RefsReadBatchReq, - self._sdk.refs.read_batch, - tsi.RefsReadBatchRes, - ) + return self._call_sdk(self._sdk.refs.read_batch, req) # ---- File API ------------------------------------------------------------- @@ -987,12 +863,7 @@ def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadR return tsi.FileContentReadRes(content=bytes_buffer.getvalue()) def files_stats(self, req: tsi.FilesStatsReq) -> tsi.FilesStatsRes: - return self._via_sdk( - req, - sdk_models.FilesStatsReq, - self._sdk.files.query_stats, - tsi.FilesStatsRes, - ) + return self._call_sdk(self._sdk.files.query_stats, req) # ---- Feedback API ----------------------------------------------------------- @@ -1016,7 +887,10 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: # SDK 0.0.1: feedback.create for single create is shadowed by the # batch-create overload of the same generated name. return self._raw_request( - "POST", FEEDBACK_CREATE_PATH, req=req, res_type=tsi.FeedbackCreateRes + "POST", + FEEDBACK_CREATE_PATH, + req=req, + res_type=tsi.FeedbackCreateRes, ) def feedback_create_batch( @@ -1024,89 +898,43 @@ def feedback_create_batch( ) -> tsi.FeedbackCreateBatchRes: # Note: the SDK method is named `create` for /feedback/batch/create # (duplicate-name shadowing in 0.0.1; the batch overload won). - return self._via_sdk( - req, - sdk_models.FeedbackCreateBatchReq, - self._sdk.feedback.create, - tsi.FeedbackCreateBatchRes, - ) + return self._call_sdk(self._sdk.feedback.create, req) @validate_call def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: - return self._via_sdk( - req, - sdk_models.FeedbackQueryReq, - self._sdk.feedback.query, - tsi.FeedbackQueryRes, - ) + return self._call_sdk(self._sdk.feedback.query, req) @validate_call def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: - return self._via_sdk( - req, - sdk_models.FeedbackPurgeReq, - self._sdk.feedback.purge, - tsi.FeedbackPurgeRes, - ) + return self._call_sdk(self._sdk.feedback.purge, req) @validate_call def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: - return self._via_sdk( - req, - sdk_models.FeedbackReplaceReq, - self._sdk.feedback.replace, - tsi.FeedbackReplaceRes, - ) + return self._call_sdk(self._sdk.feedback.replace, req) @validate_call def feedback_stats(self, req: tsi.FeedbackStatsReq) -> tsi.FeedbackStatsRes: - return self._via_sdk( - req, - sdk_models.FeedbackStatsReq, - self._sdk.feedback.create_stats, - tsi.FeedbackStatsRes, - ) - - @validate_call - def feedback_aggregate( - self, req: tsi.FeedbackAggregateReq - ) -> tsi.FeedbackAggregateRes: - """Query the feedback table for aggregate scores over time.""" - # Not yet present in the published SDK. - return self._raw_request( - "POST", FEEDBACK_AGGREGATE_PATH, req=req, res_type=tsi.FeedbackAggregateRes - ) + return self._call_sdk(self._sdk.feedback.create_stats, req) @validate_call def feedback_payload_schema( self, req: tsi.FeedbackPayloadSchemaReq ) -> tsi.FeedbackPayloadSchemaRes: - return self._via_sdk( - req, - sdk_models.FeedbackPayloadSchemaReq, - self._sdk.feedback.create_payload_schema, - tsi.FeedbackPayloadSchemaRes, - ) + return self._call_sdk(self._sdk.feedback.create_payload_schema, req) # ---- Cost API --------------------------------------------------------------- @validate_call def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: - return self._via_sdk( - req, sdk_models.CostQueryReq, self._sdk.costs.query, tsi.CostQueryRes - ) + return self._call_sdk(self._sdk.costs.query, req) @validate_call def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: - return self._via_sdk( - req, sdk_models.CostCreateReq, self._sdk.costs.create, tsi.CostCreateRes - ) + return self._call_sdk(self._sdk.costs.create, req) @validate_call def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: - return self._via_sdk( - req, sdk_models.CostPurgeReq, self._sdk.costs.purge, tsi.CostPurgeRes - ) + return self._call_sdk(self._sdk.costs.purge, req) # ---- Execution APIs --------------------------------------------------------- @@ -1129,59 +957,14 @@ def completions_create_stream( def image_create( self, req: tsi.ImageGenerationCreateReq ) -> tsi.ImageGenerationCreateRes: - return self._via_sdk( - req, - sdk_models.ImageGenerationCreateReq, - self._sdk.images.create, - tsi.ImageGenerationCreateRes, - ) - - def project_stats(self, req: tsi.ProjectStatsReq) -> tsi.ProjectStatsRes: - # Excluded from the OpenAPI spec (include_in_schema=False). - return self._raw_request( - "POST", PROJECT_STATS_PATH, req=req, res_type=tsi.ProjectStatsRes - ) - - def project_ttl_settings_read( - self, req: tsi.ProjectTTLSettingsReadReq - ) -> tsi.ProjectTTLSettingsReadRes: - # Excluded from the OpenAPI spec (include_in_schema=False). - return self._raw_request( - "POST", - PROJECT_TTL_SETTINGS_READ_PATH, - req=req, - res_type=tsi.ProjectTTLSettingsReadRes, - ) - - def project_ttl_settings_update( - self, req: tsi.ProjectTTLSettingsUpdateReq - ) -> tsi.ProjectTTLSettingsUpdateRes: - # Excluded from the OpenAPI spec (include_in_schema=False). - return self._raw_request( - "POST", - PROJECT_TTL_SETTINGS_UPDATE_PATH, - req=req, - res_type=tsi.ProjectTTLSettingsUpdateRes, - ) - - def threads_query_stream( - self, req: tsi.ThreadsQueryReq - ) -> Iterator[tsi.ThreadSchema]: - return self._raw_stream( - "POST", THREADS_STREAM_QUERY_PATH, req=req, res_type=tsi.ThreadSchema - ) + return self._call_sdk(self._sdk.images.create, req) # ---- Annotation Queue API ----------------------------------------------------- def annotation_queue_create( self, req: tsi.AnnotationQueueCreateReq ) -> tsi.AnnotationQueueCreateRes: - return self._via_sdk( - req, - sdk_models.AnnotationQueueCreateReq, - self._sdk.annotation_queues.create_annotation_queues, - tsi.AnnotationQueueCreateRes, - ) + return self._call_sdk(self._sdk.annotation_queues.create_annotation_queues, req) def annotation_queues_query_stream( self, req: tsi.AnnotationQueuesQueryReq @@ -1196,9 +979,8 @@ def annotation_queues_query_stream( def annotation_queue_read( self, req: tsi.AnnotationQueueReadReq ) -> tsi.AnnotationQueueReadRes: - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.annotation_queues.list_annotation_queues, - tsi.AnnotationQueueReadRes, queue_id=req.queue_id, project_id=req.project_id, ) @@ -1206,9 +988,8 @@ def annotation_queue_read( def annotation_queue_delete( self, req: tsi.AnnotationQueueDeleteReq ) -> tsi.AnnotationQueueDeleteRes: - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.annotation_queues.delete_annotation_queues, - tsi.AnnotationQueueDeleteRes, queue_id=req.queue_id, project_id=req.project_id, ) @@ -1217,16 +998,12 @@ def annotation_queue_update( self, req: tsi.AnnotationQueueUpdateReq ) -> tsi.AnnotationQueueUpdateRes: # Body type excludes queue_id from the request body (it's in the URL path) - body = sdk_models.AnnotationQueueUpdateBody( - project_id=req.project_id, - name=req.name, - description=req.description, - scorer_refs=req.scorer_refs, + body = tsi.AnnotationQueueUpdateBody.model_validate( + req.model_dump(exclude={"queue_id"}) ) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.annotation_queues.update_annotation_queues, - tsi.AnnotationQueueUpdateRes, - body=body, + body, queue_id=req.queue_id, ) @@ -1234,59 +1011,28 @@ def annotation_queue_add_calls( self, req: tsi.AnnotationQueueAddCallsReq ) -> tsi.AnnotationQueueAddCallsRes: # Body type excludes queue_id from the request body (it's in the URL path) - body = sdk_models.AnnotationQueueAddCallsBody( - project_id=req.project_id, - call_ids=req.call_ids, - display_fields=req.display_fields, + body = tsi.AnnotationQueueAddCallsBody.model_validate( + req.model_dump(exclude={"queue_id"}) ) - return self._via_sdk_no_body( - self._sdk.annotation_queues.create_items, - tsi.AnnotationQueueAddCallsRes, - body=body, - queue_id=req.queue_id, + return self._call_sdk( + self._sdk.annotation_queues.create_items, body, queue_id=req.queue_id ) def annotation_queue_items_query( self, req: tsi.AnnotationQueueItemsQueryReq ) -> tsi.AnnotationQueueItemsQueryRes: # Body type excludes queue_id from the request body (it's in the URL path) - body = sdk_models.AnnotationQueueItemsQueryBody.model_validate( + body = tsi.AnnotationQueueItemsQueryBody.model_validate( req.model_dump(exclude={"queue_id"}, by_alias=True) ) - return self._via_sdk_no_body( - self._sdk.annotation_queues.query, - tsi.AnnotationQueueItemsQueryRes, - body=body, - queue_id=req.queue_id, + return self._call_sdk( + self._sdk.annotation_queues.query, body, queue_id=req.queue_id ) def annotation_queues_stats( self, req: tsi.AnnotationQueuesStatsReq ) -> tsi.AnnotationQueuesStatsRes: - return self._via_sdk( - req, - sdk_models.AnnotationQueuesStatsReq, - self._sdk.annotation_queues.create_stats, - tsi.AnnotationQueuesStatsRes, - ) - - def annotator_queue_items_progress_update( - self, req: tsi.AnnotatorQueueItemsProgressUpdateReq - ) -> tsi.AnnotatorQueueItemsProgressUpdateRes: - # Body type excludes queue_id, item_id, and wb_user_id from the request - # body (queue_id and item_id are in the URL path, wb_user_id is set - # server-side from auth) - body = sdk_models.AnnotationQueueItemProgressUpdateBody( - project_id=req.project_id, - annotation_state=req.annotation_state, - ) - return self._via_sdk_no_body( - self._sdk.annotation_queues.create_progress, - tsi.AnnotatorQueueItemsProgressUpdateRes, - body=body, - queue_id=req.queue_id, - item_id=req.item_id, - ) + return self._call_sdk(self._sdk.annotation_queues.create_stats, req) # ---- Server-side execution (not supported remotely) ---------------------------- @@ -1298,52 +1044,102 @@ def evaluation_status( ) -> tsi.EvaluationStatusRes: raise NotImplementedError("evaluation_status is not implemented") - def rescore(self, req: tsi.RescoreReq) -> tsi.RescoreRes: - raise NotImplementedError("rescore is not implemented") - def calls_score(self, req: tsi.CallsScoreReq) -> tsi.CallsScoreRes: raise NotImplementedError("calls_score is not implemented") - # ---- V2 APIs -------------------------------------------------------------- + @validate_call + def feedback_aggregate( + self, req: tsi.FeedbackAggregateReq + ) -> tsi.FeedbackAggregateRes: + """Query the feedback table for aggregate scores over time.""" + # Not yet present in the published SDK. + return self._raw_request( + "POST", FEEDBACK_AGGREGATE_PATH, req=req, res_type=tsi.FeedbackAggregateRes + ) + + def project_stats(self, req: tsi.ProjectStatsReq) -> tsi.ProjectStatsRes: + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", PROJECT_STATS_PATH, req=req, res_type=tsi.ProjectStatsRes + ) + + def project_ttl_settings_read( + self, req: tsi.ProjectTTLSettingsReadReq + ) -> tsi.ProjectTTLSettingsReadRes: + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + PROJECT_TTL_SETTINGS_READ_PATH, + req=req, + res_type=tsi.ProjectTTLSettingsReadRes, + ) + + def project_ttl_settings_update( + self, req: tsi.ProjectTTLSettingsUpdateReq + ) -> tsi.ProjectTTLSettingsUpdateRes: + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + PROJECT_TTL_SETTINGS_UPDATE_PATH, + req=req, + res_type=tsi.ProjectTTLSettingsUpdateRes, + ) + + def threads_query_stream( + self, req: tsi.ThreadsQueryReq + ) -> Iterator[tsi.ThreadSchema]: + # Raw stream: the published SDK buffers jsonl, and it has no thread + # row model. + return self._raw_stream( + "POST", THREADS_STREAM_QUERY_PATH, req=req, res_type=tsi.ThreadSchema + ) + + def rescore(self, req: tsi.RescoreReq) -> Any: + raise NotImplementedError("rescore is not implemented") + + # ---- V2 object APIs -------------------------------------------------------- - def _v2_body_create( + def _v2_create( self, req: BaseModel, body_type: type[BaseModel], - sdk_method: Callable[..., Any], - res_type: type[TRes], - ) -> TRes: - """Create via a v2 endpoint: project_id moves to the path; the rest is body.""" + sdk_method: Any, + ) -> Any: + """Create via a v2 endpoint: project_id moves to the URL path.""" entity, project = from_project_id(req.project_id) # type: ignore[attr-defined] body = body_type.model_validate(req.model_dump(exclude={"project_id"})) - return self._via_sdk_no_body( - sdk_method, res_type, body=body, entity=entity, project=project - ) + return self._call_sdk(sdk_method, body, entity=entity, project=project) def _v2_list_stream( self, - req: BaseModel, + project_id: str, res_type: type[TRes], kind: str, params: dict[str, Any], ) -> Iterator[TRes]: """Stream a v2 jsonl list endpoint (the SDK buffers jsonl responses).""" - entity, project = from_project_id(req.project_id) # type: ignore[attr-defined] + entity, project = from_project_id(project_id) url = f"/v2/{entity}/{project}/{kind}" return self._raw_stream("GET", url, params=params, res_type=res_type) + @staticmethod + def _v2_list_params(req: Any) -> dict[str, Any]: + params: dict[str, Any] = {} + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + return params + @validate_call def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: - return self._v2_body_create( - req, sdk_models.OpCreateBody, self._sdk.v2_ops.create, tsi.OpCreateRes - ) + return self._v2_create(req, tsi.OpCreateBody, self._sdk.v2_ops.create) @validate_call def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_ops.read, - tsi.OpReadRes, entity=entity, project=project, object_id=req.object_id, @@ -1352,22 +1148,15 @@ def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: @validate_call def op_list(self, req: tsi.OpListReq) -> Iterator[tsi.OpReadRes]: - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - # `eager` is missing from the SDK's generated v2_ops.list signature. - if req.eager: - params["eager"] = "true" - return self._v2_list_stream(req, tsi.OpReadRes, "ops", params) + return self._v2_list_stream( + req.project_id, tsi.OpReadRes, "ops", self._v2_list_params(req) + ) @validate_call def op_delete(self, req: tsi.OpDeleteReq) -> tsi.OpDeleteRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_ops.delete, - tsi.OpDeleteRes, entity=entity, project=project, object_id=req.object_id, @@ -1376,19 +1165,13 @@ def op_delete(self, req: tsi.OpDeleteReq) -> tsi.OpDeleteRes: @validate_call def dataset_create(self, req: tsi.DatasetCreateReq) -> tsi.DatasetCreateRes: - return self._v2_body_create( - req, - sdk_models.DatasetCreateBody, - self._sdk.v2_datasets.create, - tsi.DatasetCreateRes, - ) + return self._v2_create(req, tsi.DatasetCreateBody, self._sdk.v2_datasets.create) @validate_call def dataset_read(self, req: tsi.DatasetReadReq) -> tsi.DatasetReadRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_datasets.read, - tsi.DatasetReadRes, entity=entity, project=project, object_id=req.object_id, @@ -1397,19 +1180,18 @@ def dataset_read(self, req: tsi.DatasetReadReq) -> tsi.DatasetReadRes: @validate_call def dataset_list(self, req: tsi.DatasetListReq) -> Iterator[tsi.DatasetReadRes]: - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._v2_list_stream(req, tsi.DatasetReadRes, "datasets", params) + return self._v2_list_stream( + req.project_id, + tsi.DatasetReadRes, + "datasets", + self._v2_list_params(req), + ) @validate_call def dataset_delete(self, req: tsi.DatasetDeleteReq) -> tsi.DatasetDeleteRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_datasets.delete, - tsi.DatasetDeleteRes, entity=entity, project=project, object_id=req.object_id, @@ -1418,19 +1200,13 @@ def dataset_delete(self, req: tsi.DatasetDeleteReq) -> tsi.DatasetDeleteRes: @validate_call def scorer_create(self, req: tsi.ScorerCreateReq) -> tsi.ScorerCreateRes: - return self._v2_body_create( - req, - sdk_models.ScorerCreateBody, - self._sdk.v2_scorers.create, - tsi.ScorerCreateRes, - ) + return self._v2_create(req, tsi.ScorerCreateBody, self._sdk.v2_scorers.create) @validate_call def scorer_read(self, req: tsi.ScorerReadReq) -> tsi.ScorerReadRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_scorers.read, - tsi.ScorerReadRes, entity=entity, project=project, object_id=req.object_id, @@ -1439,19 +1215,18 @@ def scorer_read(self, req: tsi.ScorerReadReq) -> tsi.ScorerReadRes: @validate_call def scorer_list(self, req: tsi.ScorerListReq) -> Iterator[tsi.ScorerReadRes]: - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._v2_list_stream(req, tsi.ScorerReadRes, "scorers", params) + return self._v2_list_stream( + req.project_id, + tsi.ScorerReadRes, + "scorers", + self._v2_list_params(req), + ) @validate_call def scorer_delete(self, req: tsi.ScorerDeleteReq) -> tsi.ScorerDeleteRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_scorers.delete, - tsi.ScorerDeleteRes, entity=entity, project=project, object_id=req.object_id, @@ -1462,19 +1237,15 @@ def scorer_delete(self, req: tsi.ScorerDeleteReq) -> tsi.ScorerDeleteRes: def evaluation_create( self, req: tsi.EvaluationCreateReq ) -> tsi.EvaluationCreateRes: - return self._v2_body_create( - req, - sdk_models.EvaluationCreateBody, - self._sdk.v2_evaluations.create, - tsi.EvaluationCreateRes, + return self._v2_create( + req, tsi.EvaluationCreateBody, self._sdk.v2_evaluations.create ) @validate_call def evaluation_read(self, req: tsi.EvaluationReadReq) -> tsi.EvaluationReadRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_evaluations.read, - tsi.EvaluationReadRes, entity=entity, project=project, object_id=req.object_id, @@ -1485,44 +1256,35 @@ def evaluation_read(self, req: tsi.EvaluationReadReq) -> tsi.EvaluationReadRes: def evaluation_list( self, req: tsi.EvaluationListReq ) -> Iterator[tsi.EvaluationReadRes]: - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._v2_list_stream(req, tsi.EvaluationReadRes, "evaluations", params) + return self._v2_list_stream( + req.project_id, + tsi.EvaluationReadRes, + "evaluations", + self._v2_list_params(req), + ) @validate_call def evaluation_delete( self, req: tsi.EvaluationDeleteReq ) -> tsi.EvaluationDeleteRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_evaluations.delete, - tsi.EvaluationDeleteRes, entity=entity, project=project, object_id=req.object_id, digests=req.digests, ) - # ---- Model V2 API ----------------------------------------------------------- - @validate_call def model_create(self, req: tsi.ModelCreateReq) -> tsi.ModelCreateRes: - return self._v2_body_create( - req, - sdk_models.ModelCreateBody, - self._sdk.v2_models.create, - tsi.ModelCreateRes, - ) + return self._v2_create(req, tsi.ModelCreateBody, self._sdk.v2_models.create) @validate_call def model_read(self, req: tsi.ModelReadReq) -> tsi.ModelReadRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_models.read, - tsi.ModelReadRes, entity=entity, project=project, object_id=req.object_id, @@ -1531,36 +1293,32 @@ def model_read(self, req: tsi.ModelReadReq) -> tsi.ModelReadRes: @validate_call def model_list(self, req: tsi.ModelListReq) -> Iterator[tsi.ModelReadRes]: - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._v2_list_stream(req, tsi.ModelReadRes, "models", params) + return self._v2_list_stream( + req.project_id, + tsi.ModelReadRes, + "models", + self._v2_list_params(req), + ) @validate_call def model_delete(self, req: tsi.ModelDeleteReq) -> tsi.ModelDeleteRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_models.delete, - tsi.ModelDeleteRes, entity=entity, project=project, object_id=req.object_id, digests=req.digests, ) - # ---- Evaluation Run V2 API ---------------------------------------------------- - @validate_call def evaluation_run_create( self, req: tsi.EvaluationRunCreateReq ) -> tsi.EvaluationRunCreateRes: - return self._v2_body_create( + return self._v2_create( req, - sdk_models.EvaluationRunCreateBody, + tsi.EvaluationRunCreateBody, self._sdk.v2_evaluation_runs.create, - tsi.EvaluationRunCreateRes, ) @validate_call @@ -1568,9 +1326,8 @@ def evaluation_run_read( self, req: tsi.EvaluationRunReadReq ) -> tsi.EvaluationRunReadRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_evaluation_runs.read, - tsi.EvaluationRunReadRes, entity=entity, project=project, evaluation_run_id=req.evaluation_run_id, @@ -1582,11 +1339,7 @@ def evaluation_run_list( ) -> Iterator[tsi.EvaluationRunReadRes]: # Raw: the SDK's generated list signature renames the filter params # (evaluations vs evaluation_refs), so use the wire format directly. - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset + params = self._v2_list_params(req) if req.filter: if req.filter.evaluations: params["evaluation_refs"] = ",".join(req.filter.evaluations) @@ -1595,7 +1348,7 @@ def evaluation_run_list( if req.filter.evaluation_run_ids: params["evaluation_run_ids"] = ",".join(req.filter.evaluation_run_ids) return self._v2_list_stream( - req, tsi.EvaluationRunReadRes, "evaluation_runs", params + req.project_id, tsi.EvaluationRunReadRes, "evaluation_runs", params ) @validate_call @@ -1603,9 +1356,8 @@ def evaluation_run_delete( self, req: tsi.EvaluationRunDeleteReq ) -> tsi.EvaluationRunDeleteRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_evaluation_runs.delete, - tsi.EvaluationRunDeleteRes, entity=entity, project=project, evaluation_run_ids=req.evaluation_run_ids, @@ -1616,37 +1368,30 @@ def evaluation_run_finish( self, req: tsi.EvaluationRunFinishReq ) -> tsi.EvaluationRunFinishRes: entity, project = from_project_id(req.project_id) - body = sdk_models.EvaluationRunFinishBody.model_validate( + body = tsi.EvaluationRunFinishBody.model_validate( req.model_dump(exclude={"project_id", "evaluation_run_id"}) ) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_evaluation_runs.finish, - tsi.EvaluationRunFinishRes, - body=body, + body, entity=entity, project=project, evaluation_run_id=req.evaluation_run_id, ) - # ---- Prediction V2 API ---------------------------------------------------------- - @validate_call def prediction_create( self, req: tsi.PredictionCreateReq ) -> tsi.PredictionCreateRes: - return self._v2_body_create( - req, - sdk_models.PredictionCreateBody, - self._sdk.v2_predictions.create, - tsi.PredictionCreateRes, + return self._v2_create( + req, tsi.PredictionCreateBody, self._sdk.v2_predictions.create ) @validate_call def prediction_read(self, req: tsi.PredictionReadReq) -> tsi.PredictionReadRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_predictions.read, - tsi.PredictionReadRes, entity=entity, project=project, prediction_id=req.prediction_id, @@ -1656,23 +1401,20 @@ def prediction_read(self, req: tsi.PredictionReadReq) -> tsi.PredictionReadRes: def prediction_list( self, req: tsi.PredictionListReq ) -> Iterator[tsi.PredictionReadRes]: - params: dict[str, Any] = {} + params = self._v2_list_params(req) if req.evaluation_run_id is not None: params["evaluation_run_id"] = req.evaluation_run_id - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._v2_list_stream(req, tsi.PredictionReadRes, "predictions", params) + return self._v2_list_stream( + req.project_id, tsi.PredictionReadRes, "predictions", params + ) @validate_call def prediction_delete( self, req: tsi.PredictionDeleteReq ) -> tsi.PredictionDeleteRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_predictions.delete, - tsi.PredictionDeleteRes, entity=entity, project=project, prediction_ids=req.prediction_ids, @@ -1683,31 +1425,22 @@ def prediction_finish( self, req: tsi.PredictionFinishReq ) -> tsi.PredictionFinishRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_predictions.finish, - tsi.PredictionFinishRes, entity=entity, project=project, prediction_id=req.prediction_id, ) - # ---- Score V2 API --------------------------------------------------------------- - @validate_call def score_create(self, req: tsi.ScoreCreateReq) -> tsi.ScoreCreateRes: - return self._v2_body_create( - req, - sdk_models.ScoreCreateBody, - self._sdk.v2_scores.create, - tsi.ScoreCreateRes, - ) + return self._v2_create(req, tsi.ScoreCreateBody, self._sdk.v2_scores.create) @validate_call def score_read(self, req: tsi.ScoreReadReq) -> tsi.ScoreReadRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_scores.read, - tsi.ScoreReadRes, entity=entity, project=project, score_id=req.score_id, @@ -1715,21 +1448,16 @@ def score_read(self, req: tsi.ScoreReadReq) -> tsi.ScoreReadRes: @validate_call def score_list(self, req: tsi.ScoreListReq) -> Iterator[tsi.ScoreReadRes]: - params: dict[str, Any] = {} + params = self._v2_list_params(req) if req.evaluation_run_id is not None: params["evaluation_run_id"] = req.evaluation_run_id - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._v2_list_stream(req, tsi.ScoreReadRes, "scores", params) + return self._v2_list_stream(req.project_id, tsi.ScoreReadRes, "scores", params) @validate_call def score_delete(self, req: tsi.ScoreDeleteReq) -> tsi.ScoreDeleteRes: entity, project = from_project_id(req.project_id) - return self._via_sdk_no_body( + return self._call_sdk( self._sdk.v2_scores.delete, - tsi.ScoreDeleteRes, entity=entity, project=project, score_ids=req.score_ids,