Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 4 additions & 33 deletions tests/trace/data_serialization/test_cases/library_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def evaluation_equality_check(a, b):
# When doing this, replace "llm_as_a_judge_scorer_digest" with the current value of llm_as_a_judge_scorer_digest_for_current_non_legacy_test_on_old_python
# Do this, rather than creating a new variable, because each new version of legacy test case will need a different value.
llm_as_a_judge_scorer_digest_for_current_non_legacy_test_on_current_python = (
"usU9eU7is5YeNlwmYcSOHYfjJB8xHGCXXUVpm6dBbfc"
"4U7vV5XKCkJ0uOdkflK1O2b6jU3vUY17ZJtoItgc0iA"
)
llm_as_a_judge_scorer_digest_for_current_non_legacy_test_on_old_python = (
"usU9eU7is5YeNlwmYcSOHYfjJB8xHGCXXUVpm6dBbfc"
"4U7vV5XKCkJ0uOdkflK1O2b6jU3vUY17ZJtoItgc0iA"
)
llm_as_a_judge_scorer_digest = (
llm_as_a_judge_scorer_digest_for_current_non_legacy_test_on_current_python
Expand Down Expand Up @@ -170,32 +170,21 @@ def evaluation_equality_check(a, b):
"name": None,
"description": None,
"column_map": None,
"model": "weave:///shawn/test-project/object/LLMStructuredCompletionModel:gsLyIHy6h9PE8KVMoKpXcYykXOMQamcLBTvzPU7vNN4",
"model": "weave:///shawn/test-project/object/LLMStructuredCompletionModel:pzXf4DUrjqEMPKQTP4mZnjUp2G7lEGocXS8J1Jk8dqg",
"enable_image_input_scoring": True,
"enable_audio_input_scoring": True,
"enable_video_input_scoring": True,
"media_scoring_json_paths": [
"$.messages[0].content[1].input_audio"
],
"scoring_prompt": "Here are the inputs: {inputs}. Here is the output: {output}. Is the output correct?",
"score": "weave:///shawn/test-project/op/LLMAsAJudgeScorer.score:6xWBXgbLjYI67G1Uvms2dCWP2izbVABBwvqmx00CUT4",
"summarize": "weave:///shawn/test-project/op/Scorer.summarize:R9dPVXqD4IgSlmmGS5RA8uQPehMlvQ0CHRJMEMf1AMQ",
"_class_name": "LLMAsAJudgeScorer",
"_bases": ["Scorer", "Object", "BaseModel"],
},
},
{
"object_id": "LLMAsAJudgeScorer.score",
"digest": "6xWBXgbLjYI67G1Uvms2dCWP2izbVABBwvqmx00CUT4",
"exp_val": {
"_type": "CustomWeaveType",
"weave_type": {"type": "Op"},
"files": {"obj.py": "fqXqYs4C4l0HpQOaRfbVXwsvwYUZhMYyn4cvK0wnCMU"},
},
},
{
"object_id": "LLMStructuredCompletionModel",
"digest": "gsLyIHy6h9PE8KVMoKpXcYykXOMQamcLBTvzPU7vNN4",
"digest": "pzXf4DUrjqEMPKQTP4mZnjUp2G7lEGocXS8J1Jk8dqg",
"exp_val": {
"_type": "LLMStructuredCompletionModel",
"name": None,
Expand Down Expand Up @@ -228,20 +217,10 @@ def evaluation_equality_check(a, b):
"_class_name": "LLMStructuredCompletionModelDefaultParams",
"_bases": ["BaseModel"],
},
"predict": "weave:///shawn/test-project/op/LLMStructuredCompletionModel.predict:M6uEk3KmOzZagYl3tJBeoiOHX7opfOQyuqnSguDXjPI",
"_class_name": "LLMStructuredCompletionModel",
"_bases": ["Model", "Object", "BaseModel"],
},
},
{
"object_id": "LLMStructuredCompletionModel.predict",
"digest": "M6uEk3KmOzZagYl3tJBeoiOHX7opfOQyuqnSguDXjPI",
"exp_val": {
"_type": "CustomWeaveType",
"weave_type": {"type": "Op"},
"files": {"obj.py": "1GtS3cAyf0xckKcss0LQesVtm44iEG49EsX1xuzTmvc"},
},
},
{
"object_id": "Evaluation.summarize",
"digest": "Y0s05NYTuqlmXieehHPogfq2JXKl4Y1Xgy8CKumdmjI",
Expand Down Expand Up @@ -291,14 +270,6 @@ def evaluation_equality_check(a, b):
"digest": "vY6VtT9xBAKNfqhozgQdWEGuijncPtmZLYKrXexUERY",
"exp_content": b'import weave\nfrom weave.object.obj import Object\nfrom weave.trace.table import Table\nfrom weave.flow.util import transpose\nfrom weave.flow.scorer import get_scorer_attributes\nfrom weave.flow.scorer import auto_summarize\nfrom weave.trace.op import op\n\nclass EvaluationResults(Object):\n rows: Table\n\n@weave.op\n@op\nasync def summarize(self, eval_table: EvaluationResults) -> dict:\n eval_table_rows = list(eval_table.rows)\n cols = transpose(eval_table_rows)\n summary = {}\n\n for name, vals in cols.items():\n if name == "scores":\n if scorers := self.scorers:\n for scorer in scorers:\n scorer_attributes = get_scorer_attributes(scorer)\n scorer_name = scorer_attributes.scorer_name\n summarize_fn = scorer_attributes.summarize_fn\n scorer_stats = transpose(vals)\n score_table = scorer_stats[scorer_name]\n scored = summarize_fn(score_table)\n summary[scorer_name] = scored\n else:\n model_output_summary = auto_summarize(vals)\n if model_output_summary:\n summary[name] = model_output_summary\n return summary\n',
},
{
"digest": "fqXqYs4C4l0HpQOaRfbVXwsvwYUZhMYyn4cvK0wnCMU",
"exp_content": b'import weave\nfrom typing import Any\nfrom weave.prompt.prompt import MessagesPrompt\nfrom weave.trace.op import op\n\n@weave.op\n@op\ndef score(self, *, output: str, **kwargs: Any) -> Any:\n """Score the output using the scoring_prompt."""\n if isinstance(self.scoring_prompt, MessagesPrompt):\n model_input = self.scoring_prompt.format(output=output, **kwargs)\n else:\n scoring_prompt = self.scoring_prompt.format(output=output, **kwargs)\n model_input = [{"role": "user", "content": scoring_prompt}]\n return self.model.predict(model_input)\n',
},
{
"digest": "1GtS3cAyf0xckKcss0LQesVtm44iEG49EsX1xuzTmvc",
"exp_content": b'import weave\nfrom typing import Annotated as MessageListLike\nfrom typing import Annotated as LLMStructuredModelParamsLike\nfrom typing import Any\nfrom weave.trace.context.weave_client_context import get_weave_client\nfrom weave.trace.context.weave_client_context import WeaveInitError\nfrom weave.utils.project_id import to_project_id\nfrom typing import Literal as ResponseFormat\nimport json\nfrom pydantic.main import BaseModel\nfrom weave.trace.op import op\n\nclass Message(BaseModel):\n """A message in a conversation with an LLM.\n\n Attributes:\n role: The role of the message\'s author. Can be: system, user, assistant, function or tool.\n content: The contents of the message. Required for all messages, but may be null for assistant messages with function calls.\n name: The name of the author of the message. Required if role is "function". Must match the name of the function represented in content.\n Can contain characters (a-z, A-Z, 0-9), and underscores, with a maximum length of 64 characters.\n function_call: The name and arguments of a function that should be called, as generated by the model.\n tool_call_id: Tool call that this message is responding to.\n """\n\n role: str\n content: str | list[dict] | None = None\n name: str | None = None\n function_call: dict | None = None\n tool_call_id: str | None = None\n\ndef parse_response(\n response_payload: dict, response_format: ResponseFormat | None\n) -> Message | str | dict[str, Any]:\n """Extract the model output from an LLM completion response payload.\n\n Raises:\n RuntimeError: the provider returned a top-level `error` field.\n ValueError: the payload is malformed (missing choices/message), the\n content is None/empty, or json_object parsing failed.\n """\n if response_payload.get("error"):\n raise RuntimeError(f"LLM API returned an error: {response_payload[\'error\']}")\n\n choices = response_payload.get("choices")\n if not choices:\n raise ValueError(\n "LLM response is missing \'choices\' -> the upstream call likely failed "\n "(invalid API key, content filtering, or provider error). "\n f"Response keys: {sorted(response_payload.keys())}"\n )\n\n message = choices[0].get("message") if isinstance(choices[0], dict) else None\n if not isinstance(message, dict):\n raise TypeError(\n f"LLM response choice did not contain a message dict: {choices[0]!r}"\n )\n content = message.get("content")\n\n if response_format == "text":\n if content is None:\n raise ValueError(\n "LLM response content is None -> the model returned no text. "\n "Check your API key, model config, and content filtering settings."\n )\n return content\n elif response_format == "json_object":\n if content is None or (isinstance(content, str) and not content.strip()):\n raise ValueError(\n "LLM response content was empty when JSON output was requested. "\n "Check your API key and that the model supports JSON mode."\n )\n try:\n return json.loads(content)\n except json.JSONDecodeError as e:\n snippet = content if len(content) <= 200 else content[:200] + "..."\n raise ValueError(\n f"LLM response was not valid JSON (response_format=json_object). "\n f"Content snippet: {snippet!r}"\n ) from e\n else:\n raise ValueError(f"Invalid response_format: {response_format}")\n\n@weave.op\n@op\ndef predict(\n self,\n user_input: MessageListLike | None = None,\n config: LLMStructuredModelParamsLike | None = None,\n **template_vars: Any,\n) -> Message | str | dict[str, Any]:\n """Generates a prediction by preparing messages (template + user_input)\n and calling the LLM completions endpoint with overridden config, using the provided client.\n\n Messages are prepared in one of two ways:\n 1. If default_params.prompt is set, the referenced MessagesPrompt object is\n loaded and its format() method is called with template_vars to generate messages.\n 2. If default_params.messages_template is set (and prompt is not), the template\n messages are used with template variable substitution.\n\n Note: If both prompt and messages_template are provided, prompt takes precedence.\n\n Args:\n user_input: The user input messages to append after template messages\n config: Optional configuration to override default parameters\n **template_vars: Variables to substitute in the messages template using {variable_name} syntax\n """\n if user_input is None:\n user_input = []\n\n current_client = get_weave_client()\n if current_client is None:\n raise WeaveInitError(\n "You must call `weave.init(<project_name>)` first, to predict with a LLMStructuredCompletionModel"\n )\n\n req = self.prepare_completion_request(\n project_id=to_project_id(current_client.entity, current_client.project),\n user_input=user_input,\n config=config,\n **template_vars,\n )\n\n # 5. Call the LLM API\n try:\n api_response = current_client.server.completions_create(req=req)\n except Exception as e:\n raise RuntimeError("Failed to call LLM completions endpoint.") from e\n\n # 6. Extract the message from the API response\n try:\n # The \'response\' attribute of CompletionsCreateRes is a dict\n response_payload = api_response.response\n response_format = (\n req.inputs.response_format.get("type")\n if req.inputs.response_format is not None\n else None\n )\n return parse_response(response_payload, response_format)\n except (\n KeyError,\n IndexError,\n TypeError,\n AttributeError,\n ValueError,\n json.JSONDecodeError,\n ) as e:\n raise RuntimeError(\n f"Failed to extract message from LLM response payload. Response: {api_response.response}"\n ) from e\n',
},
],
# Sad ... equality is really a pain to assert here (and is broken)
# TODO: Write a good equality check and make it work
Expand Down
74 changes: 74 additions & 0 deletions tests/trace/test_llm_as_a_judge_scorer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from unittest.mock import patch

import weave
from weave.flow.scorer import Scorer
from weave.prompt.prompt import MessagesPrompt
from weave.scorers import LLMAsAJudgeScorer
from weave.trace.object_record import pydantic_object_record
from weave.trace.refs import ObjectRef
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.interface.builtin_object_classes.builtin_object_registry import (
LLMStructuredCompletionModel,
)
Expand Down Expand Up @@ -105,3 +109,73 @@ def test_score_with_messages_prompt():
assert len(messages) == 2
assert messages[0]["content"] == "You are a math judge."
assert messages[1]["content"] == "Expected: 4, Got: 4"


def _make_judge_scorer() -> LLMAsAJudgeScorer:
return LLMAsAJudgeScorer(
model=LLMStructuredCompletionModel(
llm_model_id="gpt-4o-mini",
default_params=LLMStructuredCompletionModelDefaultParams(
response_format="json_object",
),
),
scoring_prompt="Output: {output}",
)


def test_llm_as_a_judge_scorer_record_excludes_op_methods():
"""WB-35184: the scorer and its nested model must not record their @op methods.

Publishing those embeds CustomWeaveType(Op) payloads that the scoring worker
rejects (``_assert_safe_scorer_payload``), so a programmatically created judge
monitor silently never scores. Both classes opt out via
``_weave_exclude_ops_from_record``; a plain Scorer subclass still records its ops.
"""
scorer = _make_judge_scorer()

scorer_record = pydantic_object_record(scorer)
assert "score" not in scorer_record.__dict__
assert "summarize" not in scorer_record.__dict__
assert scorer_record._class_name == "LLMAsAJudgeScorer"

model_record = pydantic_object_record(scorer.model)
assert "predict" not in model_record.__dict__
assert model_record._class_name == "LLMStructuredCompletionModel"

class _PlainScorer(Scorer):
pass

plain_record = pydantic_object_record(_PlainScorer(name="plain"))
assert "score" in plain_record.__dict__
assert "summarize" in plain_record.__dict__


def test_llm_as_a_judge_scorer_publish_has_no_op_refs(client):
"""The published payload must carry no op refs, so the scoring worker accepts it.

The worker walks the payload, follows refs, and fails closed on any nested
CustomWeaveType(Op). Previously the scorer's score/summarize and the nested
model's predict serialized as op refs and tripped that guard (WB-35184).
"""
scorer = _make_judge_scorer()
ref = weave.publish(scorer)

def stored_val(name: str, digest: str) -> dict:
res = client.server.obj_read(
tsi.ObjReadReq(project_id=client.project_id, object_id=name, digest=digest)
)
return res.obj.val

scorer_val = stored_val(ref.name, ref.digest)
assert "score" not in scorer_val
assert "summarize" not in scorer_val

# The nested model is published as its own ref; resolve and check it too.
model_ref = ObjectRef.parse_uri(scorer_val["model"])
model_val = stored_val(model_ref.name, model_ref.digest)
assert "predict" not in model_val

# The scorer still round-trips back to a usable object.
loaded = weave.get(ref.uri)
assert isinstance(loaded, LLMAsAJudgeScorer)
assert isinstance(loaded.model, LLMStructuredCompletionModel)
8 changes: 7 additions & 1 deletion weave/scorers/llm_as_a_judge_scorer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, ClassVar

from pydantic import AliasChoices, ConfigDict, Field, field_validator

Expand Down Expand Up @@ -29,6 +29,12 @@ class LLMAsAJudgeScorer(Scorer):

model_config = ConfigDict(populate_by_name=True)

# Don't serialize score()/summarize() as op refs on publish: the resulting
# CustomWeaveType(Op) payloads are what the scoring worker's safety guard
# rejects, and nothing reads them (the @op still wraps the live method). This
# matches the op-free shape the Weave UI already persists. See WB-35184.
_weave_exclude_ops_from_record: ClassVar[bool] = True

model: LLMStructuredCompletionModel
scoring_prompt: str | MessagesPrompt
enable_image_input_scoring: bool = False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Annotated, Any, Literal
from typing import Annotated, Any, ClassVar, Literal

from pydantic import BaseModel, BeforeValidator, Field

Expand Down Expand Up @@ -121,6 +121,12 @@ def cast_to_llm_structured_model_params(


class LLMStructuredCompletionModel(Model):
# Don't serialize predict() as an op ref on publish: nested inside a
# published LLMAsAJudgeScorer it embeds a CustomWeaveType(Op) payload the
# scoring worker's safety guard rejects, and nothing reads the ref (the @op
# still wraps the live method). See WB-35184.
_weave_exclude_ops_from_record: ClassVar[bool] = True

# <provider>/<model> or ref to a provider model
llm_model_id: str | base_object_def.RefStr

Expand Down
Loading