Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ jobs:
"pandas_test",
"fastmcp",
"smolagents",
"stainless",
"autogen_tests",
]
# Shards that don't support certain Python versions. Excluding here
Expand Down Expand Up @@ -578,7 +577,6 @@ jobs:
"pandas_test",
"fastmcp",
"smolagents",
"stainless",
"autogen_tests",
]
# Shards that don't support certain Python versions. Excluding here
Expand Down
6 changes: 0 additions & 6 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def lint(session: nox.Session):
"trace",
"trace_calls_complete_only",
"trace_no_server",
"stainless",
],
)
def tests(session: nox.Session, shard: str):
Expand Down Expand Up @@ -185,7 +184,6 @@ def tests(session: nox.Session, shard: str):
"trace_server": ["tests/trace_server/", "tests/shared/"],
"trace_server_bindings": ["tests/trace_server_bindings/"],
"trace_server_migrator": ["tests/trace_server_migrator/"],
"stainless": ["tests/trace_server_bindings/"],
"scorers": ["tests/scorers/"],
"autogen_tests": ["tests/integrations/autogen/"],
"verifiers_test": ["tests/integrations/verifiers/"],
Expand Down Expand Up @@ -248,10 +246,6 @@ def tests(session: nox.Session, shard: str):
if shard == "trace_calls_complete_only":
env["WEAVE_USE_CALLS_COMPLETE"] = "true"

# Set trace-server flag for stainless shard
if shard == "stainless":
pytest_args.extend(["--remote-http-trace-server=stainless"])

if shard == "verifiers_test":
# Pinning to this commit because the latest version of the gsm8k environment is broken.
session.install(GSM8K_ENVIRONMENT_PACKAGE)
Expand Down
58 changes: 38 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any
from unittest.mock import MagicMock, patch

import httpx
import pytest
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
Expand All @@ -20,13 +21,15 @@
from weave.trace.context.call_context import set_call_stack
from weave.trace.settings import replace_settings
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server_bindings import remote_http_trace_server
from weave.trace_server_bindings import stainless_remote_http_trace_server
from weave.trace_server_bindings.async_batch_processor import AsyncBatchProcessor
from weave.trace_server_bindings.caching_middleware_trace_server import (
CachingMiddlewareTraceServer,
)
from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor
from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer
from weave.trace_server_bindings.stainless_remote_http_trace_server import (
StainlessRemoteHTTPTraceServer,
)

pytest_plugins = ["tests.trace_server.conftest"]

Expand Down Expand Up @@ -457,7 +460,7 @@ def create_client(
# Note: this is only for local dev testing and should be removed
return weave_init.init_weave("dev_testing")
elif trace_server_flag == "http":
server = RemoteHTTPTraceServer(trace_server_flag)
server = StainlessRemoteHTTPTraceServer(trace_server_flag)
else:
server = trace_server

Expand Down Expand Up @@ -597,11 +600,12 @@ def client(

@pytest.fixture
def network_proxy_client(client, monkeypatch):
"""This fixture is used to test the `RemoteHTTPTraceServer` class. There is
almost no logic in this class, other than a little batching, so we typically
skip it for simplicity. However, we can use this fixture to test such logic.
It initializes a mini FastAPI app that proxies requests from the
`RemoteHTTPTraceServer` to the underlying `client.server` object.
"""This fixture is used to test the `StainlessRemoteHTTPTraceServer` class.
There is almost no logic in this class, other than a little batching, so we
typically skip it for simplicity. However, we can use this fixture to test
such logic. It initializes a mini FastAPI app and routes the server's HTTP
transport into it, proxying requests to the underlying `client.server`
object.

We probably will want to flesh this out more in the future, but this is a
starting point.
Expand Down Expand Up @@ -707,12 +711,25 @@ def obj_read(req: tsi.ObjReadReq) -> tsi.ObjReadRes:

with TestClient(app) as c:

def post(url, data=None, json=None, **kwargs):
kwargs.pop("stream", None)
return c.post(url, data=data, json=json, **kwargs)

orig_post = weave.utils.http_requests.post
weave.utils.http_requests.post = post
class TestClientTransport(httpx.BaseTransport):
"""Routes the server's httpx requests into the FastAPI TestClient."""

def handle_request(self, request: httpx.Request) -> httpx.Response:
request.read()
resp = c.request(
request.method,
request.url.path,
params=request.url.params,
content=request.content,
headers={
k: v
for k, v in request.headers.items()
if k.lower() not in {"host", "content-length"}
},
)
return httpx.Response(
resp.status_code, headers=resp.headers, content=resp.content
)

def make_fast_async_batch_processor(*args, **kwargs):
kwargs.setdefault("min_batch_interval", 0)
Expand All @@ -723,24 +740,25 @@ def make_fast_call_batch_processor(*args, **kwargs):
return CallBatchProcessor(*args, **kwargs)

monkeypatch.setattr(
remote_http_trace_server,
stainless_remote_http_trace_server,
"AsyncBatchProcessor",
make_fast_async_batch_processor,
)
monkeypatch.setattr(
remote_http_trace_server,
stainless_remote_http_trace_server,
"CallBatchProcessor",
make_fast_call_batch_processor,
)

remote_client = RemoteHTTPTraceServer(
trace_server_url="",
# Absolute base URL required for httpx cookie handling; the transport
# routes by path, so the host is never dialed.
remote_client = StainlessRemoteHTTPTraceServer(
trace_server_url="http://testserver",
should_batch=True,
transport=TestClientTransport(),
)
yield (client, remote_client, records)

weave.utils.http_requests.post = orig_post


@pytest.fixture(autouse=True)
def caching_client_isolation(monkeypatch, tmp_path):
Expand Down
6 changes: 4 additions & 2 deletions tests/trace/test_wave_client_file_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
WeaveClientSendFileCache,
)
from weave.trace_server.trace_server_interface import FileCreateReq, FileCreateRes
from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer
from weave.trace_server_bindings.stainless_remote_http_trace_server import (
StainlessRemoteHTTPTraceServer,
)


class TestThreadSafeLRUCache:
Expand Down Expand Up @@ -402,7 +404,7 @@ def offline_client(monkeypatch):
monkeypatch.setenv("WEAVE_RETRY_MAX_ATTEMPTS", "2")
monkeypatch.setenv("WEAVE_RETRY_MAX_INTERVAL", "0.01")
monkeypatch.setenv("WEAVE_ENABLE_WAL", "false")
server = RemoteHTTPTraceServer("http://example.com")
server = StainlessRemoteHTTPTraceServer("http://example.com")
client = WeaveClient(
entity="ent",
project="proj",
Expand Down
6 changes: 3 additions & 3 deletions tests/trace/test_weave_init_availability.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from unittest.mock import MagicMock

from weave.trace import weave_init
from weave.trace_server_bindings import remote_http_trace_server
from weave.trace_server_bindings import stainless_remote_http_trace_server


def test_get_server_info_json_decode_error():
"""Test that _get_server_info returns None when server info cannot be decoded."""
mock_server = MagicMock(spec=remote_http_trace_server.RemoteHTTPTraceServer)
mock_server = MagicMock(spec=stainless_remote_http_trace_server.StainlessRemoteHTTPTraceServer)
mock_server.server_info.side_effect = json.JSONDecodeError("test error", "doc", 0)

result = weave_init._get_server_info(mock_server)
Expand All @@ -20,7 +20,7 @@ def test_get_server_info_json_decode_error():

def test_get_server_info_success():
"""Test that _get_server_info returns server info when server is available."""
mock_server = MagicMock(spec=remote_http_trace_server.RemoteHTTPTraceServer)
mock_server = MagicMock(spec=stainless_remote_http_trace_server.StainlessRemoteHTTPTraceServer)
server_info = {"version": "1.0.0"}
mock_server.server_info.return_value = server_info

Expand Down
16 changes: 0 additions & 16 deletions tests/trace_server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def pytest_addoption(parser):
default="false",
help="Use a clickhouse process instead of a container",
)
parser.addoption(
"--remote-http-trace-server",
action="store",
default="remote",
help="Specify the remote HTTP trace server implementation: remote or stainless",
)
except ValueError:
pass

Expand All @@ -82,7 +76,6 @@ def pytest_collection_modifyitems(config, items):
# Add the trace_server marker to:
# 1. All tests in the trace_server directory (regardless of fixture usage)
# 2. All tests that use the trace_server fixture (for tests outside this directory)
# Note: Filtering based on remote-http-trace-server flag is handled in tests/trace_server_bindings/conftest.py
for item in items:
# Check if the test is in the trace_server directory by checking parent directories
if "trace_server" in item.path.parts:
Expand Down Expand Up @@ -110,15 +103,6 @@ def get_trace_server_flag(request):
return weave_server_flag


def get_remote_http_trace_server_flag(request):
"""Get the remote HTTP trace server implementation to use.

Returns:
str: Either 'remote' for RemoteHTTPTraceServer or 'stainless' for StainlessRemoteHTTPTraceServer
"""
return request.config.getoption("--remote-http-trace-server")


@pytest.fixture(autouse=True)
def reset_project_version_cache():
project_version.reset_project_residence_cache()
Expand Down
85 changes: 45 additions & 40 deletions tests/trace_server_bindings/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from types import MethodType
from unittest.mock import MagicMock

import httpx
import pytest
import tenacity

from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.ids import generate_id
from weave.trace_server_bindings.remote_http_trace_server import (
RemoteHTTPTraceServer,
from weave.trace_server_bindings.stainless_remote_http_trace_server import (
StainlessRemoteHTTPTraceServer,
)

# =============================================================================
Expand Down Expand Up @@ -60,35 +61,58 @@ def generate_call_start_end_pair(


# =============================================================================
# Fixtures
# HTTP transport spy
# =============================================================================


@pytest.fixture
def success_response():
"""Common fixture for mocking a successful HTTP response."""
response = MagicMock()
response.status_code = 200
response.json.return_value = {"id": "test_id", "trace_id": "test_trace_id"}
return response
class SpyTransport(httpx.BaseTransport):
"""httpx transport that records requests and replays queued responses.

Queue items may be ``httpx.Response`` objects or exceptions to raise.
When the queue is empty, returns ``default_response`` (200 ``{}`` unless
overridden).
"""

@pytest.fixture
def server_class(request):
"""Returns the appropriate server class based on --remote-http-trace-server flag."""
flag = request.config.getoption("--remote-http-trace-server", default="remote")
if flag == "stainless":
from weave.trace_server_bindings.stainless_remote_http_trace_server import (
StainlessRemoteHTTPTraceServer,
)
def __init__(
self,
*items: httpx.Response | Exception,
default_response: httpx.Response | None = None,
) -> None:
self.queue: list[httpx.Response | Exception] = list(items)
self.requests: list[httpx.Request] = []
self.default_response = default_response

def handle_request(self, request: httpx.Request) -> httpx.Response:
request.read()
self.requests.append(request)
if self.queue:
item = self.queue.pop(0)
if isinstance(item, Exception):
raise item
return item
if self.default_response is not None:
return self.default_response
return httpx.Response(200, json={})

@property
def urls(self) -> list[str]:
return [str(r.url) for r in self.requests]

return StainlessRemoteHTTPTraceServer
return RemoteHTTPTraceServer

# =============================================================================
# Fixtures
# =============================================================================


@pytest.fixture
def server_class():
"""The remote trace server implementation under test."""
return StainlessRemoteHTTPTraceServer


@pytest.fixture
def server(request, server_class):
"""Common server fixture that uses server_class based on the CLI flag."""
"""Common server fixture parametrized by batching/retry behavior."""
server_ = server_class("http://example.com", should_batch=True)

if request.param == "normal":
Expand Down Expand Up @@ -116,25 +140,6 @@ def server(request, server_class):
server_.feedback_processor.stop_accepting_new_work_and_flush_queue()


def pytest_ignore_collect(collection_path, config):
"""Ignore test files based on --remote-http-trace-server flag.

This runs before collection, preventing files from being imported at all.
"""
if "trace_server_bindings" not in collection_path.parts:
return None

flag = config.getoption("--remote-http-trace-server", default="remote")
filename = collection_path.name

if flag == "remote" and filename.endswith("_stainless.py"):
return True
if flag == "stainless" and filename.endswith("_remote.py"):
return True

return None


def pytest_collection_modifyitems(config, items):
"""Add trace_server marker to all tests in this directory."""
for item in items:
Expand Down
Loading
Loading