Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions weave/trace_server/base64_content_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import re
from typing import Any, TypeVar

import ddtrace

from weave.trace_server.trace_server_interface import (
CallEndReq,
Expand All @@ -20,6 +19,7 @@
TraceServerInterface,
)
from weave.type_wrappers.Content.content import Content
from weave.trace_server.tracing import traced

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,7 +66,7 @@ def is_data_uri(data_uri: str) -> bool:
return DATA_URI_PATTERN.match(data_uri) is not None


@ddtrace.tracer.wrap(name="store_content_object")
@traced(name="store_content_object")
def store_content_object(
content_obj: Content,
project_id: str,
Expand Down
4 changes: 2 additions & 2 deletions weave/trace_server/clickhouse/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from collections.abc import Sequence
from typing import Any, TypeVar, cast

import ddtrace
import sqlparse
from clickhouse_connect.driver.client import Client as CHClient
from clickhouse_connect.driver.exceptions import DatabaseError
Expand All @@ -23,6 +22,7 @@
from weave.trace_server.datadog import set_current_span_dd_tags
from weave.trace_server.errors import InsertTooLarge
from weave.trace_server.kafka import KafkaProducer
from weave.trace_server.tracing import traced

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -380,7 +380,7 @@ def maybe_enqueue_minimal_call_end(
# ---------------------------------------------------------------------------


@ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.find_call_descendants")
@traced(name="clickhouse_trace_server_batched.find_call_descendants")
def find_call_descendants(
root_ids: list[str],
all_calls: list[tsi.CallSchema],
Expand Down
14 changes: 7 additions & 7 deletions weave/trace_server/eval_results_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from collections.abc import Callable, Iterable
from typing import Any

import ddtrace
from pydantic import ValidationError

from weave.shared import refs_internal as ri
Expand All @@ -22,6 +21,7 @@
from weave.trace_server import trace_server_common as tsc
from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.errors import InvalidRequest
from weave.trace_server.tracing import traced

_SUPPORTED_SORT_PREFIXES = (
"scores.",
Expand Down Expand Up @@ -367,7 +367,7 @@ def _build_trial(
)


@ddtrace.tracer.wrap(name="eval_results_helpers.build_eval_rows_from_calls")
@traced(name="eval_results_helpers.build_eval_rows_from_calls")
def build_eval_rows_from_calls(
predict_and_score_calls: list[tsi.CallSchema],
child_by_parent: dict[str, list[tsi.CallSchema]],
Expand Down Expand Up @@ -433,7 +433,7 @@ def finalize_rows(
return apply_row_selection(rows, eval_root_ids, require_intersection, offset, limit)


@ddtrace.tracer.wrap(name="eval_results_helpers.build_eval_rows")
@traced(name="eval_results_helpers.build_eval_rows")
def build_eval_rows(
page_calls: list[tsi.CallSchema],
eval_root_ids: list[str],
Expand Down Expand Up @@ -542,7 +542,7 @@ def resolve_eval_row_refs(
return []


@ddtrace.tracer.wrap(name="eval_results_helpers.eval_results_grouped_rows")
@traced(name="eval_results_helpers.eval_results_grouped_rows")
def eval_results_grouped_rows(
req: tsi.EvalResultsQueryReq,
eval_root_ids: list[str],
Expand Down Expand Up @@ -576,7 +576,7 @@ def eval_results_grouped_rows(
)


@ddtrace.tracer.wrap(name="eval_results_helpers.fetch_eval_root_metadata")
@traced(name="eval_results_helpers.fetch_eval_root_metadata")
def fetch_eval_root_metadata(
server: tsi.TraceServerInterface,
project_id: str,
Expand Down Expand Up @@ -618,7 +618,7 @@ def validate_eval_results_request(req: tsi.EvalResultsQueryReq) -> None:
)


@ddtrace.tracer.wrap(name="eval_results_helpers.eval_results_query")
@traced(name="eval_results_helpers.eval_results_query")
def eval_results_query(
server: tsi.TraceServerInterface,
req: tsi.EvalResultsQueryReq,
Expand Down Expand Up @@ -744,7 +744,7 @@ def _process_scorer_output(
)


@ddtrace.tracer.wrap(name="eval_results_helpers.compute_summary_from_rows")
@traced(name="eval_results_helpers.compute_summary_from_rows")
def compute_summary_from_rows(
rows: list[tsi.EvalResultsRow],
eval_call_metadata: dict[str, dict[str, Any]] | None = None,
Expand Down
4 changes: 2 additions & 2 deletions weave/trace_server/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import socket
from typing import Any

import ddtrace
from confluent_kafka import (
Consumer as ConfluentKafkaConsumer,
)
Expand All @@ -26,6 +25,7 @@
kafka_client_user,
kafka_producer_max_buffer_size,
)
from weave.trace_server.tracing import traced

CALL_ENDED_TOPIC = "weave.call_ended"
SCORE_CALLS_TOPIC = "weave.score_calls"
Expand Down Expand Up @@ -149,7 +149,7 @@ def produce_score_calls(
if flush_immediately:
self.flush(0)

@ddtrace.tracer.wrap(name="kafka_producer.produce_score_agent_spans")
@traced(name="kafka_producer.produce_score_agent_spans")
def produce_score_agent_spans(self, event: ScoreAgentSpansEvent) -> None:
"""Produce a weave.score_agent_spans event to Kafka.

Expand Down
4 changes: 2 additions & 2 deletions weave/trace_server/parallel_bucket_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass

import ddtrace

from weave.trace_server import clickhouse_trace_server_settings as ch_settings
from weave.trace_server import trace_server_interface as tsi
Expand All @@ -57,6 +56,7 @@
key_for_project_digest,
store_in_bucket,
)
from weave.trace_server.tracing import traced

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,7 +129,7 @@ def has(self, project_id: str, digest: str) -> bool:
def __bool__(self) -> bool:
return bool(self._pending)

@ddtrace.tracer.wrap(name="bucket_upload_batch.flush")
@traced(name="bucket_upload_batch.flush")
def flush(
self, client: FileStorageClient | None
) -> list[FileChunkCreateCHInsertable]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

import logging

import ddtrace
from clickhouse_connect.driver.client import Client as CHClient

from weave.trace_server.calls_query_builder.utils import param_slot
from weave.trace_server.orm import ParamBuilder
from weave.trace_server.project_version.types import ProjectDataResidence
from weave.trace_server.tracing import traced
from opentelemetry import trace

logger = logging.getLogger(__name__)


@ddtrace.tracer.wrap(name="clickhouse_project_version.get_project_data_residence")
@traced(name="clickhouse_project_version.get_project_data_residence")
def get_project_data_residence(
project_id: str, ch_client: CHClient
) -> ProjectDataResidence:
Expand Down Expand Up @@ -43,9 +44,11 @@ def get_project_data_residence(
has_complete = row[0]
has_merged = row[1]

root_span = ddtrace.tracer.current_root_span()
if root_span:
root_span.set_tags({"has_complete": has_complete, "has_merged": has_merged})
root_span = trace.get_current_span()
if root_span.is_recording():
root_span.set_attributes(
{"has_complete": has_complete, "has_merged": has_merged}
)

if has_complete and has_merged:
return ProjectDataResidence.BOTH
Expand Down
4 changes: 2 additions & 2 deletions weave/trace_server/project_version/project_version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import threading

import ddtrace
from cachetools import LRUCache
from clickhouse_connect.driver.client import Client as CHClient

Expand All @@ -18,6 +17,7 @@
ReadTable,
WriteTarget,
)
from weave.trace_server.tracing import _tracer # noqa: PLC2701

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +60,7 @@ def _get_residence(

# Only span the cache-miss path. Cache-hit calls are extremely high
# volume and produce noisy DD spans with no useful information.
with ddtrace.tracer.trace("table_routing.fetch_residence"):
with _tracer.start_as_current_span("table_routing.fetch_residence"):
residence = get_project_data_residence(project_id, ch_client)

set_root_span_dd_tags({"project_version.fetch_residence": residence.value})
Expand Down
12 changes: 6 additions & 6 deletions weave/trace_server/ttl_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import logging
import threading

import ddtrace
import redis
from cachetools import TTLCache
from clickhouse_connect.driver.client import Client as CHClient

from weave.trace_server.datadog import set_current_span_dd_tags
from weave.trace_server.redis_client import get_redis_client
from weave.trace_server.tracing import _tracer, traced # noqa: PLC2701

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,18 +66,18 @@ def get_project_retention_days(
set_current_span_dd_tags({"ttl.cache_hit": "L1"})
return cached

with ddtrace.tracer.trace("ttl_settings.get_project_retention_days") as span:
with _tracer.start_as_current_span("ttl_settings.get_project_retention_days") as span:
redis_client = get_redis_client()
if redis_client is not None:
redis_val = _l2_get(redis_client, project_id)
if redis_val is not None:
_l1_set(project_id, redis_val)
span.set_tag("ttl.cache_hit", "L2")
span.set_attribute("ttl.cache_hit", "L2")
return redis_val

retention_days = _query_clickhouse(ch_client, project_id)
span.set_tag("ttl.cache_hit", "clickhouse")
span.set_tag("ttl.retention_days", retention_days)
span.set_attribute("ttl.cache_hit", "clickhouse")
span.set_attribute("ttl.retention_days", retention_days)

if redis_client is not None:
_l2_set(redis_client, project_id, retention_days)
Expand Down Expand Up @@ -179,7 +179,7 @@ def _l2_delete(redis_client: redis.Redis, project_id: str) -> None:
logger.exception("Redis L2 cache delete failed for project %s", project_id)


@ddtrace.tracer.wrap(name="ttl_settings.query_clickhouse")
@traced(name="ttl_settings.query_clickhouse")
def _query_clickhouse(ch_client: CHClient, project_id: str) -> int:
"""Query ClickHouse for the latest retention_days via argMax.

Expand Down
4 changes: 2 additions & 2 deletions weave/trace_server/usage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from collections.abc import Iterable
from typing import Any

import ddtrace

from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.tracing import traced


@dataclasses.dataclass(frozen=True)
Expand All @@ -17,7 +17,7 @@ class UsageCall:
summary: dict[str, Any] | None


@ddtrace.tracer.wrap(name="usage_utils.aggregate_usage_with_descendants")
@traced(name="usage_utils.aggregate_usage_with_descendants")
def aggregate_usage_with_descendants(
calls: Iterable[UsageCall],
include_costs: bool,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
from abc import ABC, abstractmethod

import ddtrace

import weave
from weave.evaluation.eval import Evaluation
Expand All @@ -13,6 +12,7 @@
LLMStructuredCompletionModel,
)
from weave.trace_server.trace_server_interface import EvaluateModelArgs
from weave.trace_server.tracing import traced

EVALUATE_MODEL_WORKER_MARKER = {"_weave_eval_meta": {"evaluate_model_worker": True}}

Expand All @@ -34,7 +34,7 @@ def evaluate_model(args: EvaluateModelArgs) -> None:
_evaluate_model(args)


@ddtrace.tracer.wrap(name="evaluate_model_worker.evaluate_model")
@traced(name="evaluate_model_worker.evaluate_model")
def _evaluate_model(args: EvaluateModelArgs) -> None:
# This worker reconstructs user-supplied objects; it must never deserialize
# code-bearing custom objects (Op / load_op). The secure client locks the decode
Expand All @@ -51,7 +51,7 @@ def _evaluate_model(args: EvaluateModelArgs) -> None:
_run_evaluation(loaded_evaluation, loaded_model, args.evaluation_call_id)


@ddtrace.tracer.wrap(name="evaluate_model_worker.evaluate_model.get_valid_evaluation")
@traced(name="evaluate_model_worker.evaluate_model.get_valid_evaluation")
def _get_valid_evaluation(client: WeaveClient, evaluation_ref: str) -> Evaluation:
loaded_evaluation = client.get(Ref.parse_uri(evaluation_ref))

Expand All @@ -73,7 +73,7 @@ def _get_valid_evaluation(client: WeaveClient, evaluation_ref: str) -> Evaluatio
return loaded_evaluation


@ddtrace.tracer.wrap(name="evaluate_model_worker.evaluate_model.get_valid_model")
@traced(name="evaluate_model_worker.evaluate_model.get_valid_model")
def _get_valid_model(
client: WeaveClient, model_ref: str
) -> LLMStructuredCompletionModel:
Expand All @@ -87,7 +87,7 @@ def _get_valid_model(
return loaded_model


@ddtrace.tracer.wrap(name="evaluate_model_worker.evaluate_model.run_evaluation")
@traced(name="evaluate_model_worker.evaluate_model.run_evaluation")
def _run_evaluation(
loaded_evaluation: Evaluation,
loaded_model: LLMStructuredCompletionModel,
Expand Down
Loading