|
5 | 5 | @description: SwanLab 运行时实验API |
6 | 6 | """ |
7 | 7 |
|
8 | | -from typing import List, Literal, Optional |
| 8 | +from typing import Dict, List, Literal, Optional, Union |
9 | 9 |
|
10 | 10 | from google.protobuf.timestamp_pb2 import Timestamp |
11 | 11 |
|
|
14 | 14 | from swanlab.sdk.internal.core_python import client |
15 | 15 | from swanlab.sdk.internal.pkg import helper |
16 | 16 | from swanlab.sdk.typings.core_python.api.experiment import InitExperimentType |
17 | | -from swanlab.sdk.typings.run import ResumeType |
| 17 | +from swanlab.sdk.typings.core_python.api.experiment import RunType |
| 18 | +from swanlab.sdk.typings.run import ResumeType, RunStateType |
| 19 | + |
| 20 | +from .utils import parse_column_type, to_camel_case |
18 | 21 |
|
19 | 22 |
|
20 | 23 | def create_or_resume_experiment( |
@@ -91,3 +94,93 @@ def stop_experiment(username: str, project: str, cuid: str, *, state: RunState, |
91 | 94 | "from": "sdk", |
92 | 95 | }, |
93 | 96 | ) |
| 97 | + return resp.raw.status_code == 201 |
| 98 | + |
| 99 | + |
| 100 | +def send_experiment_heartbeat(*, cuid: str, flag_id: str) -> None: |
| 101 | + """ |
| 102 | + 发送实验心跳,保持实验处于活跃状态 |
| 103 | + :param cuid: 实验唯一标识符 |
| 104 | + :param flag_id: 实验标记ID |
| 105 | + """ |
| 106 | + client.post(f"/house/experiments/{cuid}/heartbeat", {"flagId": flag_id}) |
| 107 | + |
| 108 | + |
| 109 | +def update_experiment_state( |
| 110 | + *, |
| 111 | + username: str, |
| 112 | + projname: str, |
| 113 | + cuid: str, |
| 114 | + state: RunStateType, |
| 115 | + finished_at: Optional[str] = None, |
| 116 | +) -> None: |
| 117 | + """ |
| 118 | + 更新实验状态 |
| 119 | + :param username: 实验所属用户名 |
| 120 | + :param projname: 实验所属项目名称 |
| 121 | + :param cuid: 实验唯一标识符 |
| 122 | + :param state: 实验状态 |
| 123 | + :param finished_at: 实验结束时间,格式为 ISO 8601,如果不提供则使用当前时间 |
| 124 | + """ |
| 125 | + put_data = { |
| 126 | + "state": state, |
| 127 | + "finishedAt": finished_at, |
| 128 | + "from": "sdk", |
| 129 | + } |
| 130 | + put_data = {k: v for k, v in put_data.items() if v is not None} |
| 131 | + client.put(f"/project/{username}/{projname}/runs/{cuid}/state", put_data) |
| 132 | + |
| 133 | + |
| 134 | +def get_project_experiments( |
| 135 | + *, |
| 136 | + path: str, |
| 137 | + filters: Optional[Dict[str, object]] = None, |
| 138 | +) -> Union[List[RunType], Dict[str, List[RunType]]]: |
| 139 | + """ |
| 140 | + 获取指定项目下的所有实验信息 |
| 141 | + 若有实验分组,则返回一个字典,使用时需递归展平实验数据 |
| 142 | + :param path: 项目路径 username/project |
| 143 | + :param filters: 筛选实验的条件,可选 |
| 144 | + """ |
| 145 | + parsed_filters = ( |
| 146 | + [ |
| 147 | + { |
| 148 | + "key": to_camel_case(key) if parse_column_type(key) == "STABLE" else key.split(".", 1)[-1], |
| 149 | + "active": True, |
| 150 | + "value": [value], |
| 151 | + "op": "EQ", |
| 152 | + "type": parse_column_type(key), |
| 153 | + } |
| 154 | + for key, value in filters.items() |
| 155 | + ] |
| 156 | + if filters |
| 157 | + else [] |
| 158 | + ) |
| 159 | + return client.post(f"/project/{path}/runs/shows", data={"filters": parsed_filters}).data |
| 160 | + |
| 161 | + |
| 162 | +def get_single_experiment(*, path: str) -> RunType: |
| 163 | + """ |
| 164 | + 获取指定实验信息 |
| 165 | + :param path: 实验路径 username/project/expid |
| 166 | + """ |
| 167 | + proj_path, expid = path.rsplit("/", 1) |
| 168 | + return client.get(f"/project/{proj_path}/runs/{expid}").data |
| 169 | + |
| 170 | + |
| 171 | +def get_experiment_metrics(*, expid: str, key: str) -> Dict[str, str]: |
| 172 | + """ |
| 173 | + 获取指定字段的指标数据,返回csv网址 |
| 174 | + :param expid: 实验cuid |
| 175 | + :param key: 指定字段列表 |
| 176 | + """ |
| 177 | + return client.get(f"/experiment/{expid}/column/csv", params={"key": key}).data |
| 178 | + |
| 179 | + |
| 180 | +def delete_experiment(*, path: str) -> None: |
| 181 | + """ |
| 182 | + 删除指定实验 |
| 183 | + :param path: 实验路径 'username/project/expid' |
| 184 | + """ |
| 185 | + proj_path, expid = path.rsplit("/", 1) |
| 186 | + client.delete(f"/project/{proj_path}/runs/{expid}") |
0 commit comments