diff --git a/pyproject.toml b/pyproject.toml index edff8675da93..631721bde608 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,11 @@ dependencies = [ # (gen_ai.tool.call.arguments, cache token names, # gen_ai.usage.reasoning.output_tokens) added in 0.63b0. "opentelemetry-semantic-conventions>=0.63b0", + + # Generated trace-server client (source of truth for the API types). + # TEMPORARY: resolved from test PyPI via [tool.uv.sources] below until the + # package is published to real PyPI. + "weave-server-sdk==0.0.1", ] [project.optional-dependencies] @@ -233,6 +238,7 @@ select = [ "FIX003", # https://docs.astral.sh/ruff/rules/line-contains-xxx/ "I", # https://docs.astral.sh/ruff/rules/#isort-i "W", # https://docs.astral.sh/ruff/rules/#warning-w + "TID251", # banned-api: enforces the client/server import boundary (see flake8-tidy-imports below) "TID252", # https://docs.astral.sh/ruff/rules/relative-imports/#relative-imports-tid252 "UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up "TRY", # https://docs.astral.sh/ruff/rules/#tryceratops-try @@ -318,12 +324,25 @@ line-length = 88 show-fixes = true exclude = ["rules"] +[tool.ruff.lint.flake8-tidy-imports.banned-api] +# Client code must not depend on the trace server's interface modules; API +# types come from weave_server_sdk (generated from the OpenAPI spec). The ban +# list grows as the migration proceeds. The server package itself, tests, and +# scripts are exempt via per-file-ignores below. +"weave.trace_server.http_service_interface".msg = "Client code must not use the server's HTTP body models. Use weave_server_sdk.models (generated from the OpenAPI spec) instead." + [tool.ruff.lint.per-file-ignores] "!/weave/trace/**/*.py" = ["T201"] "!/tests/**/*.py" = ["RUF059"] -# Tests intentionally use async functions without await, compare known float values, -# and use pytest.raises match strings with regex-like characters. -"tests/**/*.py" = ["RUF029", "RUF043", "RUF069"] +# The trace server may import its own modules; tests exercise the server +# directly; scripts and the mock server are operational server tooling. +"weave/trace_server/**/*.py" = ["TID251"] +"scripts/**/*.py" = ["TID251"] +"trace_server_mock/**/*.py" = ["TID251"] +# Tests intentionally use async functions without await, compare known float +# values, use pytest.raises match strings with regex-like characters, and may +# exercise the trace server directly (TID251). +"tests/**/*.py" = ["RUF029", "RUF043", "RUF069", "TID251"] "weave/trace/serialization/op_type.py" = ["RUF100", "N802"] "weave/trace_server/costs/update_costs.py" = ["PLW1514"] "weave/type_handlers/Video/video.py" = ["PLW0603"] @@ -526,6 +545,17 @@ conflicts = [ ], ] +# TEMPORARY: weave-server-sdk is only published to test PyPI so far. Resolve +# just that package from the test index; everything else stays on real PyPI. +# Remove this (and the index below) once it is published to real PyPI. +[tool.uv.sources] +weave-server-sdk = { index = "testpypi" } + +[[tool.uv.index]] +name = "testpypi" +url = "https://test.pypi.org/simple/" +explicit = true + # uv workspace members. Each listed path is a sibling Python package that uv # manages alongside the root `weave` project. Members can depend on each # other and on the root by name; uv resolves those to the local workspace diff --git a/tests/conftest.py b/tests/conftest.py index 249fcab7a144..2e1d0115cfbf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from typing import Any from unittest.mock import MagicMock, patch +import httpx import pytest from fastapi import FastAPI from fastapi.responses import StreamingResponse @@ -597,11 +598,12 @@ def client( @pytest.fixture def network_proxy_client(client, monkeypatch): - """This fixture is used to test the `RemoteHTTPTraceServer` class. There is - almost no logic in this class, other than a little batching, so we typically - skip it for simplicity. However, we can use this fixture to test such logic. - It initializes a mini FastAPI app that proxies requests from the - `RemoteHTTPTraceServer` to the underlying `client.server` object. + """This fixture is used to test the `RemoteHTTPTraceServer` class. + There is almost no logic in this class, other than a little batching, so we + typically skip it for simplicity. However, we can use this fixture to test + such logic. It initializes a mini FastAPI app and routes the server's HTTP + transport into it, proxying requests to the underlying `client.server` + object. We probably will want to flesh this out more in the future, but this is a starting point. @@ -707,12 +709,25 @@ def obj_read(req: tsi.ObjReadReq) -> tsi.ObjReadRes: with TestClient(app) as c: - def post(url, data=None, json=None, **kwargs): - kwargs.pop("stream", None) - return c.post(url, data=data, json=json, **kwargs) - - orig_post = weave.utils.http_requests.post - weave.utils.http_requests.post = post + class TestClientTransport(httpx.BaseTransport): + """Routes the server's httpx requests into the FastAPI TestClient.""" + + def handle_request(self, request: httpx.Request) -> httpx.Response: + request.read() + resp = c.request( + request.method, + request.url.path, + params=request.url.params, + content=request.content, + headers={ + k: v + for k, v in request.headers.items() + if k.lower() not in {"host", "content-length"} + }, + ) + return httpx.Response( + resp.status_code, headers=resp.headers, content=resp.content + ) def make_fast_async_batch_processor(*args, **kwargs): kwargs.setdefault("min_batch_interval", 0) @@ -733,14 +748,15 @@ def make_fast_call_batch_processor(*args, **kwargs): make_fast_call_batch_processor, ) + # Absolute base URL required for httpx cookie handling; the transport + # routes by path, so the host is never dialed. remote_client = RemoteHTTPTraceServer( - trace_server_url="", + trace_server_url="http://testserver", should_batch=True, + transport=TestClientTransport(), ) yield (client, remote_client, records) - weave.utils.http_requests.post = orig_post - @pytest.fixture(autouse=True) def caching_client_isolation(monkeypatch, tmp_path): diff --git a/tests/trace_server_bindings/conftest.py b/tests/trace_server_bindings/conftest.py index 2a810877c675..1548f25acb6e 100644 --- a/tests/trace_server_bindings/conftest.py +++ b/tests/trace_server_bindings/conftest.py @@ -2,6 +2,7 @@ from types import MethodType from unittest.mock import MagicMock +import httpx import pytest import tenacity @@ -59,24 +60,60 @@ def generate_call_start_end_pair( return tsi.CallStartReq(start=start), tsi.CallEndReq(end=end) +# ============================================================================= +# HTTP transport spy +# ============================================================================= + + +class SpyTransport(httpx.BaseTransport): + """httpx transport that records requests and replays queued responses. + + Queue items may be ``httpx.Response`` objects or exceptions to raise. + When the queue is empty, returns ``default_response`` (200 ``{}`` unless + overridden). + """ + + def __init__( + self, + *items: httpx.Response | Exception, + default_response: httpx.Response | None = None, + ) -> None: + self.queue: list[httpx.Response | Exception] = list(items) + self.requests: list[httpx.Request] = [] + self.default_response = default_response + + def handle_request(self, request: httpx.Request) -> httpx.Response: + request.read() + self.requests.append(request) + if self.queue: + item = self.queue.pop(0) + if isinstance(item, Exception): + raise item + return item + if self.default_response is not None: + return self.default_response + return httpx.Response(200, json={}) + + @property + def urls(self) -> list[str]: + return [str(r.url) for r in self.requests] + + # ============================================================================= # Fixtures # ============================================================================= @pytest.fixture -def success_response(): - """Common fixture for mocking a successful HTTP response.""" - response = MagicMock() - response.status_code = 200 - response.json.return_value = {"id": "test_id", "trace_id": "test_trace_id"} - return response +def server_class(): + """The remote trace server implementation under test.""" + return RemoteHTTPTraceServer @pytest.fixture -def server(request): - """Common server fixture configured by the indirect parameter.""" - server_ = RemoteHTTPTraceServer("http://example.com", should_batch=True) +def server(request, server_class): + """Common server fixture parametrized by batching/retry behavior.""" + server_ = server_class("http://example.com", should_batch=True) if request.param == "normal": server_._send_batch_to_server = MagicMock() diff --git a/tests/trace_server_bindings/test_http_behavior.py b/tests/trace_server_bindings/test_http_behavior.py index 66ae13d2a87a..b78abd66e9b6 100644 --- a/tests/trace_server_bindings/test_http_behavior.py +++ b/tests/trace_server_bindings/test_http_behavior.py @@ -1,7 +1,11 @@ """HTTP behavior tests for RemoteHTTPTraceServer. These tests verify HTTP request/response handling, retry behavior for various -status codes, and error handling specific to RemoteHTTPTraceServer. +status codes, and error handling of the remote trace server binding. + +Mocking happens at the httpx transport boundary (the external seam), so the +full stack — SDK encode/decode, event hooks, error translation, retry +predicates — is exercised for real. """ from __future__ import annotations @@ -10,7 +14,6 @@ import json import logging from types import MethodType -from unittest.mock import MagicMock, patch import httpx import pytest @@ -18,6 +21,7 @@ from pydantic import ValidationError from tests.trace_server_bindings.conftest import ( + SpyTransport, generate_end, generate_id, generate_start, @@ -28,6 +32,7 @@ from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor from weave.trace_server_bindings.http_utils import ( ERROR_CODE_CALLS_COMPLETE_MODE_REQUIRED, + CallsCompleteModeRequired, ) from weave.trace_server_bindings.models import ( CompleteBatchItem, @@ -38,6 +43,31 @@ RemoteHTTPTraceServer, ) +BASE_URL = "http://example.com" + + +def make_server( + transport: httpx.BaseTransport, + should_batch: bool = False, + **kwargs, +) -> RemoteHTTPTraceServer: + return RemoteHTTPTraceServer( + BASE_URL, should_batch=should_batch, transport=transport, **kwargs + ) + + +def shutdown(server: RemoteHTTPTraceServer) -> None: + if server.call_processor and server.call_processor.is_accepting_new_work(): + server.call_processor.stop_accepting_new_work_and_flush_queue() + if server.feedback_processor and server.feedback_processor.is_accepting_new_work(): + server.feedback_processor.stop_accepting_new_work_and_flush_queue() + + +def call_start_ok_response(call_id: str) -> httpx.Response: + return httpx.Response( + 200, json=tsi.CallStartRes(id=call_id, trace_id="test_trace_id").model_dump() + ) + def make_calls_complete_required_response() -> httpx.Response: """Create a 400 response indicating the project requires calls_complete mode.""" @@ -47,58 +77,175 @@ def make_calls_complete_required_response() -> httpx.Response: "error_code": ERROR_CODE_CALLS_COMPLETE_MODE_REQUIRED, "message": "Project requires calls_complete mode", }, - request=httpx.Request("POST", "http://example.com/call/upsert_batch"), ) -@pytest.fixture -def unbatched_server(): - """Create a RemoteHTTPTraceServer instance without batching for testing.""" - return RemoteHTTPTraceServer("http://example.com") - - -@patch("weave.utils.http_requests.post") -def test_call_start_ok(mock_post, unbatched_server): +def test_call_start_ok(): """Test successful call_start request.""" call_id = generate_id() - mock_response = httpx.Response( - 200, - json=dict(tsi.CallStartRes(id=call_id, trace_id="test_trace_id")), - request=httpx.Request("POST", "http://test.com"), - ) - mock_post.return_value = mock_response + transport = SpyTransport(call_start_ok_response(call_id)) + server = make_server(transport) + start = generate_start(call_id) - unbatched_server.call_start(tsi.CallStartReq(start=start)) - mock_post.assert_called_once() + result = server.call_start(tsi.CallStartReq(start=start)) + + assert transport.urls == [f"{BASE_URL}/call/start"] + sent = json.loads(transport.requests[0].content) + assert sent["start"]["id"] == call_id + assert result.id == call_id + assert result.trace_id == "test_trace_id" -@patch("weave.utils.http_requests.post") -def test_400_no_retry(mock_post, unbatched_server): +def test_400_no_retry(): """Test that 400 errors are not retried.""" call_id = generate_id() - resp1 = httpx.Response( - 400, - json=dict(tsi.CallStartRes(id=call_id, trace_id="test_trace_id")), - request=httpx.Request("POST", "http://test.com"), - ) - mock_post.side_effect = [resp1] + transport = SpyTransport(httpx.Response(400, json={"error": "Bad Request"})) + server = make_server(transport) start = generate_start(call_id) with pytest.raises(httpx.HTTPStatusError): - unbatched_server.call_start(tsi.CallStartReq(start=start)) + server.call_start(tsi.CallStartReq(start=start)) + + # Should only be called once (no retry for 400) + assert len(transport.requests) == 1 -def test_invalid_no_retry(unbatched_server): +def test_invalid_no_retry(): """Test that validation errors are not retried.""" + transport = SpyTransport() + server = make_server(transport) with pytest.raises(ValidationError): - unbatched_server.call_start(tsi.CallStartReq(start={"invalid": "broken"})) + server.call_start(tsi.CallStartReq(start={"invalid": "broken"})) + assert len(transport.requests) == 0 + + +def test_500_502_503_504_429_retry(monkeypatch): + """Test that 5xx and 429 errors are retried.""" + monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "6") + monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") + call_id = generate_id() + + transport = SpyTransport( + httpx.Response(500), + httpx.Response(502), + httpx.Response(503), + httpx.Response(504), + httpx.Response(429), + call_start_ok_response(call_id), + ) + server = make_server(transport) + + start = generate_start(call_id) + result = server.call_start(tsi.CallStartReq(start=start)) + assert result.id == call_id + assert len(transport.requests) == 6 + + +def test_other_error_retry(monkeypatch): + """Test that connection errors are retried.""" + monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "6") + monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") + call_id = generate_id() + + transport = SpyTransport( + ConnectionResetError(), + ConnectionError(), + OSError(), + TimeoutError(), + call_start_ok_response(call_id), + ) + server = make_server(transport) + + start = generate_start(call_id) + result = server.call_start(tsi.CallStartReq(start=start)) + assert result.id == call_id -@patch("weave.utils.http_requests.post") -def test_calls_complete_batch_endpoint_and_payload(mock_post, monkeypatch): +def test_retry_id_header_injected(monkeypatch): + """Every request carries the per-attempt X-Weave-Retry-Id header.""" + monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "3") + monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") + call_id = generate_id() + transport = SpyTransport(httpx.Response(500), call_start_ok_response(call_id)) + server = make_server(transport) + + server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + + retry_ids = [r.headers.get("X-Weave-Retry-Id") for r in transport.requests] + assert all(retry_ids) + # Same logical request, so both attempts share one retry id + assert len(set(retry_ids)) == 1 + + +def test_extra_headers_and_auth_are_sent(): + """Constructor extra_headers and auth flow into every request.""" + call_id = generate_id() + transport = SpyTransport(call_start_ok_response(call_id)) + server = make_server( + transport, auth=("api", "secret-key"), extra_headers={"X-Custom": "yes"} + ) + + server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + + request = transport.requests[0] + assert request.headers["X-Custom"] == "yes" + assert request.headers["Authorization"].startswith("Basic ") + + +def test_set_auth_applies_to_subsequent_requests(): + """set_auth after construction updates the live client.""" + call_id = generate_id() + transport = SpyTransport(default_response=None) + transport.default_response = call_start_ok_response(call_id) + server = make_server(transport) + + server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + assert "Authorization" not in transport.requests[0].headers + + server.set_auth(("api", "secret-key")) + server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + assert transport.requests[1].headers["Authorization"].startswith("Basic ") + + +def test_typed_sdk_route_obj_read(): + """A typed SDK route hits the right path and converts the response.""" + transport = SpyTransport( + httpx.Response( + 200, + json={ + "obj": { + "project_id": "entity/project", + "object_id": "my-obj", + "created_at": "2024-01-01T00:00:00Z", + "deleted_at": None, + "digest": "abc", + "version_index": 0, + "is_latest": 1, + "kind": "object", + "base_object_class": None, + "leaf_object_class": None, + "val": {"a": 1}, + } + }, + ) + ) + server = make_server(transport) + + res = server.obj_read( + tsi.ObjReadReq(project_id="entity/project", object_id="my-obj", digest="abc") + ) + + assert transport.urls == [f"{BASE_URL}/obj/read"] + assert isinstance(res, tsi.ObjReadRes) + assert res.obj.object_id == "my-obj" + assert res.obj.val == {"a": 1} + + +def test_calls_complete_batch_endpoint_and_payload(monkeypatch): """Send calls_complete batches to the v2 endpoint with correct payload.""" monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "true") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) + transport = SpyTransport() + server = make_server(transport, should_batch=True) complete = tsi.CompletedCallSchemaForInsert( project_id="entity/project", @@ -114,33 +261,21 @@ def test_calls_complete_batch_endpoint_and_payload(mock_post, monkeypatch): ) batch = [CompleteBatchItem(req=complete)] - mock_post.return_value = httpx.Response( - 200, - json={}, - request=httpx.Request("POST", "http://example.com"), - ) - try: server._flush_calls_complete(batch) finally: - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() - - assert mock_post.call_count == 1 - url = mock_post.call_args[0][0] - assert url == "http://example.com/v2/entity/project/calls/complete" - sent_data = mock_post.call_args[1]["data"] - payload = json.loads(sent_data.decode("utf-8")) + shutdown(server) + + assert transport.urls == [f"{BASE_URL}/v2/entity/project/calls/complete"] + payload = json.loads(transport.requests[0].content) expected = tsi.CallsUpsertCompleteReq(batch=[complete]).model_dump(mode="json") assert payload == expected -@patch("weave.utils.http_requests.post") -def test_eager_calls_use_v2_start_end_endpoints(mock_post): +def test_eager_calls_use_v2_start_end_endpoints(): """Use v2 endpoints for eager start/end and include started_at in end.""" - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) + transport = SpyTransport() + server = make_server(transport, should_batch=True) start = generate_start(id="call-id", project_id="entity/project") ended_at = datetime.datetime.now(tz=datetime.timezone.utc) @@ -153,11 +288,6 @@ def test_eager_calls_use_v2_start_end_endpoints(mock_post): summary={"result": "Test summary"}, ) - mock_post.side_effect = [ - httpx.Response(200, request=httpx.Request("POST", "http://example.com")), - httpx.Response(200, request=httpx.Request("POST", "http://example.com")), - ] - try: server._flush_calls_eager( [ @@ -166,50 +296,33 @@ def test_eager_calls_use_v2_start_end_endpoints(mock_post): ] ) - urls = [call[0][0] for call in mock_post.call_args_list] - assert urls == [ - "http://example.com/v2/entity/project/call/start", - "http://example.com/v2/entity/project/call/end", + assert transport.urls == [ + f"{BASE_URL}/v2/entity/project/call/start", + f"{BASE_URL}/v2/entity/project/call/end", ] - end_payload = json.loads(mock_post.call_args_list[1][1]["data"].decode("utf-8")) + end_payload = json.loads(transport.requests[1].content) payload_started_at = datetime.datetime.fromisoformat( end_payload["end"]["started_at"].replace("Z", "+00:00") ) assert payload_started_at == end.started_at assert end_payload["end"]["id"] == "call-id" finally: - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() + shutdown(server) @pytest.mark.disable_logging_error_check def test_eager_non_retryable_error_drops_item(caplog): """Drop eager items on non-retryable errors without raising.""" - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) + transport = SpyTransport(default_response=httpx.Response(400, json={})) + server = make_server(transport, should_batch=True) start = generate_start(id="call-id", project_id="entity/project") - def _raise_non_retryable(_: StartBatchItem) -> None: - raise httpx.HTTPStatusError( - "400", - request=httpx.Request("POST", "http://example.com"), - response=httpx.Response( - 400, request=httpx.Request("POST", "http://example.com") - ), - ) - - server._send_call_start_v2 = _raise_non_retryable # type: ignore[assignment] - caplog.set_level(logging.ERROR) try: server._flush_calls_eager([StartBatchItem(req=tsi.CallStartReq(start=start))]) finally: - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() + shutdown(server) assert any("dropped call start ids" in record.message for record in caplog.records) @@ -217,7 +330,8 @@ def _raise_non_retryable(_: StartBatchItem) -> None: @pytest.mark.disable_logging_error_check def test_eager_retryable_error_logs_and_continues(caplog): """Log and drop item on retryable error, continue with remaining items.""" - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) + transport = SpyTransport() + server = make_server(transport, should_batch=True) start1 = generate_start(id="call-id-1", project_id="entity/project") start2 = generate_start(id="call-id-2", project_id="entity/project") @@ -228,10 +342,8 @@ def _raise_retryable_once(start) -> None: if start.id == "call-id-1": raise httpx.HTTPStatusError( "500", - request=httpx.Request("POST", "http://example.com"), - response=httpx.Response( - 500, request=httpx.Request("POST", "http://example.com") - ), + request=httpx.Request("POST", BASE_URL), + response=httpx.Response(500, request=httpx.Request("POST", BASE_URL)), ) # call-id-2 succeeds @@ -246,10 +358,7 @@ def _raise_retryable_once(start) -> None: ] ) finally: - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() + shutdown(server) # Item 1 was logged as dropped assert any("dropped call start ids" in record.message for record in caplog.records) @@ -258,150 +367,93 @@ def _raise_retryable_once(start) -> None: assert "call-id-2" in call_attempts -@patch("weave.utils.http_requests.post") -def test_500_502_503_504_429_retry(mock_post, unbatched_server, monkeypatch): - """Test that 5xx and 429 errors are retried.""" - monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "6") - monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") - call_id = generate_id() - - mock_post.side_effect = [ - httpx.Response(500, request=httpx.Request("POST", "http://test.com")), - httpx.Response(502, request=httpx.Request("POST", "http://test.com")), - httpx.Response(503, request=httpx.Request("POST", "http://test.com")), - httpx.Response(504, request=httpx.Request("POST", "http://test.com")), - httpx.Response(429, request=httpx.Request("POST", "http://test.com")), - httpx.Response( - 200, - json=dict(tsi.CallStartRes(id=call_id, trace_id="test_trace_id")), - request=httpx.Request("POST", "http://test.com"), - ), - ] - start = generate_start(call_id) - unbatched_server.call_start(tsi.CallStartReq(start=start)) - - -@patch("weave.utils.http_requests.post") -def test_other_error_retry(mock_post, unbatched_server, monkeypatch): - """Test that connection errors are retried.""" - monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "6") - monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") - call_id = generate_id() - - mock_post.side_effect = [ - ConnectionResetError(), - ConnectionError(), - OSError(), - TimeoutError(), - httpx.Response( - 200, - json=dict(tsi.CallStartRes(id=call_id, trace_id="test_trace_id")), - request=httpx.Request("POST", "http://test.com"), - ), - ] - start = generate_start(call_id) - unbatched_server.call_start(tsi.CallStartReq(start=start)) - - -@patch("weave.utils.http_requests.post") -def test_timeout_retry_mechanism(mock_post, success_response, monkeypatch): +def test_timeout_retry_mechanism(monkeypatch): """Test that timeouts trigger the retry mechanism.""" monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - - # Mock server to raise errors twice, then succeed - mock_post.side_effect = [ + monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") + transport = SpyTransport( httpx.TimeoutException("Connection timed out"), - httpx.HTTPStatusError( - "500 Server Error", request=MagicMock(), response=MagicMock(status_code=500) - ), - success_response, - ] + httpx.Response(500), + httpx.Response(200, json={}), + ) + server = make_server(transport, should_batch=True) # Trying to send a batch should fail 2 times, then succeed server.call_start(tsi.CallStartReq(start=generate_start())) server.call_processor.stop_accepting_new_work_and_flush_queue() - # Verify that requests.post was called 3 times - assert mock_post.call_count == 3 + assert len(transport.requests) == 3 + assert all(url == f"{BASE_URL}/call/upsert_batch" for url in transport.urls) + +@pytest.mark.disable_logging_error_check +def test_post_timeout(monkeypatch, log_collector): + """Test batch recovery after timeout exhaustion. -@pytest.fixture -def fast_retrying_server(monkeypatch): - """Create a RemoteHTTPTraceServer with fast retry settings for testing.""" + This test verifies that we can still send new batches even if one batch + times out and exhausts all retries. + """ + configure_logger() monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) + fast_retry = tenacity.retry( wait=tenacity.wait_fixed(0.1), stop=tenacity.stop_after_attempt(2), reraise=True, ) + + # Phase 1: Try but fail to process the first batch + transport = SpyTransport( + default_response=None, + ) + transport.default_response = None + transport.queue = [ + httpx.TimeoutException("Connection timed out"), + httpx.TimeoutException("Connection timed out"), + ] + server = make_server(transport, should_batch=True) unwrapped_send_batch_to_server = MethodType( server._send_batch_to_server.__wrapped__, # type: ignore[attr-defined] server, ) server._send_batch_to_server = fast_retry(unwrapped_send_batch_to_server) - yield server - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() - - -@pytest.mark.disable_logging_error_check -@patch("weave.utils.http_requests.post") -def test_post_timeout(mock_post, success_response, fast_retrying_server, log_collector): - """Test batch recovery after timeout exhaustion. - This test verifies that we can still send new batches even if one batch - times out and exhausts all retries. - """ - configure_logger() - # Configure mock to timeout twice to exhaust retries - mock_post.side_effect = [ - httpx.TimeoutException("Connection timed out"), - httpx.TimeoutException("Connection timed out"), - ] - - # Phase 1: Try but fail to process the first batch - fast_retrying_server.call_start(tsi.CallStartReq(start=generate_start())) - fast_retrying_server.call_processor.stop_accepting_new_work_and_flush_queue() + server.call_start(tsi.CallStartReq(start=generate_start())) + server.call_processor.stop_accepting_new_work_and_flush_queue() logs = log_collector.get_warning_logs() assert len(logs) >= 1 assert any("requeuing batch" in log.msg for log in logs) + shutdown(server) - # Phase 2: Reset mock and verify we can still process a new batch - mock_post.reset_mock() - mock_post.side_effect = [ + # Phase 2: Verify a fresh server can still process a new batch after a + # transient timeout + call_id = generate_id() + transport2 = SpyTransport( httpx.TimeoutException("Connection timed out"), - success_response, - ] - - # Create a new server since the old one has shutdown its batch processor - new_server = RemoteHTTPTraceServer("http://example.com", should_batch=False) - fast_retry = tenacity.retry( - wait=tenacity.wait_fixed(0.1), - stop=tenacity.stop_after_attempt(2), - reraise=True, + call_start_ok_response(call_id), ) - unwrapped_send_batch_to_server = MethodType( - new_server._send_batch_to_server.__wrapped__, # type: ignore[attr-defined] - new_server, + new_server = make_server(transport2, should_batch=False) + fast_retried_via_sdk = fast_retry( + MethodType( + new_server._via_sdk.__wrapped__, # type: ignore[attr-defined] + new_server, + ) ) - new_server._send_batch_to_server = fast_retry(unwrapped_send_batch_to_server) + new_server._via_sdk = fast_retried_via_sdk - # Should succeed with retry - start_req = tsi.CallStartReq(start=generate_start()) - response = new_server.call_start(start_req) - assert response.id == "test_id" + response = new_server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + assert response.id == call_id assert response.trace_id == "test_trace_id" -@patch("weave.utils.http_requests.post") -def test_auto_upgrade_to_calls_complete_on_error(mock_post, monkeypatch): +def test_auto_upgrade_to_calls_complete_on_error(monkeypatch): """Verify client switches to CallBatchProcessor when server returns CALLS_COMPLETE_MODE_REQUIRED.""" monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) + transport = SpyTransport( + make_calls_complete_required_response(), + httpx.Response(200, json={}), + ) + server = make_server(transport, should_batch=True) # Verify initial state: using legacy AsyncBatchProcessor assert server.use_calls_complete is False @@ -409,13 +461,6 @@ def test_auto_upgrade_to_calls_complete_on_error(mock_post, monkeypatch): assert not isinstance(server.call_processor, CallBatchProcessor) old_processor = server.call_processor - mock_post.side_effect = [ - make_calls_complete_required_response(), - httpx.Response( - 200, json={}, request=httpx.Request("POST", "http://example.com") - ), - ] - call_id = generate_id() start = StartBatchItem( req=tsi.CallStartReq(start=generate_start(call_id, "entity/project")) @@ -430,22 +475,16 @@ def test_auto_upgrade_to_calls_complete_on_error(mock_post, monkeypatch): assert server.use_calls_complete is True assert isinstance(server.call_processor, CallBatchProcessor) assert old_processor.stop_accepting_work_event.is_set() - assert any("/calls/complete" in call[0][0] for call in mock_post.call_args_list) + assert any("/calls/complete" in url for url in transport.urls) finally: - if server.call_processor and server.call_processor.is_accepting_new_work(): - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() + shutdown(server) -@patch("weave.utils.http_requests.post") -def test_eager_calls_complete_required_is_reraised(mock_post, monkeypatch): +def test_eager_calls_complete_required_is_reraised(monkeypatch): """Verify CallsCompleteModeRequired in eager path is re-raised for caller to handle.""" - from weave.trace_server_bindings.http_utils import CallsCompleteModeRequired - monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "true") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - mock_post.side_effect = [make_calls_complete_required_response()] + transport = SpyTransport(make_calls_complete_required_response()) + server = make_server(transport, should_batch=True) start = StartBatchItem( req=tsi.CallStartReq(start=generate_start("call-id", "entity/project")) @@ -455,7 +494,81 @@ def test_eager_calls_complete_required_is_reraised(mock_post, monkeypatch): with pytest.raises(CallsCompleteModeRequired): server._flush_calls_eager([start]) finally: - if server.call_processor and server.call_processor.is_accepting_new_work(): - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() + shutdown(server) + + +def test_calls_query_stream_parses_jsonl(): + """calls_query_stream yields CallSchema objects from a jsonl response.""" + call = { + "project_id": "entity/project", + "id": "call-1", + "op_name": "op", + "trace_id": "trace-1", + "started_at": "2024-01-01T00:00:00Z", + "attributes": {}, + "inputs": {}, + } + body = "\n".join([json.dumps(call), json.dumps({**call, "id": "call-2"})]) + transport = SpyTransport( + httpx.Response( + 200, + content=body.encode("utf-8"), + headers={"content-type": "application/jsonl"}, + ) + ) + server = make_server(transport) + + calls = list( + server.calls_query_stream(tsi.CallsQueryReq(project_id="entity/project")) + ) + + assert transport.urls == [f"{BASE_URL}/calls/stream_query"] + assert [c.id for c in calls] == ["call-1", "call-2"] + assert all(isinstance(c, tsi.CallSchema) for c in calls) + + +def test_file_create_sends_multipart(): + """file_create posts multipart form data (SDK 0.0.1 lost the body).""" + transport = SpyTransport(httpx.Response(200, json={"digest": "digest-1"})) + server = make_server(transport) + + res = server.file_create( + tsi.FileCreateReq( + project_id="entity/project", name="file.txt", content=b"hello" + ) + ) + + assert res.digest == "digest-1" + request = transport.requests[0] + assert str(request.url) == f"{BASE_URL}/files/create" + assert request.headers["content-type"].startswith("multipart/form-data") + assert b"hello" in request.content + assert b"entity/project" in request.content + + +def test_feedback_create_unbatched_uses_single_route(): + """Unbatched feedback_create posts to /feedback/create (SDK route shadowed).""" + transport = SpyTransport( + httpx.Response( + 200, + json={ + "id": "feedback-1", + "created_at": "2024-01-01T00:00:00Z", + "wb_user_id": "user", + "payload": {"note": "hi"}, + }, + ) + ) + server = make_server(transport) + + res = server.feedback_create( + tsi.FeedbackCreateReq( + project_id="entity/project", + weave_ref="weave:///entity/project/call/call-1", + feedback_type="wandb.note.1", + payload={"note": "hi"}, + ) + ) + + assert transport.urls == [f"{BASE_URL}/feedback/create"] + assert res.id == "feedback-1" diff --git a/tests/trace_server_bindings/test_tags_aliases_routes.py b/tests/trace_server_bindings/test_tags_aliases_routes.py index 994c969c9751..817b9426c795 100644 --- a/tests/trace_server_bindings/test_tags_aliases_routes.py +++ b/tests/trace_server_bindings/test_tags_aliases_routes.py @@ -1,17 +1,18 @@ """Tests for RESTful tags and aliases routes in RemoteHTTPTraceServer. These tests verify that the tag/alias methods send the correct HTTP method, -URL path, and request body through the RemoteHTTPTraceServer client. +URL path, and request body. Requests are observed at the httpx transport +boundary, so the SDK routing and conversion layers are exercised for real. """ from __future__ import annotations import json -from unittest.mock import MagicMock, patch import httpx import pytest +from tests.trace_server_bindings.conftest import SpyTransport from weave.trace_server import trace_server_interface as tsi from weave.trace_server_bindings.remote_http_trace_server import ( RemoteHTTPTraceServer, @@ -21,9 +22,14 @@ @pytest.fixture -def server(): - """Create a RemoteHTTPTraceServer with mocked HTTP methods.""" - srv = RemoteHTTPTraceServer(BASE_URL, should_batch=False) +def transport(): + return SpyTransport() + + +@pytest.fixture +def server(transport): + """Create a RemoteHTTPTraceServer over a spy transport.""" + srv = RemoteHTTPTraceServer(BASE_URL, should_batch=False, transport=transport) yield srv if srv.call_processor: srv.call_processor.stop_accepting_new_work_and_flush_queue() @@ -31,164 +37,116 @@ def server(): srv.feedback_processor.stop_accepting_new_work_and_flush_queue() -def _mock_response(json_data: dict | None = None) -> MagicMock: - """Create a mock httpx.Response with the given JSON data.""" - resp = MagicMock(spec=httpx.Response) - resp.status_code = 200 - resp.json.return_value = json_data or {} - return resp - - class TestObjAddTags: - def test_sends_correct_request(self, server): - mock_resp = _mock_response() - with patch.object(server, "put", return_value=mock_resp) as mock_put: - server.obj_add_tags( - tsi.ObjAddTagsReq( - project_id="entity/project", - object_id="my-obj", - digest="abc123", - tags=["production", "reviewed"], - ) + def test_sends_correct_request(self, server, transport): + server.obj_add_tags( + tsi.ObjAddTagsReq( + project_id="entity/project", + object_id="my-obj", + digest="abc123", + tags=["production", "reviewed"], ) + ) - mock_put.assert_called_once() - call_url = mock_put.call_args[0][0] - assert call_url == "/objs/my-obj/versions/abc123/tags" + request = transport.requests[0] + assert request.method == "PUT" + assert str(request.url) == f"{BASE_URL}/objs/my-obj/versions/abc123/tags" - sent_data = mock_put.call_args[1]["data"] - body = json.loads(sent_data) - assert body["project_id"] == "entity/project" - assert body["tags"] == ["production", "reviewed"] - # object_id and digest should NOT be in the body (they're in the URL) - assert "object_id" not in body - assert "digest" not in body + body = json.loads(request.content) + assert body["project_id"] == "entity/project" + assert body["tags"] == ["production", "reviewed"] + # object_id and digest should NOT be in the body (they're in the URL) + assert "object_id" not in body + assert "digest" not in body class TestObjRemoveTags: - def test_sends_correct_request(self, server): - mock_resp = _mock_response() - with patch.object(server, "post", return_value=mock_resp) as mock_post: - server.obj_remove_tags( - tsi.ObjRemoveTagsReq( - project_id="entity/project", - object_id="my-obj", - digest="abc123", - tags=["staging"], - ) + def test_sends_correct_request(self, server, transport): + server.obj_remove_tags( + tsi.ObjRemoveTagsReq( + project_id="entity/project", + object_id="my-obj", + digest="abc123", + tags=["staging"], ) + ) - mock_post.assert_called_once() - call_url = mock_post.call_args[0][0] - assert call_url == "/objs/my-obj/versions/abc123/tags/remove" + request = transport.requests[0] + assert request.method == "POST" + assert str(request.url) == f"{BASE_URL}/objs/my-obj/versions/abc123/tags/remove" - sent_data = mock_post.call_args[1]["data"] - body = json.loads(sent_data) - assert body["project_id"] == "entity/project" - assert body["tags"] == ["staging"] - assert "object_id" not in body - assert "digest" not in body + body = json.loads(request.content) + assert body["project_id"] == "entity/project" + assert body["tags"] == ["staging"] + assert "object_id" not in body + assert "digest" not in body class TestObjSetAliases: - def test_sends_correct_request(self, server): - mock_resp = _mock_response() - with patch.object(server, "put", return_value=mock_resp) as mock_put: - server.obj_set_aliases( - tsi.ObjSetAliasesReq( - project_id="entity/project", - object_id="my-obj", - digest="abc123", - aliases=["staging", "candidate"], - ) + def test_sends_correct_request(self, server, transport): + server.obj_set_aliases( + tsi.ObjSetAliasesReq( + project_id="entity/project", + object_id="my-obj", + digest="abc123", + aliases=["staging", "candidate"], ) + ) - mock_put.assert_called_once() - call_url = mock_put.call_args[0][0] - assert call_url == "/objs/my-obj/aliases" + request = transport.requests[0] + assert request.method == "PUT" + assert str(request.url) == f"{BASE_URL}/objs/my-obj/aliases" - sent_data = mock_put.call_args[1]["data"] - body = json.loads(sent_data) - assert body["project_id"] == "entity/project" - assert body["digest"] == "abc123" - assert body["aliases"] == ["staging", "candidate"] - assert "object_id" not in body + body = json.loads(request.content) + assert body["project_id"] == "entity/project" + assert body["digest"] == "abc123" + assert body["aliases"] == ["staging", "candidate"] + assert "object_id" not in body class TestObjRemoveAliases: - def test_sends_correct_request(self, server): - mock_resp = _mock_response() - with patch.object(server, "post", return_value=mock_resp) as mock_post: - server.obj_remove_aliases( - tsi.ObjRemoveAliasesReq( - project_id="entity/project", - object_id="my-obj", - aliases=["staging"], - ) + def test_sends_correct_request(self, server, transport): + server.obj_remove_aliases( + tsi.ObjRemoveAliasesReq( + project_id="entity/project", + object_id="my-obj", + aliases=["staging"], ) + ) - mock_post.assert_called_once() - call_url = mock_post.call_args[0][0] - assert call_url == "/objs/my-obj/aliases/remove" + request = transport.requests[0] + assert request.method == "POST" + assert str(request.url) == f"{BASE_URL}/objs/my-obj/aliases/remove" - sent_data = mock_post.call_args[1]["data"] - body = json.loads(sent_data) - assert body["project_id"] == "entity/project" - assert body["aliases"] == ["staging"] - assert "object_id" not in body + body = json.loads(request.content) + assert body["project_id"] == "entity/project" + assert body["aliases"] == ["staging"] + assert "object_id" not in body class TestTagsList: - def test_sends_get_to_correct_url(self, server): - mock_resp = _mock_response({"tags": ["prod", "staging"]}) - with patch.object(server, "get", return_value=mock_resp) as mock_get: - result = server.tags_list(tsi.TagsListReq(project_id="entity/project")) - - mock_get.assert_called_once() - call_url = mock_get.call_args[0][0] - assert call_url == "/tags" + def test_sends_get_with_project_id_param(self, server, transport): + transport.queue.append(httpx.Response(200, json={"tags": ["prod", "staging"]})) + result = server.tags_list(tsi.TagsListReq(project_id="entity/project")) - def test_passes_project_id_as_query_param(self, server): - mock_resp = _mock_response({"tags": ["prod"]}) - with patch.object(server, "get", return_value=mock_resp) as mock_get: - server.tags_list(tsi.TagsListReq(project_id="entity/project")) - - call_kwargs = mock_get.call_args[1] - assert call_kwargs["params"] == {"project_id": "entity/project"} - - def test_returns_parsed_response(self, server): - mock_resp = _mock_response({"tags": ["prod", "staging"]}) - with patch.object(server, "get", return_value=mock_resp): - result = server.tags_list(tsi.TagsListReq(project_id="entity/project")) - - assert isinstance(result, tsi.TagsListRes) - assert result.tags == ["prod", "staging"] + request = transport.requests[0] + assert request.method == "GET" + assert request.url.path == "/tags" + assert request.url.params["project_id"] == "entity/project" + assert isinstance(result, tsi.TagsListRes) + assert result.tags == ["prod", "staging"] class TestAliasesList: - def test_sends_get_to_correct_url(self, server): - mock_resp = _mock_response({"aliases": ["deploy", "staging"]}) - with patch.object(server, "get", return_value=mock_resp) as mock_get: - server.aliases_list(tsi.AliasesListReq(project_id="entity/project")) - - mock_get.assert_called_once() - call_url = mock_get.call_args[0][0] - assert call_url == "/aliases" - - def test_passes_project_id_as_query_param(self, server): - mock_resp = _mock_response({"aliases": ["deploy"]}) - with patch.object(server, "get", return_value=mock_resp) as mock_get: - server.aliases_list(tsi.AliasesListReq(project_id="entity/project")) - - call_kwargs = mock_get.call_args[1] - assert call_kwargs["params"] == {"project_id": "entity/project"} - - def test_returns_parsed_response(self, server): - mock_resp = _mock_response({"aliases": ["deploy", "staging"]}) - with patch.object(server, "get", return_value=mock_resp): - result = server.aliases_list( - tsi.AliasesListReq(project_id="entity/project") - ) - - assert isinstance(result, tsi.AliasesListRes) - assert result.aliases == ["deploy", "staging"] + def test_sends_get_with_project_id_param(self, server, transport): + transport.queue.append( + httpx.Response(200, json={"aliases": ["deploy", "staging"]}) + ) + result = server.aliases_list(tsi.AliasesListReq(project_id="entity/project")) + + request = transport.requests[0] + assert request.method == "GET" + assert request.url.path == "/aliases" + assert request.url.params["project_id"] == "entity/project" + assert isinstance(result, tsi.AliasesListRes) + assert result.aliases == ["deploy", "staging"] diff --git a/tests/utils/test_http_requests.py b/tests/utils/test_http_requests.py index 6c21773a9ad5..6fd237e2ac30 100644 --- a/tests/utils/test_http_requests.py +++ b/tests/utils/test_http_requests.py @@ -33,7 +33,7 @@ def test_request_hook_logs_when_enabled(monkeypatch): request = httpx.Request("GET", "https://api.wandb.ai/calls") with patch.object(http_requests, "pprint_request") as mock_pprint_request: - http_requests._log_request(request) + http_requests.log_request(request) mock_pprint_request.assert_called_once_with(request) assert isinstance(request.extensions.get("weave_start_time"), float) @@ -43,7 +43,7 @@ def test_request_hook_noop_when_disabled(): request = httpx.Request("GET", "https://api.wandb.ai/calls") with patch.object(http_requests, "pprint_request") as mock_pprint_request: - http_requests._log_request(request) + http_requests.log_request(request) mock_pprint_request.assert_not_called() assert "weave_start_time" not in request.extensions @@ -59,7 +59,7 @@ def test_response_hook_logs_when_enabled(monkeypatch): patch.object(http_requests, "pprint_response") as mock_pprint_response, patch("weave.utils.http_requests.time", return_value=2.0), ): - http_requests._log_response(response) + http_requests.log_response(response) mock_pprint_response.assert_called_once_with(response) diff --git a/uv.lock b/uv.lock index e3b3e0c26b7e..361fff69c68c 100644 --- a/uv.lock +++ b/uv.lock @@ -10762,6 +10762,7 @@ dependencies = [ { name = "sentry-sdk", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, { name = "tenacity", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, { name = "tzdata", marker = "(platform_python_implementation == 'CPython' and sys_platform == 'win32') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai') or (sys_platform != 'win32' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (sys_platform != 'win32' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (sys_platform != 'win32' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (sys_platform != 'win32' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (sys_platform != 'win32' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (sys_platform != 'win32' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, + { name = "weave-server-sdk", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, ] [package.optional-dependencies] @@ -11088,6 +11089,7 @@ requires-dist = [ { name = "verifiers", marker = "python_full_version >= '3.11' and extra == 'verifiers'", specifier = ">=0.1.3.post0,<0.1.13" }, { name = "vertexai", marker = "platform_python_implementation != 'PyPy' and extra == 'vertexai'", specifier = ">=1.70.0" }, { name = "wandb", marker = "extra == 'wandb'", specifier = ">=0.17.1" }, + { name = "weave-server-sdk", specifier = "==0.0.1", index = "https://test.pypi.org/simple/" }, ] provides-extras = ["anthropic", "autogen", "bedrock", "cerebras", "claude-agent-sdk", "cohere", "crewai", "dspy", "fastmcp", "gepa", "google-genai", "groq", "huggingface", "instructor", "langchain", "langchain-nvidia-ai-endpoints", "litellm", "llamaindex", "mistral", "modal", "notdiamond", "openai", "openai-agents", "presidio", "rich", "scorers", "smolagents", "trace-server", "verdict", "verifiers", "vertexai", "video-support", "wandb"] @@ -11153,6 +11155,20 @@ verifiers-test = [ { name = "verifiers", marker = "python_full_version >= '3.11'", specifier = ">=0.1.3.post0,<0.1.13" }, ] +[[package]] +name = "weave-server-sdk" +version = "0.0.1" +source = { registry = "https://test.pypi.org/simple/" } +dependencies = [ + { name = "httpx", version = "0.27.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.14' and platform_python_implementation == 'CPython' and extra == 'extra-5-weave-crewai') or (python_full_version >= '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (python_full_version >= '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (python_full_version >= '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (python_full_version >= '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai') or (extra != 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra != 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, + { name = "httpx", version = "0.28.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.14' and platform_python_implementation == 'CPython') or (python_full_version < '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (python_full_version < '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (python_full_version < '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (python_full_version < '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation == 'CPython' and extra != 'extra-5-weave-crewai') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, + { name = "pydantic", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, +] +sdist = { url = "https://test-files.pythonhosted.org/packages/4e/29/419e7dd1b628ee7a8f7501717b188eff04bbd5fa7c9937243701292517b1/weave_server_sdk-0.0.1.tar.gz", hash = "sha256:21d32d702ffdbee6f0ca00a31e550a6cb97076def9780da1cd18d3e23a7103f5", size = 22409, upload-time = "2026-06-11T04:23:44.073Z" } +wheels = [ + { url = "https://test-files.pythonhosted.org/packages/49/51/0027425fddfa07ff11e41f362d397f33f7247fcee3ca61c581215c6b190f/weave_server_sdk-0.0.1-py3-none-any.whl", hash = "sha256:c5a39f2f9573946e5f89b3cd3b1350226f047594b1122480ebc0e06b3ee0b8bc", size = 38852, upload-time = "2026-06-11T04:23:43.141Z" }, +] + [[package]] name = "webcolors" version = "25.10.0" diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 4ff7325e0d8a..a30cb159ed0d 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -2489,11 +2489,8 @@ def test_func(req: TableCreateFromDigestsReq) -> Any: if hasattr(server, "_next_trace_server"): server = server._next_trace_server - assert hasattr(server, "_post_request_executor") - assert hasattr(server._post_request_executor, "__wrapped__") - return server._post_request_executor.__wrapped__( - server, "/table/create_from_digests", req - ) + assert hasattr(server, "unretried_table_create_from_digests") + return server.unretried_table_create_from_digests(req) use_parallel_chunks = check_endpoint_exists( test_func, test_req, "table_create_from_digests" diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index b7ba2d07d59b..bcb61a9ba567 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -27,7 +27,6 @@ CachingMiddlewareTraceServer, ) from weave.trace_server_bindings.client_interface import TraceServerClientInterface -from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer from weave.trace_server_version import MIN_TRACE_SERVER_VERSION from weave.wandb_interface.context import get_wandb_api_context @@ -328,6 +327,12 @@ def init_weave_get_server( api_key: str | None = None, should_batch: bool = True, ) -> TraceServerClientInterface: + # Imported lazily so `import weave` does not pay for weave_server_sdk's + # model definitions until a client is actually initialized. + from weave.trace_server_bindings.remote_http_trace_server import ( + RemoteHTTPTraceServer, + ) + res = RemoteHTTPTraceServer.from_env(should_batch) if api_key is not None: res.set_auth(("api", api_key)) diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py index b7c9bdade5fd..5049b6f21479 100644 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ b/weave/trace_server_bindings/remote_http_trace_server.py @@ -1,22 +1,60 @@ +"""Remote trace server binding backed by the generated ``weave-server-sdk``. + +``RemoteHTTPTraceServer`` keeps its tsi-typed ``TraceServerClientInterface`` +surface, but delegates HTTP transport and request/response typing to the +``weave_server_sdk`` package (generated from the trace server's OpenAPI spec — +the source of truth for the API shape). + +Design notes: + +- A single ``httpx.Client`` (built like ``weave.utils.http_requests``: default + transport for env-proxy handling, no connection limits, ``ssl_verify()`` and + ``http_timeout()`` honored) is injected into the SDK so every request — + SDK-routed or raw — shares one connection pool, auth, and event hooks. +- A response event hook routes every non-2xx response through + ``handle_response_error`` *before* the SDK sees it, so callers observe the + exact same ``httpx.HTTPStatusError`` / ``CallsCompleteModeRequired`` + semantics as ``RemoteHTTPTraceServer`` (retry predicates, 413 batch + splitting, and calls_complete auto-upgrade all key off these). +- A request event hook injects the dynamic ``X-Weave-Retry-Id`` header at send + time so every retry attempt carries the current retry id. +- Endpoints the SDK cannot reach go through ``_raw_request``/``_raw_stream`` + with an explicit reason. Two categories: + 1. Endpoints excluded from the OpenAPI spec (``include_in_schema=False`` on + the server): calls_complete v2, eager v2 call start/end, completions, + project stats, TTL settings. + 2. weave-server-sdk 0.0.1 codegen bugs (duplicate method names where the + last definition wins, lost multipart body): single feedback create, obj + tag add/remove, /trace/usage, file create. Remove these hatches when a + fixed SDK ships. +- Streaming endpoints (``*_stream`` and the v2 jsonl list endpoints) use + ``_raw_stream`` because the published SDK buffers jsonl responses into + memory; the raw path preserves line-by-line streaming. +""" + +from __future__ import annotations + import datetime import io import logging -from collections.abc import Iterator -from typing import Any, cast +from collections.abc import Callable, Iterator +from typing import Any, TypeVar, cast from zoneinfo import ZoneInfo import httpx from pydantic import BaseModel, Field, validate_call from pydantic.json_schema import SkipJsonSchema from typing_extensions import Self +from weave_server_sdk import WeaveTrace +from weave_server_sdk import models as sdk_models -from weave.trace.env import weave_trace_server_url +from weave.trace.env import ssl_verify, weave_trace_server_url from weave.trace.settings import ( + http_timeout, max_calls_queue_size, should_enable_disk_fallback, should_use_calls_complete, ) -from weave.trace_server import http_service_interface as his from weave.trace_server import trace_server_interface as tsi from weave.trace_server.ids import generate_id from weave.trace_server.service_interface import ServerInfoRes @@ -38,23 +76,50 @@ EntityProjectInfo, StartBatchItem, ) -from weave.utils import http_requests +from weave.utils.http_requests import CLIENT_LIMITS, log_request, log_response from weave.utils.project_id import from_project_id from weave.utils.retry import get_current_retry_id, with_retry from weave.wandb_interface import project_creator +TRes = TypeVar("TRes", bound=BaseModel) + logger = logging.getLogger(__name__) -# Default timeout values (in seconds) -# DEFAULT_CONNECT_TIMEOUT = 10 -# DEFAULT_READ_TIMEOUT = 30 -# DEFAULT_TIMEOUT = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT) +# Endpoints reached via _raw_request/_raw_stream because they are excluded from +# the OpenAPI spec (include_in_schema=False on the server). +CALLS_COMPLETE_PATH = "/v2/{entity}/{project}/calls/complete" +CALL_START_V2_PATH = "/v2/{entity}/{project}/call/start" +CALL_END_V2_PATH = "/v2/{entity}/{project}/call/end" +COMPLETIONS_CREATE_PATH = "/completions/create" +PROJECT_STATS_PATH = "/project/stats" +PROJECT_TTL_SETTINGS_READ_PATH = "/project/ttl_settings/read" +PROJECT_TTL_SETTINGS_UPDATE_PATH = "/project/ttl_settings/update" +FEEDBACK_AGGREGATE_PATH = "/feedback/aggregate" + +# Endpoints reached via _raw_request because weave-server-sdk 0.0.1 cannot call +# them (duplicate generated method names where the last definition wins, or a +# lost multipart body). Remove once a fixed SDK is published. +FEEDBACK_CREATE_PATH = "/feedback/create" +FEEDBACK_BATCH_CREATE_PATH = "/feedback/batch/create" +TRACE_USAGE_PATH = "/trace/usage" +OBJ_ADD_TAGS_PATH = "/objs/{object_id}/versions/{digest}/tags" +OBJ_REMOVE_TAGS_PATH = "/objs/{object_id}/versions/{digest}/tags/remove" +FILE_CREATE_PATH = "/files/create" +FILE_CONTENT_PATH = "/files/content" + +# Streaming endpoints; the published SDK buffers jsonl bodies, so these are +# reached via _raw_stream to preserve line-by-line streaming. +CALLS_STREAM_QUERY_PATH = "/calls/stream_query" +THREADS_STREAM_QUERY_PATH = "/threads/stream_query" +ANNOTATION_QUEUES_QUERY_PATH = "/annotation_queues/query" +CALL_UPSERT_BATCH_PATH = "/call/upsert_batch" class RemoteHTTPTraceServer(TraceServerClientInterface): + """The weave-server-sdk-backed remote trace server binding.""" + trace_server_url: str - # My current batching is not safe in notebooks, disable it for now def __init__( self, trace_server_url: str, @@ -63,13 +128,24 @@ def __init__( remote_request_bytes_limit: int = REMOTE_REQUEST_BYTES_LIMIT, auth: tuple[str, str] | None = None, extra_headers: dict[str, str] | None = None, + transport: httpx.BaseTransport | None = None, ): super().__init__() - self.trace_server_url = trace_server_url + self.trace_server_url = trace_server_url.rstrip("/") self.should_batch = should_batch self.use_calls_complete = should_use_calls_complete() and should_batch self.call_processor: AsyncBatchProcessor | CallBatchProcessor | None = None self.feedback_processor: AsyncBatchProcessor | None = None + self._auth: tuple[str, str] | None = auth + self._extra_headers: dict[str, str] | None = extra_headers + self.remote_request_bytes_limit = remote_request_bytes_limit + # Test seam: lets tests (and in-process fixtures) substitute the HTTP + # transport while keeping the full client stack (hooks, SDK) in play. + self._transport = transport + + self._http = self._build_http_client() + self._sdk = WeaveTrace(http_client=self._http) + if self.should_batch: if self.use_calls_complete: self.call_processor = CallBatchProcessor( @@ -89,9 +165,48 @@ def __init__( max_queue_size=max_calls_queue_size(), enable_disk_fallback=should_enable_disk_fallback(), ) - self._auth: tuple[str, str] | None = auth - self._extra_headers: dict[str, str] | None = extra_headers - self.remote_request_bytes_limit = remote_request_bytes_limit + + # ---- transport --------------------------------------------------------- + + def _build_http_client(self) -> httpx.Client: + kwargs: dict[str, Any] = {} + if self._transport is not None: + kwargs["transport"] = self._transport + return httpx.Client( + base_url=self.trace_server_url, + auth=self._auth, + headers=self._extra_headers, + # Default transport (unless injected) so env proxy handling + # (incl. NO_PROXY) works natively. + event_hooks={ + "request": [log_request, self._inject_dynamic_headers], + "response": [log_response, self._raise_for_status], + }, + timeout=http_timeout(), + limits=CLIENT_LIMITS, + verify=ssl_verify(), + **kwargs, + ) + + def _inject_dynamic_headers(self, request: httpx.Request) -> None: + """Inject per-attempt headers at send time (httpx request hook).""" + if retry_id := get_current_retry_id(): + request.headers["X-Weave-Retry-Id"] = retry_id + + def _raise_for_status(self, response: httpx.Response) -> None: + """Surface error responses with RemoteHTTPTraceServer's exception + semantics (httpx response hook). + + Raising here means the SDK's own exception types never surface: retry + predicates, 413 batch splitting, and client code keep seeing + ``httpx.HTTPStatusError`` / ``CallsCompleteModeRequired`` exactly as + before. + """ + if response.status_code >= 400: + # Event hooks fire before the body is read; load it so + # handle_response_error can inspect the error payload. + response.read() + handle_response_error(response, str(response.request.url)) def ensure_project_exists( self, entity: str, project: str @@ -108,69 +223,128 @@ def from_env(cls, should_batch: bool = False) -> Self: def set_auth(self, auth: tuple[str, str]) -> None: self._auth = auth + self._http.auth = auth - def _build_dynamic_request_headers(self) -> dict[str, str]: - """Build headers for HTTP requests, including extra headers and retry ID.""" - headers = dict(self._extra_headers) if self._extra_headers else {} - if retry_id := get_current_retry_id(): - headers["X-Weave-Retry-Id"] = retry_id - return headers - - def get(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response: - headers = self._build_dynamic_request_headers() + # ---- request helpers ---------------------------------------------------- - return http_requests.get( - self.trace_server_url + url, - *args, - auth=self._auth, - headers=headers, - **kwargs, - ) + @with_retry + def _via_sdk( + self, + req: BaseModel, + sdk_req_type: type[BaseModel], + sdk_method: Callable[..., Any], + res_type: type[TRes], + **path_args: Any, + ) -> TRes: + """Round-trip a tsi request through the typed SDK binding. + + ``by_alias`` is required since query models have Mongo-style properties + aliased to start with ``$``. + """ + body = sdk_req_type.model_validate(req.model_dump(by_alias=True)) + sdk_res = sdk_method(body, **path_args) + if sdk_res is None: + return res_type() + return res_type.model_validate(sdk_res.model_dump(by_alias=True)) - def post(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response: - headers = self._build_dynamic_request_headers() + @with_retry + def _via_sdk_no_body( + self, + sdk_method: Callable[..., Any], + res_type: type[TRes], + **call_args: Any, + ) -> TRes: + """Call an SDK binding that takes only path/query arguments.""" + sdk_res = sdk_method(**call_args) + if sdk_res is None: + return res_type() + return res_type.model_validate(sdk_res.model_dump(by_alias=True)) - return http_requests.post( - self.trace_server_url + url, - *args, - auth=self._auth, - headers=headers, - **kwargs, + @with_retry + def _raw_request( + self, + method: str, + path: str, + *, + req: BaseModel | None = None, + params: dict[str, Any] | None = None, + res_type: type[TRes], + ) -> TRes: + """Call an endpoint the SDK cannot reach (see module docstring).""" + r = self._http.request( + method, + path, + content=req.model_dump_json(by_alias=True).encode("utf-8") + if req is not None + else None, + headers={"content-type": "application/json"} if req is not None else None, + params=params, ) + return res_type.model_validate(r.json()) - def put(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response: - headers = self._build_dynamic_request_headers() - - return http_requests.put( - self.trace_server_url + url, - *args, - auth=self._auth, - headers={**headers, **kwargs.pop("headers", {})}, - **kwargs, + @with_retry + def _raw_post_bytes(self, path: str, encoded_data: bytes) -> None: + """POST pre-encoded json bytes (batch flush hot path; no re-parse).""" + self._http.post( + path, content=encoded_data, headers={"content-type": "application/json"} ) - def delete(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response: - headers = self._build_dynamic_request_headers() + def _raw_stream( + self, + method: str, + path: str, + *, + req: BaseModel | None = None, + params: dict[str, Any] | None = None, + res_type: type[TRes], + ) -> Iterator[TRes]: + """Stream a jsonl endpoint line-by-line (the SDK buffers jsonl).""" + r = self._open_stream(method, path, req=req, params=params) + try: + for line in r.iter_lines(): + if line: + yield res_type.model_validate_json(line) + finally: + r.close() - return http_requests.delete( - self.trace_server_url + url, - *args, - auth=self._auth, - headers=headers, - **kwargs, + @with_retry + def _open_stream( + self, + method: str, + path: str, + *, + req: BaseModel | None = None, + params: dict[str, Any] | None = None, + ) -> httpx.Response: + """Open a streaming response; retries cover connection/headers only. + + Mid-stream failures are not retried, matching RemoteHTTPTraceServer. + The caller owns the returned response and must close() it. + """ + request = self._http.build_request( + method, + path, + content=req.model_dump_json(by_alias=True).encode("utf-8") + if req is not None + else None, + headers={"content-type": "application/json"} if req is not None else None, + params=params, ) + return self._http.send(request, stream=True) + + # ---- batching ----------------------------------------------------------- @with_retry def _send_batch_to_server(self, encoded_data: bytes) -> None: - """Send a batch of data to the server with retry logic. + """Send an encoded batch of calls to the server with retry logic. - This method is separated from _flush_calls to avoid recursive retries. + Separated from _flush_calls to avoid recursive retries. """ - r = self.post( - "/call/upsert_batch", - data=encoded_data, # type: ignore + self._http.post( + CALL_UPSERT_BATCH_PATH, + content=encoded_data, + headers={"content-type": "application/json"}, ) - handle_response_error(r, "/call/upsert_batch") def _flush_calls( self, @@ -181,9 +355,9 @@ def _flush_calls( """Process a batch of calls, splitting if necessary and sending to the server. This method handles the logic of splitting batches that are too large, - but delegates the actual server communication (with retries) to _send_batch_to_server. + but delegates the actual server communication (with retries) to + _send_batch_to_server. """ - # Call processor must be defined for this method assert self.call_processor is not None if len(batch) == 0: return @@ -220,13 +394,10 @@ def _upgrade_to_calls_complete( ) -> None: """Upgrade from legacy AsyncBatchProcessor to CallBatchProcessor. - This is called when the server indicates a project requires calls_complete mode. - The upgrade happens transparently: we replace the processor and re-enqueue - the current batch items. No calls are dropped during this upgrade. - - Args: - batch: The batch of items that failed to send (will be re-enqueued). - error_message: The error message from the server (for logging). + This is called when the server indicates a project requires + calls_complete mode. The upgrade happens transparently: we replace the + processor and re-enqueue the current batch items. No calls are dropped + during this upgrade. """ # Already upgraded? Just re-enqueue to the new processor if self.use_calls_complete: @@ -241,10 +412,8 @@ def _upgrade_to_calls_complete( error_message, ) - # Store old processor reference for cleanup old_processor = self.call_processor - # Create new CallBatchProcessor self.use_calls_complete = True self.call_processor = CallBatchProcessor( complete_processor_fn=self._flush_calls_complete, @@ -277,9 +446,9 @@ def _flush_calls_eager( to be visible immediately in the UI. Uses single call/start and call/end endpoints for easier rate limiting. - Each item is sent individually with retry logic (@with_retry). If all retries - are exhausted, the item is logged and dropped, then processing continues - with remaining items in the batch. + Each item is sent individually with retry logic (@with_retry). If all + retries are exhausted, the item is logged and dropped, then processing + continues with remaining items in the batch. """ for item in batch: try: @@ -296,22 +465,24 @@ def _flush_calls_eager( @with_retry def _send_call_start_v2(self, start: tsi.StartedCallSchemaForInsert) -> None: """Send a single call start to the v2 endpoint.""" - project_id = start.project_id - entity, project = project_id.split("/", 1) - url = f"/v2/{entity}/{project}/call/start" + entity, project = from_project_id(start.project_id) req = tsi.CallStartV2Req(start=start) - r = self.post(url, data=req.model_dump_json().encode("utf-8")) - handle_response_error(r, url) + self._http.post( + CALL_START_V2_PATH.format(entity=entity, project=project), + content=req.model_dump_json().encode("utf-8"), + headers={"content-type": "application/json"}, + ) @with_retry def _send_call_end_v2(self, end: tsi.EndedCallSchemaForInsertWithStartedAt) -> None: """Send a single call end to the v2 endpoint.""" - project_id = end.project_id - entity, project = project_id.split("/", 1) - url = f"/v2/{entity}/{project}/call/end" + entity, project = from_project_id(end.project_id) req = tsi.CallEndV2Req(end=end) - r = self.post(url, data=req.model_dump_json().encode("utf-8")) - handle_response_error(r, url) + self._http.post( + CALL_END_V2_PATH.format(entity=entity, project=project), + content=req.model_dump_json().encode("utf-8"), + headers={"content-type": "application/json"}, + ) def _extract_entity_project( self, batch: list[CompleteBatchItem] @@ -334,14 +505,13 @@ def _extract_entity_project( return EntityProjectInfo(entity=entity, project=project, project_id=project_id) - @with_retry def _send_calls_complete_to_server( self, entity: str, project: str, encoded_data: bytes ) -> None: """Send a batch of completed calls to the server with retry logic.""" - url = f"/v2/{entity}/{project}/calls/complete" - r = self.post(url, data=encoded_data) - handle_response_error(r, url) + self._raw_post_bytes( + CALLS_COMPLETE_PATH.format(entity=entity, project=project), encoded_data + ) def _flush_calls_complete( self, @@ -349,7 +519,7 @@ def _flush_calls_complete( *, _should_update_batch_size: bool = True, ) -> None: - """Process a batch of complete calls and send to the calls/upsert endpoint. + """Process a batch of complete calls and send to the calls/complete endpoint. This is the new calls_complete path. Complete calls have both start and end information bundled together. @@ -389,15 +559,17 @@ def get_call_processor(self) -> AsyncBatchProcessor | CallBatchProcessor | None: return self.call_processor def _send_feedback_batch_to_server(self, encoded_data: bytes) -> None: - """Send a batch of feedback data to the server with retry logic. + """Send a batch of feedback data to the server. - This method is separated from _flush_feedback to avoid recursive retries. + No request-level retry here: failures are classified by the caller + (404 falls back to individual creates; retryable errors requeue at the + batch-processor level). """ - r = self.post( - "/feedback/batch/create", - data=encoded_data, # type: ignore + self._http.post( + FEEDBACK_BATCH_CREATE_PATH, + content=encoded_data, + headers={"content-type": "application/json"}, ) - handle_response_error(r, "/feedback/batch/create") def _flush_feedback( self, @@ -406,9 +578,9 @@ def _flush_feedback( """Process a batch of feedback, splitting if necessary and sending to the server. This method handles the logic of splitting batches that are too large, - but delegates the actual server communication (with retries) to _send_feedback_batch_to_server. + but delegates the actual server communication (with retries) to + _send_feedback_batch_to_server. """ - # Feedback processor must be defined for this method assert self.feedback_processor is not None if len(batch) == 0: return @@ -445,11 +617,11 @@ class FeedbackCreateReqStripped(tsi.FeedbackCreateReq): for item in batch: item_copy = FeedbackCreateReqStripped(**item.model_dump()) try: - self._generic_request( - "/feedback/create", - item_copy, - FeedbackCreateReqStripped, - tsi.FeedbackCreateRes, + self._raw_request( + "POST", + FEEDBACK_CREATE_PATH, + req=item_copy, + res_type=tsi.FeedbackCreateRes, ) except Exception as individual_error: logger.warning( @@ -478,127 +650,26 @@ def get_feedback_processor(self) -> AsyncBatchProcessor | None: """ return self.feedback_processor - @with_retry - def _post_request_executor( - self, - url: str, - req: BaseModel, - stream: bool = False, - ) -> httpx.Response: - r = self.post( - url, - # `by_alias` is required since we have Mongo-style properties in the - # query models that are aliased to conform to start with `$`. Without - # this, the model_dump will use the internal property names which are - # not valid for the `model_validate` step. - data=req.model_dump_json(by_alias=True).encode("utf-8"), - stream=stream, - ) - handle_response_error(r, url) - return r - - @with_retry - def _get_request_executor( - self, - url: str, - params: dict[str, Any] | None = None, - stream: bool = False, - ) -> httpx.Response: - r = self.get(url, params=params or {}, stream=stream) - handle_response_error(r, url) - return r - - def _put_request_executor( - self, - url: str, - req: BaseModel, - stream: bool = False, - ) -> httpx.Response: - r = self.put( - url, - data=req.model_dump_json(by_alias=True).encode("utf-8"), - stream=stream, - ) - handle_response_error(r, url) - return r - - @with_retry - def _delete_request_executor( - self, - url: str, - params: dict[str, Any] | None = None, - stream: bool = False, - ) -> httpx.Response: - r = self.delete(url, params=params or {}, stream=stream) - handle_response_error(r, url) - return r - - def _generic_request( - self, - url: str, - req: BaseModel, - req_model: type[BaseModel], - res_model: type[BaseModel], - method: str = "POST", - params: dict[str, Any] | None = None, - ) -> BaseModel: - if method == "POST": - r = self._post_request_executor(url, req) - elif method == "PUT": - r = self._put_request_executor(url, req) - elif method == "GET": - r = self._get_request_executor(url, params) - elif method == "DELETE": - r = self._delete_request_executor(url, params) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - return res_model.model_validate(r.json()) - - def _generic_stream_request( - self, - url: str, - req: BaseModel, - req_model: type[BaseModel], - res_model: type[BaseModel], - method: str = "POST", - params: dict[str, Any] | None = None, - ) -> Iterator[BaseModel]: - if method == "POST": - r = self._post_request_executor(url, req, stream=True) - elif method == "GET": - r = self._get_request_executor(url, params, stream=True) - elif method == "DELETE": - r = self._delete_request_executor(url, params, stream=True) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - try: - for line in r.iter_lines(): - if line: - yield res_model.model_validate_json(line) - finally: - r.close() + # ---- service ------------------------------------------------------------ @with_retry def server_info(self) -> ServerInfoRes: - r = self.get( - "/server_info", - ) - handle_response_error(r, "/server_info") - return ServerInfoRes.model_validate(r.json()) + res = self._sdk.services.server_info() + return ServerInfoRes.model_validate(res.model_dump()) @validate_call + @with_retry def projects_info(self, req: tsi.ProjectsInfoReq) -> list[tsi.ProjectsInfoRes]: - r = self._post_request_executor("/service/projects_info", req) - handle_response_error(r, "/service/projects_info") - return [tsi.ProjectsInfoRes.model_validate(item) for item in r.json()] + body = sdk_models.ProjectsInfoReq.model_validate(req.model_dump()) + res = self._sdk.service.create_projects_info(body) + return [tsi.ProjectsInfoRes.model_validate(item.model_dump()) for item in res] def otel_export(self, req: tsi.OTelExportReq) -> tsi.OTelExportRes: # TODO: Add docs link (DOCS-1390) raise NotImplementedError("Sending otel traces directly is not yet supported.") - # Call API + # ---- Call API ------------------------------------------------------------- + @validate_call def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: if self.should_batch: @@ -610,13 +681,16 @@ def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: ) self.call_processor.enqueue_start(StartBatchItem(req=req)) return tsi.CallStartRes(id=req.start.id, trace_id=req.start.trace_id) - return self._generic_request( - "/call/start", req, tsi.CallStartReq, tsi.CallStartRes + return self._via_sdk( + req, sdk_models.CallStartReq, self._sdk.calls.start, tsi.CallStartRes ) def call_start_batch(self, req: tsi.CallCreateBatchReq) -> tsi.CallCreateBatchRes: - return self._generic_request( - "/call/upsert_batch", req, tsi.CallCreateBatchReq, tsi.CallCreateBatchRes + return self._via_sdk( + req, + sdk_models.CallCreateBatchReq, + self._sdk.calls.upsert_batch, + tsi.CallCreateBatchRes, ) @validate_call @@ -626,12 +700,14 @@ def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: self.call_processor.enqueue([EndBatchItem(req=req)]) return tsi.CallEndRes() - return self._generic_request("/call/end", req, tsi.CallEndReq, tsi.CallEndRes) + return self._via_sdk( + req, sdk_models.CallEndReq, self._sdk.calls.end, tsi.CallEndRes + ) @validate_call def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes: - return self._generic_request( - "/call/read", req, tsi.CallReadReq, tsi.CallReadRes + return self._via_sdk( + req, sdk_models.CallReadReq, self._sdk.calls.read, tsi.CallReadRes ) @validate_call @@ -641,129 +717,138 @@ def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: @validate_call def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]: - return self._generic_stream_request( - "/calls/stream_query", req, tsi.CallsQueryReq, tsi.CallSchema + return self._raw_stream( + "POST", CALLS_STREAM_QUERY_PATH, req=req, res_type=tsi.CallSchema ) @validate_call def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: - return self._generic_request( - "/calls/query_stats", req, tsi.CallsQueryStatsReq, tsi.CallsQueryStatsRes + return self._via_sdk( + req, + sdk_models.CallsQueryStatsReq, + self._sdk.calls.query_stats, + tsi.CallsQueryStatsRes, ) @validate_call def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes: - return self._generic_request( - "/trace/usage", req, tsi.TraceUsageReq, tsi.TraceUsageRes + # SDK 0.0.1: calls.create_usage for /trace/usage is shadowed by the + # /calls/usage overload of the same generated name. + return self._raw_request( + "POST", TRACE_USAGE_PATH, req=req, res_type=tsi.TraceUsageRes ) @validate_call def calls_usage(self, req: tsi.CallsUsageReq) -> tsi.CallsUsageRes: - return self._generic_request( - "/calls/usage", req, tsi.CallsUsageReq, tsi.CallsUsageRes + return self._via_sdk( + req, + sdk_models.CallsUsageReq, + self._sdk.calls.create_usage, + tsi.CallsUsageRes, ) @validate_call def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: - return self._generic_request( - "/calls/delete", req, tsi.CallsDeleteReq, tsi.CallsDeleteRes + return self._via_sdk( + req, sdk_models.CallsDeleteReq, self._sdk.calls.delete, tsi.CallsDeleteRes ) @validate_call def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: - return self._generic_request( - "/call/update", req, tsi.CallUpdateReq, tsi.CallUpdateRes + return self._via_sdk( + req, sdk_models.CallUpdateReq, self._sdk.calls.update, tsi.CallUpdateRes ) - # Obj API + # ---- Obj API -------------------------------------------------------------- @validate_call def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: - return self._generic_request( - "/obj/create", req, tsi.ObjCreateReq, tsi.ObjCreateRes + return self._via_sdk( + req, sdk_models.ObjCreateReq, self._sdk.objects.create, tsi.ObjCreateRes ) @validate_call def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: - return self._generic_request("/obj/read", req, tsi.ObjReadReq, tsi.ObjReadRes) + return self._via_sdk( + req, sdk_models.ObjReadReq, self._sdk.objects.read, tsi.ObjReadRes + ) @validate_call def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: - return self._generic_request( - "/objs/query", req, tsi.ObjQueryReq, tsi.ObjQueryRes + return self._via_sdk( + req, sdk_models.ObjQueryReq, self._sdk.objects.query, tsi.ObjQueryRes ) def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: - return self._generic_request( - "/obj/delete", req, tsi.ObjDeleteReq, tsi.ObjDeleteRes + return self._via_sdk( + req, sdk_models.ObjDeleteReq, self._sdk.objects.delete, tsi.ObjDeleteRes ) def obj_add_tags(self, req: tsi.ObjAddTagsReq) -> tsi.ObjAddTagsRes: - body = his.ObjTagsBody(project_id=req.project_id, tags=req.tags) - return self._generic_request( - f"/objs/{req.object_id}/versions/{req.digest}/tags", - body, - his.ObjTagsBody, - tsi.ObjAddTagsRes, - method="PUT", + # SDK 0.0.1: objects.tags for add-tags is shadowed by the /tags list + # overload of the same generated name. + body = sdk_models.ObjTagsBody(project_id=req.project_id, tags=req.tags) + return self._raw_request( + "PUT", + OBJ_ADD_TAGS_PATH.format(object_id=req.object_id, digest=req.digest), + req=body, + res_type=tsi.ObjAddTagsRes, ) def obj_remove_tags(self, req: tsi.ObjRemoveTagsReq) -> tsi.ObjRemoveTagsRes: - body = his.ObjTagsBody(project_id=req.project_id, tags=req.tags) - return self._generic_request( - f"/objs/{req.object_id}/versions/{req.digest}/tags/remove", - body, - his.ObjTagsBody, - tsi.ObjRemoveTagsRes, + # SDK 0.0.1: objects.create_remove for remove-tags is shadowed by the + # remove-aliases overload of the same generated name. + body = sdk_models.ObjTagsBody(project_id=req.project_id, tags=req.tags) + return self._raw_request( + "POST", + OBJ_REMOVE_TAGS_PATH.format(object_id=req.object_id, digest=req.digest), + req=body, + res_type=tsi.ObjRemoveTagsRes, ) def obj_set_aliases(self, req: tsi.ObjSetAliasesReq) -> tsi.ObjSetAliasesRes: - body = his.ObjSetAliasesBody( + body = sdk_models.ObjSetAliasesBody( project_id=req.project_id, digest=req.digest, aliases=req.aliases ) - return self._generic_request( - f"/objs/{req.object_id}/aliases", - body, - his.ObjSetAliasesBody, + res = self._via_sdk_no_body( + self._sdk.objects.update_aliases, tsi.ObjSetAliasesRes, - method="PUT", + body=body, + object_id=req.object_id, ) + return res def obj_remove_aliases( self, req: tsi.ObjRemoveAliasesReq ) -> tsi.ObjRemoveAliasesRes: - body = his.ObjRemoveAliasesBody(project_id=req.project_id, aliases=req.aliases) - return self._generic_request( - f"/objs/{req.object_id}/aliases/remove", - body, - his.ObjRemoveAliasesBody, + body = sdk_models.ObjRemoveAliasesBody( + project_id=req.project_id, aliases=req.aliases + ) + return self._via_sdk_no_body( + self._sdk.objects.create_remove, tsi.ObjRemoveAliasesRes, + body=body, + object_id=req.object_id, ) def tags_list(self, req: tsi.TagsListReq) -> tsi.TagsListRes: - return self._generic_request( - "/tags", - req, - tsi.TagsListReq, - tsi.TagsListRes, - method="GET", - params={"project_id": req.project_id}, + return self._via_sdk_no_body( + self._sdk.objects.tags, tsi.TagsListRes, project_id=req.project_id ) def aliases_list(self, req: tsi.AliasesListReq) -> tsi.AliasesListRes: - return self._generic_request( - "/aliases", - req, - tsi.AliasesListReq, + return self._via_sdk_no_body( + self._sdk.objects.list_aliases, tsi.AliasesListRes, - method="GET", - params={"project_id": req.project_id}, + project_id=req.project_id, ) + # ---- Table API ------------------------------------------------------------ + @validate_call def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: - return self._generic_request( - "/table/create", req, tsi.TableCreateReq, tsi.TableCreateRes + return self._via_sdk( + req, sdk_models.TableCreateReq, self._sdk.tables.create, tsi.TableCreateRes ) @validate_call @@ -794,14 +879,17 @@ def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: digest=second_half_res.digest, updated_row_digests=all_digests ) else: - return self._generic_request( - "/table/update", req, tsi.TableUpdateReq, tsi.TableUpdateRes + return self._via_sdk( + req, + sdk_models.TableUpdateReq, + self._sdk.tables.update, + tsi.TableUpdateRes, ) @validate_call def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: - return self._generic_request( - "/table/query", req, tsi.TableQueryReq, tsi.TableQueryRes + return self._via_sdk( + req, sdk_models.TableQueryReq, self._sdk.tables.query, tsi.TableQueryRes ) @validate_call @@ -814,8 +902,11 @@ def table_query_stream( @validate_call def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes: - return self._generic_request( - "/table/query_stats", req, tsi.TableQueryStatsReq, tsi.TableQueryStatsRes + return self._via_sdk( + req, + sdk_models.TableQueryStatsReq, + self._sdk.tables.query_stats, + tsi.TableQueryStatsRes, ) @validate_call @@ -823,10 +914,25 @@ def table_create_from_digests( self, req: tsi.TableCreateFromDigestsReq ) -> tsi.TableCreateFromDigestsRes: """Create a table by specifying row digests instead of actual rows.""" - return self._generic_request( - "/table/create_from_digests", + return self._via_sdk( + req, + sdk_models.TableCreateFromDigestsReq, + self._sdk.tables.create_create_from_digests, + tsi.TableCreateFromDigestsRes, + ) + + def unretried_table_create_from_digests( + self, req: tsi.TableCreateFromDigestsReq + ) -> tsi.TableCreateFromDigestsRes: + """Single-attempt variant used by the client's endpoint-availability + probe: a missing endpoint (404) must fail fast, and a flaky probe must + not stall table saves behind retries. + """ + return self._via_sdk.__wrapped__( # type: ignore[attr-defined] + self, req, - tsi.TableCreateFromDigestsReq, + sdk_models.TableCreateFromDigestsReq, + self._sdk.tables.create_create_from_digests, tsi.TableCreateFromDigestsRes, ) @@ -834,50 +940,62 @@ def table_create_from_digests( def table_query_stats_batch( self, req: tsi.TableQueryStatsReq ) -> tsi.TableQueryStatsRes: - return self._generic_request( - "/table/query_stats_batch", + return self._via_sdk( req, - tsi.TableQueryStatsBatchReq, + sdk_models.TableQueryStatsBatchReq, + self._sdk.tables.create_query_stats_batch, tsi.TableQueryStatsBatchRes, ) @validate_call def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: - return self._generic_request( - "/refs/read_batch", req, tsi.RefsReadBatchReq, tsi.RefsReadBatchRes + return self._via_sdk( + req, + sdk_models.RefsReadBatchReq, + self._sdk.refs.read_batch, + tsi.RefsReadBatchRes, ) + # ---- File API ------------------------------------------------------------- + @with_retry def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + # SDK 0.0.1: files.create lost its multipart body in generation; post + # the multipart form directly. data: dict[str, str] = {"project_id": req.project_id} if req.expected_digest is not None: data["expected_digest"] = req.expected_digest - r = self.post( - "/files/create", + r = self._http.post( + FILE_CREATE_PATH, data=data, files={"file": (req.name, req.content)}, ) - handle_response_error(r, "/files/create") return tsi.FileCreateRes.model_validate(r.json()) @with_retry def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: - r = self.post( - "/files/content", - json={"project_id": req.project_id, "digest": req.digest}, - ) - handle_response_error(r, "/files/content") - # TODO: Should stream to disk rather than to memory - bytes = io.BytesIO() - bytes.writelines(r.iter_bytes()) - bytes.seek(0) - return tsi.FileContentReadRes(content=bytes.read()) + # Raw to keep the response streamed to a buffer rather than the SDK's + # whole-body bytes return. + r = self._open_stream("POST", FILE_CONTENT_PATH, req=req) + try: + # TODO: Should stream to disk rather than to memory + bytes_buffer = io.BytesIO() + for chunk in r.iter_bytes(): + bytes_buffer.write(chunk) + finally: + r.close() + return tsi.FileContentReadRes(content=bytes_buffer.getvalue()) def files_stats(self, req: tsi.FilesStatsReq) -> tsi.FilesStatsRes: - return self._generic_request( - "/files/stats", req, tsi.FilesStatsReq, tsi.FilesStatsRes + return self._via_sdk( + req, + sdk_models.FilesStatsReq, + self._sdk.files.query_stats, + tsi.FilesStatsRes, ) + # ---- Feedback API ----------------------------------------------------------- + @validate_call def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: if self.should_batch: @@ -895,42 +1013,58 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: payload=req.payload, ) else: - return self._generic_request( - "/feedback/create", req, tsi.FeedbackCreateReq, tsi.FeedbackCreateRes + # SDK 0.0.1: feedback.create for single create is shadowed by the + # batch-create overload of the same generated name. + return self._raw_request( + "POST", FEEDBACK_CREATE_PATH, req=req, res_type=tsi.FeedbackCreateRes ) def feedback_create_batch( self, req: tsi.FeedbackCreateBatchReq ) -> tsi.FeedbackCreateBatchRes: - return self._generic_request( - "/feedback/batch/create", + # Note: the SDK method is named `create` for /feedback/batch/create + # (duplicate-name shadowing in 0.0.1; the batch overload won). + return self._via_sdk( req, - tsi.FeedbackCreateBatchReq, + sdk_models.FeedbackCreateBatchReq, + self._sdk.feedback.create, tsi.FeedbackCreateBatchRes, ) @validate_call def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: - return self._generic_request( - "/feedback/query", req, tsi.FeedbackQueryReq, tsi.FeedbackQueryRes + return self._via_sdk( + req, + sdk_models.FeedbackQueryReq, + self._sdk.feedback.query, + tsi.FeedbackQueryRes, ) @validate_call def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: - return self._generic_request( - "/feedback/purge", req, tsi.FeedbackPurgeReq, tsi.FeedbackPurgeRes + return self._via_sdk( + req, + sdk_models.FeedbackPurgeReq, + self._sdk.feedback.purge, + tsi.FeedbackPurgeRes, ) @validate_call def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: - return self._generic_request( - "/feedback/replace", req, tsi.FeedbackReplaceReq, tsi.FeedbackReplaceRes + return self._via_sdk( + req, + sdk_models.FeedbackReplaceReq, + self._sdk.feedback.replace, + tsi.FeedbackReplaceRes, ) @validate_call def feedback_stats(self, req: tsi.FeedbackStatsReq) -> tsi.FeedbackStatsRes: - return self._generic_request( - "/feedback/stats", req, tsi.FeedbackStatsReq, tsi.FeedbackStatsRes + return self._via_sdk( + req, + sdk_models.FeedbackStatsReq, + self._sdk.feedback.create_stats, + tsi.FeedbackStatsRes, ) @validate_call @@ -938,51 +1072,50 @@ def feedback_aggregate( self, req: tsi.FeedbackAggregateReq ) -> tsi.FeedbackAggregateRes: """Query the feedback table for aggregate scores over time.""" - return self._generic_request( - "/feedback/aggregate", - req, - tsi.FeedbackAggregateReq, - tsi.FeedbackAggregateRes, + # Not yet present in the published SDK. + return self._raw_request( + "POST", FEEDBACK_AGGREGATE_PATH, req=req, res_type=tsi.FeedbackAggregateRes ) @validate_call def feedback_payload_schema( self, req: tsi.FeedbackPayloadSchemaReq ) -> tsi.FeedbackPayloadSchemaRes: - return self._generic_request( - "/feedback/payload_schema", + return self._via_sdk( req, - tsi.FeedbackPayloadSchemaReq, + sdk_models.FeedbackPayloadSchemaReq, + self._sdk.feedback.create_payload_schema, tsi.FeedbackPayloadSchemaRes, ) - # Cost API + # ---- Cost API --------------------------------------------------------------- + @validate_call def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: - return self._generic_request( - "/cost/query", req, tsi.CostQueryReq, tsi.CostQueryRes + return self._via_sdk( + req, sdk_models.CostQueryReq, self._sdk.costs.query, tsi.CostQueryRes ) @validate_call def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: - return self._generic_request( - "/cost/create", req, tsi.CostCreateReq, tsi.CostCreateRes + return self._via_sdk( + req, sdk_models.CostCreateReq, self._sdk.costs.create, tsi.CostCreateRes ) @validate_call def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: - return self._generic_request( - "/cost/purge", req, tsi.CostPurgeReq, tsi.CostPurgeRes + return self._via_sdk( + req, sdk_models.CostPurgeReq, self._sdk.costs.purge, tsi.CostPurgeRes ) + # ---- Execution APIs --------------------------------------------------------- + def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: - return self._generic_request( - "/completions/create", - req, - tsi.CompletionsCreateReq, - tsi.CompletionsCreateRes, + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", COMPLETIONS_CREATE_PATH, req=req, res_type=tsi.CompletionsCreateRes ) def completions_create_stream( @@ -996,169 +1129,167 @@ def completions_create_stream( def image_create( self, req: tsi.ImageGenerationCreateReq ) -> tsi.ImageGenerationCreateRes: - return self._generic_request( - "/image/create", + return self._via_sdk( req, - tsi.ImageGenerationCreateReq, + sdk_models.ImageGenerationCreateReq, + self._sdk.images.create, tsi.ImageGenerationCreateRes, ) def project_stats(self, req: tsi.ProjectStatsReq) -> tsi.ProjectStatsRes: - return self._generic_request( - "/project/stats", req, tsi.ProjectStatsReq, tsi.ProjectStatsRes + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", PROJECT_STATS_PATH, req=req, res_type=tsi.ProjectStatsRes ) def project_ttl_settings_read( self, req: tsi.ProjectTTLSettingsReadReq ) -> tsi.ProjectTTLSettingsReadRes: - return self._generic_request( - "/project/ttl_settings/read", - req, - tsi.ProjectTTLSettingsReadReq, - tsi.ProjectTTLSettingsReadRes, + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + PROJECT_TTL_SETTINGS_READ_PATH, + req=req, + res_type=tsi.ProjectTTLSettingsReadRes, ) def project_ttl_settings_update( self, req: tsi.ProjectTTLSettingsUpdateReq ) -> tsi.ProjectTTLSettingsUpdateRes: - return self._generic_request( - "/project/ttl_settings/update", - req, - tsi.ProjectTTLSettingsUpdateReq, - tsi.ProjectTTLSettingsUpdateRes, + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + PROJECT_TTL_SETTINGS_UPDATE_PATH, + req=req, + res_type=tsi.ProjectTTLSettingsUpdateRes, ) def threads_query_stream( self, req: tsi.ThreadsQueryReq ) -> Iterator[tsi.ThreadSchema]: - return self._generic_stream_request( - "/threads/stream_query", req, tsi.ThreadsQueryReq, tsi.ThreadSchema + return self._raw_stream( + "POST", THREADS_STREAM_QUERY_PATH, req=req, res_type=tsi.ThreadSchema ) - # Annotation Queue API + # ---- Annotation Queue API ----------------------------------------------------- + def annotation_queue_create( self, req: tsi.AnnotationQueueCreateReq ) -> tsi.AnnotationQueueCreateRes: - return self._generic_request( - "/annotation_queues", + return self._via_sdk( req, - tsi.AnnotationQueueCreateReq, + sdk_models.AnnotationQueueCreateReq, + self._sdk.annotation_queues.create_annotation_queues, tsi.AnnotationQueueCreateRes, ) def annotation_queues_query_stream( self, req: tsi.AnnotationQueuesQueryReq ) -> Iterator[tsi.AnnotationQueueSchema]: - return self._generic_stream_request( - "/annotation_queues/query", - req, - tsi.AnnotationQueuesQueryReq, - tsi.AnnotationQueueSchema, + return self._raw_stream( + "POST", + ANNOTATION_QUEUES_QUERY_PATH, + req=req, + res_type=tsi.AnnotationQueueSchema, ) def annotation_queue_read( self, req: tsi.AnnotationQueueReadReq ) -> tsi.AnnotationQueueReadRes: - return self._generic_request( - f"/annotation_queues/{req.queue_id}", - req, - tsi.AnnotationQueueReadReq, + return self._via_sdk_no_body( + self._sdk.annotation_queues.list_annotation_queues, tsi.AnnotationQueueReadRes, - method="GET", - params={"project_id": req.project_id}, + queue_id=req.queue_id, + project_id=req.project_id, ) def annotation_queue_delete( self, req: tsi.AnnotationQueueDeleteReq ) -> tsi.AnnotationQueueDeleteRes: - return self._generic_request( - f"/annotation_queues/{req.queue_id}", - req, - tsi.AnnotationQueueDeleteReq, + return self._via_sdk_no_body( + self._sdk.annotation_queues.delete_annotation_queues, tsi.AnnotationQueueDeleteRes, - method="DELETE", - params={"project_id": req.project_id}, + queue_id=req.queue_id, + project_id=req.project_id, ) def annotation_queue_update( self, req: tsi.AnnotationQueueUpdateReq ) -> tsi.AnnotationQueueUpdateRes: - # Convert to Body type to exclude queue_id from request body (it's in the URL path) - body = his.AnnotationQueueUpdateBody( + # Body type excludes queue_id from the request body (it's in the URL path) + body = sdk_models.AnnotationQueueUpdateBody( project_id=req.project_id, name=req.name, description=req.description, scorer_refs=req.scorer_refs, ) - return self._generic_request( - f"/annotation_queues/{req.queue_id}", - body, - his.AnnotationQueueUpdateBody, + return self._via_sdk_no_body( + self._sdk.annotation_queues.update_annotation_queues, tsi.AnnotationQueueUpdateRes, - method="PUT", + body=body, + queue_id=req.queue_id, ) def annotation_queue_add_calls( self, req: tsi.AnnotationQueueAddCallsReq ) -> tsi.AnnotationQueueAddCallsRes: - # Convert to Body type to exclude queue_id from request body (it's in the URL path) - body = his.AnnotationQueueAddCallsBody( + # Body type excludes queue_id from the request body (it's in the URL path) + body = sdk_models.AnnotationQueueAddCallsBody( project_id=req.project_id, call_ids=req.call_ids, display_fields=req.display_fields, ) - return self._generic_request( - f"/annotation_queues/{req.queue_id}/items", - body, - his.AnnotationQueueAddCallsBody, + return self._via_sdk_no_body( + self._sdk.annotation_queues.create_items, tsi.AnnotationQueueAddCallsRes, + body=body, + queue_id=req.queue_id, ) def annotation_queue_items_query( self, req: tsi.AnnotationQueueItemsQueryReq ) -> tsi.AnnotationQueueItemsQueryRes: - # Convert to Body type to exclude queue_id from request body (it's in the URL path) - body = his.AnnotationQueueItemsQueryBody( - project_id=req.project_id, - filter=req.filter, - sort_by=req.sort_by, - limit=req.limit, - offset=req.offset, - include_position=req.include_position, - ) - return self._generic_request( - f"/annotation_queues/{req.queue_id}/items/query", - body, - his.AnnotationQueueItemsQueryBody, + # Body type excludes queue_id from the request body (it's in the URL path) + body = sdk_models.AnnotationQueueItemsQueryBody.model_validate( + req.model_dump(exclude={"queue_id"}, by_alias=True) + ) + return self._via_sdk_no_body( + self._sdk.annotation_queues.query, tsi.AnnotationQueueItemsQueryRes, + body=body, + queue_id=req.queue_id, ) def annotation_queues_stats( self, req: tsi.AnnotationQueuesStatsReq ) -> tsi.AnnotationQueuesStatsRes: - return self._generic_request( - "/annotation_queues/stats", + return self._via_sdk( req, - tsi.AnnotationQueuesStatsReq, + sdk_models.AnnotationQueuesStatsReq, + self._sdk.annotation_queues.create_stats, tsi.AnnotationQueuesStatsRes, ) def annotator_queue_items_progress_update( self, req: tsi.AnnotatorQueueItemsProgressUpdateReq ) -> tsi.AnnotatorQueueItemsProgressUpdateRes: - # Convert to Body type to exclude queue_id, item_id, and wb_user_id from request body - # (queue_id and item_id are in the URL path, wb_user_id is set server-side from auth) - body = his.AnnotationQueueItemProgressUpdateBody( + # Body type excludes queue_id, item_id, and wb_user_id from the request + # body (queue_id and item_id are in the URL path, wb_user_id is set + # server-side from auth) + body = sdk_models.AnnotationQueueItemProgressUpdateBody( project_id=req.project_id, annotation_state=req.annotation_state, ) - return self._generic_request( - f"/annotation_queues/{req.queue_id}/items/{req.item_id}/progress", - body, - his.AnnotationQueueItemProgressUpdateBody, + return self._via_sdk_no_body( + self._sdk.annotation_queues.create_progress, tsi.AnnotatorQueueItemsProgressUpdateRes, + body=body, + queue_id=req.queue_id, + item_id=req.item_id, ) + # ---- Server-side execution (not supported remotely) ---------------------------- + def evaluate_model(self, req: tsi.EvaluateModelReq) -> tsi.EvaluateModelRes: raise NotImplementedError("evaluate_model is not implemented") @@ -1173,348 +1304,262 @@ def rescore(self, req: tsi.RescoreReq) -> tsi.RescoreRes: def calls_score(self, req: tsi.CallsScoreReq) -> tsi.CallsScoreRes: raise NotImplementedError("calls_score is not implemented") - # === V2 APIs === + # ---- V2 APIs -------------------------------------------------------------- + + def _v2_body_create( + self, + req: BaseModel, + body_type: type[BaseModel], + sdk_method: Callable[..., Any], + res_type: type[TRes], + ) -> TRes: + """Create via a v2 endpoint: project_id moves to the path; the rest is body.""" + entity, project = from_project_id(req.project_id) # type: ignore[attr-defined] + body = body_type.model_validate(req.model_dump(exclude={"project_id"})) + return self._via_sdk_no_body( + sdk_method, res_type, body=body, entity=entity, project=project + ) + + def _v2_list_stream( + self, + req: BaseModel, + res_type: type[TRes], + kind: str, + params: dict[str, Any], + ) -> Iterator[TRes]: + """Stream a v2 jsonl list endpoint (the SDK buffers jsonl responses).""" + entity, project = from_project_id(req.project_id) # type: ignore[attr-defined] + url = f"/v2/{entity}/{project}/{kind}" + return self._raw_stream("GET", url, params=params, res_type=res_type) @validate_call def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/ops" - # For create, we need to send the body without project_id (OpCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.OpCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.OpCreateBody, - tsi.OpCreateRes, - method="POST", + return self._v2_body_create( + req, sdk_models.OpCreateBody, self._sdk.v2_ops.create, tsi.OpCreateRes ) @validate_call def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/ops/{req.object_id}/versions/{req.digest}" - return self._generic_request( - url, - req, - tsi.OpReadReq, + return self._via_sdk_no_body( + self._sdk.v2_ops.read, tsi.OpReadRes, - method="GET", + entity=entity, + project=project, + object_id=req.object_id, + digest=req.digest, ) @validate_call def op_list(self, req: tsi.OpListReq) -> Iterator[tsi.OpReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/ops" - # Build query params params: dict[str, Any] = {} if req.limit is not None: params["limit"] = req.limit if req.offset is not None: params["offset"] = req.offset + # `eager` is missing from the SDK's generated v2_ops.list signature. if req.eager: params["eager"] = "true" - return self._generic_stream_request( - url, - req, - tsi.OpListReq, - tsi.OpReadRes, - method="GET", - params=params, - ) + return self._v2_list_stream(req, tsi.OpReadRes, "ops", params) @validate_call def op_delete(self, req: tsi.OpDeleteReq) -> tsi.OpDeleteRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/ops/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.OpDeleteReq, + return self._via_sdk_no_body( + self._sdk.v2_ops.delete, tsi.OpDeleteRes, - method="DELETE", - params=params, + entity=entity, + project=project, + object_id=req.object_id, + digests=req.digests, ) @validate_call def dataset_create(self, req: tsi.DatasetCreateReq) -> tsi.DatasetCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/datasets" - # For create, we need to send the body without project_id (DatasetCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.DatasetCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.DatasetCreateBody, + return self._v2_body_create( + req, + sdk_models.DatasetCreateBody, + self._sdk.v2_datasets.create, tsi.DatasetCreateRes, - method="POST", ) @validate_call def dataset_read(self, req: tsi.DatasetReadReq) -> tsi.DatasetReadRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/datasets/{req.object_id}/versions/{req.digest}" - return self._generic_request( - url, - req, - tsi.DatasetReadReq, + return self._via_sdk_no_body( + self._sdk.v2_datasets.read, tsi.DatasetReadRes, - method="GET", + entity=entity, + project=project, + object_id=req.object_id, + digest=req.digest, ) @validate_call def dataset_list(self, req: tsi.DatasetListReq) -> Iterator[tsi.DatasetReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/datasets" - # Build query params params: dict[str, Any] = {} if req.limit is not None: params["limit"] = req.limit if req.offset is not None: params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.DatasetListReq, - tsi.DatasetReadRes, - method="GET", - params=params, - ) + return self._v2_list_stream(req, tsi.DatasetReadRes, "datasets", params) @validate_call def dataset_delete(self, req: tsi.DatasetDeleteReq) -> tsi.DatasetDeleteRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/datasets/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.DatasetDeleteReq, + return self._via_sdk_no_body( + self._sdk.v2_datasets.delete, tsi.DatasetDeleteRes, - method="DELETE", - params=params, + entity=entity, + project=project, + object_id=req.object_id, + digests=req.digests, ) @validate_call def scorer_create(self, req: tsi.ScorerCreateReq) -> tsi.ScorerCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scorers" - # For create, we need to send the body without project_id (ScorerCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.ScorerCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.ScorerCreateBody, + return self._v2_body_create( + req, + sdk_models.ScorerCreateBody, + self._sdk.v2_scorers.create, tsi.ScorerCreateRes, - method="POST", ) @validate_call def scorer_read(self, req: tsi.ScorerReadReq) -> tsi.ScorerReadRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scorers/{req.object_id}/versions/{req.digest}" - return self._generic_request( - url, - req, - tsi.ScorerReadReq, + return self._via_sdk_no_body( + self._sdk.v2_scorers.read, tsi.ScorerReadRes, - method="GET", + entity=entity, + project=project, + object_id=req.object_id, + digest=req.digest, ) @validate_call def scorer_list(self, req: tsi.ScorerListReq) -> Iterator[tsi.ScorerReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scorers" - # Build query params - params = {} + params: dict[str, Any] = {} if req.limit is not None: params["limit"] = req.limit if req.offset is not None: params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.ScorerListReq, - tsi.ScorerReadRes, - method="GET", - params=params, - ) + return self._v2_list_stream(req, tsi.ScorerReadRes, "scorers", params) @validate_call def scorer_delete(self, req: tsi.ScorerDeleteReq) -> tsi.ScorerDeleteRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scorers/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.ScorerDeleteReq, + return self._via_sdk_no_body( + self._sdk.v2_scorers.delete, tsi.ScorerDeleteRes, - method="DELETE", - params=params, + entity=entity, + project=project, + object_id=req.object_id, + digests=req.digests, ) @validate_call def evaluation_create( self, req: tsi.EvaluationCreateReq ) -> tsi.EvaluationCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluations" - # For create, we need to send the body without project_id (EvaluationCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.EvaluationCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.EvaluationCreateBody, + return self._v2_body_create( + req, + sdk_models.EvaluationCreateBody, + self._sdk.v2_evaluations.create, tsi.EvaluationCreateRes, - method="POST", ) @validate_call def evaluation_read(self, req: tsi.EvaluationReadReq) -> tsi.EvaluationReadRes: entity, project = from_project_id(req.project_id) - url = ( - f"/v2/{entity}/{project}/evaluations/{req.object_id}/versions/{req.digest}" - ) - return self._generic_request( - url, - req, - tsi.EvaluationReadReq, + return self._via_sdk_no_body( + self._sdk.v2_evaluations.read, tsi.EvaluationReadRes, - method="GET", + entity=entity, + project=project, + object_id=req.object_id, + digest=req.digest, ) @validate_call def evaluation_list( self, req: tsi.EvaluationListReq ) -> Iterator[tsi.EvaluationReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluations" - # Build query params - params = {} + params: dict[str, Any] = {} if req.limit is not None: params["limit"] = req.limit if req.offset is not None: params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.EvaluationListReq, - tsi.EvaluationReadRes, - method="GET", - params=params, - ) + return self._v2_list_stream(req, tsi.EvaluationReadRes, "evaluations", params) @validate_call def evaluation_delete( self, req: tsi.EvaluationDeleteReq ) -> tsi.EvaluationDeleteRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluations/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.EvaluationDeleteReq, + return self._via_sdk_no_body( + self._sdk.v2_evaluations.delete, tsi.EvaluationDeleteRes, - method="DELETE", - params=params, + entity=entity, + project=project, + object_id=req.object_id, + digests=req.digests, ) - # Model V2 API + # ---- Model V2 API ----------------------------------------------------------- @validate_call def model_create(self, req: tsi.ModelCreateReq) -> tsi.ModelCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/models" - body = tsi.ModelCreateBody.model_validate( - req.model_dump(exclude={"project_id"}) - ) - return self._generic_request( - url, - body, - tsi.ModelCreateBody, + return self._v2_body_create( + req, + sdk_models.ModelCreateBody, + self._sdk.v2_models.create, tsi.ModelCreateRes, - method="POST", ) @validate_call def model_read(self, req: tsi.ModelReadReq) -> tsi.ModelReadRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/models/{req.object_id}/versions/{req.digest}" - return self._generic_request( - url, - req, - tsi.ModelReadReq, + return self._via_sdk_no_body( + self._sdk.v2_models.read, tsi.ModelReadRes, - method="GET", + entity=entity, + project=project, + object_id=req.object_id, + digest=req.digest, ) @validate_call def model_list(self, req: tsi.ModelListReq) -> Iterator[tsi.ModelReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/models" - # Build query params - params = {} + params: dict[str, Any] = {} if req.limit is not None: params["limit"] = req.limit if req.offset is not None: params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.ModelListReq, - tsi.ModelReadRes, - method="GET", - params=params, - ) + return self._v2_list_stream(req, tsi.ModelReadRes, "models", params) @validate_call def model_delete(self, req: tsi.ModelDeleteReq) -> tsi.ModelDeleteRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/models/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.ModelDeleteReq, + return self._via_sdk_no_body( + self._sdk.v2_models.delete, tsi.ModelDeleteRes, - method="DELETE", - params=params, + entity=entity, + project=project, + object_id=req.object_id, + digests=req.digests, ) + # ---- Evaluation Run V2 API ---------------------------------------------------- + @validate_call def evaluation_run_create( self, req: tsi.EvaluationRunCreateReq ) -> tsi.EvaluationRunCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs" - # For create, we need to send the body without project_id (EvaluationRunCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.EvaluationRunCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.EvaluationRunCreateBody, + return self._v2_body_create( + req, + sdk_models.EvaluationRunCreateBody, + self._sdk.v2_evaluation_runs.create, tsi.EvaluationRunCreateRes, ) @@ -1523,22 +1568,20 @@ def evaluation_run_read( self, req: tsi.EvaluationRunReadReq ) -> tsi.EvaluationRunReadRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs/{req.evaluation_run_id}" - return self._generic_request( - url, - req, - tsi.EvaluationRunReadReq, + return self._via_sdk_no_body( + self._sdk.v2_evaluation_runs.read, tsi.EvaluationRunReadRes, - method="GET", + entity=entity, + project=project, + evaluation_run_id=req.evaluation_run_id, ) @validate_call def evaluation_run_list( self, req: tsi.EvaluationRunListReq ) -> Iterator[tsi.EvaluationRunReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs" - # Build query params + # Raw: the SDK's generated list signature renames the filter params + # (evaluations vs evaluation_refs), so use the wire format directly. params: dict[str, Any] = {} if req.limit is not None: params["limit"] = req.limit @@ -1551,13 +1594,8 @@ def evaluation_run_list( params["model_refs"] = ",".join(req.filter.models) if req.filter.evaluation_run_ids: params["evaluation_run_ids"] = ",".join(req.filter.evaluation_run_ids) - return self._generic_stream_request( - url, - req, - tsi.EvaluationRunListReq, - tsi.EvaluationRunReadRes, - method="GET", - params=params, + return self._v2_list_stream( + req, tsi.EvaluationRunReadRes, "evaluation_runs", params ) @validate_call @@ -1565,16 +1603,12 @@ def evaluation_run_delete( self, req: tsi.EvaluationRunDeleteReq ) -> tsi.EvaluationRunDeleteRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs" - # Build query params - evaluation_run_ids are passed as a query param - params = {"evaluation_run_ids": req.evaluation_run_ids} - return self._generic_request( - url, - req, - tsi.EvaluationRunDeleteReq, + return self._via_sdk_no_body( + self._sdk.v2_evaluation_runs.delete, tsi.EvaluationRunDeleteRes, - method="DELETE", - params=params, + entity=entity, + project=project, + evaluation_run_ids=req.evaluation_run_ids, ) @validate_call @@ -1582,53 +1616,46 @@ def evaluation_run_finish( self, req: tsi.EvaluationRunFinishReq ) -> tsi.EvaluationRunFinishRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs/{req.evaluation_run_id}/finish" - return self._generic_request( - url, - req, - tsi.EvaluationRunFinishReq, + body = sdk_models.EvaluationRunFinishBody.model_validate( + req.model_dump(exclude={"project_id", "evaluation_run_id"}) + ) + return self._via_sdk_no_body( + self._sdk.v2_evaluation_runs.finish, tsi.EvaluationRunFinishRes, - method="POST", + body=body, + entity=entity, + project=project, + evaluation_run_id=req.evaluation_run_id, ) - # Prediction V2 API + # ---- Prediction V2 API ---------------------------------------------------------- @validate_call def prediction_create( self, req: tsi.PredictionCreateReq ) -> tsi.PredictionCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions" - body = tsi.PredictionCreateBody.model_validate( - req.model_dump(exclude={"project_id"}) - ) - return self._generic_request( - url, - body, - tsi.PredictionCreateBody, + return self._v2_body_create( + req, + sdk_models.PredictionCreateBody, + self._sdk.v2_predictions.create, tsi.PredictionCreateRes, - method="POST", ) @validate_call def prediction_read(self, req: tsi.PredictionReadReq) -> tsi.PredictionReadRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions/{req.prediction_id}" - return self._generic_request( - url, - req, - tsi.PredictionReadReq, + return self._via_sdk_no_body( + self._sdk.v2_predictions.read, tsi.PredictionReadRes, - method="GET", + entity=entity, + project=project, + prediction_id=req.prediction_id, ) @validate_call def prediction_list( self, req: tsi.PredictionListReq ) -> Iterator[tsi.PredictionReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions" - # Build query params params: dict[str, Any] = {} if req.evaluation_run_id is not None: params["evaluation_run_id"] = req.evaluation_run_id @@ -1636,30 +1663,19 @@ def prediction_list( params["limit"] = req.limit if req.offset is not None: params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.PredictionListReq, - tsi.PredictionReadRes, - method="GET", - params=params, - ) + return self._v2_list_stream(req, tsi.PredictionReadRes, "predictions", params) @validate_call def prediction_delete( self, req: tsi.PredictionDeleteReq ) -> tsi.PredictionDeleteRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions" - # Build query params - prediction_ids are passed as a query param - params = {"prediction_ids": req.prediction_ids} - return self._generic_request( - url, - req, - tsi.PredictionDeleteReq, + return self._via_sdk_no_body( + self._sdk.v2_predictions.delete, tsi.PredictionDeleteRes, - method="DELETE", - params=params, + entity=entity, + project=project, + prediction_ids=req.prediction_ids, ) @validate_call @@ -1667,49 +1683,38 @@ def prediction_finish( self, req: tsi.PredictionFinishReq ) -> tsi.PredictionFinishRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions/{req.prediction_id}/finish" - return self._generic_request( - url, - req, - tsi.PredictionFinishReq, + return self._via_sdk_no_body( + self._sdk.v2_predictions.finish, tsi.PredictionFinishRes, - method="POST", + entity=entity, + project=project, + prediction_id=req.prediction_id, ) - # Score V2 API + # ---- Score V2 API --------------------------------------------------------------- @validate_call def score_create(self, req: tsi.ScoreCreateReq) -> tsi.ScoreCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scores" - body = tsi.ScoreCreateBody.model_validate( - req.model_dump(exclude={"project_id"}) - ) - return self._generic_request( - url, - body, - tsi.ScoreCreateBody, + return self._v2_body_create( + req, + sdk_models.ScoreCreateBody, + self._sdk.v2_scores.create, tsi.ScoreCreateRes, - method="POST", ) @validate_call def score_read(self, req: tsi.ScoreReadReq) -> tsi.ScoreReadRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scores/{req.score_id}" - return self._generic_request( - url, - req, - tsi.ScoreReadReq, + return self._via_sdk_no_body( + self._sdk.v2_scores.read, tsi.ScoreReadRes, - method="GET", + entity=entity, + project=project, + score_id=req.score_id, ) @validate_call def score_list(self, req: tsi.ScoreListReq) -> Iterator[tsi.ScoreReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scores" - # Build query params params: dict[str, Any] = {} if req.evaluation_run_id is not None: params["evaluation_run_id"] = req.evaluation_run_id @@ -1717,31 +1722,20 @@ def score_list(self, req: tsi.ScoreListReq) -> Iterator[tsi.ScoreReadRes]: params["limit"] = req.limit if req.offset is not None: params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.ScoreListReq, - tsi.ScoreReadRes, - method="GET", - params=params, - ) + return self._v2_list_stream(req, tsi.ScoreReadRes, "scores", params) @validate_call def score_delete(self, req: tsi.ScoreDeleteReq) -> tsi.ScoreDeleteRes: entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scores" - # Build query params - score_ids are passed as a query param - params = {"score_ids": req.score_ids} - return self._generic_request( - url, - req, - tsi.ScoreDeleteReq, + return self._via_sdk_no_body( + self._sdk.v2_scores.delete, tsi.ScoreDeleteRes, - method="DELETE", - params=params, + entity=entity, + project=project, + score_ids=req.score_ids, ) - # === Calls V2 API === + # ---- Calls V2 API --------------------------------------------------------------- def calls_complete( self, req: tsi.CallsUpsertCompleteReq @@ -1754,16 +1748,14 @@ def calls_complete( if not req.batch: return tsi.CallsUpsertCompleteRes() - # Extract entity/project from first item first_item = req.batch[0] entity, project = from_project_id(first_item.project_id) - - url = f"/v2/{entity}/{project}/calls/complete" - return self._generic_request( - url, - req, - tsi.CallsUpsertCompleteReq, - tsi.CallsUpsertCompleteRes, + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + CALLS_COMPLETE_PATH.format(entity=entity, project=project), + req=req, + res_type=tsi.CallsUpsertCompleteRes, ) def call_start_v2(self, req: tsi.CallStartV2Req) -> tsi.CallStartV2Res: @@ -1771,15 +1763,13 @@ def call_start_v2(self, req: tsi.CallStartV2Req) -> tsi.CallStartV2Res: This endpoint is used for eager ops that need their start visible immediately. """ - project_id = req.start.project_id - entity, project = from_project_id(project_id) - - url = f"/v2/{entity}/{project}/call/start" - return self._generic_request( - url, - req, - tsi.CallStartV2Req, - tsi.CallStartV2Res, + entity, project = from_project_id(req.start.project_id) + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + CALL_START_V2_PATH.format(entity=entity, project=project), + req=req, + res_type=tsi.CallStartV2Res, ) def call_end_v2(self, req: tsi.CallEndV2Req) -> tsi.CallEndV2Res: @@ -1787,13 +1777,11 @@ def call_end_v2(self, req: tsi.CallEndV2Req) -> tsi.CallEndV2Res: This endpoint is used for eager ops that need their end sent separately. """ - project_id = req.end.project_id - entity, project = from_project_id(project_id) - - url = f"/v2/{entity}/{project}/call/end" - return self._generic_request( - url, - req, - tsi.CallEndV2Req, - tsi.CallEndV2Res, + entity, project = from_project_id(req.end.project_id) + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + CALL_END_V2_PATH.format(entity=entity, project=project), + req=req, + res_type=tsi.CallEndV2Res, ) diff --git a/weave/utils/http_requests.py b/weave/utils/http_requests.py index 81339d6d0ab6..10a887681d2f 100644 --- a/weave/utils/http_requests.py +++ b/weave/utils/http_requests.py @@ -147,7 +147,8 @@ def _is_debug_http_enabled() -> bool: return os.environ.get("WEAVE_DEBUG_HTTP") == "1" -def _log_request(request: Request) -> None: +def log_request(request: Request) -> None: + """Httpx request hook: pretty-print the request when WEAVE_DEBUG_HTTP=1.""" if not _is_debug_http_enabled(): return @@ -156,7 +157,8 @@ def _log_request(request: Request) -> None: pprint_request(request) -def _log_response(response: Response) -> None: +def log_response(response: Response) -> None: + """Httpx response hook: pretty-print the response when WEAVE_DEBUG_HTTP=1.""" if not _is_debug_http_enabled(): return @@ -183,7 +185,7 @@ def _build_client(verify: bool, timeout: float) -> httpx.Client: return httpx.Client( # Use HTTPX's default transport so env proxy handling (including # NO_PROXY) works natively. - event_hooks={"request": [_log_request], "response": [_log_response]}, + event_hooks={"request": [log_request], "response": [log_response]}, timeout=timeout, limits=CLIENT_LIMITS, verify=verify,