Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/trace/test_saved_view.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime

import pytest
from weave_server_sdk import models as tsi
from weave_server_sdk.models import Query

import weave
from weave.flow.saved_view import (
Expand Down Expand Up @@ -35,12 +35,12 @@ def test_query_to_filters_none():


def test_query_to_filters_one_filter():
query = tsi.Query(**{"$expr": {"$eq": [{"$getField": "rank"}, {"$literal": 1}]}})
query = Query(**{"$expr": {"$eq": [{"$getField": "rank"}, {"$literal": 1}]}})
assert query_to_filters(query) == [
Filter(field="rank", operator="(number): =", value=1)
]

query = tsi.Query(
query = Query(
**{
"$expr": {
"$gt": [
Expand All @@ -54,7 +54,7 @@ def test_query_to_filters_one_filter():
Filter(field="completion_token_cost", operator="(number): >", value=25)
]

query = tsi.Query(
query = Query(
**{
"$expr": {
"$eq": [{"$getField": "inputs.model"}, {"$literal": "gpt-4o-mini"}]
Expand All @@ -67,7 +67,7 @@ def test_query_to_filters_one_filter():


def test_query_to_filters_multiple_filters():
query = tsi.Query(
query = Query(
**{
"$expr": {
"$and": [
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_filter_manipulation():
view = weave.SavedView("traces", "My saved view")

view.add_filter("inputs.model", "equals", "gpt-3.5-turbo")
assert view.base.definition.query == tsi.Query(
assert view.base.definition.query == Query(
**{
"$expr": {
"$and": [
Expand Down
34 changes: 21 additions & 13 deletions tests/trace/test_wal_client_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

import pytest
from PIL import Image
from weave_server_sdk.models import (
CallStartReq,
ObjCreateReq,
ObjSchemaForInsert,
StartedCallSchemaForInsert,
TableCreateReq,
TableSchemaForInsert,
)

import weave
from weave.durability.wal_client_id import compute_client_id
Expand All @@ -32,7 +40,7 @@
from weave.durability.wal_writer import JSONLWALWriter
from weave.trace import weave_client
from weave.trace.settings import UserSettings, override_settings
from weave.trace_server_bindings import models as tsi
from weave.trace_server_bindings.models import FileCreateReq


def _read_all_wal_records(client: weave.WeaveClient) -> list[dict]:
Expand Down Expand Up @@ -370,8 +378,8 @@ def test_write_only_manager(self, tmp_path):

mgr.write(
"obj_create",
tsi.ObjCreateReq(
obj=tsi.ObjSchemaForInsert(
ObjCreateReq(
obj=ObjSchemaForInsert(
project_id="test-entity/test-project",
object_id="test_obj",
val={"hello": "world"},
Expand Down Expand Up @@ -438,8 +446,8 @@ def test_handler_calls_correct_server_method(self):
mock_server = MagicMock()
handlers = TraceServerHandlers(mock_server).as_dict()

req = tsi.ObjCreateReq(
obj=tsi.ObjSchemaForInsert(
req = ObjCreateReq(
obj=ObjSchemaForInsert(
project_id="e/p",
object_id="test",
val={"x": 1},
Expand All @@ -450,16 +458,16 @@ def test_handler_calls_correct_server_method(self):

mock_server.obj_create.assert_called_once()
call_arg = mock_server.obj_create.call_args[0][0]
assert isinstance(call_arg, tsi.ObjCreateReq)
assert isinstance(call_arg, ObjCreateReq)
assert call_arg.obj.object_id == "test"

def test_handler_calls_table_create(self):
"""table_create handler should call server.table_create."""
mock_server = MagicMock()
handlers = TraceServerHandlers(mock_server).as_dict()

req = tsi.TableCreateReq(
table=tsi.TableSchemaForInsert(
req = TableCreateReq(
table=TableSchemaForInsert(
project_id="e/p",
rows=[{"val": {"x": 1}, "digest": "abc123"}],
)
Expand All @@ -486,7 +494,7 @@ def test_file_create_roundtrips_bytes_content(self):
# WAL today (model_dump(mode="json") rejects non-utf8 bytes at write
# time, so non-utf8 content never appears in a replayed record).
original_content = "small file body — including non-ascii: café 🎉".encode()
req = tsi.FileCreateReq(
req = FileCreateReq(
project_id="e/p",
name="note.txt",
content=original_content,
Expand All @@ -497,7 +505,7 @@ def test_file_create_roundtrips_bytes_content(self):

mock_server.file_create.assert_called_once()
call_arg = mock_server.file_create.call_args[0][0]
assert isinstance(call_arg, tsi.FileCreateReq)
assert isinstance(call_arg, FileCreateReq)
assert call_arg.content == original_content

def test_call_start_roundtrips_through_handler(self):
Expand All @@ -511,8 +519,8 @@ def test_call_start_roundtrips_through_handler(self):
handlers = TraceServerHandlers(mock_server).as_dict()

started = datetime.datetime.now(datetime.timezone.utc)
req = tsi.CallStartReq(
start=tsi.StartedCallSchemaForInsert(
req = CallStartReq(
start=StartedCallSchemaForInsert(
project_id="e/p",
op_name="predict",
trace_id="t1",
Expand All @@ -526,7 +534,7 @@ def test_call_start_roundtrips_through_handler(self):

mock_server.call_start.assert_called_once()
call_arg = mock_server.call_start.call_args[0][0]
assert isinstance(call_arg, tsi.CallStartReq)
assert isinstance(call_arg, CallStartReq)
assert call_arg.start.op_name == "predict"
assert call_arg.start.started_at == started
assert call_arg.start.attributes == {"k": "v"}
Expand Down
19 changes: 12 additions & 7 deletions tests/trace_server_bindings/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import httpx
import pytest
import tenacity
from weave_server_sdk import models as tsi
from weave_server_sdk.models import (
CallEndReq,
CallStartReq,
EndedCallSchemaForInsert,
StartedCallSchemaForInsert,
)

from weave.trace_server.ids import generate_id
from weave.trace_server_bindings.remote_http_trace_server import (
Expand All @@ -20,9 +25,9 @@
def generate_start(
id: str | None = None,
project_id: str = "test",
) -> tsi.StartedCallSchemaForInsert:
) -> StartedCallSchemaForInsert:
"""Generate a test StartedCallSchemaForInsert."""
return tsi.StartedCallSchemaForInsert(
return StartedCallSchemaForInsert(
project_id=project_id,
id=id or generate_id(),
op_name="test_name",
Expand All @@ -37,9 +42,9 @@ def generate_start(
def generate_end(
id: str | None = None,
project_id: str = "test",
) -> tsi.EndedCallSchemaForInsert:
) -> EndedCallSchemaForInsert:
"""Generate a test EndedCallSchemaForInsert."""
return tsi.EndedCallSchemaForInsert(
return EndedCallSchemaForInsert(
project_id=project_id,
id=id or generate_id(),
ended_at=datetime.datetime.now(tz=datetime.timezone.utc)
Expand All @@ -53,11 +58,11 @@ def generate_end(
def generate_call_start_end_pair(
id: str | None = None,
project_id: str = "test",
) -> tuple[tsi.CallStartReq, tsi.CallEndReq]:
) -> tuple[CallStartReq, CallEndReq]:
"""Generate a matching pair of CallStartReq and CallEndReq for testing."""
start = generate_start(id, project_id)
end = generate_end(id, project_id)
return tsi.CallStartReq(start=start), tsi.CallEndReq(end=end)
return CallStartReq(start=start), CallEndReq(end=end)


# =============================================================================
Expand Down
22 changes: 14 additions & 8 deletions tests/trace_server_bindings/test_call_batch_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@

import httpx
import pytest
from weave_server_sdk.models import (
CallEndReq,
CallStartReq,
EndedCallSchemaForInsert,
StartedCallSchemaForInsert,
)

from weave.trace_server_bindings import models as tsi
from weave.trace_server_bindings.call_batch_processor import CallBatchProcessor
from weave.trace_server_bindings.models import (
CompleteBatchItem,
CompletedCallSchemaForInsert,
EndBatchItem,
StartBatchItem,
)
Expand All @@ -35,7 +41,7 @@ def _make_start_item(
'call-1'
"""
started_at = datetime.datetime.now(datetime.timezone.utc)
start = tsi.StartedCallSchemaForInsert(
start = StartedCallSchemaForInsert(
project_id=project_id,
id=call_id,
trace_id=trace_id,
Expand All @@ -44,7 +50,7 @@ def _make_start_item(
attributes={},
inputs={},
)
return StartBatchItem(req=tsi.CallStartReq(start=start))
return StartBatchItem(req=CallStartReq(start=start))


def _make_end_item(call_id: str, *, project_id: str = "proj") -> EndBatchItem:
Expand All @@ -63,13 +69,13 @@ def _make_end_item(call_id: str, *, project_id: str = "proj") -> EndBatchItem:
'call-1'
"""
ended_at = datetime.datetime.now(datetime.timezone.utc)
end = tsi.EndedCallSchemaForInsert(
end = EndedCallSchemaForInsert(
project_id=project_id,
id=call_id,
ended_at=ended_at,
summary={},
)
return EndBatchItem(req=tsi.CallEndReq(end=end))
return EndBatchItem(req=CallEndReq(end=end))


def _make_complete_item(
Expand All @@ -92,7 +98,7 @@ def _make_complete_item(
"""
started_at = datetime.datetime.now(datetime.timezone.utc)
ended_at = started_at + datetime.timedelta(seconds=1)
complete = tsi.CompletedCallSchemaForInsert(
complete = CompletedCallSchemaForInsert(
project_id=project_id,
id=call_id,
trace_id=trace_id,
Expand Down Expand Up @@ -193,7 +199,7 @@ def test_missing_trace_id_raises_value_error() -> None:
processor = CallBatchProcessor(complete_fn, eager_fn, min_batch_interval=0.01)

started_at = datetime.datetime.now(datetime.timezone.utc)
start = tsi.StartedCallSchemaForInsert(
start = StartedCallSchemaForInsert(
project_id="proj",
id="call-1",
trace_id=None,
Expand All @@ -202,7 +208,7 @@ def test_missing_trace_id_raises_value_error() -> None:
attributes={},
inputs={},
)
start_item = StartBatchItem(req=tsi.CallStartReq(start=start))
start_item = StartBatchItem(req=CallStartReq(start=start))
end_item = _make_end_item("call-1")

processor.enqueue([start_item])
Expand Down
Loading
Loading