-
Notifications
You must be signed in to change notification settings - Fork 153
chore(weave): validate monitor query fields on save #7191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,7 @@ | |
| from dataclasses import dataclass | ||
| from typing import Any, Literal, cast | ||
|
|
||
| from pydantic import BaseModel, Field | ||
| from pydantic import BaseModel, Field, ValidationError | ||
| from typing_extensions import Self | ||
|
|
||
| from weave.shared.trace_server_interface_util import ( | ||
|
|
@@ -64,7 +64,7 @@ | |
| trace_id_index_expr, | ||
| ) | ||
| from weave.trace_server.common_interface import SortBy | ||
| from weave.trace_server.errors import InvalidFieldError | ||
| from weave.trace_server.errors import InvalidFieldError, InvalidRequest | ||
| from weave.trace_server.interface import query as tsi_query | ||
| from weave.trace_server.interface.feedback_types import MULTI_VALUE_FEEDBACK_TYPES | ||
| from weave.trace_server.interface.query import ( | ||
|
|
@@ -1943,6 +1943,94 @@ def get_field_by_name(name: str) -> CallsMergedField: | |
| return ALLOWED_CALL_FIELDS[name] | ||
|
|
||
|
|
||
| # Field references get_field_by_name accepts beyond exact ALLOWED_CALL_FIELDS | ||
| # keys, used only to build the user-facing error message. The `*_dump` prefixes | ||
| # are derived so they can't drift; the special entries must be kept in sync by | ||
| # hand with the explicit branches in get_field_by_name above | ||
| # (`annotation_queue_items.queue_id` is an exact ref, not a prefix). | ||
| _DUMP_SUFFIX = "_dump" | ||
| _SPECIAL_DYNAMIC_FIELD_PREFIXES = ( | ||
| "feedback.*", | ||
| "annotation_queue_items.queue_id", | ||
| "summary.weave.*", | ||
| ) | ||
| ALLOWED_DYNAMIC_FIELD_PREFIXES = _SPECIAL_DYNAMIC_FIELD_PREFIXES + tuple( | ||
| f"{name[: -len(_DUMP_SUFFIX)]}.*" | ||
| for name, field in ALLOWED_CALL_FIELDS.items() | ||
| if isinstance(field, CallsMergedDynamicField) and name.endswith(_DUMP_SUFFIX) | ||
| ) | ||
|
|
||
|
|
||
| # Serialized `_class_name`/`_bases` values for Monitor objects. These mirror the | ||
| # SDK classes in weave/flow/monitor.py but are kept as server-side strings on | ||
| # purpose: the trace server must not import from weave/flow. | ||
| MONITOR_OBJECT_CLASSES = frozenset({"Monitor", "ClassifierMonitor"}) | ||
|
|
||
|
|
||
| def validate_monitor_query_fields( | ||
| base_object_class: str | None, | ||
| leaf_object_class: str | None, | ||
| val: object, | ||
| ) -> None: | ||
| """Reject a Monitor whose `query` references a field outside the allowed set.""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment and implementation don't quite match. While it is true that this will detect a query with unacceptable fields, I think it will also detect queries that don't compile. At least that's how the code reads. That also seems reasonable, but the mismatch between the comment and code is confusing. |
||
| if ( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Half of this function is answering the question "is this a monitor that we should validate". That's not immediately obvious. Consider ways to make the code more self-documenting. |
||
| base_object_class not in MONITOR_OBJECT_CLASSES | ||
| and leaf_object_class not in MONITOR_OBJECT_CLASSES | ||
| ): | ||
| return | ||
| if not isinstance(val, dict): | ||
| return | ||
| raw_query = val.get("query") | ||
| if raw_query is None: | ||
| return | ||
| cleaned = _strip_weave_object_keys(raw_query) | ||
| try: | ||
| query = tsi_query.Query.model_validate(cleaned) | ||
| except ValidationError: | ||
| # Not a recognizable query (e.g. a user object that happens to share | ||
| # the Monitor class name); leave the write untouched. | ||
| return | ||
| validate_query_compiles(query) | ||
|
|
||
|
|
||
| def validate_query_compiles(query: tsi_query.Query) -> None: | ||
| """Validate that `query` references only allowed call fields and is well-formed.""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar comment about mismatch between comment and what the code appears to be doing. Consider connecting the dots here. |
||
| try: | ||
| process_query_to_conditions(query, ParamBuilder(), "calls_merged") | ||
| except InvalidFieldError as e: | ||
| raise InvalidFieldError(_invalid_field_message(str(e))) from e | ||
| except (ValueError, TypeError) as e: | ||
| raise InvalidRequest(f"Invalid query: {e}") from e | ||
|
|
||
|
|
||
| _WEAVE_BOOKKEEPING_KEYS = frozenset({"_type", "_class_name", "_bases"}) | ||
|
|
||
|
|
||
| def _strip_weave_object_keys(value: object) -> object: | ||
| """Drop weave bookkeeping keys (`_type`, `_class_name`, `_bases`) from a serialized query.""" | ||
| if isinstance(value, dict): | ||
| return { | ||
| k: _strip_weave_object_keys(v) | ||
| for k, v in value.items() | ||
| if k not in _WEAVE_BOOKKEEPING_KEYS | ||
| } | ||
| if isinstance(value, list): | ||
| return [_strip_weave_object_keys(v) for v in value] | ||
| return value | ||
|
|
||
|
|
||
| def _invalid_field_message(reason: str) -> str: | ||
| """Append the allowed field list and dynamic prefixes to a field rejection.""" | ||
| allowed = ", ".join( | ||
| sorted(k for k in ALLOWED_CALL_FIELDS if not k.endswith(_DUMP_SUFFIX)) | ||
| ) | ||
| prefixes = ", ".join(ALLOWED_DYNAMIC_FIELD_PREFIXES) | ||
| return ( | ||
| f"{reason}. Allowed fields: {allowed}. " | ||
| f"Allowed dynamic field prefixes: {prefixes}" | ||
| ) | ||
|
|
||
|
|
||
| def _field_as_sql_maybe_agg( | ||
| field: CallsMergedField, | ||
| pb: ParamBuilder, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hard to know what these constants mean until reading the rest of the code and using claude to analyze the PR. I think these would be more self-documenting if a couple of changes were made: