diff --git a/tests/trace/test_trace_server.py b/tests/trace/test_trace_server.py index 0efdee7003ec..d568b8c2f1a7 100644 --- a/tests/trace/test_trace_server.py +++ b/tests/trace/test_trace_server.py @@ -4,6 +4,7 @@ from weave.shared.refs_internal import InvalidInternalRef from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.errors import RefObjectsNotFoundError def test_save_object(client): @@ -81,3 +82,25 @@ def test_robust_to_url_sensitive_chars(client): ) ) assert read_res.vals[0] == bad_val[bad_key] + + +def test_refs_read_batch_missing_refs_reports_digests(client): + project_id = client.project_id + create_res = client.server.obj_create( + tsi.ObjCreateReq( + obj=tsi.ObjSchemaForInsert( + project_id=project_id, object_id="real-obj", val={"a": 1} + ) + ) + ) + real_ref = f"weave:///{project_id}/object/real-obj:{create_res.digest}" + missing_digest = "0" * 43 + missing_ref = f"weave:///{project_id}/object/missing-obj:{missing_digest}" + + # The missing object surfaces as a RefObjectsNotFoundError carrying the missing + # digest as a structured field + with pytest.raises(RefObjectsNotFoundError) as exc_info: + client.server.refs_read_batch( + tsi.RefsReadBatchReq(refs=[real_ref, missing_ref]) + ) + assert missing_digest in exc_info.value.missing_object_digests diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 1ab0308a0f63..e81d15a3a6ed 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -179,6 +179,7 @@ NotFoundError, ObjectDeletedError, ObjectNameTypeCollision, + RefObjectsNotFoundError, RequestTooLarge, handle_clickhouse_query_error, ) @@ -5801,8 +5802,10 @@ def get_object_refs_root_val( objs = self._select_objs_query(object_query_builder) found_digests = {obj.digest for obj in objs} if len(ref_digests) != len(found_digests): - raise NotFoundError( - f"Ref read contains {len(ref_digests)} digests, but found {len(found_digests)} objects. Diff digests: {ref_digests - found_digests}" + missing_digests = sorted(ref_digests - found_digests) + raise RefObjectsNotFoundError( + f"Ref read contains {len(ref_digests)} digests, but found {len(found_digests)} objects. Diff digests: {missing_digests}", + missing_digests, ) # filter out deleted objects valid_objects = [obj for obj in objs if obj.deleted_at is None] diff --git a/weave/trace_server/errors.py b/weave/trace_server/errors.py index 2f1c341eb174..029999e677df 100644 --- a/weave/trace_server/errors.py +++ b/weave/trace_server/errors.py @@ -156,6 +156,14 @@ class NotFoundError(Error): pass +class RefObjectsNotFoundError(NotFoundError): + """Raised when reference objects are not found.""" + + def __init__(self, message: str, missing_object_digests: list[str]): + self.missing_object_digests = missing_object_digests + super().__init__(message) + + class MissingLLMApiKeyError(Error): """Raised when a LLM API key is missing for completion.""" @@ -319,6 +327,13 @@ def _setup_common_errors(self) -> None: # 404 self.register(NotFoundError, 404) + # Exact-type registration (the registry matches by exact type), so the + # missing-digest field is surfaced in the response body. + self.register( + RefObjectsNotFoundError, + 404, + _format_ref_objects_not_found_error, + ) self.register(ProjectNotFound, 404) self.register(RunNotFound, 404) self.register(ObjectDeletedError, 404, _format_object_deleted_error) @@ -494,6 +509,14 @@ def _format_object_deleted_error(exc: Exception) -> dict[str, Any]: return _format_error_to_json_with_extra(exc, extra) +def _format_ref_objects_not_found_error(exc: Exception) -> dict[str, Any]: + """Format RefObjectsNotFoundError with the missing object digests.""" + extra = {} + if isinstance(exc, RefObjectsNotFoundError): + extra["missing_object_digests"] = exc.missing_object_digests + return _format_error_to_json_with_extra(exc, extra) + + def _get_transport_server_error_status_code(exc: Exception) -> int: """Get status code for TransportServerError, preserving 4xx codes, defaulting to 500. diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 328f75b397d7..ef553ffccfdf 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -41,6 +41,7 @@ NotFoundError, ObjectDeletedError, ObjectNameTypeCollision, + RefObjectsNotFoundError, ) from weave.trace_server.feedback import ( TABLE_FEEDBACK, @@ -2656,7 +2657,9 @@ def read_ref(r: ri.InternalObjectRef) -> Any: include_deleted=True, ) if len(objs) == 0: - raise NotFoundError(f"Obj {r.name}:{r.version} not found") + raise RefObjectsNotFoundError( + f"Obj {r.name}:{r.version} not found", [r.version] + ) obj = objs[0] if obj.deleted_at is not None: return None