Skip to content

Commit ca80b03

Browse files
committed
feat: support media metrics
1 parent f38d19a commit ca80b03

5 files changed

Lines changed: 85 additions & 40 deletions

File tree

swanlab/api/column.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from swanlab.api.base import ApiClientContext, BaseEntity
1111
from swanlab.api.typings.column import ApiColumnType
12-
from swanlab.api.typings.common import ApiResponseType, PaginatedQuery
12+
from swanlab.api.typings.common import ApiMetricTypeLiteral, ApiResponseType, PaginatedQuery
1313
from swanlab.api.utils import get_properties, parse_column_data_type, resovle_run_path, validate_column_params
1414

1515

@@ -41,7 +41,7 @@ def __init__(
4141

4242
def _ensure_data(self) -> ApiColumnType:
4343
if self._data is None:
44-
validate_column_params(column_class=self._column_class)
44+
validate_column_params(column_class=self._column_class, column_type=self._column_type)
4545
extra: Dict[str, Any] = {"search": self._key}
4646
if self._column_class:
4747
extra["class"] = self._column_class
@@ -82,10 +82,9 @@ def run_id(self) -> str:
8282

8383
@property
8484
def key(self) -> str:
85-
res_key = self._ensure_data().get("key", "")
86-
if res_key and res_key != self._key:
87-
self._key = res_key
88-
return res_key
85+
if self._key:
86+
return self._key
87+
return self._ensure_data().get("key", "")
8988

9089
@property
9190
def name(self) -> str:
@@ -112,16 +111,11 @@ def error(self) -> Optional[Dict[str, Any]]:
112111
"""列的错误信息。"""
113112
return self._ensure_data().get("error", {})
114113

115-
def _require_found(self) -> None:
116-
"""确保列数据已加载且存在,否则抛出清晰错误。"""
117-
self._ensure_data()
118-
if not self.key:
119-
raise ValueError(f"Column '{self._key}' not found in the experiment")
120-
121-
def metric(self, sample: int = 1500, ignore_timestamp: bool = False) -> Dict[str, Any]:
114+
def metric(
115+
self, sample: int = 1500, metric_type: ApiMetricTypeLiteral = "SCALAR", ignore_timestamp: bool = False
116+
) -> Dict[str, Any]:
122117
from swanlab.api.metric import Metric
123118

124-
self._require_found()
125119
metric_type = parse_column_data_type(self.column_type)
126120
metric = Metric(
127121
ctx=self._ctx,
@@ -137,8 +131,10 @@ def metric(self, sample: int = 1500, ignore_timestamp: bool = False) -> Dict[str
137131
def export_csv(self) -> ApiResponseType:
138132
from swanlab.api.metric import Metric
139133

140-
self._require_found()
141134
metric_type = parse_column_data_type(self.column_type)
135+
if metric_type != "SCALAR":
136+
err_msg = "export_csv() only support SCALAR metric_type"
137+
return ApiResponseType(ok=False, errmsg=err_msg, data=None)
142138
metric = Metric(
143139
ctx=self._ctx,
144140
project_id=self.project_id,

swanlab/api/experiment.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Dict, Iterator, List, Optional, Union, cast
99

1010
from swanlab.api.base import ApiClientContext, BaseEntity
11-
from swanlab.api.typings.common import PaginatedQuery
11+
from swanlab.api.typings.common import ApiMetricTypeLiteral, PaginatedQuery
1212
from swanlab.api.typings.experiment import (
1313
ApiExperimentLabelType,
1414
ApiExperimentProfileType,
@@ -148,15 +148,21 @@ def column(self, key: str, column_class: Optional[str] = "CUSTOM", column_type:
148148
column_type=column_type,
149149
)
150150

151-
def metric(self, key: str, sample: int = 1500, ignore_timestamp: bool = False) -> Dict[str, Any]:
152-
"""
153-
获取实验下指定列的指标数据,最大返回 1500 条。
151+
def metric(
152+
self, key: str, sample: int = 1500, metric_type: ApiMetricTypeLiteral = "SCALAR", ignore_timestamp: bool = False
153+
) -> Dict[str, Any]:
154+
from swanlab.api.metric import Metric
154155

155-
:param key: 列的 key
156-
:param sample: 采样条数
157-
:param ignore_timestamp: 是否过滤 timestamp 字段
158-
"""
159-
return self.column(key=key).metric(sample=sample, ignore_timestamp=ignore_timestamp)
156+
metric = Metric(
157+
ctx=self._ctx,
158+
project_id=self.project_id,
159+
run_id=self.run_id,
160+
key=key,
161+
sample=sample,
162+
metric_type=metric_type,
163+
ignore_timestamp=ignore_timestamp,
164+
)
165+
return metric.json()
160166

161167
def columns(
162168
self,

swanlab/api/metric.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from swanlab.api.base import ApiClientContext, BaseEntity
1111
from swanlab.api.typings import ApiColumnCsvExportType, ApiResponseType
12-
from swanlab.api.typings.metric import ApiScalarSeriesType
12+
from swanlab.api.typings.metric import ApiLogSeriesType, ApiMediaSeriesType, ApiMediaType, ApiScalarSeriesType
1313
from swanlab.api.utils import get_properties, validate_metric_type
1414

1515

@@ -96,6 +96,12 @@ def _build_scalar_payload(self) -> Dict[str, Any]:
9696
"columns": [{"experimentId": self.run_id, "key": self.key}],
9797
}
9898

99+
def _build_media_payload(self) -> Dict[str, Any]:
100+
return {
101+
"projectId": self.project_id,
102+
"columns": [{"experimentId": self.run_id, "key": self.key}],
103+
}
104+
99105
# ------------------------------------------------------------------
100106
# 类型专属加载
101107
# ------------------------------------------------------------------
@@ -118,11 +124,33 @@ def _fetch_scalar(self) -> ApiScalarSeriesType:
118124
res[field] = stat_data.get(field, {})
119125
return res
120126

121-
def _fetch_media(self) -> Dict[str, Any]:
122-
return {}
127+
def _fetch_media(self) -> ApiMediaSeriesType:
128+
res = ApiMediaSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key)
129+
payload = self._build_media_payload()
130+
raw_resp = self._post("/house/metrics/f_media", data=payload)
131+
raw_data = self._extract_first(raw_resp)
132+
if raw_data is None:
133+
return res
134+
# print(raw_data)
135+
metrics: List[ApiMediaType] = []
136+
prefix = f"{self.project_id}/{self.run_id}"
137+
for entry in raw_data.get("metrics", []):
138+
paths = entry.get("data", [])
139+
mores = entry.get("more", [])
140+
items = []
141+
for i, path in enumerate(paths):
142+
item = {"path": path}
143+
if i < len(mores) and isinstance(mores[i], dict):
144+
item.update(mores[i])
145+
items.append(item)
146+
metrics.append({"index": entry.get("index", 0), "prefix": prefix, "items": items})
147+
148+
res["metrics"] = metrics
149+
return res
123150

124-
def _fetch_logs(self) -> Dict[str, Any]:
125-
return {}
151+
def _fetch_logs(self) -> ApiLogSeriesType:
152+
res = ApiLogSeriesType(projectId=self.project_id, experimentId=self.run_id, key="LOG")
153+
return res
126154

127155
# ------------------------------------------------------------------
128156
# 导出
@@ -134,6 +162,9 @@ def export_csv(self) -> ApiResponseType:
134162
135163
:return: ApiResponseType,成功时 data 包含临时下载 URL
136164
"""
165+
if self.metric_type != "SCALAR":
166+
err_msg = "export_csv() only support SCALAR metric_type"
167+
return ApiResponseType(ok=False, errmsg=err_msg, data=None)
137168
resp = self._get(f"/experiment/{self._run_id}/column/csv", params={"key": self.key})
138169
if not resp.ok:
139170
return resp

swanlab/api/typings/metric.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,24 @@ class ApiScalarSummaryItemType(TypedDict, total=False):
6565

6666

6767
# ---------------------------------------------------------------------------
68-
# Media — 媒体 item 数据
68+
# Media — 媒体数据
6969
# ---------------------------------------------------------------------------
70+
class ApiMediaItemDataType(TypedDict, total=False):
71+
path: str
72+
73+
7074
class ApiMediaType(TypedDict, total=False):
71-
# 项目路径: proj_id/run_id 拼接而成
75+
index: int
7276
prefix: str
73-
data: List[str]
74-
more: List[Dict[str, Any]]
77+
items: List[ApiMediaItemDataType]
78+
79+
80+
class ApiMediaSeriesType(ApiMetricColumnRefType, total=False):
81+
metrics: List[ApiMediaType]
7582

7683

7784
# ---------------------------------------------------------------------------
78-
# Log — 日志 item 数据
85+
# Log — 日志数据
7986
# ---------------------------------------------------------------------------
8087
class ApiLogType(TypedDict, total=False):
8188
epoch: int
@@ -85,5 +92,10 @@ class ApiLogType(TypedDict, total=False):
8592
timestamp: str
8693

8794

95+
class ApiLogSeriesType(ApiMetricColumnRefType, total=False):
96+
logs: List[ApiLogType]
97+
count: int
98+
99+
88100
# 统一数据类型定义用于类型提示
89101
ApiMetricType = Union[ApiScalarType, ApiMediaType, ApiLogType]

swanlab/api/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ def validate_filter(item: Dict[str, Any]) -> None:
107107
raise ValueError(f"filter value must be a list, got {type(item['value']).__name__}")
108108

109109

110-
def validate_metric_type(item: str, key: Optional[str] = None):
111-
"""校验 metric_type 的合法性"""
112-
if item not in _VALID_METRIC_TYPES:
113-
raise ValueError(f"Invalid metric_type: {item!r}, expected one of {sorted(_VALID_METRIC_TYPES)}")
114-
if not key and item != "LOG":
115-
raise ValueError("key must NOT be None if metric_type != LOG")
110+
def validate_metric_type(metric_type: str, key: Optional[str] = None) -> None:
111+
"""校验 metric_type 的合法性。非 LOG 类型必须提供非空 key。"""
112+
if metric_type not in _VALID_METRIC_TYPES:
113+
raise ValueError(f"Invalid metric_type: {metric_type!r}, expected one of {sorted(_VALID_METRIC_TYPES)}")
114+
if metric_type != "LOG" and not key:
115+
raise ValueError(f"key is required for metric_type {metric_type!r}, got key={key!r}")
116116

117117

118118
def validate_group(item: Dict[str, Any]) -> None:

0 commit comments

Comments
 (0)