Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 127 additions & 2 deletions tests/trace/test_base_object_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,19 @@
import weave
from weave.trace import base_objects
from weave.trace.refs import ObjectRef
from weave.trace.weave_client import WeaveClient
from weave.trace.serialization.serialize import to_json
from weave.trace.weave_client import WeaveClient, map_to_refs
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.errors import ObjectNameTypeCollision
from weave.trace_server.calls_query_builder.calls_query_builder import (
_DUMP_SUFFIX,
ALLOWED_CALL_FIELDS,
ALLOWED_DYNAMIC_FIELD_PREFIXES,
)
from weave.trace_server.errors import (
InvalidFieldError,
InvalidRequest,
ObjectNameTypeCollision,
)
from weave.trace_server.interface.builtin_object_classes.test_only_example import (
TestOnlyNestedBaseModel,
)
Expand Down Expand Up @@ -1153,3 +1163,118 @@ def test_obj_create_rejects_name_type_collision(client: WeaveClient):
}
)
)


def _create_monitor(client: WeaveClient, name: str, query: dict | None):
"""Create a Monitor via the client's serialization so the stored `query` carries weave bookkeeping keys, exercising the server-side strip + validation."""
monitor = weave.Monitor(name=name, scorers=[], query=query)
val = to_json(map_to_refs(monitor), client.project_id, client)
return client.server.obj_create(
tsi.ObjCreateReq.model_validate(
{
"obj": {
"project_id": client.project_id,
"object_id": name,
"val": val,
}
}
)
)


def test_monitor_create_rejects_unknown_query_field(client: WeaveClient):
"""A Monitor query on an unknown field is rejected with the complete allowed-field list."""
bad_query = {
"$expr": {"$eq": [{"$getField": "operation_name"}, {"$literal": "predict"}]}
}
with pytest.raises(InvalidFieldError) as exc_info:
_create_monitor(client, "bad-monitor", bad_query)

allowed = ", ".join(
sorted(k for k in ALLOWED_CALL_FIELDS if not k.endswith(_DUMP_SUFFIX))
)
expected_message = (
"Field operation_name is not allowed. "
f"Allowed fields: {allowed}. "
f"Allowed dynamic field prefixes: {', '.join(ALLOWED_DYNAMIC_FIELD_PREFIXES)}"
)
assert str(exc_info.value) == expected_message
assert _DUMP_SUFFIX not in str(exc_info.value)

# A structurally invalid query (empty $and) is a bad request, not a field error.
with pytest.raises(InvalidRequest):
_create_monitor(client, "empty-and-monitor", {"$expr": {"$and": []}})

# Neither rejected monitor was stored.
objs_res = client.server.objs_query(
tsi.ObjQueryReq.model_validate(
{
"project_id": client.project_id,
"filter": {"object_ids": ["bad-monitor", "empty-and-monitor"]},
}
)
)
assert objs_res.objs == []


def test_monitor_create_accepts_valid_query_fields(client: WeaveClient):
"""Static, dynamic, and absent monitor queries all create successfully."""
valid_query = {
"$expr": {
"$and": [
{"$eq": [{"$getField": "parent_id"}, {"$literal": None}]},
{
"$eq": [
{"$getField": "summary.weave.status"},
{"$literal": "success"},
]
},
]
}
}
dynamic_query = {
"$expr": {"$eq": [{"$getField": "inputs.foo"}, {"$literal": "bar"}]}
}

_create_monitor(client, "valid-monitor", valid_query)
_create_monitor(client, "dynamic-monitor", dynamic_query)
_create_monitor(client, "no-query-monitor", None)

# A Monitor-classed object whose `query` is not a recognizable query shape is
# left untouched (we only validate queries we understand), never 500'd.
monitor = weave.Monitor(name="opaque-monitor", scorers=[], query=None)
opaque_val = to_json(map_to_refs(monitor), client.project_id, client)
opaque_val["query"] = "not-a-query"
client.server.obj_create(
tsi.ObjCreateReq.model_validate(
{
"obj": {
"project_id": client.project_id,
"object_id": "opaque-monitor",
"val": opaque_val,
}
}
)
)

objs_res = client.server.objs_query(
tsi.ObjQueryReq.model_validate(
{
"project_id": client.project_id,
"filter": {
"object_ids": [
"valid-monitor",
"dynamic-monitor",
"no-query-monitor",
"opaque-monitor",
]
},
}
)
)
assert {obj.object_id for obj in objs_res.objs} == {
"valid-monitor",
"dynamic-monitor",
"no-query-monitor",
"opaque-monitor",
}
92 changes: 90 additions & 2 deletions weave/trace_server/calls_query_builder/calls_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dataclasses import dataclass
from typing import Any, Literal, cast

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError
from typing_extensions import Self

from weave.shared.trace_server_interface_util import (
Expand Down Expand Up @@ -64,7 +64,7 @@
trace_id_index_expr,
)
from weave.trace_server.common_interface import SortBy
from weave.trace_server.errors import InvalidFieldError
from weave.trace_server.errors import InvalidFieldError, InvalidRequest
from weave.trace_server.interface import query as tsi_query
from weave.trace_server.interface.feedback_types import MULTI_VALUE_FEEDBACK_TYPES
from weave.trace_server.interface.query import (
Expand Down Expand Up @@ -1943,6 +1943,94 @@ def get_field_by_name(name: str) -> CallsMergedField:
return ALLOWED_CALL_FIELDS[name]


# Field references get_field_by_name accepts beyond exact ALLOWED_CALL_FIELDS

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to know what these constants mean until reading the rest of the code and using claude to analyze the PR. I think these would be more self-documenting if a couple of changes were made:

  • It looks like they're mainly for creating the error message, but they're up above the validation code. Consider moving them to be closer to the error message and/or giving a little high-level context before launching into the details.
  • Consider moving the error messaging and maybe the validation into a separate file. This file is huge, and adding more new code to it isn't going to help readability or agent context limits.

# keys, used only to build the user-facing error message. The `*_dump` prefixes
# are derived so they can't drift; the special entries must be kept in sync by
# hand with the explicit branches in get_field_by_name above
# (`annotation_queue_items.queue_id` is an exact ref, not a prefix).
_DUMP_SUFFIX = "_dump"
_SPECIAL_DYNAMIC_FIELD_PREFIXES = (
"feedback.*",
"annotation_queue_items.queue_id",
"summary.weave.*",
)
ALLOWED_DYNAMIC_FIELD_PREFIXES = _SPECIAL_DYNAMIC_FIELD_PREFIXES + tuple(
f"{name[: -len(_DUMP_SUFFIX)]}.*"
for name, field in ALLOWED_CALL_FIELDS.items()
if isinstance(field, CallsMergedDynamicField) and name.endswith(_DUMP_SUFFIX)
)


# Serialized `_class_name`/`_bases` values for Monitor objects. These mirror the
# SDK classes in weave/flow/monitor.py but are kept as server-side strings on
# purpose: the trace server must not import from weave/flow.
MONITOR_OBJECT_CLASSES = frozenset({"Monitor", "ClassifierMonitor"})


def validate_monitor_query_fields(
base_object_class: str | None,
leaf_object_class: str | None,
val: object,
) -> None:
"""Reject a Monitor whose `query` references a field outside the allowed set."""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment and implementation don't quite match. While it is true that this will detect a query with unacceptable fields, I think it will also detect queries that don't compile. At least that's how the code reads. That also seems reasonable, but the mismatch between the comment and code is confusing.

if (

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Half of this function is answering the question "is this a monitor that we should validate". That's not immediately obvious. Consider ways to make the code more self-documenting.

base_object_class not in MONITOR_OBJECT_CLASSES
and leaf_object_class not in MONITOR_OBJECT_CLASSES
):
return
if not isinstance(val, dict):
return
raw_query = val.get("query")
if raw_query is None:
return
cleaned = _strip_weave_object_keys(raw_query)
try:
query = tsi_query.Query.model_validate(cleaned)
except ValidationError:
# Not a recognizable query (e.g. a user object that happens to share
# the Monitor class name); leave the write untouched.
return
validate_query_compiles(query)


def validate_query_compiles(query: tsi_query.Query) -> None:
"""Validate that `query` references only allowed call fields and is well-formed."""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment about mismatch between comment and what the code appears to be doing. Consider connecting the dots here.

try:
process_query_to_conditions(query, ParamBuilder(), "calls_merged")
except InvalidFieldError as e:
raise InvalidFieldError(_invalid_field_message(str(e))) from e
except (ValueError, TypeError) as e:
raise InvalidRequest(f"Invalid query: {e}") from e


_WEAVE_BOOKKEEPING_KEYS = frozenset({"_type", "_class_name", "_bases"})


def _strip_weave_object_keys(value: object) -> object:
"""Drop weave bookkeeping keys (`_type`, `_class_name`, `_bases`) from a serialized query."""
if isinstance(value, dict):
return {
k: _strip_weave_object_keys(v)
for k, v in value.items()
if k not in _WEAVE_BOOKKEEPING_KEYS
}
if isinstance(value, list):
return [_strip_weave_object_keys(v) for v in value]
return value


def _invalid_field_message(reason: str) -> str:
"""Append the allowed field list and dynamic prefixes to a field rejection."""
allowed = ", ".join(
sorted(k for k in ALLOWED_CALL_FIELDS if not k.endswith(_DUMP_SUFFIX))
)
prefixes = ", ".join(ALLOWED_DYNAMIC_FIELD_PREFIXES)
return (
f"{reason}. Allowed fields: {allowed}. "
f"Allowed dynamic field prefixes: {prefixes}"
)


def _field_as_sql_maybe_agg(
field: CallsMergedField,
pb: ParamBuilder,
Expand Down
6 changes: 6 additions & 0 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
build_calls_complete_update_query,
build_calls_stats_query,
combine_conditions,
validate_monitor_query_fields,
)
from weave.trace_server.calls_query_builder.usage_query_builder import (
build_usage_query,
Expand Down Expand Up @@ -1988,6 +1989,11 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
actual=digest,
label=f"obj {req.obj.object_id!r}",
)
validate_monitor_query_fields(
digest_result.base_object_class,
digest_result.leaf_object_class,
processed_val,
)

kind = get_kind(processed_val)
self._reject_obj_name_type_collision(
Expand Down
8 changes: 8 additions & 0 deletions weave/trace_server/sqlite_trace_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from weave.trace_server import eval_results_helpers as eval_helpers
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.call_stats_helpers import validate_call_stats_range
from weave.trace_server.calls_query_builder.calls_query_builder import (
validate_monitor_query_fields,
)
from weave.trace_server.ch_sentinel_values import EXPIRE_AT_NEVER
from weave.trace_server.common_interface import SortBy
from weave.trace_server.digest_validation import validate_expected_digest
Expand Down Expand Up @@ -1907,6 +1910,11 @@ def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes:
actual=digest,
label=f"obj {req.obj.object_id!r}",
)
validate_monitor_query_fields(
digest_result.base_object_class,
digest_result.leaf_object_class,
processed_val,
)
project_id, object_id, wb_user_id = (
req.obj.project_id,
req.obj.object_id,
Expand Down
Loading