|
5 | 5 | @description: SwanLab 运行时实验API |
6 | 6 | """ |
7 | 7 |
|
8 | | -from typing import List, Optional |
| 8 | +from typing import Dict, List, Optional, Union |
9 | 9 |
|
10 | 10 | from swanlab.exceptions import ApiError |
11 | 11 | from swanlab.sdk.internal.core_python import client |
12 | 12 | from swanlab.sdk.internal.pkg import helper |
13 | | -from swanlab.sdk.typings.run import ResumeType |
| 13 | +from swanlab.sdk.typings.core_python.api.experiment import RunType |
| 14 | +from swanlab.sdk.typings.run import ResumeType, RunStateType |
| 15 | + |
| 16 | +from .utils import parse_column_type, to_camel_case |
14 | 17 |
|
15 | 18 |
|
16 | 19 | def create_or_resume_experiment( |
@@ -65,3 +68,92 @@ def create_or_resume_experiment( |
65 | 68 | # 200代表实验已存在,开启更新模式 |
66 | 69 | # 201代表实验不存在,新建实验 |
67 | 70 | return resp.raw.status_code == 201 |
| 71 | + |
| 72 | + |
| 73 | +def send_experiment_heartbeat(*, cuid: str, flag_id: str) -> None: |
| 74 | + """ |
| 75 | + 发送实验心跳,保持实验处于活跃状态 |
| 76 | + :param cuid: 实验唯一标识符 |
| 77 | + :param flag_id: 实验标记ID |
| 78 | + """ |
| 79 | + client.post(f"/house/experiments/{cuid}/heartbeat", {"flagId": flag_id}) |
| 80 | + |
| 81 | + |
| 82 | +def update_experiment_state( |
| 83 | + *, |
| 84 | + username: str, |
| 85 | + projname: str, |
| 86 | + cuid: str, |
| 87 | + state: RunStateType, |
| 88 | + finished_at: Optional[str] = None, |
| 89 | +) -> None: |
| 90 | + """ |
| 91 | + 更新实验状态 |
| 92 | + :param username: 实验所属用户名 |
| 93 | + :param projname: 实验所属项目名称 |
| 94 | + :param cuid: 实验唯一标识符 |
| 95 | + :param state: 实验状态 |
| 96 | + :param finished_at: 实验结束时间,格式为 ISO 8601,如果不提供则使用当前时间 |
| 97 | + """ |
| 98 | + put_data = { |
| 99 | + "state": state, |
| 100 | + "finishedAt": finished_at, |
| 101 | + "from": "sdk", |
| 102 | + } |
| 103 | + put_data = {k: v for k, v in put_data.items() if v is not None} |
| 104 | + client.put(f"/project/{username}/{projname}/runs/{cuid}/state", put_data) |
| 105 | + |
| 106 | + |
| 107 | +def get_project_experiments( |
| 108 | + *, |
| 109 | + path: str, |
| 110 | + filters: Optional[Dict[str, object]] = None, |
| 111 | +) -> Union[List[RunType], Dict[str, List[RunType]]]: |
| 112 | + """ |
| 113 | + 获取指定项目下的所有实验信息 |
| 114 | + 若有实验分组,则返回一个字典,使用时需递归展平实验数据 |
| 115 | + :param path: 项目路径 username/project |
| 116 | + :param filters: 筛选实验的条件,可选 |
| 117 | + """ |
| 118 | + parsed_filters = ( |
| 119 | + [ |
| 120 | + { |
| 121 | + "key": to_camel_case(key) if parse_column_type(key) == "STABLE" else key.split(".", 1)[-1], |
| 122 | + "active": True, |
| 123 | + "value": [value], |
| 124 | + "op": "EQ", |
| 125 | + "type": parse_column_type(key), |
| 126 | + } |
| 127 | + for key, value in filters.items() |
| 128 | + ] |
| 129 | + if filters |
| 130 | + else [] |
| 131 | + ) |
| 132 | + return client.post(f"/project/{path}/runs/shows", data={"filters": parsed_filters}).data |
| 133 | + |
| 134 | + |
| 135 | +def get_single_experiment(*, path: str) -> RunType: |
| 136 | + """ |
| 137 | + 获取指定实验信息 |
| 138 | + :param path: 实验路径 username/project/expid |
| 139 | + """ |
| 140 | + proj_path, expid = path.rsplit("/", 1) |
| 141 | + return client.get(f"/project/{proj_path}/runs/{expid}").data |
| 142 | + |
| 143 | + |
| 144 | +def get_experiment_metrics(*, expid: str, key: str) -> Dict[str, str]: |
| 145 | + """ |
| 146 | + 获取指定字段的指标数据,返回csv网址 |
| 147 | + :param expid: 实验cuid |
| 148 | + :param key: 指定字段列表 |
| 149 | + """ |
| 150 | + return client.get(f"/experiment/{expid}/column/csv", params={"key": key}).data |
| 151 | + |
| 152 | + |
| 153 | +def delete_experiment(*, path: str) -> None: |
| 154 | + """ |
| 155 | + 删除指定实验 |
| 156 | + :param path: 实验路径 'username/project/expid' |
| 157 | + """ |
| 158 | + proj_path, expid = path.rsplit("/", 1) |
| 159 | + client.delete(f"/project/{proj_path}/runs/{expid}") |
0 commit comments