diff --git a/tests/trace/test_base_object_classes.py b/tests/trace/test_base_object_classes.py index c3b770426274..17c178749bd2 100644 --- a/tests/trace/test_base_object_classes.py +++ b/tests/trace/test_base_object_classes.py @@ -20,9 +20,19 @@ import weave from weave.trace import base_objects from weave.trace.refs import ObjectRef -from weave.trace.weave_client import WeaveClient +from weave.trace.serialization.serialize import to_json +from weave.trace.weave_client import WeaveClient, map_to_refs from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.errors import ObjectNameTypeCollision +from weave.trace_server.calls_query_builder.calls_query_builder import ( + _DUMP_SUFFIX, + ALLOWED_CALL_FIELDS, + ALLOWED_DYNAMIC_FIELD_PREFIXES, +) +from weave.trace_server.errors import ( + InvalidFieldError, + InvalidRequest, + ObjectNameTypeCollision, +) from weave.trace_server.interface.builtin_object_classes.test_only_example import ( TestOnlyNestedBaseModel, ) @@ -1153,3 +1163,118 @@ def test_obj_create_rejects_name_type_collision(client: WeaveClient): } ) ) + + +def _create_monitor(client: WeaveClient, name: str, query: dict | None): + """Create a Monitor via the client's serialization so the stored `query` carries weave bookkeeping keys, exercising the server-side strip + validation.""" + monitor = weave.Monitor(name=name, scorers=[], query=query) + val = to_json(map_to_refs(monitor), client.project_id, client) + return client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client.project_id, + "object_id": name, + "val": val, + } + } + ) + ) + + +def test_monitor_create_rejects_unknown_query_field(client: WeaveClient): + """A Monitor query on an unknown field is rejected with the complete allowed-field list.""" + bad_query = { + "$expr": {"$eq": [{"$getField": "operation_name"}, {"$literal": "predict"}]} + } + with pytest.raises(InvalidFieldError) as exc_info: + _create_monitor(client, "bad-monitor", bad_query) + + allowed = ", ".join( + sorted(k for k in ALLOWED_CALL_FIELDS if not k.endswith(_DUMP_SUFFIX)) + ) + expected_message = ( + "Field operation_name is not allowed. " + f"Allowed fields: {allowed}. " + f"Allowed dynamic field prefixes: {', '.join(ALLOWED_DYNAMIC_FIELD_PREFIXES)}" + ) + assert str(exc_info.value) == expected_message + assert _DUMP_SUFFIX not in str(exc_info.value) + + # A structurally invalid query (empty $and) is a bad request, not a field error. + with pytest.raises(InvalidRequest): + _create_monitor(client, "empty-and-monitor", {"$expr": {"$and": []}}) + + # Neither rejected monitor was stored. + objs_res = client.server.objs_query( + tsi.ObjQueryReq.model_validate( + { + "project_id": client.project_id, + "filter": {"object_ids": ["bad-monitor", "empty-and-monitor"]}, + } + ) + ) + assert objs_res.objs == [] + + +def test_monitor_create_accepts_valid_query_fields(client: WeaveClient): + """Static, dynamic, and absent monitor queries all create successfully.""" + valid_query = { + "$expr": { + "$and": [ + {"$eq": [{"$getField": "parent_id"}, {"$literal": None}]}, + { + "$eq": [ + {"$getField": "summary.weave.status"}, + {"$literal": "success"}, + ] + }, + ] + } + } + dynamic_query = { + "$expr": {"$eq": [{"$getField": "inputs.foo"}, {"$literal": "bar"}]} + } + + _create_monitor(client, "valid-monitor", valid_query) + _create_monitor(client, "dynamic-monitor", dynamic_query) + _create_monitor(client, "no-query-monitor", None) + + # A Monitor-classed object whose `query` is not a recognizable query shape is + # left untouched (we only validate queries we understand), never 500'd. + monitor = weave.Monitor(name="opaque-monitor", scorers=[], query=None) + opaque_val = to_json(map_to_refs(monitor), client.project_id, client) + opaque_val["query"] = "not-a-query" + client.server.obj_create( + tsi.ObjCreateReq.model_validate( + { + "obj": { + "project_id": client.project_id, + "object_id": "opaque-monitor", + "val": opaque_val, + } + } + ) + ) + + objs_res = client.server.objs_query( + tsi.ObjQueryReq.model_validate( + { + "project_id": client.project_id, + "filter": { + "object_ids": [ + "valid-monitor", + "dynamic-monitor", + "no-query-monitor", + "opaque-monitor", + ] + }, + } + ) + ) + assert {obj.object_id for obj in objs_res.objs} == { + "valid-monitor", + "dynamic-monitor", + "no-query-monitor", + "opaque-monitor", + } diff --git a/weave/trace_server/calls_query_builder/calls_query_builder.py b/weave/trace_server/calls_query_builder/calls_query_builder.py index b489436f831f..42b8cf921247 100644 --- a/weave/trace_server/calls_query_builder/calls_query_builder.py +++ b/weave/trace_server/calls_query_builder/calls_query_builder.py @@ -31,7 +31,7 @@ from dataclasses import dataclass from typing import Any, Literal, cast -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError from typing_extensions import Self from weave.shared.trace_server_interface_util import ( @@ -64,7 +64,7 @@ trace_id_index_expr, ) from weave.trace_server.common_interface import SortBy -from weave.trace_server.errors import InvalidFieldError +from weave.trace_server.errors import InvalidFieldError, InvalidRequest from weave.trace_server.interface import query as tsi_query from weave.trace_server.interface.feedback_types import MULTI_VALUE_FEEDBACK_TYPES from weave.trace_server.interface.query import ( @@ -1943,6 +1943,94 @@ def get_field_by_name(name: str) -> CallsMergedField: return ALLOWED_CALL_FIELDS[name] +# Field references get_field_by_name accepts beyond exact ALLOWED_CALL_FIELDS +# keys, used only to build the user-facing error message. The `*_dump` prefixes +# are derived so they can't drift; the special entries must be kept in sync by +# hand with the explicit branches in get_field_by_name above +# (`annotation_queue_items.queue_id` is an exact ref, not a prefix). +_DUMP_SUFFIX = "_dump" +_SPECIAL_DYNAMIC_FIELD_PREFIXES = ( + "feedback.*", + "annotation_queue_items.queue_id", + "summary.weave.*", +) +ALLOWED_DYNAMIC_FIELD_PREFIXES = _SPECIAL_DYNAMIC_FIELD_PREFIXES + tuple( + f"{name[: -len(_DUMP_SUFFIX)]}.*" + for name, field in ALLOWED_CALL_FIELDS.items() + if isinstance(field, CallsMergedDynamicField) and name.endswith(_DUMP_SUFFIX) +) + + +# Serialized `_class_name`/`_bases` values for Monitor objects. These mirror the +# SDK classes in weave/flow/monitor.py but are kept as server-side strings on +# purpose: the trace server must not import from weave/flow. +MONITOR_OBJECT_CLASSES = frozenset({"Monitor", "ClassifierMonitor"}) + + +def validate_monitor_query_fields( + base_object_class: str | None, + leaf_object_class: str | None, + val: object, +) -> None: + """Reject a Monitor whose `query` references a field outside the allowed set.""" + if ( + base_object_class not in MONITOR_OBJECT_CLASSES + and leaf_object_class not in MONITOR_OBJECT_CLASSES + ): + return + if not isinstance(val, dict): + return + raw_query = val.get("query") + if raw_query is None: + return + cleaned = _strip_weave_object_keys(raw_query) + try: + query = tsi_query.Query.model_validate(cleaned) + except ValidationError: + # Not a recognizable query (e.g. a user object that happens to share + # the Monitor class name); leave the write untouched. + return + validate_query_compiles(query) + + +def validate_query_compiles(query: tsi_query.Query) -> None: + """Validate that `query` references only allowed call fields and is well-formed.""" + try: + process_query_to_conditions(query, ParamBuilder(), "calls_merged") + except InvalidFieldError as e: + raise InvalidFieldError(_invalid_field_message(str(e))) from e + except (ValueError, TypeError) as e: + raise InvalidRequest(f"Invalid query: {e}") from e + + +_WEAVE_BOOKKEEPING_KEYS = frozenset({"_type", "_class_name", "_bases"}) + + +def _strip_weave_object_keys(value: object) -> object: + """Drop weave bookkeeping keys (`_type`, `_class_name`, `_bases`) from a serialized query.""" + if isinstance(value, dict): + return { + k: _strip_weave_object_keys(v) + for k, v in value.items() + if k not in _WEAVE_BOOKKEEPING_KEYS + } + if isinstance(value, list): + return [_strip_weave_object_keys(v) for v in value] + return value + + +def _invalid_field_message(reason: str) -> str: + """Append the allowed field list and dynamic prefixes to a field rejection.""" + allowed = ", ".join( + sorted(k for k in ALLOWED_CALL_FIELDS if not k.endswith(_DUMP_SUFFIX)) + ) + prefixes = ", ".join(ALLOWED_DYNAMIC_FIELD_PREFIXES) + return ( + f"{reason}. Allowed fields: {allowed}. " + f"Allowed dynamic field prefixes: {prefixes}" + ) + + def _field_as_sql_maybe_agg( field: CallsMergedField, pb: ParamBuilder, diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index e025b166d200..1ebd5897c1fc 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -103,6 +103,7 @@ build_calls_complete_update_query, build_calls_stats_query, combine_conditions, + validate_monitor_query_fields, ) from weave.trace_server.calls_query_builder.usage_query_builder import ( build_usage_query, @@ -1988,6 +1989,11 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: actual=digest, label=f"obj {req.obj.object_id!r}", ) + validate_monitor_query_fields( + digest_result.base_object_class, + digest_result.leaf_object_class, + processed_val, + ) kind = get_kind(processed_val) self._reject_obj_name_type_collision( diff --git a/weave/trace_server/sqlite_trace_server.py b/weave/trace_server/sqlite_trace_server.py index 07bb568fd85d..dbd5eb7bb98e 100644 --- a/weave/trace_server/sqlite_trace_server.py +++ b/weave/trace_server/sqlite_trace_server.py @@ -33,6 +33,9 @@ from weave.trace_server import eval_results_helpers as eval_helpers from weave.trace_server import trace_server_interface as tsi from weave.trace_server.call_stats_helpers import validate_call_stats_range +from weave.trace_server.calls_query_builder.calls_query_builder import ( + validate_monitor_query_fields, +) from weave.trace_server.ch_sentinel_values import EXPIRE_AT_NEVER from weave.trace_server.common_interface import SortBy from weave.trace_server.digest_validation import validate_expected_digest @@ -1907,6 +1910,11 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: actual=digest, label=f"obj {req.obj.object_id!r}", ) + validate_monitor_query_fields( + digest_result.base_object_class, + digest_result.leaf_object_class, + processed_val, + ) project_id, object_id, wb_user_id = ( req.obj.project_id, req.obj.object_id,