diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index e025b166d200..9498119cc441 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -15,7 +15,6 @@ from zoneinfo import ZoneInfo import clickhouse_connect -import ddtrace from cachetools import TTLCache from clickhouse_connect import common as ch_common from clickhouse_connect.driver.client import Client as CHClient @@ -57,6 +56,7 @@ # GenAI / Agent observability imports from weave.trace_server.agents.clickhouse import AgentQueryHandler, AgentWriteHandler +from weave.trace_server.tracing import traced, traced_generator from weave.trace_server.agents.completion_spans import build_completion_span from weave.trace_server.agents.kafka_events import ScoreAgentSpansEvent from weave.trace_server.agents.types import ( @@ -164,7 +164,6 @@ IMAGE_GENERATION_CREATE_OP_NAME, ) from weave.trace_server.datadog import ( - generator_trace, record_db_insert, set_current_span_dd_tags, set_root_span_dd_tags, @@ -646,7 +645,7 @@ def _ensure_placeholder_file_exists(self, project_id: str) -> None: ) @tag_db_insert_path("otel_export") - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.otel_export") + @traced(name="clickhouse_trace_server_batched.otel_export") def otel_export(self, req: tsi.OTelExportReq) -> tsi.OTelExportRes: assert_non_null_wb_user_id(req) calls, rejected_spans, error_messages = self._otel_proto_to_calls(req) @@ -743,9 +742,7 @@ def otel_export(self, req: tsi.OTelExportReq) -> tsi.OTelExportRes: ) return tsi.OTelExportRes() - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched.otel_export.proto_to_calls" - ) + @traced(name="clickhouse_trace_server_batched.otel_export.proto_to_calls") def _otel_proto_to_calls( self, req: tsi.OTelExportReq ) -> tuple[ @@ -802,7 +799,7 @@ def _otel_proto_to_calls( set_current_span_dd_tags({"call_count": len(calls)}) return calls, rejected_spans, error_messages - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.otel_export.build_rows") + @traced(name="clickhouse_trace_server_batched.otel_export.build_rows") def _otel_build_rows( self, calls: list[ @@ -836,7 +833,7 @@ def _otel_build_rows( set_current_span_dd_tags({"row_count": len(rows)}) return rows - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.kafka_producer.flush") + @traced(name="clickhouse_trace_server_batched.kafka_producer.flush") def _flush_kafka_producer(self) -> None: producer = self.kafka_producer if producer is not None: @@ -1113,9 +1110,7 @@ def call_end_v2(self, req: tsi.CallEndV2Req) -> tsi.CallEndV2Res: return tsi.CallEndV2Res() - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched._update_call_end_in_calls_complete" - ) + @traced(name="clickhouse_trace_server_batched._update_call_end_in_calls_complete") def _update_call_end_in_calls_complete( self, end_call: tsi.EndedCallSchemaForInsert ) -> None: @@ -1258,7 +1253,7 @@ def calls_query(self, req: tsi.CallsQueryReq) -> tsi.CallsQueryRes: stream = self.calls_query_stream(req) return tsi.CallsQueryRes(calls=list(stream)) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.calls_query_stats") + @traced(name="clickhouse_trace_server_batched.calls_query_stats") def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsRes: """Returns a stats object for the given query. This is useful for counts or other aggregate statistics that are not directly queryable from the calls themselves. @@ -1476,7 +1471,7 @@ def feedback_payload_schema( """Discover feedback payload schema from sample rows.""" return feedback_payload_schema_handler(self, req) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.trace_usage") + @traced(name="clickhouse_trace_server_batched.trace_usage") def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes: """Compute per-call usage for a trace, with descendant rollup. @@ -1516,7 +1511,7 @@ def trace_usage(self, req: tsi.TraceUsageReq) -> tsi.TraceUsageRes: unfinished_call_ids=sorted(unfinished_call_ids), ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.calls_usage") + @traced(name="clickhouse_trace_server_batched.calls_usage") def calls_usage(self, req: tsi.CallsUsageReq) -> tsi.CallsUsageRes: """Compute aggregated usage for multiple root calls. @@ -1580,7 +1575,7 @@ def calls_usage(self, req: tsi.CallsUsageReq) -> tsi.CallsUsageRes: unfinished_call_ids=sorted(unfinished_call_ids), ) - @generator_trace("clickhouse_trace_server_batched.calls_query_stream") + @traced_generator(name="clickhouse_trace_server_batched.calls_query_stream") def calls_query_stream(self, req: tsi.CallsQueryReq) -> Iterator[tsi.CallSchema]: """Returns a stream of calls that match the given query.""" read_table = self.table_routing_resolver.resolve_read_table( @@ -1732,7 +1727,7 @@ def row_to_call_schema_dict(row: tuple[Any, ...]) -> dict[str, Any]: if hasattr(raw_res, "close"): raw_res.close() - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._add_feedback_to_calls") + @traced(name="clickhouse_trace_server_batched._add_feedback_to_calls") def _add_feedback_to_calls( self, project_id: str, calls: list[dict[str, Any]] ) -> None: @@ -1766,7 +1761,7 @@ def _get_refs_to_resolve( refs_to_resolve[i, col] = ref return refs_to_resolve - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._expand_call_refs") + @traced(name="clickhouse_trace_server_batched._expand_call_refs") def _expand_call_refs( self, project_id: str, @@ -1814,7 +1809,7 @@ def _expand_call_refs( val["_ref"] = ref.uri set_nested_key(calls[i], col, val) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.calls_delete") + @traced(name="clickhouse_trace_server_batched.calls_delete") def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: assert_non_null_wb_user_id(req) if len(req.call_ids) > ch_settings.MAX_DELETE_CALLS_COUNT: @@ -1885,7 +1880,7 @@ def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes: return tsi.CallsDeleteRes(num_deleted=len(all_descendants)) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._delete_calls_complete") + @traced(name="clickhouse_trace_server_batched._delete_calls_complete") def _delete_calls_complete(self, project_id: str, call_ids: list[str]) -> None: pb = ParamBuilder() project_id_param = pb.add_param(project_id) @@ -1912,7 +1907,7 @@ def _ensure_valid_update_field(self, req: tsi.CallUpdateReq) -> None: f"One of [{', '.join(valid_update_fields)}] is required for call update" ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.call_update") + @traced(name="clickhouse_trace_server_batched.call_update") @tag_db_insert_path("call_update") def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: assert_non_null_wb_user_id(req) @@ -1936,7 +1931,7 @@ def call_update(self, req: tsi.CallUpdateReq) -> tsi.CallUpdateRes: return tsi.CallUpdateRes() - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._update_calls_complete") + @traced(name="clickhouse_trace_server_batched._update_calls_complete") def _update_calls_complete( self, project_id: str, call_id: str, display_name: str ) -> None: @@ -1959,7 +1954,7 @@ def _update_calls_complete( settings=ch_settings.CLICKHOUSE_LIGHTWEIGHT_UPDATE_SETTINGS, ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.obj_create") + @traced(name="clickhouse_trace_server_batched.obj_create") @tag_db_insert_path("obj_create") def obj_create(self, req: tsi.ObjCreateReq) -> tsi.ObjCreateRes: # Partial-failure semantics: ClickHouse cannot atomically write @@ -2055,7 +2050,7 @@ def _reject_obj_name_type_collision( existing_base_object_classes=mismatched, ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.create_obj_batch") + @traced(name="clickhouse_trace_server_batched.create_obj_batch") @tag_db_insert_path("obj_create_batch") def obj_create_batch( self, batch: list[tsi.ObjSchemaForInsert] @@ -2184,7 +2179,7 @@ def obj_read(self, req: tsi.ObjReadReq) -> tsi.ObjReadRes: self._enrich_objs_with_tags_and_aliases(req.project_id, [obj_schema]) return tsi.ObjReadRes(obj=obj_schema) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.objs_query") + @traced(name="clickhouse_trace_server_batched.objs_query") def objs_query(self, req: tsi.ObjQueryReq) -> tsi.ObjQueryRes: object_query_builder = ObjectMetadataQueryBuilder(req.project_id) if req.filter: @@ -2457,7 +2452,7 @@ def aliases_list(self, req: tsi.AliasesListReq) -> tsi.AliasesListRes: aliases = [row[0] for row in query_result.result_rows] return tsi.AliasesListRes(aliases=aliases) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._get_tags_for_objects") + @traced(name="clickhouse_trace_server_batched._get_tags_for_objects") def _get_tags_for_objects( self, project_id: str, @@ -2476,9 +2471,7 @@ def _get_tags_for_objects( result.setdefault(key, []).append(row[2]) return result - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched._get_aliases_for_objects" - ) + @traced(name="clickhouse_trace_server_batched._get_aliases_for_objects") def _get_aliases_for_objects( self, project_id: str, @@ -2497,7 +2490,7 @@ def _get_aliases_for_objects( result.setdefault(key, []).append(row[2]) return result - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._maybe_resolve_alias") + @traced(name="clickhouse_trace_server_batched._maybe_resolve_alias") def _maybe_resolve_alias( self, project_id: str, @@ -2735,7 +2728,7 @@ def table_query_stream( offset=req.offset, ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._table_query_stream") + @traced_generator(name="clickhouse_trace_server_batched._table_query_stream") def _table_query_stream( self, project_id: str, @@ -2890,9 +2883,7 @@ def _default_true(val: bool | None) -> bool: **dict(zip(columns, query_result.result_rows[0], strict=False)) ) - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched.project_ttl_settings_read" - ) + @traced(name="clickhouse_trace_server_batched.project_ttl_settings_read") def project_ttl_settings_read( self, req: tsi.ProjectTTLSettingsReadReq ) -> tsi.ProjectTTLSettingsReadRes: @@ -2902,9 +2893,7 @@ def project_ttl_settings_read( ) @tag_db_insert_path("project_ttl_settings_update") - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched.project_ttl_settings_update" - ) + @traced(name="clickhouse_trace_server_batched.project_ttl_settings_update") def project_ttl_settings_update( self, req: tsi.ProjectTTLSettingsUpdateReq ) -> tsi.ProjectTTLSettingsUpdateRes: @@ -3002,7 +2991,7 @@ def threads_query_stream( ) # Annotation Queue API - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.annotation_queue_create") + @traced(name="clickhouse_trace_server_batched.annotation_queue_create") def annotation_queue_create( self, req: tsi.AnnotationQueueCreateReq ) -> tsi.AnnotationQueueCreateRes: @@ -3032,9 +3021,7 @@ def annotation_queue_create( return tsi.AnnotationQueueCreateRes(id=queue_id) - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched.annotation_queues_query_stream" - ) + @traced_generator(name="clickhouse_trace_server_batched.annotation_queues_query_stream") def annotation_queues_query_stream( self, req: tsi.AnnotationQueuesQueryReq ) -> Iterator[tsi.AnnotationQueueSchema]: @@ -3087,7 +3074,7 @@ def annotation_queues_query_stream( deleted_at=deleted_at_with_tz, ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.annotation_queue_read") + @traced(name="clickhouse_trace_server_batched.annotation_queue_read") def annotation_queue_read( self, req: tsi.AnnotationQueueReadReq ) -> tsi.AnnotationQueueReadRes: @@ -3121,7 +3108,7 @@ def annotation_queue_read( return tsi.AnnotationQueueReadRes(queue=queue) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.annotation_queue_update") + @traced(name="clickhouse_trace_server_batched.annotation_queue_update") def annotation_queue_update( self, req: tsi.AnnotationQueueUpdateReq ) -> tsi.AnnotationQueueUpdateRes: @@ -3218,7 +3205,7 @@ def annotation_queue_update( return tsi.AnnotationQueueUpdateRes(queue=queue) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.annotation_queue_delete") + @traced(name="clickhouse_trace_server_batched.annotation_queue_delete") def annotation_queue_delete( self, req: tsi.AnnotationQueueDeleteReq ) -> tsi.AnnotationQueueDeleteRes: @@ -3275,9 +3262,7 @@ def annotation_queue_delete( return tsi.AnnotationQueueDeleteRes(queue=queue) @tag_db_insert_path("annotation_queue_add_calls") - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched.annotation_queue_add_calls" - ) + @traced(name="clickhouse_trace_server_batched.annotation_queue_add_calls") def annotation_queue_add_calls( self, req: tsi.AnnotationQueueAddCallsReq ) -> tsi.AnnotationQueueAddCallsRes: @@ -3372,9 +3357,7 @@ def annotation_queue_add_calls( added_count=len(calls_data), duplicates=len(existing_call_ids) ) - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched.annotation_queue_items_query" - ) + @traced(name="clickhouse_trace_server_batched.annotation_queue_items_query") def annotation_queue_items_query( self, req: tsi.AnnotationQueueItemsQueryReq ) -> tsi.AnnotationQueueItemsQueryRes: @@ -3420,7 +3403,7 @@ def annotation_queue_items_query( return tsi.AnnotationQueueItemsQueryRes(items=items) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.annotation_queues_stats") + @traced(name="clickhouse_trace_server_batched.annotation_queues_stats") def annotation_queues_stats( self, req: tsi.AnnotationQueuesStatsReq ) -> tsi.AnnotationQueuesStatsRes: @@ -3494,9 +3477,7 @@ def _fetch_queue_item_for_progress_update( raise ValueError(f"Failed to fetch queue item '{item_id}'") - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched.annotator_queue_items_progress_update" - ) + @traced(name="clickhouse_trace_server_batched.annotator_queue_items_progress_update") def annotator_queue_items_progress_update( self, req: tsi.AnnotatorQueueItemsProgressUpdateReq ) -> tsi.AnnotatorQueueItemsProgressUpdateRes: @@ -5697,7 +5678,7 @@ def _do_read() -> T: return _do_read() - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._obj_read_with_retry") + @traced(name="clickhouse_trace_server_batched._obj_read_with_retry") def _obj_read_with_retry( self, req: tsi.ObjReadReq, max_attempts: int = 2 ) -> tsi.ObjReadRes: @@ -5706,9 +5687,7 @@ def _obj_read_with_retry( lambda: self.obj_read(req), max_attempts=max_attempts ) - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched._file_content_read_with_retry" - ) + @traced(name="clickhouse_trace_server_batched._file_content_read_with_retry") def _file_content_read_with_retry( self, req: tsi.FileContentReadReq, max_attempts: int = 2 ) -> tsi.FileContentReadRes: @@ -5717,7 +5696,7 @@ def _file_content_read_with_retry( lambda: self._file_content_read_once(req), max_attempts=max_attempts ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._parsed_refs_read_batch") + @traced(name="clickhouse_trace_server_batched._parsed_refs_read_batch") def _parsed_refs_read_batch( self, parsed_refs: ObjRefListType, @@ -5746,9 +5725,7 @@ def make_ref_cache_key(ref: ri.InternalObjectRef) -> str: # Return the final data payload return [final_result_cache[make_ref_cache_key(ref)] for ref in parsed_refs] - @ddtrace.tracer.wrap( - name="clickhouse_trace_server_batched._refs_read_batch_within_project" - ) + @traced(name="clickhouse_trace_server_batched._refs_read_batch_within_project") def _refs_read_batch_within_project( self, project_id_scope: str, @@ -6026,12 +6003,12 @@ def file_create(self, req: tsi.FileCreateReq) -> tsi.FileCreateRes: set_root_span_dd_tags({"write_bytes": len(req.content)}) return tsi.FileCreateRes(digest=digest) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._file_create_clickhouse") + @traced(name="clickhouse_trace_server_batched._file_create_clickhouse") def _file_create_clickhouse(self, req: tsi.FileCreateReq, digest: str) -> None: set_root_span_dd_tags({"storage_provider": "clickhouse"}) self._insert_file_chunks(file_chunks_for(req, digest)) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._file_create_bucket") + @traced(name="clickhouse_trace_server_batched._file_create_bucket") def _file_create_bucket( self, req: tsi.FileCreateReq, digest: str, client: FileStorageClient ) -> None: @@ -6063,7 +6040,7 @@ def _file_create_bucket( ] ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._flush_file_chunks") + @traced(name="clickhouse_trace_server_batched._flush_file_chunks") def _flush_file_chunks(self) -> None: if not self._flush_immediately: raise ValueError("File chunks must be flushed immediately") @@ -6200,7 +6177,7 @@ def _file_content_read_once( set_root_span_dd_tags({"read_bytes": len(bytes)}) return tsi.FileContentReadRes(content=bytes) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._file_read_bucket") + @traced(name="clickhouse_trace_server_batched._file_read_bucket") def _file_read_bucket(self, file_storage_uri: FileStorageURI) -> bytes: set_root_span_dd_tags({"storage_provider": "bucket"}) client = self.file_storage_client @@ -6337,7 +6314,7 @@ def feedback_create(self, req: tsi.FeedbackCreateReq) -> tsi.FeedbackCreateRes: return format_feedback_to_res(row) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched.feedback_create_batch") + @traced(name="clickhouse_trace_server_batched.feedback_create_batch") @tag_db_insert_path("feedback_create_batch") def feedback_create_batch( self, req: tsi.FeedbackCreateBatchReq @@ -7068,7 +7045,7 @@ def _run_migrations(self) -> None: ) migrator.apply_migrations(self._database) - @generator_trace("clickhouse_trace_server_batched._query_stream") + @traced_generator(name="clickhouse_trace_server_batched._query_stream") def _query_stream( self, query: str, @@ -7117,7 +7094,7 @@ def _query_stream( # always raises, optionally with custom error class handle_clickhouse_query_error(e) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._query") + @traced(name="clickhouse_trace_server_batched._query") def _query( self, query: str, @@ -7165,7 +7142,7 @@ def _query( ) return res - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._command") + @traced(name="clickhouse_trace_server_batched._command") def _command( self, command: str, @@ -7214,7 +7191,7 @@ def _command( ) return - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._insert") + @traced(name="clickhouse_trace_server_batched._insert") def _insert( self, table: str, @@ -7290,7 +7267,7 @@ def _insert_call(self, ch_call: CallCHInsertable) -> None: if self._flush_immediately: self._flush_calls() - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._flush_calls") + @traced(name="clickhouse_trace_server_batched._flush_calls") def _flush_calls(self) -> None: try: self._insert_call_batch(self._call_batch) @@ -7374,7 +7351,7 @@ def _insert_call_complete_batch( do_sync_insert=do_sync_insert, ) - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._flush_calls_complete") + @traced(name="clickhouse_trace_server_batched._flush_calls_complete") def _flush_calls_complete(self) -> None: """Flush the calls_complete batch to the database.""" if not self._calls_complete_batch: @@ -7389,7 +7366,7 @@ def _flush_calls_complete(self) -> None: finally: self._calls_complete_batch = [] - @ddtrace.tracer.wrap(name="clickhouse_trace_server_batched._strip_large_values") + @traced(name="clickhouse_trace_server_batched._strip_large_values") def _strip_large_values(self, batch: list[list[Any]]) -> list[list[Any]]: """Iterate through the batch and replace large JSON values with placeholders.