diff --git a/dev_docs/BuiltinObjectClasses.md b/dev_docs/BuiltinObjectClasses.md index 5587591adca8..d927737945a0 100644 --- a/dev_docs/BuiltinObjectClasses.md +++ b/dev_docs/BuiltinObjectClasses.md @@ -80,7 +80,7 @@ While many Weave Objects are free-form and user-defined, there is often a need f Here's how to define and use a validated base object: -1. **Define your schema** (in `weave/trace_server/interface/builtin_object_classes/your_schema.py`): +1. **Define your schema** (in `weave/shared/builtin_object_classes/your_schema.py`): ```python from pydantic import BaseModel @@ -160,10 +160,10 @@ Run `make synchronize-base-object-schemas` to ensure the frontend TypeScript typ ### Architecture Flow -1. Define your schema in a python file in the `weave/trace_server/interface/builtin_object_classes/test_only_example.py` directory. See `weave/trace_server/interface/builtin_object_classes/test_only_example.py` as an example. -2. Make sure to register your schemas in `weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py` by calling `register_base_object`. +1. Define your schema in a python file in the `weave/shared/builtin_object_classes/test_only_example.py` directory. See `weave/shared/builtin_object_classes/test_only_example.py` as an example. +2. Make sure to register your schemas in `weave/shared/builtin_object_classes/builtin_object_registry.py` by calling `register_base_object`. 3. Run `make synchronize-base-object-schemas` to generate the frontend types. - - The first step (`make generate_base_object_schemas`) will run `scripts/generate_base_object_schemas.py` to generate a JSON schema in `weave/trace_server/interface/builtin_object_classes/generated/generated_builtin_object_class_schemas.json`. + - The first step (`make generate_base_object_schemas`) will run `scripts/generate_base_object_schemas.py` to generate a JSON schema in `weave/shared/builtin_object_classes/generated/generated_builtin_object_class_schemas.json`. - The second step (yarn `generate-schemas`) will read this file and use it to generate the frontend types located in `frontends/weave/src/components/PagePanelComponents/Home/Browse3/pages/wfReactInterface/generatedBuiltinObjectClasses.zod.ts`. 4. Now, each use case uses different parts: 1. `Python Writing`. Users can directly import these classes and use them as normal Pydantic models, which get published with `weave.publish`. The python client correct builds the requisite payload. diff --git a/scripts/generate_base_object_schemas.py b/scripts/generate_base_object_schemas.py index 37b0ad1192cb..2737051a153a 100644 --- a/scripts/generate_base_object_schemas.py +++ b/scripts/generate_base_object_schemas.py @@ -3,15 +3,14 @@ from pydantic import create_model -from weave.trace_server.interface.builtin_object_classes.builtin_object_registry import ( +from weave.shared.builtin_object_classes.builtin_object_registry import ( BUILTIN_OBJECT_REGISTRY, ) OUTPUT_DIR = ( Path(__file__).parent.parent / "weave" - / "trace_server" - / "interface" + / "shared" / "builtin_object_classes" / "generated" ) diff --git a/tests/trace/test_llm_structured_completion_model.py b/tests/trace/test_llm_structured_completion_model.py index 80fdecceaa7e..9e641fde5b6b 100644 --- a/tests/trace/test_llm_structured_completion_model.py +++ b/tests/trace/test_llm_structured_completion_model.py @@ -193,7 +193,7 @@ def test_llm_structured_completion_model_filtering(client: WeaveClient): @patch( - "weave.trace_server.interface.builtin_object_classes.llm_structured_model.get_weave_client" + "weave.shared.builtin_object_classes.llm_structured_model.get_weave_client" ) def test_llm_structured_completion_model_predict_text_response(mock_get_client): """Test the predict function with mocked LLM API response for text format.""" @@ -245,7 +245,7 @@ def test_llm_structured_completion_model_predict_text_response(mock_get_client): @patch( - "weave.trace_server.interface.builtin_object_classes.llm_structured_model.get_weave_client" + "weave.shared.builtin_object_classes.llm_structured_model.get_weave_client" ) def test_llm_structured_completion_model_predict_json_response(mock_get_client): """Test the predict function with mocked LLM API response for JSON format.""" @@ -284,7 +284,7 @@ def test_llm_structured_completion_model_predict_json_response(mock_get_client): @patch( - "weave.trace_server.interface.builtin_object_classes.llm_structured_model.get_weave_client" + "weave.shared.builtin_object_classes.llm_structured_model.get_weave_client" ) def test_llm_structured_completion_model_predict_with_template(mock_get_client): """Test the predict function with message templates and template variables.""" @@ -343,7 +343,7 @@ def test_llm_structured_completion_model_predict_with_template(mock_get_client): @patch( - "weave.trace_server.interface.builtin_object_classes.llm_structured_model.get_weave_client" + "weave.shared.builtin_object_classes.llm_structured_model.get_weave_client" ) def test_llm_structured_completion_model_predict_with_config_override(mock_get_client): """Test the predict function with config parameter overriding defaults.""" @@ -395,7 +395,7 @@ def test_llm_structured_completion_model_predict_with_config_override(mock_get_c @patch( - "weave.trace_server.interface.builtin_object_classes.llm_structured_model.get_weave_client" + "weave.shared.builtin_object_classes.llm_structured_model.get_weave_client" ) def test_llm_structured_completion_model_predict_error_handling(mock_get_client): """Test the predict function error handling.""" @@ -656,7 +656,7 @@ def test_cast_to_message(): @patch( - "weave.trace_server.interface.builtin_object_classes.llm_structured_model.get_weave_client" + "weave.shared.builtin_object_classes.llm_structured_model.get_weave_client" ) def test_llm_structured_completion_model_predict_with_prompt( mock_get_client, client: WeaveClient @@ -733,7 +733,7 @@ def test_llm_structured_completion_model_predict_with_prompt( @patch( - "weave.trace_server.interface.builtin_object_classes.llm_structured_model.get_weave_client" + "weave.shared.builtin_object_classes.llm_structured_model.get_weave_client" ) def test_llm_structured_completion_model_prompt_takes_precedence( mock_get_client, client: WeaveClient diff --git a/weave/flow/annotation_spec.py b/weave/flow/annotation_spec.py index aefaa01f0f70..6b2df4f4559f 100644 --- a/weave/flow/annotation_spec.py +++ b/weave/flow/annotation_spec.py @@ -1,4 +1,4 @@ -from weave.trace_server.interface.builtin_object_classes import annotation_spec +from weave.shared.builtin_object_classes import annotation_spec # Re-export: AnnotationSpec = annotation_spec.AnnotationSpec diff --git a/weave/flow/leaderboard.py b/weave/flow/leaderboard.py index 24fe35ac3aa5..b4e268961eba 100644 --- a/weave/flow/leaderboard.py +++ b/weave/flow/leaderboard.py @@ -3,9 +3,9 @@ from dataclasses import dataclass from typing import Any +from weave.shared.builtin_object_classes import leaderboard from weave.trace.refs import OpRef from weave.trace.weave_client import WeaveClient, get_ref -from weave.trace_server.interface.builtin_object_classes import leaderboard from weave.trace_server.trace_server_interface import CallsFilter from weave.utils.project_id import from_project_id diff --git a/weave/flow/saved_view.py b/weave/flow/saved_view.py index c7f47710f87e..ba7c01ab3a96 100644 --- a/weave/flow/saved_view.py +++ b/weave/flow/saved_view.py @@ -6,6 +6,10 @@ from pydantic import BaseModel from typing_extensions import Self +from weave.shared.builtin_object_classes.saved_view import Column, Pin +from weave.shared.builtin_object_classes.saved_view import ( + SavedView as SavedViewBase, +) from weave.trace import urls from weave.trace.api import publish as weave_publish from weave.trace.api import ref as weave_ref @@ -20,10 +24,6 @@ from weave.trace_server import trace_server_interface as tsi from weave.trace_server.common_interface import SortBy from weave.trace_server.interface import query as tsi_query -from weave.trace_server.interface.builtin_object_classes.saved_view import Column, Pin -from weave.trace_server.interface.builtin_object_classes.saved_view import ( - SavedView as SavedViewBase, -) KNOWN_COLUMNS = [ "id", diff --git a/weave/scorers/llm_as_a_judge_scorer.py b/weave/scorers/llm_as_a_judge_scorer.py index 38b943c6de45..a0244a9e71d5 100644 --- a/weave/scorers/llm_as_a_judge_scorer.py +++ b/weave/scorers/llm_as_a_judge_scorer.py @@ -4,13 +4,13 @@ from weave.flow.scorer import Scorer from weave.prompt.prompt import MessagesPrompt +from weave.shared.builtin_object_classes.llm_structured_model import ( + LLMStructuredCompletionModel, +) from weave.trace.context.weave_client_context import get_weave_client from weave.trace.objectify import maybe_objectify, register_object from weave.trace.op import op from weave.trace.vals import make_trace_obj -from weave.trace_server.interface.builtin_object_classes.llm_structured_model import ( - LLMStructuredCompletionModel, -) @register_object diff --git a/weave/shared/builtin_object_classes/alert_spec.py b/weave/shared/builtin_object_classes/alert_spec.py new file mode 100644 index 000000000000..7b832f0eba9c --- /dev/null +++ b/weave/shared/builtin_object_classes/alert_spec.py @@ -0,0 +1,75 @@ +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + +from weave.shared.builtin_object_classes import base_object_def + + +class WeaveMetricThresholdSpec(BaseModel): + """Alert specification for weave metric threshold alerts. + + Fields align with gorilla's WeaveMetricThresholdFilter and the + alert_worker's WeaveMetricFilter. Extra keys are permitted so that + the schema can evolve without breaking existing stored objects. + """ + + model_config = ConfigDict(extra="allow") + + alert_condition: str = Field( + default="THRESHOLD", + description="Trigger condition type (e.g. 'THRESHOLD')", + ) + comparison_operator: Literal["GREATER_THAN", "LESS_THAN", "EQUAL"] = Field( + default="GREATER_THAN", + description="Direction of the threshold comparison", + ) + metric_path: str = Field( + default="", + description="Dot-notation path to the metric value in the call (e.g. 'output.score')", + ) + op_str: str = Field( + default="", + description="Op or scorer ref to scope this alert to; empty means all ops", + ) + threshold: float = Field( + default=0.0, + description="Threshold value that triggers the alert", + ) + window_size: int | None = Field( + default=None, + description="Max number of recent calls to include in the window", + ) + window_duration: int | None = Field( + default=None, + description="Time window in seconds for historical calls", + ) + monitor_ref: str | None = Field( + default=None, + description="Monitor ref; presence indicates alert was created from a monitor", + ) + scorer_ref: str | None = Field( + default=None, + description="Scorer object ref identifying which scorer instance within a monitor to alert on. " + "Disambiguates scorers that share the same op class (e.g. multiple LLMAsAJudgeScorer instances). " + "Filterable via inputs.self on scorer calls or input_refs on CallsFilter.", + ) + aggregation_function: Literal["mean", "median", "min", "max", "mode"] = Field( + default="mean", + description="Aggregation function applied to metric values within the window before threshold comparison", + ) + + +class AlertSpec(base_object_def.BaseObject): + op_scope: list[str] | None = Field( + default=None, + description="If provided, this alert only applies to calls from the given op refs", + examples=[ + ["weave:///entity/project/op/name:digest"], + ["weave:///entity/project/op/name:*"], + ], + ) + + spec: WeaveMetricThresholdSpec = Field( + default_factory=WeaveMetricThresholdSpec, + description="Alert specification (threshold config, window config, etc.)", + ) diff --git a/weave/shared/builtin_object_classes/annotation_spec.py b/weave/shared/builtin_object_classes/annotation_spec.py new file mode 100644 index 000000000000..818b0573a16c --- /dev/null +++ b/weave/shared/builtin_object_classes/annotation_spec.py @@ -0,0 +1,126 @@ +from typing import Any + +from pydantic import BaseModel, Field, create_model, field_validator, model_validator +from pydantic.fields import FieldInfo + +from weave.shared.builtin_object_classes import base_object_def + +SUPPORTED_PRIMITIVES = (int, float, bool, str) + + +class AnnotationSpec(base_object_def.BaseObject): + field_schema: dict[str, Any] = Field( + default={}, + description="Expected to be valid JSON Schema. Can be provided as a dict, a Pydantic model class, a tuple of a primitive type and a Pydantic Field, or primitive type", + examples=[ + # String feedback + {"type": "string", "maxLength": 100}, + # Number feedback + {"type": "number", "minimum": 0, "maximum": 100}, + # Integer feedback + {"type": "integer", "minimum": 0, "maximum": 100}, + # Boolean feedback + {"type": "boolean"}, + # Categorical feedback + {"type": "string", "enum": ["option1", "option2"]}, + ], + ) + + # TODO + # If true, all unique creators will have their + # own value for this feedback type. Otherwise, + # by default, the value is shared and can be edited. + unique_among_creators: bool = False + + # TODO + # If provided, this feedback type will only be shown + # when a call is generated from the given op ref + op_scope: list[str] | None = Field( + default=None, + examples=[ + ["weave:///entity/project/op/name:digest"], + ["weave:///entity/project/op/name:*"], + ], + ) + + @model_validator(mode="before") + @classmethod + def preprocess_field_schema(cls, data: dict[str, Any]) -> dict[str, Any]: + if "field_schema" not in data: + return data + + field_schema = data["field_schema"] + + temp_field_tuple = None + # Handle Pydantic Field + if isinstance(field_schema, tuple): + if len(field_schema) != 2: + raise ValueError("Expected a tuple of length 2") + annotation, field = field_schema + if ( + not isinstance(annotation, type) + ) or annotation not in SUPPORTED_PRIMITIVES: + raise TypeError("Expected annotation to be a primitive type") + if not isinstance(field, FieldInfo): + raise TypeError("Expected field to be a Pydantic Field") + temp_field_tuple = (annotation, field) + elif field_schema in SUPPORTED_PRIMITIVES: + temp_field_tuple = (field_schema, Field()) + + if temp_field_tuple is not None: + # Create a temporary model to leverage Pydantic's schema generation + TempModel = create_model("TempModel", field=temp_field_tuple) # noqa: N806 + + schema = TempModel.model_json_schema()["properties"]["field"] + + if ( + "title" in schema and schema["title"] == "Field" + ): # default title for Field + schema.pop("title") + + data["field_schema"] = schema + return data + + # Handle Pydantic model + if isinstance(field_schema, type) and issubclass(field_schema, BaseModel): + # Read back through `data` (typed `Any`): pydantic is invisible to + # the type-check env, so `issubclass` cannot narrow off bare `type`. + data["field_schema"] = data["field_schema"].model_json_schema() + return data + + return data + + @field_validator("field_schema") + def validate_field_schema(cls, schema: dict[str, Any]) -> dict[str, Any]: # noqa: N805 + # Imported lazily: `import jsonschema` eagerly loads every installed + # format-checker lib (rfc3987_syntax builds a Lark grammar, ~0.45s), + # so keeping it out of module scope avoids paying that on `import weave`. + import jsonschema + + # Validate the schema + try: + jsonschema.validate(None, schema) + except jsonschema.exceptions.SchemaError: + raise + except jsonschema.exceptions.ValidationError: + pass # we don't care that `None` does not conform + return schema + + def value_is_valid(self, payload: Any) -> bool: + """Validates a payload against this annotation spec's schema. + + Args: + payload: The data to validate against the schema + + Returns: + bool: True if validation succeeds, False otherwise + """ + # Lazy import: see note in validate_field_schema (avoids ~0.45s of + # jsonschema format-checker imports at `import weave` time). + import jsonschema + + try: + jsonschema.validate(payload, self.field_schema) + except jsonschema.exceptions.ValidationError: + return False + return True diff --git a/weave/shared/builtin_object_classes/base_object_def.py b/weave/shared/builtin_object_classes/base_object_def.py new file mode 100644 index 000000000000..07574e7e547c --- /dev/null +++ b/weave/shared/builtin_object_classes/base_object_def.py @@ -0,0 +1,10 @@ +import pydantic + +RefStr = str + + +# This is just an alternative to weave.Object for the server side. +# I _think_ this will go away once we have the full weave system on the server +class BaseObject(pydantic.BaseModel): + name: str | None = None + description: str | None = None diff --git a/weave/shared/builtin_object_classes/builtin_object_registry.py b/weave/shared/builtin_object_classes/builtin_object_registry.py new file mode 100644 index 000000000000..eef1840eec58 --- /dev/null +++ b/weave/shared/builtin_object_classes/builtin_object_registry.py @@ -0,0 +1,52 @@ +from weave.shared.builtin_object_classes.alert_spec import AlertSpec +from weave.shared.builtin_object_classes.annotation_spec import ( + AnnotationSpec, +) +from weave.shared.builtin_object_classes.base_object_def import ( + BaseObject, +) +from weave.shared.builtin_object_classes.comparison_view import ( + ComparisonView, +) +from weave.shared.builtin_object_classes.leaderboard import Leaderboard +from weave.shared.builtin_object_classes.llm_structured_model import ( + LLMStructuredCompletionModel, +) +from weave.shared.builtin_object_classes.provider import ( + Provider, + ProviderModel, +) +from weave.shared.builtin_object_classes.saved_view import ( + ChartConfig, + SavedView, +) +from weave.shared.builtin_object_classes.test_only_example import ( + TestOnlyExample, + TestOnlyInheritedBaseObject, + TestOnlyNestedBaseObject, +) + +BUILTIN_OBJECT_REGISTRY: dict[str, type[BaseObject]] = {} + + +def register_base_object(cls: type[BaseObject]) -> None: + """Register a BaseObject class in the global registry. + + Args: + cls: The BaseObject class to register + """ + BUILTIN_OBJECT_REGISTRY[cls.__name__] = cls + + +register_base_object(TestOnlyExample) +register_base_object(TestOnlyNestedBaseObject) +register_base_object(TestOnlyInheritedBaseObject) +register_base_object(Leaderboard) +register_base_object(AlertSpec) +register_base_object(AnnotationSpec) +register_base_object(Provider) +register_base_object(ProviderModel) +register_base_object(SavedView) +register_base_object(ComparisonView) +register_base_object(LLMStructuredCompletionModel) +register_base_object(ChartConfig) diff --git a/weave/shared/builtin_object_classes/comparison_view.py b/weave/shared/builtin_object_classes/comparison_view.py new file mode 100644 index 000000000000..692b5876c9b4 --- /dev/null +++ b/weave/shared/builtin_object_classes/comparison_view.py @@ -0,0 +1,50 @@ +"""ComparisonView builtin object class for saving comparison view configurations. + +This allows users to save and restore comparison configurations including +evaluation call IDs and selected metrics. +""" + +from pydantic import BaseModel + +from weave.shared.builtin_object_classes import base_object_def + + +class ComparisonViewDefinition(BaseModel): + """Definition of a comparison view's configuration. + + Args: + evaluation_call_ids (list[str]): List of evaluation call IDs being compared. + selected_metrics (list[str] | None): List of metrics that are visible in plots. + + Examples: + >>> definition = ComparisonViewDefinition( + ... evaluation_call_ids=["call_1", "call_2"], + ... selected_metrics=["accuracy", "f1_score"] + ... ) + """ + + evaluation_call_ids: list[str] + selected_metrics: list[str] | None = None + + +class ComparisonView(base_object_def.BaseObject): + """A saved comparison view configuration. + + Args: + label (str): Human-readable name for the comparison view. + definition (ComparisonViewDefinition): The view's configuration. + + Examples: + >>> view = ComparisonView( + ... label="My Comparison", + ... definition=ComparisonViewDefinition( + ... evaluation_call_ids=["call_1", "call_2"] + ... ) + ... ) + """ + + label: str + definition: ComparisonViewDefinition + + +__all__ = ["ComparisonView", "ComparisonViewDefinition"] diff --git a/weave/trace_server/interface/builtin_object_classes/generated/generated_builtin_object_class_schemas.json b/weave/shared/builtin_object_classes/generated/generated_builtin_object_class_schemas.json similarity index 100% rename from weave/trace_server/interface/builtin_object_classes/generated/generated_builtin_object_class_schemas.json rename to weave/shared/builtin_object_classes/generated/generated_builtin_object_class_schemas.json diff --git a/weave/shared/builtin_object_classes/leaderboard.py b/weave/shared/builtin_object_classes/leaderboard.py new file mode 100644 index 000000000000..96938cc08c0f --- /dev/null +++ b/weave/shared/builtin_object_classes/leaderboard.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel + +from weave.shared.builtin_object_classes import base_object_def + + +class LeaderboardColumn(BaseModel): + evaluation_object_ref: base_object_def.RefStr + scorer_name: str + summary_metric_path: str + should_minimize: bool | None = None + + +class Leaderboard(base_object_def.BaseObject): + columns: list[LeaderboardColumn] diff --git a/weave/shared/builtin_object_classes/llm_structured_model.py b/weave/shared/builtin_object_classes/llm_structured_model.py new file mode 100644 index 000000000000..735f3fd2fd68 --- /dev/null +++ b/weave/shared/builtin_object_classes/llm_structured_model.py @@ -0,0 +1,373 @@ +import json +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, BeforeValidator, Field + +from weave import Model, op +from weave.prompt.prompt import format_message_with_template_vars +from weave.shared.builtin_object_classes import base_object_def +from weave.trace import vals +from weave.trace.context.weave_client_context import WeaveInitError, get_weave_client +from weave.trace_server.trace_server_interface import ( + CompletionsCreateReq, + CompletionsCreateRequestInputs, +) +from weave.utils.project_id import to_project_id + +ResponseFormat = Literal["json_object", "json_schema", "text"] + + +def is_response_format(value: Any) -> bool: + return isinstance(value, str) and value in {"json_object", "text"} + + +class Message(BaseModel): + """A message in a conversation with an LLM. + + Attributes: + role: The role of the message's author. Can be: system, user, assistant, function or tool. + content: The contents of the message. Required for all messages, but may be null for assistant messages with function calls. + name: The name of the author of the message. Required if role is "function". Must match the name of the function represented in content. + Can contain characters (a-z, A-Z, 0-9), and underscores, with a maximum length of 64 characters. + function_call: The name and arguments of a function that should be called, as generated by the model. + tool_call_id: Tool call that this message is responding to. + """ + + role: str + content: str | list[dict] | None = None + name: str | None = None + function_call: dict | None = None + tool_call_id: str | None = None + + +class LLMStructuredCompletionModelDefaultParams(BaseModel): + """Default parameters for LLMStructuredCompletionModel. + + Attributes: + messages_template: A list of Messages to use as a template. Messages can contain + template variables using {variable_name} syntax. These will be substituted + when predict() is called with template_vars. + prompt: A reference string to a MessagesPrompt object. If provided, this takes + precedence over messages_template. The referenced prompt's format() method will + be used to generate messages with template variable substitution. + Example: "weave:///entity/project/object/my_prompt:latest" + """ + + # This is a list of Messages, loosely following litellm's message format + # https://docs.litellm.ai/docs/completion/input#properties-of-messages + messages_template: list[Message] | None = None + prompt: base_object_def.RefStr | None = None + + temperature: float | None = None + top_p: float | None = None + max_tokens: int | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + stop: list[str] | None = None + n_times: int | None = None + functions: list[dict] | None = None + + # Either json, text, or json_schema + response_format: ResponseFormat | None = None + + # TODO: Currently not used. Fast follow up with json_schema + # if default_params.response_format is set to JSON_SCHEMA, this will be used + # response_format_schema: dict | None = None + + +def cast_to_message_list(obj: Any) -> list[Message]: + if isinstance(obj, Message): + return [obj] + elif isinstance(obj, dict): + return [Message.model_validate(obj)] + elif isinstance(obj, str): + return [Message(content=obj, role="user")] + elif isinstance(obj, list): + return [cast_to_message(item) for item in obj] + raise TypeError("Unable to cast to Message") + + +def cast_to_message(obj: Any) -> Message: + if isinstance(obj, Message): + return obj + elif isinstance(obj, dict): + return Message.model_validate(obj) + elif isinstance(obj, str): + return Message(content=obj, role="user") + raise TypeError("Unable to cast to Message") + + +def cast_to_llm_structured_model_params( + obj: Any, +) -> LLMStructuredCompletionModelDefaultParams: + if isinstance(obj, LLMStructuredCompletionModelDefaultParams): + return obj + elif isinstance(obj, dict): + return LLMStructuredCompletionModelDefaultParams.model_validate(obj) + elif isinstance(obj, vals.Traceable): + return LLMStructuredCompletionModelDefaultParams.model_validate( + vals.unwrap(obj) # Recursively "unwrap" to a dict with plain python types + ) + + raise TypeError("Unable to cast to LLMStructuredCompletionModelDefaultParams") + + +MessageListLike = Annotated[list[Message], BeforeValidator(cast_to_message_list)] +MessageLike = Annotated[Message, BeforeValidator(cast_to_message)] +LLMStructuredModelParamsLike = Annotated[ + LLMStructuredCompletionModelDefaultParams, + BeforeValidator(cast_to_llm_structured_model_params), +] + + +class LLMStructuredCompletionModel(Model): + # / or ref to a provider model + llm_model_id: str | base_object_def.RefStr + + default_params: LLMStructuredModelParamsLike = Field( + default_factory=LLMStructuredCompletionModelDefaultParams + ) + + @op + def predict( + self, + user_input: MessageListLike | None = None, + config: LLMStructuredModelParamsLike | None = None, + **template_vars: Any, + ) -> Message | str | dict[str, Any]: + """Generates a prediction by preparing messages (template + user_input) + and calling the LLM completions endpoint with overridden config, using the provided client. + + Messages are prepared in one of two ways: + 1. If default_params.prompt is set, the referenced MessagesPrompt object is + loaded and its format() method is called with template_vars to generate messages. + 2. If default_params.messages_template is set (and prompt is not), the template + messages are used with template variable substitution. + + Note: If both prompt and messages_template are provided, prompt takes precedence. + + Args: + user_input: The user input messages to append after template messages + config: Optional configuration to override default parameters + **template_vars: Variables to substitute in the messages template using {variable_name} syntax + """ + if user_input is None: + user_input = [] + + current_client = get_weave_client() + if current_client is None: + raise WeaveInitError( + "You must call `weave.init()` first, to predict with a LLMStructuredCompletionModel" + ) + + req = self.prepare_completion_request( + project_id=to_project_id(current_client.entity, current_client.project), + user_input=user_input, + config=config, + **template_vars, + ) + + # 5. Call the LLM API + try: + api_response = current_client.server.completions_create(req=req) + except Exception as e: + raise RuntimeError("Failed to call LLM completions endpoint.") from e + + # 6. Extract the message from the API response + try: + # The 'response' attribute of CompletionsCreateRes is a dict + response_payload = api_response.response + response_format = ( + req.inputs.response_format.get("type") + if req.inputs.response_format is not None + else None + ) + return parse_response(response_payload, response_format) + except ( + KeyError, + IndexError, + TypeError, + AttributeError, + ValueError, + json.JSONDecodeError, + ) as e: + raise RuntimeError( + f"Failed to extract message from LLM response payload. Response: {api_response.response}" + ) from e + + def prepare_completion_request( + self, + project_id: str, + user_input: MessageListLike, + config: LLMStructuredModelParamsLike | None, + **template_vars: Any, + ) -> CompletionsCreateReq: + # Ensure user_input is properly converted to a list of Message objects + # This is needed because the @op decorator might interfere with Pydantic validation + if not isinstance(user_input, list) or ( + user_input and not isinstance(user_input[0], Message) + ): + user_input = cast_to_message_list(user_input) + + # 1. Prepare messages from messages_template (if no prompt is set) + # Note: If prompt is set, we don't prepare messages here - we pass the prompt + # reference to the completions endpoint which will resolve and substitute it + template_msgs = None + + # Only use messages_template if prompt is NOT set + if ( + self.default_params + and self.default_params.messages_template + and not self.default_params.prompt + ): + template_msgs = self.default_params.messages_template + if template_vars: + # Convert Message objects to dicts, apply template vars, convert back + formatted_dicts = [ + format_message_with_template_vars( + msg.model_dump(exclude_none=True), **template_vars + ) + for msg in template_msgs + ] + template_msgs = [Message.model_validate(d) for d in formatted_dicts] + + prepared_messages_dicts = _prepare_llm_messages(template_msgs, user_input) + + # 2. Prepare completion parameters, starting with defaults from LLMStructuredCompletionModel + completion_params: dict[str, Any] = {} + default_p_model = self.default_params + if default_p_model: + completion_params = parse_params_to_litellm_params(default_p_model) + + # 3. Override parameters with the provided config dictionary + if config: + completion_params = { + **completion_params, + **parse_params_to_litellm_params(config), + } + + # 4. Create the completion inputs + model_id_str = str(self.llm_model_id) + + # Include template_vars if they exist + if template_vars: + completion_params["template_vars"] = template_vars + + completion_inputs = CompletionsCreateRequestInputs( + model=model_id_str, messages=prepared_messages_dicts, **completion_params + ) + req = CompletionsCreateReq( + project_id=project_id, + inputs=completion_inputs, + ) + + return req + + +def parse_response( + response_payload: dict, response_format: ResponseFormat | None +) -> Message | str | dict[str, Any]: + """Extract the model output from an LLM completion response payload. + + Raises: + RuntimeError: the provider returned a top-level `error` field. + ValueError: the payload is malformed (missing choices/message), the + content is None/empty, or json_object parsing failed. + """ + if response_payload.get("error"): + raise RuntimeError(f"LLM API returned an error: {response_payload['error']}") + + choices = response_payload.get("choices") + if not choices: + raise ValueError( + "LLM response is missing 'choices' -> the upstream call likely failed " + "(invalid API key, content filtering, or provider error). " + f"Response keys: {sorted(response_payload.keys())}" + ) + + message = choices[0].get("message") if isinstance(choices[0], dict) else None + if not isinstance(message, dict): + raise TypeError( + f"LLM response choice did not contain a message dict: {choices[0]!r}" + ) + content = message.get("content") + + if response_format == "text": + if content is None: + raise ValueError( + "LLM response content is None -> the model returned no text. " + "Check your API key, model config, and content filtering settings." + ) + return content + elif response_format == "json_object": + if content is None or (isinstance(content, str) and not content.strip()): + raise ValueError( + "LLM response content was empty when JSON output was requested. " + "Check your API key and that the model supports JSON mode." + ) + try: + return json.loads(content) + except json.JSONDecodeError as e: + snippet = content if len(content) <= 200 else content[:200] + "..." + raise ValueError( + f"LLM response was not valid JSON (response_format=json_object). " + f"Content snippet: {snippet!r}" + ) from e + else: + raise ValueError(f"Invalid response_format: {response_format}") + + +def _prepare_llm_messages( + template_messages: list[Message] | None, + user_input: list[Message], +) -> list[dict[str, Any]]: + """Prepares a list of message dictionaries for the LLM API from a message template and user input. + Helper function for PlaygroundModel.predict. + Returns a list of message dictionaries. + """ + final_messages_dicts: list[dict[str, Any]] = [] + + # 1. Initialize messages from template + if template_messages: + for msg_template in template_messages: + msg_dict = msg_template.model_dump(exclude_none=True) + final_messages_dicts.append(msg_dict) + + # 2. Append user_input messages + for u_msg in user_input: + final_messages_dicts.append(u_msg.model_dump(exclude_none=True)) + + return final_messages_dicts + + +def parse_params_to_litellm_params( + params_source: LLMStructuredCompletionModelDefaultParams, +) -> dict[str, Any]: + final_params: dict[str, Any] = {} + source_dict_to_iterate: dict[str, Any] = params_source.model_dump(exclude_none=True) + + for key, value in source_dict_to_iterate.items(): + if key == "response_format": + litellm_response_format_value = None + if isinstance(value, str) and is_response_format(value): + litellm_response_format_value = {"type": value} + elif ( + isinstance(value, dict) + and "type" in value + and is_response_format(value["type"]) + ): # Pre-formed dict with valid type + litellm_response_format_value = value + + if litellm_response_format_value is not None: + final_params["response_format"] = litellm_response_format_value + elif key == "n_times": + final_params["n"] = value + elif key == "messages_template": + pass + elif key in {"functions", "stop"}: + if isinstance(value, list) and len(value) > 0: + final_params[key] = value + else: + final_params[key] = value + + return final_params diff --git a/weave/shared/builtin_object_classes/provider.py b/weave/shared/builtin_object_classes/provider.py new file mode 100644 index 000000000000..32093cdde1f7 --- /dev/null +++ b/weave/shared/builtin_object_classes/provider.py @@ -0,0 +1,70 @@ +import re +from enum import Enum +from urllib.parse import urlparse + +from pydantic import ConfigDict, Field, field_validator + +from weave.shared.builtin_object_classes import base_object_def +from weave.shared.url_safety import is_publicly_routable_url + +# Headers that must not appear in user-supplied extra_headers. +# https://coreweave.atlassian.net/browse/VULNMGMT-770 +BLOCKED_HEADER_RE = re.compile( + r"^(?:metadata-flavor" + r"|x-aws-ec2-metadata-token(?:-ttl-seconds)?" + r")$", + re.IGNORECASE, +) + + +INVALID_BASE_URL_MSG = "base_url is not a valid provider URL" + + +def _validate_provider_base_url(url: str) -> str: + """Validate that a provider base_url is a well-formed, publicly-routable HTTP(S) URL. + + See https://coreweave.atlassian.net/browse/VULNMGMT-770 + """ + # urlparse silently strips a bare trailing '?', so check the raw string too. + if "?" in url: + raise ValueError(INVALID_BASE_URL_MSG) + try: + parsed = urlparse(url) + except ValueError as exc: + raise ValueError(INVALID_BASE_URL_MSG) from exc + if parsed.fragment: + raise ValueError(INVALID_BASE_URL_MSG) + if not is_publicly_routable_url(url): + raise ValueError(INVALID_BASE_URL_MSG) + return url + + +class ProviderReturnType(str, Enum): + OPENAI = "openai" + + +class Provider(base_object_def.BaseObject): + model_config = ConfigDict(validate_assignment=True) + + base_url: str + api_key_name: str + extra_headers: dict[str, str] = Field(default_factory=dict) + return_type: ProviderReturnType = Field(default=ProviderReturnType.OPENAI) + + @field_validator("base_url") + @classmethod + def validate_base_url(cls, v: str) -> str: + return _validate_provider_base_url(v) + + @field_validator("extra_headers") + @classmethod + def validate_extra_headers(cls, v: dict[str, str]) -> dict[str, str]: + for key in v: + if BLOCKED_HEADER_RE.match(key): + raise ValueError("extra_headers contains a disallowed header") + return v + + +class ProviderModel(base_object_def.BaseObject): + provider: base_object_def.RefStr + max_tokens: int diff --git a/weave/shared/builtin_object_classes/saved_view.py b/weave/shared/builtin_object_classes/saved_view.py new file mode 100644 index 000000000000..a482215b0ba6 --- /dev/null +++ b/weave/shared/builtin_object_classes/saved_view.py @@ -0,0 +1,131 @@ +from typing import Literal + +from pydantic import BaseModel, Field + +from weave.shared.builtin_object_classes import base_object_def +from weave.trace_server import trace_server_interface as tsi +from weave.trace_server.common_interface import SortBy + +PathElement = str | int + + +class Pin(BaseModel): + left: list[str] + right: list[str] + + +class Column(BaseModel): + # Optional in case we want something like computed columns in the future. + path: list[PathElement] | None = Field(default=None) + label: str | None = Field(default=None) + + +class ChartConfig(BaseModel): + x_axis: str = Field(title="XAxis") + y_axis: str = Field(title="YAxis") + plot_type: Literal["scatter", "line", "bar"] | None = Field( + default=None, + ) + bin_count: int | None = Field(default=None) + aggregation: Literal["average", "sum", "min", "max", "p95", "p99"] | None = Field( + default=None + ) + group_keys: list[str] | None = Field(default=None) + custom_name: str | None = Field(default=None) + + +class ObjectVersionGroup(BaseModel): + label: str # label for the combination of the groups + base_ref: str + versions: list[str] | Literal["*"] + show_version_indicator: bool + + +class ObjectConfig(BaseModel): + version_groups: list[ObjectVersionGroup] | None = Field(default=None) + display_name_map: dict[str, str] | None = Field( + default=None + ) # obj -> display name (keys can use "*" wildcards) + deselected: list[str] | None = Field( + default=None + ) # List of dataset refs or patterns to exclude + + +class DynamicLeaderboardColumnConfig(BaseModel): + evaluation_object_ref: base_object_def.RefStr | None = Field(default=None) + scorer_name: str | None = Field(default=None) + display_name: str | None = Field(default=None) + summary_metric_path: str | None = Field(default=None) + should_minimize: bool | None = Field(default=None) + deselected: bool | None = Field( + default=None + ) # If True, this metric is excluded from the leaderboard + + +class DynamicLeaderboardConfig(BaseModel): + # These are initialized to empty lists and dicts by default (show everything) + model_configuration: ObjectConfig | None = Field(default=None) + dataset_configuration: ObjectConfig | None = Field(default=None) + scorer_configuration: ObjectConfig | None = Field(default=None) + # Only has entries when a column is marked as deselected or minimized + columns_configuration: list[DynamicLeaderboardColumnConfig] | None = Field( + default=None + ) + + +class SavedViewDefinition(BaseModel): + filter: tsi.CallsFilter | None = Field(default=None) + + query: tsi.Query | None = Field(default=None) + + # cols is the current UI column visibility config that + # doesn't allow specifying column order - prefer use of + # explicit columns list which is what we should work towards. + cols: dict[str, bool] | None = Field(default=None) + + # columns is specifying exactly which columns to include + # including order. + columns: list[Column] | None = Field(default=None) + + # column_order is a simple ordered list of column field names. + # Used by the frontend to persist user-defined column ordering. + column_order: list[str] | None = Field(default=None) + + # Paths to columns whose values are refs to other objects and that must be + # dereferenced for filtering / sorting / display to work. Mirrors the + # `expand_columns` field of `CallsQueryReq` — when a saved view filters on + # a sub-field of a referenced object (e.g. `inputs.self.base_model_name`), + # the parent ref path (`inputs.self`) must be in this list so the trace + # server joins through to the referenced object at query time. + expand_columns: list[str] | None = Field(default=None) + + header_depth: int | None = Field(default=None) + + pin: Pin | None = Field(default=None) + sort_by: list[SortBy] | None = Field(default=None) + page: int | None = Field(default=None) + page_size: int | None = Field(default=None) + charts: list[ChartConfig] | None = Field(default=None) + + # Evaluations calls table has dataset and evaluation object + # selectors that can be used to filter down evals to those using these objects. + # The selector is an object ref where the version can either be a digest or `*` + # to match all versions. + dataset_selector: str | None = Field(default=None) + evaluation_selector: str | None = Field(default=None) + + # Dynamic leaderboards are populated by the evals in a saved view + dynamic_leaderboard_config: DynamicLeaderboardConfig | None = Field(default=None) + + +class SavedView(base_object_def.BaseObject): + # "traces" or "evaluations", type is str for extensibility + view_type: str + + # Avoiding confusion around object_id + name + label: str + + definition: SavedViewDefinition + + +__all__ = ["SavedView"] diff --git a/weave/shared/builtin_object_classes/test_only_example.py b/weave/shared/builtin_object_classes/test_only_example.py new file mode 100644 index 000000000000..7fce7c45d696 --- /dev/null +++ b/weave/shared/builtin_object_classes/test_only_example.py @@ -0,0 +1,34 @@ +from pydantic import BaseModel, Field + +from weave.shared.builtin_object_classes import base_object_def + + +class TestOnlyNestedBaseModel(BaseModel): + a: int + aliased_property: int = Field(alias="aliased_property_alias") + + +class TestOnlyNestedBaseObject(base_object_def.BaseObject): + b: int + + +class TestOnlyInheritedBaseObject(TestOnlyNestedBaseObject): + """A builtin object that inherits from another builtin object for testing inheritance.""" + + c: int + additional_field: str = "default_value" + + +class TestOnlyExample(base_object_def.BaseObject): + primitive: int + nested_base_model: TestOnlyNestedBaseModel + # Important: `RefStr` is just an alias for `str`. When defining `BaseObject`s, we + # should never have a property point to another `BaseObject`. This is because each + # base object is stored in the database and should be treated like a foreign key. + # + # It would be nice to have a way to ensure that no `BaseObject` has any `BaseObject` + # properties. + nested_base_object: base_object_def.RefStr + + +__all__ = ["TestOnlyExample", "TestOnlyInheritedBaseObject", "TestOnlyNestedBaseObject"] diff --git a/weave/shared/url_safety.py b/weave/shared/url_safety.py new file mode 100644 index 000000000000..70972da86370 --- /dev/null +++ b/weave/shared/url_safety.py @@ -0,0 +1,93 @@ +"""Server-side URL safety helpers. + +Used to gate any place where the trace server fetches a URL whose host +could come from an untrusted source: provider base URLs, upstream image +generation responses, etc. Primary protection belongs at the network +layer (pod egress policy); this module is belt-and-suspenders. + +`is_publicly_routable_url` rejects: +- non-http(s) schemes +- empty / missing host +- localhost variants and reserved cloud-metadata hostnames +- IP literals that are not globally routable, including alternative IPv4 + encodings (decimal, hex, octal) that bypass naive string checks +""" + +from __future__ import annotations + +import ipaddress +import re +import socket +from urllib.parse import urlparse + +ALLOWED_SCHEMES = frozenset({"http", "https"}) + +LOCALHOST_LITERALS = frozenset( + {"localhost", "localhost.localdomain", "ip6-localhost", "ip6-loopback"} +) + +# Symbolic hostnames that resolve to cloud instance-metadata services. +# Bare-IP IMDS (169.254.169.254) is covered by the is_global check below; +# this regex catches the hostnames that resolve to it. +BLOCKED_HOSTNAME_RE = re.compile( + r"(?:^|\.)" + r"(?:metadata\.google\.internal" + r"|metadata\.goog" + r"|metadata\.internal" + r"|metadata\.azure\.com" + r")\.?$", + re.IGNORECASE, +) + + +def is_publicly_routable_url(url: str) -> bool: + """Return True if `url` is safe to fetch from a server context. + + A URL is considered safe when it parses cleanly, uses http or https, + has a host that is neither a localhost literal nor a known cloud- + metadata hostname, and (if the host is an IP literal in any common + encoding) resolves to a globally routable address. + + Non-IP hostnames are accepted on the assumption that DNS resolution + and further egress filtering happen downstream. DNS-rebinding attacks + against hostnames are the egress policy's job, not this function's. + """ + try: + parsed = urlparse(url) + except ValueError: + return False + + if parsed.scheme.lower() not in ALLOWED_SCHEMES: + return False + + host = (parsed.hostname or "").lower().rstrip(".") + if not host: + return False + + if host in LOCALHOST_LITERALS or host.endswith(".localhost"): + return False + + if BLOCKED_HOSTNAME_RE.search(host): + return False + + addr: ipaddress.IPv4Address | ipaddress.IPv6Address | None = None + try: + addr = ipaddress.ip_address(host) + except ValueError: + try: + packed = socket.inet_aton(host) + except OSError: + return True + addr = ipaddress.ip_address(packed) + + # `is_global` alone is too permissive: it returns True for IPv4/IPv6 + # multicast (e.g. 224.0.0.1, ff00::/8). Enumerate the disallowed + # categories explicitly. + return not ( + addr.is_loopback + or addr.is_link_local + or addr.is_private + or addr.is_reserved + or addr.is_multicast + or addr.is_unspecified + ) diff --git a/weave/trace/api.py b/weave/trace/api.py index 65d9163cd0c7..e5d2940fa294 100644 --- a/weave/trace/api.py +++ b/weave/trace/api.py @@ -11,6 +11,7 @@ # TODO: type_handlers is imported here to trigger registration of the image serializer. # There is probably a better place for this, but including here for now to get the fix in. from weave import type_handlers # noqa: F401 +from weave.shared.builtin_object_classes import leaderboard from weave.shared.ids import generate_id from weave.trace import urls, weave_client, weave_init from weave.trace.autopatch import AutopatchSettings @@ -28,7 +29,6 @@ ) from weave.trace.table import Table from weave.trace.view_utils import set_call_view -from weave.trace_server.interface.builtin_object_classes import leaderboard from weave.trace_server_bindings.link_asset_to_registry import LinkAssetToRegistryRes from weave.type_wrappers.Content.content import Content diff --git a/weave/trace/base_objects.py b/weave/trace/base_objects.py index 44995b5b7fa3..1f9f9ad56143 100644 --- a/weave/trace/base_objects.py +++ b/weave/trace/base_objects.py @@ -1,4 +1,4 @@ -from weave.trace_server.interface.builtin_object_classes.builtin_object_registry import ( +from weave.shared.builtin_object_classes.builtin_object_registry import ( BUILTIN_OBJECT_REGISTRY, AlertSpec, AnnotationSpec, diff --git a/weave/trace/serialization/serialize.py b/weave/trace/serialization/serialize.py index 7b487d4d4207..4b3655d6d7bd 100644 --- a/weave/trace/serialization/serialize.py +++ b/weave/trace/serialization/serialize.py @@ -428,7 +428,7 @@ def from_json(obj: Any, project_id: str, server: TraceServerInterface) -> Any: return custom_objs.decode_custom_obj(encoded) elif isinstance(val_type, str) and obj.get("_class_name") == val_type: - from weave.trace_server.interface.builtin_object_classes.builtin_object_registry import ( + from weave.shared.builtin_object_classes.builtin_object_registry import ( BUILTIN_OBJECT_REGISTRY, ) diff --git a/weave/trace_server/constants.py b/weave/trace_server/constants.py index 7e178ea2feaa..86e917f4a593 100644 --- a/weave/trace_server/constants.py +++ b/weave/trace_server/constants.py @@ -1,3 +1,11 @@ """Back-compat shim: implementation moved to weave.shared.constants.""" +from typing import Any + +from weave.shared import constants as _impl from weave.shared.constants import * # noqa: F403 + + +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/helpers/url_safety.py b/weave/trace_server/helpers/url_safety.py index 70972da86370..bf5c9ec90184 100644 --- a/weave/trace_server/helpers/url_safety.py +++ b/weave/trace_server/helpers/url_safety.py @@ -1,93 +1,11 @@ -"""Server-side URL safety helpers. +"""Back-compat shim: implementation moved to weave.shared.url_safety.""" -Used to gate any place where the trace server fetches a URL whose host -could come from an untrusted source: provider base URLs, upstream image -generation responses, etc. Primary protection belongs at the network -layer (pod egress policy); this module is belt-and-suspenders. +from typing import Any -`is_publicly_routable_url` rejects: -- non-http(s) schemes -- empty / missing host -- localhost variants and reserved cloud-metadata hostnames -- IP literals that are not globally routable, including alternative IPv4 - encodings (decimal, hex, octal) that bypass naive string checks -""" +from weave.shared import url_safety as _impl +from weave.shared.url_safety import * # noqa: F403 -from __future__ import annotations -import ipaddress -import re -import socket -from urllib.parse import urlparse - -ALLOWED_SCHEMES = frozenset({"http", "https"}) - -LOCALHOST_LITERALS = frozenset( - {"localhost", "localhost.localdomain", "ip6-localhost", "ip6-loopback"} -) - -# Symbolic hostnames that resolve to cloud instance-metadata services. -# Bare-IP IMDS (169.254.169.254) is covered by the is_global check below; -# this regex catches the hostnames that resolve to it. -BLOCKED_HOSTNAME_RE = re.compile( - r"(?:^|\.)" - r"(?:metadata\.google\.internal" - r"|metadata\.goog" - r"|metadata\.internal" - r"|metadata\.azure\.com" - r")\.?$", - re.IGNORECASE, -) - - -def is_publicly_routable_url(url: str) -> bool: - """Return True if `url` is safe to fetch from a server context. - - A URL is considered safe when it parses cleanly, uses http or https, - has a host that is neither a localhost literal nor a known cloud- - metadata hostname, and (if the host is an IP literal in any common - encoding) resolves to a globally routable address. - - Non-IP hostnames are accepted on the assumption that DNS resolution - and further egress filtering happen downstream. DNS-rebinding attacks - against hostnames are the egress policy's job, not this function's. - """ - try: - parsed = urlparse(url) - except ValueError: - return False - - if parsed.scheme.lower() not in ALLOWED_SCHEMES: - return False - - host = (parsed.hostname or "").lower().rstrip(".") - if not host: - return False - - if host in LOCALHOST_LITERALS or host.endswith(".localhost"): - return False - - if BLOCKED_HOSTNAME_RE.search(host): - return False - - addr: ipaddress.IPv4Address | ipaddress.IPv6Address | None = None - try: - addr = ipaddress.ip_address(host) - except ValueError: - try: - packed = socket.inet_aton(host) - except OSError: - return True - addr = ipaddress.ip_address(packed) - - # `is_global` alone is too permissive: it returns True for IPv4/IPv6 - # multicast (e.g. 224.0.0.1, ff00::/8). Enumerate the disallowed - # categories explicitly. - return not ( - addr.is_loopback - or addr.is_link_local - or addr.is_private - or addr.is_reserved - or addr.is_multicast - or addr.is_unspecified - ) +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/ids.py b/weave/trace_server/ids.py index 45708dad8f8a..9a5f456ed477 100644 --- a/weave/trace_server/ids.py +++ b/weave/trace_server/ids.py @@ -1,3 +1,11 @@ """Back-compat shim: implementation moved to weave.shared.ids.""" +from typing import Any + +from weave.shared import ids as _impl from weave.shared.ids import * # noqa: F403 + + +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/alert_spec.py b/weave/trace_server/interface/builtin_object_classes/alert_spec.py index 601a30002702..8790176c77f0 100644 --- a/weave/trace_server/interface/builtin_object_classes/alert_spec.py +++ b/weave/trace_server/interface/builtin_object_classes/alert_spec.py @@ -1,75 +1,11 @@ -from typing import Literal +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.alert_spec.""" -from pydantic import BaseModel, ConfigDict, Field +from typing import Any -from weave.trace_server.interface.builtin_object_classes import base_object_def +from weave.shared.builtin_object_classes import alert_spec as _impl +from weave.shared.builtin_object_classes.alert_spec import * # noqa: F403 -class WeaveMetricThresholdSpec(BaseModel): - """Alert specification for weave metric threshold alerts. - - Fields align with gorilla's WeaveMetricThresholdFilter and the - alert_worker's WeaveMetricFilter. Extra keys are permitted so that - the schema can evolve without breaking existing stored objects. - """ - - model_config = ConfigDict(extra="allow") - - alert_condition: str = Field( - default="THRESHOLD", - description="Trigger condition type (e.g. 'THRESHOLD')", - ) - comparison_operator: Literal["GREATER_THAN", "LESS_THAN", "EQUAL"] = Field( - default="GREATER_THAN", - description="Direction of the threshold comparison", - ) - metric_path: str = Field( - default="", - description="Dot-notation path to the metric value in the call (e.g. 'output.score')", - ) - op_str: str = Field( - default="", - description="Op or scorer ref to scope this alert to; empty means all ops", - ) - threshold: float = Field( - default=0.0, - description="Threshold value that triggers the alert", - ) - window_size: int | None = Field( - default=None, - description="Max number of recent calls to include in the window", - ) - window_duration: int | None = Field( - default=None, - description="Time window in seconds for historical calls", - ) - monitor_ref: str | None = Field( - default=None, - description="Monitor ref; presence indicates alert was created from a monitor", - ) - scorer_ref: str | None = Field( - default=None, - description="Scorer object ref identifying which scorer instance within a monitor to alert on. " - "Disambiguates scorers that share the same op class (e.g. multiple LLMAsAJudgeScorer instances). " - "Filterable via inputs.self on scorer calls or input_refs on CallsFilter.", - ) - aggregation_function: Literal["mean", "median", "min", "max", "mode"] = Field( - default="mean", - description="Aggregation function applied to metric values within the window before threshold comparison", - ) - - -class AlertSpec(base_object_def.BaseObject): - op_scope: list[str] | None = Field( - default=None, - description="If provided, this alert only applies to calls from the given op refs", - examples=[ - ["weave:///entity/project/op/name:digest"], - ["weave:///entity/project/op/name:*"], - ], - ) - - spec: WeaveMetricThresholdSpec = Field( - default_factory=WeaveMetricThresholdSpec, - description="Alert specification (threshold config, window config, etc.)", - ) +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/annotation_spec.py b/weave/trace_server/interface/builtin_object_classes/annotation_spec.py index 05c9143be068..6de1c9dad73f 100644 --- a/weave/trace_server/interface/builtin_object_classes/annotation_spec.py +++ b/weave/trace_server/interface/builtin_object_classes/annotation_spec.py @@ -1,126 +1,11 @@ -from typing import Any - -from pydantic import BaseModel, Field, create_model, field_validator, model_validator -from pydantic.fields import FieldInfo - -from weave.trace_server.interface.builtin_object_classes import base_object_def - -SUPPORTED_PRIMITIVES = (int, float, bool, str) - - -class AnnotationSpec(base_object_def.BaseObject): - field_schema: dict[str, Any] = Field( - default={}, - description="Expected to be valid JSON Schema. Can be provided as a dict, a Pydantic model class, a tuple of a primitive type and a Pydantic Field, or primitive type", - examples=[ - # String feedback - {"type": "string", "maxLength": 100}, - # Number feedback - {"type": "number", "minimum": 0, "maximum": 100}, - # Integer feedback - {"type": "integer", "minimum": 0, "maximum": 100}, - # Boolean feedback - {"type": "boolean"}, - # Categorical feedback - {"type": "string", "enum": ["option1", "option2"]}, - ], - ) - - # TODO - # If true, all unique creators will have their - # own value for this feedback type. Otherwise, - # by default, the value is shared and can be edited. - unique_among_creators: bool = False - - # TODO - # If provided, this feedback type will only be shown - # when a call is generated from the given op ref - op_scope: list[str] | None = Field( - default=None, - examples=[ - ["weave:///entity/project/op/name:digest"], - ["weave:///entity/project/op/name:*"], - ], - ) - - @model_validator(mode="before") - @classmethod - def preprocess_field_schema(cls, data: dict[str, Any]) -> dict[str, Any]: - if "field_schema" not in data: - return data - - field_schema = data["field_schema"] +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.annotation_spec.""" - temp_field_tuple = None - # Handle Pydantic Field - if isinstance(field_schema, tuple): - if len(field_schema) != 2: - raise ValueError("Expected a tuple of length 2") - annotation, field = field_schema - if ( - not isinstance(annotation, type) - ) or annotation not in SUPPORTED_PRIMITIVES: - raise TypeError("Expected annotation to be a primitive type") - if not isinstance(field, FieldInfo): - raise TypeError("Expected field to be a Pydantic Field") - temp_field_tuple = (annotation, field) - elif field_schema in SUPPORTED_PRIMITIVES: - temp_field_tuple = (field_schema, Field()) - - if temp_field_tuple is not None: - # Create a temporary model to leverage Pydantic's schema generation - TempModel = create_model("TempModel", field=temp_field_tuple) # noqa: N806 - - schema = TempModel.model_json_schema()["properties"]["field"] - - if ( - "title" in schema and schema["title"] == "Field" - ): # default title for Field - schema.pop("title") - - data["field_schema"] = schema - return data - - # Handle Pydantic model - if isinstance(field_schema, type) and issubclass(field_schema, BaseModel): - # Read back through `data` (typed `Any`): pydantic is invisible to - # the type-check env, so `issubclass` cannot narrow off bare `type`. - data["field_schema"] = data["field_schema"].model_json_schema() - return data - - return data - - @field_validator("field_schema") - def validate_field_schema(cls, schema: dict[str, Any]) -> dict[str, Any]: # noqa: N805 - # Imported lazily: `import jsonschema` eagerly loads every installed - # format-checker lib (rfc3987_syntax builds a Lark grammar, ~0.45s), - # so keeping it out of module scope avoids paying that on `import weave`. - import jsonschema - - # Validate the schema - try: - jsonschema.validate(None, schema) - except jsonschema.exceptions.SchemaError: - raise - except jsonschema.exceptions.ValidationError: - pass # we don't care that `None` does not conform - return schema - - def value_is_valid(self, payload: Any) -> bool: - """Validates a payload against this annotation spec's schema. +from typing import Any - Args: - payload: The data to validate against the schema +from weave.shared.builtin_object_classes import annotation_spec as _impl +from weave.shared.builtin_object_classes.annotation_spec import * # noqa: F403 - Returns: - bool: True if validation succeeds, False otherwise - """ - # Lazy import: see note in validate_field_schema (avoids ~0.45s of - # jsonschema format-checker imports at `import weave` time). - import jsonschema - try: - jsonschema.validate(payload, self.field_schema) - except jsonschema.exceptions.ValidationError: - return False - return True +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/base_object_def.py b/weave/trace_server/interface/builtin_object_classes/base_object_def.py index 07574e7e547c..8a2571af85de 100644 --- a/weave/trace_server/interface/builtin_object_classes/base_object_def.py +++ b/weave/trace_server/interface/builtin_object_classes/base_object_def.py @@ -1,10 +1,11 @@ -import pydantic +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.base_object_def.""" -RefStr = str +from typing import Any +from weave.shared.builtin_object_classes import base_object_def as _impl +from weave.shared.builtin_object_classes.base_object_def import * # noqa: F403 -# This is just an alternative to weave.Object for the server side. -# I _think_ this will go away once we have the full weave system on the server -class BaseObject(pydantic.BaseModel): - name: str | None = None - description: str | None = None + +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py b/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py index 5b3981b8a565..abeb946b288e 100644 --- a/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py +++ b/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py @@ -1,52 +1,11 @@ -from weave.trace_server.interface.builtin_object_classes.alert_spec import AlertSpec -from weave.trace_server.interface.builtin_object_classes.annotation_spec import ( - AnnotationSpec, -) -from weave.trace_server.interface.builtin_object_classes.base_object_def import ( - BaseObject, -) -from weave.trace_server.interface.builtin_object_classes.comparison_view import ( - ComparisonView, -) -from weave.trace_server.interface.builtin_object_classes.leaderboard import Leaderboard -from weave.trace_server.interface.builtin_object_classes.llm_structured_model import ( - LLMStructuredCompletionModel, -) -from weave.trace_server.interface.builtin_object_classes.provider import ( - Provider, - ProviderModel, -) -from weave.trace_server.interface.builtin_object_classes.saved_view import ( - ChartConfig, - SavedView, -) -from weave.trace_server.interface.builtin_object_classes.test_only_example import ( - TestOnlyExample, - TestOnlyInheritedBaseObject, - TestOnlyNestedBaseObject, -) +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.builtin_object_registry.""" -BUILTIN_OBJECT_REGISTRY: dict[str, type[BaseObject]] = {} +from typing import Any +from weave.shared.builtin_object_classes import builtin_object_registry as _impl +from weave.shared.builtin_object_classes.builtin_object_registry import * # noqa: F403 -def register_base_object(cls: type[BaseObject]) -> None: - """Register a BaseObject class in the global registry. - Args: - cls: The BaseObject class to register - """ - BUILTIN_OBJECT_REGISTRY[cls.__name__] = cls - - -register_base_object(TestOnlyExample) -register_base_object(TestOnlyNestedBaseObject) -register_base_object(TestOnlyInheritedBaseObject) -register_base_object(Leaderboard) -register_base_object(AlertSpec) -register_base_object(AnnotationSpec) -register_base_object(Provider) -register_base_object(ProviderModel) -register_base_object(SavedView) -register_base_object(ComparisonView) -register_base_object(LLMStructuredCompletionModel) -register_base_object(ChartConfig) +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/comparison_view.py b/weave/trace_server/interface/builtin_object_classes/comparison_view.py index ae7d2973768c..f832c9f65c4e 100644 --- a/weave/trace_server/interface/builtin_object_classes/comparison_view.py +++ b/weave/trace_server/interface/builtin_object_classes/comparison_view.py @@ -1,50 +1,11 @@ -"""ComparisonView builtin object class for saving comparison view configurations. +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.comparison_view.""" -This allows users to save and restore comparison configurations including -evaluation call IDs and selected metrics. -""" +from typing import Any -from pydantic import BaseModel +from weave.shared.builtin_object_classes import comparison_view as _impl +from weave.shared.builtin_object_classes.comparison_view import * # noqa: F403 -from weave.trace_server.interface.builtin_object_classes import base_object_def - -class ComparisonViewDefinition(BaseModel): - """Definition of a comparison view's configuration. - - Args: - evaluation_call_ids (list[str]): List of evaluation call IDs being compared. - selected_metrics (list[str] | None): List of metrics that are visible in plots. - - Examples: - >>> definition = ComparisonViewDefinition( - ... evaluation_call_ids=["call_1", "call_2"], - ... selected_metrics=["accuracy", "f1_score"] - ... ) - """ - - evaluation_call_ids: list[str] - selected_metrics: list[str] | None = None - - -class ComparisonView(base_object_def.BaseObject): - """A saved comparison view configuration. - - Args: - label (str): Human-readable name for the comparison view. - definition (ComparisonViewDefinition): The view's configuration. - - Examples: - >>> view = ComparisonView( - ... label="My Comparison", - ... definition=ComparisonViewDefinition( - ... evaluation_call_ids=["call_1", "call_2"] - ... ) - ... ) - """ - - label: str - definition: ComparisonViewDefinition - - -__all__ = ["ComparisonView", "ComparisonViewDefinition"] +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/leaderboard.py b/weave/trace_server/interface/builtin_object_classes/leaderboard.py index 30171233a459..4ca0daf5bde7 100644 --- a/weave/trace_server/interface/builtin_object_classes/leaderboard.py +++ b/weave/trace_server/interface/builtin_object_classes/leaderboard.py @@ -1,14 +1,11 @@ -from pydantic import BaseModel +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.leaderboard.""" -from weave.trace_server.interface.builtin_object_classes import base_object_def +from typing import Any +from weave.shared.builtin_object_classes import leaderboard as _impl +from weave.shared.builtin_object_classes.leaderboard import * # noqa: F403 -class LeaderboardColumn(BaseModel): - evaluation_object_ref: base_object_def.RefStr - scorer_name: str - summary_metric_path: str - should_minimize: bool | None = None - -class Leaderboard(base_object_def.BaseObject): - columns: list[LeaderboardColumn] +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/llm_structured_model.py b/weave/trace_server/interface/builtin_object_classes/llm_structured_model.py index 0024ee4c3174..c191da658071 100644 --- a/weave/trace_server/interface/builtin_object_classes/llm_structured_model.py +++ b/weave/trace_server/interface/builtin_object_classes/llm_structured_model.py @@ -1,373 +1,11 @@ -import json -from typing import Annotated, Any, Literal +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.llm_structured_model.""" -from pydantic import BaseModel, BeforeValidator, Field +from typing import Any -from weave import Model, op -from weave.prompt.prompt import format_message_with_template_vars -from weave.trace import vals -from weave.trace.context.weave_client_context import WeaveInitError, get_weave_client -from weave.trace_server.interface.builtin_object_classes import base_object_def -from weave.trace_server.trace_server_interface import ( - CompletionsCreateReq, - CompletionsCreateRequestInputs, -) -from weave.utils.project_id import to_project_id +from weave.shared.builtin_object_classes import llm_structured_model as _impl +from weave.shared.builtin_object_classes.llm_structured_model import * # noqa: F403 -ResponseFormat = Literal["json_object", "json_schema", "text"] - -def is_response_format(value: Any) -> bool: - return isinstance(value, str) and value in {"json_object", "text"} - - -class Message(BaseModel): - """A message in a conversation with an LLM. - - Attributes: - role: The role of the message's author. Can be: system, user, assistant, function or tool. - content: The contents of the message. Required for all messages, but may be null for assistant messages with function calls. - name: The name of the author of the message. Required if role is "function". Must match the name of the function represented in content. - Can contain characters (a-z, A-Z, 0-9), and underscores, with a maximum length of 64 characters. - function_call: The name and arguments of a function that should be called, as generated by the model. - tool_call_id: Tool call that this message is responding to. - """ - - role: str - content: str | list[dict] | None = None - name: str | None = None - function_call: dict | None = None - tool_call_id: str | None = None - - -class LLMStructuredCompletionModelDefaultParams(BaseModel): - """Default parameters for LLMStructuredCompletionModel. - - Attributes: - messages_template: A list of Messages to use as a template. Messages can contain - template variables using {variable_name} syntax. These will be substituted - when predict() is called with template_vars. - prompt: A reference string to a MessagesPrompt object. If provided, this takes - precedence over messages_template. The referenced prompt's format() method will - be used to generate messages with template variable substitution. - Example: "weave:///entity/project/object/my_prompt:latest" - """ - - # This is a list of Messages, loosely following litellm's message format - # https://docs.litellm.ai/docs/completion/input#properties-of-messages - messages_template: list[Message] | None = None - prompt: base_object_def.RefStr | None = None - - temperature: float | None = None - top_p: float | None = None - max_tokens: int | None = None - presence_penalty: float | None = None - frequency_penalty: float | None = None - stop: list[str] | None = None - n_times: int | None = None - functions: list[dict] | None = None - - # Either json, text, or json_schema - response_format: ResponseFormat | None = None - - # TODO: Currently not used. Fast follow up with json_schema - # if default_params.response_format is set to JSON_SCHEMA, this will be used - # response_format_schema: dict | None = None - - -def cast_to_message_list(obj: Any) -> list[Message]: - if isinstance(obj, Message): - return [obj] - elif isinstance(obj, dict): - return [Message.model_validate(obj)] - elif isinstance(obj, str): - return [Message(content=obj, role="user")] - elif isinstance(obj, list): - return [cast_to_message(item) for item in obj] - raise TypeError("Unable to cast to Message") - - -def cast_to_message(obj: Any) -> Message: - if isinstance(obj, Message): - return obj - elif isinstance(obj, dict): - return Message.model_validate(obj) - elif isinstance(obj, str): - return Message(content=obj, role="user") - raise TypeError("Unable to cast to Message") - - -def cast_to_llm_structured_model_params( - obj: Any, -) -> LLMStructuredCompletionModelDefaultParams: - if isinstance(obj, LLMStructuredCompletionModelDefaultParams): - return obj - elif isinstance(obj, dict): - return LLMStructuredCompletionModelDefaultParams.model_validate(obj) - elif isinstance(obj, vals.Traceable): - return LLMStructuredCompletionModelDefaultParams.model_validate( - vals.unwrap(obj) # Recursively "unwrap" to a dict with plain python types - ) - - raise TypeError("Unable to cast to LLMStructuredCompletionModelDefaultParams") - - -MessageListLike = Annotated[list[Message], BeforeValidator(cast_to_message_list)] -MessageLike = Annotated[Message, BeforeValidator(cast_to_message)] -LLMStructuredModelParamsLike = Annotated[ - LLMStructuredCompletionModelDefaultParams, - BeforeValidator(cast_to_llm_structured_model_params), -] - - -class LLMStructuredCompletionModel(Model): - # / or ref to a provider model - llm_model_id: str | base_object_def.RefStr - - default_params: LLMStructuredModelParamsLike = Field( - default_factory=LLMStructuredCompletionModelDefaultParams - ) - - @op - def predict( - self, - user_input: MessageListLike | None = None, - config: LLMStructuredModelParamsLike | None = None, - **template_vars: Any, - ) -> Message | str | dict[str, Any]: - """Generates a prediction by preparing messages (template + user_input) - and calling the LLM completions endpoint with overridden config, using the provided client. - - Messages are prepared in one of two ways: - 1. If default_params.prompt is set, the referenced MessagesPrompt object is - loaded and its format() method is called with template_vars to generate messages. - 2. If default_params.messages_template is set (and prompt is not), the template - messages are used with template variable substitution. - - Note: If both prompt and messages_template are provided, prompt takes precedence. - - Args: - user_input: The user input messages to append after template messages - config: Optional configuration to override default parameters - **template_vars: Variables to substitute in the messages template using {variable_name} syntax - """ - if user_input is None: - user_input = [] - - current_client = get_weave_client() - if current_client is None: - raise WeaveInitError( - "You must call `weave.init()` first, to predict with a LLMStructuredCompletionModel" - ) - - req = self.prepare_completion_request( - project_id=to_project_id(current_client.entity, current_client.project), - user_input=user_input, - config=config, - **template_vars, - ) - - # 5. Call the LLM API - try: - api_response = current_client.server.completions_create(req=req) - except Exception as e: - raise RuntimeError("Failed to call LLM completions endpoint.") from e - - # 6. Extract the message from the API response - try: - # The 'response' attribute of CompletionsCreateRes is a dict - response_payload = api_response.response - response_format = ( - req.inputs.response_format.get("type") - if req.inputs.response_format is not None - else None - ) - return parse_response(response_payload, response_format) - except ( - KeyError, - IndexError, - TypeError, - AttributeError, - ValueError, - json.JSONDecodeError, - ) as e: - raise RuntimeError( - f"Failed to extract message from LLM response payload. Response: {api_response.response}" - ) from e - - def prepare_completion_request( - self, - project_id: str, - user_input: MessageListLike, - config: LLMStructuredModelParamsLike | None, - **template_vars: Any, - ) -> CompletionsCreateReq: - # Ensure user_input is properly converted to a list of Message objects - # This is needed because the @op decorator might interfere with Pydantic validation - if not isinstance(user_input, list) or ( - user_input and not isinstance(user_input[0], Message) - ): - user_input = cast_to_message_list(user_input) - - # 1. Prepare messages from messages_template (if no prompt is set) - # Note: If prompt is set, we don't prepare messages here - we pass the prompt - # reference to the completions endpoint which will resolve and substitute it - template_msgs = None - - # Only use messages_template if prompt is NOT set - if ( - self.default_params - and self.default_params.messages_template - and not self.default_params.prompt - ): - template_msgs = self.default_params.messages_template - if template_vars: - # Convert Message objects to dicts, apply template vars, convert back - formatted_dicts = [ - format_message_with_template_vars( - msg.model_dump(exclude_none=True), **template_vars - ) - for msg in template_msgs - ] - template_msgs = [Message.model_validate(d) for d in formatted_dicts] - - prepared_messages_dicts = _prepare_llm_messages(template_msgs, user_input) - - # 2. Prepare completion parameters, starting with defaults from LLMStructuredCompletionModel - completion_params: dict[str, Any] = {} - default_p_model = self.default_params - if default_p_model: - completion_params = parse_params_to_litellm_params(default_p_model) - - # 3. Override parameters with the provided config dictionary - if config: - completion_params = { - **completion_params, - **parse_params_to_litellm_params(config), - } - - # 4. Create the completion inputs - model_id_str = str(self.llm_model_id) - - # Include template_vars if they exist - if template_vars: - completion_params["template_vars"] = template_vars - - completion_inputs = CompletionsCreateRequestInputs( - model=model_id_str, messages=prepared_messages_dicts, **completion_params - ) - req = CompletionsCreateReq( - project_id=project_id, - inputs=completion_inputs, - ) - - return req - - -def parse_response( - response_payload: dict, response_format: ResponseFormat | None -) -> Message | str | dict[str, Any]: - """Extract the model output from an LLM completion response payload. - - Raises: - RuntimeError: the provider returned a top-level `error` field. - ValueError: the payload is malformed (missing choices/message), the - content is None/empty, or json_object parsing failed. - """ - if response_payload.get("error"): - raise RuntimeError(f"LLM API returned an error: {response_payload['error']}") - - choices = response_payload.get("choices") - if not choices: - raise ValueError( - "LLM response is missing 'choices' -> the upstream call likely failed " - "(invalid API key, content filtering, or provider error). " - f"Response keys: {sorted(response_payload.keys())}" - ) - - message = choices[0].get("message") if isinstance(choices[0], dict) else None - if not isinstance(message, dict): - raise TypeError( - f"LLM response choice did not contain a message dict: {choices[0]!r}" - ) - content = message.get("content") - - if response_format == "text": - if content is None: - raise ValueError( - "LLM response content is None -> the model returned no text. " - "Check your API key, model config, and content filtering settings." - ) - return content - elif response_format == "json_object": - if content is None or (isinstance(content, str) and not content.strip()): - raise ValueError( - "LLM response content was empty when JSON output was requested. " - "Check your API key and that the model supports JSON mode." - ) - try: - return json.loads(content) - except json.JSONDecodeError as e: - snippet = content if len(content) <= 200 else content[:200] + "..." - raise ValueError( - f"LLM response was not valid JSON (response_format=json_object). " - f"Content snippet: {snippet!r}" - ) from e - else: - raise ValueError(f"Invalid response_format: {response_format}") - - -def _prepare_llm_messages( - template_messages: list[Message] | None, - user_input: list[Message], -) -> list[dict[str, Any]]: - """Prepares a list of message dictionaries for the LLM API from a message template and user input. - Helper function for PlaygroundModel.predict. - Returns a list of message dictionaries. - """ - final_messages_dicts: list[dict[str, Any]] = [] - - # 1. Initialize messages from template - if template_messages: - for msg_template in template_messages: - msg_dict = msg_template.model_dump(exclude_none=True) - final_messages_dicts.append(msg_dict) - - # 2. Append user_input messages - for u_msg in user_input: - final_messages_dicts.append(u_msg.model_dump(exclude_none=True)) - - return final_messages_dicts - - -def parse_params_to_litellm_params( - params_source: LLMStructuredCompletionModelDefaultParams, -) -> dict[str, Any]: - final_params: dict[str, Any] = {} - source_dict_to_iterate: dict[str, Any] = params_source.model_dump(exclude_none=True) - - for key, value in source_dict_to_iterate.items(): - if key == "response_format": - litellm_response_format_value = None - if isinstance(value, str) and is_response_format(value): - litellm_response_format_value = {"type": value} - elif ( - isinstance(value, dict) - and "type" in value - and is_response_format(value["type"]) - ): # Pre-formed dict with valid type - litellm_response_format_value = value - - if litellm_response_format_value is not None: - final_params["response_format"] = litellm_response_format_value - elif key == "n_times": - final_params["n"] = value - elif key == "messages_template": - pass - elif key in {"functions", "stop"}: - if isinstance(value, list) and len(value) > 0: - final_params[key] = value - else: - final_params[key] = value - - return final_params +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/provider.py b/weave/trace_server/interface/builtin_object_classes/provider.py index ffea5de4eef9..5e9c8ffc5543 100644 --- a/weave/trace_server/interface/builtin_object_classes/provider.py +++ b/weave/trace_server/interface/builtin_object_classes/provider.py @@ -1,70 +1,11 @@ -import re -from enum import Enum -from urllib.parse import urlparse +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.provider.""" -from pydantic import ConfigDict, Field, field_validator +from typing import Any -from weave.trace_server.helpers.url_safety import is_publicly_routable_url -from weave.trace_server.interface.builtin_object_classes import base_object_def +from weave.shared.builtin_object_classes import provider as _impl +from weave.shared.builtin_object_classes.provider import * # noqa: F403 -# Headers that must not appear in user-supplied extra_headers. -# https://coreweave.atlassian.net/browse/VULNMGMT-770 -BLOCKED_HEADER_RE = re.compile( - r"^(?:metadata-flavor" - r"|x-aws-ec2-metadata-token(?:-ttl-seconds)?" - r")$", - re.IGNORECASE, -) - -INVALID_BASE_URL_MSG = "base_url is not a valid provider URL" - - -def _validate_provider_base_url(url: str) -> str: - """Validate that a provider base_url is a well-formed, publicly-routable HTTP(S) URL. - - See https://coreweave.atlassian.net/browse/VULNMGMT-770 - """ - # urlparse silently strips a bare trailing '?', so check the raw string too. - if "?" in url: - raise ValueError(INVALID_BASE_URL_MSG) - try: - parsed = urlparse(url) - except ValueError as exc: - raise ValueError(INVALID_BASE_URL_MSG) from exc - if parsed.fragment: - raise ValueError(INVALID_BASE_URL_MSG) - if not is_publicly_routable_url(url): - raise ValueError(INVALID_BASE_URL_MSG) - return url - - -class ProviderReturnType(str, Enum): - OPENAI = "openai" - - -class Provider(base_object_def.BaseObject): - model_config = ConfigDict(validate_assignment=True) - - base_url: str - api_key_name: str - extra_headers: dict[str, str] = Field(default_factory=dict) - return_type: ProviderReturnType = Field(default=ProviderReturnType.OPENAI) - - @field_validator("base_url") - @classmethod - def validate_base_url(cls, v: str) -> str: - return _validate_provider_base_url(v) - - @field_validator("extra_headers") - @classmethod - def validate_extra_headers(cls, v: dict[str, str]) -> dict[str, str]: - for key in v: - if BLOCKED_HEADER_RE.match(key): - raise ValueError("extra_headers contains a disallowed header") - return v - - -class ProviderModel(base_object_def.BaseObject): - provider: base_object_def.RefStr - max_tokens: int +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/saved_view.py b/weave/trace_server/interface/builtin_object_classes/saved_view.py index d7c51b6c952c..d617dbd51899 100644 --- a/weave/trace_server/interface/builtin_object_classes/saved_view.py +++ b/weave/trace_server/interface/builtin_object_classes/saved_view.py @@ -1,131 +1,11 @@ -from typing import Literal +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.saved_view.""" -from pydantic import BaseModel, Field +from typing import Any -from weave.trace_server import trace_server_interface as tsi -from weave.trace_server.common_interface import SortBy -from weave.trace_server.interface.builtin_object_classes import base_object_def +from weave.shared.builtin_object_classes import saved_view as _impl +from weave.shared.builtin_object_classes.saved_view import * # noqa: F403 -PathElement = str | int - -class Pin(BaseModel): - left: list[str] - right: list[str] - - -class Column(BaseModel): - # Optional in case we want something like computed columns in the future. - path: list[PathElement] | None = Field(default=None) - label: str | None = Field(default=None) - - -class ChartConfig(BaseModel): - x_axis: str = Field(title="XAxis") - y_axis: str = Field(title="YAxis") - plot_type: Literal["scatter", "line", "bar"] | None = Field( - default=None, - ) - bin_count: int | None = Field(default=None) - aggregation: Literal["average", "sum", "min", "max", "p95", "p99"] | None = Field( - default=None - ) - group_keys: list[str] | None = Field(default=None) - custom_name: str | None = Field(default=None) - - -class ObjectVersionGroup(BaseModel): - label: str # label for the combination of the groups - base_ref: str - versions: list[str] | Literal["*"] - show_version_indicator: bool - - -class ObjectConfig(BaseModel): - version_groups: list[ObjectVersionGroup] | None = Field(default=None) - display_name_map: dict[str, str] | None = Field( - default=None - ) # obj -> display name (keys can use "*" wildcards) - deselected: list[str] | None = Field( - default=None - ) # List of dataset refs or patterns to exclude - - -class DynamicLeaderboardColumnConfig(BaseModel): - evaluation_object_ref: base_object_def.RefStr | None = Field(default=None) - scorer_name: str | None = Field(default=None) - display_name: str | None = Field(default=None) - summary_metric_path: str | None = Field(default=None) - should_minimize: bool | None = Field(default=None) - deselected: bool | None = Field( - default=None - ) # If True, this metric is excluded from the leaderboard - - -class DynamicLeaderboardConfig(BaseModel): - # These are initialized to empty lists and dicts by default (show everything) - model_configuration: ObjectConfig | None = Field(default=None) - dataset_configuration: ObjectConfig | None = Field(default=None) - scorer_configuration: ObjectConfig | None = Field(default=None) - # Only has entries when a column is marked as deselected or minimized - columns_configuration: list[DynamicLeaderboardColumnConfig] | None = Field( - default=None - ) - - -class SavedViewDefinition(BaseModel): - filter: tsi.CallsFilter | None = Field(default=None) - - query: tsi.Query | None = Field(default=None) - - # cols is the current UI column visibility config that - # doesn't allow specifying column order - prefer use of - # explicit columns list which is what we should work towards. - cols: dict[str, bool] | None = Field(default=None) - - # columns is specifying exactly which columns to include - # including order. - columns: list[Column] | None = Field(default=None) - - # column_order is a simple ordered list of column field names. - # Used by the frontend to persist user-defined column ordering. - column_order: list[str] | None = Field(default=None) - - # Paths to columns whose values are refs to other objects and that must be - # dereferenced for filtering / sorting / display to work. Mirrors the - # `expand_columns` field of `CallsQueryReq` — when a saved view filters on - # a sub-field of a referenced object (e.g. `inputs.self.base_model_name`), - # the parent ref path (`inputs.self`) must be in this list so the trace - # server joins through to the referenced object at query time. - expand_columns: list[str] | None = Field(default=None) - - header_depth: int | None = Field(default=None) - - pin: Pin | None = Field(default=None) - sort_by: list[SortBy] | None = Field(default=None) - page: int | None = Field(default=None) - page_size: int | None = Field(default=None) - charts: list[ChartConfig] | None = Field(default=None) - - # Evaluations calls table has dataset and evaluation object - # selectors that can be used to filter down evals to those using these objects. - # The selector is an object ref where the version can either be a digest or `*` - # to match all versions. - dataset_selector: str | None = Field(default=None) - evaluation_selector: str | None = Field(default=None) - - # Dynamic leaderboards are populated by the evals in a saved view - dynamic_leaderboard_config: DynamicLeaderboardConfig | None = Field(default=None) - - -class SavedView(base_object_def.BaseObject): - # "traces" or "evaluations", type is str for extensibility - view_type: str - - # Avoiding confusion around object_id + name - label: str - - definition: SavedViewDefinition - - -__all__ = ["SavedView"] +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/builtin_object_classes/test_only_example.py b/weave/trace_server/interface/builtin_object_classes/test_only_example.py index ff2b197d6241..3b56c262834d 100644 --- a/weave/trace_server/interface/builtin_object_classes/test_only_example.py +++ b/weave/trace_server/interface/builtin_object_classes/test_only_example.py @@ -1,34 +1,11 @@ -from pydantic import BaseModel, Field +"""Back-compat shim: implementation moved to weave.shared.builtin_object_classes.test_only_example.""" -from weave.trace_server.interface.builtin_object_classes import base_object_def +from typing import Any +from weave.shared.builtin_object_classes import test_only_example as _impl +from weave.shared.builtin_object_classes.test_only_example import * # noqa: F403 -class TestOnlyNestedBaseModel(BaseModel): - a: int - aliased_property: int = Field(alias="aliased_property_alias") - -class TestOnlyNestedBaseObject(base_object_def.BaseObject): - b: int - - -class TestOnlyInheritedBaseObject(TestOnlyNestedBaseObject): - """A builtin object that inherits from another builtin object for testing inheritance.""" - - c: int - additional_field: str = "default_value" - - -class TestOnlyExample(base_object_def.BaseObject): - primitive: int - nested_base_model: TestOnlyNestedBaseModel - # Important: `RefStr` is just an alias for `str`. When defining `BaseObject`s, we - # should never have a property point to another `BaseObject`. This is because each - # base object is stored in the database and should be treated like a foreign key. - # - # It would be nice to have a way to ensure that no `BaseObject` has any `BaseObject` - # properties. - nested_base_object: base_object_def.RefStr - - -__all__ = ["TestOnlyExample", "TestOnlyInheritedBaseObject", "TestOnlyNestedBaseObject"] +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/interface/feedback_types.py b/weave/trace_server/interface/feedback_types.py index a7a33da3b1fd..985dc8364cb4 100644 --- a/weave/trace_server/interface/feedback_types.py +++ b/weave/trace_server/interface/feedback_types.py @@ -1,3 +1,11 @@ """Back-compat shim: implementation moved to weave.shared.feedback_types.""" +from typing import Any + +from weave.shared import feedback_types as _impl from weave.shared.feedback_types import * # noqa: F403 + + +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name) diff --git a/weave/trace_server/trace_server_converter.py b/weave/trace_server/trace_server_converter.py index b12e25d498f3..8d2b56cc332e 100644 --- a/weave/trace_server/trace_server_converter.py +++ b/weave/trace_server/trace_server_converter.py @@ -1,3 +1,11 @@ """Back-compat shim: implementation moved to weave.shared.trace_server_converter.""" +from typing import Any + +from weave.shared import trace_server_converter as _impl from weave.shared.trace_server_converter import * # noqa: F403 + + +def __getattr__(name: str) -> Any: + # Forward names that star-import misses (e.g. excluded by __all__). + return getattr(_impl, name)