diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2a0ecab1d3ec..78ff3ff1deec 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -149,7 +149,6 @@ jobs: "pandas_test", "fastmcp", "smolagents", - "stainless", "autogen_tests", ] # Shards that don't support certain Python versions. Excluding here @@ -578,7 +577,6 @@ jobs: "pandas_test", "fastmcp", "smolagents", - "stainless", "autogen_tests", ] # Shards that don't support certain Python versions. Excluding here diff --git a/noxfile.py b/noxfile.py index b398669f1f61..f2bfdbd4d7b2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -112,7 +112,6 @@ def lint(session: nox.Session): "trace", "trace_calls_complete_only", "trace_no_server", - "stainless", ], ) def tests(session: nox.Session, shard: str): @@ -185,7 +184,6 @@ def tests(session: nox.Session, shard: str): "trace_server": ["tests/trace_server/", "tests/shared/"], "trace_server_bindings": ["tests/trace_server_bindings/"], "trace_server_migrator": ["tests/trace_server_migrator/"], - "stainless": ["tests/trace_server_bindings/"], "scorers": ["tests/scorers/"], "autogen_tests": ["tests/integrations/autogen/"], "verifiers_test": ["tests/integrations/verifiers/"], @@ -248,10 +246,6 @@ def tests(session: nox.Session, shard: str): if shard == "trace_calls_complete_only": env["WEAVE_USE_CALLS_COMPLETE"] = "true" - # Set trace-server flag for stainless shard - if shard == "stainless": - pytest_args.extend(["--remote-http-trace-server=stainless"]) - if shard == "verifiers_test": # Pinning to this commit because the latest version of the gsm8k environment is broken. session.install(GSM8K_ENVIRONMENT_PACKAGE) diff --git a/tests/conftest.py b/tests/conftest.py index 249fcab7a144..b80db4b4645c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from typing import Any from unittest.mock import MagicMock, patch +import httpx import pytest from fastapi import FastAPI from fastapi.responses import StreamingResponse @@ -20,13 +21,15 @@ from weave.trace.context.call_context import set_call_stack from weave.trace.settings import replace_settings from weave.trace_server import trace_server_interface as tsi -from weave.trace_server_bindings import remote_http_trace_server +from weave.trace_server_bindings import stainless_remote_http_trace_server from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor from weave.trace_server_bindings.caching_middleware_trace_server import ( CachingMiddlewareTraceServer, ) from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor -from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer +from weave.trace_server_bindings.stainless_remote_http_trace_server import ( + StainlessRemoteHTTPTraceServer, +) pytest_plugins = ["tests.trace_server.conftest"] @@ -457,7 +460,7 @@ def create_client( # Note: this is only for local dev testing and should be removed return weave_init.init_weave("dev_testing") elif trace_server_flag == "http": - server = RemoteHTTPTraceServer(trace_server_flag) + server = StainlessRemoteHTTPTraceServer(trace_server_flag) else: server = trace_server @@ -597,11 +600,12 @@ def client( @pytest.fixture def network_proxy_client(client, monkeypatch): - """This fixture is used to test the `RemoteHTTPTraceServer` class. There is - almost no logic in this class, other than a little batching, so we typically - skip it for simplicity. However, we can use this fixture to test such logic. - It initializes a mini FastAPI app that proxies requests from the - `RemoteHTTPTraceServer` to the underlying `client.server` object. + """This fixture is used to test the `StainlessRemoteHTTPTraceServer` class. + There is almost no logic in this class, other than a little batching, so we + typically skip it for simplicity. However, we can use this fixture to test + such logic. It initializes a mini FastAPI app and routes the server's HTTP + transport into it, proxying requests to the underlying `client.server` + object. We probably will want to flesh this out more in the future, but this is a starting point. @@ -707,12 +711,25 @@ def obj_read(req: tsi.ObjReadReq) -> tsi.ObjReadRes: with TestClient(app) as c: - def post(url, data=None, json=None, **kwargs): - kwargs.pop("stream", None) - return c.post(url, data=data, json=json, **kwargs) - - orig_post = weave.utils.http_requests.post - weave.utils.http_requests.post = post + class TestClientTransport(httpx.BaseTransport): + """Routes the server's httpx requests into the FastAPI TestClient.""" + + def handle_request(self, request: httpx.Request) -> httpx.Response: + request.read() + resp = c.request( + request.method, + request.url.path, + params=request.url.params, + content=request.content, + headers={ + k: v + for k, v in request.headers.items() + if k.lower() not in {"host", "content-length"} + }, + ) + return httpx.Response( + resp.status_code, headers=resp.headers, content=resp.content + ) def make_fast_async_batch_processor(*args, **kwargs): kwargs.setdefault("min_batch_interval", 0) @@ -723,24 +740,25 @@ def make_fast_call_batch_processor(*args, **kwargs): return CallBatchProcessor(*args, **kwargs) monkeypatch.setattr( - remote_http_trace_server, + stainless_remote_http_trace_server, "AsyncBatchProcessor", make_fast_async_batch_processor, ) monkeypatch.setattr( - remote_http_trace_server, + stainless_remote_http_trace_server, "CallBatchProcessor", make_fast_call_batch_processor, ) - remote_client = RemoteHTTPTraceServer( - trace_server_url="", + # Absolute base URL required for httpx cookie handling; the transport + # routes by path, so the host is never dialed. + remote_client = StainlessRemoteHTTPTraceServer( + trace_server_url="http://testserver", should_batch=True, + transport=TestClientTransport(), ) yield (client, remote_client, records) - weave.utils.http_requests.post = orig_post - @pytest.fixture(autouse=True) def caching_client_isolation(monkeypatch, tmp_path): diff --git a/tests/trace/test_wave_client_file_cache.py b/tests/trace/test_wave_client_file_cache.py index 327b1be622d8..69e95a08ba5b 100644 --- a/tests/trace/test_wave_client_file_cache.py +++ b/tests/trace/test_wave_client_file_cache.py @@ -11,7 +11,9 @@ WeaveClientSendFileCache, ) from weave.trace_server.trace_server_interface import FileCreateReq, FileCreateRes -from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer +from weave.trace_server_bindings.stainless_remote_http_trace_server import ( + StainlessRemoteHTTPTraceServer, +) class TestThreadSafeLRUCache: @@ -402,7 +404,7 @@ def offline_client(monkeypatch): monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "2") monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.01") monkeypatch.setenv("WEAVE_ENABLE_WAL", "false") - server = RemoteHTTPTraceServer("http://example.com") + server = StainlessRemoteHTTPTraceServer("http://example.com") client = WeaveClient( entity="ent", project="proj", diff --git a/tests/trace/test_weave_init_availability.py b/tests/trace/test_weave_init_availability.py index 3910331fc0fc..bff6f3e118b9 100644 --- a/tests/trace/test_weave_init_availability.py +++ b/tests/trace/test_weave_init_availability.py @@ -4,12 +4,12 @@ from unittest.mock import MagicMock from weave.trace import weave_init -from weave.trace_server_bindings import remote_http_trace_server +from weave.trace_server_bindings import stainless_remote_http_trace_server def test_get_server_info_json_decode_error(): """Test that _get_server_info returns None when server info cannot be decoded.""" - mock_server = MagicMock(spec=remote_http_trace_server.RemoteHTTPTraceServer) + mock_server = MagicMock(spec=stainless_remote_http_trace_server.StainlessRemoteHTTPTraceServer) mock_server.server_info.side_effect = json.JSONDecodeError("test error", "doc", 0) result = weave_init._get_server_info(mock_server) @@ -20,7 +20,7 @@ def test_get_server_info_json_decode_error(): def test_get_server_info_success(): """Test that _get_server_info returns server info when server is available.""" - mock_server = MagicMock(spec=remote_http_trace_server.RemoteHTTPTraceServer) + mock_server = MagicMock(spec=stainless_remote_http_trace_server.StainlessRemoteHTTPTraceServer) server_info = {"version": "1.0.0"} mock_server.server_info.return_value = server_info diff --git a/tests/trace_server/conftest.py b/tests/trace_server/conftest.py index 3086d5a830ad..32db608eaa47 100644 --- a/tests/trace_server/conftest.py +++ b/tests/trace_server/conftest.py @@ -68,12 +68,6 @@ def pytest_addoption(parser): default="false", help="Use a clickhouse process instead of a container", ) - parser.addoption( - "--remote-http-trace-server", - action="store", - default="remote", - help="Specify the remote HTTP trace server implementation: remote or stainless", - ) except ValueError: pass @@ -82,7 +76,6 @@ def pytest_collection_modifyitems(config, items): # Add the trace_server marker to: # 1. All tests in the trace_server directory (regardless of fixture usage) # 2. All tests that use the trace_server fixture (for tests outside this directory) - # Note: Filtering based on remote-http-trace-server flag is handled in tests/trace_server_bindings/conftest.py for item in items: # Check if the test is in the trace_server directory by checking parent directories if "trace_server" in item.path.parts: @@ -110,15 +103,6 @@ def get_trace_server_flag(request): return weave_server_flag -def get_remote_http_trace_server_flag(request): - """Get the remote HTTP trace server implementation to use. - - Returns: - str: Either 'remote' for RemoteHTTPTraceServer or 'stainless' for StainlessRemoteHTTPTraceServer - """ - return request.config.getoption("--remote-http-trace-server") - - @pytest.fixture(autouse=True) def reset_project_version_cache(): project_version.reset_project_residence_cache() diff --git a/tests/trace_server_bindings/conftest.py b/tests/trace_server_bindings/conftest.py index f8a540f55890..ec157f651f25 100644 --- a/tests/trace_server_bindings/conftest.py +++ b/tests/trace_server_bindings/conftest.py @@ -2,13 +2,14 @@ from types import MethodType from unittest.mock import MagicMock +import httpx import pytest import tenacity from weave.trace_server import trace_server_interface as tsi from weave.trace_server.ids import generate_id -from weave.trace_server_bindings.remote_http_trace_server import ( - RemoteHTTPTraceServer, +from weave.trace_server_bindings.stainless_remote_http_trace_server import ( + StainlessRemoteHTTPTraceServer, ) # ============================================================================= @@ -60,35 +61,58 @@ def generate_call_start_end_pair( # ============================================================================= -# Fixtures +# HTTP transport spy # ============================================================================= -@pytest.fixture -def success_response(): - """Common fixture for mocking a successful HTTP response.""" - response = MagicMock() - response.status_code = 200 - response.json.return_value = {"id": "test_id", "trace_id": "test_trace_id"} - return response +class SpyTransport(httpx.BaseTransport): + """httpx transport that records requests and replays queued responses. + Queue items may be ``httpx.Response`` objects or exceptions to raise. + When the queue is empty, returns ``default_response`` (200 ``{}`` unless + overridden). + """ -@pytest.fixture -def server_class(request): - """Returns the appropriate server class based on --remote-http-trace-server flag.""" - flag = request.config.getoption("--remote-http-trace-server", default="remote") - if flag == "stainless": - from weave.trace_server_bindings.stainless_remote_http_trace_server import ( - StainlessRemoteHTTPTraceServer, - ) + def __init__( + self, + *items: httpx.Response | Exception, + default_response: httpx.Response | None = None, + ) -> None: + self.queue: list[httpx.Response | Exception] = list(items) + self.requests: list[httpx.Request] = [] + self.default_response = default_response + + def handle_request(self, request: httpx.Request) -> httpx.Response: + request.read() + self.requests.append(request) + if self.queue: + item = self.queue.pop(0) + if isinstance(item, Exception): + raise item + return item + if self.default_response is not None: + return self.default_response + return httpx.Response(200, json={}) + + @property + def urls(self) -> list[str]: + return [str(r.url) for r in self.requests] - return StainlessRemoteHTTPTraceServer - return RemoteHTTPTraceServer + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def server_class(): + """The remote trace server implementation under test.""" + return StainlessRemoteHTTPTraceServer @pytest.fixture def server(request, server_class): - """Common server fixture that uses server_class based on the CLI flag.""" + """Common server fixture parametrized by batching/retry behavior.""" server_ = server_class("http://example.com", should_batch=True) if request.param == "normal": @@ -116,25 +140,6 @@ def server(request, server_class): server_.feedback_processor.stop_accepting_new_work_and_flush_queue() -def pytest_ignore_collect(collection_path, config): - """Ignore test files based on --remote-http-trace-server flag. - - This runs before collection, preventing files from being imported at all. - """ - if "trace_server_bindings" not in collection_path.parts: - return None - - flag = config.getoption("--remote-http-trace-server", default="remote") - filename = collection_path.name - - if flag == "remote" and filename.endswith("_stainless.py"): - return True - if flag == "stainless" and filename.endswith("_remote.py"): - return True - - return None - - def pytest_collection_modifyitems(config, items): """Add trace_server marker to all tests in this directory.""" for item in items: diff --git a/tests/trace_server_bindings/test_http_behavior_stainless.py b/tests/trace_server_bindings/test_http_behavior.py similarity index 94% rename from tests/trace_server_bindings/test_http_behavior_stainless.py rename to tests/trace_server_bindings/test_http_behavior.py index 25eb43f87bf4..e0ea18eae9ea 100644 --- a/tests/trace_server_bindings/test_http_behavior_stainless.py +++ b/tests/trace_server_bindings/test_http_behavior.py @@ -1,7 +1,7 @@ """HTTP behavior tests for StainlessRemoteHTTPTraceServer. These tests verify HTTP request/response handling, retry behavior for various -status codes, and error handling specific to StainlessRemoteHTTPTraceServer. +status codes, and error handling of the remote trace server binding. Mocking happens at the httpx transport boundary (the external seam), so the full stack — SDK encode/decode, event hooks, error translation, retry @@ -21,6 +21,7 @@ from pydantic import ValidationError from tests.trace_server_bindings.conftest import ( + SpyTransport, generate_end, generate_id, generate_start, @@ -45,40 +46,6 @@ BASE_URL = "http://example.com" -class SpyTransport(httpx.BaseTransport): - """httpx transport that records requests and replays queued responses. - - Queue items may be ``httpx.Response`` objects or exceptions to raise. - When the queue is empty, returns ``default_response`` (200 ``{}`` unless - overridden). - """ - - def __init__( - self, - *items: httpx.Response | Exception, - default_response: httpx.Response | None = None, - ) -> None: - self.queue: list[httpx.Response | Exception] = list(items) - self.requests: list[httpx.Request] = [] - self.default_response = default_response - - def handle_request(self, request: httpx.Request) -> httpx.Response: - request.read() - self.requests.append(request) - if self.queue: - item = self.queue.pop(0) - if isinstance(item, Exception): - raise item - return item - if self.default_response is not None: - return self.default_response - return httpx.Response(200, json={}) - - @property - def urls(self) -> list[str]: - return [str(r.url) for r in self.requests] - - def make_server( transport: httpx.BaseTransport, should_batch: bool = False, diff --git a/tests/trace_server_bindings/test_http_behavior_remote.py b/tests/trace_server_bindings/test_http_behavior_remote.py deleted file mode 100644 index 66ae13d2a87a..000000000000 --- a/tests/trace_server_bindings/test_http_behavior_remote.py +++ /dev/null @@ -1,461 +0,0 @@ -"""HTTP behavior tests for RemoteHTTPTraceServer. - -These tests verify HTTP request/response handling, retry behavior for various -status codes, and error handling specific to RemoteHTTPTraceServer. -""" - -from __future__ import annotations - -import datetime -import json -import logging -from types import MethodType -from unittest.mock import MagicMock, patch - -import httpx -import pytest -import tenacity -from pydantic import ValidationError - -from tests.trace_server_bindings.conftest import ( - generate_end, - generate_id, - generate_start, -) -from weave.trace.display.term import configure_logger -from weave.trace_server import trace_server_interface as tsi -from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor -from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor -from weave.trace_server_bindings.http_utils import ( - ERROR_CODE_CALLS_COMPLETE_MODE_REQUIRED, -) -from weave.trace_server_bindings.models import ( - CompleteBatchItem, - EndBatchItem, - StartBatchItem, -) -from weave.trace_server_bindings.remote_http_trace_server import ( - RemoteHTTPTraceServer, -) - - -def make_calls_complete_required_response() -> httpx.Response: - """Create a 400 response indicating the project requires calls_complete mode.""" - return httpx.Response( - 400, - json={ - "error_code": ERROR_CODE_CALLS_COMPLETE_MODE_REQUIRED, - "message": "Project requires calls_complete mode", - }, - request=httpx.Request("POST", "http://example.com/call/upsert_batch"), - ) - - -@pytest.fixture -def unbatched_server(): - """Create a RemoteHTTPTraceServer instance without batching for testing.""" - return RemoteHTTPTraceServer("http://example.com") - - -@patch("weave.utils.http_requests.post") -def test_call_start_ok(mock_post, unbatched_server): - """Test successful call_start request.""" - call_id = generate_id() - mock_response = httpx.Response( - 200, - json=dict(tsi.CallStartRes(id=call_id, trace_id="test_trace_id")), - request=httpx.Request("POST", "http://test.com"), - ) - mock_post.return_value = mock_response - start = generate_start(call_id) - unbatched_server.call_start(tsi.CallStartReq(start=start)) - mock_post.assert_called_once() - - -@patch("weave.utils.http_requests.post") -def test_400_no_retry(mock_post, unbatched_server): - """Test that 400 errors are not retried.""" - call_id = generate_id() - resp1 = httpx.Response( - 400, - json=dict(tsi.CallStartRes(id=call_id, trace_id="test_trace_id")), - request=httpx.Request("POST", "http://test.com"), - ) - mock_post.side_effect = [resp1] - - start = generate_start(call_id) - with pytest.raises(httpx.HTTPStatusError): - unbatched_server.call_start(tsi.CallStartReq(start=start)) - - -def test_invalid_no_retry(unbatched_server): - """Test that validation errors are not retried.""" - with pytest.raises(ValidationError): - unbatched_server.call_start(tsi.CallStartReq(start={"invalid": "broken"})) - - -@patch("weave.utils.http_requests.post") -def test_calls_complete_batch_endpoint_and_payload(mock_post, monkeypatch): - """Send calls_complete batches to the v2 endpoint with correct payload.""" - monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "true") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - - complete = tsi.CompletedCallSchemaForInsert( - project_id="entity/project", - id="call-id", - trace_id="trace-id", - op_name="test_op", - started_at=datetime.datetime.now(tz=datetime.timezone.utc), - ended_at=datetime.datetime.now(tz=datetime.timezone.utc), - attributes={"a": 1}, - inputs={"b": 2}, - output={"c": 3}, - summary={"result": "ok"}, - ) - batch = [CompleteBatchItem(req=complete)] - - mock_post.return_value = httpx.Response( - 200, - json={}, - request=httpx.Request("POST", "http://example.com"), - ) - - try: - server._flush_calls_complete(batch) - finally: - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() - - assert mock_post.call_count == 1 - url = mock_post.call_args[0][0] - assert url == "http://example.com/v2/entity/project/calls/complete" - sent_data = mock_post.call_args[1]["data"] - payload = json.loads(sent_data.decode("utf-8")) - expected = tsi.CallsUpsertCompleteReq(batch=[complete]).model_dump(mode="json") - assert payload == expected - - -@patch("weave.utils.http_requests.post") -def test_eager_calls_use_v2_start_end_endpoints(mock_post): - """Use v2 endpoints for eager start/end and include started_at in end.""" - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - - start = generate_start(id="call-id", project_id="entity/project") - ended_at = datetime.datetime.now(tz=datetime.timezone.utc) - started_at = ended_at - datetime.timedelta(seconds=1) - end = tsi.EndedCallSchemaForInsertWithStartedAt( - project_id="entity/project", - id="call-id", - ended_at=ended_at, - started_at=started_at, - summary={"result": "Test summary"}, - ) - - mock_post.side_effect = [ - httpx.Response(200, request=httpx.Request("POST", "http://example.com")), - httpx.Response(200, request=httpx.Request("POST", "http://example.com")), - ] - - try: - server._flush_calls_eager( - [ - StartBatchItem(req=tsi.CallStartReq(start=start)), - EndBatchItem(req=tsi.CallEndReq(end=end)), - ] - ) - - urls = [call[0][0] for call in mock_post.call_args_list] - assert urls == [ - "http://example.com/v2/entity/project/call/start", - "http://example.com/v2/entity/project/call/end", - ] - - end_payload = json.loads(mock_post.call_args_list[1][1]["data"].decode("utf-8")) - payload_started_at = datetime.datetime.fromisoformat( - end_payload["end"]["started_at"].replace("Z", "+00:00") - ) - assert payload_started_at == end.started_at - assert end_payload["end"]["id"] == "call-id" - finally: - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() - - -@pytest.mark.disable_logging_error_check -def test_eager_non_retryable_error_drops_item(caplog): - """Drop eager items on non-retryable errors without raising.""" - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - start = generate_start(id="call-id", project_id="entity/project") - - def _raise_non_retryable(_: StartBatchItem) -> None: - raise httpx.HTTPStatusError( - "400", - request=httpx.Request("POST", "http://example.com"), - response=httpx.Response( - 400, request=httpx.Request("POST", "http://example.com") - ), - ) - - server._send_call_start_v2 = _raise_non_retryable # type: ignore[assignment] - - caplog.set_level(logging.ERROR) - try: - server._flush_calls_eager([StartBatchItem(req=tsi.CallStartReq(start=start))]) - finally: - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() - - assert any("dropped call start ids" in record.message for record in caplog.records) - - -@pytest.mark.disable_logging_error_check -def test_eager_retryable_error_logs_and_continues(caplog): - """Log and drop item on retryable error, continue with remaining items.""" - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - start1 = generate_start(id="call-id-1", project_id="entity/project") - start2 = generate_start(id="call-id-2", project_id="entity/project") - - call_attempts = [] - - def _raise_retryable_once(start) -> None: - call_attempts.append(start.id) - if start.id == "call-id-1": - raise httpx.HTTPStatusError( - "500", - request=httpx.Request("POST", "http://example.com"), - response=httpx.Response( - 500, request=httpx.Request("POST", "http://example.com") - ), - ) - # call-id-2 succeeds - - server._send_call_start_v2 = _raise_retryable_once # type: ignore[assignment] - - try: - # Should NOT raise - logs and drops item 1, continues with item 2 - server._flush_calls_eager( - [ - StartBatchItem(req=tsi.CallStartReq(start=start1)), - StartBatchItem(req=tsi.CallStartReq(start=start2)), - ] - ) - finally: - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() - - # Item 1 was logged as dropped - assert any("dropped call start ids" in record.message for record in caplog.records) - assert any("call-id-1" in record.message for record in caplog.records) - # Item 2 was still processed - assert "call-id-2" in call_attempts - - -@patch("weave.utils.http_requests.post") -def test_500_502_503_504_429_retry(mock_post, unbatched_server, monkeypatch): - """Test that 5xx and 429 errors are retried.""" - monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "6") - monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") - call_id = generate_id() - - mock_post.side_effect = [ - httpx.Response(500, request=httpx.Request("POST", "http://test.com")), - httpx.Response(502, request=httpx.Request("POST", "http://test.com")), - httpx.Response(503, request=httpx.Request("POST", "http://test.com")), - httpx.Response(504, request=httpx.Request("POST", "http://test.com")), - httpx.Response(429, request=httpx.Request("POST", "http://test.com")), - httpx.Response( - 200, - json=dict(tsi.CallStartRes(id=call_id, trace_id="test_trace_id")), - request=httpx.Request("POST", "http://test.com"), - ), - ] - start = generate_start(call_id) - unbatched_server.call_start(tsi.CallStartReq(start=start)) - - -@patch("weave.utils.http_requests.post") -def test_other_error_retry(mock_post, unbatched_server, monkeypatch): - """Test that connection errors are retried.""" - monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "6") - monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.1") - call_id = generate_id() - - mock_post.side_effect = [ - ConnectionResetError(), - ConnectionError(), - OSError(), - TimeoutError(), - httpx.Response( - 200, - json=dict(tsi.CallStartRes(id=call_id, trace_id="test_trace_id")), - request=httpx.Request("POST", "http://test.com"), - ), - ] - start = generate_start(call_id) - unbatched_server.call_start(tsi.CallStartReq(start=start)) - - -@patch("weave.utils.http_requests.post") -def test_timeout_retry_mechanism(mock_post, success_response, monkeypatch): - """Test that timeouts trigger the retry mechanism.""" - monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - - # Mock server to raise errors twice, then succeed - mock_post.side_effect = [ - httpx.TimeoutException("Connection timed out"), - httpx.HTTPStatusError( - "500 Server Error", request=MagicMock(), response=MagicMock(status_code=500) - ), - success_response, - ] - - # Trying to send a batch should fail 2 times, then succeed - server.call_start(tsi.CallStartReq(start=generate_start())) - server.call_processor.stop_accepting_new_work_and_flush_queue() - - # Verify that requests.post was called 3 times - assert mock_post.call_count == 3 - - -@pytest.fixture -def fast_retrying_server(monkeypatch): - """Create a RemoteHTTPTraceServer with fast retry settings for testing.""" - monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - fast_retry = tenacity.retry( - wait=tenacity.wait_fixed(0.1), - stop=tenacity.stop_after_attempt(2), - reraise=True, - ) - unwrapped_send_batch_to_server = MethodType( - server._send_batch_to_server.__wrapped__, # type: ignore[attr-defined] - server, - ) - server._send_batch_to_server = fast_retry(unwrapped_send_batch_to_server) - yield server - if server.call_processor: - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() - - -@pytest.mark.disable_logging_error_check -@patch("weave.utils.http_requests.post") -def test_post_timeout(mock_post, success_response, fast_retrying_server, log_collector): - """Test batch recovery after timeout exhaustion. - - This test verifies that we can still send new batches even if one batch - times out and exhausts all retries. - """ - configure_logger() - # Configure mock to timeout twice to exhaust retries - mock_post.side_effect = [ - httpx.TimeoutException("Connection timed out"), - httpx.TimeoutException("Connection timed out"), - ] - - # Phase 1: Try but fail to process the first batch - fast_retrying_server.call_start(tsi.CallStartReq(start=generate_start())) - fast_retrying_server.call_processor.stop_accepting_new_work_and_flush_queue() - logs = log_collector.get_warning_logs() - assert len(logs) >= 1 - assert any("requeuing batch" in log.msg for log in logs) - - # Phase 2: Reset mock and verify we can still process a new batch - mock_post.reset_mock() - mock_post.side_effect = [ - httpx.TimeoutException("Connection timed out"), - success_response, - ] - - # Create a new server since the old one has shutdown its batch processor - new_server = RemoteHTTPTraceServer("http://example.com", should_batch=False) - fast_retry = tenacity.retry( - wait=tenacity.wait_fixed(0.1), - stop=tenacity.stop_after_attempt(2), - reraise=True, - ) - unwrapped_send_batch_to_server = MethodType( - new_server._send_batch_to_server.__wrapped__, # type: ignore[attr-defined] - new_server, - ) - new_server._send_batch_to_server = fast_retry(unwrapped_send_batch_to_server) - - # Should succeed with retry - start_req = tsi.CallStartReq(start=generate_start()) - response = new_server.call_start(start_req) - assert response.id == "test_id" - assert response.trace_id == "test_trace_id" - - -@patch("weave.utils.http_requests.post") -def test_auto_upgrade_to_calls_complete_on_error(mock_post, monkeypatch): - """Verify client switches to CallBatchProcessor when server returns CALLS_COMPLETE_MODE_REQUIRED.""" - monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "false") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - - # Verify initial state: using legacy AsyncBatchProcessor - assert server.use_calls_complete is False - assert isinstance(server.call_processor, AsyncBatchProcessor) - assert not isinstance(server.call_processor, CallBatchProcessor) - old_processor = server.call_processor - - mock_post.side_effect = [ - make_calls_complete_required_response(), - httpx.Response( - 200, json={}, request=httpx.Request("POST", "http://example.com") - ), - ] - - call_id = generate_id() - start = StartBatchItem( - req=tsi.CallStartReq(start=generate_start(call_id, "entity/project")) - ) - end = EndBatchItem(req=tsi.CallEndReq(end=generate_end(call_id, "entity/project"))) - - try: - server._flush_calls([start, end]) - server.call_processor.stop_accepting_new_work_and_flush_queue() - - # Verify upgrade happened - assert server.use_calls_complete is True - assert isinstance(server.call_processor, CallBatchProcessor) - assert old_processor.stop_accepting_work_event.is_set() - assert any("/calls/complete" in call[0][0] for call in mock_post.call_args_list) - finally: - if server.call_processor and server.call_processor.is_accepting_new_work(): - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() - - -@patch("weave.utils.http_requests.post") -def test_eager_calls_complete_required_is_reraised(mock_post, monkeypatch): - """Verify CallsCompleteModeRequired in eager path is re-raised for caller to handle.""" - from weave.trace_server_bindings.http_utils import CallsCompleteModeRequired - - monkeypatch.setenv("WEAVE_USE_CALLS_COMPLETE", "true") - server = RemoteHTTPTraceServer("http://example.com", should_batch=True) - mock_post.side_effect = [make_calls_complete_required_response()] - - start = StartBatchItem( - req=tsi.CallStartReq(start=generate_start("call-id", "entity/project")) - ) - - try: - with pytest.raises(CallsCompleteModeRequired): - server._flush_calls_eager([start]) - finally: - if server.call_processor and server.call_processor.is_accepting_new_work(): - server.call_processor.stop_accepting_new_work_and_flush_queue() - if server.feedback_processor: - server.feedback_processor.stop_accepting_new_work_and_flush_queue() diff --git a/tests/trace_server_bindings/test_tags_aliases_routes.py b/tests/trace_server_bindings/test_tags_aliases_routes.py index 994c969c9751..c63b68b4f9be 100644 --- a/tests/trace_server_bindings/test_tags_aliases_routes.py +++ b/tests/trace_server_bindings/test_tags_aliases_routes.py @@ -1,29 +1,37 @@ -"""Tests for RESTful tags and aliases routes in RemoteHTTPTraceServer. +"""Tests for RESTful tags and aliases routes in StainlessRemoteHTTPTraceServer. These tests verify that the tag/alias methods send the correct HTTP method, -URL path, and request body through the RemoteHTTPTraceServer client. +URL path, and request body. Requests are observed at the httpx transport +boundary, so the SDK routing and conversion layers are exercised for real. """ from __future__ import annotations import json -from unittest.mock import MagicMock, patch import httpx import pytest +from tests.trace_server_bindings.conftest import SpyTransport from weave.trace_server import trace_server_interface as tsi -from weave.trace_server_bindings.remote_http_trace_server import ( - RemoteHTTPTraceServer, +from weave.trace_server_bindings.stainless_remote_http_trace_server import ( + StainlessRemoteHTTPTraceServer, ) BASE_URL = "http://example.com" @pytest.fixture -def server(): - """Create a RemoteHTTPTraceServer with mocked HTTP methods.""" - srv = RemoteHTTPTraceServer(BASE_URL, should_batch=False) +def transport(): + return SpyTransport() + + +@pytest.fixture +def server(transport): + """Create a StainlessRemoteHTTPTraceServer over a spy transport.""" + srv = StainlessRemoteHTTPTraceServer( + BASE_URL, should_batch=False, transport=transport + ) yield srv if srv.call_processor: srv.call_processor.stop_accepting_new_work_and_flush_queue() @@ -31,164 +39,116 @@ def server(): srv.feedback_processor.stop_accepting_new_work_and_flush_queue() -def _mock_response(json_data: dict | None = None) -> MagicMock: - """Create a mock httpx.Response with the given JSON data.""" - resp = MagicMock(spec=httpx.Response) - resp.status_code = 200 - resp.json.return_value = json_data or {} - return resp - - class TestObjAddTags: - def test_sends_correct_request(self, server): - mock_resp = _mock_response() - with patch.object(server, "put", return_value=mock_resp) as mock_put: - server.obj_add_tags( - tsi.ObjAddTagsReq( - project_id="entity/project", - object_id="my-obj", - digest="abc123", - tags=["production", "reviewed"], - ) + def test_sends_correct_request(self, server, transport): + server.obj_add_tags( + tsi.ObjAddTagsReq( + project_id="entity/project", + object_id="my-obj", + digest="abc123", + tags=["production", "reviewed"], ) + ) - mock_put.assert_called_once() - call_url = mock_put.call_args[0][0] - assert call_url == "/objs/my-obj/versions/abc123/tags" + request = transport.requests[0] + assert request.method == "PUT" + assert str(request.url) == f"{BASE_URL}/objs/my-obj/versions/abc123/tags" - sent_data = mock_put.call_args[1]["data"] - body = json.loads(sent_data) - assert body["project_id"] == "entity/project" - assert body["tags"] == ["production", "reviewed"] - # object_id and digest should NOT be in the body (they're in the URL) - assert "object_id" not in body - assert "digest" not in body + body = json.loads(request.content) + assert body["project_id"] == "entity/project" + assert body["tags"] == ["production", "reviewed"] + # object_id and digest should NOT be in the body (they're in the URL) + assert "object_id" not in body + assert "digest" not in body class TestObjRemoveTags: - def test_sends_correct_request(self, server): - mock_resp = _mock_response() - with patch.object(server, "post", return_value=mock_resp) as mock_post: - server.obj_remove_tags( - tsi.ObjRemoveTagsReq( - project_id="entity/project", - object_id="my-obj", - digest="abc123", - tags=["staging"], - ) + def test_sends_correct_request(self, server, transport): + server.obj_remove_tags( + tsi.ObjRemoveTagsReq( + project_id="entity/project", + object_id="my-obj", + digest="abc123", + tags=["staging"], ) + ) - mock_post.assert_called_once() - call_url = mock_post.call_args[0][0] - assert call_url == "/objs/my-obj/versions/abc123/tags/remove" + request = transport.requests[0] + assert request.method == "POST" + assert str(request.url) == f"{BASE_URL}/objs/my-obj/versions/abc123/tags/remove" - sent_data = mock_post.call_args[1]["data"] - body = json.loads(sent_data) - assert body["project_id"] == "entity/project" - assert body["tags"] == ["staging"] - assert "object_id" not in body - assert "digest" not in body + body = json.loads(request.content) + assert body["project_id"] == "entity/project" + assert body["tags"] == ["staging"] + assert "object_id" not in body + assert "digest" not in body class TestObjSetAliases: - def test_sends_correct_request(self, server): - mock_resp = _mock_response() - with patch.object(server, "put", return_value=mock_resp) as mock_put: - server.obj_set_aliases( - tsi.ObjSetAliasesReq( - project_id="entity/project", - object_id="my-obj", - digest="abc123", - aliases=["staging", "candidate"], - ) + def test_sends_correct_request(self, server, transport): + server.obj_set_aliases( + tsi.ObjSetAliasesReq( + project_id="entity/project", + object_id="my-obj", + digest="abc123", + aliases=["staging", "candidate"], ) + ) - mock_put.assert_called_once() - call_url = mock_put.call_args[0][0] - assert call_url == "/objs/my-obj/aliases" + request = transport.requests[0] + assert request.method == "PUT" + assert str(request.url) == f"{BASE_URL}/objs/my-obj/aliases" - sent_data = mock_put.call_args[1]["data"] - body = json.loads(sent_data) - assert body["project_id"] == "entity/project" - assert body["digest"] == "abc123" - assert body["aliases"] == ["staging", "candidate"] - assert "object_id" not in body + body = json.loads(request.content) + assert body["project_id"] == "entity/project" + assert body["digest"] == "abc123" + assert body["aliases"] == ["staging", "candidate"] + assert "object_id" not in body class TestObjRemoveAliases: - def test_sends_correct_request(self, server): - mock_resp = _mock_response() - with patch.object(server, "post", return_value=mock_resp) as mock_post: - server.obj_remove_aliases( - tsi.ObjRemoveAliasesReq( - project_id="entity/project", - object_id="my-obj", - aliases=["staging"], - ) + def test_sends_correct_request(self, server, transport): + server.obj_remove_aliases( + tsi.ObjRemoveAliasesReq( + project_id="entity/project", + object_id="my-obj", + aliases=["staging"], ) + ) - mock_post.assert_called_once() - call_url = mock_post.call_args[0][0] - assert call_url == "/objs/my-obj/aliases/remove" + request = transport.requests[0] + assert request.method == "POST" + assert str(request.url) == f"{BASE_URL}/objs/my-obj/aliases/remove" - sent_data = mock_post.call_args[1]["data"] - body = json.loads(sent_data) - assert body["project_id"] == "entity/project" - assert body["aliases"] == ["staging"] - assert "object_id" not in body + body = json.loads(request.content) + assert body["project_id"] == "entity/project" + assert body["aliases"] == ["staging"] + assert "object_id" not in body class TestTagsList: - def test_sends_get_to_correct_url(self, server): - mock_resp = _mock_response({"tags": ["prod", "staging"]}) - with patch.object(server, "get", return_value=mock_resp) as mock_get: - result = server.tags_list(tsi.TagsListReq(project_id="entity/project")) - - mock_get.assert_called_once() - call_url = mock_get.call_args[0][0] - assert call_url == "/tags" + def test_sends_get_with_project_id_param(self, server, transport): + transport.queue.append(httpx.Response(200, json={"tags": ["prod", "staging"]})) + result = server.tags_list(tsi.TagsListReq(project_id="entity/project")) - def test_passes_project_id_as_query_param(self, server): - mock_resp = _mock_response({"tags": ["prod"]}) - with patch.object(server, "get", return_value=mock_resp) as mock_get: - server.tags_list(tsi.TagsListReq(project_id="entity/project")) - - call_kwargs = mock_get.call_args[1] - assert call_kwargs["params"] == {"project_id": "entity/project"} - - def test_returns_parsed_response(self, server): - mock_resp = _mock_response({"tags": ["prod", "staging"]}) - with patch.object(server, "get", return_value=mock_resp): - result = server.tags_list(tsi.TagsListReq(project_id="entity/project")) - - assert isinstance(result, tsi.TagsListRes) - assert result.tags == ["prod", "staging"] + request = transport.requests[0] + assert request.method == "GET" + assert request.url.path == "/tags" + assert request.url.params["project_id"] == "entity/project" + assert isinstance(result, tsi.TagsListRes) + assert result.tags == ["prod", "staging"] class TestAliasesList: - def test_sends_get_to_correct_url(self, server): - mock_resp = _mock_response({"aliases": ["deploy", "staging"]}) - with patch.object(server, "get", return_value=mock_resp) as mock_get: - server.aliases_list(tsi.AliasesListReq(project_id="entity/project")) - - mock_get.assert_called_once() - call_url = mock_get.call_args[0][0] - assert call_url == "/aliases" - - def test_passes_project_id_as_query_param(self, server): - mock_resp = _mock_response({"aliases": ["deploy"]}) - with patch.object(server, "get", return_value=mock_resp) as mock_get: - server.aliases_list(tsi.AliasesListReq(project_id="entity/project")) - - call_kwargs = mock_get.call_args[1] - assert call_kwargs["params"] == {"project_id": "entity/project"} - - def test_returns_parsed_response(self, server): - mock_resp = _mock_response({"aliases": ["deploy", "staging"]}) - with patch.object(server, "get", return_value=mock_resp): - result = server.aliases_list( - tsi.AliasesListReq(project_id="entity/project") - ) - - assert isinstance(result, tsi.AliasesListRes) - assert result.aliases == ["deploy", "staging"] + def test_sends_get_with_project_id_param(self, server, transport): + transport.queue.append( + httpx.Response(200, json={"aliases": ["deploy", "staging"]}) + ) + result = server.aliases_list(tsi.AliasesListReq(project_id="entity/project")) + + request = transport.requests[0] + assert request.method == "GET" + assert request.url.path == "/aliases" + assert request.url.params["project_id"] == "entity/project" + assert isinstance(result, tsi.AliasesListRes) + assert result.aliases == ["deploy", "staging"] diff --git a/tests/trace_server_bindings/test_trace_server_bindings.py b/tests/trace_server_bindings/test_trace_server_bindings.py index b960cc3f4364..1ec2eeccc37a 100644 --- a/tests/trace_server_bindings/test_trace_server_bindings.py +++ b/tests/trace_server_bindings/test_trace_server_bindings.py @@ -1,8 +1,7 @@ -"""Tests for RemoteHTTPTraceServer and StainlessRemoteHTTPTraceServer bindings. +"""Tests for the StainlessRemoteHTTPTraceServer binding. -These tests verify the batching, splitting, and retry behavior of both server -implementations. The --remote-http-trace-server flag controls which implementation -is tested (default: "remote", or "stainless"). +These tests verify the batching, splitting, and retry behavior of the remote +trace server binding. """ from __future__ import annotations @@ -14,7 +13,6 @@ import httpx import pytest -import requests from tests.trace_server_bindings.conftest import ( generate_call_start_end_pair, @@ -27,9 +25,6 @@ EndBatchItem, StartBatchItem, ) -from weave.trace_server_bindings.remote_http_trace_server import ( - RemoteHTTPTraceServer, -) @pytest.mark.parametrize("server", ["small_limit"], indirect=True) @@ -214,12 +209,11 @@ def test_non_uniform_batch_items(server): @pytest.mark.disable_logging_error_check @pytest.mark.parametrize("server", ["normal"], indirect=True) @pytest.mark.parametrize("log_collector", ["warning"], indirect=True) -def test_drop_data_when_queue_is_full(server, server_class, log_collector): +def test_drop_data_when_queue_is_full(server, log_collector): """Test that items are dropped when the queue is full.""" - # For StainlessRemoteHTTPTraceServer, set _dropped_item_count to 0 - # so the next drop (1st) will log (logging happens at 1, 1001, 2001, etc.) - if server_class.__name__ == "StainlessRemoteHTTPTraceServer": - server.call_processor._dropped_item_count = 0 + # Set _dropped_item_count to 0 so the next drop (1st) will log + # (logging happens at 1, 1001, 2001, etc.) + server.call_processor._dropped_item_count = 0 # Replace the real queue with a mock that raises Full when put_nowait is called mock_queue = MagicMock() @@ -245,7 +239,7 @@ def test_drop_data_when_queue_is_full(server, server_class, log_collector): @pytest.mark.disable_logging_error_check @pytest.mark.parametrize("server", ["normal"], indirect=True) -def test_requeue_after_max_retries(server, server_class, caplog): +def test_requeue_after_max_retries(server, caplog): """Test that batches are requeued after max retries.""" caplog.set_level(logging.WARNING) @@ -255,15 +249,9 @@ def test_requeue_after_max_retries(server, server_class, caplog): # Mock enqueue to verify it gets called, and _send_batch_to_server to throw an exception server.call_processor.enqueue = MagicMock() - # Use the appropriate exception type for each implementation - if server_class == RemoteHTTPTraceServer: - server._send_batch_to_server = MagicMock( - side_effect=httpx.ConnectError("Connection error") - ) - else: - server._send_batch_to_server = MagicMock( - side_effect=requests.ConnectionError("Connection error") - ) + server._send_batch_to_server = MagicMock( + side_effect=httpx.ConnectError("Connection error") + ) # Create a batch start, end = generate_call_start_end_pair() diff --git a/weave/durability/wal_sender.py b/weave/durability/wal_sender.py index 93a00071ea26..53cab7ea4770 100644 --- a/weave/durability/wal_sender.py +++ b/weave/durability/wal_sender.py @@ -58,8 +58,8 @@ from weave.telemetry.trace_sentry import log_error from weave.trace_server import trace_server_interface as tsi from weave.trace_server_bindings.client_interface import TraceServerClientInterface -from weave.trace_server_bindings.remote_http_trace_server import ( - RemoteHTTPTraceServer, +from weave.trace_server_bindings.stainless_remote_http_trace_server import ( + StainlessRemoteHTTPTraceServer, ) logger = logging.getLogger(__name__) @@ -447,9 +447,9 @@ def main(argv: list[str] | None = None) -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") - # RemoteHTTPTraceServer is missing a few abstract methods that the - # WAL sender doesn't use. Suppress via Any intermediate. - remote_cls: Any = RemoteHTTPTraceServer + # StainlessRemoteHTTPTraceServer is missing a few abstract methods that + # the WAL sender doesn't use. Suppress via Any intermediate. + remote_cls: Any = StainlessRemoteHTTPTraceServer server: TraceServerClientInterface = remote_cls( args.trace_server_url, auth=("api", args.api_key) ) diff --git a/weave/trace/api.py b/weave/trace/api.py index 6c6fc7745668..18bd085ad9a9 100644 --- a/weave/trace/api.py +++ b/weave/trace/api.py @@ -100,10 +100,6 @@ def init( This includes connection time, data transfer, and server processing. Increase for slow networks or when working with large payloads. Default: `30.0` - - `use_stainless_server` (bool): Uses the Stainless-generated HTTP client which - provides better type safety, automatic retries, and improved error handling. This is - experimental and may become the default in future versions. - Default: `False` - `use_calls_complete` (bool): Uses an optimized write path that batches complete call data (start and end) into a single request instead of separate start/end requests. This reduces server load and improves performance, especially for short-lived ops. diff --git a/weave/trace/settings.py b/weave/trace/settings.py index 9889504f2019..ba385b11a349 100644 --- a/weave/trace/settings.py +++ b/weave/trace/settings.py @@ -243,15 +243,6 @@ class UserSettings: Can be overridden with the environment variable `WEAVE_HTTP_TIMEOUT` """ - use_stainless_server: bool = False - """ - Toggles use of the stainless-generated HTTP client for trace server communication. - - If True, uses StainlessRemoteHTTPTraceServer instead of RemoteHTTPTraceServer. - This provides better type safety and automatic client generation from OpenAPI specs. - Can be overridden with the environment variable `WEAVE_USE_STAINLESS_SERVER` - """ - use_calls_complete: bool = True """ Toggles use of the calls_complete write path for new calls. @@ -369,7 +360,6 @@ class _SettingsOverrides(TypedDict, total=False): enable_disk_fallback: bool use_parallel_table_upload: bool http_timeout: float - use_stainless_server: bool use_calls_complete: bool enable_client_side_digests: bool enable_wal: bool @@ -613,13 +603,6 @@ def http_timeout() -> float: return _env_or_default("http_timeout", _current_settings.get().http_timeout) -def should_use_stainless_server() -> bool: - """Returns whether the stainless-generated HTTP client should be used.""" - return _env_or_default( - "use_stainless_server", _current_settings.get().use_stainless_server - ) - - def should_use_calls_complete() -> bool: """Returns whether the calls_complete write path should be used.""" return _env_or_default( diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index ca1bd323b34e..27f218a27d49 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -417,7 +417,7 @@ def __init__( self._server_feedback_processor: AsyncBatchProcessor | None = None # This is a short-term hack to get around the fact that we are reaching into # the underlying implementation of the specific server to get the call processor. - # The `RemoteHTTPTraceServer` contains a call processor and we use that to control + # The `StainlessRemoteHTTPTraceServer` contains a call processor and we use that to control # some client-side flushing mechanics. We should move this to the interface layer. However, # we don't really want the server-side implementations to need to define no-ops as that is # even uglier. So we are using this "hasattr" check to avoid forcing the server-side implementations @@ -2166,7 +2166,7 @@ def _handle_digest_mismatch( """ # DigestMismatchError: raised directly by local servers (SQLite, ClickHouse). is_local_mismatch = isinstance(e, DigestMismatchError) - # HTTPError 409: raised by RemoteHTTPTraceServer when the remote + # HTTPError 409: raised by StainlessRemoteHTTPTraceServer when the remote # server returns HTTP 409 Conflict for a digest mismatch. is_remote_mismatch = ( isinstance(e, HTTPError) @@ -2489,11 +2489,8 @@ def test_func(req: TableCreateFromDigestsReq) -> Any: if hasattr(server, "_next_trace_server"): server = server._next_trace_server - assert hasattr(server, "_post_request_executor") - assert hasattr(server._post_request_executor, "__wrapped__") - return server._post_request_executor.__wrapped__( - server, "/table/create_from_digests", req - ) + assert hasattr(server, "unretried_table_create_from_digests") + return server.unretried_table_create_from_digests(req) use_parallel_chunks = check_endpoint_exists( test_func, test_req, "table_create_from_digests" @@ -3069,7 +3066,7 @@ def sanitize_object_name(name: str) -> str: def _get_call_processor(server: Any) -> Any: """Get the call processor from a server, traversing through middleware wrappers. - Most production clients (RemoteHTTPTraceServer, StainlessRemoteHTTPTraceServer) + Most production clients (StainlessRemoteHTTPTraceServer) use batching and have a call_processor. This traverses through middleware like CachingMiddlewareTraceServer to find it. diff --git a/weave/trace/weave_init.py b/weave/trace/weave_init.py index c718e58da197..c0e6577ea27a 100644 --- a/weave/trace/weave_init.py +++ b/weave/trace/weave_init.py @@ -16,7 +16,6 @@ from weave.trace.context import weave_client_context from weave.trace.settings import ( should_redact_pii, - should_use_stainless_server, use_server_cache, ) from weave.trace.urls import otel_traces_endpoint @@ -28,7 +27,6 @@ CachingMiddlewareTraceServer, ) from weave.trace_server_bindings.client_interface import TraceServerClientInterface -from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer from weave.trace_server_version import MIN_TRACE_SERVER_VERSION from weave.wandb_interface.context import get_wandb_api_context @@ -329,15 +327,15 @@ def init_weave_get_server( api_key: str | None = None, should_batch: bool = True, ) -> TraceServerClientInterface: - res: TraceServerClientInterface - if should_use_stainless_server(): - from weave.trace_server_bindings.stainless_remote_http_trace_server import ( - StainlessRemoteHTTPTraceServer, - ) + # Imported lazily so `import weave` does not pay for weave_server_sdk's + # model definitions until a client is actually initialized. + from weave.trace_server_bindings.stainless_remote_http_trace_server import ( + StainlessRemoteHTTPTraceServer, + ) - res = StainlessRemoteHTTPTraceServer.from_env(should_batch) - else: - res = RemoteHTTPTraceServer.from_env(should_batch) + res: TraceServerClientInterface = StainlessRemoteHTTPTraceServer.from_env( + should_batch + ) if api_key is not None: res.set_auth(("api", api_key)) return res diff --git a/weave/trace_server_bindings/remote_http_trace_server.py b/weave/trace_server_bindings/remote_http_trace_server.py deleted file mode 100644 index b7c9bdade5fd..000000000000 --- a/weave/trace_server_bindings/remote_http_trace_server.py +++ /dev/null @@ -1,1799 +0,0 @@ -import datetime -import io -import logging -from collections.abc import Iterator -from typing import Any, cast -from zoneinfo import ZoneInfo - -import httpx -from pydantic import BaseModel, Field, validate_call -from pydantic.json_schema import SkipJsonSchema -from typing_extensions import Self - -from weave.trace.env import weave_trace_server_url -from weave.trace.settings import ( - max_calls_queue_size, - should_enable_disk_fallback, - should_use_calls_complete, -) -from weave.trace_server import http_service_interface as his -from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.ids import generate_id -from weave.trace_server.service_interface import ServerInfoRes -from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor -from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor -from weave.trace_server_bindings.client_interface import TraceServerClientInterface -from weave.trace_server_bindings.http_utils import ( - REMOTE_REQUEST_BYTES_LIMIT, - CallsCompleteModeRequired, - handle_response_error, - log_dropped_call_batch, - log_dropped_feedback_batch, - process_batch_with_retry, -) -from weave.trace_server_bindings.models import ( - Batch, - CompleteBatchItem, - EndBatchItem, - EntityProjectInfo, - StartBatchItem, -) -from weave.utils import http_requests -from weave.utils.project_id import from_project_id -from weave.utils.retry import get_current_retry_id, with_retry -from weave.wandb_interface import project_creator - -logger = logging.getLogger(__name__) - -# Default timeout values (in seconds) -# DEFAULT_CONNECT_TIMEOUT = 10 -# DEFAULT_READ_TIMEOUT = 30 -# DEFAULT_TIMEOUT = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT) - - -class RemoteHTTPTraceServer(TraceServerClientInterface): - trace_server_url: str - - # My current batching is not safe in notebooks, disable it for now - def __init__( - self, - trace_server_url: str, - should_batch: bool = False, - *, - remote_request_bytes_limit: int = REMOTE_REQUEST_BYTES_LIMIT, - auth: tuple[str, str] | None = None, - extra_headers: dict[str, str] | None = None, - ): - super().__init__() - self.trace_server_url = trace_server_url - self.should_batch = should_batch - self.use_calls_complete = should_use_calls_complete() and should_batch - self.call_processor: AsyncBatchProcessor | CallBatchProcessor | None = None - self.feedback_processor: AsyncBatchProcessor | None = None - if self.should_batch: - if self.use_calls_complete: - self.call_processor = CallBatchProcessor( - complete_processor_fn=self._flush_calls_complete, - eager_processor_fn=self._flush_calls_eager, - max_queue_size=max_calls_queue_size(), - enable_disk_fallback=should_enable_disk_fallback(), - ) - else: - self.call_processor = AsyncBatchProcessor( - self._flush_calls, - max_queue_size=max_calls_queue_size(), - enable_disk_fallback=should_enable_disk_fallback(), - ) - self.feedback_processor = AsyncBatchProcessor( - self._flush_feedback, - max_queue_size=max_calls_queue_size(), - enable_disk_fallback=should_enable_disk_fallback(), - ) - self._auth: tuple[str, str] | None = auth - self._extra_headers: dict[str, str] | None = extra_headers - self.remote_request_bytes_limit = remote_request_bytes_limit - - def ensure_project_exists( - self, entity: str, project: str - ) -> tsi.EnsureProjectExistsRes: - # TODO: This should happen in the wandb backend, not here, and it's slow - # (hundreds of ms) - return tsi.EnsureProjectExistsRes.model_validate( - project_creator.ensure_project_exists(entity, project) - ) - - @classmethod - def from_env(cls, should_batch: bool = False) -> Self: - return cls(weave_trace_server_url(), should_batch) - - def set_auth(self, auth: tuple[str, str]) -> None: - self._auth = auth - - def _build_dynamic_request_headers(self) -> dict[str, str]: - """Build headers for HTTP requests, including extra headers and retry ID.""" - headers = dict(self._extra_headers) if self._extra_headers else {} - if retry_id := get_current_retry_id(): - headers["X-Weave-Retry-Id"] = retry_id - return headers - - def get(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response: - headers = self._build_dynamic_request_headers() - - return http_requests.get( - self.trace_server_url + url, - *args, - auth=self._auth, - headers=headers, - **kwargs, - ) - - def post(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response: - headers = self._build_dynamic_request_headers() - - return http_requests.post( - self.trace_server_url + url, - *args, - auth=self._auth, - headers=headers, - **kwargs, - ) - - def put(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response: - headers = self._build_dynamic_request_headers() - - return http_requests.put( - self.trace_server_url + url, - *args, - auth=self._auth, - headers={**headers, **kwargs.pop("headers", {})}, - **kwargs, - ) - - def delete(self, url: str, *args: Any, **kwargs: Any) -> httpx.Response: - headers = self._build_dynamic_request_headers() - - return http_requests.delete( - self.trace_server_url + url, - *args, - auth=self._auth, - headers=headers, - **kwargs, - ) - - @with_retry - def _send_batch_to_server(self, encoded_data: bytes) -> None: - """Send a batch of data to the server with retry logic. - - This method is separated from _flush_calls to avoid recursive retries. - """ - r = self.post( - "/call/upsert_batch", - data=encoded_data, # type: ignore - ) - handle_response_error(r, "/call/upsert_batch") - - def _flush_calls( - self, - batch: list[StartBatchItem | EndBatchItem], - *, - _should_update_batch_size: bool = True, - ) -> None: - """Process a batch of calls, splitting if necessary and sending to the server. - - This method handles the logic of splitting batches that are too large, - but delegates the actual server communication (with retries) to _send_batch_to_server. - """ - # Call processor must be defined for this method - assert self.call_processor is not None - if len(batch) == 0: - return - - def get_item_id(item: StartBatchItem | EndBatchItem) -> str: - if isinstance(item, StartBatchItem): - return f"{item.req.start.id}-start" - elif isinstance(item, EndBatchItem): - return f"{item.req.end.id}-end" - return "unknown" - - def encode_batch(batch: list[StartBatchItem | EndBatchItem]) -> bytes: - data = Batch(batch=batch).model_dump_json() - return data.encode("utf-8") - - try: - process_batch_with_retry( - batch_name="calls", - batch=batch, - remote_request_bytes_limit=self.remote_request_bytes_limit, - send_batch_fn=self._send_batch_to_server, - processor_obj=self.call_processor, - should_update_batch_size=_should_update_batch_size, - get_item_id_fn=get_item_id, - log_dropped_fn=log_dropped_call_batch, - encode_batch_fn=encode_batch, - ) - except CallsCompleteModeRequired as e: - # Project requires calls_complete mode - upgrade and re-enqueue the batch - self._upgrade_to_calls_complete(batch, str(e)) - - def _upgrade_to_calls_complete( - self, batch: list[StartBatchItem | EndBatchItem], error_message: str - ) -> None: - """Upgrade from legacy AsyncBatchProcessor to CallBatchProcessor. - - This is called when the server indicates a project requires calls_complete mode. - The upgrade happens transparently: we replace the processor and re-enqueue - the current batch items. No calls are dropped during this upgrade. - - Args: - batch: The batch of items that failed to send (will be re-enqueued). - error_message: The error message from the server (for logging). - """ - # Already upgraded? Just re-enqueue to the new processor - if self.use_calls_complete: - if isinstance(self.call_processor, CallBatchProcessor): - self.call_processor.enqueue( - cast(list[StartBatchItem | EndBatchItem | CompleteBatchItem], batch) - ) - return - - logger.warning( - "Project has been previously written to with `use_calls_complete=True` and requires 'calls_complete' mode. Automatically upgrading SDK to use the more performant calls_complete processor. Server message: %s", - error_message, - ) - - # Store old processor reference for cleanup - old_processor = self.call_processor - - # Create new CallBatchProcessor - self.use_calls_complete = True - self.call_processor = CallBatchProcessor( - complete_processor_fn=self._flush_calls_complete, - eager_processor_fn=self._flush_calls_eager, - max_queue_size=max_calls_queue_size(), - enable_disk_fallback=should_enable_disk_fallback(), - ) - - # Re-enqueue the batch items to the new processor - # Cast needed: list is invariant, but StartBatchItem | EndBatchItem is a valid subset of BatchItem - self.call_processor.enqueue( - cast(list[StartBatchItem | EndBatchItem | CompleteBatchItem], batch) - ) - - # Stop the old processor gracefully - any remaining items in its queue - # will be caught by _flush_calls which will re-enqueue them to the - # new processor via this same method (the "already upgraded" path above) - if old_processor is not None: - old_processor.stop_accepting_work_event.set() - - def _flush_calls_eager( - self, - batch: list[StartBatchItem | EndBatchItem], - *, - _should_update_batch_size: bool = True, - ) -> None: - """Process eager start/end items via v2 single endpoints. - - This is used for ops like Evaluation.evaluate that need their start - to be visible immediately in the UI. Uses single call/start and call/end - endpoints for easier rate limiting. - - Each item is sent individually with retry logic (@with_retry). If all retries - are exhausted, the item is logged and dropped, then processing continues - with remaining items in the batch. - """ - for item in batch: - try: - if isinstance(item, StartBatchItem): - self._send_call_start_v2(item.req.start) - elif isinstance(item, EndBatchItem): - self._send_call_end_v2(item.req.end) - except CallsCompleteModeRequired: - # Re-raise so caller can handle the upgrade to calls_complete mode - raise - except Exception as e: - log_dropped_call_batch([item], e) - - @with_retry - def _send_call_start_v2(self, start: tsi.StartedCallSchemaForInsert) -> None: - """Send a single call start to the v2 endpoint.""" - project_id = start.project_id - entity, project = project_id.split("/", 1) - url = f"/v2/{entity}/{project}/call/start" - req = tsi.CallStartV2Req(start=start) - r = self.post(url, data=req.model_dump_json().encode("utf-8")) - handle_response_error(r, url) - - @with_retry - def _send_call_end_v2(self, end: tsi.EndedCallSchemaForInsertWithStartedAt) -> None: - """Send a single call end to the v2 endpoint.""" - project_id = end.project_id - entity, project = project_id.split("/", 1) - url = f"/v2/{entity}/{project}/call/end" - req = tsi.CallEndV2Req(end=end) - r = self.post(url, data=req.model_dump_json().encode("utf-8")) - handle_response_error(r, url) - - def _extract_entity_project( - self, batch: list[CompleteBatchItem] - ) -> EntityProjectInfo: - """Extract entity, project, and project_id from first batch item.""" - if not batch: - raise ValueError("Cannot extract entity/project from empty batch") - - first_item = batch[0] - project_id = first_item.req.project_id - - if not project_id or "/" not in project_id: - raise ValueError( - f"Invalid project_id format: {project_id}. Expected 'entity/project'" - ) - - entity, project = project_id.split("/", 1) - if not entity or not project: - raise ValueError(f"Invalid project_id: {project_id}") - - return EntityProjectInfo(entity=entity, project=project, project_id=project_id) - - @with_retry - def _send_calls_complete_to_server( - self, entity: str, project: str, encoded_data: bytes - ) -> None: - """Send a batch of completed calls to the server with retry logic.""" - url = f"/v2/{entity}/{project}/calls/complete" - r = self.post(url, data=encoded_data) - handle_response_error(r, url) - - def _flush_calls_complete( - self, - batch: list[CompleteBatchItem], - *, - _should_update_batch_size: bool = True, - ) -> None: - """Process a batch of complete calls and send to the calls/upsert endpoint. - - This is the new calls_complete path. Complete calls have both start and - end information bundled together. - """ - assert self.call_processor is not None - if not batch: - return - - ep_info = self._extract_entity_project(batch) - - def get_item_id(item: CompleteBatchItem) -> str: - return f"{item.req.id}-complete" - - def encode_batch(batch: list[CompleteBatchItem]) -> bytes: - api_batch = [item.req for item in batch] - req = tsi.CallsUpsertCompleteReq(batch=api_batch) - return req.model_dump_json().encode("utf-8") - - process_batch_with_retry( - batch_name="calls_complete", - batch=batch, - remote_request_bytes_limit=self.remote_request_bytes_limit, - send_batch_fn=lambda data: self._send_calls_complete_to_server( - ep_info.entity, ep_info.project, data - ), - processor_obj=self.call_processor, - should_update_batch_size=_should_update_batch_size, - get_item_id_fn=get_item_id, - log_dropped_fn=log_dropped_call_batch, - encode_batch_fn=encode_batch, - ) - - def get_call_processor(self) -> AsyncBatchProcessor | CallBatchProcessor | None: - """Custom method not defined on the formal TraceServerInterface to expose - the underlying call processor. Should be formalized in a client-side interface. - """ - return self.call_processor - - def _send_feedback_batch_to_server(self, encoded_data: bytes) -> None: - """Send a batch of feedback data to the server with retry logic. - - This method is separated from _flush_feedback to avoid recursive retries. - """ - r = self.post( - "/feedback/batch/create", - data=encoded_data, # type: ignore - ) - handle_response_error(r, "/feedback/batch/create") - - def _flush_feedback( - self, - batch: list[tsi.FeedbackCreateReq], - ) -> None: - """Process a batch of feedback, splitting if necessary and sending to the server. - - This method handles the logic of splitting batches that are too large, - but delegates the actual server communication (with retries) to _send_feedback_batch_to_server. - """ - # Feedback processor must be defined for this method - assert self.feedback_processor is not None - if len(batch) == 0: - return - - def get_item_id(item: tsi.FeedbackCreateReq) -> str: - return f"{item.id}" - - def encode_batch(batch: list[tsi.FeedbackCreateReq]) -> bytes: - batch_req = tsi.FeedbackCreateBatchReq(batch=batch) - data = batch_req.model_dump_json() - return data.encode("utf-8") - - def send_feedback_batch(encoded_data: bytes) -> None: - try: - self._send_feedback_batch_to_server(encoded_data) - except (httpx.HTTPError, httpx.HTTPStatusError) as e: - # If batching endpoint doesn't exist (404) fall back to individual calls - if ( - response := getattr(e, "response", None) - ) and response.status_code == 404: - logger.debug( - "Batching endpoint not available, falling back to individual feedback creation: %s", - e, - ) - - # Feedback endpoint doesn't support id, created_at, so we need to strip them - class FeedbackCreateReqStripped(tsi.FeedbackCreateReq): - id: SkipJsonSchema[str] = Field(exclude=True) - created_at: SkipJsonSchema[datetime.datetime | None] = Field( - exclude=True, default=None - ) - - # Fall back to individual feedback creation calls - for item in batch: - item_copy = FeedbackCreateReqStripped(**item.model_dump()) - try: - self._generic_request( - "/feedback/create", - item_copy, - FeedbackCreateReqStripped, - tsi.FeedbackCreateRes, - ) - except Exception as individual_error: - logger.warning( - "Failed to create individual feedback: %s", - individual_error, - ) - else: - # Re-raise server errors (5xx) as they're not client compatibility issues - raise - - process_batch_with_retry( - batch_name="feedback", - batch=batch, - remote_request_bytes_limit=self.remote_request_bytes_limit, - send_batch_fn=send_feedback_batch, - processor_obj=self.feedback_processor, - should_update_batch_size=True, - get_item_id_fn=get_item_id, - log_dropped_fn=log_dropped_feedback_batch, - encode_batch_fn=encode_batch, - ) - - def get_feedback_processor(self) -> AsyncBatchProcessor | None: - """Custom method not defined on the formal TraceServerInterface to expose - the underlying feedback processor. Should be formalized in a client-side interface. - """ - return self.feedback_processor - - @with_retry - def _post_request_executor( - self, - url: str, - req: BaseModel, - stream: bool = False, - ) -> httpx.Response: - r = self.post( - url, - # `by_alias` is required since we have Mongo-style properties in the - # query models that are aliased to conform to start with `$`. Without - # this, the model_dump will use the internal property names which are - # not valid for the `model_validate` step. - data=req.model_dump_json(by_alias=True).encode("utf-8"), - stream=stream, - ) - handle_response_error(r, url) - return r - - @with_retry - def _get_request_executor( - self, - url: str, - params: dict[str, Any] | None = None, - stream: bool = False, - ) -> httpx.Response: - r = self.get(url, params=params or {}, stream=stream) - handle_response_error(r, url) - return r - - def _put_request_executor( - self, - url: str, - req: BaseModel, - stream: bool = False, - ) -> httpx.Response: - r = self.put( - url, - data=req.model_dump_json(by_alias=True).encode("utf-8"), - stream=stream, - ) - handle_response_error(r, url) - return r - - @with_retry - def _delete_request_executor( - self, - url: str, - params: dict[str, Any] | None = None, - stream: bool = False, - ) -> httpx.Response: - r = self.delete(url, params=params or {}, stream=stream) - handle_response_error(r, url) - return r - - def _generic_request( - self, - url: str, - req: BaseModel, - req_model: type[BaseModel], - res_model: type[BaseModel], - method: str = "POST", - params: dict[str, Any] | None = None, - ) -> BaseModel: - if method == "POST": - r = self._post_request_executor(url, req) - elif method == "PUT": - r = self._put_request_executor(url, req) - elif method == "GET": - r = self._get_request_executor(url, params) - elif method == "DELETE": - r = self._delete_request_executor(url, params) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - return res_model.model_validate(r.json()) - - def _generic_stream_request( - self, - url: str, - req: BaseModel, - req_model: type[BaseModel], - res_model: type[BaseModel], - method: str = "POST", - params: dict[str, Any] | None = None, - ) -> Iterator[BaseModel]: - if method == "POST": - r = self._post_request_executor(url, req, stream=True) - elif method == "GET": - r = self._get_request_executor(url, params, stream=True) - elif method == "DELETE": - r = self._delete_request_executor(url, params, stream=True) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - - try: - for line in r.iter_lines(): - if line: - yield res_model.model_validate_json(line) - finally: - r.close() - - @with_retry - def server_info(self) -> ServerInfoRes: - r = self.get( - "/server_info", - ) - handle_response_error(r, "/server_info") - return ServerInfoRes.model_validate(r.json()) - - @validate_call - def projects_info(self, req: tsi.ProjectsInfoReq) -> list[tsi.ProjectsInfoRes]: - r = self._post_request_executor("/service/projects_info", req) - handle_response_error(r, "/service/projects_info") - return [tsi.ProjectsInfoRes.model_validate(item) for item in r.json()] - - def otel_export(self, req: tsi.OTelExportReq) -> tsi.OTelExportRes: - # TODO: Add docs link (DOCS-1390) - raise NotImplementedError("Sending otel traces directly is not yet supported.") - - # Call API - @validate_call - def call_start(self, req: tsi.CallStartReq) -> tsi.CallStartRes: - if self.should_batch: - assert self.call_processor is not None - - if req.start.id is None or req.start.trace_id is None: - raise ValueError( - "CallStartReq must have id and trace_id when batching." - ) - self.call_processor.enqueue_start(StartBatchItem(req=req)) - return tsi.CallStartRes(id=req.start.id, trace_id=req.start.trace_id) - return self._generic_request( - "/call/start", req, tsi.CallStartReq, tsi.CallStartRes - ) - - def call_start_batch(self, req: tsi.CallCreateBatchReq) -> tsi.CallCreateBatchRes: - return self._generic_request( - "/call/upsert_batch", req, tsi.CallCreateBatchReq, tsi.CallCreateBatchRes - ) - - @validate_call - def call_end(self, req: tsi.CallEndReq) -> tsi.CallEndRes: - if self.should_batch: - assert self.call_processor is not None - - self.call_processor.enqueue([EndBatchItem(req=req)]) - return tsi.CallEndRes() - return self._generic_request("/call/end", req, tsi.CallEndReq, tsi.CallEndRes) - - @validate_call - def call_read(self, req: tsi.CallReadReq) -> tsi.CallReadRes: - return self._generic_request( - "/call/read", req, tsi.CallReadReq, tsi.CallReadRes - ) - - @validate_call - def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: - # This previously called the deprecated /calls/query endpoint. - return tsi.CallsQueryRes(calls=list(self.calls_query_stream(req))) - - @validate_call - def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]: - return self._generic_stream_request( - "/calls/stream_query", req, tsi.CallsQueryReq, tsi.CallSchema - ) - - @validate_call - def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: - return self._generic_request( - "/calls/query_stats", req, tsi.CallsQueryStatsReq, tsi.CallsQueryStatsRes - ) - - @validate_call - def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes: - return self._generic_request( - "/trace/usage", req, tsi.TraceUsageReq, tsi.TraceUsageRes - ) - - @validate_call - def calls_usage(self, req: tsi.CallsUsageReq) -> tsi.CallsUsageRes: - return self._generic_request( - "/calls/usage", req, tsi.CallsUsageReq, tsi.CallsUsageRes - ) - - @validate_call - def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: - return self._generic_request( - "/calls/delete", req, tsi.CallsDeleteReq, tsi.CallsDeleteRes - ) - - @validate_call - def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: - return self._generic_request( - "/call/update", req, tsi.CallUpdateReq, tsi.CallUpdateRes - ) - - # Obj API - - @validate_call - def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: - return self._generic_request( - "/obj/create", req, tsi.ObjCreateReq, tsi.ObjCreateRes - ) - - @validate_call - def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: - return self._generic_request("/obj/read", req, tsi.ObjReadReq, tsi.ObjReadRes) - - @validate_call - def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: - return self._generic_request( - "/objs/query", req, tsi.ObjQueryReq, tsi.ObjQueryRes - ) - - def obj_delete(self, req: tsi.ObjDeleteReq) -> tsi.ObjDeleteRes: - return self._generic_request( - "/obj/delete", req, tsi.ObjDeleteReq, tsi.ObjDeleteRes - ) - - def obj_add_tags(self, req: tsi.ObjAddTagsReq) -> tsi.ObjAddTagsRes: - body = his.ObjTagsBody(project_id=req.project_id, tags=req.tags) - return self._generic_request( - f"/objs/{req.object_id}/versions/{req.digest}/tags", - body, - his.ObjTagsBody, - tsi.ObjAddTagsRes, - method="PUT", - ) - - def obj_remove_tags(self, req: tsi.ObjRemoveTagsReq) -> tsi.ObjRemoveTagsRes: - body = his.ObjTagsBody(project_id=req.project_id, tags=req.tags) - return self._generic_request( - f"/objs/{req.object_id}/versions/{req.digest}/tags/remove", - body, - his.ObjTagsBody, - tsi.ObjRemoveTagsRes, - ) - - def obj_set_aliases(self, req: tsi.ObjSetAliasesReq) -> tsi.ObjSetAliasesRes: - body = his.ObjSetAliasesBody( - project_id=req.project_id, digest=req.digest, aliases=req.aliases - ) - return self._generic_request( - f"/objs/{req.object_id}/aliases", - body, - his.ObjSetAliasesBody, - tsi.ObjSetAliasesRes, - method="PUT", - ) - - def obj_remove_aliases( - self, req: tsi.ObjRemoveAliasesReq - ) -> tsi.ObjRemoveAliasesRes: - body = his.ObjRemoveAliasesBody(project_id=req.project_id, aliases=req.aliases) - return self._generic_request( - f"/objs/{req.object_id}/aliases/remove", - body, - his.ObjRemoveAliasesBody, - tsi.ObjRemoveAliasesRes, - ) - - def tags_list(self, req: tsi.TagsListReq) -> tsi.TagsListRes: - return self._generic_request( - "/tags", - req, - tsi.TagsListReq, - tsi.TagsListRes, - method="GET", - params={"project_id": req.project_id}, - ) - - def aliases_list(self, req: tsi.AliasesListReq) -> tsi.AliasesListRes: - return self._generic_request( - "/aliases", - req, - tsi.AliasesListReq, - tsi.AliasesListRes, - method="GET", - params={"project_id": req.project_id}, - ) - - @validate_call - def table_create(self, req: tsi.TableCreateReq) -> tsi.TableCreateRes: - return self._generic_request( - "/table/create", req, tsi.TableCreateReq, tsi.TableCreateRes - ) - - @validate_call - def table_update(self, req: tsi.TableUpdateReq) -> tsi.TableUpdateRes: - """Similar to `calls/batch_upsert`, we can dynamically adjust the payload size - due to the property that table updates can be decomposed into a series of - updates. - """ - estimated_bytes = len(req.model_dump_json(by_alias=True).encode("utf-8")) - if estimated_bytes > self.remote_request_bytes_limit and len(req.updates) > 1: - split_ndx = len(req.updates) // 2 - first_half_req = tsi.TableUpdateReq( - project_id=req.project_id, - base_digest=req.base_digest, - updates=req.updates[:split_ndx], - ) - first_half_res = self.table_update(first_half_req) - second_half_req = tsi.TableUpdateReq( - project_id=req.project_id, - base_digest=first_half_res.digest, - updates=req.updates[split_ndx:], - ) - second_half_res = self.table_update(second_half_req) - all_digests = ( - first_half_res.updated_row_digests + second_half_res.updated_row_digests - ) - return tsi.TableUpdateRes( - digest=second_half_res.digest, updated_row_digests=all_digests - ) - else: - return self._generic_request( - "/table/update", req, tsi.TableUpdateReq, tsi.TableUpdateRes - ) - - @validate_call - def table_query(self, req: tsi.TableQueryReq) -> tsi.TableQueryRes: - return self._generic_request( - "/table/query", req, tsi.TableQueryReq, tsi.TableQueryRes - ) - - @validate_call - def table_query_stream( - self, req: tsi.TableQueryReq - ) -> Iterator[tsi.TableRowSchema]: - # Need to manually iterate over this until the stream endpoint is built and shipped. - res = self.table_query(req) - yield from res.rows - - @validate_call - def table_query_stats(self, req: tsi.TableQueryStatsReq) -> tsi.TableQueryStatsRes: - return self._generic_request( - "/table/query_stats", req, tsi.TableQueryStatsReq, tsi.TableQueryStatsRes - ) - - @validate_call - def table_create_from_digests( - self, req: tsi.TableCreateFromDigestsReq - ) -> tsi.TableCreateFromDigestsRes: - """Create a table by specifying row digests instead of actual rows.""" - return self._generic_request( - "/table/create_from_digests", - req, - tsi.TableCreateFromDigestsReq, - tsi.TableCreateFromDigestsRes, - ) - - @validate_call - def table_query_stats_batch( - self, req: tsi.TableQueryStatsReq - ) -> tsi.TableQueryStatsRes: - return self._generic_request( - "/table/query_stats_batch", - req, - tsi.TableQueryStatsBatchReq, - tsi.TableQueryStatsBatchRes, - ) - - @validate_call - def refs_read_batch(self, req: tsi.RefsReadBatchReq) -> tsi.RefsReadBatchRes: - return self._generic_request( - "/refs/read_batch", req, tsi.RefsReadBatchReq, tsi.RefsReadBatchRes - ) - - @with_retry - def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: - data: dict[str, str] = {"project_id": req.project_id} - if req.expected_digest is not None: - data["expected_digest"] = req.expected_digest - r = self.post( - "/files/create", - data=data, - files={"file": (req.name, req.content)}, - ) - handle_response_error(r, "/files/create") - return tsi.FileCreateRes.model_validate(r.json()) - - @with_retry - def file_content_read(self, req: tsi.FileContentReadReq) -> tsi.FileContentReadRes: - r = self.post( - "/files/content", - json={"project_id": req.project_id, "digest": req.digest}, - ) - handle_response_error(r, "/files/content") - # TODO: Should stream to disk rather than to memory - bytes = io.BytesIO() - bytes.writelines(r.iter_bytes()) - bytes.seek(0) - return tsi.FileContentReadRes(content=bytes.read()) - - def files_stats(self, req: tsi.FilesStatsReq) -> tsi.FilesStatsRes: - return self._generic_request( - "/files/stats", req, tsi.FilesStatsReq, tsi.FilesStatsRes - ) - - @validate_call - def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: - if self.should_batch: - assert self.feedback_processor is not None - - feedback_id = req.id or generate_id() - req.id = feedback_id - - self.feedback_processor.enqueue([req]) - return tsi.FeedbackCreateRes( - id=feedback_id, - # technically incorrect, this can be off by a few seconds - created_at=datetime.datetime.now(ZoneInfo("UTC")), - wb_user_id=req.wb_user_id or "", - payload=req.payload, - ) - else: - return self._generic_request( - "/feedback/create", req, tsi.FeedbackCreateReq, tsi.FeedbackCreateRes - ) - - def feedback_create_batch( - self, req: tsi.FeedbackCreateBatchReq - ) -> tsi.FeedbackCreateBatchRes: - return self._generic_request( - "/feedback/batch/create", - req, - tsi.FeedbackCreateBatchReq, - tsi.FeedbackCreateBatchRes, - ) - - @validate_call - def feedback_query(self, req: tsi.FeedbackQueryReq) -> tsi.FeedbackQueryRes: - return self._generic_request( - "/feedback/query", req, tsi.FeedbackQueryReq, tsi.FeedbackQueryRes - ) - - @validate_call - def feedback_purge(self, req: tsi.FeedbackPurgeReq) -> tsi.FeedbackPurgeRes: - return self._generic_request( - "/feedback/purge", req, tsi.FeedbackPurgeReq, tsi.FeedbackPurgeRes - ) - - @validate_call - def feedback_replace(self, req: tsi.FeedbackReplaceReq) -> tsi.FeedbackReplaceRes: - return self._generic_request( - "/feedback/replace", req, tsi.FeedbackReplaceReq, tsi.FeedbackReplaceRes - ) - - @validate_call - def feedback_stats(self, req: tsi.FeedbackStatsReq) -> tsi.FeedbackStatsRes: - return self._generic_request( - "/feedback/stats", req, tsi.FeedbackStatsReq, tsi.FeedbackStatsRes - ) - - @validate_call - def feedback_aggregate( - self, req: tsi.FeedbackAggregateReq - ) -> tsi.FeedbackAggregateRes: - """Query the feedback table for aggregate scores over time.""" - return self._generic_request( - "/feedback/aggregate", - req, - tsi.FeedbackAggregateReq, - tsi.FeedbackAggregateRes, - ) - - @validate_call - def feedback_payload_schema( - self, req: tsi.FeedbackPayloadSchemaReq - ) -> tsi.FeedbackPayloadSchemaRes: - return self._generic_request( - "/feedback/payload_schema", - req, - tsi.FeedbackPayloadSchemaReq, - tsi.FeedbackPayloadSchemaRes, - ) - - # Cost API - @validate_call - def cost_query(self, req: tsi.CostQueryReq) -> tsi.CostQueryRes: - return self._generic_request( - "/cost/query", req, tsi.CostQueryReq, tsi.CostQueryRes - ) - - @validate_call - def cost_create(self, req: tsi.CostCreateReq) -> tsi.CostCreateRes: - return self._generic_request( - "/cost/create", req, tsi.CostCreateReq, tsi.CostCreateRes - ) - - @validate_call - def cost_purge(self, req: tsi.CostPurgeReq) -> tsi.CostPurgeRes: - return self._generic_request( - "/cost/purge", req, tsi.CostPurgeReq, tsi.CostPurgeRes - ) - - def completions_create( - self, req: tsi.CompletionsCreateReq - ) -> tsi.CompletionsCreateRes: - return self._generic_request( - "/completions/create", - req, - tsi.CompletionsCreateReq, - tsi.CompletionsCreateRes, - ) - - def completions_create_stream( - self, req: tsi.CompletionsCreateReq - ) -> Iterator[dict[str, Any]]: - # For remote servers, streaming is not implemented - # Fall back to non-streaming completion - response = self.completions_create(req) - yield {"response": response.response, "weave_call_id": response.weave_call_id} - - def image_create( - self, req: tsi.ImageGenerationCreateReq - ) -> tsi.ImageGenerationCreateRes: - return self._generic_request( - "/image/create", - req, - tsi.ImageGenerationCreateReq, - tsi.ImageGenerationCreateRes, - ) - - def project_stats(self, req: tsi.ProjectStatsReq) -> tsi.ProjectStatsRes: - return self._generic_request( - "/project/stats", req, tsi.ProjectStatsReq, tsi.ProjectStatsRes - ) - - def project_ttl_settings_read( - self, req: tsi.ProjectTTLSettingsReadReq - ) -> tsi.ProjectTTLSettingsReadRes: - return self._generic_request( - "/project/ttl_settings/read", - req, - tsi.ProjectTTLSettingsReadReq, - tsi.ProjectTTLSettingsReadRes, - ) - - def project_ttl_settings_update( - self, req: tsi.ProjectTTLSettingsUpdateReq - ) -> tsi.ProjectTTLSettingsUpdateRes: - return self._generic_request( - "/project/ttl_settings/update", - req, - tsi.ProjectTTLSettingsUpdateReq, - tsi.ProjectTTLSettingsUpdateRes, - ) - - def threads_query_stream( - self, req: tsi.ThreadsQueryReq - ) -> Iterator[tsi.ThreadSchema]: - return self._generic_stream_request( - "/threads/stream_query", req, tsi.ThreadsQueryReq, tsi.ThreadSchema - ) - - # Annotation Queue API - def annotation_queue_create( - self, req: tsi.AnnotationQueueCreateReq - ) -> tsi.AnnotationQueueCreateRes: - return self._generic_request( - "/annotation_queues", - req, - tsi.AnnotationQueueCreateReq, - tsi.AnnotationQueueCreateRes, - ) - - def annotation_queues_query_stream( - self, req: tsi.AnnotationQueuesQueryReq - ) -> Iterator[tsi.AnnotationQueueSchema]: - return self._generic_stream_request( - "/annotation_queues/query", - req, - tsi.AnnotationQueuesQueryReq, - tsi.AnnotationQueueSchema, - ) - - def annotation_queue_read( - self, req: tsi.AnnotationQueueReadReq - ) -> tsi.AnnotationQueueReadRes: - return self._generic_request( - f"/annotation_queues/{req.queue_id}", - req, - tsi.AnnotationQueueReadReq, - tsi.AnnotationQueueReadRes, - method="GET", - params={"project_id": req.project_id}, - ) - - def annotation_queue_delete( - self, req: tsi.AnnotationQueueDeleteReq - ) -> tsi.AnnotationQueueDeleteRes: - return self._generic_request( - f"/annotation_queues/{req.queue_id}", - req, - tsi.AnnotationQueueDeleteReq, - tsi.AnnotationQueueDeleteRes, - method="DELETE", - params={"project_id": req.project_id}, - ) - - def annotation_queue_update( - self, req: tsi.AnnotationQueueUpdateReq - ) -> tsi.AnnotationQueueUpdateRes: - # Convert to Body type to exclude queue_id from request body (it's in the URL path) - body = his.AnnotationQueueUpdateBody( - project_id=req.project_id, - name=req.name, - description=req.description, - scorer_refs=req.scorer_refs, - ) - return self._generic_request( - f"/annotation_queues/{req.queue_id}", - body, - his.AnnotationQueueUpdateBody, - tsi.AnnotationQueueUpdateRes, - method="PUT", - ) - - def annotation_queue_add_calls( - self, req: tsi.AnnotationQueueAddCallsReq - ) -> tsi.AnnotationQueueAddCallsRes: - # Convert to Body type to exclude queue_id from request body (it's in the URL path) - body = his.AnnotationQueueAddCallsBody( - project_id=req.project_id, - call_ids=req.call_ids, - display_fields=req.display_fields, - ) - return self._generic_request( - f"/annotation_queues/{req.queue_id}/items", - body, - his.AnnotationQueueAddCallsBody, - tsi.AnnotationQueueAddCallsRes, - ) - - def annotation_queue_items_query( - self, req: tsi.AnnotationQueueItemsQueryReq - ) -> tsi.AnnotationQueueItemsQueryRes: - # Convert to Body type to exclude queue_id from request body (it's in the URL path) - body = his.AnnotationQueueItemsQueryBody( - project_id=req.project_id, - filter=req.filter, - sort_by=req.sort_by, - limit=req.limit, - offset=req.offset, - include_position=req.include_position, - ) - return self._generic_request( - f"/annotation_queues/{req.queue_id}/items/query", - body, - his.AnnotationQueueItemsQueryBody, - tsi.AnnotationQueueItemsQueryRes, - ) - - def annotation_queues_stats( - self, req: tsi.AnnotationQueuesStatsReq - ) -> tsi.AnnotationQueuesStatsRes: - return self._generic_request( - "/annotation_queues/stats", - req, - tsi.AnnotationQueuesStatsReq, - tsi.AnnotationQueuesStatsRes, - ) - - def annotator_queue_items_progress_update( - self, req: tsi.AnnotatorQueueItemsProgressUpdateReq - ) -> tsi.AnnotatorQueueItemsProgressUpdateRes: - # Convert to Body type to exclude queue_id, item_id, and wb_user_id from request body - # (queue_id and item_id are in the URL path, wb_user_id is set server-side from auth) - body = his.AnnotationQueueItemProgressUpdateBody( - project_id=req.project_id, - annotation_state=req.annotation_state, - ) - return self._generic_request( - f"/annotation_queues/{req.queue_id}/items/{req.item_id}/progress", - body, - his.AnnotationQueueItemProgressUpdateBody, - tsi.AnnotatorQueueItemsProgressUpdateRes, - ) - - def evaluate_model(self, req: tsi.EvaluateModelReq) -> tsi.EvaluateModelRes: - raise NotImplementedError("evaluate_model is not implemented") - - def evaluation_status( - self, req: tsi.EvaluationStatusReq - ) -> tsi.EvaluationStatusRes: - raise NotImplementedError("evaluation_status is not implemented") - - def rescore(self, req: tsi.RescoreReq) -> tsi.RescoreRes: - raise NotImplementedError("rescore is not implemented") - - def calls_score(self, req: tsi.CallsScoreReq) -> tsi.CallsScoreRes: - raise NotImplementedError("calls_score is not implemented") - - # === V2 APIs === - - @validate_call - def op_create(self, req: tsi.OpCreateReq) -> tsi.OpCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/ops" - # For create, we need to send the body without project_id (OpCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.OpCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.OpCreateBody, - tsi.OpCreateRes, - method="POST", - ) - - @validate_call - def op_read(self, req: tsi.OpReadReq) -> tsi.OpReadRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/ops/{req.object_id}/versions/{req.digest}" - return self._generic_request( - url, - req, - tsi.OpReadReq, - tsi.OpReadRes, - method="GET", - ) - - @validate_call - def op_list(self, req: tsi.OpListReq) -> Iterator[tsi.OpReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/ops" - # Build query params - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - if req.eager: - params["eager"] = "true" - return self._generic_stream_request( - url, - req, - tsi.OpListReq, - tsi.OpReadRes, - method="GET", - params=params, - ) - - @validate_call - def op_delete(self, req: tsi.OpDeleteReq) -> tsi.OpDeleteRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/ops/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.OpDeleteReq, - tsi.OpDeleteRes, - method="DELETE", - params=params, - ) - - @validate_call - def dataset_create(self, req: tsi.DatasetCreateReq) -> tsi.DatasetCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/datasets" - # For create, we need to send the body without project_id (DatasetCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.DatasetCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.DatasetCreateBody, - tsi.DatasetCreateRes, - method="POST", - ) - - @validate_call - def dataset_read(self, req: tsi.DatasetReadReq) -> tsi.DatasetReadRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/datasets/{req.object_id}/versions/{req.digest}" - return self._generic_request( - url, - req, - tsi.DatasetReadReq, - tsi.DatasetReadRes, - method="GET", - ) - - @validate_call - def dataset_list(self, req: tsi.DatasetListReq) -> Iterator[tsi.DatasetReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/datasets" - # Build query params - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.DatasetListReq, - tsi.DatasetReadRes, - method="GET", - params=params, - ) - - @validate_call - def dataset_delete(self, req: tsi.DatasetDeleteReq) -> tsi.DatasetDeleteRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/datasets/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.DatasetDeleteReq, - tsi.DatasetDeleteRes, - method="DELETE", - params=params, - ) - - @validate_call - def scorer_create(self, req: tsi.ScorerCreateReq) -> tsi.ScorerCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scorers" - # For create, we need to send the body without project_id (ScorerCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.ScorerCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.ScorerCreateBody, - tsi.ScorerCreateRes, - method="POST", - ) - - @validate_call - def scorer_read(self, req: tsi.ScorerReadReq) -> tsi.ScorerReadRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scorers/{req.object_id}/versions/{req.digest}" - return self._generic_request( - url, - req, - tsi.ScorerReadReq, - tsi.ScorerReadRes, - method="GET", - ) - - @validate_call - def scorer_list(self, req: tsi.ScorerListReq) -> Iterator[tsi.ScorerReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scorers" - # Build query params - params = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.ScorerListReq, - tsi.ScorerReadRes, - method="GET", - params=params, - ) - - @validate_call - def scorer_delete(self, req: tsi.ScorerDeleteReq) -> tsi.ScorerDeleteRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scorers/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.ScorerDeleteReq, - tsi.ScorerDeleteRes, - method="DELETE", - params=params, - ) - - @validate_call - def evaluation_create( - self, req: tsi.EvaluationCreateReq - ) -> tsi.EvaluationCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluations" - # For create, we need to send the body without project_id (EvaluationCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.EvaluationCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.EvaluationCreateBody, - tsi.EvaluationCreateRes, - method="POST", - ) - - @validate_call - def evaluation_read(self, req: tsi.EvaluationReadReq) -> tsi.EvaluationReadRes: - entity, project = from_project_id(req.project_id) - url = ( - f"/v2/{entity}/{project}/evaluations/{req.object_id}/versions/{req.digest}" - ) - return self._generic_request( - url, - req, - tsi.EvaluationReadReq, - tsi.EvaluationReadRes, - method="GET", - ) - - @validate_call - def evaluation_list( - self, req: tsi.EvaluationListReq - ) -> Iterator[tsi.EvaluationReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluations" - # Build query params - params = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.EvaluationListReq, - tsi.EvaluationReadRes, - method="GET", - params=params, - ) - - @validate_call - def evaluation_delete( - self, req: tsi.EvaluationDeleteReq - ) -> tsi.EvaluationDeleteRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluations/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.EvaluationDeleteReq, - tsi.EvaluationDeleteRes, - method="DELETE", - params=params, - ) - - # Model V2 API - - @validate_call - def model_create(self, req: tsi.ModelCreateReq) -> tsi.ModelCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/models" - body = tsi.ModelCreateBody.model_validate( - req.model_dump(exclude={"project_id"}) - ) - return self._generic_request( - url, - body, - tsi.ModelCreateBody, - tsi.ModelCreateRes, - method="POST", - ) - - @validate_call - def model_read(self, req: tsi.ModelReadReq) -> tsi.ModelReadRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/models/{req.object_id}/versions/{req.digest}" - return self._generic_request( - url, - req, - tsi.ModelReadReq, - tsi.ModelReadRes, - method="GET", - ) - - @validate_call - def model_list(self, req: tsi.ModelListReq) -> Iterator[tsi.ModelReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/models" - # Build query params - params = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.ModelListReq, - tsi.ModelReadRes, - method="GET", - params=params, - ) - - @validate_call - def model_delete(self, req: tsi.ModelDeleteReq) -> tsi.ModelDeleteRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/models/{req.object_id}" - # Build query params - params = {} - if req.digests: - params["digests"] = req.digests - return self._generic_request( - url, - req, - tsi.ModelDeleteReq, - tsi.ModelDeleteRes, - method="DELETE", - params=params, - ) - - @validate_call - def evaluation_run_create( - self, req: tsi.EvaluationRunCreateReq - ) -> tsi.EvaluationRunCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs" - # For create, we need to send the body without project_id (EvaluationRunCreateBody) - body_data = req.model_dump(exclude={"project_id"}) - body = tsi.EvaluationRunCreateBody.model_validate(body_data) - return self._generic_request( - url, - body, - tsi.EvaluationRunCreateBody, - tsi.EvaluationRunCreateRes, - ) - - @validate_call - def evaluation_run_read( - self, req: tsi.EvaluationRunReadReq - ) -> tsi.EvaluationRunReadRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs/{req.evaluation_run_id}" - return self._generic_request( - url, - req, - tsi.EvaluationRunReadReq, - tsi.EvaluationRunReadRes, - method="GET", - ) - - @validate_call - def evaluation_run_list( - self, req: tsi.EvaluationRunListReq - ) -> Iterator[tsi.EvaluationRunReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs" - # Build query params - params: dict[str, Any] = {} - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - if req.filter: - if req.filter.evaluations: - params["evaluation_refs"] = ",".join(req.filter.evaluations) - if req.filter.models: - params["model_refs"] = ",".join(req.filter.models) - if req.filter.evaluation_run_ids: - params["evaluation_run_ids"] = ",".join(req.filter.evaluation_run_ids) - return self._generic_stream_request( - url, - req, - tsi.EvaluationRunListReq, - tsi.EvaluationRunReadRes, - method="GET", - params=params, - ) - - @validate_call - def evaluation_run_delete( - self, req: tsi.EvaluationRunDeleteReq - ) -> tsi.EvaluationRunDeleteRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs" - # Build query params - evaluation_run_ids are passed as a query param - params = {"evaluation_run_ids": req.evaluation_run_ids} - return self._generic_request( - url, - req, - tsi.EvaluationRunDeleteReq, - tsi.EvaluationRunDeleteRes, - method="DELETE", - params=params, - ) - - @validate_call - def evaluation_run_finish( - self, req: tsi.EvaluationRunFinishReq - ) -> tsi.EvaluationRunFinishRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/evaluation_runs/{req.evaluation_run_id}/finish" - return self._generic_request( - url, - req, - tsi.EvaluationRunFinishReq, - tsi.EvaluationRunFinishRes, - method="POST", - ) - - # Prediction V2 API - - @validate_call - def prediction_create( - self, req: tsi.PredictionCreateReq - ) -> tsi.PredictionCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions" - body = tsi.PredictionCreateBody.model_validate( - req.model_dump(exclude={"project_id"}) - ) - return self._generic_request( - url, - body, - tsi.PredictionCreateBody, - tsi.PredictionCreateRes, - method="POST", - ) - - @validate_call - def prediction_read(self, req: tsi.PredictionReadReq) -> tsi.PredictionReadRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions/{req.prediction_id}" - return self._generic_request( - url, - req, - tsi.PredictionReadReq, - tsi.PredictionReadRes, - method="GET", - ) - - @validate_call - def prediction_list( - self, req: tsi.PredictionListReq - ) -> Iterator[tsi.PredictionReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions" - # Build query params - params: dict[str, Any] = {} - if req.evaluation_run_id is not None: - params["evaluation_run_id"] = req.evaluation_run_id - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.PredictionListReq, - tsi.PredictionReadRes, - method="GET", - params=params, - ) - - @validate_call - def prediction_delete( - self, req: tsi.PredictionDeleteReq - ) -> tsi.PredictionDeleteRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions" - # Build query params - prediction_ids are passed as a query param - params = {"prediction_ids": req.prediction_ids} - return self._generic_request( - url, - req, - tsi.PredictionDeleteReq, - tsi.PredictionDeleteRes, - method="DELETE", - params=params, - ) - - @validate_call - def prediction_finish( - self, req: tsi.PredictionFinishReq - ) -> tsi.PredictionFinishRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/predictions/{req.prediction_id}/finish" - return self._generic_request( - url, - req, - tsi.PredictionFinishReq, - tsi.PredictionFinishRes, - method="POST", - ) - - # Score V2 API - - @validate_call - def score_create(self, req: tsi.ScoreCreateReq) -> tsi.ScoreCreateRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scores" - body = tsi.ScoreCreateBody.model_validate( - req.model_dump(exclude={"project_id"}) - ) - return self._generic_request( - url, - body, - tsi.ScoreCreateBody, - tsi.ScoreCreateRes, - method="POST", - ) - - @validate_call - def score_read(self, req: tsi.ScoreReadReq) -> tsi.ScoreReadRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scores/{req.score_id}" - return self._generic_request( - url, - req, - tsi.ScoreReadReq, - tsi.ScoreReadRes, - method="GET", - ) - - @validate_call - def score_list(self, req: tsi.ScoreListReq) -> Iterator[tsi.ScoreReadRes]: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scores" - # Build query params - params: dict[str, Any] = {} - if req.evaluation_run_id is not None: - params["evaluation_run_id"] = req.evaluation_run_id - if req.limit is not None: - params["limit"] = req.limit - if req.offset is not None: - params["offset"] = req.offset - return self._generic_stream_request( - url, - req, - tsi.ScoreListReq, - tsi.ScoreReadRes, - method="GET", - params=params, - ) - - @validate_call - def score_delete(self, req: tsi.ScoreDeleteReq) -> tsi.ScoreDeleteRes: - entity, project = from_project_id(req.project_id) - url = f"/v2/{entity}/{project}/scores" - # Build query params - score_ids are passed as a query param - params = {"score_ids": req.score_ids} - return self._generic_request( - url, - req, - tsi.ScoreDeleteReq, - tsi.ScoreDeleteRes, - method="DELETE", - params=params, - ) - - # === Calls V2 API === - - def calls_complete( - self, req: tsi.CallsUpsertCompleteReq - ) -> tsi.CallsUpsertCompleteRes: - """Batch complete calls endpoint (v2). - - This endpoint is used when use_calls_complete is enabled to send - complete calls (with both start and end information) in batches. - """ - if not req.batch: - return tsi.CallsUpsertCompleteRes() - - # Extract entity/project from first item - first_item = req.batch[0] - entity, project = from_project_id(first_item.project_id) - - url = f"/v2/{entity}/{project}/calls/complete" - return self._generic_request( - url, - req, - tsi.CallsUpsertCompleteReq, - tsi.CallsUpsertCompleteRes, - ) - - def call_start_v2(self, req: tsi.CallStartV2Req) -> tsi.CallStartV2Res: - """Single call start endpoint (v2). - - This endpoint is used for eager ops that need their start visible immediately. - """ - project_id = req.start.project_id - entity, project = from_project_id(project_id) - - url = f"/v2/{entity}/{project}/call/start" - return self._generic_request( - url, - req, - tsi.CallStartV2Req, - tsi.CallStartV2Res, - ) - - def call_end_v2(self, req: tsi.CallEndV2Req) -> tsi.CallEndV2Res: - """Single call end endpoint (v2). - - This endpoint is used for eager ops that need their end sent separately. - """ - project_id = req.end.project_id - entity, project = from_project_id(project_id) - - url = f"/v2/{entity}/{project}/call/end" - return self._generic_request( - url, - req, - tsi.CallEndV2Req, - tsi.CallEndV2Res, - ) diff --git a/weave/trace_server_bindings/stainless_remote_http_trace_server.py b/weave/trace_server_bindings/stainless_remote_http_trace_server.py index a3075d90b285..46c4fbee219e 100644 --- a/weave/trace_server_bindings/stainless_remote_http_trace_server.py +++ b/weave/trace_server_bindings/stainless_remote_http_trace_server.py @@ -921,6 +921,21 @@ def table_create_from_digests( tsi.TableCreateFromDigestsRes, ) + def unretried_table_create_from_digests( + self, req: tsi.TableCreateFromDigestsReq + ) -> tsi.TableCreateFromDigestsRes: + """Single-attempt variant used by the client's endpoint-availability + probe: a missing endpoint (404) must fail fast, and a flaky probe must + not stall table saves behind retries. + """ + return self._via_sdk.__wrapped__( # type: ignore[attr-defined] + self, + req, + sdk_models.TableCreateFromDigestsReq, + self._sdk.tables.create_create_from_digests, + tsi.TableCreateFromDigestsRes, + ) + @validate_call def table_query_stats_batch( self, req: tsi.TableQueryStatsReq