99
1010from swanlab .api .base import ApiClientContext , BaseEntity
1111from swanlab .api .typings import ApiColumnCsvExportType , ApiResponseType
12- from swanlab .api .typings .metric import ApiScalarSeriesType
12+ from swanlab .api .typings .metric import ApiLogSeriesType , ApiMediaSeriesType , ApiMediaType , ApiScalarSeriesType
1313from swanlab .api .utils import get_properties , validate_metric_type
1414
1515
@@ -96,6 +96,12 @@ def _build_scalar_payload(self) -> Dict[str, Any]:
9696 "columns" : [{"experimentId" : self .run_id , "key" : self .key }],
9797 }
9898
99+ def _build_media_payload (self ) -> Dict [str , Any ]:
100+ return {
101+ "projectId" : self .project_id ,
102+ "columns" : [{"experimentId" : self .run_id , "key" : self .key }],
103+ }
104+
99105 # ------------------------------------------------------------------
100106 # 类型专属加载
101107 # ------------------------------------------------------------------
@@ -118,11 +124,33 @@ def _fetch_scalar(self) -> ApiScalarSeriesType:
118124 res [field ] = stat_data .get (field , {})
119125 return res
120126
121- def _fetch_media (self ) -> Dict [str , Any ]:
122- return {}
127+ def _fetch_media (self ) -> ApiMediaSeriesType :
128+ res = ApiMediaSeriesType (projectId = self .project_id , experimentId = self .run_id , key = self .key )
129+ payload = self ._build_media_payload ()
130+ raw_resp = self ._post ("/house/metrics/f_media" , data = payload )
131+ raw_data = self ._extract_first (raw_resp )
132+ if raw_data is None :
133+ return res
134+ # print(raw_data)
135+ metrics : List [ApiMediaType ] = []
136+ prefix = f"{ self .project_id } /{ self .run_id } "
137+ for entry in raw_data .get ("metrics" , []):
138+ paths = entry .get ("data" , [])
139+ mores = entry .get ("more" , [])
140+ items = []
141+ for i , path in enumerate (paths ):
142+ item = {"path" : path }
143+ if i < len (mores ) and isinstance (mores [i ], dict ):
144+ item .update (mores [i ])
145+ items .append (item )
146+ metrics .append ({"index" : entry .get ("index" , 0 ), "prefix" : prefix , "items" : items })
147+
148+ res ["metrics" ] = metrics
149+ return res
123150
124- def _fetch_logs (self ) -> Dict [str , Any ]:
125- return {}
151+ def _fetch_logs (self ) -> ApiLogSeriesType :
152+ res = ApiLogSeriesType (projectId = self .project_id , experimentId = self .run_id , key = "LOG" )
153+ return res
126154
127155 # ------------------------------------------------------------------
128156 # 导出
@@ -134,6 +162,9 @@ def export_csv(self) -> ApiResponseType:
134162
135163 :return: ApiResponseType,成功时 data 包含临时下载 URL
136164 """
165+ if self .metric_type != "SCALAR" :
166+ err_msg = "export_csv() only support SCALAR metric_type"
167+ return ApiResponseType (ok = False , errmsg = err_msg , data = None )
137168 resp = self ._get (f"/experiment/{ self ._run_id } /column/csv" , params = {"key" : self .key })
138169 if not resp .ok :
139170 return resp
0 commit comments