Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tests/trace/test_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from weave.shared.refs_internal import InvalidInternalRef
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.errors import RefObjectsNotFoundError


def test_save_object(client):
Expand Down Expand Up @@ -81,3 +82,25 @@ def test_robust_to_url_sensitive_chars(client):
)
)
assert read_res.vals[0] == bad_val[bad_key]


def test_refs_read_batch_missing_refs_reports_digests(client):
project_id = client.project_id
create_res = client.server.obj_create(
tsi.ObjCreateReq(
obj=tsi.ObjSchemaForInsert(
project_id=project_id, object_id="real-obj", val={"a": 1}
)
)
)
real_ref = f"weave:///{project_id}/object/real-obj:{create_res.digest}"
missing_digest = "0" * 43
missing_ref = f"weave:///{project_id}/object/missing-obj:{missing_digest}"

# The missing object surfaces as a RefObjectsNotFoundError carrying the missing
# digest as a structured field
with pytest.raises(RefObjectsNotFoundError) as exc_info:
client.server.refs_read_batch(
tsi.RefsReadBatchReq(refs=[real_ref, missing_ref])
)
assert missing_digest in exc_info.value.missing_object_digests
7 changes: 5 additions & 2 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
NotFoundError,
ObjectDeletedError,
ObjectNameTypeCollision,
RefObjectsNotFoundError,
RequestTooLarge,
handle_clickhouse_query_error,
)
Expand Down Expand Up @@ -5801,8 +5802,10 @@ def get_object_refs_root_val(
objs = self._select_objs_query(object_query_builder)
found_digests = {obj.digest for obj in objs}
if len(ref_digests) != len(found_digests):
raise NotFoundError(
f"Ref read contains {len(ref_digests)} digests, but found {len(found_digests)} objects. Diff digests: {ref_digests - found_digests}"
missing_digests = sorted(ref_digests - found_digests)
raise RefObjectsNotFoundError(
f"Ref read contains {len(ref_digests)} digests, but found {len(found_digests)} objects. Diff digests: {missing_digests}",
missing_digests,
)
# filter out deleted objects
valid_objects = [obj for obj in objs if obj.deleted_at is None]
Expand Down
23 changes: 23 additions & 0 deletions weave/trace_server/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ class NotFoundError(Error):
pass


class RefObjectsNotFoundError(NotFoundError):
"""Raised when reference objects are not found."""

def __init__(self, message: str, missing_object_digests: list[str]):
self.missing_object_digests = missing_object_digests
super().__init__(message)


class MissingLLMApiKeyError(Error):
"""Raised when a LLM API key is missing for completion."""

Expand Down Expand Up @@ -319,6 +327,13 @@ def _setup_common_errors(self) -> None:

# 404
self.register(NotFoundError, 404)
# Exact-type registration (the registry matches by exact type), so the
# missing-digest field is surfaced in the response body.
self.register(
RefObjectsNotFoundError,
404,
_format_ref_objects_not_found_error,
)
self.register(ProjectNotFound, 404)
self.register(RunNotFound, 404)
self.register(ObjectDeletedError, 404, _format_object_deleted_error)
Expand Down Expand Up @@ -494,6 +509,14 @@ def _format_object_deleted_error(exc: Exception) -> dict[str, Any]:
return _format_error_to_json_with_extra(exc, extra)


def _format_ref_objects_not_found_error(exc: Exception) -> dict[str, Any]:
"""Format RefObjectsNotFoundError with the missing object digests."""
extra = {}
if isinstance(exc, RefObjectsNotFoundError):
extra["missing_object_digests"] = exc.missing_object_digests
return _format_error_to_json_with_extra(exc, extra)


def _get_transport_server_error_status_code(exc: Exception) -> int:
"""Get status code for TransportServerError, preserving 4xx codes, defaulting to 500.

Expand Down
5 changes: 4 additions & 1 deletion weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
NotFoundError,
ObjectDeletedError,
ObjectNameTypeCollision,
RefObjectsNotFoundError,
)
from weave.trace_server.feedback import (
TABLE_FEEDBACK,
Expand Down Expand Up @@ -2656,7 +2657,9 @@ def read_ref(r: ri.InternalObjectRef) -> Any:
include_deleted=True,
)
if len(objs) == 0:
raise NotFoundError(f"Obj {r.name}:{r.version} not found")
raise RefObjectsNotFoundError(
f"Obj {r.name}:{r.version} not found", [r.version]
)
obj = objs[0]
if obj.deleted_at is not None:
return None
Expand Down
Loading