diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f431b9fdc69c..2a0ecab1d3ec 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -149,7 +149,7 @@ jobs: "pandas_test", "fastmcp", "smolagents", - # "stainless", + "stainless", "autogen_tests", ] # Shards that don't support certain Python versions. Excluding here @@ -578,7 +578,7 @@ jobs: "pandas_test", "fastmcp", "smolagents", - # "stainless", + "stainless", "autogen_tests", ] # Shards that don't support certain Python versions. Excluding here diff --git a/pyproject.toml b/pyproject.toml index 54fffe1ebaaf..a13bcad23b96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,21 +58,17 @@ dependencies = [ # (gen_ai.tool.call.arguments, cache token names, # gen_ai.usage.reasoning.output_tokens) added in 0.63b0. "opentelemetry-semantic-conventions>=0.63b0", + + # Generated trace-server client (source of truth for the API types). + # TEMPORARY: resolved from test PyPI via [tool.uv.sources] below until the + # package is published to real PyPI. + "weave-server-sdk==0.0.1", ] [project.optional-dependencies] wandb = ["wandb>=0.17.1"] rich = ["rich"] # Optional dependency for enhanced console output -# DO NOT MODIFY THE STAINLESS DEPENDENCY GROUP MANUALLY. IT IS AUTO-GENERATED BY THE CODE GENERATOR. -# 1. For normal dev, pin to a SHA and allow direct references. This will look like: -# weave_server_sdk @ git+https://github.com/wandb/weave-stainless@daf91fdd07535c570eb618b343e2d125f9b09e25 -# 2. For deploys, pin to a specific version and remove allow-direct-references. This will look like: -# weave_server_sdk==0.0.1 -# stainless = [ -# "weave_server_sdk @ git+https://github.com/wandb/weave-stainless.git@9ce4f0792a54f05fba22531faee0c439ae9d70fa", -# ] - # `trace_server` is the dependency list of the trace server itself. We eventually will extract # this to a separate package. Note, when that happens, we will need to pull along some of the # default dependencies as well. @@ -537,6 +533,17 @@ conflicts = [ ], ] +# TEMPORARY: weave-server-sdk is only published to test PyPI so far. Resolve +# just that package from the test index; everything else stays on real PyPI. +# Remove this (and the index below) once it is published to real PyPI. +[tool.uv.sources] +weave-server-sdk = { index = "testpypi" } + +[[tool.uv.index]] +name = "testpypi" +url = "https://test.pypi.org/simple/" +explicit = true + # uv workspace members. Each listed path is a sibling Python package that uv # manages alongside the root `weave` project. Members can depend on each # other and on the root by name; uv resolves those to the local workspace diff --git a/tests/trace_server_bindings/test_http_behavior_stainless.py b/tests/trace_server_bindings/test_http_behavior_stainless.py index 99bf93f56ceb..25eb43f87bf4 100644 --- a/tests/trace_server_bindings/test_http_behavior_stainless.py +++ b/tests/trace_server_bindings/test_http_behavior_stainless.py @@ -2,189 +2,606 @@ These tests verify HTTP request/response handling, retry behavior for various status codes, and error handling specific to StainlessRemoteHTTPTraceServer. + +Mocking happens at the httpx transport boundary (the external seam), so the +full stack — SDK encode/decode, event hooks, error translation, retry +predicates — is exercised for real. """ from __future__ import annotations -from unittest.mock import MagicMock +import datetime +import json +import logging +from types import MethodType +import httpx import pytest -import requests import tenacity from pydantic import ValidationError -from tests.trace_server_bindings.conftest import generate_id, generate_start +from tests.trace_server_bindings.conftest import ( + generate_end, + generate_id, + generate_start, +) from weave.trace.display.term import configure_logger from weave.trace_server import trace_server_interface as tsi +from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor +from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor +from weave.trace_server_bindings.http_utils import ( + ERROR_CODE_CALLS_COMPLETE_MODE_REQUIRED, + CallsCompleteModeRequired, +) +from weave.trace_server_bindings.models import ( + CompleteBatchItem, + EndBatchItem, + StartBatchItem, +) from weave.trace_server_bindings.stainless_remote_http_trace_server import ( StainlessRemoteHTTPTraceServer, ) -from weave.utils.retry import with_retry +BASE_URL = "http://example.com" -@pytest.fixture -def unbatched_server(): - """Create a StainlessRemoteHTTPTraceServer instance without batching for testing.""" - return StainlessRemoteHTTPTraceServer("http://example.com") +class SpyTransport(httpx.BaseTransport): + """httpx transport that records requests and replays queued responses. -def test_call_start_ok(unbatched_server): - """Test successful call_start request.""" - call_id = generate_id() + Queue items may be ``httpx.Response`` objects or exceptions to raise. + When the queue is empty, returns ``default_response`` (200 ``{}`` unless + overridden). + """ - mock_response = MagicMock() - mock_response.model_dump.return_value = { - "id": call_id, - "trace_id": "test_trace_id", - } - unbatched_server._stainless_client.calls.start = MagicMock( - return_value=mock_response + def __init__( + self, + *items: httpx.Response | Exception, + default_response: httpx.Response | None = None, + ) -> None: + self.queue: list[httpx.Response | Exception] = list(items) + self.requests: list[httpx.Request] = [] + self.default_response = default_response + + def handle_request(self, request: httpx.Request) -> httpx.Response: + request.read() + self.requests.append(request) + if self.queue: + item = self.queue.pop(0) + if isinstance(item, Exception): + raise item + return item + if self.default_response is not None: + return self.default_response + return httpx.Response(200, json={}) + + @property + def urls(self) -> list[str]: + return [str(r.url) for r in self.requests] + + +def make_server( + transport: httpx.BaseTransport, + should_batch: bool = False, + **kwargs, +) -> StainlessRemoteHTTPTraceServer: + return StainlessRemoteHTTPTraceServer( + BASE_URL, should_batch=should_batch, transport=transport, **kwargs ) + +def shutdown(server: StainlessRemoteHTTPTraceServer) -> None: + if server.call_processor and server.call_processor.is_accepting_new_work(): + server.call_processor.stop_accepting_new_work_and_flush_queue() + if server.feedback_processor and server.feedback_processor.is_accepting_new_work(): + server.feedback_processor.stop_accepting_new_work_and_flush_queue() + + +def call_start_ok_response(call_id: str) -> httpx.Response: + return httpx.Response( + 200, json=tsi.CallStartRes(id=call_id, trace_id="test_trace_id").model_dump() + ) + + +def make_calls_complete_required_response() -> httpx.Response: + """Create a 400 response indicating the project requires calls_complete mode.""" + return httpx.Response( + 400, + json={ + "error_code": ERROR_CODE_CALLS_COMPLETE_MODE_REQUIRED, + "message": "Project requires calls_complete mode", + }, + ) + + +def test_call_start_ok(): + """Test successful call_start request.""" + call_id = generate_id() + transport = SpyTransport(call_start_ok_response(call_id)) + server = make_server(transport) + start = generate_start(call_id) - result = unbatched_server.call_start(tsi.CallStartReq(start=start)) + result = server.call_start(tsi.CallStartReq(start=start)) - unbatched_server._stainless_client.calls.start.assert_called_once() + assert transport.urls == [f"{BASE_URL}/call/start"] + sent = json.loads(transport.requests[0].content) + assert sent["start"]["id"] == call_id assert result.id == call_id assert result.trace_id == "test_trace_id" -def test_400_no_retry(unbatched_server): +def test_400_no_retry(): """Test that 400 errors are not retried.""" - from weave_server_sdk import APIStatusError - call_id = generate_id() - error_response = MagicMock() - error_response.status_code = 400 - error = APIStatusError( - message="Bad Request", - response=error_response, - body={"error": "Bad Request"}, - ) - - unbatched_server._stainless_client.calls.start = MagicMock(side_effect=error) + transport = SpyTransport(httpx.Response(400, json={"error": "Bad Request"})) + server = make_server(transport) start = generate_start(call_id) - with pytest.raises(APIStatusError): - unbatched_server.call_start(tsi.CallStartReq(start=start)) + with pytest.raises(httpx.HTTPStatusError): + server.call_start(tsi.CallStartReq(start=start)) # Should only be called once (no retry for 400) - assert unbatched_server._stainless_client.calls.start.call_count == 1 + assert len(transport.requests) == 1 -def test_invalid_no_retry(unbatched_server): +def test_invalid_no_retry(): """Test that validation errors are not retried.""" + transport = SpyTransport() + server = make_server(transport) with pytest.raises(ValidationError): - unbatched_server.call_start(tsi.CallStartReq(start={"invalid": "broken"})) + server.call_start(tsi.CallStartReq(start={"invalid": "broken"})) + assert len(transport.requests) == 0 + + +def test_500_502_503_504_429_retry(monkeypatch): + """Test that 5xx and 429 errors are retried.""" + monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "6") + monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") + call_id = generate_id() + + transport = SpyTransport( + httpx.Response(500), + httpx.Response(502), + httpx.Response(503), + httpx.Response(504), + httpx.Response(429), + call_start_ok_response(call_id), + ) + server = make_server(transport) + + start = generate_start(call_id) + result = server.call_start(tsi.CallStartReq(start=start)) + assert result.id == call_id + assert len(transport.requests) == 6 + + +def test_other_error_retry(monkeypatch): + """Test that connection errors are retried.""" + monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "6") + monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") + call_id = generate_id() + + transport = SpyTransport( + ConnectionResetError(), + ConnectionError(), + OSError(), + TimeoutError(), + call_start_ok_response(call_id), + ) + server = make_server(transport) + + start = generate_start(call_id) + result = server.call_start(tsi.CallStartReq(start=start)) + assert result.id == call_id + + +def test_retry_id_header_injected(monkeypatch): + """Every request carries the per-attempt X-Weave-Retry-Id header.""" + monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "3") + monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") + call_id = generate_id() + transport = SpyTransport(httpx.Response(500), call_start_ok_response(call_id)) + server = make_server(transport) + + server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + + retry_ids = [r.headers.get("X-Weave-Retry-Id") for r in transport.requests] + assert all(retry_ids) + # Same logical request, so both attempts share one retry id + assert len(set(retry_ids)) == 1 + + +def test_extra_headers_and_auth_are_sent(): + """Constructor extra_headers and auth flow into every request.""" + call_id = generate_id() + transport = SpyTransport(call_start_ok_response(call_id)) + server = make_server( + transport, auth=("api", "secret-key"), extra_headers={"X-Custom": "yes"} + ) + + server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + + request = transport.requests[0] + assert request.headers["X-Custom"] == "yes" + assert request.headers["Authorization"].startswith("Basic ") + + +def test_set_auth_applies_to_subsequent_requests(): + """set_auth after construction updates the live client.""" + call_id = generate_id() + transport = SpyTransport(default_response=None) + transport.default_response = call_start_ok_response(call_id) + server = make_server(transport) + + server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + assert "Authorization" not in transport.requests[0].headers + + server.set_auth(("api", "secret-key")) + server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + assert transport.requests[1].headers["Authorization"].startswith("Basic ") + + +def test_typed_sdk_route_obj_read(): + """A typed SDK route hits the right path and converts the response.""" + transport = SpyTransport( + httpx.Response( + 200, + json={ + "obj": { + "project_id": "entity/project", + "object_id": "my-obj", + "created_at": "2024-01-01T00:00:00Z", + "deleted_at": None, + "digest": "abc", + "version_index": 0, + "is_latest": 1, + "kind": "object", + "base_object_class": None, + "leaf_object_class": None, + "val": {"a": 1}, + } + }, + ) + ) + server = make_server(transport) + + res = server.obj_read( + tsi.ObjReadReq(project_id="entity/project", object_id="my-obj", digest="abc") + ) + + assert transport.urls == [f"{BASE_URL}/obj/read"] + assert isinstance(res, tsi.ObjReadRes) + assert res.obj.object_id == "my-obj" + assert res.obj.val == {"a": 1} + + +def test_calls_complete_batch_endpoint_and_payload(monkeypatch): + """Send calls_complete batches to the v2 endpoint with correct payload.""" + monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "true") + transport = SpyTransport() + server = make_server(transport, should_batch=True) + + complete = tsi.CompletedCallSchemaForInsert( + project_id="entity/project", + id="call-id", + trace_id="trace-id", + op_name="test_op", + started_at=datetime.datetime.now(tz=datetime.timezone.utc), + ended_at=datetime.datetime.now(tz=datetime.timezone.utc), + attributes={"a": 1}, + inputs={"b": 2}, + output={"c": 3}, + summary={"result": "ok"}, + ) + batch = [CompleteBatchItem(req=complete)] + + try: + server._flush_calls_complete(batch) + finally: + shutdown(server) + + assert transport.urls == [f"{BASE_URL}/v2/entity/project/calls/complete"] + payload = json.loads(transport.requests[0].content) + expected = tsi.CallsUpsertCompleteReq(batch=[complete]).model_dump(mode="json") + assert payload == expected + + +def test_eager_calls_use_v2_start_end_endpoints(): + """Use v2 endpoints for eager start/end and include started_at in end.""" + transport = SpyTransport() + server = make_server(transport, should_batch=True) + + start = generate_start(id="call-id", project_id="entity/project") + ended_at = datetime.datetime.now(tz=datetime.timezone.utc) + started_at = ended_at - datetime.timedelta(seconds=1) + end = tsi.EndedCallSchemaForInsertWithStartedAt( + project_id="entity/project", + id="call-id", + ended_at=ended_at, + started_at=started_at, + summary={"result": "Test summary"}, + ) + + try: + server._flush_calls_eager( + [ + StartBatchItem(req=tsi.CallStartReq(start=start)), + EndBatchItem(req=tsi.CallEndReq(end=end)), + ] + ) + + assert transport.urls == [ + f"{BASE_URL}/v2/entity/project/call/start", + f"{BASE_URL}/v2/entity/project/call/end", + ] + + end_payload = json.loads(transport.requests[1].content) + payload_started_at = datetime.datetime.fromisoformat( + end_payload["end"]["started_at"].replace("Z", "+00:00") + ) + assert payload_started_at == end.started_at + assert end_payload["end"]["id"] == "call-id" + finally: + shutdown(server) @pytest.mark.disable_logging_error_check -def test_timeout_retry_mechanism(success_response): - """Test that timeouts trigger the retry mechanism.""" - server = StainlessRemoteHTTPTraceServer("http://example.com", should_batch=True) +def test_eager_non_retryable_error_drops_item(caplog): + """Drop eager items on non-retryable errors without raising.""" + transport = SpyTransport(default_response=httpx.Response(400, json={})) + server = make_server(transport, should_batch=True) + start = generate_start(id="call-id", project_id="entity/project") - # Mock _send_batch_to_server to raise errors twice, then succeed - call_count = 0 + caplog.set_level(logging.ERROR) + try: + server._flush_calls_eager([StartBatchItem(req=tsi.CallStartReq(start=start))]) + finally: + shutdown(server) - def mock_send_batch(encoded_data: bytes) -> None: - nonlocal call_count - call_count += 1 - if call_count == 1: - raise requests.exceptions.Timeout("Connection timed out") - elif call_count == 2: - raise requests.exceptions.HTTPError("500 Server Error") - else: - return + assert any("dropped call start ids" in record.message for record in caplog.records) - # Wrap the mock with the retry decorator to preserve retry behavior - server._send_batch_to_server = with_retry(mock_send_batch) + +@pytest.mark.disable_logging_error_check +def test_eager_retryable_error_logs_and_continues(caplog): + """Log and drop item on retryable error, continue with remaining items.""" + transport = SpyTransport() + server = make_server(transport, should_batch=True) + start1 = generate_start(id="call-id-1", project_id="entity/project") + start2 = generate_start(id="call-id-2", project_id="entity/project") + + call_attempts = [] + + def _raise_retryable_once(start) -> None: + call_attempts.append(start.id) + if start.id == "call-id-1": + raise httpx.HTTPStatusError( + "500", + request=httpx.Request("POST", BASE_URL), + response=httpx.Response(500, request=httpx.Request("POST", BASE_URL)), + ) + # call-id-2 succeeds + + server._send_call_start_v2 = _raise_retryable_once # type: ignore[assignment] + + try: + # Should NOT raise - logs and drops item 1, continues with item 2 + server._flush_calls_eager( + [ + StartBatchItem(req=tsi.CallStartReq(start=start1)), + StartBatchItem(req=tsi.CallStartReq(start=start2)), + ] + ) + finally: + shutdown(server) + + # Item 1 was logged as dropped + assert any("dropped call start ids" in record.message for record in caplog.records) + assert any("call-id-1" in record.message for record in caplog.records) + # Item 2 was still processed + assert "call-id-2" in call_attempts + + +def test_timeout_retry_mechanism(monkeypatch): + """Test that timeouts trigger the retry mechanism.""" + monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") + monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") + transport = SpyTransport( + httpx.TimeoutException("Connection timed out"), + httpx.Response(500), + httpx.Response(200, json={}), + ) + server = make_server(transport, should_batch=True) # Trying to send a batch should fail 2 times, then succeed server.call_start(tsi.CallStartReq(start=generate_start())) server.call_processor.stop_accepting_new_work_and_flush_queue() - # Verify that _send_batch_to_server was called 3 times (2 failures + 1 success) - assert call_count == 3 - - -@pytest.fixture -def fast_retrying_server(): - """Create a StainlessRemoteHTTPTraceServer with fast retry settings for testing.""" - server = StainlessRemoteHTTPTraceServer("http://example.com", should_batch=True) - fast_retry = tenacity.retry( - wait=tenacity.wait_fixed(0.1), - stop=tenacity.stop_after_attempt(2), - reraise=True, - ) - original_stainless_request = server._stainless_request - server._stainless_request = fast_retry(original_stainless_request) - yield server - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() + assert len(transport.requests) == 3 + assert all(url == f"{BASE_URL}/call/upsert_batch" for url in transport.urls) @pytest.mark.disable_logging_error_check -def test_post_timeout(success_response, fast_retrying_server, log_collector): +def test_post_timeout(monkeypatch, log_collector): """Test batch recovery after timeout exhaustion. This test verifies that we can still send new batches even if one batch times out and exhausts all retries. """ configure_logger() - call_count = 0 + monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") - def mock_send_batch_timeout(encoded_data: bytes) -> None: - nonlocal call_count - call_count += 1 - raise requests.exceptions.Timeout("Connection timed out") - - # Wrap the mock with the retry decorator to preserve retry behavior - fast_retrying_server._send_batch_to_server = with_retry(mock_send_batch_timeout) + fast_retry = tenacity.retry( + wait=tenacity.wait_fixed(0.1), + stop=tenacity.stop_after_attempt(2), + reraise=True, + ) # Phase 1: Try but fail to process the first batch - fast_retrying_server.call_start(tsi.CallStartReq(start=generate_start())) - fast_retrying_server.call_processor.stop_accepting_new_work_and_flush_queue() + transport = SpyTransport( + default_response=None, + ) + transport.default_response = None + transport.queue = [ + httpx.TimeoutException("Connection timed out"), + httpx.TimeoutException("Connection timed out"), + ] + server = make_server(transport, should_batch=True) + unwrapped_send_batch_to_server = MethodType( + server._send_batch_to_server.__wrapped__, # type: ignore[attr-defined] + server, + ) + server._send_batch_to_server = fast_retry(unwrapped_send_batch_to_server) + + server.call_start(tsi.CallStartReq(start=generate_start())) + server.call_processor.stop_accepting_new_work_and_flush_queue() logs = log_collector.get_warning_logs() assert len(logs) >= 1 - assert any( - "requeueing batch" in log.msg or "batch failed" in log.msg for log in logs - ) - - # Phase 2: Reset mock and verify we can still process a new batch - call_count = 0 - - def mock_start_success(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 1: - raise requests.exceptions.Timeout("Connection timed out") - else: - mock_response = MagicMock() - mock_response.id = "test_id" - mock_response.trace_id = "test_trace_id" - mock_response.model_dump.return_value = { - "id": "test_id", - "trace_id": "test_trace_id", - } - return mock_response - - # Create a new server since the old one has shutdown its batch processor - new_server = StainlessRemoteHTTPTraceServer( - "http://example.com", should_batch=False + assert any("requeuing batch" in log.msg for log in logs) + shutdown(server) + + # Phase 2: Verify a fresh server can still process a new batch after a + # transient timeout + call_id = generate_id() + transport2 = SpyTransport( + httpx.TimeoutException("Connection timed out"), + call_start_ok_response(call_id), ) - fast_retry = tenacity.retry( - wait=tenacity.wait_fixed(0.1), - stop=tenacity.stop_after_attempt(2), - reraise=True, + new_server = make_server(transport2, should_batch=False) + fast_retried_via_sdk = fast_retry( + MethodType( + new_server._via_sdk.__wrapped__, # type: ignore[attr-defined] + new_server, + ) ) - original_stainless_request = new_server._stainless_request - new_server._stainless_request = fast_retry(original_stainless_request) - new_server._stainless_client.calls.start = mock_start_success + new_server._via_sdk = fast_retried_via_sdk - # Should succeed with retry - start_req = tsi.CallStartReq(start=generate_start()) - response = new_server.call_start(start_req) - assert response.id == "test_id" + response = new_server.call_start(tsi.CallStartReq(start=generate_start(call_id))) + assert response.id == call_id assert response.trace_id == "test_trace_id" + + +def test_auto_upgrade_to_calls_complete_on_error(monkeypatch): + """Verify client switches to CallBatchProcessor when server returns CALLS_COMPLETE_MODE_REQUIRED.""" + monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") + transport = SpyTransport( + make_calls_complete_required_response(), + httpx.Response(200, json={}), + ) + server = make_server(transport, should_batch=True) + + # Verify initial state: using legacy AsyncBatchProcessor + assert server.use_calls_complete is False + assert isinstance(server.call_processor, AsyncBatchProcessor) + assert not isinstance(server.call_processor, CallBatchProcessor) + old_processor = server.call_processor + + call_id = generate_id() + start = StartBatchItem( + req=tsi.CallStartReq(start=generate_start(call_id, "entity/project")) + ) + end = EndBatchItem(req=tsi.CallEndReq(end=generate_end(call_id, "entity/project"))) + + try: + server._flush_calls([start, end]) + server.call_processor.stop_accepting_new_work_and_flush_queue() + + # Verify upgrade happened + assert server.use_calls_complete is True + assert isinstance(server.call_processor, CallBatchProcessor) + assert old_processor.stop_accepting_work_event.is_set() + assert any("/calls/complete" in url for url in transport.urls) + finally: + shutdown(server) + + +def test_eager_calls_complete_required_is_reraised(monkeypatch): + """Verify CallsCompleteModeRequired in eager path is re-raised for caller to handle.""" + monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "true") + transport = SpyTransport(make_calls_complete_required_response()) + server = make_server(transport, should_batch=True) + + start = StartBatchItem( + req=tsi.CallStartReq(start=generate_start("call-id", "entity/project")) + ) + + try: + with pytest.raises(CallsCompleteModeRequired): + server._flush_calls_eager([start]) + finally: + shutdown(server) + + +def test_calls_query_stream_parses_jsonl(): + """calls_query_stream yields CallSchema objects from a jsonl response.""" + call = { + "project_id": "entity/project", + "id": "call-1", + "op_name": "op", + "trace_id": "trace-1", + "started_at": "2024-01-01T00:00:00Z", + "attributes": {}, + "inputs": {}, + } + body = "\n".join([json.dumps(call), json.dumps({**call, "id": "call-2"})]) + transport = SpyTransport( + httpx.Response( + 200, + content=body.encode("utf-8"), + headers={"content-type": "application/jsonl"}, + ) + ) + server = make_server(transport) + + calls = list( + server.calls_query_stream(tsi.CallsQueryReq(project_id="entity/project")) + ) + + assert transport.urls == [f"{BASE_URL}/calls/stream_query"] + assert [c.id for c in calls] == ["call-1", "call-2"] + assert all(isinstance(c, tsi.CallSchema) for c in calls) + + +def test_file_create_sends_multipart(): + """file_create posts multipart form data (SDK 0.0.1 lost the body).""" + transport = SpyTransport(httpx.Response(200, json={"digest": "digest-1"})) + server = make_server(transport) + + res = server.file_create( + tsi.FileCreateReq( + project_id="entity/project", name="file.txt", content=b"hello" + ) + ) + + assert res.digest == "digest-1" + request = transport.requests[0] + assert str(request.url) == f"{BASE_URL}/files/create" + assert request.headers["content-type"].startswith("multipart/form-data") + assert b"hello" in request.content + assert b"entity/project" in request.content + + +def test_feedback_create_unbatched_uses_single_route(): + """Unbatched feedback_create posts to /feedback/create (SDK route shadowed).""" + transport = SpyTransport( + httpx.Response( + 200, + json={ + "id": "feedback-1", + "created_at": "2024-01-01T00:00:00Z", + "wb_user_id": "user", + "payload": {"note": "hi"}, + }, + ) + ) + server = make_server(transport) + + res = server.feedback_create( + tsi.FeedbackCreateReq( + project_id="entity/project", + weave_ref="weave:///entity/project/call/call-1", + feedback_type="wandb.note.1", + payload={"note": "hi"}, + ) + ) + + assert transport.urls == [f"{BASE_URL}/feedback/create"] + assert res.id == "feedback-1" diff --git a/tests/utils/test_http_requests.py b/tests/utils/test_http_requests.py index 6c21773a9ad5..6fd237e2ac30 100644 --- a/tests/utils/test_http_requests.py +++ b/tests/utils/test_http_requests.py @@ -33,7 +33,7 @@ def test_request_hook_logs_when_enabled(monkeypatch): request = httpx.Request("GET", "https://api.wandb.ai/calls") with patch.object(http_requests, "pprint_request") as mock_pprint_request: - http_requests._log_request(request) + http_requests.log_request(request) mock_pprint_request.assert_called_once_with(request) assert isinstance(request.extensions.get("weave_start_time"), float) @@ -43,7 +43,7 @@ def test_request_hook_noop_when_disabled(): request = httpx.Request("GET", "https://api.wandb.ai/calls") with patch.object(http_requests, "pprint_request") as mock_pprint_request: - http_requests._log_request(request) + http_requests.log_request(request) mock_pprint_request.assert_not_called() assert "weave_start_time" not in request.extensions @@ -59,7 +59,7 @@ def test_response_hook_logs_when_enabled(monkeypatch): patch.object(http_requests, "pprint_response") as mock_pprint_response, patch("weave.utils.http_requests.time", return_value=2.0), ): - http_requests._log_response(response) + http_requests.log_response(response) mock_pprint_response.assert_called_once_with(response) diff --git a/uv.lock b/uv.lock index e3b3e0c26b7e..361fff69c68c 100644 --- a/uv.lock +++ b/uv.lock @@ -10762,6 +10762,7 @@ dependencies = [ { name = "sentry-sdk", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, { name = "tenacity", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, { name = "tzdata", marker = "(platform_python_implementation == 'CPython' and sys_platform == 'win32') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai') or (sys_platform != 'win32' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (sys_platform != 'win32' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (sys_platform != 'win32' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (sys_platform != 'win32' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (sys_platform != 'win32' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (sys_platform != 'win32' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, + { name = "weave-server-sdk", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, ] [package.optional-dependencies] @@ -11088,6 +11089,7 @@ requires-dist = [ { name = "verifiers", marker = "python_full_version >= '3.11' and extra == 'verifiers'", specifier = ">=0.1.3.post0,<0.1.13" }, { name = "vertexai", marker = "platform_python_implementation != 'PyPy' and extra == 'vertexai'", specifier = ">=1.70.0" }, { name = "wandb", marker = "extra == 'wandb'", specifier = ">=0.17.1" }, + { name = "weave-server-sdk", specifier = "==0.0.1", index = "https://test.pypi.org/simple/" }, ] provides-extras = ["anthropic", "autogen", "bedrock", "cerebras", "claude-agent-sdk", "cohere", "crewai", "dspy", "fastmcp", "gepa", "google-genai", "groq", "huggingface", "instructor", "langchain", "langchain-nvidia-ai-endpoints", "litellm", "llamaindex", "mistral", "modal", "notdiamond", "openai", "openai-agents", "presidio", "rich", "scorers", "smolagents", "trace-server", "verdict", "verifiers", "vertexai", "video-support", "wandb"] @@ -11153,6 +11155,20 @@ verifiers-test = [ { name = "verifiers", marker = "python_full_version >= '3.11'", specifier = ">=0.1.3.post0,<0.1.13" }, ] +[[package]] +name = "weave-server-sdk" +version = "0.0.1" +source = { registry = "https://test.pypi.org/simple/" } +dependencies = [ + { name = "httpx", version = "0.27.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version < '3.14' and platform_python_implementation == 'CPython' and extra == 'extra-5-weave-crewai') or (python_full_version >= '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (python_full_version >= '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (python_full_version >= '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (python_full_version >= '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai') or (extra != 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra != 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, + { name = "httpx", version = "0.28.1", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.14' and platform_python_implementation == 'CPython') or (python_full_version < '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (python_full_version < '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (python_full_version < '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (python_full_version < '3.14' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation == 'CPython' and extra != 'extra-5-weave-crewai') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (platform_python_implementation != 'CPython' and extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, + { name = "pydantic", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-dspy') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-instructor') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-langchain') or (extra == 'extra-5-weave-crewai' and extra == 'extra-5-weave-scorers') or (extra == 'extra-5-weave-dspy' and extra == 'extra-5-weave-gepa') or (extra == 'extra-5-weave-langchain' and extra == 'extra-5-weave-vertexai')" }, +] +sdist = { url = "https://test-files.pythonhosted.org/packages/4e/29/419e7dd1b628ee7a8f7501717b188eff04bbd5fa7c9937243701292517b1/weave_server_sdk-0.0.1.tar.gz", hash = "sha256:21d32d702ffdbee6f0ca00a31e550a6cb97076def9780da1cd18d3e23a7103f5", size = 22409, upload-time = "2026-06-11T04:23:44.073Z" } +wheels = [ + { url = "https://test-files.pythonhosted.org/packages/49/51/0027425fddfa07ff11e41f362d397f33f7247fcee3ca61c581215c6b190f/weave_server_sdk-0.0.1-py3-none-any.whl", hash = "sha256:c5a39f2f9573946e5f89b3cd3b1350226f047594b1122480ebc0e06b3ee0b8bc", size = 38852, upload-time = "2026-06-11T04:23:43.141Z" }, +] + [[package]] name = "webcolors" version = "25.10.0" diff --git a/weave/trace_server_bindings/stainless_remote_http_trace_server.py b/weave/trace_server_bindings/stainless_remote_http_trace_server.py index d230ad46c853..a3075d90b285 100644 --- a/weave/trace_server_bindings/stainless_remote_http_trace_server.py +++ b/weave/trace_server_bindings/stainless_remote_http_trace_server.py @@ -1,50 +1,122 @@ +"""Remote trace server binding backed by the generated ``weave-server-sdk``. + +This is a drop-in replacement for ``RemoteHTTPTraceServer``. It speaks the same +tsi-typed ``TraceServerClientInterface`` surface, but delegates HTTP transport +and request/response typing to the ``weave_server_sdk`` package (generated from +the trace server's OpenAPI spec — the source of truth for the API shape). + +Design notes: + +- A single ``httpx.Client`` (built like ``weave.utils.http_requests``: default + transport for env-proxy handling, no connection limits, ``ssl_verify()`` and + ``http_timeout()`` honored) is injected into the SDK so every request — + SDK-routed or raw — shares one connection pool, auth, and event hooks. +- A response event hook routes every non-2xx response through + ``handle_response_error`` *before* the SDK sees it, so callers observe the + exact same ``httpx.HTTPStatusError`` / ``CallsCompleteModeRequired`` + semantics as ``RemoteHTTPTraceServer`` (retry predicates, 413 batch + splitting, and calls_complete auto-upgrade all key off these). +- A request event hook injects the dynamic ``X-Weave-Retry-Id`` header at send + time so every retry attempt carries the current retry id. +- Endpoints the SDK cannot reach go through ``_raw_request``/``_raw_stream`` + with an explicit reason. Two categories: + 1. Endpoints excluded from the OpenAPI spec (``include_in_schema=False`` on + the server): calls_complete v2, eager v2 call start/end, completions, + project stats, TTL settings. + 2. weave-server-sdk 0.0.1 codegen bugs (duplicate method names where the + last definition wins, lost multipart body): single feedback create, obj + tag add/remove, /trace/usage, file create. Remove these hatches when a + fixed SDK ships. +- Streaming endpoints (``*_stream`` and the v2 jsonl list endpoints) use + ``_raw_stream`` because the published SDK buffers jsonl responses into + memory; the raw path preserves line-by-line streaming. +""" + from __future__ import annotations import datetime import io import logging from collections.abc import Callable, Iterator -from typing import Any, TypeVar +from typing import Any, TypeVar, cast from zoneinfo import ZoneInfo -from pydantic import BaseModel, validate_call +import httpx +from pydantic import BaseModel, Field, validate_call +from pydantic.json_schema import SkipJsonSchema from typing_extensions import Self -from weave_server_sdk import Client as StainlessClient - -from weave.trace.env import weave_trace_server_url -from weave.trace.settings import max_calls_queue_size, should_enable_disk_fallback +from weave_server_sdk import WeaveTrace +from weave_server_sdk import models as sdk_models + +from weave.trace.env import ssl_verify, weave_trace_server_url +from weave.trace.settings import ( + http_timeout, + max_calls_queue_size, + should_enable_disk_fallback, + should_use_calls_complete, +) from weave.trace_server import trace_server_interface as tsi from weave.trace_server.ids import generate_id from weave.trace_server.service_interface import ServerInfoRes from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor +from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor from weave.trace_server_bindings.client_interface import TraceServerClientInterface from weave.trace_server_bindings.http_utils import ( REMOTE_REQUEST_BYTES_LIMIT, + CallsCompleteModeRequired, + handle_response_error, log_dropped_call_batch, log_dropped_feedback_batch, process_batch_with_retry, ) from weave.trace_server_bindings.models import ( Batch, + CompleteBatchItem, EndBatchItem, + EntityProjectInfo, StartBatchItem, ) +from weave.utils.http_requests import CLIENT_LIMITS, log_request, log_response from weave.utils.project_id import from_project_id from weave.utils.retry import get_current_retry_id, with_retry from weave.wandb_interface import project_creator -TReq = TypeVar("TReq", bound=BaseModel) TRes = TypeVar("TRes", bound=BaseModel) logger = logging.getLogger(__name__) +# Endpoints reached via _raw_request/_raw_stream because they are excluded from +# the OpenAPI spec (include_in_schema=False on the server). +CALLS_COMPLETE_PATH = "/v2/{entity}/{project}/calls/complete" +CALL_START_V2_PATH = "/v2/{entity}/{project}/call/start" +CALL_END_V2_PATH = "/v2/{entity}/{project}/call/end" +COMPLETIONS_CREATE_PATH = "/completions/create" +PROJECT_STATS_PATH = "/project/stats" +PROJECT_TTL_SETTINGS_READ_PATH = "/project/ttl_settings/read" +PROJECT_TTL_SETTINGS_UPDATE_PATH = "/project/ttl_settings/update" +FEEDBACK_AGGREGATE_PATH = "/feedback/aggregate" + +# Endpoints reached via _raw_request because weave-server-sdk 0.0.1 cannot call +# them (duplicate generated method names where the last definition wins, or a +# lost multipart body). Remove once a fixed SDK is published. +FEEDBACK_CREATE_PATH = "/feedback/create" +FEEDBACK_BATCH_CREATE_PATH = "/feedback/batch/create" +TRACE_USAGE_PATH = "/trace/usage" +OBJ_ADD_TAGS_PATH = "/objs/{object_id}/versions/{digest}/tags" +OBJ_REMOVE_TAGS_PATH = "/objs/{object_id}/versions/{digest}/tags/remove" +FILE_CREATE_PATH = "/files/create" +FILE_CONTENT_PATH = "/files/content" + +# Streaming endpoints; the published SDK buffers jsonl bodies, so these are +# reached via _raw_stream to preserve line-by-line streaming. +CALLS_STREAM_QUERY_PATH = "/calls/stream_query" +THREADS_STREAM_QUERY_PATH = "/threads/stream_query" +ANNOTATION_QUEUES_QUERY_PATH = "/annotation_queues/query" +CALL_UPSERT_BATCH_PATH = "/call/upsert_batch" -class StainlessRemoteHTTPTraceServer(TraceServerClientInterface): - """Drop-in replacement for RemoteHTTPTraceServer using the stainless client. - This implementation uses the stainless-generated client instead of manual HTTP requests. - It maintains the same interface and behavior as RemoteHTTPTraceServer. - """ +class StainlessRemoteHTTPTraceServer(TraceServerClientInterface): + """SDK-backed drop-in replacement for ``RemoteHTTPTraceServer``.""" trace_server_url: str @@ -54,44 +126,88 @@ def __init__( should_batch: bool = False, *, remote_request_bytes_limit: int = REMOTE_REQUEST_BYTES_LIMIT, - username: str = "", - password: str = "", + auth: tuple[str, str] | None = None, extra_headers: dict[str, str] | None = None, + transport: httpx.BaseTransport | None = None, ): + super().__init__() self.trace_server_url = trace_server_url.rstrip("/") self.should_batch = should_batch - self.call_processor: AsyncBatchProcessor | None = None + self.use_calls_complete = should_use_calls_complete() and should_batch + self.call_processor: AsyncBatchProcessor | CallBatchProcessor | None = None self.feedback_processor: AsyncBatchProcessor | None = None + self._auth: tuple[str, str] | None = auth + self._extra_headers: dict[str, str] | None = extra_headers self.remote_request_bytes_limit = remote_request_bytes_limit - self._extra_headers: dict[str, str] = extra_headers or {} - self._username: str = username - self._password: str = password - - # Initialize stainless client - default_headers = self._extra_headers.copy() - if retry_id := get_current_retry_id(): - default_headers["X-Weave-Retry-Id"] = retry_id + # Test seam: lets tests (and in-process fixtures) substitute the HTTP + # transport while keeping the full client stack (hooks, SDK) in play. + self._transport = transport - self._stainless_client = StainlessClient( - base_url=trace_server_url, - username=username, - password=password, - default_headers=default_headers, - batch_requests=False, # We handle batching ourselves - ) + self._http = self._build_http_client() + self._sdk = WeaveTrace(http_client=self._http) if self.should_batch: - self.call_processor = AsyncBatchProcessor( - self._flush_calls, - max_queue_size=max_calls_queue_size(), - enable_disk_fallback=should_enable_disk_fallback(), - ) + if self.use_calls_complete: + self.call_processor = CallBatchProcessor( + complete_processor_fn=self._flush_calls_complete, + eager_processor_fn=self._flush_calls_eager, + max_queue_size=max_calls_queue_size(), + enable_disk_fallback=should_enable_disk_fallback(), + ) + else: + self.call_processor = AsyncBatchProcessor( + self._flush_calls, + max_queue_size=max_calls_queue_size(), + enable_disk_fallback=should_enable_disk_fallback(), + ) self.feedback_processor = AsyncBatchProcessor( self._flush_feedback, max_queue_size=max_calls_queue_size(), enable_disk_fallback=should_enable_disk_fallback(), ) + # ---- transport --------------------------------------------------------- + + def _build_http_client(self) -> httpx.Client: + kwargs: dict[str, Any] = {} + if self._transport is not None: + kwargs["transport"] = self._transport + return httpx.Client( + base_url=self.trace_server_url, + auth=self._auth, + headers=self._extra_headers, + # Default transport (unless injected) so env proxy handling + # (incl. NO_PROXY) works natively. + event_hooks={ + "request": [log_request, self._inject_dynamic_headers], + "response": [log_response, self._raise_for_status], + }, + timeout=http_timeout(), + limits=CLIENT_LIMITS, + verify=ssl_verify(), + **kwargs, + ) + + def _inject_dynamic_headers(self, request: httpx.Request) -> None: + """Inject per-attempt headers at send time (httpx request hook).""" + if retry_id := get_current_retry_id(): + request.headers["X-Weave-Retry-Id"] = retry_id + + def _raise_for_status(self, response: httpx.Response) -> None: + """Surface error responses with RemoteHTTPTraceServer's exception + semantics (httpx response hook). + + Raising here means the SDK's own exception types never surface: retry + predicates, 413 batch splitting, and client code keep seeing + ``httpx.HTTPStatusError`` / ``CallsCompleteModeRequired`` exactly as + before. + """ + if response.status_code >= 400: + # Event hooks fire before the body is read; load it so + # handle_response_error can inspect the error payload. + response.read() + handle_response_error(response, str(response.request.url)) + def ensure_project_exists( self, entity: str, project: str ) -> tsi.EnsureProjectExistsRes: @@ -106,190 +222,129 @@ def from_env(cls, should_batch: bool = False) -> Self: return cls(weave_trace_server_url(), should_batch) def set_auth(self, auth: tuple[str, str]) -> None: - """Set authentication credentials. - - Args: - auth: Tuple of (username, password) for authentication. - """ - self._username, self._password = auth - # Recreate stainless client with new credentials - default_headers = self._extra_headers.copy() - if retry_id := get_current_retry_id(): - default_headers["X-Weave-Retry-Id"] = retry_id - - self._stainless_client = StainlessClient( - base_url=self.trace_server_url, - username=self._username, - password=self._password, - default_headers=default_headers, - batch_requests=False, # We handle batching ourselves - ) + self._auth = auth + self._http.auth = auth - def _update_client_headers(self) -> None: - """Update client headers with current retry ID and extra headers.""" - headers = self._extra_headers.copy() - if retry_id := get_current_retry_id(): - headers["X-Weave-Retry-Id"] = retry_id - if headers: - self._stainless_client = self._stainless_client.copy( - default_headers=headers - ) + # ---- request helpers ---------------------------------------------------- - def _stainless_request( + @with_retry + def _via_sdk( self, req: BaseModel, + sdk_req_type: type[BaseModel], + sdk_method: Callable[..., Any], res_type: type[TRes], - stainless_api: Callable[..., Any], - *, - exclude: set[str] | None = None, - **extra_kwargs: Any, + **path_args: Any, ) -> TRes: - """Helper method to make a stainless API request with proper type conversion. + """Round-trip a tsi request through the typed SDK binding. - Args: - req: Request object (already validated by @validate_call). - res_type: Type of the response model. - stainless_api: Stainless API callable to invoke. - exclude: Set of field names to exclude from request dump. - **extra_kwargs: Additional keyword arguments to pass to the API. - - Returns: - Validated response model instance. + ``by_alias`` is required since query models have Mongo-style properties + aliased to start with ``$``. """ - self._update_client_headers() - - dump_kwargs: dict[str, Any] = {"by_alias": True} - exclude_set = set(extra_kwargs.keys()) - if exclude: - exclude_set.update(exclude) - if exclude_set: - dump_kwargs["exclude"] = exclude_set + body = sdk_req_type.model_validate(req.model_dump(by_alias=True)) + sdk_res = sdk_method(body, **path_args) + if sdk_res is None: + return res_type() + return res_type.model_validate(sdk_res.model_dump(by_alias=True)) - req_dict = req.model_dump(**dump_kwargs) - response = stainless_api(**req_dict, **extra_kwargs) - return res_type.model_validate(response.model_dump()) - - def _stainless_request_object( + @with_retry + def _via_sdk_no_body( self, - req: BaseModel, + sdk_method: Callable[..., Any], res_type: type[TRes], - stainless_api: Callable[..., Any], - *, - exclude: set[str] | None = None, - **extra_kwargs: Any, + **call_args: Any, ) -> TRes: - """Helper method for Object API requests that split project_id into entity/project. - - Args: - req: Request object (already validated by @validate_call). - res_type: Type of the response model. - stainless_api: Stainless API callable to invoke. - exclude: Set of field names to exclude from request dump. - **extra_kwargs: Additional keyword arguments to pass to the API. - - Returns: - Validated response model instance. - """ - self._update_client_headers() - entity, project = from_project_id(req.project_id) + """Call an SDK binding that takes only path/query arguments.""" + sdk_res = sdk_method(**call_args) + if sdk_res is None: + return res_type() + return res_type.model_validate(sdk_res.model_dump(by_alias=True)) - exclude_set = {"project_id"} - if exclude: - exclude_set.update(exclude) - exclude_set.update(extra_kwargs.keys()) - - dump_kwargs: dict[str, Any] = {"by_alias": True} - if exclude_set: - dump_kwargs["exclude"] = exclude_set + @with_retry + def _raw_request( + self, + method: str, + path: str, + *, + req: BaseModel | None = None, + params: dict[str, Any] | None = None, + res_type: type[TRes], + ) -> TRes: + """Call an endpoint the SDK cannot reach (see module docstring).""" + r = self._http.request( + method, + path, + content=req.model_dump_json(by_alias=True).encode("utf-8") + if req is not None + else None, + headers={"content-type": "application/json"} if req is not None else None, + params=params, + ) + return res_type.model_validate(r.json()) - req_dict = req.model_dump(**dump_kwargs) - response = stainless_api( - entity=entity, project=project, **req_dict, **extra_kwargs + @with_retry + def _raw_post_bytes(self, path: str, encoded_data: bytes) -> None: + """POST pre-encoded json bytes (batch flush hot path; no re-parse).""" + self._http.post( + path, content=encoded_data, headers={"content-type": "application/json"} ) - return res_type.model_validate(response.model_dump()) - - def _prepare_v2_request(self, req: BaseModel) -> tuple[str, str]: - """Prepare v2 API request by updating headers and splitting project_id. - Args: - req: Request object with project_id attribute. - - Returns: - Tuple of (entity, project) from split project_id. - - Examples: - >>> entity, project = self._prepare_v2_request(req) - """ - self._update_client_headers() - return from_project_id(req.project_id) - - def _stainless_list_object( + def _raw_stream( self, - req: BaseModel, - res_type: type[TRes], - stainless_api: Callable[..., Any], + method: str, + path: str, *, - exclude: set[str] | None = None, - **extra_kwargs: Any, + req: BaseModel | None = None, + params: dict[str, Any] | None = None, + res_type: type[TRes], ) -> Iterator[TRes]: - """Helper method for Object API list requests that split project_id into entity/project. + """Stream a jsonl endpoint line-by-line (the SDK buffers jsonl).""" + r = self._open_stream(method, path, req=req, params=params) + try: + for line in r.iter_lines(): + if line: + yield res_type.model_validate_json(line) + finally: + r.close() - Args: - req: Request object (already validated by @validate_call). - res_type: Type of the response model to yield. - stainless_api: Stainless API callable to invoke. - exclude: Set of field names to exclude from request dump. - **extra_kwargs: Additional keyword arguments to pass to the API. + @with_retry + def _open_stream( + self, + method: str, + path: str, + *, + req: BaseModel | None = None, + params: dict[str, Any] | None = None, + ) -> httpx.Response: + """Open a streaming response; retries cover connection/headers only. - Yields: - Validated response model instances of type res_type. + Mid-stream failures are not retried, matching RemoteHTTPTraceServer. + The caller owns the returned response and must close() it. """ - self._update_client_headers() - entity, project = from_project_id(req.project_id) - - exclude_set = {"project_id"} - if exclude: - exclude_set.update(exclude) - exclude_set.update(extra_kwargs.keys()) - - dump_kwargs: dict[str, Any] = {"by_alias": True, "exclude_none": True} - if exclude_set: - dump_kwargs["exclude"] = exclude_set - - req_dict = req.model_dump(**dump_kwargs) - response = stainless_api( - entity=entity, project=project, **req_dict, **extra_kwargs + request = self._http.build_request( + method, + path, + content=req.model_dump_json(by_alias=True).encode("utf-8") + if req is not None + else None, + headers={"content-type": "application/json"} if req is not None else None, + params=params, ) - for item in response: - yield res_type.model_validate(item) + return self._http.send(request, stream=True) + + # ---- batching ----------------------------------------------------------- @with_retry def _send_batch_to_server(self, encoded_data: bytes) -> None: - """Send an encoded batch of calls to the server using the stainless client. + """Send an encoded batch of calls to the server with retry logic. - Args: - encoded_data: Encoded batch data to send. + Separated from _flush_calls to avoid recursive retries. """ - self._update_client_headers() - # Parse the batch and convert to stainless format - batch_data = Batch.model_validate_json(encoded_data.decode("utf-8")) - stainless_batch = [] - for item in batch_data.batch: - if isinstance(item, StartBatchItem): - stainless_batch.append( - { - "mode": "start", - "req": item.req.model_dump(by_alias=True), - } - ) - elif isinstance(item, EndBatchItem): - stainless_batch.append( - { - "mode": "end", - "req": item.req.model_dump(by_alias=True), - } - ) - self._stainless_client.calls.upsert_batch(batch=stainless_batch) + self._http.post( + CALL_UPSERT_BATCH_PATH, + content=encoded_data, + headers={"content-type": "application/json"}, + ) def _flush_calls( self, @@ -299,9 +354,9 @@ def _flush_calls( ) -> None: """Process a batch of calls, splitting if necessary and sending to the server. - Args: - batch: List of batch items to process. - _should_update_batch_size: Whether to update batch size based on response. + This method handles the logic of splitting batches that are too large, + but delegates the actual server communication (with retries) to + _send_batch_to_server. """ assert self.call_processor is not None if len(batch) == 0: @@ -318,11 +373,178 @@ def encode_batch(batch: list[StartBatchItem | EndBatchItem]) -> bytes: data = Batch(batch=batch).model_dump_json() return data.encode("utf-8") + try: + process_batch_with_retry( + batch_name="calls", + batch=batch, + remote_request_bytes_limit=self.remote_request_bytes_limit, + send_batch_fn=self._send_batch_to_server, + processor_obj=self.call_processor, + should_update_batch_size=_should_update_batch_size, + get_item_id_fn=get_item_id, + log_dropped_fn=log_dropped_call_batch, + encode_batch_fn=encode_batch, + ) + except CallsCompleteModeRequired as e: + # Project requires calls_complete mode - upgrade and re-enqueue the batch + self._upgrade_to_calls_complete(batch, str(e)) + + def _upgrade_to_calls_complete( + self, batch: list[StartBatchItem | EndBatchItem], error_message: str + ) -> None: + """Upgrade from legacy AsyncBatchProcessor to CallBatchProcessor. + + This is called when the server indicates a project requires + calls_complete mode. The upgrade happens transparently: we replace the + processor and re-enqueue the current batch items. No calls are dropped + during this upgrade. + """ + # Already upgraded? Just re-enqueue to the new processor + if self.use_calls_complete: + if isinstance(self.call_processor, CallBatchProcessor): + self.call_processor.enqueue( + cast(list[StartBatchItem | EndBatchItem | CompleteBatchItem], batch) + ) + return + + logger.warning( + "Project has been previously written to with `use_calls_complete=True` and requires 'calls_complete' mode. Automatically upgrading SDK to use the more performant calls_complete processor. Server message: %s", + error_message, + ) + + old_processor = self.call_processor + + self.use_calls_complete = True + self.call_processor = CallBatchProcessor( + complete_processor_fn=self._flush_calls_complete, + eager_processor_fn=self._flush_calls_eager, + max_queue_size=max_calls_queue_size(), + enable_disk_fallback=should_enable_disk_fallback(), + ) + + # Re-enqueue the batch items to the new processor + # Cast needed: list is invariant, but StartBatchItem | EndBatchItem is a valid subset of BatchItem + self.call_processor.enqueue( + cast(list[StartBatchItem | EndBatchItem | CompleteBatchItem], batch) + ) + + # Stop the old processor gracefully - any remaining items in its queue + # will be caught by _flush_calls which will re-enqueue them to the + # new processor via this same method (the "already upgraded" path above) + if old_processor is not None: + old_processor.stop_accepting_work_event.set() + + def _flush_calls_eager( + self, + batch: list[StartBatchItem | EndBatchItem], + *, + _should_update_batch_size: bool = True, + ) -> None: + """Process eager start/end items via v2 single endpoints. + + This is used for ops like Evaluation.evaluate that need their start + to be visible immediately in the UI. Uses single call/start and call/end + endpoints for easier rate limiting. + + Each item is sent individually with retry logic (@with_retry). If all + retries are exhausted, the item is logged and dropped, then processing + continues with remaining items in the batch. + """ + for item in batch: + try: + if isinstance(item, StartBatchItem): + self._send_call_start_v2(item.req.start) + elif isinstance(item, EndBatchItem): + self._send_call_end_v2(item.req.end) + except CallsCompleteModeRequired: + # Re-raise so caller can handle the upgrade to calls_complete mode + raise + except Exception as e: + log_dropped_call_batch([item], e) + + @with_retry + def _send_call_start_v2(self, start: tsi.StartedCallSchemaForInsert) -> None: + """Send a single call start to the v2 endpoint.""" + entity, project = from_project_id(start.project_id) + req = tsi.CallStartV2Req(start=start) + self._http.post( + CALL_START_V2_PATH.format(entity=entity, project=project), + content=req.model_dump_json().encode("utf-8"), + headers={"content-type": "application/json"}, + ) + + @with_retry + def _send_call_end_v2(self, end: tsi.EndedCallSchemaForInsertWithStartedAt) -> None: + """Send a single call end to the v2 endpoint.""" + entity, project = from_project_id(end.project_id) + req = tsi.CallEndV2Req(end=end) + self._http.post( + CALL_END_V2_PATH.format(entity=entity, project=project), + content=req.model_dump_json().encode("utf-8"), + headers={"content-type": "application/json"}, + ) + + def _extract_entity_project( + self, batch: list[CompleteBatchItem] + ) -> EntityProjectInfo: + """Extract entity, project, and project_id from first batch item.""" + if not batch: + raise ValueError("Cannot extract entity/project from empty batch") + + first_item = batch[0] + project_id = first_item.req.project_id + + if not project_id or "/" not in project_id: + raise ValueError( + f"Invalid project_id format: {project_id}. Expected 'entity/project'" + ) + + entity, project = project_id.split("/", 1) + if not entity or not project: + raise ValueError(f"Invalid project_id: {project_id}") + + return EntityProjectInfo(entity=entity, project=project, project_id=project_id) + + def _send_calls_complete_to_server( + self, entity: str, project: str, encoded_data: bytes + ) -> None: + """Send a batch of completed calls to the server with retry logic.""" + self._raw_post_bytes( + CALLS_COMPLETE_PATH.format(entity=entity, project=project), encoded_data + ) + + def _flush_calls_complete( + self, + batch: list[CompleteBatchItem], + *, + _should_update_batch_size: bool = True, + ) -> None: + """Process a batch of complete calls and send to the calls/complete endpoint. + + This is the new calls_complete path. Complete calls have both start and + end information bundled together. + """ + assert self.call_processor is not None + if not batch: + return + + ep_info = self._extract_entity_project(batch) + + def get_item_id(item: CompleteBatchItem) -> str: + return f"{item.req.id}-complete" + + def encode_batch(batch: list[CompleteBatchItem]) -> bytes: + api_batch = [item.req for item in batch] + req = tsi.CallsUpsertCompleteReq(batch=api_batch) + return req.model_dump_json().encode("utf-8") + process_batch_with_retry( - batch_name="calls", + batch_name="calls_complete", batch=batch, remote_request_bytes_limit=self.remote_request_bytes_limit, - send_batch_fn=self._send_batch_to_server, + send_batch_fn=lambda data: self._send_calls_complete_to_server( + ep_info.entity, ep_info.project, data + ), processor_obj=self.call_processor, should_update_batch_size=_should_update_batch_size, get_item_id_fn=get_item_id, @@ -330,22 +552,34 @@ def encode_batch(batch: list[StartBatchItem | EndBatchItem]) -> bytes: encode_batch_fn=encode_batch, ) - def get_call_processor(self) -> AsyncBatchProcessor | None: - """Get the call processor for batching. - - Returns: - AsyncBatchProcessor instance or None if batching is disabled. + def get_call_processor(self) -> AsyncBatchProcessor | CallBatchProcessor | None: + """Custom method not defined on the formal TraceServerInterface to expose + the underlying call processor. Should be formalized in a client-side interface. """ return self.call_processor + def _send_feedback_batch_to_server(self, encoded_data: bytes) -> None: + """Send a batch of feedback data to the server. + + No request-level retry here: failures are classified by the caller + (404 falls back to individual creates; retryable errors requeue at the + batch-processor level). + """ + self._http.post( + FEEDBACK_BATCH_CREATE_PATH, + content=encoded_data, + headers={"content-type": "application/json"}, + ) + def _flush_feedback( self, batch: list[tsi.FeedbackCreateReq], ) -> None: """Process a batch of feedback, splitting if necessary and sending to the server. - Args: - batch: List of feedback requests to process. + This method handles the logic of splitting batches that are too large, + but delegates the actual server communication (with retries) to + _send_feedback_batch_to_server. """ assert self.feedback_processor is not None if len(batch) == 0: @@ -360,37 +594,42 @@ def encode_batch(batch: list[tsi.FeedbackCreateReq]) -> bytes: return data.encode("utf-8") def send_feedback_batch(encoded_data: bytes) -> None: - self._update_client_headers() try: - batch_req = tsi.FeedbackCreateBatchReq.model_validate_json( - encoded_data.decode("utf-8") - ) - # Convert to stainless format - stainless_batch = [ - item.model_dump(exclude={"id", "created_at"}, exclude_none=True) - for item in batch_req.batch - ] - self._stainless_client.feedback.batch_create(batch=stainless_batch) - except Exception as e: + self._send_feedback_batch_to_server(encoded_data) + except (httpx.HTTPError, httpx.HTTPStatusError) as e: # If batching endpoint doesn't exist (404) fall back to individual calls - if hasattr(e, "status_code") and e.status_code == 404: + if ( + response := getattr(e, "response", None) + ) and response.status_code == 404: logger.debug( "Batching endpoint not available, falling back to individual feedback creation: %s", e, ) + + # Feedback endpoint doesn't support id, created_at, so we need to strip them + class FeedbackCreateReqStripped(tsi.FeedbackCreateReq): + id: SkipJsonSchema[str] = Field(exclude=True) + created_at: SkipJsonSchema[datetime.datetime | None] = Field( + exclude=True, default=None + ) + # Fall back to individual feedback creation calls - for item in batch_req.batch: + for item in batch: + item_copy = FeedbackCreateReqStripped(**item.model_dump()) try: - item_dict = item.model_dump( - exclude={"id", "created_at"}, exclude_none=True + self._raw_request( + "POST", + FEEDBACK_CREATE_PATH, + req=item_copy, + res_type=tsi.FeedbackCreateRes, ) - self._stainless_client.feedback.create(**item_dict) except Exception as individual_error: logger.warning( "Failed to create individual feedback: %s", individual_error, ) else: + # Re-raise server errors (5xx) as they're not client compatibility issues raise process_batch_with_retry( @@ -406,421 +645,218 @@ def send_feedback_batch(encoded_data: bytes) -> None: ) def get_feedback_processor(self) -> AsyncBatchProcessor | None: - """Get the feedback processor for batching. - - Returns: - AsyncBatchProcessor instance or None if batching is disabled. + """Custom method not defined on the formal TraceServerInterface to expose + the underlying feedback processor. Should be formalized in a client-side interface. """ return self.feedback_processor - def server_info(self) -> ServerInfoRes: - """Get server information. + # ---- service ------------------------------------------------------------ - Returns: - ServerInfoRes with server information. - """ - self._update_client_headers() - response = self._stainless_client.services.server_info() - return ServerInfoRes.model_validate(response.model_dump()) + @with_retry + def server_info(self) -> ServerInfoRes: + res = self._sdk.services.server_info() + return ServerInfoRes.model_validate(res.model_dump()) @validate_call + @with_retry def projects_info(self, req: tsi.ProjectsInfoReq) -> list[tsi.ProjectsInfoRes]: - self._update_client_headers() - response = self._stainless_client.services.projects_info( - project_ids=req.project_ids, - ) - return [ - tsi.ProjectsInfoRes.model_validate(item.model_dump()) for item in response - ] + body = sdk_models.ProjectsInfoReq.model_validate(req.model_dump()) + res = self._sdk.service.create_projects_info(body) + return [tsi.ProjectsInfoRes.model_validate(item.model_dump()) for item in res] - @validate_call def otel_export(self, req: tsi.OTelExportReq) -> tsi.OTelExportRes: - """Export OTEL traces. - - Args: - req: OTEL export request. - - Returns: - OTEL export response. - - Raises: - NotImplementedError: OTEL export is not yet supported. - """ + # TODO: Add docs link (DOCS-1390) raise NotImplementedError("Sending otel traces directly is not yet supported.") - # Call API + # ---- Call API ------------------------------------------------------------- + @validate_call def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: - """Start a call. - - Args: - req: Call start request. - - Returns: - Call start response. - """ if self.should_batch: assert self.call_processor is not None + if req.start.id is None or req.start.trace_id is None: raise ValueError( "CallStartReq must have id and trace_id when batching." ) self.call_processor.enqueue_start(StartBatchItem(req=req)) return tsi.CallStartRes(id=req.start.id, trace_id=req.start.trace_id) - - return self._stainless_request( - req, tsi.CallStartRes, self._stainless_client.calls.start + return self._via_sdk( + req, sdk_models.CallStartReq, self._sdk.calls.start, tsi.CallStartRes ) - @validate_call def call_start_batch(self, req: tsi.CallCreateBatchReq) -> tsi.CallCreateBatchRes: - """Start a batch of calls. - - Args: - req: Batch create request. - - Returns: - Batch create response. - """ - self._update_client_headers() - # Convert to stainless format - stainless_batch = [] - for item in req.batch: - if item.mode == "start": - stainless_batch.append( - { - "mode": "start", - "req": item.req.model_dump(by_alias=True), - } - ) - elif item.mode == "end": - stainless_batch.append( - { - "mode": "end", - "req": item.req.model_dump(by_alias=True), - } - ) - response = self._stainless_client.calls.upsert_batch(batch=stainless_batch) - # Convert response back - res_items = [] - for item in response.batch: - if hasattr(item, "id"): # CallStartRes - res_items.append(tsi.CallStartRes.model_validate(item.model_dump())) - else: # CallEndRes - res_items.append(tsi.CallEndRes()) - return tsi.CallCreateBatchRes(res=res_items) + return self._via_sdk( + req, + sdk_models.CallCreateBatchReq, + self._sdk.calls.upsert_batch, + tsi.CallCreateBatchRes, + ) @validate_call def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: - """End a call. - - Args: - req: Call end request. - - Returns: - Call end response. - """ if self.should_batch: assert self.call_processor is not None + self.call_processor.enqueue([EndBatchItem(req=req)]) return tsi.CallEndRes() - - return self._stainless_request( - req, tsi.CallEndRes, self._stainless_client.calls.end + return self._via_sdk( + req, sdk_models.CallEndReq, self._sdk.calls.end, tsi.CallEndRes ) @validate_call def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes: - """Read a call. - - Args: - req: Call read request. - - Returns: - Call read response. - """ - return self._stainless_request( - req, tsi.CallReadRes, self._stainless_client.calls.read + return self._via_sdk( + req, sdk_models.CallReadReq, self._sdk.calls.read, tsi.CallReadRes ) @validate_call def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: - """Query calls. - - Args: - req: Calls query request. - - Returns: - Calls query response. - """ + # This previously called the deprecated /calls/query endpoint. return tsi.CallsQueryRes(calls=list(self.calls_query_stream(req))) @validate_call def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]: - """Stream query calls. - - Args: - req: Calls query request. - - Yields: - CallSchema instances. - """ - self._update_client_headers() - req_dict = req.model_dump(by_alias=True) - # Use stream_query endpoint - response = self._stainless_client.calls.stream_query(**req_dict) - for item in response: - yield tsi.CallSchema.model_validate(item) + return self._raw_stream( + "POST", CALLS_STREAM_QUERY_PATH, req=req, res_type=tsi.CallSchema + ) @validate_call def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: - """Query call statistics. - - Args: - req: Calls query stats request. - - Returns: - Calls query stats response. - """ - return self._stainless_request( + return self._via_sdk( req, + sdk_models.CallsQueryStatsReq, + self._sdk.calls.query_stats, tsi.CallsQueryStatsRes, - self._stainless_client.calls.query_stats, ) @validate_call def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes: - """Compute per-call usage with descendant rollup. - - Args: - req: Trace usage request. - - Returns: - Trace usage response. - - Examples: - >>> server = StainlessRemoteHTTPTraceServer("http://example.com") - >>> req = tsi.TraceUsageReq(project_id="entity/project") - >>> _ = server.trace_usage(req) # doctest: +SKIP - """ - return self._stainless_request( - req, - tsi.TraceUsageRes, - self._stainless_client.trace.usage, + # SDK 0.0.1: calls.create_usage for /trace/usage is shadowed by the + # /calls/usage overload of the same generated name. + return self._raw_request( + "POST", TRACE_USAGE_PATH, req=req, res_type=tsi.TraceUsageRes ) @validate_call def calls_usage(self, req: tsi.CallsUsageReq) -> tsi.CallsUsageRes: - """Compute aggregated usage for multiple root calls. - - Args: - req: Calls usage request. - - Returns: - Calls usage response. - """ - return self._stainless_request( + return self._via_sdk( req, + sdk_models.CallsUsageReq, + self._sdk.calls.create_usage, tsi.CallsUsageRes, - self._stainless_client.calls.usage, ) @validate_call def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: - """Delete calls. - - Args: - req: Calls delete request. - - Returns: - Calls delete response. - """ - return self._stainless_request( - req, - tsi.CallsDeleteRes, - self._stainless_client.calls.delete, + return self._via_sdk( + req, sdk_models.CallsDeleteReq, self._sdk.calls.delete, tsi.CallsDeleteRes ) @validate_call def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: - """Update a call. - - Args: - req: Call update request. - - Returns: - Call update response. - """ - return self._stainless_request( - req, - tsi.CallUpdateRes, - self._stainless_client.calls.update, + return self._via_sdk( + req, sdk_models.CallUpdateReq, self._sdk.calls.update, tsi.CallUpdateRes ) - # Obj API + # ---- Obj API -------------------------------------------------------------- + @validate_call def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: - """Create an object. - - Args: - req: Object create request. - - Returns: - Object create response. - """ - return self._stainless_request( - req, - tsi.ObjCreateRes, - self._stainless_client.objects.create, + return self._via_sdk( + req, sdk_models.ObjCreateReq, self._sdk.objects.create, tsi.ObjCreateRes ) @validate_call def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: - """Read an object. - - Args: - req: Object read request. - - Returns: - Object read response. - """ - return self._stainless_request( - req, tsi.ObjReadRes, self._stainless_client.objects.read + return self._via_sdk( + req, sdk_models.ObjReadReq, self._sdk.objects.read, tsi.ObjReadRes ) @validate_call def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: - """Query objects. - - Args: - req: Object query request. - - Returns: - Object query response. - """ - return self._stainless_request( - req, - tsi.ObjQueryRes, - self._stainless_client.objects.query, + return self._via_sdk( + req, sdk_models.ObjQueryReq, self._sdk.objects.query, tsi.ObjQueryRes ) - @validate_call def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: - """Delete an object. - - Args: - req: Object delete request. - - Returns: - Object delete response. - """ - return self._stainless_request( - req, - tsi.ObjDeleteRes, - self._stainless_client.objects.delete, + return self._via_sdk( + req, sdk_models.ObjDeleteReq, self._sdk.objects.delete, tsi.ObjDeleteRes ) - # Tag and Alias API - # NOTE: These methods require the Stainless SDK to include tag/alias endpoints. - # Until the SDK spec is updated, these will raise NotImplementedError at call time. def obj_add_tags(self, req: tsi.ObjAddTagsReq) -> tsi.ObjAddTagsRes: - try: - return self._stainless_request( - req, tsi.ObjAddTagsRes, self._stainless_client.objects.tags.add - ) - except AttributeError: - raise NotImplementedError( - "Tag operations are not yet supported by the Stainless SDK. " - "Please upgrade the SDK or use RemoteHTTPTraceServer instead." - ) from None + # SDK 0.0.1: objects.tags for add-tags is shadowed by the /tags list + # overload of the same generated name. + body = sdk_models.ObjTagsBody(project_id=req.project_id, tags=req.tags) + return self._raw_request( + "PUT", + OBJ_ADD_TAGS_PATH.format(object_id=req.object_id, digest=req.digest), + req=body, + res_type=tsi.ObjAddTagsRes, + ) def obj_remove_tags(self, req: tsi.ObjRemoveTagsReq) -> tsi.ObjRemoveTagsRes: - try: - return self._stainless_request( - req, tsi.ObjRemoveTagsRes, self._stainless_client.objects.tags.remove - ) - except AttributeError: - raise NotImplementedError( - "Tag operations are not yet supported by the Stainless SDK. " - "Please upgrade the SDK or use RemoteHTTPTraceServer instead." - ) from None + # SDK 0.0.1: objects.create_remove for remove-tags is shadowed by the + # remove-aliases overload of the same generated name. + body = sdk_models.ObjTagsBody(project_id=req.project_id, tags=req.tags) + return self._raw_request( + "POST", + OBJ_REMOVE_TAGS_PATH.format(object_id=req.object_id, digest=req.digest), + req=body, + res_type=tsi.ObjRemoveTagsRes, + ) def obj_set_aliases(self, req: tsi.ObjSetAliasesReq) -> tsi.ObjSetAliasesRes: - try: - return self._stainless_request( - req, tsi.ObjSetAliasesRes, self._stainless_client.objects.aliases.set - ) - except AttributeError: - raise NotImplementedError( - "Alias operations are not yet supported by the Stainless SDK. " - "Please upgrade the SDK or use RemoteHTTPTraceServer instead." - ) from None + body = sdk_models.ObjSetAliasesBody( + project_id=req.project_id, digest=req.digest, aliases=req.aliases + ) + res = self._via_sdk_no_body( + self._sdk.objects.update_aliases, + tsi.ObjSetAliasesRes, + body=body, + object_id=req.object_id, + ) + return res def obj_remove_aliases( self, req: tsi.ObjRemoveAliasesReq ) -> tsi.ObjRemoveAliasesRes: - try: - return self._stainless_request( - req, - tsi.ObjRemoveAliasesRes, - self._stainless_client.objects.aliases.remove, - ) - except AttributeError: - raise NotImplementedError( - "Alias operations are not yet supported by the Stainless SDK. " - "Please upgrade the SDK or use RemoteHTTPTraceServer instead." - ) from None + body = sdk_models.ObjRemoveAliasesBody( + project_id=req.project_id, aliases=req.aliases + ) + return self._via_sdk_no_body( + self._sdk.objects.create_remove, + tsi.ObjRemoveAliasesRes, + body=body, + object_id=req.object_id, + ) def tags_list(self, req: tsi.TagsListReq) -> tsi.TagsListRes: - try: - return self._stainless_request( - req, tsi.TagsListRes, self._stainless_client.objects.tags.list - ) - except AttributeError: - raise NotImplementedError( - "Tag operations are not yet supported by the Stainless SDK. " - "Please upgrade the SDK or use RemoteHTTPTraceServer instead." - ) from None + return self._via_sdk_no_body( + self._sdk.objects.tags, tsi.TagsListRes, project_id=req.project_id + ) def aliases_list(self, req: tsi.AliasesListReq) -> tsi.AliasesListRes: - try: - return self._stainless_request( - req, tsi.AliasesListRes, self._stainless_client.objects.aliases.list - ) - except AttributeError: - raise NotImplementedError( - "Alias operations are not yet supported by the Stainless SDK. " - "Please upgrade the SDK or use RemoteHTTPTraceServer instead." - ) from None + return self._via_sdk_no_body( + self._sdk.objects.list_aliases, + tsi.AliasesListRes, + project_id=req.project_id, + ) + + # ---- Table API ------------------------------------------------------------ - # Table API @validate_call def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: - """Create a table. - - Args: - req: Table create request. - - Returns: - Table create response. - """ - return self._stainless_request( - req, - tsi.TableCreateRes, - self._stainless_client.tables.create, + return self._via_sdk( + req, sdk_models.TableCreateReq, self._sdk.tables.create, tsi.TableCreateRes ) @validate_call def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: - """Update a table. - - Args: - req: Table update request. - - Returns: - Table update response. + """Similar to `calls/batch_upsert`, we can dynamically adjust the payload size + due to the property that table updates can be decomposed into a series of + updates. """ - # Handle large requests by splitting estimated_bytes = len(req.model_dump_json(by_alias=True).encode("utf-8")) if estimated_bytes > self.remote_request_bytes_limit and len(req.updates) > 1: split_ndx = len(req.updates) // 2 @@ -843,188 +879,110 @@ def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: digest=second_half_res.digest, updated_row_digests=all_digests ) else: - return self._stainless_request( + return self._via_sdk( req, + sdk_models.TableUpdateReq, + self._sdk.tables.update, tsi.TableUpdateRes, - self._stainless_client.tables.update, ) @validate_call def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: - """Query a table. - - Args: - req: Table query request. - - Returns: - Table query response. - """ - return self._stainless_request( - req, - tsi.TableQueryRes, - self._stainless_client.tables.query, + return self._via_sdk( + req, sdk_models.TableQueryReq, self._sdk.tables.query, tsi.TableQueryRes ) @validate_call def table_query_stream( self, req: tsi.TableQueryReq ) -> Iterator[tsi.TableRowSchema]: - """Stream query a table. - - Args: - req: Table query request. - - Yields: - TableRowSchema instances. - """ # Need to manually iterate over this until the stream endpoint is built and shipped. res = self.table_query(req) yield from res.rows @validate_call def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes: - """Query table statistics. - - Args: - req: Table query stats request. - - Returns: - Table query stats response. - """ - return self._stainless_request( + return self._via_sdk( req, + sdk_models.TableQueryStatsReq, + self._sdk.tables.query_stats, tsi.TableQueryStatsRes, - self._stainless_client.tables.query_stats, ) @validate_call def table_create_from_digests( self, req: tsi.TableCreateFromDigestsReq ) -> tsi.TableCreateFromDigestsRes: - """Create a table from digests. - - Args: - req: Table create from digests request. - - Returns: - Table create from digests response. - """ - return self._stainless_request( + """Create a table by specifying row digests instead of actual rows.""" + return self._via_sdk( req, + sdk_models.TableCreateFromDigestsReq, + self._sdk.tables.create_create_from_digests, tsi.TableCreateFromDigestsRes, - self._stainless_client.tables.create_from_digests, ) @validate_call def table_query_stats_batch( - self, req: tsi.TableQueryStatsBatchReq - ) -> tsi.TableQueryStatsBatchRes: - """Query table statistics in batch. - - Args: - req: Table query stats batch request. - - Returns: - Table query stats batch response. - """ - return self._stainless_request( + self, req: tsi.TableQueryStatsReq + ) -> tsi.TableQueryStatsRes: + return self._via_sdk( req, + sdk_models.TableQueryStatsBatchReq, + self._sdk.tables.create_query_stats_batch, tsi.TableQueryStatsBatchRes, - self._stainless_client.tables.query_stats_batch, ) @validate_call def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: - """Read refs in batch. - - Args: - req: Refs read batch request. - - Returns: - Refs read batch response. - """ - return self._stainless_request( + return self._via_sdk( req, + sdk_models.RefsReadBatchReq, + self._sdk.refs.read_batch, tsi.RefsReadBatchRes, - self._stainless_client.refs.read_batch, ) - @validate_call - def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: - """Create a file. - - Args: - req: File create request. + # ---- File API ------------------------------------------------------------- - Returns: - File create response. - """ - self._update_client_headers() - # Files API uses multipart/form-data - stainless expects (filename, content) tuple - file_tuple = (req.name, req.content) - kwargs: dict[str, Any] = { - "file": file_tuple, - "project_id": req.project_id, - } + @with_retry + def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: + # SDK 0.0.1: files.create lost its multipart body in generation; post + # the multipart form directly. + data: dict[str, str] = {"project_id": req.project_id} if req.expected_digest is not None: - kwargs["expected_digest"] = req.expected_digest - response = self._stainless_client.files.create(**kwargs) - return tsi.FileCreateRes.model_validate(response.model_dump()) + data["expected_digest"] = req.expected_digest + r = self._http.post( + FILE_CREATE_PATH, + data=data, + files={"file": (req.name, req.content)}, + ) + return tsi.FileCreateRes.model_validate(r.json()) - @validate_call + @with_retry def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: - """Read file content. - - Args: - req: File content read request. - - Returns: - File content read response. - """ - self._update_client_headers() - response = self._stainless_client.files.content( - digest=req.digest, project_id=req.project_id - ) - # TODO: Should stream to disk rather than to memory - bytes_content = io.BytesIO() - # BinaryAPIResponse has content property or we can read it directly - if hasattr(response, "content"): - bytes_content.write(response.content) - elif hasattr(response, "iter_bytes"): - for chunk in response.iter_bytes(): - bytes_content.write(chunk) - else: - # Fallback: read from raw response - bytes_content.write(response.read()) - bytes_content.seek(0) - return tsi.FileContentReadRes(content=bytes_content.read()) + # Raw to keep the response streamed to a buffer rather than the SDK's + # whole-body bytes return. + r = self._open_stream("POST", FILE_CONTENT_PATH, req=req) + try: + # TODO: Should stream to disk rather than to memory + bytes_buffer = io.BytesIO() + for chunk in r.iter_bytes(): + bytes_buffer.write(chunk) + finally: + r.close() + return tsi.FileContentReadRes(content=bytes_buffer.getvalue()) - @validate_call def files_stats(self, req: tsi.FilesStatsReq) -> tsi.FilesStatsRes: - """Get file statistics. - - Args: - req: Files stats request. - - Returns: - Files stats response. - """ - return self._stainless_request( + return self._via_sdk( req, + sdk_models.FilesStatsReq, + self._sdk.files.query_stats, tsi.FilesStatsRes, - self._stainless_client.files.stats, ) + # ---- Feedback API ----------------------------------------------------------- + @validate_call def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: - """Create feedback. - - Args: - req: Feedback create request. - - Returns: - Feedback create response. - """ if self.should_batch: assert self.feedback_processor is not None @@ -1040,981 +998,775 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: payload=req.payload, ) else: - self._update_client_headers() - req_dict = req.model_dump( - exclude={"id", "created_at"}, exclude_none=True, by_alias=True + # SDK 0.0.1: feedback.create for single create is shadowed by the + # batch-create overload of the same generated name. + return self._raw_request( + "POST", FEEDBACK_CREATE_PATH, req=req, res_type=tsi.FeedbackCreateRes ) - response = self._stainless_client.feedback.create(**req_dict) - return tsi.FeedbackCreateRes.model_validate(response.model_dump()) - @validate_call def feedback_create_batch( self, req: tsi.FeedbackCreateBatchReq ) -> tsi.FeedbackCreateBatchRes: - """Create feedback in batch. - - Args: - req: Feedback create batch request. - - Returns: - Feedback create batch response. - """ - self._update_client_headers() - # Convert to stainless format - stainless_batch = [ - item.model_dump( - exclude={"id", "created_at"}, exclude_none=True, by_alias=True - ) - for item in req.batch - ] - response = self._stainless_client.feedback.batch_create(batch=stainless_batch) - return tsi.FeedbackCreateBatchRes.model_validate(response.model_dump()) + # Note: the SDK method is named `create` for /feedback/batch/create + # (duplicate-name shadowing in 0.0.1; the batch overload won). + return self._via_sdk( + req, + sdk_models.FeedbackCreateBatchReq, + self._sdk.feedback.create, + tsi.FeedbackCreateBatchRes, + ) @validate_call def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: - """Query feedback. - - Args: - req: Feedback query request. - - Returns: - Feedback query response. - """ - return self._stainless_request( + return self._via_sdk( req, + sdk_models.FeedbackQueryReq, + self._sdk.feedback.query, tsi.FeedbackQueryRes, - self._stainless_client.feedback.query, ) @validate_call def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: - """Purge feedback. - - Args: - req: Feedback purge request. - - Returns: - Feedback purge response. - """ - return self._stainless_request( + return self._via_sdk( req, + sdk_models.FeedbackPurgeReq, + self._sdk.feedback.purge, tsi.FeedbackPurgeRes, - self._stainless_client.feedback.purge, ) @validate_call def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: - """Replace feedback. - - Args: - req: Feedback replace request. - - Returns: - Feedback replace response. - """ - return self._stainless_request( + return self._via_sdk( req, + sdk_models.FeedbackReplaceReq, + self._sdk.feedback.replace, tsi.FeedbackReplaceRes, - self._stainless_client.feedback.replace, ) - # Cost API @validate_call - def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: - """Query costs. - - Args: - req: Cost query request. - - Returns: - Cost query response. - """ - return self._stainless_request( + def feedback_stats(self, req: tsi.FeedbackStatsReq) -> tsi.FeedbackStatsRes: + return self._via_sdk( req, - tsi.CostQueryRes, - self._stainless_client.costs.query, + sdk_models.FeedbackStatsReq, + self._sdk.feedback.create_stats, + tsi.FeedbackStatsRes, ) @validate_call - def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: - """Create cost. - - Args: - req: Cost create request. + def feedback_aggregate( + self, req: tsi.FeedbackAggregateReq + ) -> tsi.FeedbackAggregateRes: + """Query the feedback table for aggregate scores over time.""" + # Not yet present in the published SDK. + return self._raw_request( + "POST", FEEDBACK_AGGREGATE_PATH, req=req, res_type=tsi.FeedbackAggregateRes + ) - Returns: - Cost create response. - """ - return self._stainless_request( + @validate_call + def feedback_payload_schema( + self, req: tsi.FeedbackPayloadSchemaReq + ) -> tsi.FeedbackPayloadSchemaRes: + return self._via_sdk( req, - tsi.CostCreateRes, - self._stainless_client.costs.create, + sdk_models.FeedbackPayloadSchemaReq, + self._sdk.feedback.create_payload_schema, + tsi.FeedbackPayloadSchemaRes, ) - @validate_call - def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: - """Purge costs. + # ---- Cost API --------------------------------------------------------------- - Args: - req: Cost purge request. + @validate_call + def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: + return self._via_sdk( + req, sdk_models.CostQueryReq, self._sdk.costs.query, tsi.CostQueryRes + ) - Returns: - Cost purge response. - """ - return self._stainless_request( - req, - tsi.CostPurgeRes, - self._stainless_client.costs.purge, + @validate_call + def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: + return self._via_sdk( + req, sdk_models.CostCreateReq, self._sdk.costs.create, tsi.CostCreateRes ) @validate_call + def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: + return self._via_sdk( + req, sdk_models.CostPurgeReq, self._sdk.costs.purge, tsi.CostPurgeRes + ) + + # ---- Execution APIs --------------------------------------------------------- + def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: - """Create completion. - - Args: - req: Completions create request. - - Returns: - Completions create response. - """ - return self._stainless_request( - req, - tsi.CompletionsCreateRes, - self._stainless_client.completions.create, + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", COMPLETIONS_CREATE_PATH, req=req, res_type=tsi.CompletionsCreateRes ) - @validate_call def completions_create_stream( self, req: tsi.CompletionsCreateReq ) -> Iterator[dict[str, Any]]: - """Create completion stream. - - Args: - req: Completions create request. - - Yields: - Dictionary chunks of the streamed response. - """ # For remote servers, streaming is not implemented # Fall back to non-streaming completion response = self.completions_create(req) yield {"response": response.response, "weave_call_id": response.weave_call_id} - @validate_call def image_create( self, req: tsi.ImageGenerationCreateReq ) -> tsi.ImageGenerationCreateRes: - """Create image generation. - - Args: - req: Image generation create request. - - Returns: - Image generation create response. - """ - # Image generation may not be in stainless client yet - raise NotImplementedError( - "Image generation not yet implemented in stainless client" + return self._via_sdk( + req, + sdk_models.ImageGenerationCreateReq, + self._sdk.images.create, + tsi.ImageGenerationCreateRes, ) - @validate_call def project_stats(self, req: tsi.ProjectStatsReq) -> tsi.ProjectStatsRes: - """Get project statistics. - - Args: - req: Project stats request. - - Returns: - Project stats response. - """ - return self._stainless_request( - req, - tsi.ProjectStatsRes, - self._stainless_client.services.project_stats, + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", PROJECT_STATS_PATH, req=req, res_type=tsi.ProjectStatsRes ) - @validate_call def project_ttl_settings_read( self, req: tsi.ProjectTTLSettingsReadReq ) -> tsi.ProjectTTLSettingsReadRes: - raise NotImplementedError( - "project_ttl_settings_read is not yet implemented in stainless client" + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + PROJECT_TTL_SETTINGS_READ_PATH, + req=req, + res_type=tsi.ProjectTTLSettingsReadRes, ) - @validate_call def project_ttl_settings_update( self, req: tsi.ProjectTTLSettingsUpdateReq ) -> tsi.ProjectTTLSettingsUpdateRes: - raise NotImplementedError( - "project_ttl_settings_update is not yet implemented in stainless client" + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + PROJECT_TTL_SETTINGS_UPDATE_PATH, + req=req, + res_type=tsi.ProjectTTLSettingsUpdateRes, ) - @validate_call def threads_query_stream( self, req: tsi.ThreadsQueryReq ) -> Iterator[tsi.ThreadSchema]: - """Stream query threads. - - Args: - req: Threads query request. + return self._raw_stream( + "POST", THREADS_STREAM_QUERY_PATH, req=req, res_type=tsi.ThreadSchema + ) - Yields: - ThreadSchema instances. - """ - self._update_client_headers() - req_dict = req.model_dump(by_alias=True) - response = self._stainless_client.threads.stream_query(**req_dict) - for item in response: - yield tsi.ThreadSchema.model_validate(item) + # ---- Annotation Queue API ----------------------------------------------------- - @validate_call - def evaluate_model(self, req: tsi.EvaluateModelReq) -> tsi.EvaluateModelRes: - """Evaluate model. + def annotation_queue_create( + self, req: tsi.AnnotationQueueCreateReq + ) -> tsi.AnnotationQueueCreateRes: + return self._via_sdk( + req, + sdk_models.AnnotationQueueCreateReq, + self._sdk.annotation_queues.create_annotation_queues, + tsi.AnnotationQueueCreateRes, + ) + + def annotation_queues_query_stream( + self, req: tsi.AnnotationQueuesQueryReq + ) -> Iterator[tsi.AnnotationQueueSchema]: + return self._raw_stream( + "POST", + ANNOTATION_QUEUES_QUERY_PATH, + req=req, + res_type=tsi.AnnotationQueueSchema, + ) + + def annotation_queue_read( + self, req: tsi.AnnotationQueueReadReq + ) -> tsi.AnnotationQueueReadRes: + return self._via_sdk_no_body( + self._sdk.annotation_queues.list_annotation_queues, + tsi.AnnotationQueueReadRes, + queue_id=req.queue_id, + project_id=req.project_id, + ) + + def annotation_queue_delete( + self, req: tsi.AnnotationQueueDeleteReq + ) -> tsi.AnnotationQueueDeleteRes: + return self._via_sdk_no_body( + self._sdk.annotation_queues.delete_annotation_queues, + tsi.AnnotationQueueDeleteRes, + queue_id=req.queue_id, + project_id=req.project_id, + ) + + def annotation_queue_update( + self, req: tsi.AnnotationQueueUpdateReq + ) -> tsi.AnnotationQueueUpdateRes: + # Body type excludes queue_id from the request body (it's in the URL path) + body = sdk_models.AnnotationQueueUpdateBody( + project_id=req.project_id, + name=req.name, + description=req.description, + scorer_refs=req.scorer_refs, + ) + return self._via_sdk_no_body( + self._sdk.annotation_queues.update_annotation_queues, + tsi.AnnotationQueueUpdateRes, + body=body, + queue_id=req.queue_id, + ) + + def annotation_queue_add_calls( + self, req: tsi.AnnotationQueueAddCallsReq + ) -> tsi.AnnotationQueueAddCallsRes: + # Body type excludes queue_id from the request body (it's in the URL path) + body = sdk_models.AnnotationQueueAddCallsBody( + project_id=req.project_id, + call_ids=req.call_ids, + display_fields=req.display_fields, + ) + return self._via_sdk_no_body( + self._sdk.annotation_queues.create_items, + tsi.AnnotationQueueAddCallsRes, + body=body, + queue_id=req.queue_id, + ) + + def annotation_queue_items_query( + self, req: tsi.AnnotationQueueItemsQueryReq + ) -> tsi.AnnotationQueueItemsQueryRes: + # Body type excludes queue_id from the request body (it's in the URL path) + body = sdk_models.AnnotationQueueItemsQueryBody.model_validate( + req.model_dump(exclude={"queue_id"}, by_alias=True) + ) + return self._via_sdk_no_body( + self._sdk.annotation_queues.query, + tsi.AnnotationQueueItemsQueryRes, + body=body, + queue_id=req.queue_id, + ) + + def annotation_queues_stats( + self, req: tsi.AnnotationQueuesStatsReq + ) -> tsi.AnnotationQueuesStatsRes: + return self._via_sdk( + req, + sdk_models.AnnotationQueuesStatsReq, + self._sdk.annotation_queues.create_stats, + tsi.AnnotationQueuesStatsRes, + ) - Args: - req: Evaluate model request. + def annotator_queue_items_progress_update( + self, req: tsi.AnnotatorQueueItemsProgressUpdateReq + ) -> tsi.AnnotatorQueueItemsProgressUpdateRes: + # Body type excludes queue_id, item_id, and wb_user_id from the request + # body (queue_id and item_id are in the URL path, wb_user_id is set + # server-side from auth) + body = sdk_models.AnnotationQueueItemProgressUpdateBody( + project_id=req.project_id, + annotation_state=req.annotation_state, + ) + return self._via_sdk_no_body( + self._sdk.annotation_queues.create_progress, + tsi.AnnotatorQueueItemsProgressUpdateRes, + body=body, + queue_id=req.queue_id, + item_id=req.item_id, + ) - Returns: - Evaluate model response. + # ---- Server-side execution (not supported remotely) ---------------------------- - Raises: - NotImplementedError: Not implemented. - """ + def evaluate_model(self, req: tsi.EvaluateModelReq) -> tsi.EvaluateModelRes: raise NotImplementedError("evaluate_model is not implemented") - @validate_call def evaluation_status( self, req: tsi.EvaluationStatusReq ) -> tsi.EvaluationStatusRes: - """Get evaluation status. - - Args: - req: Evaluation status request. - - Returns: - Evaluation status response. - - Raises: - NotImplementedError: Not implemented. - """ raise NotImplementedError("evaluation_status is not implemented") - @validate_call - def calls_score(self, req: tsi.CallsScoreReq) -> tsi.CallsScoreRes: - """Score calls. + def rescore(self, req: tsi.RescoreReq) -> tsi.RescoreRes: + raise NotImplementedError("rescore is not implemented") - Args: - req: Calls score request. + def calls_score(self, req: tsi.CallsScoreReq) -> tsi.CallsScoreRes: + raise NotImplementedError("calls_score is not implemented") - Returns: - Calls score response. + # ---- V2 APIs -------------------------------------------------------------- - Raises: - NotImplementedError: Not implemented. - """ - raise NotImplementedError("calls_score is not implemented") + def _v2_body_create( + self, + req: BaseModel, + body_type: type[BaseModel], + sdk_method: Callable[..., Any], + res_type: type[TRes], + ) -> TRes: + """Create via a v2 endpoint: project_id moves to the path; the rest is body.""" + entity, project = from_project_id(req.project_id) # type: ignore[attr-defined] + body = body_type.model_validate(req.model_dump(exclude={"project_id"})) + return self._via_sdk_no_body( + sdk_method, res_type, body=body, entity=entity, project=project + ) - # === Object APIs === + def _v2_list_stream( + self, + req: BaseModel, + res_type: type[TRes], + kind: str, + params: dict[str, Any], + ) -> Iterator[TRes]: + """Stream a v2 jsonl list endpoint (the SDK buffers jsonl responses).""" + entity, project = from_project_id(req.project_id) # type: ignore[attr-defined] + url = f"/v2/{entity}/{project}/{kind}" + return self._raw_stream("GET", url, params=params, res_type=res_type) @validate_call def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: - """Create op. - - Args: - req: Op create request. - - Returns: - Op create response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.ops.create( - entity=entity, - project=project, - name=req.name, - source_code=req.source_code, + return self._v2_body_create( + req, sdk_models.OpCreateBody, self._sdk.v2_ops.create, tsi.OpCreateRes ) - return tsi.OpCreateRes.model_validate(response.model_dump()) @validate_call def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: - """Read op. - - Args: - req: Op read request. - - Returns: - Op read response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.ops.read( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_ops.read, + tsi.OpReadRes, entity=entity, project=project, object_id=req.object_id, digest=req.digest, ) - return tsi.OpReadRes.model_validate(response.model_dump()) @validate_call def op_list(self, req: tsi.OpListReq) -> Iterator[tsi.OpReadRes]: - """List ops. - - Args: - req: Op list request. - - Yields: - OpReadRes instances. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.ops.list( - entity=entity, - project=project, - limit=req.limit, - offset=req.offset, - ) - for item in response: - yield tsi.OpReadRes.model_validate(item) + params: dict[str, Any] = {} + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + # `eager` is missing from the SDK's generated v2_ops.list signature. + if req.eager: + params["eager"] = "true" + return self._v2_list_stream(req, tsi.OpReadRes, "ops", params) @validate_call def op_delete(self, req: tsi.OpDeleteReq) -> tsi.OpDeleteRes: - """Delete op. - - Args: - req: Op delete request. - - Returns: - Op delete response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.ops.delete( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_ops.delete, + tsi.OpDeleteRes, entity=entity, project=project, object_id=req.object_id, + digests=req.digests, ) - return tsi.OpDeleteRes.model_validate(response.model_dump()) @validate_call def dataset_create(self, req: tsi.DatasetCreateReq) -> tsi.DatasetCreateRes: - """Create dataset. - - Args: - req: Dataset create request. - - Returns: - Dataset create response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.datasets.create( - entity=entity, - project=project, - rows=req.rows, - description=req.description, - name=req.name, + return self._v2_body_create( + req, + sdk_models.DatasetCreateBody, + self._sdk.v2_datasets.create, + tsi.DatasetCreateRes, ) - return tsi.DatasetCreateRes.model_validate(response.model_dump()) @validate_call def dataset_read(self, req: tsi.DatasetReadReq) -> tsi.DatasetReadRes: - """Read dataset. - - Args: - req: Dataset read request. - - Returns: - Dataset read response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.datasets.read( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_datasets.read, + tsi.DatasetReadRes, entity=entity, project=project, object_id=req.object_id, digest=req.digest, ) - return tsi.DatasetReadRes.model_validate(response.model_dump()) @validate_call def dataset_list(self, req: tsi.DatasetListReq) -> Iterator[tsi.DatasetReadRes]: - """List datasets. - - Args: - req: Dataset list request. - - Yields: - DatasetReadRes instances. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.datasets.list( - entity=entity, - project=project, - limit=req.limit, - offset=req.offset, - ) - for item in response: - yield tsi.DatasetReadRes.model_validate(item) + params: dict[str, Any] = {} + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + return self._v2_list_stream(req, tsi.DatasetReadRes, "datasets", params) @validate_call def dataset_delete(self, req: tsi.DatasetDeleteReq) -> tsi.DatasetDeleteRes: - """Delete dataset. - - Args: - req: Dataset delete request. - - Returns: - Dataset delete response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.datasets.delete( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_datasets.delete, + tsi.DatasetDeleteRes, entity=entity, project=project, object_id=req.object_id, + digests=req.digests, ) - return tsi.DatasetDeleteRes.model_validate(response.model_dump()) @validate_call def scorer_create(self, req: tsi.ScorerCreateReq) -> tsi.ScorerCreateRes: - """Create scorer. - - Args: - req: Scorer create request. - - Returns: - Scorer create response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.scorers.create( - entity=entity, - project=project, - name=req.name, - op_source_code=req.op_source_code, - description=req.description, + return self._v2_body_create( + req, + sdk_models.ScorerCreateBody, + self._sdk.v2_scorers.create, + tsi.ScorerCreateRes, ) - return tsi.ScorerCreateRes.model_validate(response.model_dump()) @validate_call def scorer_read(self, req: tsi.ScorerReadReq) -> tsi.ScorerReadRes: - """Read scorer. - - Args: - req: Scorer read request. - - Returns: - Scorer read response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.scorers.read( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_scorers.read, + tsi.ScorerReadRes, entity=entity, project=project, object_id=req.object_id, digest=req.digest, ) - return tsi.ScorerReadRes.model_validate(response.model_dump()) @validate_call def scorer_list(self, req: tsi.ScorerListReq) -> Iterator[tsi.ScorerReadRes]: - """List scorers. - - Args: - req: Scorer list request. - - Yields: - ScorerReadRes instances. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.scorers.list( - entity=entity, - project=project, - limit=req.limit, - offset=req.offset, - ) - for item in response: - yield tsi.ScorerReadRes.model_validate(item) + params: dict[str, Any] = {} + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + return self._v2_list_stream(req, tsi.ScorerReadRes, "scorers", params) @validate_call def scorer_delete(self, req: tsi.ScorerDeleteReq) -> tsi.ScorerDeleteRes: - """Delete scorer. - - Args: - req: Scorer delete request. - - Returns: - Scorer delete response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.scorers.delete( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_scorers.delete, + tsi.ScorerDeleteRes, entity=entity, project=project, object_id=req.object_id, + digests=req.digests, ) - return tsi.ScorerDeleteRes.model_validate(response.model_dump()) @validate_call def evaluation_create( self, req: tsi.EvaluationCreateReq ) -> tsi.EvaluationCreateRes: - """Create evaluation. - - Args: - req: Evaluation create request. - - Returns: - Evaluation create response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.evaluations.create( - entity=entity, - project=project, - dataset=req.dataset, - name=req.name, - description=req.description, - scorers=req.scorers, - trials=req.trials, + return self._v2_body_create( + req, + sdk_models.EvaluationCreateBody, + self._sdk.v2_evaluations.create, + tsi.EvaluationCreateRes, ) - return tsi.EvaluationCreateRes.model_validate(response.model_dump()) @validate_call def evaluation_read(self, req: tsi.EvaluationReadReq) -> tsi.EvaluationReadRes: - """Read evaluation. - - Args: - req: Evaluation read request. - - Returns: - Evaluation read response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.evaluations.read( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_evaluations.read, + tsi.EvaluationReadRes, entity=entity, project=project, object_id=req.object_id, digest=req.digest, ) - return tsi.EvaluationReadRes.model_validate(response.model_dump()) @validate_call def evaluation_list( self, req: tsi.EvaluationListReq ) -> Iterator[tsi.EvaluationReadRes]: - """List evaluations. - - Args: - req: Evaluation list request. - - Yields: - EvaluationReadRes instances. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.evaluations.list( - entity=entity, - project=project, - limit=req.limit, - offset=req.offset, - ) - for item in response: - yield tsi.EvaluationReadRes.model_validate(item) + params: dict[str, Any] = {} + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + return self._v2_list_stream(req, tsi.EvaluationReadRes, "evaluations", params) @validate_call def evaluation_delete( self, req: tsi.EvaluationDeleteReq ) -> tsi.EvaluationDeleteRes: - """Delete evaluation. - - Args: - req: Evaluation delete request. - - Returns: - Evaluation delete response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.evaluations.delete( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_evaluations.delete, + tsi.EvaluationDeleteRes, entity=entity, project=project, object_id=req.object_id, + digests=req.digests, ) - return tsi.EvaluationDeleteRes.model_validate(response.model_dump()) + + # ---- Model V2 API ----------------------------------------------------------- @validate_call def model_create(self, req: tsi.ModelCreateReq) -> tsi.ModelCreateRes: - """Create model. - - Args: - req: Model create request. - - Returns: - Model create response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.models.create( - entity=entity, - project=project, - name=req.name, - source_code=req.source_code, - attributes=req.attributes, - description=req.description, + return self._v2_body_create( + req, + sdk_models.ModelCreateBody, + self._sdk.v2_models.create, + tsi.ModelCreateRes, ) - return tsi.ModelCreateRes.model_validate(response.model_dump()) @validate_call def model_read(self, req: tsi.ModelReadReq) -> tsi.ModelReadRes: - """Read model. - - Args: - req: Model read request. - - Returns: - Model read response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.models.read( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_models.read, + tsi.ModelReadRes, entity=entity, project=project, object_id=req.object_id, digest=req.digest, ) - return tsi.ModelReadRes.model_validate(response.model_dump()) @validate_call def model_list(self, req: tsi.ModelListReq) -> Iterator[tsi.ModelReadRes]: - """List models. - - Args: - req: Model list request. - - Yields: - ModelReadRes instances. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.models.list( - entity=entity, - project=project, - limit=req.limit, - offset=req.offset, - ) - for item in response: - yield tsi.ModelReadRes.model_validate(item) + params: dict[str, Any] = {} + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + return self._v2_list_stream(req, tsi.ModelReadRes, "models", params) @validate_call def model_delete(self, req: tsi.ModelDeleteReq) -> tsi.ModelDeleteRes: - """Delete model. - - Args: - req: Model delete request. - - Returns: - Model delete response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.models.delete( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_models.delete, + tsi.ModelDeleteRes, entity=entity, project=project, object_id=req.object_id, + digests=req.digests, ) - return tsi.ModelDeleteRes.model_validate(response.model_dump()) + + # ---- Evaluation Run V2 API ---------------------------------------------------- @validate_call def evaluation_run_create( self, req: tsi.EvaluationRunCreateReq ) -> tsi.EvaluationRunCreateRes: - """Create evaluation run. - - Args: - req: Evaluation run create request. - - Returns: - Evaluation run create response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.evaluation_runs.create( - entity=entity, - project=project, - evaluation=req.evaluation, - model=req.model, + return self._v2_body_create( + req, + sdk_models.EvaluationRunCreateBody, + self._sdk.v2_evaluation_runs.create, + tsi.EvaluationRunCreateRes, ) - return tsi.EvaluationRunCreateRes.model_validate(response.model_dump()) @validate_call def evaluation_run_read( self, req: tsi.EvaluationRunReadReq ) -> tsi.EvaluationRunReadRes: - """Read evaluation run. - - Args: - req: Evaluation run read request. - - Returns: - Evaluation run read response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.evaluation_runs.read( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_evaluation_runs.read, + tsi.EvaluationRunReadRes, entity=entity, project=project, evaluation_run_id=req.evaluation_run_id, ) - return tsi.EvaluationRunReadRes.model_validate(response.model_dump()) @validate_call def evaluation_run_list( self, req: tsi.EvaluationRunListReq ) -> Iterator[tsi.EvaluationRunReadRes]: - """List evaluation runs. - - Args: - req: Evaluation run list request. - - Yields: - EvaluationRunReadRes instances. - """ - entity, project = self._prepare_v2_request(req) - - # Extract filter parameters with explicit typing - evaluation_refs: str | None = ( - ",".join(req.filter.evaluations) - if req.filter and req.filter.evaluations - else None - ) - model_refs: str | None = ( - ",".join(req.filter.models) if req.filter and req.filter.models else None - ) - evaluation_run_ids: str | None = ( - ",".join(req.filter.evaluation_run_ids) - if req.filter and req.filter.evaluation_run_ids - else None - ) - - # Call stainless API with typed parameters - # Pass filter parameters explicitly as typed keyword arguments - response = self._stainless_client.v2.evaluation_runs.list( - entity=entity, - project=project, - limit=req.limit, - offset=req.offset, - evaluation_refs=evaluation_refs, - model_refs=model_refs, - evaluation_run_ids=evaluation_run_ids, + # Raw: the SDK's generated list signature renames the filter params + # (evaluations vs evaluation_refs), so use the wire format directly. + params: dict[str, Any] = {} + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + if req.filter: + if req.filter.evaluations: + params["evaluation_refs"] = ",".join(req.filter.evaluations) + if req.filter.models: + params["model_refs"] = ",".join(req.filter.models) + if req.filter.evaluation_run_ids: + params["evaluation_run_ids"] = ",".join(req.filter.evaluation_run_ids) + return self._v2_list_stream( + req, tsi.EvaluationRunReadRes, "evaluation_runs", params ) - for item in response: - yield tsi.EvaluationRunReadRes.model_validate(item) - @validate_call def evaluation_run_delete( self, req: tsi.EvaluationRunDeleteReq ) -> tsi.EvaluationRunDeleteRes: - """Delete evaluation run. - - Args: - req: Evaluation run delete request. - - Returns: - Evaluation run delete response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.evaluation_runs.delete( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_evaluation_runs.delete, + tsi.EvaluationRunDeleteRes, entity=entity, project=project, evaluation_run_ids=req.evaluation_run_ids, ) - return tsi.EvaluationRunDeleteRes.model_validate(response.model_dump()) @validate_call def evaluation_run_finish( self, req: tsi.EvaluationRunFinishReq ) -> tsi.EvaluationRunFinishRes: - """Finish evaluation run. - - Args: - req: Evaluation run finish request. - - Returns: - Evaluation run finish response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.evaluation_runs.finish( + entity, project = from_project_id(req.project_id) + body = sdk_models.EvaluationRunFinishBody.model_validate( + req.model_dump(exclude={"project_id", "evaluation_run_id"}) + ) + return self._via_sdk_no_body( + self._sdk.v2_evaluation_runs.finish, + tsi.EvaluationRunFinishRes, + body=body, entity=entity, project=project, evaluation_run_id=req.evaluation_run_id, - summary=req.summary, ) - return tsi.EvaluationRunFinishRes.model_validate(response.model_dump()) + + # ---- Prediction V2 API ---------------------------------------------------------- @validate_call def prediction_create( self, req: tsi.PredictionCreateReq ) -> tsi.PredictionCreateRes: - """Create prediction. - - Args: - req: Prediction create request. - - Returns: - Prediction create response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.predictions.create( - entity=entity, - project=project, - inputs=req.inputs, - model=req.model, - output=req.output, - evaluation_run_id=req.evaluation_run_id, + return self._v2_body_create( + req, + sdk_models.PredictionCreateBody, + self._sdk.v2_predictions.create, + tsi.PredictionCreateRes, ) - return tsi.PredictionCreateRes.model_validate(response.model_dump()) @validate_call def prediction_read(self, req: tsi.PredictionReadReq) -> tsi.PredictionReadRes: - """Read prediction. - - Args: - req: Prediction read request. - - Returns: - Prediction read response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.predictions.read( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_predictions.read, + tsi.PredictionReadRes, entity=entity, project=project, prediction_id=req.prediction_id, ) - return tsi.PredictionReadRes.model_validate(response.model_dump()) @validate_call def prediction_list( self, req: tsi.PredictionListReq ) -> Iterator[tsi.PredictionReadRes]: - """List predictions. - - Args: - req: Prediction list request. - - Yields: - PredictionReadRes instances. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.predictions.list( - entity=entity, - project=project, - evaluation_run_id=req.evaluation_run_id, - limit=req.limit, - offset=req.offset, - ) - for item in response: - yield tsi.PredictionReadRes.model_validate(item) + params: dict[str, Any] = {} + if req.evaluation_run_id is not None: + params["evaluation_run_id"] = req.evaluation_run_id + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + return self._v2_list_stream(req, tsi.PredictionReadRes, "predictions", params) @validate_call def prediction_delete( self, req: tsi.PredictionDeleteReq ) -> tsi.PredictionDeleteRes: - """Delete prediction. - - Args: - req: Prediction delete request. - - Returns: - Prediction delete response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.predictions.delete( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_predictions.delete, + tsi.PredictionDeleteRes, entity=entity, project=project, prediction_ids=req.prediction_ids, ) - return tsi.PredictionDeleteRes.model_validate(response.model_dump()) @validate_call def prediction_finish( self, req: tsi.PredictionFinishReq ) -> tsi.PredictionFinishRes: - """Finish prediction. - - Args: - req: Prediction finish request. - - Returns: - Prediction finish response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.predictions.finish( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_predictions.finish, + tsi.PredictionFinishRes, entity=entity, project=project, prediction_id=req.prediction_id, ) - return tsi.PredictionFinishRes.model_validate(response.model_dump()) + + # ---- Score V2 API --------------------------------------------------------------- @validate_call def score_create(self, req: tsi.ScoreCreateReq) -> tsi.ScoreCreateRes: - """Create score. - - Args: - req: Score create request. - - Returns: - Score create response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.scores.create( - entity=entity, - project=project, - prediction_id=req.prediction_id, - scorer=req.scorer, - value=req.value, - evaluation_run_id=req.evaluation_run_id, + return self._v2_body_create( + req, + sdk_models.ScoreCreateBody, + self._sdk.v2_scores.create, + tsi.ScoreCreateRes, ) - return tsi.ScoreCreateRes.model_validate(response.model_dump()) @validate_call def score_read(self, req: tsi.ScoreReadReq) -> tsi.ScoreReadRes: - """Read score. - - Args: - req: Score read request. - - Returns: - Score read response. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.scores.read( + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_scores.read, + tsi.ScoreReadRes, entity=entity, project=project, score_id=req.score_id, ) - return tsi.ScoreReadRes.model_validate(response.model_dump()) @validate_call def score_list(self, req: tsi.ScoreListReq) -> Iterator[tsi.ScoreReadRes]: - """List scores. + params: dict[str, Any] = {} + if req.evaluation_run_id is not None: + params["evaluation_run_id"] = req.evaluation_run_id + if req.limit is not None: + params["limit"] = req.limit + if req.offset is not None: + params["offset"] = req.offset + return self._v2_list_stream(req, tsi.ScoreReadRes, "scores", params) - Args: - req: Score list request. - - Yields: - ScoreReadRes instances. - """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.scores.list( + @validate_call + def score_delete(self, req: tsi.ScoreDeleteReq) -> tsi.ScoreDeleteRes: + entity, project = from_project_id(req.project_id) + return self._via_sdk_no_body( + self._sdk.v2_scores.delete, + tsi.ScoreDeleteRes, entity=entity, project=project, - evaluation_run_id=req.evaluation_run_id, - limit=req.limit, - offset=req.offset, + score_ids=req.score_ids, ) - for item in response: - yield tsi.ScoreReadRes.model_validate(item) - @validate_call - def score_delete(self, req: tsi.ScoreDeleteReq) -> tsi.ScoreDeleteRes: - """Delete score. + # ---- Calls V2 API --------------------------------------------------------------- - Args: - req: Score delete request. + def calls_complete( + self, req: tsi.CallsUpsertCompleteReq + ) -> tsi.CallsUpsertCompleteRes: + """Batch complete calls endpoint (v2). - Returns: - Score delete response. + This endpoint is used when use_calls_complete is enabled to send + complete calls (with both start and end information) in batches. """ - entity, project = self._prepare_v2_request(req) - response = self._stainless_client.v2.scores.delete( - entity=entity, - project=project, - score_ids=req.score_ids, + if not req.batch: + return tsi.CallsUpsertCompleteRes() + + first_item = req.batch[0] + entity, project = from_project_id(first_item.project_id) + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + CALLS_COMPLETE_PATH.format(entity=entity, project=project), + req=req, + res_type=tsi.CallsUpsertCompleteRes, + ) + + def call_start_v2(self, req: tsi.CallStartV2Req) -> tsi.CallStartV2Res: + """Single call start endpoint (v2). + + This endpoint is used for eager ops that need their start visible immediately. + """ + entity, project = from_project_id(req.start.project_id) + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + CALL_START_V2_PATH.format(entity=entity, project=project), + req=req, + res_type=tsi.CallStartV2Res, + ) + + def call_end_v2(self, req: tsi.CallEndV2Req) -> tsi.CallEndV2Res: + """Single call end endpoint (v2). + + This endpoint is used for eager ops that need their end sent separately. + """ + entity, project = from_project_id(req.end.project_id) + # Excluded from the OpenAPI spec (include_in_schema=False). + return self._raw_request( + "POST", + CALL_END_V2_PATH.format(entity=entity, project=project), + req=req, + res_type=tsi.CallEndV2Res, ) - return tsi.ScoreDeleteRes.model_validate(response.model_dump()) diff --git a/weave/utils/http_requests.py b/weave/utils/http_requests.py index 81339d6d0ab6..10a887681d2f 100644 --- a/weave/utils/http_requests.py +++ b/weave/utils/http_requests.py @@ -147,7 +147,8 @@ def _is_debug_http_enabled() -> bool: return os.environ.get("WEAVE_DEBUG_HTTP") == "1" -def _log_request(request: Request) -> None: +def log_request(request: Request) -> None: + """Httpx request hook: pretty-print the request when WEAVE_DEBUG_HTTP=1.""" if not _is_debug_http_enabled(): return @@ -156,7 +157,8 @@ def _log_request(request: Request) -> None: pprint_request(request) -def _log_response(response: Response) -> None: +def log_response(response: Response) -> None: + """Httpx response hook: pretty-print the response when WEAVE_DEBUG_HTTP=1.""" if not _is_debug_http_enabled(): return @@ -183,7 +185,7 @@ def _build_client(verify: bool, timeout: float) -> httpx.Client: return httpx.Client( # Use HTTPX's default transport so env proxy handling (including # NO_PROXY) works natively. - event_hooks={"request": [_log_request], "response": [_log_response]}, + event_hooks={"request": [log_request], "response": [log_response]}, timeout=timeout, limits=CLIENT_LIMITS, verify=verify,