From f4a5f3bc3254c2b78bd7ae5c4fc1a85d59b23ea7 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Tue, 21 Apr 2026 18:37:26 +0800 Subject: [PATCH 01/52] feat: openapi skeleton --- swanlab/api/__init__.py | 0 swanlab/api/base.py | 0 swanlab/api/helper.py | 0 swanlab/api/typings/common.py | 81 ++++++++++++++++++++++++++ swanlab/api/typings/experiment.py | 30 ++++++++++ swanlab/api/typings/project.py | 28 +++++++++ swanlab/api/typings/workspace.py | 19 ++++++ swanlab/cli/__init__.py | 4 ++ swanlab/cli/api/__init__.py | 28 +++++++++ swanlab/cli/api/experiment/__init__.py | 10 ++++ swanlab/cli/api/helper.py | 10 ++++ swanlab/cli/api/project/__init__.py | 10 ++++ swanlab/cli/api/selfhosted/__init__.py | 10 ++++ swanlab/cli/api/user/__init__.py | 10 ++++ swanlab/cli/api/workspace/__init__.py | 10 ++++ 15 files changed, 250 insertions(+) create mode 100644 swanlab/api/__init__.py create mode 100644 swanlab/api/base.py create mode 100644 swanlab/api/helper.py create mode 100644 swanlab/api/typings/common.py create mode 100644 swanlab/api/typings/experiment.py create mode 100644 swanlab/api/typings/project.py create mode 100644 swanlab/api/typings/workspace.py create mode 100644 swanlab/cli/api/__init__.py create mode 100644 swanlab/cli/api/experiment/__init__.py create mode 100644 swanlab/cli/api/helper.py create mode 100644 swanlab/cli/api/project/__init__.py create mode 100644 swanlab/cli/api/selfhosted/__init__.py create mode 100644 swanlab/cli/api/user/__init__.py create mode 100644 swanlab/cli/api/workspace/__init__.py diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/swanlab/api/base.py b/swanlab/api/base.py new file mode 100644 index 000000000..e69de29bb diff --git a/swanlab/api/helper.py b/swanlab/api/helper.py new file mode 100644 index 000000000..e69de29bb diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py new file mode 100644 index 000000000..06c25cf8c --- /dev/null +++ b/swanlab/api/typings/common.py @@ -0,0 +1,81 @@ +""" +@author: caddiesnew +@file: common.py +@time: 2026/4/20 +@description: 公共查询 API 通用类型定义 +""" + +from typing import Any, Dict, List, Literal, TypedDict + +# 实验状态类型 +ApiRunStateType = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] + +# 可见性类型 +ApiVisibilityType = Literal["PUBLIC", "PRIVATE"] + +# 工作空间类型 +ApiWorkspaceType = Literal["TEAM", "PERSON"] + +# 工作空间成员类型 +ApiRoleType = Literal["VISITOR", "VIEWER", "MEMBER", "OWNER"] + +# Self-Hosted 身份类型 +ApiIdentityType = Literal["root", "user"] + +# License 许可证类型 +ApiLicensePlanType = Literal["free", "commercial"] + + +class ApiLabelType(TypedDict): + name: str + + +class ApiPaginationType(TypedDict): + list: List + size: int + pages: int + total: int + + +class ApiResponseType: + """ + API 响应的统一封装,保证任何异常都不会导致程序 crash。 + + - ok=True 时 data 持有正常返回值 + - ok=False 时 data 为 None,errmsg 描述失败原因 + """ + + __slots__ = ("ok", "errmsg", "data") + + def __init__(self, *, ok: bool, errmsg: str = "", data: Any = None) -> None: + self.ok = ok + self.errmsg = errmsg + self.data = data + + def to_dict(self) -> Dict[str, Any]: + return {"ok": self.ok, "errmsg": self.errmsg, "data": self.data} + + def to_json_dict(self) -> Dict[str, Any]: + """返回 JSON 可序列化的字典,自动将实体 data 转为 dict。""" + data = self.data + errors: list[str] = [] + if not self.ok and self.errmsg: + errors.append(self.errmsg) + if data is not None and hasattr(data, "to_dict"): + data = data.to_dict() + # 收集实体内部子请求的错误 + if hasattr(data, "__getitem__"): + # to_dict 返回的 dict 不带 _errors,需要从实体取 + pass + entity_errors = getattr(self.data, "_errors", []) + errors.extend(entity_errors) + ok = self.ok and not errors + return {"ok": ok, "errmsg": "; ".join(errors) if errors else "", "data": data} + + def __repr__(self) -> str: + if self.ok: + return f"ApiResponse(ok=True, data={self.data!r})" + return f"ApiResponse(ok=False, errmsg={self.errmsg!r})" + + +__all__ = ["ApiLabelType", "ApiPaginationType", "ApiResponseType"] diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py new file mode 100644 index 000000000..7b4db8346 --- /dev/null +++ b/swanlab/api/typings/experiment.py @@ -0,0 +1,30 @@ +""" +@author: caddiesnew +@file: experiment.py +@time: 2026/4/20 +@description: 公共查询 API 实验类型定义 +""" + +from typing import Dict, List, Optional, TypedDict + +from .common import ApiLabelType, ApiRunStateType + + +class ApiExperimentUserType(TypedDict): + username: str + name: str + + +class ApiExperimentType(TypedDict): + cuid: str + name: str + description: str + labels: List[ApiLabelType] + profile: Dict[str, object] + show: bool + state: ApiRunStateType + cluster: str + job: str + user: ApiExperimentUserType + rootExpId: Optional[str] + rootProId: Optional[str] diff --git a/swanlab/api/typings/project.py b/swanlab/api/typings/project.py new file mode 100644 index 000000000..9b9bc5521 --- /dev/null +++ b/swanlab/api/typings/project.py @@ -0,0 +1,28 @@ +""" +@author: caddiesnew +@file: project.py +@time: 2026/4/20 +@description: 公共查询 API 项目类型定义 +""" + +from typing import Dict, List, TypedDict + +from .common import ApiLabelType, ApiVisibilityType + + +class ApiProjectCountType(TypedDict): + experiments: int + contributors: int + collaborators: int + clones: int + + +class ApiProjectType(TypedDict): + name: str + username: str + path: str + visibility: ApiVisibilityType + description: str + group: Dict[str, str] + projectLabels: List[ApiLabelType] + _count: ApiProjectCountType diff --git a/swanlab/api/typings/workspace.py b/swanlab/api/typings/workspace.py new file mode 100644 index 000000000..f769f2326 --- /dev/null +++ b/swanlab/api/typings/workspace.py @@ -0,0 +1,19 @@ +""" +@author: caddiesnew +@file: workspace.py +@time: 2026/4/20 +@description: 公共查询 API 工作空间类型定义 +""" + +from typing import Dict, TypedDict + +from .common import ApiRoleType, ApiWorkspaceType + + +class ApiWorkspaceInfoType(TypedDict): + name: str + username: str + profile: Dict[str, str] + type: ApiWorkspaceType + comment: str + role: ApiRoleType diff --git a/swanlab/cli/__init__.py b/swanlab/cli/__init__.py index 97fe8aa8d..ba8bce93b 100644 --- a/swanlab/cli/__init__.py +++ b/swanlab/cli/__init__.py @@ -9,6 +9,7 @@ from swanlab.sdk.internal.pkg.helper import get_swanlab_version +from .api import api_cli from .auth import login, logout, verify from .converter import convert from .dashboard import watch @@ -54,5 +55,8 @@ def cli(): # noinspection PyTypeChecker cli.add_command(disabled) +# Api Cli +# noinspection PyTypeChecker +cli.add_command(api_cli) __all__ = ["cli"] diff --git a/swanlab/cli/api/__init__.py b/swanlab/cli/api/__init__.py new file mode 100644 index 000000000..08e92a779 --- /dev/null +++ b/swanlab/cli/api/__init__.py @@ -0,0 +1,28 @@ +""" +@author: caddiesnew +@file: __init__.py +@time: 2026/4/20 +@description: CLI API 子命令 — 通过命令行调用 SwanLab 公共查询 API +""" + +import click + +from .experiment import experiment_cli +from .project import project_cli +from .selfhosted import selfhosted_cli +from .workspace import workspace_cli + + +@click.group("api") +def api_cli(): + """Generic SwanLab API requests.""" + pass + + +api_cli.add_command(project_cli) +api_cli.add_command(experiment_cli) +api_cli.add_command(workspace_cli) +api_cli.add_command(selfhosted_cli) + + +__all__ = ["api_cli"] diff --git a/swanlab/cli/api/experiment/__init__.py b/swanlab/cli/api/experiment/__init__.py new file mode 100644 index 000000000..b8e6d9b31 --- /dev/null +++ b/swanlab/cli/api/experiment/__init__.py @@ -0,0 +1,10 @@ +import click + +from swanlab.api.typings.common import ApiResponseType +from swanlab.cli.api.helper import format_output + + +@click.group("run") +def experiment_cli(): + """Experiment(Run) management commands.""" + pass diff --git a/swanlab/cli/api/helper.py b/swanlab/cli/api/helper.py new file mode 100644 index 000000000..4ecd04067 --- /dev/null +++ b/swanlab/cli/api/helper.py @@ -0,0 +1,10 @@ +import json + +import click + +from swanlab.api.typings.common import ApiResponseType + + +def format_output(resp: ApiResponseType) -> None: + """统一输出 ApiResponse JSON。""" + click.echo(json.dumps(resp.to_json_dict())) diff --git a/swanlab/cli/api/project/__init__.py b/swanlab/cli/api/project/__init__.py new file mode 100644 index 000000000..32a8741e5 --- /dev/null +++ b/swanlab/cli/api/project/__init__.py @@ -0,0 +1,10 @@ +import click + +from swanlab.api.typings.common import ApiResponseType +from swanlab.cli.api.helper import format_output + + +@click.group("project") +def project_cli(): + """Project management commands.""" + pass diff --git a/swanlab/cli/api/selfhosted/__init__.py b/swanlab/cli/api/selfhosted/__init__.py new file mode 100644 index 000000000..3003315f0 --- /dev/null +++ b/swanlab/cli/api/selfhosted/__init__.py @@ -0,0 +1,10 @@ +import click + +from swanlab.api.typings.common import ApiResponseType +from swanlab.cli.api.helper import format_output + + +@click.group("selfhosted") +def selfhosted_cli(): + """Self-hosted deployment management commands.""" + pass diff --git a/swanlab/cli/api/user/__init__.py b/swanlab/cli/api/user/__init__.py new file mode 100644 index 000000000..9a8286e14 --- /dev/null +++ b/swanlab/cli/api/user/__init__.py @@ -0,0 +1,10 @@ +import click + +from swanlab.api.typings.common import ApiResponseType +from swanlab.cli.api.helper import format_output + + +@click.group("user") +def user_cli(): + """User management commands.""" + pass diff --git a/swanlab/cli/api/workspace/__init__.py b/swanlab/cli/api/workspace/__init__.py new file mode 100644 index 000000000..2c3de65f3 --- /dev/null +++ b/swanlab/cli/api/workspace/__init__.py @@ -0,0 +1,10 @@ +import click + +from swanlab.api.typings.common import ApiResponseType +from swanlab.cli.api.helper import format_output + + +@click.group("workspace") +def workspace_cli(): + """Workspace management commands.""" + pass From c9409a99e4c24318c0b95e66240e0a97102a83e3 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Tue, 21 Apr 2026 19:49:13 +0800 Subject: [PATCH 02/52] feat: add typing skeleton --- swanlab/api/helper.py | 0 swanlab/api/typings/common.py | 3 ++ swanlab/api/typings/experiment.py | 8 +-- swanlab/api/typings/selfhosted.py | 24 +++++++++ swanlab/api/typings/user.py | 24 +++++++++ swanlab/api/utils.py | 85 +++++++++++++++++++++++++++++++ 6 files changed, 138 insertions(+), 6 deletions(-) delete mode 100644 swanlab/api/helper.py create mode 100644 swanlab/api/typings/selfhosted.py create mode 100644 swanlab/api/typings/user.py create mode 100644 swanlab/api/utils.py diff --git a/swanlab/api/helper.py b/swanlab/api/helper.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index 06c25cf8c..8011634be 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -7,6 +7,9 @@ from typing import Any, Dict, List, Literal, TypedDict +# 列类型 +ApiColumnType = Literal["SCALAR", "CONFIG", "STABLE"] + # 实验状态类型 ApiRunStateType = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index 7b4db8346..5a93a4bdb 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -8,11 +8,7 @@ from typing import Dict, List, Optional, TypedDict from .common import ApiLabelType, ApiRunStateType - - -class ApiExperimentUserType(TypedDict): - username: str - name: str +from .user import ApiUserType class ApiExperimentType(TypedDict): @@ -25,6 +21,6 @@ class ApiExperimentType(TypedDict): state: ApiRunStateType cluster: str job: str - user: ApiExperimentUserType + user: ApiUserType rootExpId: Optional[str] rootProId: Optional[str] diff --git a/swanlab/api/typings/selfhosted.py b/swanlab/api/typings/selfhosted.py new file mode 100644 index 000000000..337297165 --- /dev/null +++ b/swanlab/api/typings/selfhosted.py @@ -0,0 +1,24 @@ +""" +@author: caddiesnew +@file: user.py +@time: 2026/4/20 +@description: 公共查询 API self-hosted 类型定义 +""" + +from typing import TypedDict + +from .common import ApiLicensePlanType + + +class ApiApiKeyType(TypedDict): + id: int + name: str + key: str + + +class ApiSelfHostedInfoType(TypedDict): + enabled: bool + expired: bool + root: bool + plan: ApiLicensePlanType + seats: int diff --git a/swanlab/api/typings/user.py b/swanlab/api/typings/user.py new file mode 100644 index 000000000..13eaac06c --- /dev/null +++ b/swanlab/api/typings/user.py @@ -0,0 +1,24 @@ +""" +@author: caddiesnew +@file: user.py +@time: 2026/4/20 +@description: 公共查询 API 用户类型定义 +""" + +from typing import TypedDict + + +class ApiUserType(TypedDict): + name: str + username: str + + +class ApiUserProfileType(TypedDict): + bio: str + institution: str + localtion: str + school: str + email: str + idc: str + url: str + telephone: str diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py new file mode 100644 index 000000000..5ea2e9515 --- /dev/null +++ b/swanlab/api/utils.py @@ -0,0 +1,85 @@ +from typing import Dict, List, Optional, Tuple + + +def parse_column_type(column: str) -> str: + """从前缀中获取指标类型""" + column_type = column.split(".", 1)[0] + if column_type == "summary": + return "SCALAR" + elif column_type == "config": + return "CONFIG" + else: + return "STABLE" + + +def to_camel_case(name: str) -> str: + """将下划线命名转化为驼峰命名""" + return "".join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split("_"))]) + + +_SPECIAL_FILTER_MAP = { + # (backend_key, operator) — 用户侧 key 到后端字段名和操作符的映射 + # backend_key: 后端 API 实际接受的字段名 + # operator: 筛选操作符,EQ=精确匹配,IN=包含匹配(用于 tags 列表) + "group": ("cluster", "EQ"), + "tags": ("labels", "IN"), + "name": ("name", "EQ"), + "username": ("user.username", "EQ"), + "job_type": ("job", "EQ"), +} + + +def parse_filter(key: str, value: object) -> Dict[str, object]: + """将用户侧筛选条件转换为后端 filter 格式。 + + :param key: 筛选字段名。预定义字段(group/tags/name/username/job_type)会映射到后端字段名; + 其他字段按 column type 自动转换:STABLE 类型转 camelCase,其余取最后一段。 + :param value: 筛选值。预定义字段中 tags 接受列表/元组,其余均为单值(内部统一包装为列表)。 + :return: 后端 filter 字典,包含 key / active / value / op / type 五个字段。 + """ + if key in _SPECIAL_FILTER_MAP: + backend_key, op = _SPECIAL_FILTER_MAP[key] + filter_value = list(value) if key == "tags" and isinstance(value, (list, tuple)) else [value] + return {"key": backend_key, "active": True, "value": filter_value, "op": op, "type": "STABLE"} + ct = parse_column_type(key) + return { + "key": to_camel_case(key) if ct == "STABLE" else key.split(".", 1)[-1], + "active": True, + "value": [value], + "op": "EQ", + "type": ct, + } + + +def unwrap_api_payload(data): + """提取 raw resp 的 data 响应.""" + if isinstance(data, dict) and "data" in data and isinstance(data["data"], (dict, list)): + return data["data"] + return data + + +# mulitpart-save +def extract_upload_id(payload: Dict[str, object]) -> Optional[str]: + upload_id = payload.get("uploadId") + if isinstance(upload_id, str) and upload_id: + return upload_id + return None + + +# multipart-save +def extract_part_urls(payload: Dict[str, object]) -> List[Tuple[int, str]]: + parts = payload.get("parts") + if not isinstance(parts, list): + raise ValueError("Multipart upload URLs are missing in prepare response.") + + resolved = [] + for part in parts: + if not isinstance(part, dict): + raise ValueError("Multipart prepare response contains invalid part data.") + part_number = part.get("partNumber") + url = part.get("url") + if not isinstance(part_number, int) or not isinstance(url, str) or not url: + raise ValueError("Invalid partNumber or url in multipart response.") + resolved.append((part_number, url)) + + return sorted(resolved, key=lambda item: item[0]) From 42558fa331963e4012024d537cade0fdf3d51f09 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Tue, 21 Apr 2026 20:00:41 +0800 Subject: [PATCH 03/52] feat: add base request --- swanlab/api/base.py | 91 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/swanlab/api/base.py b/swanlab/api/base.py index e69de29bb..37ed015bc 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -0,0 +1,91 @@ +""" +@author: caddiesnew +@file: base.py +@time: 2026/4/20 +@description: 所有实体类的公共基类 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional + +from swanlab.sdk.internal.pkg import safe + +from .typings.common import ApiPaginationType, ApiResponseType + +if TYPE_CHECKING: + from swanlab.sdk.internal.pkg.client import Client + + +class BaseEntity(ABC): + """ + swanlab/api 实体类公共基类。 + + 统一持有 _client、_web_host 和 _api_host,提供 _get/_post/_put/_delete HTTP 快捷方法和 _paginate 分页迭代。 + 所有 HTTP 请求通过 wrapper_safe_request 包裹,保证任何异常都不会导致程序 crash,统一返回 ApiResponse。 + 子类只需实现 to_dict() 和业务逻辑。 + """ + + def __init__(self, client: "Client", web_host: str, api_host: str) -> None: + self._client: "Client" = client + self._web_host: str = web_host + self._api_host: str = api_host + self._errors: list[str] = [] + + @abstractmethod + def to_dict(self) -> Dict[str, Any]: + """将实体序列化为 JSON 可序列化的字典。""" + + def wrapper_safe_request(self, method: Callable, path: str, **kwargs) -> ApiResponseType: + """安全请求包装:捕获所有异常,始终返回 ApiResponse 而不抛出。""" + _err: list[str] = [] + common_err: str = f"API request failed: {path}" + + @safe.decorator(message=common_err, on_error=lambda e: _err.append(str(e))) + def _do(): + return method(path, **kwargs).data + + data = _do() + if data is not None: + return ApiResponseType(ok=True, data=data) + errmsg = _err[0] if _err else common_err + self._errors.append(errmsg) + return ApiResponseType(ok=False, errmsg=errmsg) + + def _get(self, path: str, **kwargs) -> ApiResponseType: + return self.wrapper_safe_request(self._client.get, path, **kwargs) + + def _post(self, path: str, **kwargs) -> ApiResponseType: + return self.wrapper_safe_request(self._client.post, path, **kwargs) + + def _put(self, path: str, **kwargs) -> ApiResponseType: + return self.wrapper_safe_request(self._client.put, path, **kwargs) + + def _delete(self, path: str, **kwargs) -> ApiResponseType: + return self.wrapper_safe_request(self._client.delete, path, **kwargs) + + def _build_url(self, path: str) -> str: + return f"{self._api_host}/{path}" + + def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = None) -> Iterator[dict]: + """通用分页迭代器,自动处理 page/size 参数。""" + page = 1 + while True: + p = {"page": page, "size": page_size} + if params: + p.update({k: v for k, v in params.items() if v is not None}) + resp = self._get(path, params=p) + if not resp.ok: + return + body: ApiPaginationType = resp.data + items = body["list"] + if not items: + break + yield from items + if page >= body["pages"]: + break + page += 1 + + def __repr__(self) -> str: + cls = self.__class__.__name__ + ident = getattr(self, "_path", None) or getattr(self, "_username", None) or "?" + return f"{cls}('{ident}')" From da03a7da92eff29156cb9568f902e89e7d218fae Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Tue, 21 Apr 2026 20:05:02 +0800 Subject: [PATCH 04/52] feat: add util --- swanlab/api/typings/__init__.py | 6 ++++++ swanlab/api/utils.py | 21 ++++++++++++++++++++- swanlab/api/workspace/__init__.py | 0 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 swanlab/api/typings/__init__.py create mode 100644 swanlab/api/workspace/__init__.py diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py new file mode 100644 index 000000000..331a23e36 --- /dev/null +++ b/swanlab/api/typings/__init__.py @@ -0,0 +1,6 @@ +""" +@author: caddiesnew +@file: __init__.py +@time: 2026/4/21 18:40 +@description: SwanLab OpenAPI 类型提示, 以 Api 前缀区分 +""" diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 5ea2e9515..ce621158d 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -1,4 +1,23 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple + + +def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str, object]: + """递归获取实例中所有 property 的值,用于 to_dict() 默认实现。""" + if _visited is None: + _visited = set() + obj_id = id(obj) + if obj_id in _visited: + return {} + _visited = _visited | {obj_id} + + result = {} + for name in dir(obj): + if name.startswith("_"): + continue + if isinstance(getattr(type(obj), name, None), property): + value = getattr(obj, name, None) + result[name] = value if type(value).__module__ == "builtins" else get_properties(value, _visited) + return result def parse_column_type(column: str) -> str: diff --git a/swanlab/api/workspace/__init__.py b/swanlab/api/workspace/__init__.py new file mode 100644 index 000000000..e69de29bb From 4b2be0bdc0707c8602cebd2667656f9b6f3c81e1 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Tue, 21 Apr 2026 20:23:26 +0800 Subject: [PATCH 05/52] feat: add implementation skeleton --- swanlab/api/experiment/__init__.py | 279 +++++++++++++++++++++++++++++ swanlab/api/project/__init__.py | 134 ++++++++++++++ swanlab/api/selfhosted/__init__.py | 0 swanlab/api/typings/common.py | 18 +- swanlab/api/typings/experiment.py | 10 +- swanlab/api/typings/project.py | 10 +- swanlab/api/typings/selfhosted.py | 4 +- swanlab/api/typings/workspace.py | 6 +- swanlab/api/user/__init__.py | 0 swanlab/api/workspace/__init__.py | 118 ++++++++++++ 10 files changed, 557 insertions(+), 22 deletions(-) create mode 100644 swanlab/api/experiment/__init__.py create mode 100644 swanlab/api/project/__init__.py create mode 100644 swanlab/api/selfhosted/__init__.py create mode 100644 swanlab/api/user/__init__.py diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment/__init__.py new file mode 100644 index 000000000..0c089e725 --- /dev/null +++ b/swanlab/api/experiment/__init__.py @@ -0,0 +1,279 @@ +""" +@author: caddiesnew +@file: experiment.py +@time: 2026/4/20 +@description: Experiment 实体类 — 单个实验的查询与操作 +""" + +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast + +from swanlab.api.base import BaseEntity +from swanlab.api.typings.experiment import ApiExperimentLabelType, ApiExperimentType +from swanlab.api.typings.user import ApiUserType +from swanlab.api.utils import get_properties, parse_column_type, to_camel_case + +if TYPE_CHECKING: + from swanlab.sdk.internal.pkg.client import Client + + +class Profile: + """Experiment profile containing config, metadata, requirements, and conda info.""" + + def __init__(self, data: Dict) -> None: + self._data = data + + @staticmethod + def _clean_field(value: Any) -> Any: + """Recursively clean config field, removing desc/sort and keeping value.""" + if isinstance(value, dict): + if "value" in value: + return Profile._clean_field(value["value"]) + else: + return {k: Profile._clean_field(v) for k, v in value.items()} + elif isinstance(value, list): + return [Profile._clean_field(item) for item in value] + return value + + @property + def config(self) -> Dict: + """Experiment configuration (cleaned, without desc/sort fields).""" + raw_config = self._data.get("config", {}) + return {k: Profile._clean_field(v) for k, v in raw_config.items()} if isinstance(raw_config, dict) else {} + + @property + def metadata(self) -> Dict: + return self._data.get("metadata", {}) + + @property + def requirements(self) -> str: + return self._data.get("requirements", "") + + @property + def conda(self) -> str: + return self._data.get("conda", "") + + +class Experiment(BaseEntity): + """ + 表示一个 SwanLab 实验。 + + 支持双模式:构造时传入 data,或 data=None(按需懒加载)。 + """ + + def __init__( + self, + client: "Client", + web_host: str, + api_host: str, + *, + path: str, + data: Optional[ApiExperimentType] = None, + ) -> None: + super().__init__(client, web_host, api_host) + self._path = path # 'username/project-name' + self._data = data + + def _ensure_data(self) -> ApiExperimentType: + if self._data is None: + resp = self._get(f"/project/{self._path}/runs/{self.id}") + self._data = resp.data if resp.ok and resp.data else cast(ApiExperimentType, {}) + return self._data + + @property + def id(self) -> str: + return self._data.get("cuid", "") if self._data is not None else self._ensure_data().get("cuid", "") + + @property + def name(self) -> str: + return self._ensure_data().get("name", "") + + @property + def description(self) -> str: + return self._ensure_data().get("description", "") + + @property + def state(self) -> str: + return self._ensure_data().get("state", "") + + @property + def url(self) -> str: + return self._build_url(f"@{self._path}/runs/{self.id}/chart") + + @property + def show(self) -> bool: + return self._ensure_data().get("show", True) + + @property + def labels(self) -> List[ApiExperimentLabelType]: + return [label for label in self._ensure_data().get("labels", [])] + + @property + def group(self) -> str: + return self._ensure_data().get("cluster", "") + + @property + def job_type(self) -> str: + return self._ensure_data().get("job", "") + + @property + def user(self) -> ApiUserType: + user_data = self._ensure_data().get("user", {}) + return user_data if isinstance(user_data, dict) else cast(ApiUserType, {}) + + @property + def created_at(self) -> str: + return self._ensure_data().get("createdAt", "") + + @property + def finished_at(self) -> str: + return self._ensure_data().get("finishedAt", "") + + @property + def profile(self) -> Profile: + """Experiment profile containing config, metadata, requirements, and conda.""" + data = self._ensure_data() + if "profile" not in data and self.id: + resp = self._get(f"/project/{self._path}/runs/{self.id}") + if resp.ok and resp.data: + self._data = resp.data + data = self._data + return Profile(data.get("profile", {})) + + def metrics( + self, keys: Optional[List[str]] = None, x_axis: Optional[str] = None, sample: Optional[int] = None + ) -> Any: + """ + 获取实验指标数据,返回 pandas DataFrame。 + + :param keys: 指标 key 列表 + :param x_axis: x 轴指标,默认 step + :param sample: 均匀采样 N 条数据(等间距采样,保留整体趋势) + """ + from swanlab.vendor import pd + + if not keys: + return pd.DataFrame() + + fetch_keys = list(keys) + use_x_axis = x_axis is not None and x_axis != "step" + if use_x_axis and x_axis is not None: + fetch_keys.append(x_axis) + + dfs = [] + prefix = "" + for idx, key in enumerate(fetch_keys): + resp = self._get(f"/experiment/{self.id}/column/csv", params={"key": key}) + if not resp.ok: + continue + data = resp.data + csv_url = data[0].get("url", "") if isinstance(data, list) and data else "" + if not csv_url: + continue + df = pd.read_csv(csv_url, index_col=0) + + if idx == 0: + first_col = str(df.columns[0]) + suffix = f"{key}_" + prefix = first_col.split(suffix)[0] if suffix in first_col else "" + + def strip_suffix(col, suffix="_step"): + return col[: -len(suffix)] if col.endswith(suffix) else col + + df.columns = [ + strip_suffix(col[len(prefix) :]) if prefix and col.startswith(prefix) else strip_suffix(col) + for col in df.columns + ] + dfs.append(df) + + if not dfs: + return pd.DataFrame() + + result_df = dfs[0].join(dfs[1:], how="outer") if len(dfs) > 1 else dfs[0] + result_df = result_df.sort_index() + + if use_x_axis: + result_df = result_df.drop( + columns=[c for c in result_df.columns if c.endswith("_timestamp")], errors="ignore" + ) + if x_axis not in result_df.columns: + return pd.DataFrame() + cols = [x_axis] + [c for c in result_df.columns if c != x_axis] + result_df = result_df[cols].dropna(subset=[x_axis]) + + if sample is not None and len(result_df) > sample: + indices = [int(i * (len(result_df) - 1) / (sample - 1)) for i in range(sample)] + result_df = result_df.iloc[indices] + + return result_df + + def delete(self) -> bool: + """删除此实验。""" + resp = self._delete(f"/project/{self._path}/runs/{self.id}") + return resp.ok + + def to_dict(self) -> Dict[str, Any]: + return get_properties(self) + + +def _flatten_runs(runs: Union[list, Dict]) -> list: + """展开分组后的实验数据,返回一个包含所有实验的列表。""" + if isinstance(runs, dict): + return [item for v in runs.values() for item in _flatten_runs(v)] + if isinstance(runs, list): + return list(runs) + return [runs] + + +class Experiments(BaseEntity): + """ + 项目下实验集合的迭代器。 + + 用法:: + + for run in api.runs("username/project"): + print(run.name) + """ + + def __init__( + self, + client: "Client", + web_host: str, + api_host: str, + *, + path: str, + filters: Optional[Dict[str, object]] = None, + ) -> None: + super().__init__(client, web_host, api_host) + self._path = path + self._filters = filters + + def __iter__(self) -> Iterator[Experiment]: + parsed_filters = ( + [ + { + "key": to_camel_case(key) if parse_column_type(key) == "STABLE" else key.split(".", 1)[-1], + "active": True, + "value": [value], + "op": "EQ", + "type": parse_column_type(key), + } + for key, value in self._filters.items() + ] + if self._filters + else [] + ) + resp = self._post(f"/project/{self._path}/runs/shows", data={"filters": parsed_filters}) + if not resp.ok: + return + body = resp.data + runs: list = [] + if isinstance(body, list): + runs = body + elif isinstance(body, dict): + runs = _flatten_runs(body) + + for run_data in runs: + yield Experiment(self._client, self._web_host, self._api_host, path=self._path, data=run_data) + + def to_dict(self) -> Dict[str, Any]: + return {"path": self._path} diff --git a/swanlab/api/project/__init__.py b/swanlab/api/project/__init__.py new file mode 100644 index 000000000..a8b28970d --- /dev/null +++ b/swanlab/api/project/__init__.py @@ -0,0 +1,134 @@ +""" +@author: caddiesnew +@file: project.py +@time: 2026/4/20 +@description: Project 实体类 — 单个项目的查询与操作 +""" + +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast + +from swanlab.api.base import BaseEntity +from swanlab.api.typings.project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType +from swanlab.api.utils import get_properties + +if TYPE_CHECKING: + from swanlab.sdk.internal.pkg.client import Client + + +class Project(BaseEntity): + """ + 表示一个 SwanLab 项目。 + + 支持双模式:构造时传入 data(列表迭代注入),或 data=None(按需懒加载)。 + """ + + def __init__( + self, + client: "Client", + web_host: str, + api_host: str, + *, + path: str, + data: Optional[ApiProjectType] = None, + ) -> None: + super().__init__(client, web_host, api_host) + self._path = path + self._data = data + + def _ensure_data(self) -> ApiProjectType: + if self._data is None: + resp = self._get(f"/project/{self._path}") + self._data = resp.data if resp.ok and resp.data else cast(ApiProjectType, {}) + return self._data + + @property + def name(self) -> str: + return self._ensure_data().get("name", "") + + @property + def path(self) -> str: + return self._ensure_data().get("path", "") + + @property + def url(self) -> str: + return self._build_url(f"@{self.path}") + + @property + def description(self) -> str: + return self._ensure_data().get("description", "") + + @property + def visibility(self) -> str: + return self._ensure_data().get("visibility", "PUBLIC") + + @property + def created_at(self) -> str: + return self._ensure_data().get("createdAt", "") + + @property + def updated_at(self) -> str: + return self._ensure_data().get("updatedAt", "") + + @property + def labels(self) -> List[ApiProjectLabelType]: + return [label for label in self._ensure_data().get("projectLabels", [])] + + @property + def count(self) -> ApiProjectCountType: + return self._ensure_data().get("_count", {}) + + def runs(self, filters: Optional[Dict[str, object]] = None): + """获取项目下的实验列表。""" + from swanlab.api.experiment import Experiments + + return Experiments(self._client, self._web_host, self._api_host, path=self.path, filters=filters) + + def delete(self) -> bool: + """删除此项目。""" + resp = self._delete(f"/project/{self.path}") + return resp.ok + + def to_dict(self) -> Dict[str, Any]: + return get_properties(self) + + +class Projects(BaseEntity): + """ + 工作空间下项目集合的分页迭代器。 + + 用法:: + + for project in api.projects("username"): + print(project.name) + """ + + def __init__( + self, + client: "Client", + web_host: str, + api_host: str, + *, + path: str, + sort: Optional[str] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ) -> None: + super().__init__(client, web_host, api_host) + self._path = path + self._sort = sort + self._search = search + self._detail = detail + + def __iter__(self) -> Iterator[Project]: + params = {"sort": self._sort, "search": self._search, "detail": self._detail} + for item in self._paginate(f"/project/{self._path}", params=params): + yield Project( + self._client, + self._web_host, + self._api_host, + path=str(item.get("path", "")), + data=cast(ApiProjectType, item), + ) + + def to_dict(self) -> Dict[str, Any]: + return {"path": self._path} diff --git a/swanlab/api/selfhosted/__init__.py b/swanlab/api/selfhosted/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index 8011634be..a48b81eb4 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -8,29 +8,25 @@ from typing import Any, Dict, List, Literal, TypedDict # 列类型 -ApiColumnType = Literal["SCALAR", "CONFIG", "STABLE"] +ApiColumnEum = Literal["SCALAR", "CONFIG", "STABLE"] # 实验状态类型 -ApiRunStateType = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] +ApiRunStateEnum = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] # 可见性类型 -ApiVisibilityType = Literal["PUBLIC", "PRIVATE"] +ApiVisibilityEnum = Literal["PUBLIC", "PRIVATE"] # 工作空间类型 -ApiWorkspaceType = Literal["TEAM", "PERSON"] +ApiWorkspaceEnum = Literal["TEAM", "PERSON"] # 工作空间成员类型 -ApiRoleType = Literal["VISITOR", "VIEWER", "MEMBER", "OWNER"] +ApiRoleEnum = Literal["VISITOR", "VIEWER", "MEMBER", "OWNER"] # Self-Hosted 身份类型 -ApiIdentityType = Literal["root", "user"] +ApiIdentityEnum = Literal["root", "user"] # License 许可证类型 -ApiLicensePlanType = Literal["free", "commercial"] - - -class ApiLabelType(TypedDict): - name: str +ApiLicensePlanEnum = Literal["free", "commercial"] class ApiPaginationType(TypedDict): diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index 5a93a4bdb..bbecb3f36 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -7,18 +7,22 @@ from typing import Dict, List, Optional, TypedDict -from .common import ApiLabelType, ApiRunStateType +from .common import ApiRunStateEnum from .user import ApiUserType +class ApiExperimentLabelType(TypedDict): + name: str + + class ApiExperimentType(TypedDict): cuid: str name: str description: str - labels: List[ApiLabelType] + labels: List[ApiExperimentLabelType] profile: Dict[str, object] show: bool - state: ApiRunStateType + state: ApiRunStateEnum cluster: str job: str user: ApiUserType diff --git a/swanlab/api/typings/project.py b/swanlab/api/typings/project.py index 9b9bc5521..57e55d40c 100644 --- a/swanlab/api/typings/project.py +++ b/swanlab/api/typings/project.py @@ -7,7 +7,11 @@ from typing import Dict, List, TypedDict -from .common import ApiLabelType, ApiVisibilityType +from .common import ApiVisibilityEnum + + +class ApiProjectLabelType(TypedDict): + name: str class ApiProjectCountType(TypedDict): @@ -21,8 +25,8 @@ class ApiProjectType(TypedDict): name: str username: str path: str - visibility: ApiVisibilityType + visibility: ApiVisibilityEnum description: str group: Dict[str, str] - projectLabels: List[ApiLabelType] + projectLabels: List[ApiProjectLabelType] _count: ApiProjectCountType diff --git a/swanlab/api/typings/selfhosted.py b/swanlab/api/typings/selfhosted.py index 337297165..feb60defc 100644 --- a/swanlab/api/typings/selfhosted.py +++ b/swanlab/api/typings/selfhosted.py @@ -7,7 +7,7 @@ from typing import TypedDict -from .common import ApiLicensePlanType +from .common import ApiLicensePlanEnum class ApiApiKeyType(TypedDict): @@ -20,5 +20,5 @@ class ApiSelfHostedInfoType(TypedDict): enabled: bool expired: bool root: bool - plan: ApiLicensePlanType + plan: ApiLicensePlanEnum seats: int diff --git a/swanlab/api/typings/workspace.py b/swanlab/api/typings/workspace.py index f769f2326..5494186be 100644 --- a/swanlab/api/typings/workspace.py +++ b/swanlab/api/typings/workspace.py @@ -7,13 +7,13 @@ from typing import Dict, TypedDict -from .common import ApiRoleType, ApiWorkspaceType +from .common import ApiRoleEnum, ApiWorkspaceEnum class ApiWorkspaceInfoType(TypedDict): name: str username: str profile: Dict[str, str] - type: ApiWorkspaceType + type: ApiWorkspaceEnum comment: str - role: ApiRoleType + role: ApiRoleEnum diff --git a/swanlab/api/user/__init__.py b/swanlab/api/user/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/swanlab/api/workspace/__init__.py b/swanlab/api/workspace/__init__.py index e69de29bb..7589d19b8 100644 --- a/swanlab/api/workspace/__init__.py +++ b/swanlab/api/workspace/__init__.py @@ -0,0 +1,118 @@ +""" +@author: caddiesnew +@file: workspace.py +@time: 2026/4/20 +@description: Workspace 实体类 — 工作空间的查询 +""" + +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, cast + +from swanlab.api.base import BaseEntity +from swanlab.api.typings.workspace import ApiWorkspaceEnum, ApiWorkspaceInfoType +from swanlab.api.utils import get_properties + +if TYPE_CHECKING: + from swanlab.sdk.internal.pkg.client import Client + + +class Workspace(BaseEntity): + """ + 表示一个 SwanLab 工作空间(个人或团队)。 + """ + + def __init__( + self, + client: "Client", + web_host: str, + api_host: str, + *, + username: str, + data: Optional[ApiWorkspaceInfoType] = None, + ) -> None: + super().__init__(client, web_host, api_host) + self._username = username + self._data = data + + def _ensure_data(self) -> ApiWorkspaceInfoType: + if self._data is None: + resp = self._get(f"/group/{self._username}") + self._data = resp.data if resp.ok and resp.data else cast(ApiWorkspaceInfoType, {}) + return self._data + + @property + def name(self) -> str: + return self._ensure_data().get("name", "") + + @property + def username(self) -> str: + return self._ensure_data().get("username", "") + + @property + def workspace_type(self) -> ApiWorkspaceEnum: + return self._ensure_data().get("type", "") + + @property + def profile(self) -> Dict[str, str]: + return self._ensure_data().get("profile", {}) + + @property + def comment(self) -> str: + return self._ensure_data().get("comment", "") + + @property + def role(self) -> str: + return self._ensure_data().get("role", "") + + def projects( + self, + sort: Optional[str] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ): + """获取工作空间下的项目列表。""" + from swanlab.api.project import Projects + + return Projects( + self._client, + self._web_host, + self._api_host, + path=self.username, + sort=sort, + search=search, + detail=detail, + ) + + def to_dict(self) -> Dict[str, Any]: + return get_properties(self) + + +class Workspaces(BaseEntity): + """ + 用户工作空间集合的迭代器。 + + 用法:: + + for ws in api.workspaces("username"): + print(ws.name) + """ + + def __init__(self, client: "Client", web_host: str, api_host: str, *, username: str) -> None: + super().__init__(client, web_host, api_host) + self._username = username + + def _get_all_workspace_names(self) -> list[str]: + """获取用户个人空间 + 所属团队空间名称列表。""" + resp = self._get(f"/user/{self._username}/groups") + if not resp.ok: + return [self._username] + group_names = [r["username"] for r in resp.data] + return [self._username] + group_names + + def __iter__(self) -> Iterator[Workspace]: + for name in self._get_all_workspace_names(): + resp = self._get(f"/group/{name}") + data = resp.data if resp.ok else None + yield Workspace(self._client, self._web_host, self._api_host, username=name, data=data) + + def to_dict(self) -> Dict[str, Any]: + return {"username": self._username} From 48a18844ce47be044d344b9e6d3d20c9f6e9f375 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Tue, 21 Apr 2026 20:48:40 +0800 Subject: [PATCH 06/52] feat: add api base entity --- swanlab/api/__init__.py | 202 +++++++++++++++++++++++++++++ swanlab/api/base.py | 9 +- swanlab/api/selfhosted/__init__.py | 77 +++++++++++ swanlab/api/typings/common.py | 4 +- swanlab/api/utils.py | 42 ++++++ 5 files changed, 329 insertions(+), 5 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index e69de29bb..b68171f47 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -0,0 +1,202 @@ +""" +@author: caddiesnew +@file: __init__.py +@time: 2026/4/20 +@description: SwanLab 公共查询 API 入口,面向用户的 OOP 查询接口 +""" + +from typing import Optional + +from swanlab.exceptions import AuthenticationError +from swanlab.sdk.internal.pkg import nrc, scope +from swanlab.sdk.internal.pkg.client import Client +from swanlab.sdk.internal.settings import settings as global_settings + +from .base import BaseEntity +from .experiment import Experiment, Experiments +from .project import Project, Projects +from .typings.common import ApiResponseType +from .workspace import Workspace, Workspaces + + +class Api(BaseEntity): + """ + SwanLab 公共查询 API 入口。 + + 通过独立的 Client 实例与 SwanLab 云端交互,不与 SDK 运行时单例共享。 + 继承 BaseEntity 以复用 _get/_post/_put/_delete/_paginate 等安全 HTTP 方法。 + + 用法:: + + from swanlab import Api + + api = Api() # 自动从 .netrc 读取凭证 + api = Api(api_key="...", host="...") # 显式传入凭证 + + resp = api.project("username/project") + if resp.ok: + project = resp.data + print(project.name) + """ + + def __init__( + self, + api_key: Optional[str] = None, + host: Optional[str] = None, + web_host: Optional[str] = None, + ) -> None: + """ + 初始化 Api 实例。 + + 认证优先级:显式参数 > Settings(含 .netrc / 环境变量) + + :param api_key: API 密钥,为 None 时从 Settings / .netrc / 环境变量读取 + :param host: API 主机地址,为 None 时从 Settings 读取 + :param web_host: Web 面板地址,为 None 时从 Settings 读取 + """ + api_key, api_host, resolved_web_host = self._resolve_credentials(api_key, host, web_host) + client = Client(api_key=str(api_key), base_url=api_host) + super().__init__(client, resolved_web_host, api_host) + self._login_resp = scope.get_context("login_resp") + self._username: str = self._login_resp["userInfo"]["username"] if self._login_resp else "" + + def to_dict(self) -> dict: + """Api 非数据实体,返回空字典。""" + return {} + + @staticmethod + def _resolve_credentials( + api_key: Optional[str], + host: Optional[str], + web_host: Optional[str], + ) -> tuple[str, str, str]: + """ + 按优先级解析凭证:显式参数 > Settings(含 .netrc / 环境变量)。 + 返回 (api_key, api_host, web_host)。 + """ + if api_key is None: + api_key = global_settings.api_key + if api_key is None: + raise AuthenticationError("No API key found. Please login with `swanlab login` or pass api_key parameter.") + + api_host: str = nrc.fmt(host) if host is not None else global_settings.api_host + resolved_web_host: str = nrc.fmt(web_host) if web_host is not None else global_settings.web_host + + return api_key, api_host, resolved_web_host + + # ------------------------------------------------------------------ + # 实体查询方法 — 统一返回 ApiResponse + # ------------------------------------------------------------------ + + def workspace(self, username: Optional[str] = None) -> ApiResponseType: + """ + 获取工作空间信息,默认为当前登录用户的工作空间。 + + :param username: 指定工作空间用户名,为 None 时使用当前登录用户 + """ + if username is None: + username = self._username + resp = self._get(f"/group/{username}") + if resp.ok: + return ApiResponseType( + ok=True, + data=Workspace( + self._client, + self._web_host, + self._api_host, + username=username, + data=resp.data, + ), + ) + return resp + + def workspaces(self, username: Optional[str] = None) -> ApiResponseType: + """ + 获取工作空间列表迭代器。 + + :param username: 指定用户名,为 None 时使用当前登录用户 + """ + if username is None: + username = self._username + return ApiResponseType( + ok=True, + data=Workspaces(self._client, self._web_host, self._api_host, username=username), + ) + + def project(self, path: str) -> ApiResponseType: + """ + 获取项目信息。 + + :param path: 项目路径,格式为 'username/project-name' + """ + resp = self._get(f"/project/{path}") + if resp.ok: + return ApiResponseType( + ok=True, + data=Project(self._client, self._web_host, self._api_host, path=path, data=resp.data), + ) + return resp + + def projects( + self, + path: str, + sort: Optional[str] = None, + search: Optional[str] = None, + detail: Optional[bool] = True, + ) -> ApiResponseType: + """ + 获取工作空间下的项目列表迭代器。 + + :param path: 工作空间名称 'username' + :param sort: 排序方式 + :param search: 搜索关键词 + :param detail: 是否返回详细信息 + """ + return ApiResponseType( + ok=True, + data=Projects( + self._client, self._web_host, self._api_host, path=path, sort=sort, search=search, detail=detail + ), + ) + + def run(self, path: str) -> ApiResponseType: + """ + 获取单个实验。 + + :param path: 实验路径,格式为 'username/project/run_id' + """ + parts = path.split("/") + if len(parts) != 3: + return ApiResponseType( + ok=False, errmsg=f"Invalid path '{path}'. Expected format: 'username/project/run_id'" + ) + proj_path = path.rsplit("/", 1)[0] + expid = parts[2] + resp = self._get(f"/project/{proj_path}/runs/{expid}") + if resp.ok: + return ApiResponseType( + ok=True, + data=Experiment( + self._client, + self._web_host, + self._api_host, + path=proj_path, + data=resp.data, + ), + ) + return resp + + def runs(self, path: str, filters: Optional[dict] = None) -> ApiResponseType: + """ + 获取项目下的实验列表迭代器。 + + :param path: 项目路径,格式为 'username/project' + :param filters: 筛选条件 + """ + return ApiResponseType( + ok=True, + data=Experiments(self._client, self._web_host, self._api_host, path=path, filters=filters), + ) + + +__all__ = ["Api"] \ No newline at end of file diff --git a/swanlab/api/base.py b/swanlab/api/base.py index 37ed015bc..5f24f241c 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -10,7 +10,7 @@ from swanlab.sdk.internal.pkg import safe -from .typings.common import ApiPaginationType, ApiResponseType +from .typings.common import ApiResponseType if TYPE_CHECKING: from swanlab.sdk.internal.pkg.client import Client @@ -76,12 +76,13 @@ def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = resp = self._get(path, params=p) if not resp.ok: return - body: ApiPaginationType = resp.data - items = body["list"] + body = resp.data + items = body.get("list", []) if isinstance(body, dict) else body if not items: break yield from items - if page >= body["pages"]: + total_pages = body.get("pages", 1) if isinstance(body, dict) else 1 + if page >= total_pages: break page += 1 diff --git a/swanlab/api/selfhosted/__init__.py b/swanlab/api/selfhosted/__init__.py index e69de29bb..2cec5f87f 100644 --- a/swanlab/api/selfhosted/__init__.py +++ b/swanlab/api/selfhosted/__init__.py @@ -0,0 +1,77 @@ +""" +@author: caddiesnew +@file: project.py +@time: 2026/4/20 +@description: Project 实体类 — 单个项目的查询与操作 +""" + +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast + +from swanlab.api.base import BaseEntity +from swanlab.api.typings.common import ApiResponseType +from swanlab.api.typings.selfhosted import ApiApiKeyType, ApiLicensePlanEnum, ApiSelfHostedInfoType +from swanlab.api.utils import get_properties + +if TYPE_CHECKING: + from swanlab.sdk.internal.pkg.client import Client + + +class SelfHosted(BaseEntity): + """ + 表示一个 SwanLab 项目。 + + 支持双模式:构造时传入 data(列表迭代注入),或 data=None(按需懒加载)。 + """ + + def __init__( + self, + client: "Client", + web_host: str, + api_host: str, + ) -> None: + super().__init__(client, web_host, api_host) + + def _ensure_data(self) -> ApiSelfHostedInfoType: + if self._data is None: + resp = self._get("/self_hosted/info") + self._data = resp.data if resp.ok and resp.data else cast(ApiSelfHostedInfoType, {}) + return self._data + + @property + def enabled(self) -> bool: + return self._ensure_data().get("enabled", False) + + @property + def expired(self) -> bool: + return self._ensure_data().get("expired", False) + + @property + def root(self) -> bool: + return self._ensure_data().get("root", False) + + @property + def plan(self) -> ApiLicensePlanEnum: + return self._ensure_data().get("plan", "free") + + @property + def seats(self) -> int: + return self._ensure_data().get("seats", 0) + + def create_user(self, username: str, password: str) -> None: + """ + 添加用户(私有化管理员限定) + :param username: 待创建用户名 + :param password: 待创建用户密码 + """ + data = {"users": [{"username": username, "password": password}]} + self._post("/self_hosted/users", data=data) + + def get_users(self, page_num: int = 1, page_size: int = 20) -> ApiResponseType: + """ + 分页获取用户(管理员限定) + :param client: 已登录的客户端实例 + :param page: 页码 + :param size: 每页大小 + """ + params = {"page": page_num, "size": page_size} + return self._get("/self_hosted/users", params=params) diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index a48b81eb4..4d3d945c7 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Literal, TypedDict # 列类型 -ApiColumnEum = Literal["SCALAR", "CONFIG", "STABLE"] +ApiColumnEnum = Literal["SCALAR", "CONFIG", "STABLE"] # 实验状态类型 ApiRunStateEnum = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] @@ -29,6 +29,8 @@ ApiLicensePlanEnum = Literal["free", "commercial"] + + class ApiPaginationType(TypedDict): list: List size: int diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index ce621158d..f88bc48be 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -1,5 +1,8 @@ +from functools import wraps from typing import Dict, List, Optional, Set, Tuple +from swanlab.api.typings.common import ApiIdentityEnum + def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str, object]: """递归获取实例中所有 property 的值,用于 to_dict() 默认实现。""" @@ -36,6 +39,45 @@ def to_camel_case(name: str) -> str: return "".join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split("_"))]) +#TODO: 私有化接口装饰器 +# def with_self_hosted(identity: ApiIdentityEnum = "user"): +# """ +# 用于需要在私有化环境下使用的接口的装饰器。 +# :param identity: 用户身份,默认为 "user",如果为 "root",则会额外验证是否为根用户。 +# """ + +# def decorator(func): +# @wraps(func) +# def wrapper(self, *args, **kwargs): +# client = getattr(self, "_client", None) +# if not isinstance(client, Client): +# raise AttributeError("There is no SwanLab client instance.") + +# # 1. 尝试获取私有化服务信息 +# try: +# self_hosted_info = get_self_hosted_init(client) +# except ApiError: +# raise ValueError("You haven't launched a swanlab self-hosted instance. This usages are not available.") + +# if not self_hosted_info.get("enabled", False): +# raise ValueError("SwanLab self-hosted instance hasn't been ready yet.") +# if self_hosted_info.get("expired", True): +# raise ValueError("SwanLab self-hosted instance has expired.") + +# # 2. 检测用户权限(商业版root用户功能) +# if identity == "root": +# if not self_hosted_info.get("root", False): +# raise ValueError("You don't have permission to perform this action. Please login as a root user") +# if not getattr(self, "is_self", True): +# raise ValueError("This root-only action can only be performed by the logged-in root user.") + +# return func(self, *args, **kwargs) + +# return wrapper + +# return decorator + + _SPECIAL_FILTER_MAP = { # (backend_key, operator) — 用户侧 key 到后端字段名和操作符的映射 # backend_key: 后端 API 实际接受的字段名 From 4a653c8fddbb7d132551be01b15d6eeee9d22df9 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 10:48:31 +0800 Subject: [PATCH 07/52] chore: remove unused typing --- swanlab/api/__init__.py | 2 +- swanlab/api/typings/common.py | 4 +--- swanlab/api/utils.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index b68171f47..42a2e9e56 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -199,4 +199,4 @@ def runs(self, path: str, filters: Optional[dict] = None) -> ApiResponseType: ) -__all__ = ["Api"] \ No newline at end of file +__all__ = ["Api"] diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index 4d3d945c7..c1748693a 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -29,8 +29,6 @@ ApiLicensePlanEnum = Literal["free", "commercial"] - - class ApiPaginationType(TypedDict): list: List size: int @@ -79,4 +77,4 @@ def __repr__(self) -> str: return f"ApiResponse(ok=False, errmsg={self.errmsg!r})" -__all__ = ["ApiLabelType", "ApiPaginationType", "ApiResponseType"] +__all__ = ["ApiPaginationType", "ApiResponseType"] diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index f88bc48be..398592404 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -39,7 +39,7 @@ def to_camel_case(name: str) -> str: return "".join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split("_"))]) -#TODO: 私有化接口装饰器 +# TODO: 私有化接口装饰器 # def with_self_hosted(identity: ApiIdentityEnum = "user"): # """ # 用于需要在私有化环境下使用的接口的装饰器。 From eb275b63ee74eb65a6ebe6344338237807f33519 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 12:32:05 +0800 Subject: [PATCH 08/52] refactor: simplify code --- swanlab/api/__init__.py | 1 + swanlab/api/base.py | 28 +++---- .../{experiment/__init__.py => experiment.py} | 44 +++++----- .../api/{project/__init__.py => project.py} | 6 +- .../{selfhosted/__init__.py => selfhosted.py} | 35 ++++---- swanlab/api/typings/__init__.py | 41 +++++++++ swanlab/api/typings/common.py | 18 ++-- swanlab/api/typings/experiment.py | 4 +- swanlab/api/typings/project.py | 4 +- swanlab/api/typings/selfhosted.py | 6 +- swanlab/api/typings/user.py | 2 +- swanlab/api/typings/workspace.py | 6 +- swanlab/api/user/__init__.py | 0 swanlab/api/utils.py | 83 ++----------------- .../{workspace/__init__.py => workspace.py} | 10 +-- 15 files changed, 129 insertions(+), 159 deletions(-) rename swanlab/api/{experiment/__init__.py => experiment.py} (88%) rename swanlab/api/{project/__init__.py => project.py} (96%) rename swanlab/api/{selfhosted/__init__.py => selfhosted.py} (61%) delete mode 100644 swanlab/api/user/__init__.py rename swanlab/api/{workspace/__init__.py => workspace.py} (92%) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 42a2e9e56..189c25b94 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -181,6 +181,7 @@ def run(self, path: str) -> ApiResponseType: self._web_host, self._api_host, path=proj_path, + cuid=expid, data=resp.data, ), ) diff --git a/swanlab/api/base.py b/swanlab/api/base.py index 5f24f241c..e533a4805 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -10,7 +10,7 @@ from swanlab.sdk.internal.pkg import safe -from .typings.common import ApiResponseType +from .typings.common import ApiPaginationType, ApiResponseType if TYPE_CHECKING: from swanlab.sdk.internal.pkg.client import Client @@ -21,7 +21,7 @@ class BaseEntity(ABC): swanlab/api 实体类公共基类。 统一持有 _client、_web_host 和 _api_host,提供 _get/_post/_put/_delete HTTP 快捷方法和 _paginate 分页迭代。 - 所有 HTTP 请求通过 wrapper_safe_request 包裹,保证任何异常都不会导致程序 crash,统一返回 ApiResponse。 + 所有 HTTP 请求通过 _safe_request 包裹,保证任何异常都不会导致程序 crash,统一返回 ApiResponse。 子类只需实现 to_dict() 和业务逻辑。 """ @@ -32,10 +32,10 @@ def __init__(self, client: "Client", web_host: str, api_host: str) -> None: self._errors: list[str] = [] @abstractmethod - def to_dict(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Any]: """将实体序列化为 JSON 可序列化的字典。""" - def wrapper_safe_request(self, method: Callable, path: str, **kwargs) -> ApiResponseType: + def _safe_request(self, method: Callable, path: str, **kwargs) -> ApiResponseType: """安全请求包装:捕获所有异常,始终返回 ApiResponse 而不抛出。""" _err: list[str] = [] common_err: str = f"API request failed: {path}" @@ -52,19 +52,20 @@ def _do(): return ApiResponseType(ok=False, errmsg=errmsg) def _get(self, path: str, **kwargs) -> ApiResponseType: - return self.wrapper_safe_request(self._client.get, path, **kwargs) + return self._safe_request(self._client.get, path, **kwargs) def _post(self, path: str, **kwargs) -> ApiResponseType: - return self.wrapper_safe_request(self._client.post, path, **kwargs) + return self._safe_request(self._client.post, path, **kwargs) def _put(self, path: str, **kwargs) -> ApiResponseType: - return self.wrapper_safe_request(self._client.put, path, **kwargs) + return self._safe_request(self._client.put, path, **kwargs) def _delete(self, path: str, **kwargs) -> ApiResponseType: - return self.wrapper_safe_request(self._client.delete, path, **kwargs) + return self._safe_request(self._client.delete, path, **kwargs) - def _build_url(self, path: str) -> str: - return f"{self._api_host}/{path}" + def _build_web_url(self, path: str) -> str: + """构建前端 Web 页面 URL(使用 _web_host 而非 _api_host)。""" + return f"{self._web_host}/{path}" def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = None) -> Iterator[dict]: """通用分页迭代器,自动处理 page/size 参数。""" @@ -76,13 +77,12 @@ def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = resp = self._get(path, params=p) if not resp.ok: return - body = resp.data - items = body.get("list", []) if isinstance(body, dict) else body + body: ApiPaginationType = resp.data + items = body.get("list", []) if not items: break yield from items - total_pages = body.get("pages", 1) if isinstance(body, dict) else 1 - if page >= total_pages: + if page >= body.get("pages", 1): break page += 1 diff --git a/swanlab/api/experiment/__init__.py b/swanlab/api/experiment.py similarity index 88% rename from swanlab/api/experiment/__init__.py rename to swanlab/api/experiment.py index 0c089e725..bc67cc4df 100644 --- a/swanlab/api/experiment/__init__.py +++ b/swanlab/api/experiment.py @@ -10,7 +10,7 @@ from swanlab.api.base import BaseEntity from swanlab.api.typings.experiment import ApiExperimentLabelType, ApiExperimentType from swanlab.api.typings.user import ApiUserType -from swanlab.api.utils import get_properties, parse_column_type, to_camel_case +from swanlab.api.utils import get_properties, parse_filter if TYPE_CHECKING: from swanlab.sdk.internal.pkg.client import Client @@ -58,6 +58,7 @@ class Experiment(BaseEntity): 表示一个 SwanLab 实验。 支持双模式:构造时传入 data,或 data=None(按需懒加载)。 + 构造时从 data 中提取 _cuid 缓存,避免 _ensure_data 与 id 属性的循环调用。 """ def __init__( @@ -67,21 +68,27 @@ def __init__( api_host: str, *, path: str, + cuid: str = "", data: Optional[ApiExperimentType] = None, ) -> None: super().__init__(client, web_host, api_host) self._path = path # 'username/project-name' + self._cuid: str = cuid or (data.get("cuid", "") if data else "") self._data = data def _ensure_data(self) -> ApiExperimentType: if self._data is None: - resp = self._get(f"/project/{self._path}/runs/{self.id}") + resp = self._get(f"/project/{self._path}/runs/{self._cuid}") self._data = resp.data if resp.ok and resp.data else cast(ApiExperimentType, {}) + if not self._cuid and self._data: + self._cuid = self._data.get("cuid", "") return self._data @property - def id(self) -> str: - return self._data.get("cuid", "") if self._data is not None else self._ensure_data().get("cuid", "") + def run_id(self) -> str: + if self._cuid: + return self._cuid + return self._ensure_data().get("cuid", "") @property def name(self) -> str: @@ -97,7 +104,7 @@ def state(self) -> str: @property def url(self) -> str: - return self._build_url(f"@{self._path}/runs/{self.id}/chart") + return self._build_web_url(f"@{self._path}/runs/{self.run_id}/chart") @property def show(self) -> bool: @@ -132,8 +139,8 @@ def finished_at(self) -> str: def profile(self) -> Profile: """Experiment profile containing config, metadata, requirements, and conda.""" data = self._ensure_data() - if "profile" not in data and self.id: - resp = self._get(f"/project/{self._path}/runs/{self.id}") + if "profile" not in data and self._cuid: + resp = self._get(f"/project/{self._path}/runs/{self._cuid}") if resp.ok and resp.data: self._data = resp.data data = self._data @@ -162,7 +169,7 @@ def metrics( dfs = [] prefix = "" for idx, key in enumerate(fetch_keys): - resp = self._get(f"/experiment/{self.id}/column/csv", params={"key": key}) + resp = self._get(f"/experiment/{self.run_id}/column/csv", params={"key": key}) if not resp.ok: continue data = resp.data @@ -208,10 +215,10 @@ def strip_suffix(col, suffix="_step"): def delete(self) -> bool: """删除此实验。""" - resp = self._delete(f"/project/{self._path}/runs/{self.id}") + resp = self._delete(f"/project/{self._path}/runs/{self._cuid}") return resp.ok - def to_dict(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Any]: return get_properties(self) @@ -248,20 +255,7 @@ def __init__( self._filters = filters def __iter__(self) -> Iterator[Experiment]: - parsed_filters = ( - [ - { - "key": to_camel_case(key) if parse_column_type(key) == "STABLE" else key.split(".", 1)[-1], - "active": True, - "value": [value], - "op": "EQ", - "type": parse_column_type(key), - } - for key, value in self._filters.items() - ] - if self._filters - else [] - ) + parsed_filters = [parse_filter(k, v) for k, v in self._filters.items()] if self._filters else [] resp = self._post(f"/project/{self._path}/runs/shows", data={"filters": parsed_filters}) if not resp.ok: return @@ -275,5 +269,5 @@ def __iter__(self) -> Iterator[Experiment]: for run_data in runs: yield Experiment(self._client, self._web_host, self._api_host, path=self._path, data=run_data) - def to_dict(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Any]: return {"path": self._path} diff --git a/swanlab/api/project/__init__.py b/swanlab/api/project.py similarity index 96% rename from swanlab/api/project/__init__.py rename to swanlab/api/project.py index a8b28970d..785bf0bb6 100644 --- a/swanlab/api/project/__init__.py +++ b/swanlab/api/project.py @@ -51,7 +51,7 @@ def path(self) -> str: @property def url(self) -> str: - return self._build_url(f"@{self.path}") + return self._build_web_url(f"@{self.path}") @property def description(self) -> str: @@ -88,7 +88,7 @@ def delete(self) -> bool: resp = self._delete(f"/project/{self.path}") return resp.ok - def to_dict(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Any]: return get_properties(self) @@ -130,5 +130,5 @@ def __iter__(self) -> Iterator[Project]: data=cast(ApiProjectType, item), ) - def to_dict(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Any]: return {"path": self._path} diff --git a/swanlab/api/selfhosted/__init__.py b/swanlab/api/selfhosted.py similarity index 61% rename from swanlab/api/selfhosted/__init__.py rename to swanlab/api/selfhosted.py index 2cec5f87f..7d1e49445 100644 --- a/swanlab/api/selfhosted/__init__.py +++ b/swanlab/api/selfhosted.py @@ -1,15 +1,15 @@ """ @author: caddiesnew -@file: project.py +@file: selfhosted.py @time: 2026/4/20 -@description: Project 实体类 — 单个项目的查询与操作 +@description: SelfHosted 实体类 — 私有化部署实例的查询与管理 """ -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, Optional, cast from swanlab.api.base import BaseEntity from swanlab.api.typings.common import ApiResponseType -from swanlab.api.typings.selfhosted import ApiApiKeyType, ApiLicensePlanEnum, ApiSelfHostedInfoType +from swanlab.api.typings.selfhosted import ApiLicensePlanLiteral, ApiSelfHostedInfoType from swanlab.api.utils import get_properties if TYPE_CHECKING: @@ -18,9 +18,9 @@ class SelfHosted(BaseEntity): """ - 表示一个 SwanLab 项目。 + 表示一个 SwanLab 私有化部署实例。 - 支持双模式:构造时传入 data(列表迭代注入),或 data=None(按需懒加载)。 + 支持双模式:构造时传入 data,或 data=None(按需懒加载)。 """ def __init__( @@ -28,8 +28,11 @@ def __init__( client: "Client", web_host: str, api_host: str, + *, + data: Optional[ApiSelfHostedInfoType] = None, ) -> None: super().__init__(client, web_host, api_host) + self._data = data def _ensure_data(self) -> ApiSelfHostedInfoType: if self._data is None: @@ -50,28 +53,32 @@ def root(self) -> bool: return self._ensure_data().get("root", False) @property - def plan(self) -> ApiLicensePlanEnum: + def plan(self) -> ApiLicensePlanLiteral: return self._ensure_data().get("plan", "free") @property def seats(self) -> int: return self._ensure_data().get("seats", 0) - def create_user(self, username: str, password: str) -> None: + def create_user(self, username: str, password: str) -> ApiResponseType: """ - 添加用户(私有化管理员限定) + 添加用户(私有化管理员限定)。 + :param username: 待创建用户名 :param password: 待创建用户密码 """ data = {"users": [{"username": username, "password": password}]} - self._post("/self_hosted/users", data=data) + return self._post("/self_hosted/users", data=data) - def get_users(self, page_num: int = 1, page_size: int = 20) -> ApiResponseType: + def get_users(self, page: int = 1, size: int = 20) -> ApiResponseType: """ - 分页获取用户(管理员限定) - :param client: 已登录的客户端实例 + 分页获取用户(管理员限定)。 + :param page: 页码 :param size: 每页大小 """ - params = {"page": page_num, "size": page_size} + params = {"page": page, "size": size} return self._get("/self_hosted/users", params=params) + + def json(self) -> Dict[str, Any]: + return get_properties(self) diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index 331a23e36..ab1a8ffd0 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -4,3 +4,44 @@ @time: 2026/4/21 18:40 @description: SwanLab OpenAPI 类型提示, 以 Api 前缀区分 """ + +from .common import ( + ApiColumnLiteral, + ApiIdentityLiteral, + ApiLicensePlanLiteral, + ApiPaginationType, + ApiResponseType, + ApiRoleLiteral, + ApiRunStateLiteral, + ApiVisibilityLiteral, + ApiWorkspaceLiteral, +) +from .experiment import ApiExperimentLabelType, ApiExperimentType +from .project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType +from .selfhosted import ApiApiKeyType, ApiSelfHostedInfoType +from .user import ApiUserProfileType, ApiUserType +from .workspace import ApiWorkspaceInfoType + +__all__ = [ + # Kinds (preferred) + "ApiColumnLiteral", + "ApiRunStateLiteral", + "ApiVisibilityLiteral", + "ApiWorkspaceLiteral", + "ApiRoleLiteral", + "ApiIdentityLiteral", + "ApiLicensePlanLiteral", + # TypedDicts + "ApiPaginationType", + "ApiResponseType", + "ApiExperimentLabelType", + "ApiExperimentType", + "ApiProjectCountType", + "ApiProjectLabelType", + "ApiProjectType", + "ApiApiKeyType", + "ApiSelfHostedInfoType", + "ApiUserType", + "ApiUserProfileType", + "ApiWorkspaceInfoType", +] diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index c1748693a..d0fbb0f31 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -8,25 +8,25 @@ from typing import Any, Dict, List, Literal, TypedDict # 列类型 -ApiColumnEnum = Literal["SCALAR", "CONFIG", "STABLE"] +ApiColumnLiteral = Literal["SCALAR", "CONFIG", "STABLE"] # 实验状态类型 -ApiRunStateEnum = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] +ApiRunStateLiteral = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] # 可见性类型 -ApiVisibilityEnum = Literal["PUBLIC", "PRIVATE"] +ApiVisibilityLiteral = Literal["PUBLIC", "PRIVATE"] # 工作空间类型 -ApiWorkspaceEnum = Literal["TEAM", "PERSON"] +ApiWorkspaceLiteral = Literal["TEAM", "PERSON"] # 工作空间成员类型 -ApiRoleEnum = Literal["VISITOR", "VIEWER", "MEMBER", "OWNER"] +ApiRoleLiteral = Literal["VISITOR", "VIEWER", "MEMBER", "OWNER"] # Self-Hosted 身份类型 -ApiIdentityEnum = Literal["root", "user"] +ApiIdentityLiteral = Literal["root", "user"] # License 许可证类型 -ApiLicensePlanEnum = Literal["free", "commercial"] +ApiLicensePlanLiteral = Literal["free", "commercial"] class ApiPaginationType(TypedDict): @@ -62,10 +62,6 @@ def to_json_dict(self) -> Dict[str, Any]: errors.append(self.errmsg) if data is not None and hasattr(data, "to_dict"): data = data.to_dict() - # 收集实体内部子请求的错误 - if hasattr(data, "__getitem__"): - # to_dict 返回的 dict 不带 _errors,需要从实体取 - pass entity_errors = getattr(self.data, "_errors", []) errors.extend(entity_errors) ok = self.ok and not errors diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index bbecb3f36..8bac27c1c 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional, TypedDict -from .common import ApiRunStateEnum +from .common import ApiRunStateLiteral from .user import ApiUserType @@ -22,7 +22,7 @@ class ApiExperimentType(TypedDict): labels: List[ApiExperimentLabelType] profile: Dict[str, object] show: bool - state: ApiRunStateEnum + state: ApiRunStateLiteral cluster: str job: str user: ApiUserType diff --git a/swanlab/api/typings/project.py b/swanlab/api/typings/project.py index 57e55d40c..5b759d2a5 100644 --- a/swanlab/api/typings/project.py +++ b/swanlab/api/typings/project.py @@ -7,7 +7,7 @@ from typing import Dict, List, TypedDict -from .common import ApiVisibilityEnum +from .common import ApiVisibilityLiteral class ApiProjectLabelType(TypedDict): @@ -25,7 +25,7 @@ class ApiProjectType(TypedDict): name: str username: str path: str - visibility: ApiVisibilityEnum + visibility: ApiVisibilityLiteral description: str group: Dict[str, str] projectLabels: List[ApiProjectLabelType] diff --git a/swanlab/api/typings/selfhosted.py b/swanlab/api/typings/selfhosted.py index feb60defc..05956d363 100644 --- a/swanlab/api/typings/selfhosted.py +++ b/swanlab/api/typings/selfhosted.py @@ -2,12 +2,12 @@ @author: caddiesnew @file: user.py @time: 2026/4/20 -@description: 公共查询 API self-hosted 类型定义 +@description: 公共查询 API 私有化实例类型定义 """ from typing import TypedDict -from .common import ApiLicensePlanEnum +from .common import ApiLicensePlanLiteral class ApiApiKeyType(TypedDict): @@ -20,5 +20,5 @@ class ApiSelfHostedInfoType(TypedDict): enabled: bool expired: bool root: bool - plan: ApiLicensePlanEnum + plan: ApiLicensePlanLiteral seats: int diff --git a/swanlab/api/typings/user.py b/swanlab/api/typings/user.py index 13eaac06c..b4ff66bbc 100644 --- a/swanlab/api/typings/user.py +++ b/swanlab/api/typings/user.py @@ -16,7 +16,7 @@ class ApiUserType(TypedDict): class ApiUserProfileType(TypedDict): bio: str institution: str - localtion: str + location: str school: str email: str idc: str diff --git a/swanlab/api/typings/workspace.py b/swanlab/api/typings/workspace.py index 5494186be..fff0ffc4c 100644 --- a/swanlab/api/typings/workspace.py +++ b/swanlab/api/typings/workspace.py @@ -7,13 +7,13 @@ from typing import Dict, TypedDict -from .common import ApiRoleEnum, ApiWorkspaceEnum +from .common import ApiRoleLiteral, ApiWorkspaceLiteral class ApiWorkspaceInfoType(TypedDict): name: str username: str profile: Dict[str, str] - type: ApiWorkspaceEnum + type: ApiWorkspaceLiteral comment: str - role: ApiRoleEnum + role: ApiRoleLiteral diff --git a/swanlab/api/user/__init__.py b/swanlab/api/user/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 398592404..710d1f3f6 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -1,7 +1,11 @@ -from functools import wraps -from typing import Dict, List, Optional, Set, Tuple +""" +@author: caddiesnew +@file: utils.py +@time: 2026/4/20 +@description: swanlab/api 实体层工具函数 +""" -from swanlab.api.typings.common import ApiIdentityEnum +from typing import Dict, List, Optional, Set def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str, object]: @@ -39,45 +43,6 @@ def to_camel_case(name: str) -> str: return "".join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split("_"))]) -# TODO: 私有化接口装饰器 -# def with_self_hosted(identity: ApiIdentityEnum = "user"): -# """ -# 用于需要在私有化环境下使用的接口的装饰器。 -# :param identity: 用户身份,默认为 "user",如果为 "root",则会额外验证是否为根用户。 -# """ - -# def decorator(func): -# @wraps(func) -# def wrapper(self, *args, **kwargs): -# client = getattr(self, "_client", None) -# if not isinstance(client, Client): -# raise AttributeError("There is no SwanLab client instance.") - -# # 1. 尝试获取私有化服务信息 -# try: -# self_hosted_info = get_self_hosted_init(client) -# except ApiError: -# raise ValueError("You haven't launched a swanlab self-hosted instance. This usages are not available.") - -# if not self_hosted_info.get("enabled", False): -# raise ValueError("SwanLab self-hosted instance hasn't been ready yet.") -# if self_hosted_info.get("expired", True): -# raise ValueError("SwanLab self-hosted instance has expired.") - -# # 2. 检测用户权限(商业版root用户功能) -# if identity == "root": -# if not self_hosted_info.get("root", False): -# raise ValueError("You don't have permission to perform this action. Please login as a root user") -# if not getattr(self, "is_self", True): -# raise ValueError("This root-only action can only be performed by the logged-in root user.") - -# return func(self, *args, **kwargs) - -# return wrapper - -# return decorator - - _SPECIAL_FILTER_MAP = { # (backend_key, operator) — 用户侧 key 到后端字段名和操作符的映射 # backend_key: 后端 API 实际接受的字段名 @@ -110,37 +75,3 @@ def parse_filter(key: str, value: object) -> Dict[str, object]: "op": "EQ", "type": ct, } - - -def unwrap_api_payload(data): - """提取 raw resp 的 data 响应.""" - if isinstance(data, dict) and "data" in data and isinstance(data["data"], (dict, list)): - return data["data"] - return data - - -# mulitpart-save -def extract_upload_id(payload: Dict[str, object]) -> Optional[str]: - upload_id = payload.get("uploadId") - if isinstance(upload_id, str) and upload_id: - return upload_id - return None - - -# multipart-save -def extract_part_urls(payload: Dict[str, object]) -> List[Tuple[int, str]]: - parts = payload.get("parts") - if not isinstance(parts, list): - raise ValueError("Multipart upload URLs are missing in prepare response.") - - resolved = [] - for part in parts: - if not isinstance(part, dict): - raise ValueError("Multipart prepare response contains invalid part data.") - part_number = part.get("partNumber") - url = part.get("url") - if not isinstance(part_number, int) or not isinstance(url, str) or not url: - raise ValueError("Invalid partNumber or url in multipart response.") - resolved.append((part_number, url)) - - return sorted(resolved, key=lambda item: item[0]) diff --git a/swanlab/api/workspace/__init__.py b/swanlab/api/workspace.py similarity index 92% rename from swanlab/api/workspace/__init__.py rename to swanlab/api/workspace.py index 7589d19b8..321d3cc23 100644 --- a/swanlab/api/workspace/__init__.py +++ b/swanlab/api/workspace.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, cast from swanlab.api.base import BaseEntity -from swanlab.api.typings.workspace import ApiWorkspaceEnum, ApiWorkspaceInfoType +from swanlab.api.typings.workspace import ApiWorkspaceInfoType, ApiWorkspaceLiteral from swanlab.api.utils import get_properties if TYPE_CHECKING: @@ -48,8 +48,8 @@ def username(self) -> str: return self._ensure_data().get("username", "") @property - def workspace_type(self) -> ApiWorkspaceEnum: - return self._ensure_data().get("type", "") + def workspace_type(self) -> ApiWorkspaceLiteral: + return self._ensure_data().get("type", "PERSON") @property def profile(self) -> Dict[str, str]: @@ -82,7 +82,7 @@ def projects( detail=detail, ) - def to_dict(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Any]: return get_properties(self) @@ -114,5 +114,5 @@ def __iter__(self) -> Iterator[Workspace]: data = resp.data if resp.ok else None yield Workspace(self._client, self._web_host, self._api_host, username=name, data=data) - def to_dict(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Any]: return {"username": self._username} From 953ff24c2b8fae27bf527364cfd9dad734140cb5 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 13:02:06 +0800 Subject: [PATCH 09/52] refactor: set api client context var --- swanlab/api/__init__.py | 54 +++++++++++++++++++-------------------- swanlab/api/base.py | 32 +++++++++++++++-------- swanlab/api/experiment.py | 21 +++++---------- swanlab/api/project.py | 25 ++++++------------ swanlab/api/selfhosted.py | 13 +++------- swanlab/api/workspace.py | 23 ++++++----------- 6 files changed, 75 insertions(+), 93 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 189c25b94..481ba090f 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -12,7 +12,7 @@ from swanlab.sdk.internal.pkg.client import Client from swanlab.sdk.internal.settings import settings as global_settings -from .base import BaseEntity +from .base import ApiClientContext, BaseEntity from .experiment import Experiment, Experiments from .project import Project, Projects from .typings.common import ApiResponseType @@ -23,7 +23,7 @@ class Api(BaseEntity): """ SwanLab 公共查询 API 入口。 - 通过独立的 Client 实例与 SwanLab 云端交互,不与 SDK 运行时单例共享。 + 通过独立的 Client 实例与 SwanLab 云端交互,与 SDK 运行时客户端完全隔离。 继承 BaseEntity 以复用 _get/_post/_put/_delete/_paginate 等安全 HTTP 方法。 用法:: @@ -48,19 +48,29 @@ def __init__( """ 初始化 Api 实例。 - 认证优先级:显式参数 > Settings(含 .netrc / 环境变量) + 认证优先级: + 1. 显式参数 (api_key / host / web_host) + 2. scope 登录态(进程内已调用 swanlab.login 时可用) + 3. Settings(含 .netrc / 环境变量) + + 始终创建独立的 Client 实例,与 SDK 运行时单例互不干扰。 :param api_key: API 密钥,为 None 时从 Settings / .netrc / 环境变量读取 :param host: API 主机地址,为 None 时从 Settings 读取 :param web_host: Web 面板地址,为 None 时从 Settings 读取 """ - api_key, api_host, resolved_web_host = self._resolve_credentials(api_key, host, web_host) - client = Client(api_key=str(api_key), base_url=api_host) - super().__init__(client, resolved_web_host, api_host) - self._login_resp = scope.get_context("login_resp") - self._username: str = self._login_resp["userInfo"]["username"] if self._login_resp else "" + # 优先从 scope 获取已有登录态(如进程内已调用 swanlab.login),直接复用凭证 + login_resp = scope.get_context("login_resp") + api_key_resolved, api_host, resolved_web_host = self._resolve_credentials(api_key, host, web_host) + _client = Client(api_key=str(api_key_resolved), base_url=api_host) + if login_resp is None: + login_resp = scope.get_context("login_resp") + + ctx = ApiClientContext(client=_client, web_host=resolved_web_host, api_host=api_host) + super().__init__(ctx) + self._username: str = login_resp["userInfo"]["username"] if login_resp else "" - def to_dict(self) -> dict: + def json(self) -> dict: """Api 非数据实体,返回空字典。""" return {} @@ -71,7 +81,7 @@ def _resolve_credentials( web_host: Optional[str], ) -> tuple[str, str, str]: """ - 按优先级解析凭证:显式参数 > Settings(含 .netrc / 环境变量)。 + 按优先级解析凭证:显式参数 > scope 登录态 > Settings(含 .netrc / 环境变量)。 返回 (api_key, api_host, web_host)。 """ if api_key is None: @@ -85,7 +95,7 @@ def _resolve_credentials( return api_key, api_host, resolved_web_host # ------------------------------------------------------------------ - # 实体查询方法 — 统一返回 ApiResponse + # 实体查询方法 — 统一返回 ApiResponseType # ------------------------------------------------------------------ def workspace(self, username: Optional[str] = None) -> ApiResponseType: @@ -100,13 +110,7 @@ def workspace(self, username: Optional[str] = None) -> ApiResponseType: if resp.ok: return ApiResponseType( ok=True, - data=Workspace( - self._client, - self._web_host, - self._api_host, - username=username, - data=resp.data, - ), + data=Workspace(self._ctx, username=username, data=resp.data), ) return resp @@ -120,7 +124,7 @@ def workspaces(self, username: Optional[str] = None) -> ApiResponseType: username = self._username return ApiResponseType( ok=True, - data=Workspaces(self._client, self._web_host, self._api_host, username=username), + data=Workspaces(self._ctx, username=username), ) def project(self, path: str) -> ApiResponseType: @@ -133,7 +137,7 @@ def project(self, path: str) -> ApiResponseType: if resp.ok: return ApiResponseType( ok=True, - data=Project(self._client, self._web_host, self._api_host, path=path, data=resp.data), + data=Project(self._ctx, path=path, data=resp.data), ) return resp @@ -154,9 +158,7 @@ def projects( """ return ApiResponseType( ok=True, - data=Projects( - self._client, self._web_host, self._api_host, path=path, sort=sort, search=search, detail=detail - ), + data=Projects(self._ctx, path=path, sort=sort, search=search, detail=detail), ) def run(self, path: str) -> ApiResponseType: @@ -177,9 +179,7 @@ def run(self, path: str) -> ApiResponseType: return ApiResponseType( ok=True, data=Experiment( - self._client, - self._web_host, - self._api_host, + self._ctx, path=proj_path, cuid=expid, data=resp.data, @@ -196,7 +196,7 @@ def runs(self, path: str, filters: Optional[dict] = None) -> ApiResponseType: """ return ApiResponseType( ok=True, - data=Experiments(self._client, self._web_host, self._api_host, path=path, filters=filters), + data=Experiments(self._ctx, path=path, filters=filters), ) diff --git a/swanlab/api/base.py b/swanlab/api/base.py index e533a4805..aad275a01 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -5,7 +5,10 @@ @description: 所有实体类的公共基类 """ +from __future__ import annotations + from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional from swanlab.sdk.internal.pkg import safe @@ -16,19 +19,26 @@ from swanlab.sdk.internal.pkg.client import Client +@dataclass(frozen=True) +class ApiClientContext: + """共享上下文:所有子实体复用同一个实例,避免 (client, web_host, api_host) 三元组透传。""" + + client: "Client" + web_host: str + api_host: str + + class BaseEntity(ABC): """ swanlab/api 实体类公共基类。 - 统一持有 _client、_web_host 和 _api_host,提供 _get/_post/_put/_delete HTTP 快捷方法和 _paginate 分页迭代。 + 统一持有 _ctx(_ApiContext),提供 _get/_post/_put/_delete HTTP 快捷方法和 _paginate 分页迭代。 所有 HTTP 请求通过 _safe_request 包裹,保证任何异常都不会导致程序 crash,统一返回 ApiResponse。 子类只需实现 to_dict() 和业务逻辑。 """ - def __init__(self, client: "Client", web_host: str, api_host: str) -> None: - self._client: "Client" = client - self._web_host: str = web_host - self._api_host: str = api_host + def __init__(self, ctx: ApiClientContext) -> None: + self._ctx: ApiClientContext = ctx self._errors: list[str] = [] @abstractmethod @@ -36,7 +46,7 @@ def json(self) -> Dict[str, Any]: """将实体序列化为 JSON 可序列化的字典。""" def _safe_request(self, method: Callable, path: str, **kwargs) -> ApiResponseType: - """安全请求包装:捕获所有异常,始终返回 ApiResponse 而不抛出。""" + """安全请求包装:捕获所有异常,始终返回 ApiResponseType 而不抛出。""" _err: list[str] = [] common_err: str = f"API request failed: {path}" @@ -52,20 +62,20 @@ def _do(): return ApiResponseType(ok=False, errmsg=errmsg) def _get(self, path: str, **kwargs) -> ApiResponseType: - return self._safe_request(self._client.get, path, **kwargs) + return self._safe_request(self._ctx.client.get, path, **kwargs) def _post(self, path: str, **kwargs) -> ApiResponseType: - return self._safe_request(self._client.post, path, **kwargs) + return self._safe_request(self._ctx.client.post, path, **kwargs) def _put(self, path: str, **kwargs) -> ApiResponseType: - return self._safe_request(self._client.put, path, **kwargs) + return self._safe_request(self._ctx.client.put, path, **kwargs) def _delete(self, path: str, **kwargs) -> ApiResponseType: - return self._safe_request(self._client.delete, path, **kwargs) + return self._safe_request(self._ctx.client.delete, path, **kwargs) def _build_web_url(self, path: str) -> str: """构建前端 Web 页面 URL(使用 _web_host 而非 _api_host)。""" - return f"{self._web_host}/{path}" + return f"{self._ctx.web_host}/{path}" def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = None) -> Iterator[dict]: """通用分页迭代器,自动处理 page/size 参数。""" diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index bc67cc4df..afca2cdd9 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -5,16 +5,13 @@ @description: Experiment 实体类 — 单个实验的查询与操作 """ -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Union, cast -from swanlab.api.base import BaseEntity +from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.experiment import ApiExperimentLabelType, ApiExperimentType from swanlab.api.typings.user import ApiUserType from swanlab.api.utils import get_properties, parse_filter -if TYPE_CHECKING: - from swanlab.sdk.internal.pkg.client import Client - class Profile: """Experiment profile containing config, metadata, requirements, and conda info.""" @@ -63,15 +60,13 @@ class Experiment(BaseEntity): def __init__( self, - client: "Client", - web_host: str, - api_host: str, + ctx: ApiClientContext, *, path: str, cuid: str = "", data: Optional[ApiExperimentType] = None, ) -> None: - super().__init__(client, web_host, api_host) + super().__init__(ctx) self._path = path # 'username/project-name' self._cuid: str = cuid or (data.get("cuid", "") if data else "") self._data = data @@ -243,14 +238,12 @@ class Experiments(BaseEntity): def __init__( self, - client: "Client", - web_host: str, - api_host: str, + ctx: ApiClientContext, *, path: str, filters: Optional[Dict[str, object]] = None, ) -> None: - super().__init__(client, web_host, api_host) + super().__init__(ctx) self._path = path self._filters = filters @@ -267,7 +260,7 @@ def __iter__(self) -> Iterator[Experiment]: runs = _flatten_runs(body) for run_data in runs: - yield Experiment(self._client, self._web_host, self._api_host, path=self._path, data=run_data) + yield Experiment(self._ctx, path=self._path, data=run_data) def json(self) -> Dict[str, Any]: return {"path": self._path} diff --git a/swanlab/api/project.py b/swanlab/api/project.py index 785bf0bb6..5308f7fa1 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/project.py @@ -5,15 +5,12 @@ @description: Project 实体类 — 单个项目的查询与操作 """ -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast +from typing import Any, Dict, Iterator, List, Optional, cast -from swanlab.api.base import BaseEntity +from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType from swanlab.api.utils import get_properties -if TYPE_CHECKING: - from swanlab.sdk.internal.pkg.client import Client - class Project(BaseEntity): """ @@ -24,14 +21,12 @@ class Project(BaseEntity): def __init__( self, - client: "Client", - web_host: str, - api_host: str, + ctx: ApiClientContext, *, path: str, data: Optional[ApiProjectType] = None, ) -> None: - super().__init__(client, web_host, api_host) + super().__init__(ctx) self._path = path self._data = data @@ -81,7 +76,7 @@ def runs(self, filters: Optional[Dict[str, object]] = None): """获取项目下的实验列表。""" from swanlab.api.experiment import Experiments - return Experiments(self._client, self._web_host, self._api_host, path=self.path, filters=filters) + return Experiments(self._ctx, path=self.path, filters=filters) def delete(self) -> bool: """删除此项目。""" @@ -104,16 +99,14 @@ class Projects(BaseEntity): def __init__( self, - client: "Client", - web_host: str, - api_host: str, + ctx: ApiClientContext, *, path: str, sort: Optional[str] = None, search: Optional[str] = None, detail: Optional[bool] = True, ) -> None: - super().__init__(client, web_host, api_host) + super().__init__(ctx) self._path = path self._sort = sort self._search = search @@ -123,9 +116,7 @@ def __iter__(self) -> Iterator[Project]: params = {"sort": self._sort, "search": self._search, "detail": self._detail} for item in self._paginate(f"/project/{self._path}", params=params): yield Project( - self._client, - self._web_host, - self._api_host, + self._ctx, path=str(item.get("path", "")), data=cast(ApiProjectType, item), ) diff --git a/swanlab/api/selfhosted.py b/swanlab/api/selfhosted.py index 7d1e49445..075ce2e44 100644 --- a/swanlab/api/selfhosted.py +++ b/swanlab/api/selfhosted.py @@ -5,16 +5,13 @@ @description: SelfHosted 实体类 — 私有化部署实例的查询与管理 """ -from typing import TYPE_CHECKING, Any, Dict, Optional, cast +from typing import Any, Dict, Optional, cast -from swanlab.api.base import BaseEntity +from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.common import ApiResponseType from swanlab.api.typings.selfhosted import ApiLicensePlanLiteral, ApiSelfHostedInfoType from swanlab.api.utils import get_properties -if TYPE_CHECKING: - from swanlab.sdk.internal.pkg.client import Client - class SelfHosted(BaseEntity): """ @@ -25,13 +22,11 @@ class SelfHosted(BaseEntity): def __init__( self, - client: "Client", - web_host: str, - api_host: str, + ctx: ApiClientContext, *, data: Optional[ApiSelfHostedInfoType] = None, ) -> None: - super().__init__(client, web_host, api_host) + super().__init__(ctx) self._data = data def _ensure_data(self) -> ApiSelfHostedInfoType: diff --git a/swanlab/api/workspace.py b/swanlab/api/workspace.py index 321d3cc23..d00666221 100644 --- a/swanlab/api/workspace.py +++ b/swanlab/api/workspace.py @@ -5,15 +5,12 @@ @description: Workspace 实体类 — 工作空间的查询 """ -from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, cast +from typing import Any, Dict, Iterator, Optional, cast -from swanlab.api.base import BaseEntity +from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.workspace import ApiWorkspaceInfoType, ApiWorkspaceLiteral from swanlab.api.utils import get_properties -if TYPE_CHECKING: - from swanlab.sdk.internal.pkg.client import Client - class Workspace(BaseEntity): """ @@ -22,14 +19,12 @@ class Workspace(BaseEntity): def __init__( self, - client: "Client", - web_host: str, - api_host: str, + ctx: ApiClientContext, *, username: str, data: Optional[ApiWorkspaceInfoType] = None, ) -> None: - super().__init__(client, web_host, api_host) + super().__init__(ctx) self._username = username self._data = data @@ -73,9 +68,7 @@ def projects( from swanlab.api.project import Projects return Projects( - self._client, - self._web_host, - self._api_host, + self._ctx, path=self.username, sort=sort, search=search, @@ -96,8 +89,8 @@ class Workspaces(BaseEntity): print(ws.name) """ - def __init__(self, client: "Client", web_host: str, api_host: str, *, username: str) -> None: - super().__init__(client, web_host, api_host) + def __init__(self, ctx: ApiClientContext, *, username: str) -> None: + super().__init__(ctx) self._username = username def _get_all_workspace_names(self) -> list[str]: @@ -112,7 +105,7 @@ def __iter__(self) -> Iterator[Workspace]: for name in self._get_all_workspace_names(): resp = self._get(f"/group/{name}") data = resp.data if resp.ok else None - yield Workspace(self._client, self._web_host, self._api_host, username=name, data=data) + yield Workspace(self._ctx, username=name, data=data) def json(self) -> Dict[str, Any]: return {"username": self._username} From 2f87458fe0d3207d0c6ff8363a475875a3a1a8f1 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 14:38:04 +0800 Subject: [PATCH 10/52] feat: add api --- swanlab/api/__init__.py | 48 ++++++----------------------- swanlab/api/base.py | 13 +++++++- swanlab/api/typings/common.py | 9 ++---- swanlab/api/utils.py | 2 +- swanlab/cli/api/helper.py | 4 +-- swanlab/cli/api/project/__init__.py | 10 ++++++ 6 files changed, 37 insertions(+), 49 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 481ba090f..bc1970d93 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -95,7 +95,9 @@ def _resolve_credentials( return api_key, api_host, resolved_web_host # ------------------------------------------------------------------ - # 实体查询方法 — 统一返回 ApiResponseType + # 实体工厂方法 + # - 单实体(workspace/project/run):构造后调用 _fetch() 立即加载并返回 ok/not-ok + # - 列表迭代器(workspaces/projects/runs):惰性构造,迭代时按需分页请求 # ------------------------------------------------------------------ def workspace(self, username: Optional[str] = None) -> ApiResponseType: @@ -106,13 +108,7 @@ def workspace(self, username: Optional[str] = None) -> ApiResponseType: """ if username is None: username = self._username - resp = self._get(f"/group/{username}") - if resp.ok: - return ApiResponseType( - ok=True, - data=Workspace(self._ctx, username=username, data=resp.data), - ) - return resp + return Workspace(self._ctx, username=username)._fetch() def workspaces(self, username: Optional[str] = None) -> ApiResponseType: """ @@ -122,10 +118,7 @@ def workspaces(self, username: Optional[str] = None) -> ApiResponseType: """ if username is None: username = self._username - return ApiResponseType( - ok=True, - data=Workspaces(self._ctx, username=username), - ) + return ApiResponseType(ok=True, data=Workspaces(self._ctx, username=username)) def project(self, path: str) -> ApiResponseType: """ @@ -133,13 +126,7 @@ def project(self, path: str) -> ApiResponseType: :param path: 项目路径,格式为 'username/project-name' """ - resp = self._get(f"/project/{path}") - if resp.ok: - return ApiResponseType( - ok=True, - data=Project(self._ctx, path=path, data=resp.data), - ) - return resp + return Project(self._ctx, path=path)._fetch() def projects( self, @@ -156,10 +143,7 @@ def projects( :param search: 搜索关键词 :param detail: 是否返回详细信息 """ - return ApiResponseType( - ok=True, - data=Projects(self._ctx, path=path, sort=sort, search=search, detail=detail), - ) + return ApiResponseType(ok=True, data=Projects(self._ctx, path=path, sort=sort, search=search, detail=detail)) def run(self, path: str) -> ApiResponseType: """ @@ -174,18 +158,7 @@ def run(self, path: str) -> ApiResponseType: ) proj_path = path.rsplit("/", 1)[0] expid = parts[2] - resp = self._get(f"/project/{proj_path}/runs/{expid}") - if resp.ok: - return ApiResponseType( - ok=True, - data=Experiment( - self._ctx, - path=proj_path, - cuid=expid, - data=resp.data, - ), - ) - return resp + return Experiment(self._ctx, path=proj_path, cuid=expid)._fetch() def runs(self, path: str, filters: Optional[dict] = None) -> ApiResponseType: """ @@ -194,10 +167,7 @@ def runs(self, path: str, filters: Optional[dict] = None) -> ApiResponseType: :param path: 项目路径,格式为 'username/project' :param filters: 筛选条件 """ - return ApiResponseType( - ok=True, - data=Experiments(self._ctx, path=path, filters=filters), - ) + return ApiResponseType(ok=True, data=Experiments(self._ctx, path=path, filters=filters)) __all__ = ["Api"] diff --git a/swanlab/api/base.py b/swanlab/api/base.py index aad275a01..4bb75534e 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -34,13 +34,17 @@ class BaseEntity(ABC): 统一持有 _ctx(_ApiContext),提供 _get/_post/_put/_delete HTTP 快捷方法和 _paginate 分页迭代。 所有 HTTP 请求通过 _safe_request 包裹,保证任何异常都不会导致程序 crash,统一返回 ApiResponse。 - 子类只需实现 to_dict() 和业务逻辑。 + 子类只需实现 json() 和业务逻辑。 """ def __init__(self, ctx: ApiClientContext) -> None: self._ctx: ApiClientContext = ctx self._errors: list[str] = [] + def _ensure_data(self) -> Any: + """按需加载数据。单实体子类重写此方法;迭代器子类无需重写。""" + return None + @abstractmethod def json(self) -> Dict[str, Any]: """将实体序列化为 JSON 可序列化的字典。""" @@ -77,6 +81,13 @@ def _build_web_url(self, path: str) -> str: """构建前端 Web 页面 URL(使用 _web_host 而非 _api_host)。""" return f"{self._ctx.web_host}/{path}" + def _fetch(self) -> ApiResponseType: + """Eager 模式:触发子类 _ensure_data 加载数据,根据 _errors 返回 ApiResponseType。""" + self._ensure_data() + if self._errors: + return ApiResponseType(ok=False, errmsg=self._errors[-1]) + return ApiResponseType(ok=True, data=self) + def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = None) -> Iterator[dict]: """通用分页迭代器,自动处理 page/size 参数。""" page = 1 diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index d0fbb0f31..2ff848287 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -51,17 +51,14 @@ def __init__(self, *, ok: bool, errmsg: str = "", data: Any = None) -> None: self.errmsg = errmsg self.data = data - def to_dict(self) -> Dict[str, Any]: - return {"ok": self.ok, "errmsg": self.errmsg, "data": self.data} - - def to_json_dict(self) -> Dict[str, Any]: + def json(self) -> Dict[str, Any]: """返回 JSON 可序列化的字典,自动将实体 data 转为 dict。""" data = self.data errors: list[str] = [] if not self.ok and self.errmsg: errors.append(self.errmsg) - if data is not None and hasattr(data, "to_dict"): - data = data.to_dict() + if data is not None and hasattr(data, "json"): + data = data.json() entity_errors = getattr(self.data, "_errors", []) errors.extend(entity_errors) ok = self.ok and not errors diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 710d1f3f6..9db28757c 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -9,7 +9,7 @@ def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str, object]: - """递归获取实例中所有 property 的值,用于 to_dict() 默认实现。""" + """递归获取实例中所有 property 的值,用于 json() 默认实现。""" if _visited is None: _visited = set() obj_id = id(obj) diff --git a/swanlab/cli/api/helper.py b/swanlab/cli/api/helper.py index 4ecd04067..f2af6f0f3 100644 --- a/swanlab/cli/api/helper.py +++ b/swanlab/cli/api/helper.py @@ -6,5 +6,5 @@ def format_output(resp: ApiResponseType) -> None: - """统一输出 ApiResponse JSON。""" - click.echo(json.dumps(resp.to_json_dict())) + """统一输出 ApiResponseType JSON。""" + click.echo(json.dumps(resp.json())) diff --git a/swanlab/cli/api/project/__init__.py b/swanlab/cli/api/project/__init__.py index 32a8741e5..8672c34d3 100644 --- a/swanlab/cli/api/project/__init__.py +++ b/swanlab/cli/api/project/__init__.py @@ -1,5 +1,6 @@ import click +from swanlab.api import Api from swanlab.api.typings.common import ApiResponseType from swanlab.cli.api.helper import format_output @@ -8,3 +9,12 @@ def project_cli(): """Project management commands.""" pass + + +@project_cli.command("info") +@click.argument("path", required=True) +def get_project(path: str): + """Get project info by path (username/project).""" + api = Api() + resp = api.project(path) + format_output(resp) From ab69545d57f3ae82dc9dd59caea6204a087d1f72 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 15:48:27 +0800 Subject: [PATCH 11/52] feat: add save_option --- swanlab/api/base.py | 4 +-- swanlab/cli/api/experiment/__init__.py | 13 ++++++-- swanlab/cli/api/helper.py | 43 ++++++++++++++++++++++++-- swanlab/cli/api/project/__init__.py | 7 ++--- swanlab/cli/api/selfhosted/__init__.py | 1 + swanlab/cli/api/workspace/__init__.py | 13 ++++++-- 6 files changed, 68 insertions(+), 13 deletions(-) diff --git a/swanlab/api/base.py b/swanlab/api/base.py index 4bb75534e..bb22180a4 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -52,9 +52,9 @@ def json(self) -> Dict[str, Any]: def _safe_request(self, method: Callable, path: str, **kwargs) -> ApiResponseType: """安全请求包装:捕获所有异常,始终返回 ApiResponseType 而不抛出。""" _err: list[str] = [] - common_err: str = f"API request failed: {path}" + common_err: str = f"API Request Failed: {path}" - @safe.decorator(message=common_err, on_error=lambda e: _err.append(str(e))) + @safe.decorator(message=None, on_error=lambda e: _err.append(str(e))) def _do(): return method(path, **kwargs).data diff --git a/swanlab/cli/api/experiment/__init__.py b/swanlab/cli/api/experiment/__init__.py index b8e6d9b31..cfdc29ce6 100644 --- a/swanlab/cli/api/experiment/__init__.py +++ b/swanlab/cli/api/experiment/__init__.py @@ -1,10 +1,19 @@ import click -from swanlab.api.typings.common import ApiResponseType -from swanlab.cli.api.helper import format_output +from swanlab.api import Api +from swanlab.cli.api.helper import with_save_option @click.group("run") def experiment_cli(): """Experiment(Run) management commands.""" pass + + +@experiment_cli.command("info") +@click.argument("path", required=True) +@with_save_option +def get_experiment(path: str): + """Get Experiment(Run) info by path (username/project/run_id).""" + api = Api() + return api.run(path) diff --git a/swanlab/cli/api/helper.py b/swanlab/cli/api/helper.py index f2af6f0f3..5da9d83d5 100644 --- a/swanlab/cli/api/helper.py +++ b/swanlab/cli/api/helper.py @@ -1,10 +1,47 @@ +import functools import json +from datetime import datetime import click +import nanoid from swanlab.api.typings.common import ApiResponseType -def format_output(resp: ApiResponseType) -> None: - """统一输出 ApiResponseType JSON。""" - click.echo(json.dumps(resp.json())) +def _save_json(content: str) -> None: + """将 JSON 内容保存到当前目录。""" + filename = f"swanlab-{datetime.now().strftime('%Y%m%d_%H%M%S')}-{nanoid.generate(size=4)}.json" + with open(filename, "w", encoding="utf-8") as f: + f.write(content) + click.echo(f"Saved to {filename}") + + +def format_output(resp: ApiResponseType, save: bool = False) -> None: + """统一输出 ApiResponseType JSON,可选保存到文件。""" + data = resp.json() + click.echo(json.dumps(data, ensure_ascii=False)) + if save: + _save_json(json.dumps(data, ensure_ascii=False, indent=2)) + + +def with_save_option(f): + """ + 装饰器:为 CLI 命令添加 --save 选项并自动输出/保存响应。 + + 被装饰的函数应返回 ApiResponseType,装饰器负责 format_output 和可选的文件保存。 + """ + + @click.option( + "--save", + "-s", + is_flag=True, + default=False, + help="Save output as JSON to current directory.", + ) + @functools.wraps(f) + def wrapper(*args, save: bool, **kwargs): + resp = f(*args, **kwargs) + if resp is not None: + format_output(resp, save=save) + + return wrapper diff --git a/swanlab/cli/api/project/__init__.py b/swanlab/cli/api/project/__init__.py index 8672c34d3..0f75a57bd 100644 --- a/swanlab/cli/api/project/__init__.py +++ b/swanlab/cli/api/project/__init__.py @@ -1,8 +1,7 @@ import click from swanlab.api import Api -from swanlab.api.typings.common import ApiResponseType -from swanlab.cli.api.helper import format_output +from swanlab.cli.api.helper import with_save_option @click.group("project") @@ -13,8 +12,8 @@ def project_cli(): @project_cli.command("info") @click.argument("path", required=True) +@with_save_option def get_project(path: str): """Get project info by path (username/project).""" api = Api() - resp = api.project(path) - format_output(resp) + return api.project(path) diff --git a/swanlab/cli/api/selfhosted/__init__.py b/swanlab/cli/api/selfhosted/__init__.py index 3003315f0..306252d4a 100644 --- a/swanlab/cli/api/selfhosted/__init__.py +++ b/swanlab/cli/api/selfhosted/__init__.py @@ -1,5 +1,6 @@ import click +from swanlab.api import Api from swanlab.api.typings.common import ApiResponseType from swanlab.cli.api.helper import format_output diff --git a/swanlab/cli/api/workspace/__init__.py b/swanlab/cli/api/workspace/__init__.py index 2c3de65f3..4a7f38e7a 100644 --- a/swanlab/cli/api/workspace/__init__.py +++ b/swanlab/cli/api/workspace/__init__.py @@ -1,10 +1,19 @@ import click -from swanlab.api.typings.common import ApiResponseType -from swanlab.cli.api.helper import format_output +from swanlab.api import Api +from swanlab.cli.api.helper import with_save_option @click.group("workspace") def workspace_cli(): """Workspace management commands.""" pass + + +@workspace_cli.command("info") +@click.argument("username", required=True) +@with_save_option +def get_workspace(username: str): + """Get Workspace info.""" + api = Api() + return api.workspace(username) From ed8bcd7b6ac2612ba4ae09c30f745cdfa41dc849 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 16:04:28 +0800 Subject: [PATCH 12/52] feat: flatten module file --- swanlab/cli/api/{experiment/__init__.py => experiment.py} | 0 swanlab/cli/api/helper.py | 2 +- swanlab/cli/api/{project/__init__.py => project.py} | 0 swanlab/cli/api/{selfhosted/__init__.py => selfhosted.py} | 0 swanlab/cli/api/{user/__init__.py => user.py} | 0 swanlab/cli/api/{workspace/__init__.py => workspace.py} | 0 6 files changed, 1 insertion(+), 1 deletion(-) rename swanlab/cli/api/{experiment/__init__.py => experiment.py} (100%) rename swanlab/cli/api/{project/__init__.py => project.py} (100%) rename swanlab/cli/api/{selfhosted/__init__.py => selfhosted.py} (100%) rename swanlab/cli/api/{user/__init__.py => user.py} (100%) rename swanlab/cli/api/{workspace/__init__.py => workspace.py} (100%) diff --git a/swanlab/cli/api/experiment/__init__.py b/swanlab/cli/api/experiment.py similarity index 100% rename from swanlab/cli/api/experiment/__init__.py rename to swanlab/cli/api/experiment.py diff --git a/swanlab/cli/api/helper.py b/swanlab/cli/api/helper.py index 5da9d83d5..b7dabab30 100644 --- a/swanlab/cli/api/helper.py +++ b/swanlab/cli/api/helper.py @@ -20,7 +20,7 @@ def format_output(resp: ApiResponseType, save: bool = False) -> None: """统一输出 ApiResponseType JSON,可选保存到文件。""" data = resp.json() click.echo(json.dumps(data, ensure_ascii=False)) - if save: + if save and resp.ok: _save_json(json.dumps(data, ensure_ascii=False, indent=2)) diff --git a/swanlab/cli/api/project/__init__.py b/swanlab/cli/api/project.py similarity index 100% rename from swanlab/cli/api/project/__init__.py rename to swanlab/cli/api/project.py diff --git a/swanlab/cli/api/selfhosted/__init__.py b/swanlab/cli/api/selfhosted.py similarity index 100% rename from swanlab/cli/api/selfhosted/__init__.py rename to swanlab/cli/api/selfhosted.py diff --git a/swanlab/cli/api/user/__init__.py b/swanlab/cli/api/user.py similarity index 100% rename from swanlab/cli/api/user/__init__.py rename to swanlab/cli/api/user.py diff --git a/swanlab/cli/api/workspace/__init__.py b/swanlab/cli/api/workspace.py similarity index 100% rename from swanlab/cli/api/workspace/__init__.py rename to swanlab/cli/api/workspace.py From 27e5463889703c070f4732da43b35686a4617cd1 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 16:38:42 +0800 Subject: [PATCH 13/52] fix: export open api --- swanlab/__init__.py | 3 +++ swanlab/api/__init__.py | 32 ++++++++++++---------------- swanlab/api/base.py | 14 ++++++------ swanlab/api/experiment.py | 40 +++++++++++++++++++++++------------ swanlab/cli/api/experiment.py | 2 +- swanlab/cli/api/project.py | 2 +- swanlab/cli/api/workspace.py | 2 +- 7 files changed, 53 insertions(+), 42 deletions(-) diff --git a/swanlab/__init__.py b/swanlab/__init__.py index e8039572d..18f7597e3 100644 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -24,6 +24,7 @@ ) from . import utils +from .api import Api __version__ = helper.get_swanlab_version() @@ -56,6 +57,8 @@ "utils", "Settings", "Callback", + # Api + "Api", ] diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index bc1970d93..0ebcb427d 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -100,7 +100,7 @@ def _resolve_credentials( # - 列表迭代器(workspaces/projects/runs):惰性构造,迭代时按需分页请求 # ------------------------------------------------------------------ - def workspace(self, username: Optional[str] = None) -> ApiResponseType: + def workspace(self, username: Optional[str] = None) -> Workspace: """ 获取工作空间信息,默认为当前登录用户的工作空间。 @@ -108,9 +108,9 @@ def workspace(self, username: Optional[str] = None) -> ApiResponseType: """ if username is None: username = self._username - return Workspace(self._ctx, username=username)._fetch() + return Workspace(self._ctx, username=username) - def workspaces(self, username: Optional[str] = None) -> ApiResponseType: + def workspaces(self, username: Optional[str] = None) -> Workspaces: """ 获取工作空间列表迭代器。 @@ -118,15 +118,15 @@ def workspaces(self, username: Optional[str] = None) -> ApiResponseType: """ if username is None: username = self._username - return ApiResponseType(ok=True, data=Workspaces(self._ctx, username=username)) + return Workspaces(self._ctx, username=username) - def project(self, path: str) -> ApiResponseType: + def project(self, path: str) -> Project: """ 获取项目信息。 :param path: 项目路径,格式为 'username/project-name' """ - return Project(self._ctx, path=path)._fetch() + return Project(self._ctx, path=path) def projects( self, @@ -134,7 +134,7 @@ def projects( sort: Optional[str] = None, search: Optional[str] = None, detail: Optional[bool] = True, - ) -> ApiResponseType: + ) -> Projects: """ 获取工作空间下的项目列表迭代器。 @@ -143,31 +143,25 @@ def projects( :param search: 搜索关键词 :param detail: 是否返回详细信息 """ - return ApiResponseType(ok=True, data=Projects(self._ctx, path=path, sort=sort, search=search, detail=detail)) + return Projects(self._ctx, path=path, sort=sort, search=search, detail=detail) - def run(self, path: str) -> ApiResponseType: + def run(self, path: str) -> Experiment: """ 获取单个实验。 :param path: 实验路径,格式为 'username/project/run_id' """ - parts = path.split("/") - if len(parts) != 3: - return ApiResponseType( - ok=False, errmsg=f"Invalid path '{path}'. Expected format: 'username/project/run_id'" - ) - proj_path = path.rsplit("/", 1)[0] - expid = parts[2] - return Experiment(self._ctx, path=proj_path, cuid=expid)._fetch() - def runs(self, path: str, filters: Optional[dict] = None) -> ApiResponseType: + return Experiment(self._ctx, path=path) + + def runs(self, path: str, filters: Optional[dict] = None) -> Experiments: """ 获取项目下的实验列表迭代器。 :param path: 项目路径,格式为 'username/project' :param filters: 筛选条件 """ - return ApiResponseType(ok=True, data=Experiments(self._ctx, path=path, filters=filters)) + return Experiments(self._ctx, proj_path=path) __all__ = ["Api"] diff --git a/swanlab/api/base.py b/swanlab/api/base.py index bb22180a4..c10b26d3e 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -45,6 +45,13 @@ def _ensure_data(self) -> Any: """按需加载数据。单实体子类重写此方法;迭代器子类无需重写。""" return None + def wrapper(self) -> ApiResponseType: + """Eager 模式:触发子类 _ensure_data 加载数据,根据 _errors 返回 ApiResponseType。""" + self._ensure_data() + if self._errors: + return ApiResponseType(ok=False, errmsg=self._errors[-1]) + return ApiResponseType(ok=True, data=self) + @abstractmethod def json(self) -> Dict[str, Any]: """将实体序列化为 JSON 可序列化的字典。""" @@ -81,13 +88,6 @@ def _build_web_url(self, path: str) -> str: """构建前端 Web 页面 URL(使用 _web_host 而非 _api_host)。""" return f"{self._ctx.web_host}/{path}" - def _fetch(self) -> ApiResponseType: - """Eager 模式:触发子类 _ensure_data 加载数据,根据 _errors 返回 ApiResponseType。""" - self._ensure_data() - if self._errors: - return ApiResponseType(ok=False, errmsg=self._errors[-1]) - return ApiResponseType(ok=True, data=self) - def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = None) -> Iterator[dict]: """通用分页迭代器,自动处理 page/size 参数。""" page = 1 diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index afca2cdd9..a52425932 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -5,7 +5,7 @@ @description: Experiment 实体类 — 单个实验的查询与操作 """ -from typing import Any, Dict, Iterator, List, Optional, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.experiment import ApiExperimentLabelType, ApiExperimentType @@ -13,6 +13,20 @@ from swanlab.api.utils import get_properties, parse_filter +def _resovle_path(path: str) -> Tuple[str, str]: + """ "path like: user/proj_name/run_id""" + proj_path, cuid = "", "" + parts = path.split("/") + if len(parts) != 3: + return proj_path, cuid + cuid = parts[-1] + proj_path = path.rsplit("/", 1)[0] + return ( + proj_path, + cuid, + ) + + class Profile: """Experiment profile containing config, metadata, requirements, and conda info.""" @@ -63,17 +77,15 @@ def __init__( ctx: ApiClientContext, *, path: str, - cuid: str = "", data: Optional[ApiExperimentType] = None, ) -> None: super().__init__(ctx) - self._path = path # 'username/project-name' - self._cuid: str = cuid or (data.get("cuid", "") if data else "") + self._proj_path, self._cuid = _resovle_path(path=path) self._data = data def _ensure_data(self) -> ApiExperimentType: if self._data is None: - resp = self._get(f"/project/{self._path}/runs/{self._cuid}") + resp = self._get(f"/project/{self._proj_path}/runs/{self._cuid}") self._data = resp.data if resp.ok and resp.data else cast(ApiExperimentType, {}) if not self._cuid and self._data: self._cuid = self._data.get("cuid", "") @@ -99,7 +111,7 @@ def state(self) -> str: @property def url(self) -> str: - return self._build_web_url(f"@{self._path}/runs/{self.run_id}/chart") + return self._build_web_url(f"@{self._proj_path}/runs/{self.run_id}/chart") @property def show(self) -> bool: @@ -135,7 +147,7 @@ def profile(self) -> Profile: """Experiment profile containing config, metadata, requirements, and conda.""" data = self._ensure_data() if "profile" not in data and self._cuid: - resp = self._get(f"/project/{self._path}/runs/{self._cuid}") + resp = self._get(f"/project/{self._proj_path}/runs/{self._cuid}") if resp.ok and resp.data: self._data = resp.data data = self._data @@ -210,7 +222,7 @@ def strip_suffix(col, suffix="_step"): def delete(self) -> bool: """删除此实验。""" - resp = self._delete(f"/project/{self._path}/runs/{self._cuid}") + resp = self._delete(f"/project/{self._proj_path}/runs/{self._cuid}") return resp.ok def json(self) -> Dict[str, Any]: @@ -240,16 +252,16 @@ def __init__( self, ctx: ApiClientContext, *, - path: str, + proj_path: str, filters: Optional[Dict[str, object]] = None, ) -> None: super().__init__(ctx) - self._path = path + self._proj_path = proj_path self._filters = filters def __iter__(self) -> Iterator[Experiment]: parsed_filters = [parse_filter(k, v) for k, v in self._filters.items()] if self._filters else [] - resp = self._post(f"/project/{self._path}/runs/shows", data={"filters": parsed_filters}) + resp = self._post(f"/project/{self._proj_path}/runs/shows", data={"filters": parsed_filters}) if not resp.ok: return body = resp.data @@ -260,7 +272,9 @@ def __iter__(self) -> Iterator[Experiment]: runs = _flatten_runs(body) for run_data in runs: - yield Experiment(self._ctx, path=self._path, data=run_data) + cuid = run_data.get("cuid", "") + full_path = f"{self._proj_path}/{cuid}" + yield Experiment(self._ctx, path=full_path, data=run_data) def json(self) -> Dict[str, Any]: - return {"path": self._path} + return {"path": self._proj_path} diff --git a/swanlab/cli/api/experiment.py b/swanlab/cli/api/experiment.py index cfdc29ce6..876b9e0c4 100644 --- a/swanlab/cli/api/experiment.py +++ b/swanlab/cli/api/experiment.py @@ -16,4 +16,4 @@ def experiment_cli(): def get_experiment(path: str): """Get Experiment(Run) info by path (username/project/run_id).""" api = Api() - return api.run(path) + return api.run(path).wrapper() diff --git a/swanlab/cli/api/project.py b/swanlab/cli/api/project.py index 0f75a57bd..66fdccdb5 100644 --- a/swanlab/cli/api/project.py +++ b/swanlab/cli/api/project.py @@ -16,4 +16,4 @@ def project_cli(): def get_project(path: str): """Get project info by path (username/project).""" api = Api() - return api.project(path) + return api.project(path).wrapper() diff --git a/swanlab/cli/api/workspace.py b/swanlab/cli/api/workspace.py index 4a7f38e7a..c53e52914 100644 --- a/swanlab/cli/api/workspace.py +++ b/swanlab/cli/api/workspace.py @@ -16,4 +16,4 @@ def workspace_cli(): def get_workspace(username: str): """Get Workspace info.""" api = Api() - return api.workspace(username) + return api.workspace(username).wrapper() From c968696c94f635415a878e963c6c9fed1d8c0382 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 18:58:35 +0800 Subject: [PATCH 14/52] feat: add user and group skeleton --- swanlab/__init__.pyi | 3 ++ swanlab/api/__init__.py | 15 ++++++++-- swanlab/api/group.py | 56 +++++++++++++++++++++++++++++++++++ swanlab/api/project.py | 2 +- swanlab/api/typings/common.py | 6 ++++ swanlab/api/typings/group.py | 29 ++++++++++++++++++ swanlab/api/typings/user.py | 15 ++++++---- swanlab/api/user.py | 55 ++++++++++++++++++++++++++++++++++ swanlab/api/utils.py | 10 ++++++- swanlab/api/workspace.py | 2 +- 10 files changed, 182 insertions(+), 11 deletions(-) create mode 100644 swanlab/api/group.py create mode 100644 swanlab/api/typings/group.py create mode 100644 swanlab/api/user.py diff --git a/swanlab/__init__.pyi b/swanlab/__init__.pyi index 8d35e3c57..23a870d4b 100644 --- a/swanlab/__init__.pyi +++ b/swanlab/__init__.pyi @@ -10,6 +10,7 @@ from concurrent.futures import Future from typing import Any, Callable, List, Mapping, Optional, Union from . import utils +from .api import Api from .sdk import Audio, Callback, Image, Run, Settings, Text, Video, config from .sdk.typings.cmd import ConfigLike from .sdk.typings.run import AsyncLogType, FinishType, ModeType, ResumeType @@ -51,6 +52,8 @@ __all__ = [ "utils", "Settings", "Callback", + # Api + "Api", ] # ── lifecycle ────────────────────────────────────────────────────────────────── diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 0ebcb427d..3377bcc02 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -14,8 +14,10 @@ from .base import ApiClientContext, BaseEntity from .experiment import Experiment, Experiments +from .group import Group from .project import Project, Projects from .typings.common import ApiResponseType +from .user import User from .workspace import Workspace, Workspaces @@ -63,12 +65,15 @@ def __init__( login_resp = scope.get_context("login_resp") api_key_resolved, api_host, resolved_web_host = self._resolve_credentials(api_key, host, web_host) _client = Client(api_key=str(api_key_resolved), base_url=api_host) + if login_resp is None: - login_resp = scope.get_context("login_resp") + from swanlab.sdk.internal.pkg.client.bootstrap import login_by_api_key + + login_resp = login_by_api_key(base_url=api_host + "/api", api_key=api_key_resolved) ctx = ApiClientContext(client=_client, web_host=resolved_web_host, api_host=api_host) super().__init__(ctx) - self._username: str = login_resp["userInfo"]["username"] if login_resp else "" + self._username: str = login_resp.get("userInfo", {}).get("username", "") if login_resp else "" def json(self) -> dict: """Api 非数据实体,返回空字典。""" @@ -163,5 +168,11 @@ def runs(self, path: str, filters: Optional[dict] = None) -> Experiments: """ return Experiments(self._ctx, proj_path=path) + def user(self) -> User: + return User(self._ctx) + + def group(self) -> Group: + return Group(self._ctx, username=self._username) + __all__ = ["Api"] diff --git a/swanlab/api/group.py b/swanlab/api/group.py new file mode 100644 index 000000000..41ed5616f --- /dev/null +++ b/swanlab/api/group.py @@ -0,0 +1,56 @@ +""" +@author: caddiesnew +@file: group.py +@time: 2026/4/20 +@description: Group 实体类 — 组织信息的查询 +""" + +from typing import Any, Dict, Optional, cast + +from swanlab.api.base import ApiClientContext, BaseEntity +from swanlab.api.typings.group import ApiGroupProfileType, ApiGroupType +from swanlab.api.utils import get_properties, strip_dict + + +class Group(BaseEntity): + """ + 表示一个 SwanLab 组织。 + """ + + def __init__(self, ctx: ApiClientContext, username: str, data: Optional[ApiGroupType] = None) -> None: + super().__init__(ctx) + self._username = username + self._data = data + + def _ensure_data(self) -> ApiGroupType: + if self._data is None: + resp = self._get(f"/group/{self._username}") + self._data = resp.data if resp.ok and resp.data else cast(ApiGroupType, {}) + return self._data + + @property + def name(self) -> str: + return self._ensure_data().get("name", "") + + @property + def username(self) -> str: + return self._ensure_data().get("username", "") + + @property + def comment(self) -> str: + return self._ensure_data().get("comment", "") + + @property + def group_type(self) -> str: + return self._ensure_data().get("type", "TEAM") + + @property + def status(self) -> str: + return self._ensure_data().get("status", "ACTIVE") + + @property + def profile(self) -> Dict[str, Any]: + return strip_dict(self._ensure_data().get("profile", {}), ApiGroupProfileType) + + def json(self) -> Dict[str, Any]: + return get_properties(self) diff --git a/swanlab/api/project.py b/swanlab/api/project.py index 5308f7fa1..3fc8ab57d 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/project.py @@ -76,7 +76,7 @@ def runs(self, filters: Optional[Dict[str, object]] = None): """获取项目下的实验列表。""" from swanlab.api.experiment import Experiments - return Experiments(self._ctx, path=self.path, filters=filters) + return Experiments(self._ctx, proj_path=self.path, filters=filters) def delete(self) -> bool: """删除此项目。""" diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index 2ff848287..538703564 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -7,6 +7,9 @@ from typing import Any, Dict, List, Literal, TypedDict +# 启用/停用 +ApiStatusLiteral = Literal["ENABLED", "DISABLED"] + # 列类型 ApiColumnLiteral = Literal["SCALAR", "CONFIG", "STABLE"] @@ -16,6 +19,9 @@ # 可见性类型 ApiVisibilityLiteral = Literal["PUBLIC", "PRIVATE"] +# 组织类型 +ApiGroupLiteral = Literal["PERSON", "TEAM"] + # 工作空间类型 ApiWorkspaceLiteral = Literal["TEAM", "PERSON"] diff --git a/swanlab/api/typings/group.py b/swanlab/api/typings/group.py new file mode 100644 index 000000000..ae1cbb3a3 --- /dev/null +++ b/swanlab/api/typings/group.py @@ -0,0 +1,29 @@ +""" +@author: caddiesnew +@file: group.py +@time: 2026/4/22 +@description: 公共查询 API 组织类型定义 +""" + +from typing import TypedDict + +from .common import ApiGroupLiteral, ApiStatusLiteral + + +class ApiGroupProfileType(TypedDict): + bio: str + url: str + institution: str + school: str + email: str + location: str + + +# 在项目信息和用户信息的返回结果中,该类型的字段含义不同,注意区分 +class ApiGroupType(TypedDict): + name: str # 组织名称 (用于user.teams) + username: str + comment: str + type: ApiGroupLiteral + status: ApiStatusLiteral + profile: ApiGroupProfileType diff --git a/swanlab/api/typings/user.py b/swanlab/api/typings/user.py index b4ff66bbc..c34291103 100644 --- a/swanlab/api/typings/user.py +++ b/swanlab/api/typings/user.py @@ -7,10 +7,7 @@ from typing import TypedDict - -class ApiUserType(TypedDict): - name: str - username: str +from .common import ApiStatusLiteral class ApiUserProfileType(TypedDict): @@ -19,6 +16,12 @@ class ApiUserProfileType(TypedDict): location: str school: str email: str - idc: str url: str - telephone: str + + +class ApiUserType(TypedDict): + name: str + username: str + verified: bool + status: ApiStatusLiteral + profile: ApiUserProfileType diff --git a/swanlab/api/user.py b/swanlab/api/user.py new file mode 100644 index 000000000..a85334b87 --- /dev/null +++ b/swanlab/api/user.py @@ -0,0 +1,55 @@ +""" +@author: caddiesnew +@file: user.py +@time: 2026/4/20 +@description: User 实体类 — 用户信息的查询 +""" + +from typing import Any, Dict, Optional, cast + +from swanlab.api.base import ApiClientContext, BaseEntity +from swanlab.api.typings.user import ApiUserProfileType, ApiUserType +from swanlab.api.utils import get_properties, strip_dict + + +class User(BaseEntity): + """ + 表示一个 SwanLab 用户。 + """ + + def __init__( + self, + ctx: ApiClientContext, + data: Optional[ApiUserType] = None, + ) -> None: + super().__init__(ctx) + self._data = data + + def _ensure_data(self) -> ApiUserType: + if self._data is None: + resp = self._get("/user/profile") + self._data = resp.data if resp.ok and resp.data else cast(ApiUserType, {}) + return self._data + + @property + def name(self) -> str: + return self._ensure_data().get("name", "") + + @property + def username(self) -> str: + return self._ensure_data().get("username", "") + + @property + def verified(self) -> bool: + return self._ensure_data().get("verified", False) + + @property + def status(self) -> str: + return self._ensure_data().get("status", "DISABLED") + + @property + def profile(self) -> Dict[str, Any]: + return strip_dict(self._ensure_data().get("profile", {}), ApiUserProfileType) + + def json(self) -> Dict[str, Any]: + return get_properties(self) diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 9db28757c..ad4a8f01a 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -5,7 +5,15 @@ @description: swanlab/api 实体层工具函数 """ -from typing import Dict, List, Optional, Set +from typing import Any, Dict, Optional, Set, Type, get_type_hints + + +def strip_dict(data: Any, typed_cls: Type) -> Dict[str, Any]: + """将原始 API 响应字典裁剪为只保留 TypedDict 中声明的字段。""" + if not data: + return {} + hints = get_type_hints(typed_cls) if hasattr(typed_cls, "__annotations__") else {} + return {k: data[k] for k in hints if k in data} def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str, object]: diff --git a/swanlab/api/workspace.py b/swanlab/api/workspace.py index d00666221..e1546e66e 100644 --- a/swanlab/api/workspace.py +++ b/swanlab/api/workspace.py @@ -98,7 +98,7 @@ def _get_all_workspace_names(self) -> list[str]: resp = self._get(f"/user/{self._username}/groups") if not resp.ok: return [self._username] - group_names = [r["username"] for r in resp.data] + group_names = resp.data if isinstance(resp.data, list) else [] return [self._username] + group_names def __iter__(self) -> Iterator[Workspace]: From 033c63cdac8aa021e7f3caf6f6b239134dd0eab8 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 19:38:35 +0800 Subject: [PATCH 15/52] fix: user info --- swanlab/api/__init__.py | 26 ++++++++++++++---------- swanlab/api/base.py | 4 +++- swanlab/api/typings/user.py | 15 +++++++------- swanlab/api/user.py | 40 ++++++++++++++++++++++++------------- 4 files changed, 53 insertions(+), 32 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 3377bcc02..fa5dea1a8 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -63,17 +63,18 @@ def __init__( """ # 优先从 scope 获取已有登录态(如进程内已调用 swanlab.login),直接复用凭证 login_resp = scope.get_context("login_resp") - api_key_resolved, api_host, resolved_web_host = self._resolve_credentials(api_key, host, web_host) - _client = Client(api_key=str(api_key_resolved), base_url=api_host) + api_key, api_host, web_host = self._resolve_credentials(api_key, host, web_host) + _client = Client(api_key=str(api_key), base_url=api_host) if login_resp is None: from swanlab.sdk.internal.pkg.client.bootstrap import login_by_api_key - login_resp = login_by_api_key(base_url=api_host + "/api", api_key=api_key_resolved) - - ctx = ApiClientContext(client=_client, web_host=resolved_web_host, api_host=api_host) + login_resp = login_by_api_key(base_url=api_host + "/api", api_key=api_key) + user_info = login_resp.get("userInfo", {}) if login_resp else {} + username = user_info.get("username", "") + name = user_info.get("name", "") or "" + ctx = ApiClientContext(client=_client, web_host=web_host, api_host=api_host, username=username, name=name) super().__init__(ctx) - self._username: str = login_resp.get("userInfo", {}).get("username", "") if login_resp else "" def json(self) -> dict: """Api 非数据实体,返回空字典。""" @@ -112,7 +113,7 @@ def workspace(self, username: Optional[str] = None) -> Workspace: :param username: 指定工作空间用户名,为 None 时使用当前登录用户 """ if username is None: - username = self._username + username = self._ctx.username return Workspace(self._ctx, username=username) def workspaces(self, username: Optional[str] = None) -> Workspaces: @@ -122,7 +123,7 @@ def workspaces(self, username: Optional[str] = None) -> Workspaces: :param username: 指定用户名,为 None 时使用当前登录用户 """ if username is None: - username = self._username + username = self._ctx.username return Workspaces(self._ctx, username=username) def project(self, path: str) -> Project: @@ -171,8 +172,13 @@ def runs(self, path: str, filters: Optional[dict] = None) -> Experiments: def user(self) -> User: return User(self._ctx) - def group(self) -> Group: - return Group(self._ctx, username=self._username) + def group(self, username: Optional[str] = None) -> Group: + """ + :param username: 指定用户名,为 None 时使用当前登录用户 + """ + if username is None: + username = self._ctx.username + return Group(self._ctx, username=username) __all__ = ["Api"] diff --git a/swanlab/api/base.py b/swanlab/api/base.py index c10b26d3e..c11ffd4e9 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -21,11 +21,13 @@ @dataclass(frozen=True) class ApiClientContext: - """共享上下文:所有子实体复用同一个实例,避免 (client, web_host, api_host) 三元组透传。""" + """共享上下文:所有子实体复用同一个登录态实例。""" client: "Client" web_host: str api_host: str + username: str + name: str class BaseEntity(ABC): diff --git a/swanlab/api/typings/user.py b/swanlab/api/typings/user.py index c34291103..b9e11d4c2 100644 --- a/swanlab/api/typings/user.py +++ b/swanlab/api/typings/user.py @@ -7,21 +7,22 @@ from typing import TypedDict -from .common import ApiStatusLiteral - class ApiUserProfileType(TypedDict): + # 简介 bio: str + # 个人链接 + url: str + # 机构 institution: str - location: str + # 学校 school: str + # 邮箱 email: str - url: str + # 地址 + location: str class ApiUserType(TypedDict): name: str username: str - verified: bool - status: ApiStatusLiteral - profile: ApiUserProfileType diff --git a/swanlab/api/user.py b/swanlab/api/user.py index a85334b87..67f6afd07 100644 --- a/swanlab/api/user.py +++ b/swanlab/api/user.py @@ -5,51 +5,63 @@ @description: User 实体类 — 用户信息的查询 """ -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.user import ApiUserProfileType, ApiUserType +from swanlab.api.typings.user import ApiUserProfileType from swanlab.api.utils import get_properties, strip_dict class User(BaseEntity): """ - 表示一个 SwanLab 用户。 + 表示一个 SwanLab 用户, 限定为通过 sdk 登录的用户。 """ def __init__( self, ctx: ApiClientContext, - data: Optional[ApiUserType] = None, + data: Optional[Dict[str, Any]] = None, ) -> None: super().__init__(ctx) self._data = data - def _ensure_data(self) -> ApiUserType: + def _ensure_data(self) -> Dict[str, Any]: if self._data is None: resp = self._get("/user/profile") - self._data = resp.data if resp.ok and resp.data else cast(ApiUserType, {}) + self._data = strip_dict(resp.data, ApiUserProfileType) if resp.ok and resp.data else {} return self._data @property def name(self) -> str: - return self._ensure_data().get("name", "") + return self._ctx.name @property def username(self) -> str: - return self._ensure_data().get("username", "") + return self._ctx.username @property - def verified(self) -> bool: - return self._ensure_data().get("verified", False) + def bio(self) -> str: + return self._ensure_data().get("bio", "") @property - def status(self) -> str: - return self._ensure_data().get("status", "DISABLED") + def institution(self) -> str: + return self._ensure_data().get("institution", "") @property - def profile(self) -> Dict[str, Any]: - return strip_dict(self._ensure_data().get("profile", {}), ApiUserProfileType) + def school(self) -> str: + return self._ensure_data().get("school", "") or "" + + @property + def email(self) -> str: + return self._ensure_data().get("email", "") or "" + + @property + def location(self) -> str: + return self._ensure_data().get("location", "") + + @property + def url(self) -> str: + return self._ensure_data().get("url", "") def json(self) -> Dict[str, Any]: return get_properties(self) From e1c0bcd9460c49b1c1c91be9fa6639745eb82e5b Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 20:07:48 +0800 Subject: [PATCH 16/52] fix: workspace as group --- swanlab/api/__init__.py | 9 ----- swanlab/api/group.py | 56 -------------------------------- swanlab/api/typings/__init__.py | 18 ++++++---- swanlab/api/typings/common.py | 3 -- swanlab/api/typings/group.py | 29 ----------------- swanlab/api/typings/workspace.py | 17 ++++++++-- swanlab/api/workspace.py | 39 +++++++++++----------- 7 files changed, 46 insertions(+), 125 deletions(-) delete mode 100644 swanlab/api/group.py delete mode 100644 swanlab/api/typings/group.py diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index fa5dea1a8..14668a0b5 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -14,7 +14,6 @@ from .base import ApiClientContext, BaseEntity from .experiment import Experiment, Experiments -from .group import Group from .project import Project, Projects from .typings.common import ApiResponseType from .user import User @@ -172,13 +171,5 @@ def runs(self, path: str, filters: Optional[dict] = None) -> Experiments: def user(self) -> User: return User(self._ctx) - def group(self, username: Optional[str] = None) -> Group: - """ - :param username: 指定用户名,为 None 时使用当前登录用户 - """ - if username is None: - username = self._ctx.username - return Group(self._ctx, username=username) - __all__ = ["Api"] diff --git a/swanlab/api/group.py b/swanlab/api/group.py deleted file mode 100644 index 41ed5616f..000000000 --- a/swanlab/api/group.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -@author: caddiesnew -@file: group.py -@time: 2026/4/20 -@description: Group 实体类 — 组织信息的查询 -""" - -from typing import Any, Dict, Optional, cast - -from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.group import ApiGroupProfileType, ApiGroupType -from swanlab.api.utils import get_properties, strip_dict - - -class Group(BaseEntity): - """ - 表示一个 SwanLab 组织。 - """ - - def __init__(self, ctx: ApiClientContext, username: str, data: Optional[ApiGroupType] = None) -> None: - super().__init__(ctx) - self._username = username - self._data = data - - def _ensure_data(self) -> ApiGroupType: - if self._data is None: - resp = self._get(f"/group/{self._username}") - self._data = resp.data if resp.ok and resp.data else cast(ApiGroupType, {}) - return self._data - - @property - def name(self) -> str: - return self._ensure_data().get("name", "") - - @property - def username(self) -> str: - return self._ensure_data().get("username", "") - - @property - def comment(self) -> str: - return self._ensure_data().get("comment", "") - - @property - def group_type(self) -> str: - return self._ensure_data().get("type", "TEAM") - - @property - def status(self) -> str: - return self._ensure_data().get("status", "ACTIVE") - - @property - def profile(self) -> Dict[str, Any]: - return strip_dict(self._ensure_data().get("profile", {}), ApiGroupProfileType) - - def json(self) -> Dict[str, Any]: - return get_properties(self) diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index ab1a8ffd0..35fce418a 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -20,10 +20,10 @@ from .project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType from .selfhosted import ApiApiKeyType, ApiSelfHostedInfoType from .user import ApiUserProfileType, ApiUserType -from .workspace import ApiWorkspaceInfoType +from .workspace import ApiWorkspaceProfileType, ApiWorkspaceType __all__ = [ - # Kinds (preferred) + # Literal Definition "ApiColumnLiteral", "ApiRunStateLiteral", "ApiVisibilityLiteral", @@ -31,17 +31,23 @@ "ApiRoleLiteral", "ApiIdentityLiteral", "ApiLicensePlanLiteral", - # TypedDicts + # General TypedDicts "ApiPaginationType", "ApiResponseType", + # Experiment/Run "ApiExperimentLabelType", "ApiExperimentType", + # Project "ApiProjectCountType", "ApiProjectLabelType", "ApiProjectType", - "ApiApiKeyType", - "ApiSelfHostedInfoType", + # User "ApiUserType", "ApiUserProfileType", - "ApiWorkspaceInfoType", + # Worksapce/Group + "ApiWorkspaceType", + "ApiWorkspaceProfileType", + # Misc + "ApiApiKeyType", + "ApiSelfHostedInfoType", ] diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index 538703564..87c65b6ec 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -19,9 +19,6 @@ # 可见性类型 ApiVisibilityLiteral = Literal["PUBLIC", "PRIVATE"] -# 组织类型 -ApiGroupLiteral = Literal["PERSON", "TEAM"] - # 工作空间类型 ApiWorkspaceLiteral = Literal["TEAM", "PERSON"] diff --git a/swanlab/api/typings/group.py b/swanlab/api/typings/group.py deleted file mode 100644 index ae1cbb3a3..000000000 --- a/swanlab/api/typings/group.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -@author: caddiesnew -@file: group.py -@time: 2026/4/22 -@description: 公共查询 API 组织类型定义 -""" - -from typing import TypedDict - -from .common import ApiGroupLiteral, ApiStatusLiteral - - -class ApiGroupProfileType(TypedDict): - bio: str - url: str - institution: str - school: str - email: str - location: str - - -# 在项目信息和用户信息的返回结果中,该类型的字段含义不同,注意区分 -class ApiGroupType(TypedDict): - name: str # 组织名称 (用于user.teams) - username: str - comment: str - type: ApiGroupLiteral - status: ApiStatusLiteral - profile: ApiGroupProfileType diff --git a/swanlab/api/typings/workspace.py b/swanlab/api/typings/workspace.py index fff0ffc4c..4d3edefcc 100644 --- a/swanlab/api/typings/workspace.py +++ b/swanlab/api/typings/workspace.py @@ -9,11 +9,22 @@ from .common import ApiRoleLiteral, ApiWorkspaceLiteral +# 工作空间即 Group 组织 -class ApiWorkspaceInfoType(TypedDict): - name: str + +class ApiWorkspaceProfileType(TypedDict): + bio: str + url: str + institution: str + school: str + email: str + location: str + + +class ApiWorkspaceType(TypedDict): username: str - profile: Dict[str, str] + name: str type: ApiWorkspaceLiteral comment: str role: ApiRoleLiteral + profile: ApiWorkspaceProfileType diff --git a/swanlab/api/workspace.py b/swanlab/api/workspace.py index e1546e66e..e340a383c 100644 --- a/swanlab/api/workspace.py +++ b/swanlab/api/workspace.py @@ -5,11 +5,11 @@ @description: Workspace 实体类 — 工作空间的查询 """ -from typing import Any, Dict, Iterator, Optional, cast +from typing import Any, Dict, Iterator, List, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.workspace import ApiWorkspaceInfoType, ApiWorkspaceLiteral -from swanlab.api.utils import get_properties +from swanlab.api.typings.workspace import ApiWorkspaceLiteral, ApiWorkspaceProfileType, ApiWorkspaceType +from swanlab.api.utils import get_properties, strip_dict class Workspace(BaseEntity): @@ -22,16 +22,16 @@ def __init__( ctx: ApiClientContext, *, username: str, - data: Optional[ApiWorkspaceInfoType] = None, + data: Optional[ApiWorkspaceType] = None, ) -> None: super().__init__(ctx) self._username = username self._data = data - def _ensure_data(self) -> ApiWorkspaceInfoType: + def _ensure_data(self) -> ApiWorkspaceType: if self._data is None: resp = self._get(f"/group/{self._username}") - self._data = resp.data if resp.ok and resp.data else cast(ApiWorkspaceInfoType, {}) + self._data = resp.data if resp.ok and resp.data else cast(ApiWorkspaceType, {}) return self._data @property @@ -47,8 +47,8 @@ def workspace_type(self) -> ApiWorkspaceLiteral: return self._ensure_data().get("type", "PERSON") @property - def profile(self) -> Dict[str, str]: - return self._ensure_data().get("profile", {}) + def profile(self) -> Dict[str, Any]: + return strip_dict(self._ensure_data().get("profile", {}), ApiWorkspaceProfileType) @property def comment(self) -> str: @@ -92,20 +92,21 @@ class Workspaces(BaseEntity): def __init__(self, ctx: ApiClientContext, *, username: str) -> None: super().__init__(ctx) self._username = username + self._data: Optional[List[ApiWorkspaceType]] = None - def _get_all_workspace_names(self) -> list[str]: - """获取用户个人空间 + 所属团队空间名称列表。""" - resp = self._get(f"/user/{self._username}/groups") - if not resp.ok: - return [self._username] - group_names = resp.data if isinstance(resp.data, list) else [] - return [self._username] + group_names + def _ensure_data(self) -> List[ApiWorkspaceType]: + if self._data is None: + resp = self._get(f"/user/{self._username}/groups") + self._data = resp.data if resp.ok and resp.data else [] + assert self._data is not None + return self._data def __iter__(self) -> Iterator[Workspace]: - for name in self._get_all_workspace_names(): - resp = self._get(f"/group/{name}") - data = resp.data if resp.ok else None - yield Workspace(self._ctx, username=name, data=data) + for item in self._ensure_data(): + yield Workspace(self._ctx, username=item["username"], data=item) + + def __len__(self) -> int: + return len(self._ensure_data()) def json(self) -> Dict[str, Any]: return {"username": self._username} From d60f0c7f7ce1dcc5ee9b9945c39744df8a7a0bb0 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 21:03:09 +0800 Subject: [PATCH 17/52] fix: pagination param --- swanlab/api/__init__.py | 15 ++++++++++++--- swanlab/api/base.py | 24 +++++++++++++++++++----- swanlab/api/experiment.py | 16 +++++++++++++++- swanlab/api/project.py | 25 ++++++++++++++++++++----- swanlab/api/typings/project.py | 1 + 5 files changed, 67 insertions(+), 14 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 14668a0b5..55b49f28c 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -139,6 +139,9 @@ def projects( sort: Optional[str] = None, search: Optional[str] = None, detail: Optional[bool] = True, + page: int = 1, + size: int = 20, + all: bool = False, ) -> Projects: """ 获取工作空间下的项目列表迭代器。 @@ -147,8 +150,11 @@ def projects( :param sort: 排序方式 :param search: 搜索关键词 :param detail: 是否返回详细信息 + :param page: 起始页码,默认 1 + :param size: 每页数量,默认 20 + :param all: 是否获取全部数据,默认 False """ - return Projects(self._ctx, path=path, sort=sort, search=search, detail=detail) + return Projects(self._ctx, path=path, sort=sort, search=search, detail=detail, page=page, size=size, all=all) def run(self, path: str) -> Experiment: """ @@ -159,14 +165,17 @@ def run(self, path: str) -> Experiment: return Experiment(self._ctx, path=path) - def runs(self, path: str, filters: Optional[dict] = None) -> Experiments: + def runs( + self, path: str, filters: Optional[dict] = None, page: int = 1, size: int = 20, all: bool = False + ) -> Experiments: """ 获取项目下的实验列表迭代器。 :param path: 项目路径,格式为 'username/project' :param filters: 筛选条件 + :param all: 是否获取全部数据,默认 False """ - return Experiments(self._ctx, proj_path=path) + return Experiments(self._ctx, proj_path=path, filters=filters, page=page, size=size, all=all) def user(self) -> User: return User(self._ctx) diff --git a/swanlab/api/base.py b/swanlab/api/base.py index c11ffd4e9..3b1fce5eb 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -5,8 +5,6 @@ @description: 所有实体类的公共基类 """ -from __future__ import annotations - from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional @@ -90,9 +88,18 @@ def _build_web_url(self, path: str) -> str: """构建前端 Web 页面 URL(使用 _web_host 而非 _api_host)。""" return f"{self._ctx.web_host}/{path}" - def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = None) -> Iterator[dict]: - """通用分页迭代器,自动处理 page/size 参数。""" - page = 1 + def _paginate( + self, + path: str, + *, + page_num: int = 1, + page_size: int = 20, + fetch_all: bool = False, + params: Optional[dict] = None, + page_info: Dict[str, Any], + ) -> Iterator[dict]: + """通用分页迭代器。fetch_all=False 时只取一页,True 时自动翻页取全部。""" + page = page_num while True: p = {"page": page, "size": page_size} if params: @@ -101,12 +108,19 @@ def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = if not resp.ok: return body: ApiPaginationType = resp.data + + if page_info["total"] == 0: + page_info["total"] = body.get("total", 0) + if page_info["pages"] == 0: + page_info["pages"] = body.get("pages", 1) items = body.get("list", []) if not items: break yield from items if page >= body.get("pages", 1): break + if not fetch_all: + break page += 1 def __repr__(self) -> str: diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index a52425932..41e05e837 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -254,10 +254,15 @@ def __init__( *, proj_path: str, filters: Optional[Dict[str, object]] = None, + page: int = 1, + size: int = 20, + all: bool = False, ) -> None: super().__init__(ctx) self._proj_path = proj_path self._filters = filters + self._all = all + self._page_info: Dict[str, Any] = {"page": page, "size": size, "total": 0, "pages": 0, "list": []} def __iter__(self) -> Iterator[Experiment]: parsed_filters = [parse_filter(k, v) for k, v in self._filters.items()] if self._filters else [] @@ -271,10 +276,19 @@ def __iter__(self) -> Iterator[Experiment]: elif isinstance(body, dict): runs = _flatten_runs(body) + total = len(runs) + self._page_info.update({"total": total, "page": 1, "size": total}) + for run_data in runs: cuid = run_data.get("cuid", "") full_path = f"{self._proj_path}/{cuid}" yield Experiment(self._ctx, path=full_path, data=run_data) def json(self) -> Dict[str, Any]: - return {"path": self._proj_path} + info = { + "total": self._page_info.get("total", 0), + "page": self._page_info.get("page", 1), + "size": self._page_info.get("size", 20), + } + info["list"] = [r.json() for r in self] + return info diff --git a/swanlab/api/project.py b/swanlab/api/project.py index 3fc8ab57d..dcff78a3c 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/project.py @@ -54,7 +54,7 @@ def description(self) -> str: @property def visibility(self) -> str: - return self._ensure_data().get("visibility", "PUBLIC") + return self._ensure_data().get("visibility", "PRIVATE") @property def created_at(self) -> str: @@ -72,11 +72,11 @@ def labels(self) -> List[ApiProjectLabelType]: def count(self) -> ApiProjectCountType: return self._ensure_data().get("_count", {}) - def runs(self, filters: Optional[Dict[str, object]] = None): + def runs(self, filters: Optional[Dict[str, object]] = None, all: bool = False): """获取项目下的实验列表。""" from swanlab.api.experiment import Experiments - return Experiments(self._ctx, proj_path=self.path, filters=filters) + return Experiments(self._ctx, proj_path=self.path, filters=filters, all=all) def delete(self) -> bool: """删除此项目。""" @@ -105,16 +105,30 @@ def __init__( sort: Optional[str] = None, search: Optional[str] = None, detail: Optional[bool] = True, + page: int = 1, + size: int = 20, + all: bool = False, ) -> None: super().__init__(ctx) self._path = path self._sort = sort self._search = search self._detail = detail + self._page = page + self._size = size + self._all = all + self._page_info: Dict[str, Any] = {"page": page, "size": size, "total": 0, "pages": 0, "list": []} def __iter__(self) -> Iterator[Project]: params = {"sort": self._sort, "search": self._search, "detail": self._detail} - for item in self._paginate(f"/project/{self._path}", params=params): + for item in self._paginate( + f"/project/{self._path}", + page_num=self._page, + page_size=self._size, + fetch_all=self._all, + params=params, + page_info=self._page_info, + ): yield Project( self._ctx, path=str(item.get("path", "")), @@ -122,4 +136,5 @@ def __iter__(self) -> Iterator[Project]: ) def json(self) -> Dict[str, Any]: - return {"path": self._path} + self._page_info["list"] = [p.json() for p in self] + return self._page_info diff --git a/swanlab/api/typings/project.py b/swanlab/api/typings/project.py index 5b759d2a5..76477f31b 100644 --- a/swanlab/api/typings/project.py +++ b/swanlab/api/typings/project.py @@ -8,6 +8,7 @@ from typing import Dict, List, TypedDict from .common import ApiVisibilityLiteral +from .workspace import ApiWorkspaceType class ApiProjectLabelType(TypedDict): From 4f834e807677077b96608f28fcdf7c356ca3e168 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Wed, 22 Apr 2026 21:12:05 +0800 Subject: [PATCH 18/52] chore: update orjson dependency --- pyproject.toml | 2 + swanlab/cli/api/helper.py | 10 +- uv.lock | 199 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 206 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c2c37e3d9..2d92bc13d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ dependencies = [ "protobuf>=3.19.0,!=4.21.0,!=5.28.0,<7; sys_platform != 'linux'", "rich>=13.6.0,<14.0.0", "pydantic-settings>=2.8.1", + "orjson<=3.11.5; python_version == '3.9'", + "orjson; python_version > '3.9'", ] # 对应 requirements-media.txt 和 requirements-dashboard.txt [cite: 3, 5] diff --git a/swanlab/cli/api/helper.py b/swanlab/cli/api/helper.py index b7dabab30..19f08fd3a 100644 --- a/swanlab/cli/api/helper.py +++ b/swanlab/cli/api/helper.py @@ -1,17 +1,17 @@ import functools -import json from datetime import datetime import click import nanoid +import orjson from swanlab.api.typings.common import ApiResponseType -def _save_json(content: str) -> None: +def _save_json(content: bytes) -> None: """将 JSON 内容保存到当前目录。""" filename = f"swanlab-{datetime.now().strftime('%Y%m%d_%H%M%S')}-{nanoid.generate(size=4)}.json" - with open(filename, "w", encoding="utf-8") as f: + with open(filename, "wb") as f: f.write(content) click.echo(f"Saved to {filename}") @@ -19,9 +19,9 @@ def _save_json(content: str) -> None: def format_output(resp: ApiResponseType, save: bool = False) -> None: """统一输出 ApiResponseType JSON,可选保存到文件。""" data = resp.json() - click.echo(json.dumps(data, ensure_ascii=False)) + click.echo(orjson.dumps(data, option=orjson.OPT_INDENT_2).decode()) if save and resp.ok: - _save_json(json.dumps(data, ensure_ascii=False, indent=2)) + _save_json(orjson.dumps(data, option=orjson.OPT_INDENT_2)) def with_save_option(f): diff --git a/uv.lock b/uv.lock index 745120723..ed4c9363d 100644 --- a/uv.lock +++ b/uv.lock @@ -2790,6 +2790,201 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/a5/1be1516390333ff9be3a9cb648c9f33df79d5096e5884b5df71a588af463/opencv_python-4.13.0.92-cp37-abi3-win_amd64.whl", hash = "sha256:423d934c9fafb91aad38edf26efb46da91ffbc05f3f59c4b0c72e699720706f5", size = 40212062, upload-time = "2026-02-05T07:02:12.724Z" }, ] +[[package]] +name = "orjson" +version = "3.11.5" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10' and sys_platform == 'linux'", + "python_full_version < '3.10' and sys_platform != 'linux'", +] +sdist = { url = "https://files.pythonhosted.org/packages/04/b8/333fdb27840f3bf04022d21b654a35f58e15407183aeb16f3b41aa053446/orjson-3.11.5.tar.gz", hash = "sha256:82393ab47b4fe44ffd0a7659fa9cfaacc717eb617c93cde83795f14af5c2e9d5", size = 5972347, upload-time = "2025-12-06T15:55:39.458Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/19/b22cf9dad4db20c8737041046054cbd4f38bb5a2d0e4bb60487832ce3d76/orjson-3.11.5-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:df9eadb2a6386d5ea2bfd81309c505e125cfc9ba2b1b99a97e60985b0b3665d1", size = 245719, upload-time = "2025-12-06T15:53:43.877Z" }, + { url = "https://files.pythonhosted.org/packages/03/2e/b136dd6bf30ef5143fbe76a4c142828b55ccc618be490201e9073ad954a1/orjson-3.11.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ccc70da619744467d8f1f49a8cadae5ec7bbe054e5232d95f92ed8737f8c5870", size = 132467, upload-time = "2025-12-06T15:53:45.379Z" }, + { url = "https://files.pythonhosted.org/packages/ae/fc/ae99bfc1e1887d20a0268f0e2686eb5b13d0ea7bbe01de2b566febcd2130/orjson-3.11.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:073aab025294c2f6fc0807201c76fdaed86f8fc4be52c440fb78fbb759a1ac09", size = 130702, upload-time = "2025-12-06T15:53:46.659Z" }, + { url = "https://files.pythonhosted.org/packages/6e/43/ef7912144097765997170aca59249725c3ab8ef6079f93f9d708dd058df5/orjson-3.11.5-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:835f26fa24ba0bb8c53ae2a9328d1706135b74ec653ed933869b74b6909e63fd", size = 135907, upload-time = "2025-12-06T15:53:48.487Z" }, + { url = "https://files.pythonhosted.org/packages/3f/da/24d50e2d7f4092ddd4d784e37a3fa41f22ce8ed97abc9edd222901a96e74/orjson-3.11.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:667c132f1f3651c14522a119e4dd631fad98761fa960c55e8e7430bb2a1ba4ac", size = 139935, upload-time = "2025-12-06T15:53:49.88Z" }, + { url = "https://files.pythonhosted.org/packages/02/4a/b4cb6fcbfff5b95a3a019a8648255a0fac9b221fbf6b6e72be8df2361feb/orjson-3.11.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42e8961196af655bb5e63ce6c60d25e8798cd4dfbc04f4203457fa3869322c2e", size = 137541, upload-time = "2025-12-06T15:53:51.226Z" }, + { url = "https://files.pythonhosted.org/packages/a5/99/a11bd129f18c2377c27b2846a9d9be04acec981f770d711ba0aaea563984/orjson-3.11.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75412ca06e20904c19170f8a24486c4e6c7887dea591ba18a1ab572f1300ee9f", size = 139031, upload-time = "2025-12-06T15:53:52.309Z" }, + { url = "https://files.pythonhosted.org/packages/64/29/d7b77d7911574733a036bb3e8ad7053ceb2b7d6ea42208b9dbc55b23b9ed/orjson-3.11.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6af8680328c69e15324b5af3ae38abbfcf9cbec37b5346ebfd52339c3d7e8a18", size = 141622, upload-time = "2025-12-06T15:53:53.606Z" }, + { url = "https://files.pythonhosted.org/packages/93/41/332db96c1de76b2feda4f453e91c27202cd092835936ce2b70828212f726/orjson-3.11.5-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:a86fe4ff4ea523eac8f4b57fdac319faf037d3c1be12405e6a7e86b3fbc4756a", size = 413800, upload-time = "2025-12-06T15:53:54.866Z" }, + { url = "https://files.pythonhosted.org/packages/76/e1/5a0d148dd1f89ad2f9651df67835b209ab7fcb1118658cf353425d7563e9/orjson-3.11.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e607b49b1a106ee2086633167033afbd63f76f2999e9236f638b06b112b24ea7", size = 151198, upload-time = "2025-12-06T15:53:56.383Z" }, + { url = "https://files.pythonhosted.org/packages/0d/96/8db67430d317a01ae5cf7971914f6775affdcfe99f5bff9ef3da32492ecc/orjson-3.11.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7339f41c244d0eea251637727f016b3d20050636695bc78345cce9029b189401", size = 141984, upload-time = "2025-12-06T15:53:57.746Z" }, + { url = "https://files.pythonhosted.org/packages/71/49/40d21e1aa1ac569e521069228bb29c9b5a350344ccf922a0227d93c2ed44/orjson-3.11.5-cp310-cp310-win32.whl", hash = "sha256:8be318da8413cdbbce77b8c5fac8d13f6eb0f0db41b30bb598631412619572e8", size = 135272, upload-time = "2025-12-06T15:53:59.769Z" }, + { url = "https://files.pythonhosted.org/packages/c4/7e/d0e31e78be0c100e08be64f48d2850b23bcb4d4c70d114f4e43b39f6895a/orjson-3.11.5-cp310-cp310-win_amd64.whl", hash = "sha256:b9f86d69ae822cabc2a0f6c099b43e8733dda788405cba2665595b7e8dd8d167", size = 133360, upload-time = "2025-12-06T15:54:01.25Z" }, + { url = "https://files.pythonhosted.org/packages/fd/68/6b3659daec3a81aed5ab47700adb1a577c76a5452d35b91c88efee89987f/orjson-3.11.5-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9c8494625ad60a923af6b2b0bd74107146efe9b55099e20d7740d995f338fcd8", size = 245318, upload-time = "2025-12-06T15:54:02.355Z" }, + { url = "https://files.pythonhosted.org/packages/e9/00/92db122261425f61803ccf0830699ea5567439d966cbc35856fe711bfe6b/orjson-3.11.5-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:7bb2ce0b82bc9fd1168a513ddae7a857994b780b2945a8c51db4ab1c4b751ebc", size = 129491, upload-time = "2025-12-06T15:54:03.877Z" }, + { url = "https://files.pythonhosted.org/packages/94/4f/ffdcb18356518809d944e1e1f77589845c278a1ebbb5a8297dfefcc4b4cb/orjson-3.11.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67394d3becd50b954c4ecd24ac90b5051ee7c903d167459f93e77fc6f5b4c968", size = 132167, upload-time = "2025-12-06T15:54:04.944Z" }, + { url = "https://files.pythonhosted.org/packages/97/c6/0a8caff96f4503f4f7dd44e40e90f4d14acf80d3b7a97cb88747bb712d3e/orjson-3.11.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:298d2451f375e5f17b897794bcc3e7b821c0f32b4788b9bcae47ada24d7f3cf7", size = 130516, upload-time = "2025-12-06T15:54:06.274Z" }, + { url = "https://files.pythonhosted.org/packages/4d/63/43d4dc9bd9954bff7052f700fdb501067f6fb134a003ddcea2a0bb3854ed/orjson-3.11.5-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa5e4244063db8e1d87e0f54c3f7522f14b2dc937e65d5241ef0076a096409fd", size = 135695, upload-time = "2025-12-06T15:54:07.702Z" }, + { url = "https://files.pythonhosted.org/packages/87/6f/27e2e76d110919cb7fcb72b26166ee676480a701bcf8fc53ac5d0edce32f/orjson-3.11.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1db2088b490761976c1b2e956d5d4e6409f3732e9d79cfa69f876c5248d1baf9", size = 139664, upload-time = "2025-12-06T15:54:08.828Z" }, + { url = "https://files.pythonhosted.org/packages/d4/f8/5966153a5f1be49b5fbb8ca619a529fde7bc71aa0a376f2bb83fed248bcd/orjson-3.11.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2ed66358f32c24e10ceea518e16eb3549e34f33a9d51f99ce23b0251776a1ef", size = 137289, upload-time = "2025-12-06T15:54:09.898Z" }, + { url = "https://files.pythonhosted.org/packages/a7/34/8acb12ff0299385c8bbcbb19fbe40030f23f15a6de57a9c587ebf71483fb/orjson-3.11.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c2021afda46c1ed64d74b555065dbd4c2558d510d8cec5ea6a53001b3e5e82a9", size = 138784, upload-time = "2025-12-06T15:54:11.022Z" }, + { url = "https://files.pythonhosted.org/packages/ee/27/910421ea6e34a527f73d8f4ee7bdffa48357ff79c7b8d6eb6f7b82dd1176/orjson-3.11.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b42ffbed9128e547a1647a3e50bc88ab28ae9daa61713962e0d3dd35e820c125", size = 141322, upload-time = "2025-12-06T15:54:12.427Z" }, + { url = "https://files.pythonhosted.org/packages/87/a3/4b703edd1a05555d4bb1753d6ce44e1a05b7a6d7c164d5b332c795c63d70/orjson-3.11.5-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:8d5f16195bb671a5dd3d1dbea758918bada8f6cc27de72bd64adfbd748770814", size = 413612, upload-time = "2025-12-06T15:54:13.858Z" }, + { url = "https://files.pythonhosted.org/packages/1b/36/034177f11d7eeea16d3d2c42a1883b0373978e08bc9dad387f5074c786d8/orjson-3.11.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c0e5d9f7a0227df2927d343a6e3859bebf9208b427c79bd31949abcc2fa32fa5", size = 150993, upload-time = "2025-12-06T15:54:15.189Z" }, + { url = "https://files.pythonhosted.org/packages/44/2f/ea8b24ee046a50a7d141c0227c4496b1180b215e728e3b640684f0ea448d/orjson-3.11.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:23d04c4543e78f724c4dfe656b3791b5f98e4c9253e13b2636f1af5d90e4a880", size = 141774, upload-time = "2025-12-06T15:54:16.451Z" }, + { url = "https://files.pythonhosted.org/packages/8a/12/cc440554bf8200eb23348a5744a575a342497b65261cd65ef3b28332510a/orjson-3.11.5-cp311-cp311-win32.whl", hash = "sha256:c404603df4865f8e0afe981aa3c4b62b406e6d06049564d58934860b62b7f91d", size = 135109, upload-time = "2025-12-06T15:54:17.73Z" }, + { url = "https://files.pythonhosted.org/packages/a3/83/e0c5aa06ba73a6760134b169f11fb970caa1525fa4461f94d76e692299d9/orjson-3.11.5-cp311-cp311-win_amd64.whl", hash = "sha256:9645ef655735a74da4990c24ffbd6894828fbfa117bc97c1edd98c282ecb52e1", size = 133193, upload-time = "2025-12-06T15:54:19.426Z" }, + { url = "https://files.pythonhosted.org/packages/cb/35/5b77eaebc60d735e832c5b1a20b155667645d123f09d471db0a78280fb49/orjson-3.11.5-cp311-cp311-win_arm64.whl", hash = "sha256:1cbf2735722623fcdee8e712cbaaab9e372bbcb0c7924ad711b261c2eccf4a5c", size = 126830, upload-time = "2025-12-06T15:54:20.836Z" }, + { url = "https://files.pythonhosted.org/packages/ef/a4/8052a029029b096a78955eadd68ab594ce2197e24ec50e6b6d2ab3f4e33b/orjson-3.11.5-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:334e5b4bff9ad101237c2d799d9fd45737752929753bf4faf4b207335a416b7d", size = 245347, upload-time = "2025-12-06T15:54:22.061Z" }, + { url = "https://files.pythonhosted.org/packages/64/67/574a7732bd9d9d79ac620c8790b4cfe0717a3d5a6eb2b539e6e8995e24a0/orjson-3.11.5-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:ff770589960a86eae279f5d8aa536196ebda8273a2a07db2a54e82b93bc86626", size = 129435, upload-time = "2025-12-06T15:54:23.615Z" }, + { url = "https://files.pythonhosted.org/packages/52/8d/544e77d7a29d90cf4d9eecd0ae801c688e7f3d1adfa2ebae5e1e94d38ab9/orjson-3.11.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed24250e55efbcb0b35bed7caaec8cedf858ab2f9f2201f17b8938c618c8ca6f", size = 132074, upload-time = "2025-12-06T15:54:24.694Z" }, + { url = "https://files.pythonhosted.org/packages/6e/57/b9f5b5b6fbff9c26f77e785baf56ae8460ef74acdb3eae4931c25b8f5ba9/orjson-3.11.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a66d7769e98a08a12a139049aac2f0ca3adae989817f8c43337455fbc7669b85", size = 130520, upload-time = "2025-12-06T15:54:26.185Z" }, + { url = "https://files.pythonhosted.org/packages/f6/6d/d34970bf9eb33f9ec7c979a262cad86076814859e54eb9a059a52f6dc13d/orjson-3.11.5-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:86cfc555bfd5794d24c6a1903e558b50644e5e68e6471d66502ce5cb5fdef3f9", size = 136209, upload-time = "2025-12-06T15:54:27.264Z" }, + { url = "https://files.pythonhosted.org/packages/e7/39/bc373b63cc0e117a105ea12e57280f83ae52fdee426890d57412432d63b3/orjson-3.11.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a230065027bc2a025e944f9d4714976a81e7ecfa940923283bca7bbc1f10f626", size = 139837, upload-time = "2025-12-06T15:54:28.75Z" }, + { url = "https://files.pythonhosted.org/packages/cb/aa/7c4818c8d7d324da220f4f1af55c343956003aa4d1ce1857bdc1d396ba69/orjson-3.11.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b29d36b60e606df01959c4b982729c8845c69d1963f88686608be9ced96dbfaa", size = 137307, upload-time = "2025-12-06T15:54:29.856Z" }, + { url = "https://files.pythonhosted.org/packages/46/bf/0993b5a056759ba65145effe3a79dd5a939d4a070eaa5da2ee3180fbb13f/orjson-3.11.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c74099c6b230d4261fdc3169d50efc09abf38ace1a42ea2f9994b1d79153d477", size = 139020, upload-time = "2025-12-06T15:54:31.024Z" }, + { url = "https://files.pythonhosted.org/packages/65/e8/83a6c95db3039e504eda60fc388f9faedbb4f6472f5aba7084e06552d9aa/orjson-3.11.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e697d06ad57dd0c7a737771d470eedc18e68dfdefcdd3b7de7f33dfda5b6212e", size = 141099, upload-time = "2025-12-06T15:54:32.196Z" }, + { url = "https://files.pythonhosted.org/packages/b9/b4/24fdc024abfce31c2f6812973b0a693688037ece5dc64b7a60c1ce69e2f2/orjson-3.11.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:e08ca8a6c851e95aaecc32bc44a5aa75d0ad26af8cdac7c77e4ed93acf3d5b69", size = 413540, upload-time = "2025-12-06T15:54:33.361Z" }, + { url = "https://files.pythonhosted.org/packages/d9/37/01c0ec95d55ed0c11e4cae3e10427e479bba40c77312b63e1f9665e0737d/orjson-3.11.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:e8b5f96c05fce7d0218df3fdfeb962d6b8cfff7e3e20264306b46dd8b217c0f3", size = 151530, upload-time = "2025-12-06T15:54:34.6Z" }, + { url = "https://files.pythonhosted.org/packages/f9/d4/f9ebc57182705bb4bbe63f5bbe14af43722a2533135e1d2fb7affa0c355d/orjson-3.11.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ddbfdb5099b3e6ba6d6ea818f61997bb66de14b411357d24c4612cf1ebad08ca", size = 141863, upload-time = "2025-12-06T15:54:35.801Z" }, + { url = "https://files.pythonhosted.org/packages/0d/04/02102b8d19fdcb009d72d622bb5781e8f3fae1646bf3e18c53d1bc8115b5/orjson-3.11.5-cp312-cp312-win32.whl", hash = "sha256:9172578c4eb09dbfcf1657d43198de59b6cef4054de385365060ed50c458ac98", size = 135255, upload-time = "2025-12-06T15:54:37.209Z" }, + { url = "https://files.pythonhosted.org/packages/d4/fb/f05646c43d5450492cb387de5549f6de90a71001682c17882d9f66476af5/orjson-3.11.5-cp312-cp312-win_amd64.whl", hash = "sha256:2b91126e7b470ff2e75746f6f6ee32b9ab67b7a93c8ba1d15d3a0caaf16ec875", size = 133252, upload-time = "2025-12-06T15:54:38.401Z" }, + { url = "https://files.pythonhosted.org/packages/dc/a6/7b8c0b26ba18c793533ac1cd145e131e46fcf43952aa94c109b5b913c1f0/orjson-3.11.5-cp312-cp312-win_arm64.whl", hash = "sha256:acbc5fac7e06777555b0722b8ad5f574739e99ffe99467ed63da98f97f9ca0fe", size = 126777, upload-time = "2025-12-06T15:54:39.515Z" }, + { url = "https://files.pythonhosted.org/packages/10/43/61a77040ce59f1569edf38f0b9faadc90c8cf7e9bec2e0df51d0132c6bb7/orjson-3.11.5-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:3b01799262081a4c47c035dd77c1301d40f568f77cc7ec1bb7db5d63b0a01629", size = 245271, upload-time = "2025-12-06T15:54:40.878Z" }, + { url = "https://files.pythonhosted.org/packages/55/f9/0f79be617388227866d50edd2fd320cb8fb94dc1501184bb1620981a0aba/orjson-3.11.5-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:61de247948108484779f57a9f406e4c84d636fa5a59e411e6352484985e8a7c3", size = 129422, upload-time = "2025-12-06T15:54:42.403Z" }, + { url = "https://files.pythonhosted.org/packages/77/42/f1bf1549b432d4a78bfa95735b79b5dac75b65b5bb815bba86ad406ead0a/orjson-3.11.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:894aea2e63d4f24a7f04a1908307c738d0dce992e9249e744b8f4e8dd9197f39", size = 132060, upload-time = "2025-12-06T15:54:43.531Z" }, + { url = "https://files.pythonhosted.org/packages/25/49/825aa6b929f1a6ed244c78acd7b22c1481fd7e5fda047dc8bf4c1a807eb6/orjson-3.11.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ddc21521598dbe369d83d4d40338e23d4101dad21dae0e79fa20465dbace019f", size = 130391, upload-time = "2025-12-06T15:54:45.059Z" }, + { url = "https://files.pythonhosted.org/packages/42/ec/de55391858b49e16e1aa8f0bbbb7e5997b7345d8e984a2dec3746d13065b/orjson-3.11.5-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7cce16ae2f5fb2c53c3eafdd1706cb7b6530a67cc1c17abe8ec747f5cd7c0c51", size = 135964, upload-time = "2025-12-06T15:54:46.576Z" }, + { url = "https://files.pythonhosted.org/packages/1c/40/820bc63121d2d28818556a2d0a09384a9f0262407cf9fa305e091a8048df/orjson-3.11.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e46c762d9f0e1cfb4ccc8515de7f349abbc95b59cb5a2bd68df5973fdef913f8", size = 139817, upload-time = "2025-12-06T15:54:48.084Z" }, + { url = "https://files.pythonhosted.org/packages/09/c7/3a445ca9a84a0d59d26365fd8898ff52bdfcdcb825bcc6519830371d2364/orjson-3.11.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d7345c759276b798ccd6d77a87136029e71e66a8bbf2d2755cbdde1d82e78706", size = 137336, upload-time = "2025-12-06T15:54:49.426Z" }, + { url = "https://files.pythonhosted.org/packages/9a/b3/dc0d3771f2e5d1f13368f56b339c6782f955c6a20b50465a91acb79fe961/orjson-3.11.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75bc2e59e6a2ac1dd28901d07115abdebc4563b5b07dd612bf64260a201b1c7f", size = 138993, upload-time = "2025-12-06T15:54:50.939Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a2/65267e959de6abe23444659b6e19c888f242bf7725ff927e2292776f6b89/orjson-3.11.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:54aae9b654554c3b4edd61896b978568c6daa16af96fa4681c9b5babd469f863", size = 141070, upload-time = "2025-12-06T15:54:52.414Z" }, + { url = "https://files.pythonhosted.org/packages/63/c9/da44a321b288727a322c6ab17e1754195708786a04f4f9d2220a5076a649/orjson-3.11.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:4bdd8d164a871c4ec773f9de0f6fe8769c2d6727879c37a9666ba4183b7f8228", size = 413505, upload-time = "2025-12-06T15:54:53.67Z" }, + { url = "https://files.pythonhosted.org/packages/7f/17/68dc14fa7000eefb3d4d6d7326a190c99bb65e319f02747ef3ebf2452f12/orjson-3.11.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a261fef929bcf98a60713bf5e95ad067cea16ae345d9a35034e73c3990e927d2", size = 151342, upload-time = "2025-12-06T15:54:55.113Z" }, + { url = "https://files.pythonhosted.org/packages/c4/c5/ccee774b67225bed630a57478529fc026eda33d94fe4c0eac8fe58d4aa52/orjson-3.11.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c028a394c766693c5c9909dec76b24f37e6a1b91999e8d0c0d5feecbe93c3e05", size = 141823, upload-time = "2025-12-06T15:54:56.331Z" }, + { url = "https://files.pythonhosted.org/packages/67/80/5d00e4155d0cd7390ae2087130637671da713959bb558db9bac5e6f6b042/orjson-3.11.5-cp313-cp313-win32.whl", hash = "sha256:2cc79aaad1dfabe1bd2d50ee09814a1253164b3da4c00a78c458d82d04b3bdef", size = 135236, upload-time = "2025-12-06T15:54:57.507Z" }, + { url = "https://files.pythonhosted.org/packages/95/fe/792cc06a84808dbdc20ac6eab6811c53091b42f8e51ecebf14b540e9cfe4/orjson-3.11.5-cp313-cp313-win_amd64.whl", hash = "sha256:ff7877d376add4e16b274e35a3f58b7f37b362abf4aa31863dadacdd20e3a583", size = 133167, upload-time = "2025-12-06T15:54:58.71Z" }, + { url = "https://files.pythonhosted.org/packages/46/2c/d158bd8b50e3b1cfdcf406a7e463f6ffe3f0d167b99634717acdaf5e299f/orjson-3.11.5-cp313-cp313-win_arm64.whl", hash = "sha256:59ac72ea775c88b163ba8d21b0177628bd015c5dd060647bbab6e22da3aad287", size = 126712, upload-time = "2025-12-06T15:54:59.892Z" }, + { url = "https://files.pythonhosted.org/packages/c2/60/77d7b839e317ead7bb225d55bb50f7ea75f47afc489c81199befc5435b50/orjson-3.11.5-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e446a8ea0a4c366ceafc7d97067bfd55292969143b57e3c846d87fc701e797a0", size = 245252, upload-time = "2025-12-06T15:55:01.127Z" }, + { url = "https://files.pythonhosted.org/packages/f1/aa/d4639163b400f8044cef0fb9aa51b0337be0da3a27187a20d1166e742370/orjson-3.11.5-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:53deb5addae9c22bbe3739298f5f2196afa881ea75944e7720681c7080909a81", size = 129419, upload-time = "2025-12-06T15:55:02.723Z" }, + { url = "https://files.pythonhosted.org/packages/30/94/9eabf94f2e11c671111139edf5ec410d2f21e6feee717804f7e8872d883f/orjson-3.11.5-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82cd00d49d6063d2b8791da5d4f9d20539c5951f965e45ccf4e96d33505ce68f", size = 132050, upload-time = "2025-12-06T15:55:03.918Z" }, + { url = "https://files.pythonhosted.org/packages/3d/c8/ca10f5c5322f341ea9a9f1097e140be17a88f88d1cfdd29df522970d9744/orjson-3.11.5-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3fd15f9fc8c203aeceff4fda211157fad114dde66e92e24097b3647a08f4ee9e", size = 130370, upload-time = "2025-12-06T15:55:05.173Z" }, + { url = "https://files.pythonhosted.org/packages/25/d4/e96824476d361ee2edd5c6290ceb8d7edf88d81148a6ce172fc00278ca7f/orjson-3.11.5-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9df95000fbe6777bf9820ae82ab7578e8662051bb5f83d71a28992f539d2cda7", size = 136012, upload-time = "2025-12-06T15:55:06.402Z" }, + { url = "https://files.pythonhosted.org/packages/85/8e/9bc3423308c425c588903f2d103cfcfe2539e07a25d6522900645a6f257f/orjson-3.11.5-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:92a8d676748fca47ade5bc3da7430ed7767afe51b2f8100e3cd65e151c0eaceb", size = 139809, upload-time = "2025-12-06T15:55:07.656Z" }, + { url = "https://files.pythonhosted.org/packages/e9/3c/b404e94e0b02a232b957c54643ce68d0268dacb67ac33ffdee24008c8b27/orjson-3.11.5-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aa0f513be38b40234c77975e68805506cad5d57b3dfd8fe3baa7f4f4051e15b4", size = 137332, upload-time = "2025-12-06T15:55:08.961Z" }, + { url = "https://files.pythonhosted.org/packages/51/30/cc2d69d5ce0ad9b84811cdf4a0cd5362ac27205a921da524ff42f26d65e0/orjson-3.11.5-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa1863e75b92891f553b7922ce4ee10ed06db061e104f2b7815de80cdcb135ad", size = 138983, upload-time = "2025-12-06T15:55:10.595Z" }, + { url = "https://files.pythonhosted.org/packages/0e/87/de3223944a3e297d4707d2fe3b1ffb71437550e165eaf0ca8bbe43ccbcb1/orjson-3.11.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d4be86b58e9ea262617b8ca6251a2f0d63cc132a6da4b5fcc8e0a4128782c829", size = 141069, upload-time = "2025-12-06T15:55:11.832Z" }, + { url = "https://files.pythonhosted.org/packages/65/30/81d5087ae74be33bcae3ff2d80f5ccaa4a8fedc6d39bf65a427a95b8977f/orjson-3.11.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:b923c1c13fa02084eb38c9c065afd860a5cff58026813319a06949c3af5732ac", size = 413491, upload-time = "2025-12-06T15:55:13.314Z" }, + { url = "https://files.pythonhosted.org/packages/d0/6f/f6058c21e2fc1efaf918986dbc2da5cd38044f1a2d4b7b91ad17c4acf786/orjson-3.11.5-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:1b6bd351202b2cd987f35a13b5e16471cf4d952b42a73c391cc537974c43ef6d", size = 151375, upload-time = "2025-12-06T15:55:14.715Z" }, + { url = "https://files.pythonhosted.org/packages/54/92/c6921f17d45e110892899a7a563a925b2273d929959ce2ad89e2525b885b/orjson-3.11.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:bb150d529637d541e6af06bbe3d02f5498d628b7f98267ff87647584293ab439", size = 141850, upload-time = "2025-12-06T15:55:15.94Z" }, + { url = "https://files.pythonhosted.org/packages/88/86/cdecb0140a05e1a477b81f24739da93b25070ee01ce7f7242f44a6437594/orjson-3.11.5-cp314-cp314-win32.whl", hash = "sha256:9cc1e55c884921434a84a0c3dd2699eb9f92e7b441d7f53f3941079ec6ce7499", size = 135278, upload-time = "2025-12-06T15:55:17.202Z" }, + { url = "https://files.pythonhosted.org/packages/e4/97/b638d69b1e947d24f6109216997e38922d54dcdcdb1b11c18d7efd2d3c59/orjson-3.11.5-cp314-cp314-win_amd64.whl", hash = "sha256:a4f3cb2d874e03bc7767c8f88adaa1a9a05cecea3712649c3b58589ec7317310", size = 133170, upload-time = "2025-12-06T15:55:18.468Z" }, + { url = "https://files.pythonhosted.org/packages/8f/dd/f4fff4a6fe601b4f8f3ba3aa6da8ac33d17d124491a3b804c662a70e1636/orjson-3.11.5-cp314-cp314-win_arm64.whl", hash = "sha256:38b22f476c351f9a1c43e5b07d8b5a02eb24a6ab8e75f700f7d479d4568346a5", size = 126713, upload-time = "2025-12-06T15:55:19.738Z" }, + { url = "https://files.pythonhosted.org/packages/50/c7/7b682849dd4c9fb701a981669b964ea700516ecbd8e88f62aae07c6852bd/orjson-3.11.5-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:1b280e2d2d284a6713b0cfec7b08918ebe57df23e3f76b27586197afca3cb1e9", size = 245298, upload-time = "2025-12-06T15:55:20.984Z" }, + { url = "https://files.pythonhosted.org/packages/1b/3f/194355a9335707a15fdc79ddc670148987b43d04712dd26898a694539ce6/orjson-3.11.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c8d8a112b274fae8c5f0f01954cb0480137072c271f3f4958127b010dfefaec", size = 132150, upload-time = "2025-12-06T15:55:22.364Z" }, + { url = "https://files.pythonhosted.org/packages/e9/08/d74b3a986d37e6c2e04b8821c62927620c9a1924bb49ea51519a87751b86/orjson-3.11.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f0a2ae6f09ac7bd47d2d5a5305c1d9ed08ac057cda55bb0a49fa506f0d2da00", size = 130490, upload-time = "2025-12-06T15:55:23.619Z" }, + { url = "https://files.pythonhosted.org/packages/b2/16/ebd04c38c1db01e493a68eee442efdffc505a43112eccd481e0146c6acc2/orjson-3.11.5-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0d87bd1896faac0d10b4f849016db81a63e4ec5df38757ffae84d45ab38aa71", size = 135726, upload-time = "2025-12-06T15:55:24.912Z" }, + { url = "https://files.pythonhosted.org/packages/06/64/2ce4b2c09a099403081c37639c224bdcdfe401138bd66fed5c96d4f8dbd3/orjson-3.11.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:801a821e8e6099b8c459ac7540b3c32dba6013437c57fdcaec205b169754f38c", size = 139640, upload-time = "2025-12-06T15:55:26.535Z" }, + { url = "https://files.pythonhosted.org/packages/cd/e2/425796df8ee1d7cea3a7edf868920121dd09162859dbb76fffc9a5c37fd3/orjson-3.11.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69a0f6ac618c98c74b7fbc8c0172ba86f9e01dbf9f62aa0b1776c2231a7bffe5", size = 137289, upload-time = "2025-12-06T15:55:27.78Z" }, + { url = "https://files.pythonhosted.org/packages/32/a2/88e482eb8e899a037dcc9eff85ef117a568e6ca1ffa1a2b2be3fcb51b7bb/orjson-3.11.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fea7339bdd22e6f1060c55ac31b6a755d86a5b2ad3657f2669ec243f8e3b2bdb", size = 138761, upload-time = "2025-12-06T15:55:29.388Z" }, + { url = "https://files.pythonhosted.org/packages/f1/fd/131dd6d32eeb74c513bfa487f434a2150811d0fbd9cb06689284f2f21b34/orjson-3.11.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:4dad582bc93cef8f26513e12771e76385a7e6187fd713157e971c784112aad56", size = 141357, upload-time = "2025-12-06T15:55:31.064Z" }, + { url = "https://files.pythonhosted.org/packages/7a/90/e4a0abbcca7b53e9098ac854f27f5ed9949c796f3c760bc04af997da0eb2/orjson-3.11.5-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:0522003e9f7fba91982e83a97fec0708f5a714c96c4209db7104e6b9d132f111", size = 413638, upload-time = "2025-12-06T15:55:32.344Z" }, + { url = "https://files.pythonhosted.org/packages/d1/c2/df91e385514924120001ade9cd52d6295251023d3bfa2c0a01f38cfc485a/orjson-3.11.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:7403851e430a478440ecc1258bcbacbfbd8175f9ac1e39031a7121dd0de05ff8", size = 150972, upload-time = "2025-12-06T15:55:33.725Z" }, + { url = "https://files.pythonhosted.org/packages/a6/ff/c76cc5a30a4451191ff1b868a331ad1354433335277fc40931f5fc3cab9d/orjson-3.11.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5f691263425d3177977c8d1dd896cde7b98d93cbf390b2544a090675e83a6a0a", size = 141729, upload-time = "2025-12-06T15:55:35.317Z" }, + { url = "https://files.pythonhosted.org/packages/27/c3/7830bf74389ea1eaab2b017d8b15d1cab2bb0737d9412dfa7fb8644f7d78/orjson-3.11.5-cp39-cp39-win32.whl", hash = "sha256:61026196a1c4b968e1b1e540563e277843082e9e97d78afa03eb89315af531f1", size = 135100, upload-time = "2025-12-06T15:55:36.57Z" }, + { url = "https://files.pythonhosted.org/packages/69/e6/babf31154e047e465bc194eb72d1326d7c52ad4d7f50bf92b02b3cacda5c/orjson-3.11.5-cp39-cp39-win_amd64.whl", hash = "sha256:09b94b947ac08586af635ef922d69dc9bc63321527a3a04647f4986a73f4bd30", size = 133189, upload-time = "2025-12-06T15:55:38.143Z" }, +] + +[[package]] +name = "orjson" +version = "3.11.8" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and sys_platform == 'linux'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'linux'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version == '3.10.*' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and sys_platform == 'emscripten'", + "python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform == 'emscripten'", + "python_full_version == '3.11.*' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.10.*' and sys_platform != 'linux'", +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/1b/2024d06792d0779f9dbc51531b61c24f76c75b9f4ce05e6f3377a1814cea/orjson-3.11.8.tar.gz", hash = "sha256:96163d9cdc5a202703e9ad1b9ae757d5f0ca62f4fa0cc93d1f27b0e180cc404e", size = 5603832, upload-time = "2026-03-31T16:16:27.878Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/90/5d81f61fe3e4270da80c71442864c091cee3003cc8984c75f413fe742a07/orjson-3.11.8-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e6693ff90018600c72fd18d3d22fa438be26076cd3c823da5f63f7bab28c11cb", size = 229663, upload-time = "2026-03-31T16:14:30.708Z" }, + { url = "https://files.pythonhosted.org/packages/6c/ef/85e06b0eb11de6fb424120fd5788a07035bd4c5e6bb7841ae9972a0526d1/orjson-3.11.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93de06bc920854552493c81f1f729fab7213b7db4b8195355db5fda02c7d1363", size = 132321, upload-time = "2026-03-31T16:14:32.317Z" }, + { url = "https://files.pythonhosted.org/packages/86/71/089338ee51b3132f050db0864a7df9bdd5e94c2a03820ab8a91e8f655618/orjson-3.11.8-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fe0b8c83e0f36247fc9431ce5425a5d95f9b3a689133d494831bdbd6f0bceb13", size = 130658, upload-time = "2026-03-31T16:14:33.935Z" }, + { url = "https://files.pythonhosted.org/packages/10/0d/f39d8802345d0ad65f7fd4374b29b9b59f98656dc30f21ca5c773265b2f0/orjson-3.11.8-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:97d823831105c01f6c8029faf297633dbeb30271892bd430e9c24ceae3734744", size = 135708, upload-time = "2026-03-31T16:14:35.224Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b5/40aae576b3473511696dcffea84fde638b2b64774eb4dcb8b2c262729f8a/orjson-3.11.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c60c0423f15abb6cf78f56dff00168a1b582f7a1c23f114036e2bfc697814d5f", size = 147047, upload-time = "2026-03-31T16:14:36.489Z" }, + { url = "https://files.pythonhosted.org/packages/7b/f0/778a84458d1fdaa634b2e572e51ce0b354232f580b2327e1f00a8d88c38c/orjson-3.11.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:01928d0476b216ad2201823b0a74000440360cef4fed1912d297b8d84718f277", size = 133072, upload-time = "2026-03-31T16:14:37.715Z" }, + { url = "https://files.pythonhosted.org/packages/bf/d3/1bbf2fc3ffcc4b829ade554b574af68cec898c9b5ad6420a923c75a073d3/orjson-3.11.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a4a639049c44d36a6d1ae0f4a94b271605c745aee5647fa8ffaabcdc01b69a6", size = 133867, upload-time = "2026-03-31T16:14:39.356Z" }, + { url = "https://files.pythonhosted.org/packages/08/94/6413da22edc99a69a8d0c2e83bf42973b8aa94d83ef52a6d39ac85da00bc/orjson-3.11.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3222adff1e1ff0dce93c16146b93063a7793de6c43d52309ae321234cdaf0f4d", size = 142268, upload-time = "2026-03-31T16:14:40.972Z" }, + { url = "https://files.pythonhosted.org/packages/4a/5f/aa5dbaa6136d7ba55f5461ac2e885efc6e6349424a428927fd46d68f4396/orjson-3.11.8-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:3223665349bbfb68da234acd9846955b1a0808cbe5520ff634bf253a4407009b", size = 424008, upload-time = "2026-03-31T16:14:42.637Z" }, + { url = "https://files.pythonhosted.org/packages/fa/aa/2c1962d108c7fe5e27aa03a354b378caf56d8eafdef15fd83dec081ce45a/orjson-3.11.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:61c9d357a59465736022d5d9ba06687afb7611dfb581a9d2129b77a6fcf78e59", size = 147942, upload-time = "2026-03-31T16:14:44.256Z" }, + { url = "https://files.pythonhosted.org/packages/47/d1/65f404f4c47eb1b0b4476f03ec838cac0c4aa933920ff81e5dda4dee14e7/orjson-3.11.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58fb9b17b4472c7b1dcf1a54583629e62e23779b2331052f09a9249edf81675b", size = 136640, upload-time = "2026-03-31T16:14:45.884Z" }, + { url = "https://files.pythonhosted.org/packages/90/5f/7b784aea98bdb125a2f2da7c27d6c2d2f6d943d96ef0278bae596d563f85/orjson-3.11.8-cp310-cp310-win32.whl", hash = "sha256:b43dc2a391981d36c42fa57747a49dae793ef1d2e43898b197925b5534abd10a", size = 132066, upload-time = "2026-03-31T16:14:47.397Z" }, + { url = "https://files.pythonhosted.org/packages/92/ec/2e284af8d6c9478df5ef938917743f61d68f4c70d17f1b6e82f7e3b8dba1/orjson-3.11.8-cp310-cp310-win_amd64.whl", hash = "sha256:c98121237fea2f679480765abd566f7713185897f35c9e6c2add7e3a9900eb61", size = 127609, upload-time = "2026-03-31T16:14:48.78Z" }, + { url = "https://files.pythonhosted.org/packages/67/41/5aa7fa3b0f4dc6b47dcafc3cea909299c37e40e9972feabc8b6a74e2730d/orjson-3.11.8-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:003646067cc48b7fcab2ae0c562491c9b5d2cbd43f1e5f16d98fd118c5522d34", size = 229229, upload-time = "2026-03-31T16:14:50.424Z" }, + { url = "https://files.pythonhosted.org/packages/0a/d7/57e7f2458e0a2c41694f39fc830030a13053a84f837a5b73423dca1f0938/orjson-3.11.8-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:ed193ce51d77a3830cad399a529cd4ef029968761f43ddc549e1bc62b40d88f8", size = 128871, upload-time = "2026-03-31T16:14:51.888Z" }, + { url = "https://files.pythonhosted.org/packages/53/4a/e0fdb9430983e6c46e0299559275025075568aad5d21dd606faee3703924/orjson-3.11.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f30491bc4f862aa15744b9738517454f1e46e56c972a2be87d70d727d5b2a8f8", size = 132104, upload-time = "2026-03-31T16:14:53.142Z" }, + { url = "https://files.pythonhosted.org/packages/08/4a/2025a60ff3f5c8522060cda46612d9b1efa653de66ed2908591d8d82f22d/orjson-3.11.8-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6eda5b8b6be91d3f26efb7dc6e5e68ee805bc5617f65a328587b35255f138bf4", size = 130483, upload-time = "2026-03-31T16:14:54.605Z" }, + { url = "https://files.pythonhosted.org/packages/2d/3c/b9cde05bdc7b2385c66014e0620627da638d3d04e4954416ab48c31196c5/orjson-3.11.8-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ee8db7bfb6fe03581bbab54d7c4124a6dd6a7f4273a38f7267197890f094675f", size = 135481, upload-time = "2026-03-31T16:14:55.901Z" }, + { url = "https://files.pythonhosted.org/packages/ff/f2/a8238e7734de7cb589fed319857a8025d509c89dc52fdcc88f39c6d03d5a/orjson-3.11.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d8b5231de76c528a46b57010bbd83fb51e056aa0220a372fd5065e978406f1c", size = 146819, upload-time = "2026-03-31T16:14:57.548Z" }, + { url = "https://files.pythonhosted.org/packages/db/10/dbf1e2a3cafea673b1b4350e371877b759060d6018a998643b7040e5de48/orjson-3.11.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58a4a208a6fbfdb7a7327b8f201c6014f189f721fd55d047cafc4157af1bc62a", size = 132846, upload-time = "2026-03-31T16:14:58.91Z" }, + { url = "https://files.pythonhosted.org/packages/f8/fc/55e667ec9c85694038fcff00573d221b085d50777368ee3d77f38668bf3c/orjson-3.11.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f8952d6d2505c003e8f0224ff7858d341fa4e33fef82b91c4ff0ef070f2393c", size = 133580, upload-time = "2026-03-31T16:15:00.519Z" }, + { url = "https://files.pythonhosted.org/packages/7e/a6/c08c589a9aad0cb46c4831d17de212a2b6901f9d976814321ff8e69e8785/orjson-3.11.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0022bb50f90da04b009ce32c512dc1885910daa7cb10b7b0cba4505b16db82a8", size = 142042, upload-time = "2026-03-31T16:15:01.906Z" }, + { url = "https://files.pythonhosted.org/packages/5c/cc/2f78ea241d52b717d2efc38878615fe80425bf2beb6e68c984dde257a766/orjson-3.11.8-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:ff51f9d657d1afb6f410cb435792ce4e1fe427aab23d2fcd727a2876e21d4cb6", size = 423845, upload-time = "2026-03-31T16:15:03.703Z" }, + { url = "https://files.pythonhosted.org/packages/70/07/c17dcf05dd8045457538428a983bf1f1127928df5bf328cb24d2b7cddacb/orjson-3.11.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:6dbe9a97bdb4d8d9d5367b52a7c32549bba70b2739c58ef74a6964a6d05ae054", size = 147729, upload-time = "2026-03-31T16:15:05.203Z" }, + { url = "https://files.pythonhosted.org/packages/90/6c/0fb6e8a24e682e0958d71711ae6f39110e4b9cd8cab1357e2a89cb8e1951/orjson-3.11.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a5c370674ebabe16c6ccac33ff80c62bf8a6e59439f5e9d40c1f5ab8fd2215b7", size = 136425, upload-time = "2026-03-31T16:15:07.052Z" }, + { url = "https://files.pythonhosted.org/packages/b2/35/4d3cc3a3d616035beb51b24a09bb872942dc452cf2df0c1d11ab35046d9f/orjson-3.11.8-cp311-cp311-win32.whl", hash = "sha256:0e32f7154299f42ae66f13488963269e5eccb8d588a65bc839ed986919fc9fac", size = 131870, upload-time = "2026-03-31T16:15:08.678Z" }, + { url = "https://files.pythonhosted.org/packages/13/26/9fe70f81d16b702f8c3a775e8731b50ad91d22dacd14c7599b60a0941cd1/orjson-3.11.8-cp311-cp311-win_amd64.whl", hash = "sha256:25e0c672a2e32348d2eb33057b41e754091f2835f87222e4675b796b92264f06", size = 127440, upload-time = "2026-03-31T16:15:09.994Z" }, + { url = "https://files.pythonhosted.org/packages/e8/c6/b038339f4145efd2859c1ca53097a52c0bb9cbdd24f947ebe146da1ad067/orjson-3.11.8-cp311-cp311-win_arm64.whl", hash = "sha256:9185589c1f2a944c17e26c9925dcdbc2df061cc4a145395c57f0c51f9b5dbfcd", size = 127399, upload-time = "2026-03-31T16:15:11.412Z" }, + { url = "https://files.pythonhosted.org/packages/01/f6/8d58b32ab32d9215973a1688aebd098252ee8af1766c0e4e36e7831f0295/orjson-3.11.8-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:1cd0b77e77c95758f8e1100139844e99f3ccc87e71e6fc8e1c027e55807c549f", size = 229233, upload-time = "2026-03-31T16:15:12.762Z" }, + { url = "https://files.pythonhosted.org/packages/a9/8b/2ffe35e71f6b92622e8ea4607bf33ecf7dfb51b3619dcfabfd36cbe2d0a5/orjson-3.11.8-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:6a3d159d5ffa0e3961f353c4b036540996bf8b9697ccc38261c0eac1fd3347a6", size = 128772, upload-time = "2026-03-31T16:15:14.237Z" }, + { url = "https://files.pythonhosted.org/packages/27/d2/1f8682ae50d5c6897a563cb96bc106da8c9cb5b7b6e81a52e4cc086679b9/orjson-3.11.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76070a76e9c5ae661e2d9848f216980d8d533e0f8143e6ed462807b242e3c5e8", size = 131946, upload-time = "2026-03-31T16:15:15.607Z" }, + { url = "https://files.pythonhosted.org/packages/52/4b/5500f76f0eece84226e0689cb48dcde081104c2fa6e2483d17ca13685ffb/orjson-3.11.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:54153d21520a71a4c82a0dbb4523e468941d549d221dc173de0f019678cf3813", size = 130368, upload-time = "2026-03-31T16:15:17.066Z" }, + { url = "https://files.pythonhosted.org/packages/da/4e/58b927e08fbe9840e6c920d9e299b051ea667463b1f39a56e668669f8508/orjson-3.11.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:469ac2125611b7c5741a0b3798cd9e5786cbad6345f9f400c77212be89563bec", size = 135540, upload-time = "2026-03-31T16:15:18.404Z" }, + { url = "https://files.pythonhosted.org/packages/56/7c/ba7cb871cba1bcd5cd02ee34f98d894c6cea96353ad87466e5aef2429c60/orjson-3.11.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:14778ffd0f6896aa613951a7fbf4690229aa7a543cb2bfbe9f358e08aafa9546", size = 146877, upload-time = "2026-03-31T16:15:19.833Z" }, + { url = "https://files.pythonhosted.org/packages/0b/5d/eb9c25fc1386696c6a342cd361c306452c75e0b55e86ad602dd4827a7fd7/orjson-3.11.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea56a955056a6d6c550cf18b3348656a9d9a4f02e2d0c02cabf3c73f1055d506", size = 132837, upload-time = "2026-03-31T16:15:21.282Z" }, + { url = "https://files.pythonhosted.org/packages/37/87/5ddeb7fc1fbd9004aeccab08426f34c81a5b4c25c7061281862b015fce2b/orjson-3.11.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53a0f57e59a530d18a142f4d4ba6dfc708dc5fdedce45e98ff06b44930a2a48f", size = 133624, upload-time = "2026-03-31T16:15:22.641Z" }, + { url = "https://files.pythonhosted.org/packages/22/09/90048793db94ee4b2fcec4ac8e5ddb077367637d6650be896b3494b79bb7/orjson-3.11.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9b48e274f8824567d74e2158199e269597edf00823a1b12b63d48462bbf5123e", size = 141904, upload-time = "2026-03-31T16:15:24.435Z" }, + { url = "https://files.pythonhosted.org/packages/c0/cf/eb284847487821a5d415e54149a6449ba9bfc5872ce63ab7be41b8ec401c/orjson-3.11.8-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:3f262401086a3960586af06c054609365e98407151f5ea24a62893a40d80dbbb", size = 423742, upload-time = "2026-03-31T16:15:26.155Z" }, + { url = "https://files.pythonhosted.org/packages/44/09/e12423d327071c851c13e76936f144a96adacfc037394dec35ac3fc8d1e8/orjson-3.11.8-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8e8c6218b614badf8e229b697865df4301afa74b791b6c9ade01d19a9953a942", size = 147806, upload-time = "2026-03-31T16:15:27.909Z" }, + { url = "https://files.pythonhosted.org/packages/b3/6d/37c2589ba864e582ffe7611643314785c6afb1f83c701654ef05daa8fcc7/orjson-3.11.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:093d489fa039ddade2db541097dbb484999fcc65fc2b0ff9819141e2ab364f25", size = 136485, upload-time = "2026-03-31T16:15:29.749Z" }, + { url = "https://files.pythonhosted.org/packages/be/c9/135194a02ab76b04ed9a10f68624b7ebd238bbe55548878b11ff15a0f352/orjson-3.11.8-cp312-cp312-win32.whl", hash = "sha256:e0950ed1bcb9893f4293fd5c5a7ee10934fbf82c4101c70be360db23ce24b7d2", size = 131966, upload-time = "2026-03-31T16:15:31.687Z" }, + { url = "https://files.pythonhosted.org/packages/ed/9a/9796f8fbe3cf30ce9cb696748dbb535e5c87be4bf4fe2e9ca498ef1fa8cf/orjson-3.11.8-cp312-cp312-win_amd64.whl", hash = "sha256:3cf17c141617b88ced4536b2135c552490f07799f6ad565948ea07bef0dcb9a6", size = 127441, upload-time = "2026-03-31T16:15:33.333Z" }, + { url = "https://files.pythonhosted.org/packages/cc/47/5aaf54524a7a4a0dd09dd778f3fa65dd2108290615b652e23d944152bc8e/orjson-3.11.8-cp312-cp312-win_arm64.whl", hash = "sha256:48854463b0572cc87dac7d981aa72ed8bf6deedc0511853dc76b8bbd5482d36d", size = 127364, upload-time = "2026-03-31T16:15:34.748Z" }, + { url = "https://files.pythonhosted.org/packages/66/7f/95fba509bb2305fab0073558f1e8c3a2ec4b2afe58ed9fcb7d3b8beafe94/orjson-3.11.8-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:3f23426851d98478c8970da5991f84784a76682213cd50eb73a1da56b95239dc", size = 229180, upload-time = "2026-03-31T16:15:36.426Z" }, + { url = "https://files.pythonhosted.org/packages/f6/9d/b237215c743ca073697d759b5503abd2cb8a0d7b9c9e21f524bcf176ab66/orjson-3.11.8-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:ebaed4cef74a045b83e23537b52ef19a367c7e3f536751e355a2a394f8648559", size = 128754, upload-time = "2026-03-31T16:15:38.049Z" }, + { url = "https://files.pythonhosted.org/packages/42/3d/27d65b6d11e63f133781425f132807aef793ed25075fec686fc8e46dd528/orjson-3.11.8-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97c8f5d3b62380b70c36ffacb2a356b7c6becec86099b177f73851ba095ef623", size = 131877, upload-time = "2026-03-31T16:15:39.484Z" }, + { url = "https://files.pythonhosted.org/packages/dd/cc/faee30cd8f00421999e40ef0eba7332e3a625ce91a58200a2f52c7fef235/orjson-3.11.8-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:436c4922968a619fb7fef1ccd4b8b3a76c13b67d607073914d675026e911a65c", size = 130361, upload-time = "2026-03-31T16:15:41.274Z" }, + { url = "https://files.pythonhosted.org/packages/5c/bb/a6c55896197f97b6d4b4e7c7fd77e7235517c34f5d6ad5aadd43c54c6d7c/orjson-3.11.8-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ab359aff0436d80bfe8a23b46b5fea69f1e18aaf1760a709b4787f1318b317f", size = 135521, upload-time = "2026-03-31T16:15:42.758Z" }, + { url = "https://files.pythonhosted.org/packages/9c/7c/ca3a3525aa32ff636ebb1778e77e3587b016ab2edb1b618b36ba96f8f2c0/orjson-3.11.8-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f89b6d0b3a8d81e1929d3ab3d92bbc225688bd80a770c49432543928fe09ac55", size = 146862, upload-time = "2026-03-31T16:15:44.341Z" }, + { url = "https://files.pythonhosted.org/packages/3c/0c/18a9d7f18b5edd37344d1fd5be17e94dc652c67826ab749c6e5948a78112/orjson-3.11.8-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:29c009e7a2ca9ad0ed1376ce20dd692146a5d9fe4310848904b6b4fee5c5c137", size = 132847, upload-time = "2026-03-31T16:15:46.368Z" }, + { url = "https://files.pythonhosted.org/packages/23/91/7e722f352ad67ca573cee44de2a58fb810d0f4eb4e33276c6a557979fd8a/orjson-3.11.8-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:705b895b781b3e395c067129d8551655642dfe9437273211d5404e87ac752b53", size = 133637, upload-time = "2026-03-31T16:15:48.123Z" }, + { url = "https://files.pythonhosted.org/packages/af/04/32845ce13ac5bd1046ddb02ac9432ba856cc35f6d74dde95864fe0ad5523/orjson-3.11.8-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:88006eda83858a9fdf73985ce3804e885c2befb2f506c9a3723cdeb5a2880e3e", size = 141906, upload-time = "2026-03-31T16:15:49.626Z" }, + { url = "https://files.pythonhosted.org/packages/02/5e/c551387ddf2d7106d9039369862245c85738b828844d13b99ccb8d61fd06/orjson-3.11.8-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:55120759e61309af7fcf9e961c6f6af3dde5921cdb3ee863ef63fd9db126cae6", size = 423722, upload-time = "2026-03-31T16:15:51.176Z" }, + { url = "https://files.pythonhosted.org/packages/00/a3/ecfe62434096f8a794d4976728cb59bcfc4a643977f21c2040545d37eb4c/orjson-3.11.8-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:98bdc6cb889d19bed01de46e67574a2eab61f5cc6b768ed50e8ac68e9d6ffab6", size = 147801, upload-time = "2026-03-31T16:15:52.939Z" }, + { url = "https://files.pythonhosted.org/packages/18/6d/0dce10b9f6643fdc59d99333871a38fa5a769d8e2fc34a18e5d2bfdee900/orjson-3.11.8-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:708c95f925a43ab9f34625e45dcdadf09ec8a6e7b664a938f2f8d5650f6c090b", size = 136460, upload-time = "2026-03-31T16:15:54.431Z" }, + { url = "https://files.pythonhosted.org/packages/01/d6/6dde4f31842d87099238f1f07b459d24edc1a774d20687187443ab044191/orjson-3.11.8-cp313-cp313-win32.whl", hash = "sha256:01c4e5a6695dc09098f2e6468a251bc4671c50922d4d745aff1a0a33a0cf5b8d", size = 131956, upload-time = "2026-03-31T16:15:56.081Z" }, + { url = "https://files.pythonhosted.org/packages/c1/f9/4e494a56e013db957fb77186b818b916d4695b8fa2aa612364974160e91b/orjson-3.11.8-cp313-cp313-win_amd64.whl", hash = "sha256:c154a35dd1330707450bb4d4e7dd1f17fa6f42267a40c1e8a1daa5e13719b4b8", size = 127410, upload-time = "2026-03-31T16:15:57.54Z" }, + { url = "https://files.pythonhosted.org/packages/57/7f/803203d00d6edb6e9e7eef421d4e1adbb5ea973e40b3533f3cfd9aeb374e/orjson-3.11.8-cp313-cp313-win_arm64.whl", hash = "sha256:4861bde57f4d253ab041e374f44023460e60e71efaa121f3c5f0ed457c3a701e", size = 127338, upload-time = "2026-03-31T16:15:59.106Z" }, + { url = "https://files.pythonhosted.org/packages/6d/35/b01910c3d6b85dc882442afe5060cbf719c7d1fc85749294beda23d17873/orjson-3.11.8-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ec795530a73c269a55130498842aaa762e4a939f6ce481a7e986eeaa790e9da4", size = 229171, upload-time = "2026-03-31T16:16:00.651Z" }, + { url = "https://files.pythonhosted.org/packages/c2/56/c9ec97bd11240abef39b9e5d99a15462809c45f677420fd148a6c5e6295e/orjson-3.11.8-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:c492a0e011c0f9066e9ceaa896fbc5b068c54d365fea5f3444b697ee01bc8625", size = 128746, upload-time = "2026-03-31T16:16:02.673Z" }, + { url = "https://files.pythonhosted.org/packages/3b/e4/66d4f30a90de45e2f0cbd9623588e8ae71eef7679dbe2ae954ed6d66a41f/orjson-3.11.8-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:883206d55b1bd5f5679ad5e6ddd3d1a5e3cac5190482927fdb8c78fb699193b5", size = 131867, upload-time = "2026-03-31T16:16:04.342Z" }, + { url = "https://files.pythonhosted.org/packages/19/30/2a645fc9286b928675e43fa2a3a16fb7b6764aa78cc719dc82141e00f30b/orjson-3.11.8-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5774c1fdcc98b2259800b683b19599c133baeb11d60033e2095fd9d4667b82db", size = 124664, upload-time = "2026-03-31T16:16:05.837Z" }, + { url = "https://files.pythonhosted.org/packages/db/44/77b9a86d84a28d52ba3316d77737f6514e17118119ade3f91b639e859029/orjson-3.11.8-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ac7381c83dd3d4a6347e6635950aa448f54e7b8406a27c7ecb4a37e9f1ae08b", size = 129701, upload-time = "2026-03-31T16:16:07.407Z" }, + { url = "https://files.pythonhosted.org/packages/b3/ea/eff3d9bfe47e9bc6969c9181c58d9f71237f923f9c86a2d2f490cd898c82/orjson-3.11.8-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:14439063aebcb92401c11afc68ee4e407258d2752e62d748b6942dad20d2a70d", size = 141202, upload-time = "2026-03-31T16:16:09.48Z" }, + { url = "https://files.pythonhosted.org/packages/52/c8/90d4b4c60c84d62068d0cf9e4d8f0a4e05e76971d133ac0c60d818d4db20/orjson-3.11.8-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa72e71977bff96567b0f500fc5bfd2fdf915f34052c782a4c6ebbdaa97aa858", size = 127194, upload-time = "2026-03-31T16:16:11.02Z" }, + { url = "https://files.pythonhosted.org/packages/8d/c7/ea9e08d1f0ba981adffb629811148b44774d935171e7b3d780ae43c4c254/orjson-3.11.8-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7679bc2f01bb0d219758f1a5f87bb7c8a81c0a186824a393b366876b4948e14f", size = 133639, upload-time = "2026-03-31T16:16:13.434Z" }, + { url = "https://files.pythonhosted.org/packages/6c/8c/ddbbfd6ba59453c8fc7fe1d0e5983895864e264c37481b2a791db635f046/orjson-3.11.8-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:14f7b8fcb35ef403b42fa5ecfa4ed032332a91f3dc7368fbce4184d59e1eae0d", size = 141914, upload-time = "2026-03-31T16:16:14.955Z" }, + { url = "https://files.pythonhosted.org/packages/4e/31/dbfbefec9df060d34ef4962cd0afcb6fa7a9ec65884cb78f04a7859526c3/orjson-3.11.8-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:c2bdf7b2facc80b5e34f48a2d557727d5c5c57a8a450de122ae81fa26a81c1bc", size = 423800, upload-time = "2026-03-31T16:16:16.594Z" }, + { url = "https://files.pythonhosted.org/packages/87/cf/f74e9ae9803d4ab46b163494adba636c6d7ea955af5cc23b8aaa94cfd528/orjson-3.11.8-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:ccd7ba1b0605813a0715171d39ec4c314cb97a9c85893c2c5c0c3a3729df38bf", size = 147837, upload-time = "2026-03-31T16:16:18.585Z" }, + { url = "https://files.pythonhosted.org/packages/64/e6/9214f017b5db85e84e68602792f742e5dc5249e963503d1b356bee611e01/orjson-3.11.8-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:cdbc8c9c02463fef4d3c53a9ba3336d05496ec8e1f1c53326a1e4acc11f5c600", size = 136441, upload-time = "2026-03-31T16:16:20.151Z" }, + { url = "https://files.pythonhosted.org/packages/24/dd/3590348818f58f837a75fb969b04cdf187ae197e14d60b5e5a794a38b79d/orjson-3.11.8-cp314-cp314-win32.whl", hash = "sha256:0b57f67710a8cd459e4e54eb96d5f77f3624eba0c661ba19a525807e42eccade", size = 131983, upload-time = "2026-03-31T16:16:21.823Z" }, + { url = "https://files.pythonhosted.org/packages/3f/0f/b6cb692116e05d058f31ceee819c70f097fa9167c82f67fabe7516289abc/orjson-3.11.8-cp314-cp314-win_amd64.whl", hash = "sha256:735e2262363dcbe05c35e3a8869898022af78f89dde9e256924dc02e99fe69ca", size = 127396, upload-time = "2026-03-31T16:16:23.685Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d1/facb5b5051fabb0ef9d26c6544d87ef19a939a9a001198655d0d891062dd/orjson-3.11.8-cp314-cp314-win_arm64.whl", hash = "sha256:6ccdea2c213cf9f3d9490cbd5d427693c870753df41e6cb375bd79bcbafc8817", size = 127330, upload-time = "2026-03-31T16:16:25.496Z" }, +] + [[package]] name = "packaging" version = "26.0" @@ -4419,6 +4614,8 @@ dependencies = [ { name = "click", version = "8.1.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "click", version = "8.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "nvidia-ml-py" }, + { name = "orjson", version = "3.11.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "orjson", version = "3.11.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "platformdirs", version = "4.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, { name = "platformdirs", version = "4.9.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "protobuf" }, @@ -4499,6 +4696,8 @@ requires-dist = [ { name = "moviepy", marker = "extra == 'media'" }, { name = "numpy", marker = "extra == 'media'" }, { name = "nvidia-ml-py" }, + { name = "orjson", marker = "python_full_version == '3.9.*'", specifier = "<=3.11.5" }, + { name = "orjson", marker = "python_full_version >= '3.10'" }, { name = "pillow", marker = "extra == 'media'" }, { name = "platformdirs", specifier = ">=4.2.0" }, { name = "protobuf", marker = "sys_platform != 'linux'", specifier = ">=3.19.0,!=4.21.0,!=5.28.0,<7" }, From 333ced5a5c47d5801b16372ad07c88fd00cc18b7 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 10:44:25 +0800 Subject: [PATCH 19/52] refactor: save format definition --- swanlab/cli/api/experiment.py | 20 ++++++++++--- swanlab/cli/api/helper.py | 53 +++++++++++++---------------------- swanlab/cli/api/project.py | 20 ++++++++++--- swanlab/cli/api/workspace.py | 20 ++++++++++--- 4 files changed, 67 insertions(+), 46 deletions(-) diff --git a/swanlab/cli/api/experiment.py b/swanlab/cli/api/experiment.py index 876b9e0c4..bcea5e83b 100644 --- a/swanlab/cli/api/experiment.py +++ b/swanlab/cli/api/experiment.py @@ -1,7 +1,8 @@ import click +import orjson from swanlab.api import Api -from swanlab.cli.api.helper import with_save_option +from swanlab.cli.api.helper import format_output, save_output @click.group("run") @@ -12,8 +13,19 @@ def experiment_cli(): @experiment_cli.command("info") @click.argument("path", required=True) -@with_save_option -def get_experiment(path: str): +@click.option( + "--save", + "-s", + "name", + is_flag=False, + flag_value=".", + default=None, + help="Save output as JSON to current directory.", +) +def get_experiment(path: str, name): """Get Experiment(Run) info by path (username/project/run_id).""" api = Api() - return api.run(path).wrapper() + resp = api.run(path).wrapper() + format_output(resp) + if resp.ok and name is not None: + save_output(orjson.dumps(resp.json(), option=orjson.OPT_INDENT_2), name=name) diff --git a/swanlab/cli/api/helper.py b/swanlab/cli/api/helper.py index 19f08fd3a..8b3b921fc 100644 --- a/swanlab/cli/api/helper.py +++ b/swanlab/cli/api/helper.py @@ -1,5 +1,6 @@ -import functools +import enum from datetime import datetime +from typing import Optional import click import nanoid @@ -8,40 +9,24 @@ from swanlab.api.typings.common import ApiResponseType -def _save_json(content: bytes) -> None: - """将 JSON 内容保存到当前目录。""" - filename = f"swanlab-{datetime.now().strftime('%Y%m%d_%H%M%S')}-{nanoid.generate(size=4)}.json" - with open(filename, "wb") as f: - f.write(content) - click.echo(f"Saved to {filename}") - - -def format_output(resp: ApiResponseType, save: bool = False) -> None: - """统一输出 ApiResponseType JSON,可选保存到文件。""" - data = resp.json() - click.echo(orjson.dumps(data, option=orjson.OPT_INDENT_2).decode()) - if save and resp.ok: - _save_json(orjson.dumps(data, option=orjson.OPT_INDENT_2)) +class _SaveFormatEnum(enum.Enum): + JSON = "json" -def with_save_option(f): - """ - 装饰器:为 CLI 命令添加 --save 选项并自动输出/保存响应。 +def format_output(resp: ApiResponseType, fmt: _SaveFormatEnum = _SaveFormatEnum.JSON) -> None: + if fmt == _SaveFormatEnum.JSON: + click.echo(orjson.dumps(resp.json(), option=orjson.OPT_INDENT_2).decode()) - 被装饰的函数应返回 ApiResponseType,装饰器负责 format_output 和可选的文件保存。 - """ - @click.option( - "--save", - "-s", - is_flag=True, - default=False, - help="Save output as JSON to current directory.", - ) - @functools.wraps(f) - def wrapper(*args, save: bool, **kwargs): - resp = f(*args, **kwargs) - if resp is not None: - format_output(resp, save=save) - - return wrapper +def save_output(content: bytes, name: Optional[str] = None, fmt: _SaveFormatEnum = _SaveFormatEnum.JSON) -> None: + if name and name != ".": + ext = name.rsplit(".", 1)[-1].lower() if "." in name else None + if ext and ext not in {f.value for f in _SaveFormatEnum}: + click.echo(f"Warning: unsupported file extension .{ext}, skipped saving.") + return + filename = name + else: + filename = f"swanlab-{datetime.now().strftime('%Y%m%d_%H%M%S')}-{nanoid.generate(size=4)}.{fmt.value}" + with open(filename, "wb") as f: + f.write(content) + click.echo(f"Saved to {filename}") diff --git a/swanlab/cli/api/project.py b/swanlab/cli/api/project.py index 66fdccdb5..86771321c 100644 --- a/swanlab/cli/api/project.py +++ b/swanlab/cli/api/project.py @@ -1,7 +1,8 @@ import click +import orjson from swanlab.api import Api -from swanlab.cli.api.helper import with_save_option +from swanlab.cli.api.helper import format_output, save_output @click.group("project") @@ -12,8 +13,19 @@ def project_cli(): @project_cli.command("info") @click.argument("path", required=True) -@with_save_option -def get_project(path: str): +@click.option( + "--save", + "-s", + "name", + is_flag=False, + flag_value=".", + default=None, + help="Save output as JSON to current directory.", +) +def get_project(path: str, name): """Get project info by path (username/project).""" api = Api() - return api.project(path).wrapper() + resp = api.project(path).wrapper() + format_output(resp) + if resp.ok and name is not None: + save_output(orjson.dumps(resp.json(), option=orjson.OPT_INDENT_2), name=name) diff --git a/swanlab/cli/api/workspace.py b/swanlab/cli/api/workspace.py index c53e52914..ab1723560 100644 --- a/swanlab/cli/api/workspace.py +++ b/swanlab/cli/api/workspace.py @@ -1,7 +1,8 @@ import click +import orjson from swanlab.api import Api -from swanlab.cli.api.helper import with_save_option +from swanlab.cli.api.helper import format_output, save_output @click.group("workspace") @@ -12,8 +13,19 @@ def workspace_cli(): @workspace_cli.command("info") @click.argument("username", required=True) -@with_save_option -def get_workspace(username: str): +@click.option( + "--save", + "-s", + "name", + is_flag=False, + flag_value=".", + default=None, + help="Save output as JSON to current directory.", +) +def get_workspace(username: str, name): """Get Workspace info.""" api = Api() - return api.workspace(username).wrapper() + resp = api.workspace(username).wrapper() + format_output(resp) + if resp.ok and name is not None: + save_output(orjson.dumps(resp.json(), option=orjson.OPT_INDENT_2), name=name) From 18e755355fcf1e441bda1e5bfc71e670a4cb98c5 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 11:05:34 +0800 Subject: [PATCH 20/52] fix: workspace projects --- swanlab/api/workspace.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/swanlab/api/workspace.py b/swanlab/api/workspace.py index e340a383c..c77adc454 100644 --- a/swanlab/api/workspace.py +++ b/swanlab/api/workspace.py @@ -63,8 +63,10 @@ def projects( sort: Optional[str] = None, search: Optional[str] = None, detail: Optional[bool] = True, + page: int = 1, + size: int = 20, + all: bool = False, ): - """获取工作空间下的项目列表。""" from swanlab.api.project import Projects return Projects( @@ -73,6 +75,9 @@ def projects( sort=sort, search=search, detail=detail, + page=page, + size=size, + all=all, ) def json(self) -> Dict[str, Any]: @@ -81,7 +86,7 @@ def json(self) -> Dict[str, Any]: class Workspaces(BaseEntity): """ - 用户工作空间集合的迭代器。 + 用户工作空间集合的分页迭代器。 用法:: From 3416aed30a9b0d0f70c66b5194ebd0b17c5d125a Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 11:27:48 +0800 Subject: [PATCH 21/52] refactor: pagination query dataclass --- swanlab/api/__init__.py | 17 ++++++++--- swanlab/api/base.py | 20 ++++++------- swanlab/api/experiment.py | 15 ++++++---- swanlab/api/project.py | 39 +++++++++++++------------ swanlab/api/typings/__init__.py | 4 +-- swanlab/api/typings/common.py | 51 ++++++++++++++++++++++++++++++--- swanlab/api/workspace.py | 8 ++---- 7 files changed, 105 insertions(+), 49 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 55b49f28c..faf4b67a0 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -15,7 +15,7 @@ from .base import ApiClientContext, BaseEntity from .experiment import Experiment, Experiments from .project import Project, Projects -from .typings.common import ApiResponseType +from .typings.common import ApiResponseType, PaginatedQuery from .user import User from .workspace import Workspace, Workspaces @@ -154,7 +154,8 @@ def projects( :param size: 每页数量,默认 20 :param all: 是否获取全部数据,默认 False """ - return Projects(self._ctx, path=path, sort=sort, search=search, detail=detail, page=page, size=size, all=all) + query = PaginatedQuery(page=page, size=size, search=search, sort=sort, all=all) + return Projects(self._ctx, path=path, query=query, detail=detail) def run(self, path: str) -> Experiment: """ @@ -166,16 +167,24 @@ def run(self, path: str) -> Experiment: return Experiment(self._ctx, path=path) def runs( - self, path: str, filters: Optional[dict] = None, page: int = 1, size: int = 20, all: bool = False + self, + path: str, + filters: Optional[dict] = None, + page: int = 1, + size: int = 20, + all: bool = False, ) -> Experiments: """ 获取项目下的实验列表迭代器。 :param path: 项目路径,格式为 'username/project' :param filters: 筛选条件 + :param page: 起始页码,默认 1 + :param size: 每页数量,默认 20 :param all: 是否获取全部数据,默认 False """ - return Experiments(self._ctx, proj_path=path, filters=filters, page=page, size=size, all=all) + query = PaginatedQuery(page=page, size=size, all=all) + return Experiments(self._ctx, proj_path=path, filters=filters, query=query) def user(self) -> User: return User(self._ctx) diff --git a/swanlab/api/base.py b/swanlab/api/base.py index 3b1fce5eb..3fbab5627 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -11,7 +11,7 @@ from swanlab.sdk.internal.pkg import safe -from .typings.common import ApiPaginationType, ApiResponseType +from .typings.common import ApiPaginationType, ApiResponseType, PaginatedQuery if TYPE_CHECKING: from swanlab.sdk.internal.pkg.client import Client @@ -91,19 +91,17 @@ def _build_web_url(self, path: str) -> str: def _paginate( self, path: str, + query: PaginatedQuery, *, - page_num: int = 1, - page_size: int = 20, - fetch_all: bool = False, - params: Optional[dict] = None, page_info: Dict[str, Any], + extra: Optional[Dict[str, Any]] = None, ) -> Iterator[dict]: - """通用分页迭代器。fetch_all=False 时只取一页,True 时自动翻页取全部。""" - page = page_num + """通用分页迭代器,基于 PaginatedQuery 驱动翻页逻辑。""" + page = query.page while True: - p = {"page": page, "size": page_size} - if params: - p.update({k: v for k, v in params.items() if v is not None}) + p = query.to_params(**(extra or {})) + # 覆盖当前页码(翻页时自增) + p["page"] = page resp = self._get(path, params=p) if not resp.ok: return @@ -119,7 +117,7 @@ def _paginate( yield from items if page >= body.get("pages", 1): break - if not fetch_all: + if not query.all: break page += 1 diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 41e05e837..189eae941 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast from swanlab.api.base import ApiClientContext, BaseEntity +from swanlab.api.typings.common import PaginatedQuery from swanlab.api.typings.experiment import ApiExperimentLabelType, ApiExperimentType from swanlab.api.typings.user import ApiUserType from swanlab.api.utils import get_properties, parse_filter @@ -254,15 +255,19 @@ def __init__( *, proj_path: str, filters: Optional[Dict[str, object]] = None, - page: int = 1, - size: int = 20, - all: bool = False, + query: Optional[PaginatedQuery] = None, ) -> None: super().__init__(ctx) self._proj_path = proj_path self._filters = filters - self._all = all - self._page_info: Dict[str, Any] = {"page": page, "size": size, "total": 0, "pages": 0, "list": []} + self._query = query or PaginatedQuery() + self._page_info: Dict[str, Any] = { + "page": self._query.page, + "size": self._query.size, + "total": 0, + "pages": 0, + "list": [], + } def __iter__(self) -> Iterator[Experiment]: parsed_filters = [parse_filter(k, v) for k, v in self._filters.items()] if self._filters else [] diff --git a/swanlab/api/project.py b/swanlab/api/project.py index dcff78a3c..d698807f5 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/project.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Iterator, List, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity +from swanlab.api.typings.common import PaginatedQuery from swanlab.api.typings.project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType from swanlab.api.utils import get_properties @@ -72,11 +73,18 @@ def labels(self) -> List[ApiProjectLabelType]: def count(self) -> ApiProjectCountType: return self._ensure_data().get("_count", {}) - def runs(self, filters: Optional[Dict[str, object]] = None, all: bool = False): + def runs( + self, + filters: Optional[Dict[str, object]] = None, + page: int = 1, + size: int = 20, + all: bool = False, + ): """获取项目下的实验列表。""" from swanlab.api.experiment import Experiments - return Experiments(self._ctx, proj_path=self.path, filters=filters, all=all) + query = PaginatedQuery(page=page, size=size, all=all) + return Experiments(self._ctx, proj_path=self.path, filters=filters, query=query) def delete(self) -> bool: """删除此项目。""" @@ -102,32 +110,27 @@ def __init__( ctx: ApiClientContext, *, path: str, - sort: Optional[str] = None, - search: Optional[str] = None, + query: Optional[PaginatedQuery] = None, detail: Optional[bool] = True, - page: int = 1, - size: int = 20, - all: bool = False, ) -> None: super().__init__(ctx) self._path = path - self._sort = sort - self._search = search + self._query = query or PaginatedQuery() self._detail = detail - self._page = page - self._size = size - self._all = all - self._page_info: Dict[str, Any] = {"page": page, "size": size, "total": 0, "pages": 0, "list": []} + self._page_info: Dict[str, Any] = { + "page": self._query.page, + "size": self._query.size, + "total": 0, + "pages": 0, + "list": [], + } def __iter__(self) -> Iterator[Project]: - params = {"sort": self._sort, "search": self._search, "detail": self._detail} for item in self._paginate( f"/project/{self._path}", - page_num=self._page, - page_size=self._size, - fetch_all=self._all, - params=params, + self._query, page_info=self._page_info, + extra={"detail": self._detail}, ): yield Project( self._ctx, diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index 35fce418a..0876198eb 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -6,13 +6,13 @@ """ from .common import ( - ApiColumnLiteral, ApiIdentityLiteral, ApiLicensePlanLiteral, ApiPaginationType, ApiResponseType, ApiRoleLiteral, ApiRunStateLiteral, + ApiSidebarLiteral, ApiVisibilityLiteral, ApiWorkspaceLiteral, ) @@ -24,7 +24,7 @@ __all__ = [ # Literal Definition - "ApiColumnLiteral", + "ApiSidebarLiteral", "ApiRunStateLiteral", "ApiVisibilityLiteral", "ApiWorkspaceLiteral", diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index 87c65b6ec..8bd9ef0a6 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -5,13 +5,14 @@ @description: 公共查询 API 通用类型定义 """ -from typing import Any, Dict, List, Literal, TypedDict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, TypedDict # 启用/停用 ApiStatusLiteral = Literal["ENABLED", "DISABLED"] -# 列类型 -ApiColumnLiteral = Literal["SCALAR", "CONFIG", "STABLE"] +# 侧边列类型 +ApiSidebarLiteral = Literal["SCALAR", "CONFIG", "STABLE"] # 实验状态类型 ApiRunStateLiteral = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] @@ -31,6 +32,48 @@ # License 许可证类型 ApiLicensePlanLiteral = Literal["free", "commercial"] +# 排序规则 +ApiSortOrderLiteral = Literal["ASC", "DESC"] + + +# 后端允许的每页条数 +_VALID_PAGE_SIZES = (10, 12, 15, 20, 24, 27, 50, 100) + + +@dataclass(frozen=True) +class PaginatedQuery: + """ + 通用分页查询参数,与后端 pagination_query 对齐。 + + page: 当前页码,≥1 + size: 每页条数,必须为后端允许值之一 + search: 搜索关键词 + sort: 排序字段 + all: 是否拉取全部分页(客户端侧自动翻页) + """ + + page: int = 1 + size: int = 20 + search: Optional[str] = None + sort: Optional[ApiSortOrderLiteral] = None + all: bool = False + + def __post_init__(self) -> None: + if self.page < 1: + raise ValueError(f"page must be >= 1, got {self.page}") + if self.size not in _VALID_PAGE_SIZES: + raise ValueError(f"size must be one of {_VALID_PAGE_SIZES}, got {self.size}") + + def to_params(self, **extra: Optional[Any]) -> Dict[str, Any]: + """转换为查询参数字典,自动过滤 None 值。""" + params: Dict[str, Any] = {"page": self.page, "size": self.size} + if self.search is not None: + params["search"] = self.search + if self.sort is not None: + params["sort"] = self.sort + params.update({k: v for k, v in extra.items() if v is not None}) + return params + class ApiPaginationType(TypedDict): list: List @@ -73,4 +116,4 @@ def __repr__(self) -> str: return f"ApiResponse(ok=False, errmsg={self.errmsg!r})" -__all__ = ["ApiPaginationType", "ApiResponseType"] +__all__ = ["ApiPaginationType", "ApiResponseType", "PaginatedQuery"] diff --git a/swanlab/api/workspace.py b/swanlab/api/workspace.py index c77adc454..2a3583ef7 100644 --- a/swanlab/api/workspace.py +++ b/swanlab/api/workspace.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Iterator, List, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity +from swanlab.api.typings.common import PaginatedQuery from swanlab.api.typings.workspace import ApiWorkspaceLiteral, ApiWorkspaceProfileType, ApiWorkspaceType from swanlab.api.utils import get_properties, strip_dict @@ -69,15 +70,12 @@ def projects( ): from swanlab.api.project import Projects + query = PaginatedQuery(page=page, size=size, search=search, sort=sort, all=all) return Projects( self._ctx, path=self.username, - sort=sort, - search=search, + query=query, detail=detail, - page=page, - size=size, - all=all, ) def json(self) -> Dict[str, Any]: From 0d6085d8089557930dd256ae2025561e5b869be5 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 13:05:44 +0800 Subject: [PATCH 22/52] feat: support pagination get runs --- swanlab/api/__init__.py | 17 ++++- swanlab/api/experiment.py | 101 +++++++++++++++--------------- swanlab/api/project.py | 22 ++++++- swanlab/api/typings/common.py | 8 +-- swanlab/api/typings/experiment.py | 15 ++++- 5 files changed, 99 insertions(+), 64 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index faf4b67a0..c788afdf6 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -170,21 +170,32 @@ def runs( self, path: str, filters: Optional[dict] = None, + ) -> Experiments: + """ + 通过条件过滤获取项目下的实验列表。 + + :param path: 项目路径,格式为 'username/project' + :param filters: 筛选条件 + """ + return Experiments(self._ctx, path=path, filters=filters, mode="post") + + def runs_get( + self, + path: str, page: int = 1, size: int = 20, all: bool = False, ) -> Experiments: """ - 获取项目下的实验列表迭代器。 + 通过分页获取项目下的实验列表。 :param path: 项目路径,格式为 'username/project' - :param filters: 筛选条件 :param page: 起始页码,默认 1 :param size: 每页数量,默认 20 :param all: 是否获取全部数据,默认 False """ query = PaginatedQuery(page=page, size=size, all=all) - return Experiments(self._ctx, proj_path=path, filters=filters, query=query) + return Experiments(self._ctx, path=path, query=query, mode="get") def user(self) -> User: return User(self._ctx) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 189eae941..0ccc3f981 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -9,7 +9,7 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.common import PaginatedQuery -from swanlab.api.typings.experiment import ApiExperimentLabelType, ApiExperimentType +from swanlab.api.typings.experiment import ApiExperimentLabelType, ApiExperimentProfileType, ApiExperimentType from swanlab.api.typings.user import ApiUserType from swanlab.api.utils import get_properties, parse_filter @@ -28,46 +28,9 @@ def _resovle_path(path: str) -> Tuple[str, str]: ) -class Profile: - """Experiment profile containing config, metadata, requirements, and conda info.""" - - def __init__(self, data: Dict) -> None: - self._data = data - - @staticmethod - def _clean_field(value: Any) -> Any: - """Recursively clean config field, removing desc/sort and keeping value.""" - if isinstance(value, dict): - if "value" in value: - return Profile._clean_field(value["value"]) - else: - return {k: Profile._clean_field(v) for k, v in value.items()} - elif isinstance(value, list): - return [Profile._clean_field(item) for item in value] - return value - - @property - def config(self) -> Dict: - """Experiment configuration (cleaned, without desc/sort fields).""" - raw_config = self._data.get("config", {}) - return {k: Profile._clean_field(v) for k, v in raw_config.items()} if isinstance(raw_config, dict) else {} - - @property - def metadata(self) -> Dict: - return self._data.get("metadata", {}) - - @property - def requirements(self) -> str: - return self._data.get("requirements", "") - - @property - def conda(self) -> str: - return self._data.get("conda", "") - - class Experiment(BaseEntity): """ - 表示一个 SwanLab 实验。 + 表示一个 SwanLab 实验(完整信息,通过 POST /runs/shows 或单实验详情接口获取)。 支持双模式:构造时传入 data,或 data=None(按需懒加载)。 构造时从 data 中提取 _cuid 缓存,避免 _ensure_data 与 id 属性的循环调用。 @@ -106,6 +69,10 @@ def name(self) -> str: def description(self) -> str: return self._ensure_data().get("description", "") + @property + def type(self) -> str: + return self._ensure_data().get("type", "") + @property def state(self) -> str: return self._ensure_data().get("state", "") @@ -144,7 +111,7 @@ def finished_at(self) -> str: return self._ensure_data().get("finishedAt", "") @property - def profile(self) -> Profile: + def profile(self) -> ApiExperimentProfileType: """Experiment profile containing config, metadata, requirements, and conda.""" data = self._ensure_data() if "profile" not in data and self._cuid: @@ -152,7 +119,7 @@ def profile(self) -> Profile: if resp.ok and resp.data: self._data = resp.data data = self._data - return Profile(data.get("profile", {})) + return ApiExperimentProfileType(self._ensure_data().get("profile", {})) def metrics( self, keys: Optional[List[str]] = None, x_axis: Optional[str] = None, sample: Optional[int] = None @@ -243,24 +210,35 @@ class Experiments(BaseEntity): """ 项目下实验集合的迭代器。 + 支持两种模式: + - POST 模式(默认):通过 /runs/shows 接口获取,支持复杂过滤,不支持分页 + - GET 模式:通过 /runs 接口获取,支持标准分页,返回精简信息 + 用法:: - for run in api.runs("username/project"): + # POST 复杂过滤 + for run in api.runs(path="username/project"): print(run.name) + + # GET 分页 + for run in api.list_runs_simple(path="username/project"): + print(run.name, run.state) """ def __init__( self, ctx: ApiClientContext, *, - proj_path: str, + path: str, filters: Optional[Dict[str, object]] = None, query: Optional[PaginatedQuery] = None, + mode: str = "post", ) -> None: super().__init__(ctx) - self._proj_path = proj_path + self._proj_path = path self._filters = filters self._query = query or PaginatedQuery() + self._mode = mode self._page_info: Dict[str, Any] = { "page": self._query.page, "size": self._query.size, @@ -270,8 +248,17 @@ def __init__( } def __iter__(self) -> Iterator[Experiment]: + if self._mode == "get": + yield from self._iter_paginated() + else: + yield from self._iter_filtered() + + def _iter_filtered(self) -> Iterator[Experiment]: + """POST /runs/shows 模式:复杂过滤,不支持分页。""" parsed_filters = [parse_filter(k, v) for k, v in self._filters.items()] if self._filters else [] - resp = self._post(f"/project/{self._proj_path}/runs/shows", data={"filters": parsed_filters}) + resp = self._post( + f"/project/{self._proj_path}/runs/shows", data={"filters": parsed_filters, "groups": [], "shows": []} + ) if not resp.ok: return body = resp.data @@ -289,11 +276,21 @@ def __iter__(self) -> Iterator[Experiment]: full_path = f"{self._proj_path}/{cuid}" yield Experiment(self._ctx, path=full_path, data=run_data) + def _iter_paginated(self) -> Iterator[Experiment]: + """GET /runs 模式:标准分页,返回精简信息。""" + for item in self._paginate( + f"/project/{self._proj_path}/runs", + self._query, + page_info=self._page_info, + ): + cuid = item.get("cuid", "") + full_path = f"{self._proj_path}/{cuid}" + yield Experiment( + self._ctx, + path=full_path, + data=cast(ApiExperimentType, item), + ) + def json(self) -> Dict[str, Any]: - info = { - "total": self._page_info.get("total", 0), - "page": self._page_info.get("page", 1), - "size": self._page_info.get("size", 20), - } - info["list"] = [r.json() for r in self] - return info + self._page_info["list"] = [r.json() for r in self] + return self._page_info diff --git a/swanlab/api/project.py b/swanlab/api/project.py index d698807f5..7df11de1f 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/project.py @@ -76,15 +76,33 @@ def count(self) -> ApiProjectCountType: def runs( self, filters: Optional[Dict[str, object]] = None, + ): + """ + 获取项目下的实验列表(POST 模式,支持复杂过滤)。 + + :param filters: 筛选条件 + """ + from swanlab.api.experiment import Experiments + + return Experiments(self._ctx, path=self.path, filters=filters, mode="post") + + def runs_get( + self, page: int = 1, size: int = 20, all: bool = False, ): - """获取项目下的实验列表。""" + """ + 获取项目下的实验列表(GET 模式,标准分页,返回精简信息)。 + + :param page: 起始页码,默认 1 + :param size: 每页数量,默认 20 + :param all: 是否获取全部数据,默认 False + """ from swanlab.api.experiment import Experiments query = PaginatedQuery(page=page, size=size, all=all) - return Experiments(self._ctx, proj_path=self.path, filters=filters, query=query) + return Experiments(self._ctx, path=self.path, query=query, mode="get") def delete(self) -> bool: """删除此项目。""" diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index 8bd9ef0a6..7f205263e 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -14,6 +14,9 @@ # 侧边列类型 ApiSidebarLiteral = Literal["SCALAR", "CONFIG", "STABLE"] +# 实验类型: 运行中/总览 +ApiExperimentTypeLiteral = Literal["CHAPTER", "SUMMARY"] + # 实验状态类型 ApiRunStateLiteral = Literal["RUNNING", "FINISHED", "CRASHED", "ABORTED", "OFFLINE"] @@ -32,9 +35,6 @@ # License 许可证类型 ApiLicensePlanLiteral = Literal["free", "commercial"] -# 排序规则 -ApiSortOrderLiteral = Literal["ASC", "DESC"] - # 后端允许的每页条数 _VALID_PAGE_SIZES = (10, 12, 15, 20, 24, 27, 50, 100) @@ -55,7 +55,7 @@ class PaginatedQuery: page: int = 1 size: int = 20 search: Optional[str] = None - sort: Optional[ApiSortOrderLiteral] = None + sort: Optional[str] = None all: bool = False def __post_init__(self) -> None: diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index 8bac27c1c..beb1769b7 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -5,9 +5,9 @@ @description: 公共查询 API 实验类型定义 """ -from typing import Dict, List, Optional, TypedDict +from typing import Any, Dict, List, Optional, TypedDict -from .common import ApiRunStateLiteral +from .common import ApiExperimentTypeLiteral, ApiRunStateLiteral from .user import ApiUserType @@ -15,12 +15,21 @@ class ApiExperimentLabelType(TypedDict): name: str +# 实验配置 +class ApiExperimentProfileType(TypedDict): + config: Dict[str, Any] + metadata: Dict[str, Any] + requirements: str + conda: str + + class ApiExperimentType(TypedDict): cuid: str name: str + type: ApiExperimentTypeLiteral description: str labels: List[ApiExperimentLabelType] - profile: Dict[str, object] + profile: ApiExperimentProfileType show: bool state: ApiRunStateLiteral cluster: str From cd71acac7afb964da32be35b03d75537cde1c32e Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 15:14:23 +0800 Subject: [PATCH 23/52] feat: support filter runs --- swanlab/api/__init__.py | 12 ++- swanlab/api/experiment.py | 22 ++++-- swanlab/api/project.py | 10 ++- swanlab/api/typings/common.py | 3 + swanlab/api/typings/experiment.py | 113 +++++++++++++++++++++++++++- swanlab/api/utils.py | 120 ++++++++++++++++++------------ 6 files changed, 216 insertions(+), 64 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index c788afdf6..005d238c0 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -5,7 +5,7 @@ @description: SwanLab 公共查询 API 入口,面向用户的 OOP 查询接口 """ -from typing import Optional +from typing import Any, Dict, List, Optional from swanlab.exceptions import AuthenticationError from swanlab.sdk.internal.pkg import nrc, scope @@ -169,15 +169,19 @@ def run(self, path: str) -> Experiment: def runs( self, path: str, - filters: Optional[dict] = None, + filters: Optional[List[Dict[str, Any]]] = None, + groups: Optional[List[Dict[str, Any]]] = None, + sorts: Optional[List[Dict[str, Any]]] = None, ) -> Experiments: """ 通过条件过滤获取项目下的实验列表。 :param path: 项目路径,格式为 'username/project' - :param filters: 筛选条件 + :param filters: 过滤规则列表,每项为 {key, type, op, value} + :param groups: 分组规则列表,每项为 {key, type} + :param sorts: 排序规则列表,每项为 {key, type, order} """ - return Experiments(self._ctx, path=path, filters=filters, mode="post") + return Experiments(self._ctx, path=path, filters=filters, groups=groups, sorts=sorts, mode="post") def runs_get( self, diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 0ccc3f981..adef0a37d 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -9,9 +9,13 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.common import PaginatedQuery -from swanlab.api.typings.experiment import ApiExperimentLabelType, ApiExperimentProfileType, ApiExperimentType +from swanlab.api.typings.experiment import ( + ApiExperimentLabelType, + ApiExperimentProfileType, + ApiExperimentType, +) from swanlab.api.typings.user import ApiUserType -from swanlab.api.utils import get_properties, parse_filter +from swanlab.api.utils import _validate_and_build, get_properties, validate_filter, validate_group, validate_sort def _resovle_path(path: str) -> Tuple[str, str]: @@ -230,13 +234,17 @@ def __init__( ctx: ApiClientContext, *, path: str, - filters: Optional[Dict[str, object]] = None, + filters: Optional[List[Dict[str, Any]]] = None, + groups: Optional[List[Dict[str, Any]]] = None, + sorts: Optional[List[Dict[str, Any]]] = None, query: Optional[PaginatedQuery] = None, mode: str = "post", ) -> None: super().__init__(ctx) self._proj_path = path self._filters = filters + self._groups = groups + self._sorts = sorts self._query = query or PaginatedQuery() self._mode = mode self._page_info: Dict[str, Any] = { @@ -255,9 +263,13 @@ def __iter__(self) -> Iterator[Experiment]: def _iter_filtered(self) -> Iterator[Experiment]: """POST /runs/shows 模式:复杂过滤,不支持分页。""" - parsed_filters = [parse_filter(k, v) for k, v in self._filters.items()] if self._filters else [] resp = self._post( - f"/project/{self._proj_path}/runs/shows", data={"filters": parsed_filters, "groups": [], "shows": []} + f"/project/{self._proj_path}/runs/shows", + data={ + "filters": _validate_and_build(self._filters, validate_filter), + "groups": _validate_and_build(self._groups, validate_group), + "sorts": _validate_and_build(self._sorts, validate_sort), + }, ) if not resp.ok: return diff --git a/swanlab/api/project.py b/swanlab/api/project.py index 7df11de1f..0b557e4b4 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/project.py @@ -75,16 +75,20 @@ def count(self) -> ApiProjectCountType: def runs( self, - filters: Optional[Dict[str, object]] = None, + filters: Optional[List[Dict[str, Any]]] = None, + groups: Optional[List[Dict[str, Any]]] = None, + sorts: Optional[List[Dict[str, Any]]] = None, ): """ 获取项目下的实验列表(POST 模式,支持复杂过滤)。 - :param filters: 筛选条件 + :param filters: 过滤规则列表,每项为 {key, type, op, value} + :param groups: 分组规则列表,每项为 {key, type} + :param sorts: 排序规则列表,每项为 {key, type, order} """ from swanlab.api.experiment import Experiments - return Experiments(self._ctx, path=self.path, filters=filters, mode="post") + return Experiments(self._ctx, path=self.path, filters=filters, groups=groups, sorts=sorts, mode="post") def runs_get( self, diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index 7f205263e..b16d69df4 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -12,6 +12,9 @@ ApiStatusLiteral = Literal["ENABLED", "DISABLED"] # 侧边列类型 +# STABLE: Experiment 的固有字段,如 state, name, labels, colors 等 +# CONFIG: 动态生成的实验配置字段,如 learning_rate, batch_size 等 +# SCALAR: 动态生成的标量字段,一般用于标量图展示,如 train/loss 等 ApiSidebarLiteral = Literal["SCALAR", "CONFIG", "STABLE"] # 实验类型: 运行中/总览 diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index beb1769b7..f8063993e 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -3,19 +3,126 @@ @file: experiment.py @time: 2026/4/20 @description: 公共查询 API 实验类型定义 + +POST /runs/shows 接口支持三个维度的筛选和组织:过滤 (filters)、分组 (groups)、排序 (sorts)。 +每项都有一个 type 字段,取值取决于数据来源: + - STABLE: 实验表固有字段(固定枚举) + - CONFIG: experimentProfile.config 中用户定义的超参(动态 key,如 learning_rate) + - SCALAR: 训练过程中记录的标量指标最新值(动态 key,如 train/loss) """ -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Dict, List, Literal, Optional, TypedDict -from .common import ApiExperimentTypeLiteral, ApiRunStateLiteral +from .common import ApiExperimentTypeLiteral, ApiRunStateLiteral, ApiSidebarLiteral from .user import ApiUserType +# --------------------------------------------------------------------------- +# STABLE 字段 key 枚举 +# 对应 experiment 表的直接字段或嵌套字段,来自 sidebar.js stableFieldSelect。 +# --------------------------------------------------------------------------- +ApiStableKeyLiteral = Literal[ + # 实验状态 RUNNING / FINISHED / CRASHED / ABORTED + "state", + # 实验名称 + "name", + # 实验描述 + "description", + # 是否可见 + "show", + # 是否收藏 + "pin", + # 是否为基线 + "baseline", + # 颜色 + "colors", + # 实验分组名 + "cluster", + # 分布式任务类型 + "job", + # 创建时间 + "createdAt", + # 更新时间 + "updatedAt", + # 完成时间 + "finishedAt", + # 收藏时间 + "pinnedAt", + # 标签名数组 + "labels", +] + +# --------------------------------------------------------------------------- +# 过滤操作符 +# --------------------------------------------------------------------------- +# EQ : 等于 +# NEQ : 不等于 +# GTE : 大于等于(数值 / 日期 / 字符串) +# LTE : 小于等于(数值 / 日期 / 字符串) +# IN : 在给定值列表中 +# NOT IN : 不在给定值列表中 +# CONTAIN : 模糊包含 +# +# 注意: +# - 数组类型(如 labels)仅支持 EQ / NEQ / IN / NOT IN / CONTAIN +# - 日期类型 GTE/LTE 用 Date 对象比较;其余用 ISO 字符串比较 +# - 数值类型优先数值比较,失败回退字符串比较 +# --------------------------------------------------------------------------- +ApiFilterOpLiteral = Literal["EQ", "NEQ", "GTE", "LTE", "IN", "NOT IN", "CONTAIN"] + +# --------------------------------------------------------------------------- +# 排序方向 +# --------------------------------------------------------------------------- +ApiSortOrderLiteral = Literal["ASC", "DESC"] + + +# --------------------------------------------------------------------------- +# filter / group / sort item +# POST /runs/shows 请求体中 filters / groups / sorts 数组的元素类型。 +# 用户传入时不需要 active 字段,由 SDK 内部自动补充 active: True。 +# --------------------------------------------------------------------------- +class ApiFilterItem(TypedDict): + """POST /runs/shows 请求体中的过滤项。 + + 多个 filter 之间为 AND 关系。 + 收藏的实验(pin: true)永远不会被过滤掉。 + """ + + key: str # STABLE 时为 ApiStableKeyLiteral 枚举;CONFIG/SCALAR 时为动态字段名 + type: ApiSidebarLiteral # STABLE | CONFIG | SCALAR + op: ApiFilterOpLiteral # 过滤操作符 + value: List[str] # 过滤值列表(空值统一视为空字符串 "") + + +class ApiGroupItem(TypedDict): + """POST /runs/shows 请求体中的分组项。 + + 多个 group 形成多层嵌套(外层 group 为第一层)。 + 数组类型值(如 labels)会排序后用 ", " 连接成字符串作为分组 key。 + """ + + key: str # STABLE 时为 ApiStableKeyLiteral 枚举;CONFIG/SCALAR 时为动态字段名 + type: ApiSidebarLiteral # STABLE | CONFIG | SCALAR + + +class ApiSortItem(TypedDict): + """POST /runs/shows 请求体中的排序项。 + + 排序与分组联动:有 order 的字段按方向排序后平铺,无 order 的保留嵌套结构。 + 后端自动追加兜底排序:pin DESC > pinnedAt DESC > createdAt DESC。 + """ + + key: str # STABLE 时为 ApiStableKeyLiteral 枚举;CONFIG/SCALAR 时为动态字段名 + type: ApiSidebarLiteral # STABLE | CONFIG | SCALAR + order: ApiSortOrderLiteral # ASC | DESC + +# --------------------------------------------------------------------------- +# 实验实体 +# --------------------------------------------------------------------------- class ApiExperimentLabelType(TypedDict): name: str -# 实验配置 class ApiExperimentProfileType(TypedDict): config: Dict[str, Any] metadata: Dict[str, Any] diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index ad4a8f01a..eb1dcbd50 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -5,7 +5,7 @@ @description: swanlab/api 实体层工具函数 """ -from typing import Any, Dict, Optional, Set, Type, get_type_hints +from typing import Any, Dict, List, Optional, Set, Type, get_args, get_type_hints def strip_dict(data: Any, typed_cls: Type) -> Dict[str, Any]: @@ -35,51 +35,73 @@ def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str return result -def parse_column_type(column: str) -> str: - """从前缀中获取指标类型""" - column_type = column.split(".", 1)[0] - if column_type == "summary": - return "SCALAR" - elif column_type == "config": - return "CONFIG" - else: - return "STABLE" - - -def to_camel_case(name: str) -> str: - """将下划线命名转化为驼峰命名""" - return "".join([w.capitalize() if i > 0 else w for i, w in enumerate(name.split("_"))]) - - -_SPECIAL_FILTER_MAP = { - # (backend_key, operator) — 用户侧 key 到后端字段名和操作符的映射 - # backend_key: 后端 API 实际接受的字段名 - # operator: 筛选操作符,EQ=精确匹配,IN=包含匹配(用于 tags 列表) - "group": ("cluster", "EQ"), - "tags": ("labels", "IN"), - "name": ("name", "EQ"), - "username": ("user.username", "EQ"), - "job_type": ("job", "EQ"), -} - - -def parse_filter(key: str, value: object) -> Dict[str, object]: - """将用户侧筛选条件转换为后端 filter 格式。 - - :param key: 筛选字段名。预定义字段(group/tags/name/username/job_type)会映射到后端字段名; - 其他字段按 column type 自动转换:STABLE 类型转 camelCase,其余取最后一段。 - :param value: 筛选值。预定义字段中 tags 接受列表/元组,其余均为单值(内部统一包装为列表)。 - :return: 后端 filter 字典,包含 key / active / value / op / type 五个字段。 - """ - if key in _SPECIAL_FILTER_MAP: - backend_key, op = _SPECIAL_FILTER_MAP[key] - filter_value = list(value) if key == "tags" and isinstance(value, (list, tuple)) else [value] - return {"key": backend_key, "active": True, "value": filter_value, "op": op, "type": "STABLE"} - ct = parse_column_type(key) - return { - "key": to_camel_case(key) if ct == "STABLE" else key.split(".", 1)[-1], - "active": True, - "value": [value], - "op": "EQ", - "type": ct, - } +# --------------------------------------------------------------------------- +# POST /runs/shows 参数校验常量(从 typings 中的 Literal 类型提取,避免重复定义) +# --------------------------------------------------------------------------- +from swanlab.api.typings.common import ApiSidebarLiteral +from swanlab.api.typings.experiment import ( + ApiFilterOpLiteral, + ApiSortOrderLiteral, + ApiStableKeyLiteral, +) + +_VALID_SIDEBAR_TYPES = frozenset(get_args(ApiSidebarLiteral)) +_VALID_OPS = frozenset(get_args(ApiFilterOpLiteral)) +_VALID_ORDERS = frozenset(get_args(ApiSortOrderLiteral)) +_STABLE_KEYS = frozenset(get_args(ApiStableKeyLiteral)) + + +def _check_required(item: Dict[str, Any], keys: Set[str]) -> None: + missing = keys - item.keys() + if missing: + raise ValueError(f"Missing required fields: {sorted(missing)}, got {sorted(item.keys())}") + + +def _check_type_field(item: Dict[str, Any]) -> None: + t = item.get("type", "") + if t not in _VALID_SIDEBAR_TYPES: + raise ValueError(f"Invalid type: {t!r}, expected one of {sorted(_VALID_SIDEBAR_TYPES)}") + + +def _check_stable_key(item: Dict[str, Any]) -> None: + if item.get("type") == "STABLE" and item["key"] not in _STABLE_KEYS: + raise ValueError(f"Invalid STABLE key: {item['key']!r}, expected one of {sorted(_STABLE_KEYS)}") + + +def validate_filter(item: Dict[str, Any]) -> None: + """校验单个 filter item 的合法性。""" + _check_required(item, {"key", "type", "op", "value"}) + _check_type_field(item) + _check_stable_key(item) + if item["op"] not in _VALID_OPS: + raise ValueError(f"Invalid filter op: {item['op']!r}, expected one of {sorted(_VALID_OPS)}") + if not isinstance(item["value"], list): + raise ValueError(f"filter value must be a list, got {type(item['value']).__name__}") + + +def validate_group(item: Dict[str, Any]) -> None: + """校验单个 group item 的合法性。""" + _check_required(item, {"key", "type"}) + _check_type_field(item) + _check_stable_key(item) + + +def validate_sort(item: Dict[str, Any]) -> None: + """校验单个 sort item 的合法性。""" + _check_required(item, {"key", "type", "order"}) + _check_type_field(item) + _check_stable_key(item) + if item["order"] not in _VALID_ORDERS: + raise ValueError(f"Invalid sort order: {item['order']!r}, expected one of {sorted(_VALID_ORDERS)}") + + +def _validate_and_build( + items: Optional[List[Dict[str, Any]]], + validator, +) -> List[Dict[str, Any]]: + """校验每个 item 并补充 active: True,返回可直接发送的列表。""" + if not items: + return [] + for item in items: + validator(item) + return [{**item, "active": True} for item in items] From 6bbf719603054f8e26ff1036fed23f4587cbc1a6 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 17:24:00 +0800 Subject: [PATCH 24/52] feat: support pagination column --- swanlab/api/__init__.py | 50 ++++++++- swanlab/api/base.py | 4 + swanlab/api/column.py | 192 +++++++++++++++++++++++++++++++++ swanlab/api/experiment.py | 30 ++++++ swanlab/api/project.py | 4 + swanlab/api/typings/column.py | 48 +++++++++ swanlab/api/typings/common.py | 22 ++++ swanlab/api/typings/project.py | 1 + swanlab/api/utils.py | 31 ++++-- 9 files changed, 375 insertions(+), 7 deletions(-) create mode 100644 swanlab/api/column.py create mode 100644 swanlab/api/typings/column.py diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 005d238c0..ca11e769e 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -13,9 +13,10 @@ from swanlab.sdk.internal.settings import settings as global_settings from .base import ApiClientContext, BaseEntity +from .column import Column, Columns from .experiment import Experiment, Experiments from .project import Project, Projects -from .typings.common import ApiResponseType, PaginatedQuery +from .typings.common import PaginatedQuery from .user import User from .workspace import Workspace, Workspaces @@ -204,5 +205,52 @@ def runs_get( def user(self) -> User: return User(self._ctx) + def columns( + self, + run_id: str, + page: int = 1, + size: int = 20, + search: Optional[str] = None, + column_class: str = "CUSTOM", + column_type: Optional[str] = None, + all: bool = False, + ) -> Columns: + """ + 获取实验下的列列表(分页查询,支持搜索)。 + + :param run_id: 实验 ID(cuid) + :param page: 起始页码,默认 1 + :param size: 每页数量,默认 20 + :param search: 搜索关键词,搜索的是列的 name + :param column_class: 列的分类,CUSTOM 或 SYSTEM, 默认为 CUSTOM + :param column_type: 列的类型,如 FLOAT、STRING、IMAGE 等 + :param all: 是否获取全部数据,默认 False + """ + query = PaginatedQuery(page=page, size=size, search=search, all=all) + return Columns( + self._ctx, + run_id=run_id, + query=query, + column_type=column_type, + column_class=column_class, + ) + + def column( + self, + run_id: str, + key: str, + column_class: Optional[str] = "CUSTOM", + column_type: Optional[str] = None, + ) -> Column: + """ + 获取单个列(通过搜索 key 匹配)。 + + :param run_id: 实验 ID(run_id) + :param key: 列的键名, 输入不完整则模糊匹配 name 为首个 key. + :param column_class: 列的分类,CUSTOM 或 SYSTEM,默认 CUSTOM + :param column_type: 列的类型,如 FLOAT、STRING、IMAGE 等,默认为 None + """ + return Column(self._ctx, run_id=run_id, key=key, column_class=column_class, column_type=column_type) + __all__ = ["Api"] diff --git a/swanlab/api/base.py b/swanlab/api/base.py index 3fbab5627..c9cecbad9 100644 --- a/swanlab/api/base.py +++ b/swanlab/api/base.py @@ -5,6 +5,8 @@ @description: 所有实体类的公共基类 """ +import random +import time from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional @@ -115,6 +117,8 @@ def _paginate( if not items: break yield from items + # 随机休眠控制 qps + time.sleep(random.random()) if page >= body.get("pages", 1): break if not query.all: diff --git a/swanlab/api/column.py b/swanlab/api/column.py new file mode 100644 index 000000000..275bbb521 --- /dev/null +++ b/swanlab/api/column.py @@ -0,0 +1,192 @@ +""" +@author: caddiesnew +@file: column.py +@time: 2026/4/20 +@description: Column 实体类 — 实验列的查询与操作 +""" + +from typing import Any, Dict, Iterator, Optional, cast + +from swanlab.api.base import ApiClientContext, BaseEntity +from swanlab.api.typings.column import ApiColumnCsvExportType, ApiColumnType +from swanlab.api.typings.common import ApiResponseType, PaginatedQuery +from swanlab.api.utils import get_properties, validate_column_params + + +class Column(BaseEntity): + """ + 表示一个 SwanLab 实验列。 + + 支持双模式:构造时传入 data(列表迭代注入),或 data=None(按需懒加载)。 + 注意:列不支持单个获取 API,只能通过列表接口获取。 + """ + + def __init__( + self, + ctx: ApiClientContext, + *, + run_id: str, + key: str, + column_class: Optional[str] = "CUSTOM", + column_type: Optional[str] = None, + data: Optional[ApiColumnType] = None, + ) -> None: + super().__init__(ctx) + self._run_id = run_id + self._key = key + self._column_class = column_class + self._column_type = column_type + self._data = data + + def _ensure_data(self) -> ApiColumnType: + if self._data is None: + validate_column_params(column_class=self._column_class) + extra: Dict[str, Any] = {"search": self._key} + if self._column_class: + extra["class"] = self._column_class + resp = self._get( + f"/experiment/{self._run_id}/column", + params={"page": 1, "size": 10, **extra}, + ) + if resp.data: + items = resp.data.get("list", []) if isinstance(resp.data, dict) else [] + if items: + self._data = cast(ApiColumnType, items[0]) + if self._data is None: + self._data = cast(ApiColumnType, {}) + return self._data + + @property + def key(self) -> str: + res_key = self._ensure_data().get("key", "") + if res_key and res_key != self._key: + self._key = res_key + return res_key + + @property + def name(self) -> str: + """列的显示名称,默认为 key 的值。""" + return self._ensure_data().get("name", "") + + @property + def column_class(self) -> str: + """列的分类:CUSTOM 或 SYSTEM。""" + return self._ensure_data().get("class", "") + + @property + def column_type(self) -> str: + """列的数据类型,如 FLOAT、STRING、IMAGE 等。""" + return self._ensure_data().get("type", "") + + @property + def created_at(self) -> int: + """列的创建时间戳。""" + return self._ensure_data().get("createdAt", 0) + + @property + def error(self) -> Optional[Dict[str, Any]]: + """列的错误信息。""" + return self._ensure_data().get("error", {}) + + def export_csv(self) -> ApiResponseType: + """ + 导出列数据为 CSV。 + + :return: ApiResponseType,成功时 data 包含临时下载 URL + """ + resp = self._get(f"/experiment/{self._run_id}/column/csv", params={"key": self.key}) + if not resp.ok: + return resp + + data = resp.data + if isinstance(data, list) and data: + url = data[0].get("url", "") + return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) + elif isinstance(data, dict): + url = data.get("url", "") + return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) + + return ApiResponseType(ok=False, errmsg="Invalid response format", data=None) + + def json(self) -> Dict[str, Any]: + return get_properties(self) + + +class Columns(BaseEntity): + """ + 实验下列集合的分页迭代器。 + + 用法:: + + # 获取所有列 + for column in experiment.columns(): + print(column.name, column.data_type) + + # 分页获取列(支持搜索) + for column in experiment.columns(page=1, size=20, search="loss"): + print(column.name) + + # 获取全部列(自动翻页) + for column in experiment.columns(all=True): + print(column.name) + """ + + def __init__( + self, + ctx: ApiClientContext, + *, + run_id: str, + query: PaginatedQuery, + column_class: Optional[str] = None, + column_type: Optional[str] = None, + ) -> None: + super().__init__(ctx) + self._run_id = run_id + self._query = query + # 校验 column_type 和 column_class 的合法性 + validate_column_params(column_type=column_type, column_class=column_class) + self._column_class = column_class + self._column_type = column_type + self._page_info: Dict[str, Any] = { + "page": query.page, + "size": query.size, + "total": 0, + "pages": 0, + "list": [], + } + + def __iter__(self) -> Iterator[Column]: + """迭代分页获取列。""" + extra: Dict[str, Any] = {} + if self._column_type: + extra["type"] = self._column_type + if self._column_class: + extra["class"] = self._column_class + + for item in self._paginate( + f"/experiment/{self._run_id}/column", + self._query, + page_info=self._page_info, + extra=extra, + ): + yield Column( + self._ctx, + run_id=self._run_id, + key=item.get("key", ""), + data=cast(ApiColumnType, item), + ) + + @property + def total(self) -> int: + """获取总数(触发一次请求)。""" + # 触发一次迭代来获取总数 + if self._page_info["total"] == 0: + try: + next(iter(self)) + except StopIteration: + pass + return self._page_info["total"] + + def json(self) -> Dict[str, Any]: + self._page_info["list"] = [c.json() for c in self] + return self._page_info diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index adef0a37d..87f1e84ef 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -192,6 +192,36 @@ def strip_suffix(col, suffix="_step"): return result_df + def columns( + self, + page: int = 1, + size: int = 20, + search: Optional[str] = None, + column_type: Optional[str] = None, + column_class: Optional[str] = None, + all: bool = False, + ): + """ + 获取实验下的列列表(分页查询,支持搜索)。 + + :param page: 起始页码,默认 1 + :param size: 每页数量,默认 20 + :param search: 搜索关键词,搜索的是列的 name + :param column_type: 列的类型,如 FLOAT、STRING、IMAGE 等 + :param column_class: 列的分类,CUSTOM 或 SYSTEM + :param all: 是否获取全部数据,默认 False + """ + from swanlab.api.column import Columns + + query = PaginatedQuery(page=page, size=size, search=search, all=all) + return Columns( + self._ctx, + run_id=self._cuid, + query=query, + column_type=column_type, + column_class=column_class, + ) + def delete(self) -> bool: """删除此实验。""" resp = self._delete(f"/project/{self._proj_path}/runs/{self._cuid}") diff --git a/swanlab/api/project.py b/swanlab/api/project.py index 0b557e4b4..2bcdb0114 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/project.py @@ -37,6 +37,10 @@ def _ensure_data(self) -> ApiProjectType: self._data = resp.data if resp.ok and resp.data else cast(ApiProjectType, {}) return self._data + @property + def project_id(self) -> str: + return self._ensure_data().get("cuid", "") + @property def name(self) -> str: return self._ensure_data().get("name", "") diff --git a/swanlab/api/typings/column.py b/swanlab/api/typings/column.py new file mode 100644 index 000000000..c90dc06de --- /dev/null +++ b/swanlab/api/typings/column.py @@ -0,0 +1,48 @@ +""" +@author: caddiesnew +@file: column.py +@time: 2026/4/20 +@description: 公共查询 API 实验列类型定义 +""" + +from typing import Any, Dict, Optional, TypedDict + +from .common import ApiColumnClassLiteral, ApiColumnDataTypeLiteral + + +class ApiColumnErrorType(TypedDict, total=False): + """列错误信息""" + + message: str + code: str + + +class ApiColumnType(TypedDict, total=False): + """ + 实验列数据类型 + + 注意:后端响应使用以下字段名: + - class: 列的分类 (CUSTOM/SYSTEM) + - type: 列的数据类型 (FLOAT/STRING/IMAGE等) + - createdAt: 创建时间戳(蛇峰命名) + """ + + # 列的分类:CUSTOM 或 SYSTEM + column_class: ApiColumnClassLiteral + # 列的数据类型 + column_type: ApiColumnDataTypeLiteral + # 列的键名(唯一标识) + key: str + # 列的显示名称,默认为 key 的值 + name: str + # 创建时间戳 + createdAt: int + # 错误信息 + error: Optional[Dict[str, Any]] + + +class ApiColumnCsvExportType(TypedDict): + """列 CSV 导出响应类型""" + + # 临时下载 URL + url: str diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index b16d69df4..cf4eefd34 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -32,6 +32,28 @@ # 工作空间成员类型 ApiRoleLiteral = Literal["VISITOR", "VIEWER", "MEMBER", "OWNER"] +# 列种类 +ApiColumnClassLiteral = Literal["CUSTOM", "SYSTEM"] +# 列数据类型 +ApiColumnDataTypeLiteral = Literal[ + "FLOAT", + "BOOLEAN", + "STRING" + # media 类型 + "IMAGE", + "AUDIO", + "VIDEO", + # 3D点云 (json) + "OBJECT3D", + # 生物化学分子 + "MOLECULE", + # (js/ ts 文件) + "ECHARTS", + # 表格类型 + "TABLE", + "TEXT", +] + # Self-Hosted 身份类型 ApiIdentityLiteral = Literal["root", "user"] diff --git a/swanlab/api/typings/project.py b/swanlab/api/typings/project.py index 76477f31b..f6028961b 100644 --- a/swanlab/api/typings/project.py +++ b/swanlab/api/typings/project.py @@ -23,6 +23,7 @@ class ApiProjectCountType(TypedDict): class ApiProjectType(TypedDict): + cuid: str name: str username: str path: str diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index eb1dcbd50..0025e1c19 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -7,6 +7,13 @@ from typing import Any, Dict, List, Optional, Set, Type, get_args, get_type_hints +from swanlab.api.typings.common import ApiColumnClassLiteral, ApiColumnDataTypeLiteral, ApiSidebarLiteral +from swanlab.api.typings.experiment import ( + ApiFilterOpLiteral, + ApiSortOrderLiteral, + ApiStableKeyLiteral, +) + def strip_dict(data: Any, typed_cls: Type) -> Dict[str, Any]: """将原始 API 响应字典裁剪为只保留 TypedDict 中声明的字段。""" @@ -38,18 +45,16 @@ def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str # --------------------------------------------------------------------------- # POST /runs/shows 参数校验常量(从 typings 中的 Literal 类型提取,避免重复定义) # --------------------------------------------------------------------------- -from swanlab.api.typings.common import ApiSidebarLiteral -from swanlab.api.typings.experiment import ( - ApiFilterOpLiteral, - ApiSortOrderLiteral, - ApiStableKeyLiteral, -) _VALID_SIDEBAR_TYPES = frozenset(get_args(ApiSidebarLiteral)) _VALID_OPS = frozenset(get_args(ApiFilterOpLiteral)) _VALID_ORDERS = frozenset(get_args(ApiSortOrderLiteral)) _STABLE_KEYS = frozenset(get_args(ApiStableKeyLiteral)) +# 列相关校验常量 +_VALID_COLUMN_CLASSES = frozenset(get_args(ApiColumnClassLiteral)) +_VALID_COLUMN_DATA_TYPES = frozenset(get_args(ApiColumnDataTypeLiteral)) + def _check_required(item: Dict[str, Any], keys: Set[str]) -> None: missing = keys - item.keys() @@ -105,3 +110,17 @@ def _validate_and_build( for item in items: validator(item) return [{**item, "active": True} for item in items] + + +def validate_column_params(column_type: Optional[str] = None, column_class: Optional[str] = None) -> None: + """ + 校验列查询参数的合法性。 + + :param column_type: 列的数据类型 + :param column_class: 列的分类 + :raises ValueError: 当参数不在允许的枚举值中时 + """ + if column_type is not None and column_type not in _VALID_COLUMN_DATA_TYPES: + raise ValueError(f"Invalid column_type: {column_type!r}, expected one of {sorted(_VALID_COLUMN_DATA_TYPES)}") + if column_class is not None and column_class not in _VALID_COLUMN_CLASSES: + raise ValueError(f"Invalid column_class: {column_class!r}, expected one of {sorted(_VALID_COLUMN_CLASSES)}") From 1ba66cd22f364c218b9ecdc5b856bdc0ad2792f4 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 17:31:00 +0800 Subject: [PATCH 25/52] chore: update method name --- swanlab/api/experiment.py | 8 ++++---- swanlab/api/utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 87f1e84ef..7a8e87059 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -15,7 +15,7 @@ ApiExperimentType, ) from swanlab.api.typings.user import ApiUserType -from swanlab.api.utils import _validate_and_build, get_properties, validate_filter, validate_group, validate_sort +from swanlab.api.utils import get_properties, validate_filter, validate_group, validate_sort, validate_update_active def _resovle_path(path: str) -> Tuple[str, str]: @@ -296,9 +296,9 @@ def _iter_filtered(self) -> Iterator[Experiment]: resp = self._post( f"/project/{self._proj_path}/runs/shows", data={ - "filters": _validate_and_build(self._filters, validate_filter), - "groups": _validate_and_build(self._groups, validate_group), - "sorts": _validate_and_build(self._sorts, validate_sort), + "filters": validate_update_active(self._filters, validate_filter), + "groups": validate_update_active(self._groups, validate_group), + "sorts": validate_update_active(self._sorts, validate_sort), }, ) if not resp.ok: diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 0025e1c19..45a85f611 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -100,7 +100,7 @@ def validate_sort(item: Dict[str, Any]) -> None: raise ValueError(f"Invalid sort order: {item['order']!r}, expected one of {sorted(_VALID_ORDERS)}") -def _validate_and_build( +def validate_update_active( items: Optional[List[Dict[str, Any]]], validator, ) -> List[Dict[str, Any]]: From 3e7b7618306dc7b2cccd9b5abcd119ee0c4046cb Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 20:39:52 +0800 Subject: [PATCH 26/52] feat: add metric type --- swanlab/api/typings/__init__.py | 11 ++++++ swanlab/api/typings/experiment.py | 2 - swanlab/api/typings/metrics.py | 63 +++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 swanlab/api/typings/metrics.py diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index 0876198eb..9d05d2ef2 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -5,6 +5,7 @@ @description: SwanLab OpenAPI 类型提示, 以 Api 前缀区分 """ +from .column import ApiColumnCsvExportType, ApiColumnErrorType, ApiColumnType from .common import ( ApiIdentityLiteral, ApiLicensePlanLiteral, @@ -17,6 +18,7 @@ ApiWorkspaceLiteral, ) from .experiment import ApiExperimentLabelType, ApiExperimentType +from .metrics import ApiLogType, ApiMediaType, ApiScalarType from .project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType from .selfhosted import ApiApiKeyType, ApiSelfHostedInfoType from .user import ApiUserProfileType, ApiUserType @@ -50,4 +52,13 @@ # Misc "ApiApiKeyType", "ApiSelfHostedInfoType", + # Column + "ApiColumnErrorType", + "ApiColumnType", + "ApiColumnCsvExportType", + # Metrics + "ApiMetricsColumnType", + "ApiScalarType", + "ApiMediaMetricType", + "ApiLogMetricType", ] diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index f8063993e..3575bd9fe 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -142,5 +142,3 @@ class ApiExperimentType(TypedDict): cluster: str job: str user: ApiUserType - rootExpId: Optional[str] - rootProId: Optional[str] diff --git a/swanlab/api/typings/metrics.py b/swanlab/api/typings/metrics.py new file mode 100644 index 000000000..245995ca8 --- /dev/null +++ b/swanlab/api/typings/metrics.py @@ -0,0 +1,63 @@ +""" +@author: caddiesnew +@file: metrics.py +@time: 2026/4/23 +@description: 指标数据类型定义(用于 column 采样值获取) +""" + +from typing import Any, Dict, List, TypedDict + + +# --------------------------------------------------------------------------- +# Scalar — 标量 item 数据 +# --------------------------------------------------------------------------- +# 一个 key 下单个 scalar 的数据值包装 +class ApiScalarType(TypedDict, total=False): + step: int + data: float + timestamp: int + + +# 一个 key 下批量 scalar 的响应值包装 +# 需要请求两次 +class ApiScalarListType(TypedDict, total=False): + min: ApiScalarType + max: ApiScalarType + avg: ApiScalarType + median: ApiScalarType + latest: ApiScalarType + metrics: List[ApiScalarType] + + +# 指标概要 +# summary[run_id][key] 为下面一个 item 项 +class ApiSummaryItemType(TypedDict, total=False): + step: int + value: Any + minMax: List[Any] + min: Any + max: Any + avg: Any + median: Any + stdDev: Any + + +# --------------------------------------------------------------------------- +# Media — 媒体 item 数据 +# --------------------------------------------------------------------------- +class ApiMediaType(TypedDict, total=False): + # 项目路径: proj_id/run_id 拼接而成 + prefix: str + data: List[str] + more: List[Dict[str, Any]] + + +# --------------------------------------------------------------------------- +# Log — 日志 item 数据 +# --------------------------------------------------------------------------- +class ApiLogType(TypedDict, total=False): + epoch: int + level: str + message: str + tag: str + timestamp: str From 192c4c5adf220eef91bf4d9a1c57aa6057566e4f Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Thu, 23 Apr 2026 20:41:48 +0800 Subject: [PATCH 27/52] chore: rename metric type --- swanlab/api/typings/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index 9d05d2ef2..adb3f5865 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -57,8 +57,7 @@ "ApiColumnType", "ApiColumnCsvExportType", # Metrics - "ApiMetricsColumnType", + "ApiMediaType", "ApiScalarType", - "ApiMediaMetricType", - "ApiLogMetricType", + "ApiLogType", ] From 90e59d93e1c4a4c95e6d076ceaf4e3493eb68565 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Fri, 24 Apr 2026 11:35:42 +0800 Subject: [PATCH 28/52] chore: hiden pin property --- swanlab/api/typings/experiment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index 3575bd9fe..123207434 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -29,8 +29,8 @@ "description", # 是否可见 "show", - # 是否收藏 - "pin", + # TODO: experiment 被设置为 pin 时强制返回 + # "pin", # 是否为基线 "baseline", # 颜色 From 618e971890af94631083ca65f7390fc53c5369a7 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Fri, 24 Apr 2026 14:42:15 +0800 Subject: [PATCH 29/52] fix: add project_id and run_id for column --- swanlab/api/__init__.py | 12 +++--- swanlab/api/column.py | 32 +++++++++++++--- swanlab/api/experiment.py | 45 ++++++++++++---------- swanlab/api/typings/__init__.py | 5 --- swanlab/api/typings/column.py | 3 ++ swanlab/api/typings/experiment.py | 3 +- swanlab/api/typings/metrics.py | 63 ------------------------------- swanlab/api/utils.py | 17 ++++++++- 8 files changed, 78 insertions(+), 102 deletions(-) delete mode 100644 swanlab/api/typings/metrics.py diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index ca11e769e..04790fb87 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -207,7 +207,7 @@ def user(self) -> User: def columns( self, - run_id: str, + path: str, page: int = 1, size: int = 20, search: Optional[str] = None, @@ -218,7 +218,7 @@ def columns( """ 获取实验下的列列表(分页查询,支持搜索)。 - :param run_id: 实验 ID(cuid) + :param path: 实验路径,格式为 'username/project/run_id' :param page: 起始页码,默认 1 :param size: 每页数量,默认 20 :param search: 搜索关键词,搜索的是列的 name @@ -229,7 +229,7 @@ def columns( query = PaginatedQuery(page=page, size=size, search=search, all=all) return Columns( self._ctx, - run_id=run_id, + path=path, query=query, column_type=column_type, column_class=column_class, @@ -237,7 +237,7 @@ def columns( def column( self, - run_id: str, + path: str, key: str, column_class: Optional[str] = "CUSTOM", column_type: Optional[str] = None, @@ -245,12 +245,12 @@ def column( """ 获取单个列(通过搜索 key 匹配)。 - :param run_id: 实验 ID(run_id) + :param path: 实验路径,格式为 'username/project/run_id' :param key: 列的键名, 输入不完整则模糊匹配 name 为首个 key. :param column_class: 列的分类,CUSTOM 或 SYSTEM,默认 CUSTOM :param column_type: 列的类型,如 FLOAT、STRING、IMAGE 等,默认为 None """ - return Column(self._ctx, run_id=run_id, key=key, column_class=column_class, column_type=column_type) + return Column(self._ctx, path=path, key=key, column_class=column_class, column_type=column_type) __all__ = ["Api"] diff --git a/swanlab/api/column.py b/swanlab/api/column.py index 275bbb521..dea2b296d 100644 --- a/swanlab/api/column.py +++ b/swanlab/api/column.py @@ -10,7 +10,7 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.column import ApiColumnCsvExportType, ApiColumnType from swanlab.api.typings.common import ApiResponseType, PaginatedQuery -from swanlab.api.utils import get_properties, validate_column_params +from swanlab.api.utils import get_properties, resovle_run_path, validate_column_params class Column(BaseEntity): @@ -25,18 +25,19 @@ def __init__( self, ctx: ApiClientContext, *, - run_id: str, + path: str, key: str, column_class: Optional[str] = "CUSTOM", column_type: Optional[str] = None, data: Optional[ApiColumnType] = None, ) -> None: super().__init__(ctx) - self._run_id = run_id + self._proj_path, self._run_id = resovle_run_path(path=path) self._key = key self._column_class = column_class self._column_type = column_type self._data = data + self._project_id = None def _ensure_data(self) -> ApiColumnType: if self._data is None: @@ -54,8 +55,26 @@ def _ensure_data(self) -> ApiColumnType: self._data = cast(ApiColumnType, items[0]) if self._data is None: self._data = cast(ApiColumnType, {}) + self._data["run_id"] = self._run_id + if self._project_id is None: + resp = self._get(f"/project/{self._proj_path}") + proj_data = resp.data if resp.ok else {} + self._project_id = proj_data.get("cuid", "") + self._data["project_id"] = self._project_id return self._data + @property + def project_id(self) -> str: + if self._project_id: + return self._project_id + return self._ensure_data().get("project_id", "") + + @property + def run_id(self) -> str: + if self._run_id: + return self._run_id + return self._ensure_data().get("run_id", "") + @property def key(self) -> str: res_key = self._ensure_data().get("key", "") @@ -135,13 +154,14 @@ def __init__( self, ctx: ApiClientContext, *, - run_id: str, + path: str, query: PaginatedQuery, column_class: Optional[str] = None, column_type: Optional[str] = None, ) -> None: super().__init__(ctx) - self._run_id = run_id + self._run_path = path + self._proj_path, self._run_id = resovle_run_path(path=path) self._query = query # 校验 column_type 和 column_class 的合法性 validate_column_params(column_type=column_type, column_class=column_class) @@ -171,7 +191,7 @@ def __iter__(self) -> Iterator[Column]: ): yield Column( self._ctx, - run_id=self._run_id, + path=self._run_path, key=item.get("key", ""), data=cast(ApiColumnType, item), ) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 7a8e87059..592f599cd 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -5,7 +5,7 @@ @description: Experiment 实体类 — 单个实验的查询与操作 """ -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Union, cast from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.common import PaginatedQuery @@ -15,21 +15,14 @@ ApiExperimentType, ) from swanlab.api.typings.user import ApiUserType -from swanlab.api.utils import get_properties, validate_filter, validate_group, validate_sort, validate_update_active - - -def _resovle_path(path: str) -> Tuple[str, str]: - """ "path like: user/proj_name/run_id""" - proj_path, cuid = "", "" - parts = path.split("/") - if len(parts) != 3: - return proj_path, cuid - cuid = parts[-1] - proj_path = path.rsplit("/", 1)[0] - return ( - proj_path, - cuid, - ) +from swanlab.api.utils import ( + get_properties, + resovle_run_path, + validate_filter, + validate_group, + validate_sort, + validate_update_active, +) class Experiment(BaseEntity): @@ -48,8 +41,9 @@ def __init__( data: Optional[ApiExperimentType] = None, ) -> None: super().__init__(ctx) - self._proj_path, self._cuid = _resovle_path(path=path) + self._proj_path, self._cuid = resovle_run_path(path=path) self._data = data + self._project_id = None def _ensure_data(self) -> ApiExperimentType: if self._data is None: @@ -57,8 +51,19 @@ def _ensure_data(self) -> ApiExperimentType: self._data = resp.data if resp.ok and resp.data else cast(ApiExperimentType, {}) if not self._cuid and self._data: self._cuid = self._data.get("cuid", "") + if self._project_id is None: + resp = self._get(f"/project/{self._proj_path}") + proj_data = resp.data if resp.ok else {} + self._project_id = proj_data.get("cuid", "") + self._data["project_id"] = self._project_id return self._data + @property + def project_id(self) -> str: + if self._project_id: + return self._project_id + return self._ensure_data().get("project_id", "") + @property def run_id(self) -> str: if self._cuid: @@ -104,7 +109,7 @@ def job_type(self) -> str: @property def user(self) -> ApiUserType: user_data = self._ensure_data().get("user", {}) - return user_data if isinstance(user_data, dict) else cast(ApiUserType, {}) + return cast(ApiUserType, user_data) @property def created_at(self) -> str: @@ -123,7 +128,7 @@ def profile(self) -> ApiExperimentProfileType: if resp.ok and resp.data: self._data = resp.data data = self._data - return ApiExperimentProfileType(self._ensure_data().get("profile", {})) + return cast(ApiExperimentProfileType, self._ensure_data().get("profile", {})) def metrics( self, keys: Optional[List[str]] = None, x_axis: Optional[str] = None, sample: Optional[int] = None @@ -216,7 +221,7 @@ def columns( query = PaginatedQuery(page=page, size=size, search=search, all=all) return Columns( self._ctx, - run_id=self._cuid, + path=f"{self._proj_path}/{self._cuid}", query=query, column_type=column_type, column_class=column_class, diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index adb3f5865..35e30108a 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -18,7 +18,6 @@ ApiWorkspaceLiteral, ) from .experiment import ApiExperimentLabelType, ApiExperimentType -from .metrics import ApiLogType, ApiMediaType, ApiScalarType from .project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType from .selfhosted import ApiApiKeyType, ApiSelfHostedInfoType from .user import ApiUserProfileType, ApiUserType @@ -56,8 +55,4 @@ "ApiColumnErrorType", "ApiColumnType", "ApiColumnCsvExportType", - # Metrics - "ApiMediaType", - "ApiScalarType", - "ApiLogType", ] diff --git a/swanlab/api/typings/column.py b/swanlab/api/typings/column.py index c90dc06de..f119a3696 100644 --- a/swanlab/api/typings/column.py +++ b/swanlab/api/typings/column.py @@ -27,6 +27,9 @@ class ApiColumnType(TypedDict, total=False): - createdAt: 创建时间戳(蛇峰命名) """ + # 每个 column 与一个项目和实验绑定 + project_id: str + run_id: str # 列的分类:CUSTOM 或 SYSTEM column_class: ApiColumnClassLiteral # 列的数据类型 diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index 123207434..93915b937 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -130,7 +130,8 @@ class ApiExperimentProfileType(TypedDict): conda: str -class ApiExperimentType(TypedDict): +class ApiExperimentType(TypedDict, total=False): + project_id: str cuid: str name: str type: ApiExperimentTypeLiteral diff --git a/swanlab/api/typings/metrics.py b/swanlab/api/typings/metrics.py deleted file mode 100644 index 245995ca8..000000000 --- a/swanlab/api/typings/metrics.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -@author: caddiesnew -@file: metrics.py -@time: 2026/4/23 -@description: 指标数据类型定义(用于 column 采样值获取) -""" - -from typing import Any, Dict, List, TypedDict - - -# --------------------------------------------------------------------------- -# Scalar — 标量 item 数据 -# --------------------------------------------------------------------------- -# 一个 key 下单个 scalar 的数据值包装 -class ApiScalarType(TypedDict, total=False): - step: int - data: float - timestamp: int - - -# 一个 key 下批量 scalar 的响应值包装 -# 需要请求两次 -class ApiScalarListType(TypedDict, total=False): - min: ApiScalarType - max: ApiScalarType - avg: ApiScalarType - median: ApiScalarType - latest: ApiScalarType - metrics: List[ApiScalarType] - - -# 指标概要 -# summary[run_id][key] 为下面一个 item 项 -class ApiSummaryItemType(TypedDict, total=False): - step: int - value: Any - minMax: List[Any] - min: Any - max: Any - avg: Any - median: Any - stdDev: Any - - -# --------------------------------------------------------------------------- -# Media — 媒体 item 数据 -# --------------------------------------------------------------------------- -class ApiMediaType(TypedDict, total=False): - # 项目路径: proj_id/run_id 拼接而成 - prefix: str - data: List[str] - more: List[Dict[str, Any]] - - -# --------------------------------------------------------------------------- -# Log — 日志 item 数据 -# --------------------------------------------------------------------------- -class ApiLogType(TypedDict, total=False): - epoch: int - level: str - message: str - tag: str - timestamp: str diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 45a85f611..02dddf018 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -5,7 +5,7 @@ @description: swanlab/api 实体层工具函数 """ -from typing import Any, Dict, List, Optional, Set, Type, get_args, get_type_hints +from typing import Any, Dict, List, Optional, Set, Tuple, Type, get_args, get_type_hints from swanlab.api.typings.common import ApiColumnClassLiteral, ApiColumnDataTypeLiteral, ApiSidebarLiteral from swanlab.api.typings.experiment import ( @@ -42,6 +42,21 @@ def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str return result +# 路径解析 +def resovle_run_path(path: str) -> Tuple[str, str]: + """ "path like: user/proj_name/run_id""" + proj_path, cuid = "", "" + parts = path.split("/") + if len(parts) != 3: + return proj_path, cuid + cuid = parts[-1] + proj_path = path.rsplit("/", 1)[0] + return ( + proj_path, + cuid, + ) + + # --------------------------------------------------------------------------- # POST /runs/shows 参数校验常量(从 typings 中的 Literal 类型提取,避免重复定义) # --------------------------------------------------------------------------- From fde9a82fc6c79e82cdac38daa6071dc1e599ef3d Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Fri, 24 Apr 2026 16:18:56 +0800 Subject: [PATCH 30/52] feat: add metric skeleton --- swanlab/api/column.py | 9 ++- swanlab/api/metric.py | 103 ++++++++++++++++++++++++++++++ swanlab/api/typings/__init__.py | 1 + swanlab/api/typings/common.py | 67 +++++++++++++++++++ swanlab/api/typings/experiment.py | 66 ++----------------- swanlab/api/typings/metric.py | 90 ++++++++++++++++++++++++++ swanlab/api/utils.py | 33 ++++++++-- 7 files changed, 305 insertions(+), 64 deletions(-) create mode 100644 swanlab/api/metric.py create mode 100644 swanlab/api/typings/metric.py diff --git a/swanlab/api/column.py b/swanlab/api/column.py index dea2b296d..f4c517bb4 100644 --- a/swanlab/api/column.py +++ b/swanlab/api/column.py @@ -10,7 +10,7 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.column import ApiColumnCsvExportType, ApiColumnType from swanlab.api.typings.common import ApiResponseType, PaginatedQuery -from swanlab.api.utils import get_properties, resovle_run_path, validate_column_params +from swanlab.api.utils import get_properties, parse_column_data_type, resovle_run_path, validate_column_params class Column(BaseEntity): @@ -127,6 +127,13 @@ def export_csv(self) -> ApiResponseType: return ApiResponseType(ok=False, errmsg="Invalid response format", data=None) + def metric(self): + from swanlab.api.metric import Metric + + metric_type = parse_column_data_type(self.column_type) + + cur_metric = Metric(ctx=self._ctx, project_id=self.project_id, run_id=self.run_id, metric_type=metric_type) + def json(self) -> Dict[str, Any]: return get_properties(self) diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py new file mode 100644 index 000000000..3318abc96 --- /dev/null +++ b/swanlab/api/metric.py @@ -0,0 +1,103 @@ +""" +@author: caddiesnew +@file: column.py +@time: 2026/4/20 +@description: Column 实体类 — 实验列的查询与操作 +""" + +from typing import Any, Dict, Iterator, List, Optional, cast + +from swanlab.api.base import ApiClientContext, BaseEntity +from swanlab.api.typings import ApiColumnCsvExportType, ApiMetricTypeLiteral, ApiResponseType +from swanlab.api.typings.metric import ApiLogType, ApiMediaType, ApiMetricType, ApiScalarType +from swanlab.api.utils import get_properties, resovle_run_path, validate_column_params, validate_metric_type + + +class Metric(BaseEntity): + """ + 表示一个 SwanLab 指标列 (非单个数值,而是一组序列) + """ + + def __init__( + self, + ctx: ApiClientContext, + *, + project_id: str, + run_id: str, + key: Optional[str] = "", + metric_type: str = "SCALAR", + data: Optional[Any] = None, + ) -> None: + super().__init__(ctx) + validate_metric_type(metric_type, key) + self._project_id = project_id + self._run_id = run_id + self._key = key + + self._data = data + self._metric_type = metric_type + + def _ensure_data(self) -> ApiMetricType: + if self._data is None: + if self._metric_type == "SCALAR": + self._data = cast(ApiScalarType, {}) + elif self._metric_type == "MEDIA": + self._data = cast(ApiScalarType, {}) + elif self._metric_type == "LOG": + self._data = cast(ApiScalarType, {}) + else: + # 默认兜底到 scalar,实际上在实例化时被拦截 + self._data = cast(ApiScalarType, {}) + return self._data + + @property + def project_id(self) -> str: + return self._project_id + + @property + def run_id(self) -> str: + return self._run_id + + @property + def key(self) -> str: + return self._key if self._key else "" + + @property + def metric_type(self) -> str: + return self._metric_type + + @property + def metrics(self) -> List[Any]: + return [] + + def _fetch_scalar(self): + return + + def _fetch_media(self): + pass + + def _fetch_logs(self): + pass + + def export_csv(self) -> ApiResponseType: + """ + 导出列数据为 CSV。(同时支持 column 和 csv 导出) + + :return: ApiResponseType,成功时 data 包含临时下载 URL + """ + resp = self._get(f"/experiment/{self._run_id}/column/csv", params={"key": self.key}) + if not resp.ok: + return resp + + data = resp.data + if isinstance(data, list) and data: + url = data[0].get("url", "") + return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) + elif isinstance(data, dict): + url = data.get("url", "") + return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) + + return ApiResponseType(ok=False, errmsg="Invalid response format", data=None) + + def json(self) -> Dict[str, Any]: + return get_properties(self) diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index 35e30108a..134c19fd2 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -18,6 +18,7 @@ ApiWorkspaceLiteral, ) from .experiment import ApiExperimentLabelType, ApiExperimentType +from .metric import * from .project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType from .selfhosted import ApiApiKeyType, ApiSelfHostedInfoType from .user import ApiUserProfileType, ApiUserType diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index cf4eefd34..d778cf42f 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -54,12 +54,79 @@ "TEXT", ] +# 列数据非 media 类型,方便过滤 +ApiColumnScalarTypeLiteral = Literal["FLOAT", "BOOLEAN", "STRING"] + # Self-Hosted 身份类型 ApiIdentityLiteral = Literal["root", "user"] # License 许可证类型 ApiLicensePlanLiteral = Literal["free", "commercial"] +# 指标类型(log 不属于 column-backed metrics,使用独立查询方法) +ApiMetricTypeLiteral = Literal["SCALAR", "MEDIA", "LOG"] + +# X 轴类型 +ApiMetricXAxisLiteral = Literal["step", "time", "relative_time"] + + +# --------------------------------------------------------------------------- +# STABLE 字段 key 枚举 +# 对应 experiment 表的直接字段或嵌套字段,来自 sidebar.js stableFieldSelect。 +# --------------------------------------------------------------------------- +ApiFilterStableKeyLiteral = Literal[ + # 实验状态 RUNNING / FINISHED / CRASHED / ABORTED + "state", + # 实验名称 + "name", + # 实验描述 + "description", + # 是否可见 + "show", + # TODO: experiment 被设置为 pin 时强制返回 + # "pin", + # 是否为基线 + "baseline", + # 颜色 + "colors", + # 实验分组名 + "cluster", + # 分布式任务类型 + "job", + # 创建时间 + "createdAt", + # 更新时间 + "updatedAt", + # 完成时间 + "finishedAt", + # 收藏时间 + "pinnedAt", + # 标签名数组 + "labels", +] + +# --------------------------------------------------------------------------- +# 过滤操作符 +# --------------------------------------------------------------------------- +# EQ : 等于 +# NEQ : 不等于 +# GTE : 大于等于(数值 / 日期 / 字符串) +# LTE : 小于等于(数值 / 日期 / 字符串) +# IN : 在给定值列表中 +# NOT IN : 不在给定值列表中 +# CONTAIN : 模糊包含 +# +# 注意: +# - 数组类型(如 labels)仅支持 EQ / NEQ / IN / NOT IN / CONTAIN +# - 日期类型 GTE/LTE 用 Date 对象比较;其余用 ISO 字符串比较 +# - 数值类型优先数值比较,失败回退字符串比较 +# --------------------------------------------------------------------------- +ApiFilterOpLiteral = Literal["EQ", "NEQ", "GTE", "LTE", "IN", "NOT IN", "CONTAIN"] + +# --------------------------------------------------------------------------- +# 排序方向 +# --------------------------------------------------------------------------- +ApiSortOrderLiteral = Literal["ASC", "DESC"] # 后端允许的每页条数 _VALID_PAGE_SIZES = (10, 12, 15, 20, 24, 27, 50, 100) diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index 93915b937..eb8f6a47e 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -13,67 +13,15 @@ from typing import Any, Dict, List, Literal, Optional, TypedDict -from .common import ApiExperimentTypeLiteral, ApiRunStateLiteral, ApiSidebarLiteral +from .common import ( + ApiExperimentTypeLiteral, + ApiFilterOpLiteral, + ApiRunStateLiteral, + ApiSidebarLiteral, + ApiSortOrderLiteral, +) from .user import ApiUserType -# --------------------------------------------------------------------------- -# STABLE 字段 key 枚举 -# 对应 experiment 表的直接字段或嵌套字段,来自 sidebar.js stableFieldSelect。 -# --------------------------------------------------------------------------- -ApiStableKeyLiteral = Literal[ - # 实验状态 RUNNING / FINISHED / CRASHED / ABORTED - "state", - # 实验名称 - "name", - # 实验描述 - "description", - # 是否可见 - "show", - # TODO: experiment 被设置为 pin 时强制返回 - # "pin", - # 是否为基线 - "baseline", - # 颜色 - "colors", - # 实验分组名 - "cluster", - # 分布式任务类型 - "job", - # 创建时间 - "createdAt", - # 更新时间 - "updatedAt", - # 完成时间 - "finishedAt", - # 收藏时间 - "pinnedAt", - # 标签名数组 - "labels", -] - -# --------------------------------------------------------------------------- -# 过滤操作符 -# --------------------------------------------------------------------------- -# EQ : 等于 -# NEQ : 不等于 -# GTE : 大于等于(数值 / 日期 / 字符串) -# LTE : 小于等于(数值 / 日期 / 字符串) -# IN : 在给定值列表中 -# NOT IN : 不在给定值列表中 -# CONTAIN : 模糊包含 -# -# 注意: -# - 数组类型(如 labels)仅支持 EQ / NEQ / IN / NOT IN / CONTAIN -# - 日期类型 GTE/LTE 用 Date 对象比较;其余用 ISO 字符串比较 -# - 数值类型优先数值比较,失败回退字符串比较 -# --------------------------------------------------------------------------- -ApiFilterOpLiteral = Literal["EQ", "NEQ", "GTE", "LTE", "IN", "NOT IN", "CONTAIN"] - -# --------------------------------------------------------------------------- -# 排序方向 -# --------------------------------------------------------------------------- -ApiSortOrderLiteral = Literal["ASC", "DESC"] - # --------------------------------------------------------------------------- # filter / group / sort item diff --git a/swanlab/api/typings/metric.py b/swanlab/api/typings/metric.py new file mode 100644 index 000000000..47d3fdf59 --- /dev/null +++ b/swanlab/api/typings/metric.py @@ -0,0 +1,90 @@ +""" +@author: caddiesnew +@file: metric.py +@time: 2026/4/23 +@description: 指标数据类型定义(用于 column 采样值) +""" + +from typing import Any, Dict, List, Literal, TypedDict, Union + +from .common import ApiMetricTypeLiteral, ApiMetricXAxisLiteral + +# --------------------------------------------------------------------------- +# Common — 通用指标类型定义 +# --------------------------------------------------------------------------- + + +# 指标值类型("NaN", "INF", "-INF") +ApiMetricValueType = Union[int, float, str] + + +# --------------------------------------------------------------------------- +# Column Reference — 指标列引用,标识要查询的指标列 +# --------------------------------------------------------------------------- +class ApiMetricColumnRefType(TypedDict, total=False): + projectId: str + experimentId: str + key: str + rootProId: str + rootExpId: str + + +# --------------------------------------------------------------------------- +# Scalar — 标量指标类型 +# --------------------------------------------------------------------------- +# 使用 index 因为 x 轴可以是 step / time / relative_time / 自定义列 +class ApiScalarType(TypedDict, total=False): + index: float + data: ApiMetricValueType + timestamp: int + + +# 组合 /metrics/scalar 和 /metrics/scalar/value 的标量序列 +class ApiScalarSeriesType(ApiMetricColumnRefType, total=False): + """标量指标序列,包含折线数据和聚合值""" + + metrics: List[ApiScalarType] + minMax: List[Any] + min: ApiScalarType + max: ApiScalarType + avg: ApiScalarType + median: ApiScalarType + latest: ApiScalarType + + +# 指标概要 +# summary[run_id][key] 为下面一个 item 项 +class ApiScalarSummaryItemType(TypedDict, total=False): + step: int + value: Any + minMax: List[Any] + min: Any + max: Any + avg: Any + median: Any + stdDev: Any + + +# --------------------------------------------------------------------------- +# Media — 媒体 item 数据 +# --------------------------------------------------------------------------- +class ApiMediaType(TypedDict, total=False): + # 项目路径: proj_id/run_id 拼接而成 + prefix: str + data: List[str] + more: List[Dict[str, Any]] + + +# --------------------------------------------------------------------------- +# Log — 日志 item 数据 +# --------------------------------------------------------------------------- +class ApiLogType(TypedDict, total=False): + epoch: int + level: str + message: str + tag: str + timestamp: str + + +# 统一数据类型定义用于类型提示 +ApiMetricType = Union[ApiScalarType, ApiMediaType, ApiLogType] diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 02dddf018..46e2ba0a8 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -7,11 +7,15 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, get_args, get_type_hints -from swanlab.api.typings.common import ApiColumnClassLiteral, ApiColumnDataTypeLiteral, ApiSidebarLiteral -from swanlab.api.typings.experiment import ( +from swanlab.api.typings.common import ( + ApiColumnClassLiteral, + ApiColumnDataTypeLiteral, + ApiColumnScalarTypeLiteral, ApiFilterOpLiteral, + ApiFilterStableKeyLiteral, + ApiMetricTypeLiteral, + ApiSidebarLiteral, ApiSortOrderLiteral, - ApiStableKeyLiteral, ) @@ -64,11 +68,15 @@ def resovle_run_path(path: str) -> Tuple[str, str]: _VALID_SIDEBAR_TYPES = frozenset(get_args(ApiSidebarLiteral)) _VALID_OPS = frozenset(get_args(ApiFilterOpLiteral)) _VALID_ORDERS = frozenset(get_args(ApiSortOrderLiteral)) -_STABLE_KEYS = frozenset(get_args(ApiStableKeyLiteral)) +_STABLE_KEYS = frozenset(get_args(ApiFilterStableKeyLiteral)) # 列相关校验常量 _VALID_COLUMN_CLASSES = frozenset(get_args(ApiColumnClassLiteral)) _VALID_COLUMN_DATA_TYPES = frozenset(get_args(ApiColumnDataTypeLiteral)) +_VALID_COLUMN_SCALAR_TYPES = frozenset(get_args(ApiColumnScalarTypeLiteral)) + +# 指标相关校验常量 +_VALID_METRIC_TYPES = frozenset(get_args(ApiMetricTypeLiteral)) def _check_required(item: Dict[str, Any], keys: Set[str]) -> None: @@ -99,6 +107,14 @@ def validate_filter(item: Dict[str, Any]) -> None: raise ValueError(f"filter value must be a list, got {type(item['value']).__name__}") +def validate_metric_type(item: str, key: Optional[str] = None): + """校验 metric_type 的合法性""" + if item not in _VALID_METRIC_TYPES: + raise ValueError(f"Invalid metric_type: {item!r}, expected one of {sorted(_VALID_METRIC_TYPES)}") + if key is None and item != "LOG": + raise ValueError("key must NOT be None if metric_type != LOG") + + def validate_group(item: Dict[str, Any]) -> None: """校验单个 group item 的合法性。""" _check_required(item, {"key", "type"}) @@ -139,3 +155,12 @@ def validate_column_params(column_type: Optional[str] = None, column_class: Opti raise ValueError(f"Invalid column_type: {column_type!r}, expected one of {sorted(_VALID_COLUMN_DATA_TYPES)}") if column_class is not None and column_class not in _VALID_COLUMN_CLASSES: raise ValueError(f"Invalid column_class: {column_class!r}, expected one of {sorted(_VALID_COLUMN_CLASSES)}") + + +def parse_column_data_type(column_type: str): + """解析列类型。""" + validate_column_params(column_type=column_type) + if column_type in _VALID_COLUMN_SCALAR_TYPES: + return "SCALAR" + # 新加入的类型默认指定为 media + return "MEDIA" From 61131086b0229d57687dbdb3115f1189e9282e4c Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Fri, 24 Apr 2026 17:07:24 +0800 Subject: [PATCH 31/52] feat: add scalar retrieve --- swanlab/api/column.py | 12 +++++-- swanlab/api/metric.py | 63 ++++++++++++++++++++++++++--------- swanlab/api/typings/metric.py | 1 - swanlab/api/utils.py | 2 +- 4 files changed, 59 insertions(+), 19 deletions(-) diff --git a/swanlab/api/column.py b/swanlab/api/column.py index f4c517bb4..99457f5dc 100644 --- a/swanlab/api/column.py +++ b/swanlab/api/column.py @@ -61,6 +61,11 @@ def _ensure_data(self) -> ApiColumnType: proj_data = resp.data if resp.ok else {} self._project_id = proj_data.get("cuid", "") self._data["project_id"] = self._project_id + # 这里要确保是 cuid 而非 slug + run_resp = self._get(f"/project/{self._proj_path}/runs/{self._run_id}") + run_data = run_resp.data if run_resp.ok else {} + run_cuid = run_data.get("cuid", "") + self._run_id = run_cuid return self._data @property @@ -71,7 +76,7 @@ def project_id(self) -> str: @property def run_id(self) -> str: - if self._run_id: + if self._project_id: return self._run_id return self._ensure_data().get("run_id", "") @@ -132,7 +137,10 @@ def metric(self): metric_type = parse_column_data_type(self.column_type) - cur_metric = Metric(ctx=self._ctx, project_id=self.project_id, run_id=self.run_id, metric_type=metric_type) + cur_metric = Metric( + ctx=self._ctx, project_id=self.project_id, run_id=self.run_id, key=self.key, metric_type=metric_type + ) + return cur_metric.json() def json(self) -> Dict[str, Any]: return get_properties(self) diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index 3318abc96..fe3782256 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -9,7 +9,7 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings import ApiColumnCsvExportType, ApiMetricTypeLiteral, ApiResponseType -from swanlab.api.typings.metric import ApiLogType, ApiMediaType, ApiMetricType, ApiScalarType +from swanlab.api.typings.metric import ApiLogType, ApiMediaType, ApiMetricType, ApiScalarSeriesType, ApiScalarType from swanlab.api.utils import get_properties, resovle_run_path, validate_column_params, validate_metric_type @@ -25,6 +25,7 @@ def __init__( project_id: str, run_id: str, key: Optional[str] = "", + sample: int = 1500, metric_type: str = "SCALAR", data: Optional[Any] = None, ) -> None: @@ -33,14 +34,17 @@ def __init__( self._project_id = project_id self._run_id = run_id self._key = key - self._data = data self._metric_type = metric_type - def _ensure_data(self) -> ApiMetricType: + # TODO: 采样值,仅在 scalar 时生效, 待接入 + self._sample = sample + + def _ensure_data(self) -> Dict[str, Any]: if self._data is None: if self._metric_type == "SCALAR": - self._data = cast(ApiScalarType, {}) + self._data = self._fetch_scalar() + print(self._data) elif self._metric_type == "MEDIA": self._data = cast(ApiScalarType, {}) elif self._metric_type == "LOG": @@ -48,7 +52,7 @@ def _ensure_data(self) -> ApiMetricType: else: # 默认兜底到 scalar,实际上在实例化时被拦截 self._data = cast(ApiScalarType, {}) - return self._data + return cast(dict, self._data) @property def project_id(self) -> str: @@ -68,16 +72,45 @@ def metric_type(self) -> str: @property def metrics(self) -> List[Any]: - return [] - - def _fetch_scalar(self): - return - - def _fetch_media(self): - pass - - def _fetch_logs(self): - pass + return self._ensure_data().get("metrics", []) + + def _fetch_scalar(self) -> ApiScalarSeriesType: + res = ApiScalarSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) + # 1. 获取单指标列 + payload = { + "projectId": self.project_id, + "xType": "step", + "range": [0, 0], + "columns": [{"experimentId": self.run_id, "key": self.key}], + } + raw_resp = self._post("/house/metrics/scalar", data=payload) + resp_list = ( + raw_resp.data if raw_resp.ok and isinstance(raw_resp.data, list) and len(raw_resp.data) > 0 else None + ) + if resp_list is None: + return res + raw_data = resp_list[0] + res["metrics"] = raw_data.get("metrics", {}) + # 2. 获取统计值列 + stat_resp = self._post("/house/metrics/scalar/value", data=payload) + stat_list = ( + stat_resp.data if stat_resp.ok and isinstance(stat_resp.data, list) and len(stat_resp.data) > 0 else None + ) + if stat_list is None: + return res + stat_data = stat_list[0] + res["min"] = stat_data.get("min", {}) + res["max"] = stat_data.get("max", {}) + res["avg"] = stat_data.get("avg", {}) + res["median"] = stat_data.get("median", {}) + res["latest"] = stat_data.get("latest", {}) + return res + + def _fetch_media(self) -> Dict[str, Any]: + return {} + + def _fetch_logs(self) -> Dict[str, Any]: + return {} def export_csv(self) -> ApiResponseType: """ diff --git a/swanlab/api/typings/metric.py b/swanlab/api/typings/metric.py index 47d3fdf59..6dc286bf4 100644 --- a/swanlab/api/typings/metric.py +++ b/swanlab/api/typings/metric.py @@ -44,7 +44,6 @@ class ApiScalarSeriesType(ApiMetricColumnRefType, total=False): """标量指标序列,包含折线数据和聚合值""" metrics: List[ApiScalarType] - minMax: List[Any] min: ApiScalarType max: ApiScalarType avg: ApiScalarType diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 46e2ba0a8..6cc3aceda 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -111,7 +111,7 @@ def validate_metric_type(item: str, key: Optional[str] = None): """校验 metric_type 的合法性""" if item not in _VALID_METRIC_TYPES: raise ValueError(f"Invalid metric_type: {item!r}, expected one of {sorted(_VALID_METRIC_TYPES)}") - if key is None and item != "LOG": + if not key and item != "LOG": raise ValueError("key must NOT be None if metric_type != LOG") From abf5d34d071793e9301c221ccce8af663cc05ede Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Fri, 24 Apr 2026 17:48:16 +0800 Subject: [PATCH 32/52] refactor: support sampler --- swanlab/api/column.py | 56 +++++++++------- swanlab/api/experiment.py | 83 +++++++----------------- swanlab/api/metric.py | 133 +++++++++++++++++++++++--------------- 3 files changed, 134 insertions(+), 138 deletions(-) diff --git a/swanlab/api/column.py b/swanlab/api/column.py index 99457f5dc..5d5ea1bdd 100644 --- a/swanlab/api/column.py +++ b/swanlab/api/column.py @@ -8,7 +8,7 @@ from typing import Any, Dict, Iterator, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.column import ApiColumnCsvExportType, ApiColumnType +from swanlab.api.typings.column import ApiColumnType from swanlab.api.typings.common import ApiResponseType, PaginatedQuery from swanlab.api.utils import get_properties, parse_column_data_type, resovle_run_path, validate_column_params @@ -112,35 +112,41 @@ def error(self) -> Optional[Dict[str, Any]]: """列的错误信息。""" return self._ensure_data().get("error", {}) - def export_csv(self) -> ApiResponseType: - """ - 导出列数据为 CSV。 - - :return: ApiResponseType,成功时 data 包含临时下载 URL - """ - resp = self._get(f"/experiment/{self._run_id}/column/csv", params={"key": self.key}) - if not resp.ok: - return resp - - data = resp.data - if isinstance(data, list) and data: - url = data[0].get("url", "") - return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) - elif isinstance(data, dict): - url = data.get("url", "") - return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) - - return ApiResponseType(ok=False, errmsg="Invalid response format", data=None) - - def metric(self): + def _require_found(self) -> None: + """确保列数据已加载且存在,否则抛出清晰错误。""" + self._ensure_data() + if not self.key: + raise ValueError(f"Column '{self._key}' not found in the experiment") + + def metric(self, sample: int = 1500, ignore_timestamp: bool = False) -> Dict[str, Any]: from swanlab.api.metric import Metric + self._require_found() metric_type = parse_column_data_type(self.column_type) + metric = Metric( + ctx=self._ctx, + project_id=self.project_id, + run_id=self.run_id, + key=self.key, + sample=sample, + metric_type=metric_type, + ignore_timestamp=ignore_timestamp, + ) + return metric.json() - cur_metric = Metric( - ctx=self._ctx, project_id=self.project_id, run_id=self.run_id, key=self.key, metric_type=metric_type + def export_csv(self) -> ApiResponseType: + from swanlab.api.metric import Metric + + self._require_found() + metric_type = parse_column_data_type(self.column_type) + metric = Metric( + ctx=self._ctx, + project_id=self.project_id, + run_id=self.run_id, + key=self.key, + metric_type=metric_type, ) - return cur_metric.json() + return metric.export_csv() def json(self) -> Dict[str, Any]: return get_properties(self) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 592f599cd..eaafc6b8e 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -130,72 +130,33 @@ def profile(self) -> ApiExperimentProfileType: data = self._data return cast(ApiExperimentProfileType, self._ensure_data().get("profile", {})) - def metrics( - self, keys: Optional[List[str]] = None, x_axis: Optional[str] = None, sample: Optional[int] = None - ) -> Any: + def column(self, key: str, column_class: Optional[str] = "CUSTOM", column_type: Optional[str] = "FLOAT"): """ - 获取实验指标数据,返回 pandas DataFrame。 + 获取实验下指定 key 的单个列。 - :param keys: 指标 key 列表 - :param x_axis: x 轴指标,默认 step - :param sample: 均匀采样 N 条数据(等间距采样,保留整体趋势) + :param key: 列的 key,如 "loss"、"acc" + :param column_class: 列的分类,CUSTOM 或 SYSTEM + :param column_type: 列的数据类型,如 FLOAT、STRING、IMAGE 等 """ - from swanlab.vendor import pd - - if not keys: - return pd.DataFrame() - - fetch_keys = list(keys) - use_x_axis = x_axis is not None and x_axis != "step" - if use_x_axis and x_axis is not None: - fetch_keys.append(x_axis) - - dfs = [] - prefix = "" - for idx, key in enumerate(fetch_keys): - resp = self._get(f"/experiment/{self.run_id}/column/csv", params={"key": key}) - if not resp.ok: - continue - data = resp.data - csv_url = data[0].get("url", "") if isinstance(data, list) and data else "" - if not csv_url: - continue - df = pd.read_csv(csv_url, index_col=0) - - if idx == 0: - first_col = str(df.columns[0]) - suffix = f"{key}_" - prefix = first_col.split(suffix)[0] if suffix in first_col else "" - - def strip_suffix(col, suffix="_step"): - return col[: -len(suffix)] if col.endswith(suffix) else col - - df.columns = [ - strip_suffix(col[len(prefix) :]) if prefix and col.startswith(prefix) else strip_suffix(col) - for col in df.columns - ] - dfs.append(df) - - if not dfs: - return pd.DataFrame() - - result_df = dfs[0].join(dfs[1:], how="outer") if len(dfs) > 1 else dfs[0] - result_df = result_df.sort_index() - - if use_x_axis: - result_df = result_df.drop( - columns=[c for c in result_df.columns if c.endswith("_timestamp")], errors="ignore" - ) - if x_axis not in result_df.columns: - return pd.DataFrame() - cols = [x_axis] + [c for c in result_df.columns if c != x_axis] - result_df = result_df[cols].dropna(subset=[x_axis]) + from swanlab.api.column import Column + + return Column( + self._ctx, + path=f"{self._proj_path}/{self._cuid}", + key=key, + column_class=column_class, + column_type=column_type, + ) - if sample is not None and len(result_df) > sample: - indices = [int(i * (len(result_df) - 1) / (sample - 1)) for i in range(sample)] - result_df = result_df.iloc[indices] + def metric(self, key: str, sample: int = 1500, ignore_timestamp: bool = False) -> Dict[str, Any]: + """ + 获取实验下指定列的指标数据,最大返回 1500 条。 - return result_df + :param key: 列的 key + :param sample: 采样条数 + :param ignore_timestamp: 是否过滤 timestamp 字段 + """ + return self.column(key=key).metric(sample=sample, ignore_timestamp=ignore_timestamp) def columns( self, diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index fe3782256..fa44e6d16 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -1,21 +1,23 @@ """ @author: caddiesnew -@file: column.py +@file: metric.py @time: 2026/4/20 -@description: Column 实体类 — 实验列的查询与操作 +@description: Metric 实体类 — 指标序列的查询与操作 """ -from typing import Any, Dict, Iterator, List, Optional, cast +from typing import Any, Dict, List, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings import ApiColumnCsvExportType, ApiMetricTypeLiteral, ApiResponseType -from swanlab.api.typings.metric import ApiLogType, ApiMediaType, ApiMetricType, ApiScalarSeriesType, ApiScalarType -from swanlab.api.utils import get_properties, resovle_run_path, validate_column_params, validate_metric_type +from swanlab.api.typings import ApiColumnCsvExportType, ApiResponseType +from swanlab.api.typings.metric import ApiScalarSeriesType +from swanlab.api.utils import get_properties, validate_metric_type class Metric(BaseEntity): """ - 表示一个 SwanLab 指标列 (非单个数值,而是一组序列) + 表示一个 SwanLab 指标列(非单个数值,而是一组序列)。 + + 支持 SCALAR / MEDIA / LOG 三种类型,按需 Lazy Loading。 """ def __init__( @@ -27,32 +29,33 @@ def __init__( key: Optional[str] = "", sample: int = 1500, metric_type: str = "SCALAR", - data: Optional[Any] = None, + data: Optional[Dict[str, Any]] = None, + ignore_timestamp: bool = False, ) -> None: super().__init__(ctx) validate_metric_type(metric_type, key) self._project_id = project_id self._run_id = run_id self._key = key - self._data = data + self._data: Optional[Dict[str, Any]] = data self._metric_type = metric_type - - # TODO: 采样值,仅在 scalar 时生效, 待接入 + self._ignore_timestamp = ignore_timestamp + # TODO: 采样值,仅在 scalar 时生效,待接入 self._sample = sample + # 类型 → 加载方法 的分发表,新增类型只需在此注册 + _FETCH_DISPATCH = { + "SCALAR": "_fetch_scalar", + "MEDIA": "_fetch_media", + "LOG": "_fetch_logs", + } + def _ensure_data(self) -> Dict[str, Any]: if self._data is None: - if self._metric_type == "SCALAR": - self._data = self._fetch_scalar() - print(self._data) - elif self._metric_type == "MEDIA": - self._data = cast(ApiScalarType, {}) - elif self._metric_type == "LOG": - self._data = cast(ApiScalarType, {}) - else: - # 默认兜底到 scalar,实际上在实例化时被拦截 - self._data = cast(ApiScalarType, {}) - return cast(dict, self._data) + method_name = self._FETCH_DISPATCH.get(self._metric_type, "_fetch_scalar") + self._data = getattr(self, method_name)() + assert self._data is not None + return self._data @property def project_id(self) -> str: @@ -64,7 +67,7 @@ def run_id(self) -> str: @property def key(self) -> str: - return self._key if self._key else "" + return self._key or "" @property def metric_type(self) -> str: @@ -74,36 +77,45 @@ def metric_type(self) -> str: def metrics(self) -> List[Any]: return self._ensure_data().get("metrics", []) - def _fetch_scalar(self) -> ApiScalarSeriesType: - res = ApiScalarSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) - # 1. 获取单指标列 - payload = { + # ------------------------------------------------------------------ + # 请求辅助函数 + # ------------------------------------------------------------------ + + @staticmethod + def _extract_first(resp: ApiResponseType) -> Optional[Dict[str, Any]]: + """从列表型 API 响应中提取第一个元素,失败返回 None。""" + if resp.ok and isinstance(resp.data, list) and resp.data: + return resp.data[0] + return None + + def _build_scalar_payload(self) -> Dict[str, Any]: + return { "projectId": self.project_id, "xType": "step", "range": [0, 0], "columns": [{"experimentId": self.run_id, "key": self.key}], } - raw_resp = self._post("/house/metrics/scalar", data=payload) - resp_list = ( - raw_resp.data if raw_resp.ok and isinstance(raw_resp.data, list) and len(raw_resp.data) > 0 else None - ) - if resp_list is None: + + # ------------------------------------------------------------------ + # 类型专属加载 + # ------------------------------------------------------------------ + + def _fetch_scalar(self) -> ApiScalarSeriesType: + res = ApiScalarSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) + payload = self._build_scalar_payload() + + # 1. 获取折线数据 + raw_data = self._extract_first(self._post("/house/metrics/scalar", data=payload)) + if raw_data is None: return res - raw_data = resp_list[0] - res["metrics"] = raw_data.get("metrics", {}) - # 2. 获取统计值列 - stat_resp = self._post("/house/metrics/scalar/value", data=payload) - stat_list = ( - stat_resp.data if stat_resp.ok and isinstance(stat_resp.data, list) and len(stat_resp.data) > 0 else None - ) - if stat_list is None: + res["metrics"] = raw_data.get("metrics", []) + + # 2. 获取统计值 + stat_data = self._extract_first(self._post("/house/metrics/scalar/value", data=payload)) + if stat_data is None: return res - stat_data = stat_list[0] - res["min"] = stat_data.get("min", {}) - res["max"] = stat_data.get("max", {}) - res["avg"] = stat_data.get("avg", {}) - res["median"] = stat_data.get("median", {}) - res["latest"] = stat_data.get("latest", {}) + for field in ("min", "max", "avg", "median", "latest"): + res[field] = stat_data.get(field, {}) return res def _fetch_media(self) -> Dict[str, Any]: @@ -112,9 +124,13 @@ def _fetch_media(self) -> Dict[str, Any]: def _fetch_logs(self) -> Dict[str, Any]: return {} + # ------------------------------------------------------------------ + # 导出 + # ------------------------------------------------------------------ + def export_csv(self) -> ApiResponseType: """ - 导出列数据为 CSV。(同时支持 column 和 csv 导出) + 导出列数据为 CSV。 :return: ApiResponseType,成功时 data 包含临时下载 URL """ @@ -125,12 +141,25 @@ def export_csv(self) -> ApiResponseType: data = resp.data if isinstance(data, list) and data: url = data[0].get("url", "") - return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) elif isinstance(data, dict): url = data.get("url", "") - return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) - - return ApiResponseType(ok=False, errmsg="Invalid response format", data=None) + else: + return ApiResponseType(ok=False, errmsg="Invalid response format", data=None) + return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) def json(self) -> Dict[str, Any]: - return get_properties(self) + result = get_properties(self) + data = self._ensure_data() + + if self._metric_type == "SCALAR": + for field in ("min", "max", "avg", "median", "latest"): + val = data.get(field) + if val: + result[field] = val + + if self._ignore_timestamp: + for item in cast(List[Dict[str, Any]], result.get("metrics", [])): + if isinstance(item, dict): + item.pop("timestamp", None) + + return result From 8cf9aff29dfed836b2edaaea4396f6303e39ee10 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Fri, 24 Apr 2026 20:34:52 +0800 Subject: [PATCH 33/52] feat: support media metrics --- swanlab/api/column.py | 26 ++++++++++------------ swanlab/api/experiment.py | 24 ++++++++++++-------- swanlab/api/metric.py | 41 ++++++++++++++++++++++++++++++----- swanlab/api/typings/metric.py | 22 ++++++++++++++----- swanlab/api/utils.py | 12 +++++----- 5 files changed, 85 insertions(+), 40 deletions(-) diff --git a/swanlab/api/column.py b/swanlab/api/column.py index 5d5ea1bdd..70854d5e2 100644 --- a/swanlab/api/column.py +++ b/swanlab/api/column.py @@ -9,7 +9,7 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.column import ApiColumnType -from swanlab.api.typings.common import ApiResponseType, PaginatedQuery +from swanlab.api.typings.common import ApiMetricTypeLiteral, ApiResponseType, PaginatedQuery from swanlab.api.utils import get_properties, parse_column_data_type, resovle_run_path, validate_column_params @@ -41,7 +41,7 @@ def __init__( def _ensure_data(self) -> ApiColumnType: if self._data is None: - validate_column_params(column_class=self._column_class) + validate_column_params(column_class=self._column_class, column_type=self._column_type) extra: Dict[str, Any] = {"search": self._key} if self._column_class: extra["class"] = self._column_class @@ -82,10 +82,9 @@ def run_id(self) -> str: @property def key(self) -> str: - res_key = self._ensure_data().get("key", "") - if res_key and res_key != self._key: - self._key = res_key - return res_key + if self._key: + return self._key + return self._ensure_data().get("key", "") @property def name(self) -> str: @@ -112,16 +111,11 @@ def error(self) -> Optional[Dict[str, Any]]: """列的错误信息。""" return self._ensure_data().get("error", {}) - def _require_found(self) -> None: - """确保列数据已加载且存在,否则抛出清晰错误。""" - self._ensure_data() - if not self.key: - raise ValueError(f"Column '{self._key}' not found in the experiment") - - def metric(self, sample: int = 1500, ignore_timestamp: bool = False) -> Dict[str, Any]: + def metric( + self, sample: int = 1500, metric_type: ApiMetricTypeLiteral = "SCALAR", ignore_timestamp: bool = False + ) -> Dict[str, Any]: from swanlab.api.metric import Metric - self._require_found() metric_type = parse_column_data_type(self.column_type) metric = Metric( ctx=self._ctx, @@ -137,8 +131,10 @@ def metric(self, sample: int = 1500, ignore_timestamp: bool = False) -> Dict[str def export_csv(self) -> ApiResponseType: from swanlab.api.metric import Metric - self._require_found() metric_type = parse_column_data_type(self.column_type) + if metric_type != "SCALAR": + err_msg = "export_csv() only support SCALAR metric_type" + return ApiResponseType(ok=False, errmsg=err_msg, data=None) metric = Metric( ctx=self._ctx, project_id=self.project_id, diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index eaafc6b8e..0c56675d9 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -8,7 +8,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.common import PaginatedQuery +from swanlab.api.typings.common import ApiMetricTypeLiteral, PaginatedQuery from swanlab.api.typings.experiment import ( ApiExperimentLabelType, ApiExperimentProfileType, @@ -148,15 +148,21 @@ def column(self, key: str, column_class: Optional[str] = "CUSTOM", column_type: column_type=column_type, ) - def metric(self, key: str, sample: int = 1500, ignore_timestamp: bool = False) -> Dict[str, Any]: - """ - 获取实验下指定列的指标数据,最大返回 1500 条。 + def metric( + self, key: str, sample: int = 1500, metric_type: ApiMetricTypeLiteral = "SCALAR", ignore_timestamp: bool = False + ) -> Dict[str, Any]: + from swanlab.api.metric import Metric - :param key: 列的 key - :param sample: 采样条数 - :param ignore_timestamp: 是否过滤 timestamp 字段 - """ - return self.column(key=key).metric(sample=sample, ignore_timestamp=ignore_timestamp) + metric = Metric( + ctx=self._ctx, + project_id=self.project_id, + run_id=self.run_id, + key=key, + sample=sample, + metric_type=metric_type, + ignore_timestamp=ignore_timestamp, + ) + return metric.json() def columns( self, diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index fa44e6d16..2691b3253 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -9,7 +9,7 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings import ApiColumnCsvExportType, ApiResponseType -from swanlab.api.typings.metric import ApiScalarSeriesType +from swanlab.api.typings.metric import ApiLogSeriesType, ApiMediaSeriesType, ApiMediaType, ApiScalarSeriesType from swanlab.api.utils import get_properties, validate_metric_type @@ -96,6 +96,12 @@ def _build_scalar_payload(self) -> Dict[str, Any]: "columns": [{"experimentId": self.run_id, "key": self.key}], } + def _build_media_payload(self) -> Dict[str, Any]: + return { + "projectId": self.project_id, + "columns": [{"experimentId": self.run_id, "key": self.key}], + } + # ------------------------------------------------------------------ # 类型专属加载 # ------------------------------------------------------------------ @@ -118,11 +124,33 @@ def _fetch_scalar(self) -> ApiScalarSeriesType: res[field] = stat_data.get(field, {}) return res - def _fetch_media(self) -> Dict[str, Any]: - return {} + def _fetch_media(self) -> ApiMediaSeriesType: + res = ApiMediaSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) + payload = self._build_media_payload() + raw_resp = self._post("/house/metrics/f_media", data=payload) + raw_data = self._extract_first(raw_resp) + if raw_data is None: + return res + # print(raw_data) + metrics: List[ApiMediaType] = [] + prefix = f"{self.project_id}/{self.run_id}" + for entry in raw_data.get("metrics", []): + paths = entry.get("data", []) + mores = entry.get("more", []) + items = [] + for i, path in enumerate(paths): + item = {"path": path} + if i < len(mores) and isinstance(mores[i], dict): + item.update(mores[i]) + items.append(item) + metrics.append({"index": entry.get("index", 0), "prefix": prefix, "items": items}) + + res["metrics"] = metrics + return res - def _fetch_logs(self) -> Dict[str, Any]: - return {} + def _fetch_logs(self) -> ApiLogSeriesType: + res = ApiLogSeriesType(projectId=self.project_id, experimentId=self.run_id, key="LOG") + return res # ------------------------------------------------------------------ # 导出 @@ -134,6 +162,9 @@ def export_csv(self) -> ApiResponseType: :return: ApiResponseType,成功时 data 包含临时下载 URL """ + if self.metric_type != "SCALAR": + err_msg = "export_csv() only support SCALAR metric_type" + return ApiResponseType(ok=False, errmsg=err_msg, data=None) resp = self._get(f"/experiment/{self._run_id}/column/csv", params={"key": self.key}) if not resp.ok: return resp diff --git a/swanlab/api/typings/metric.py b/swanlab/api/typings/metric.py index 6dc286bf4..954de5e75 100644 --- a/swanlab/api/typings/metric.py +++ b/swanlab/api/typings/metric.py @@ -65,17 +65,24 @@ class ApiScalarSummaryItemType(TypedDict, total=False): # --------------------------------------------------------------------------- -# Media — 媒体 item 数据 +# Media — 媒体数据 # --------------------------------------------------------------------------- +class ApiMediaItemDataType(TypedDict, total=False): + path: str + + class ApiMediaType(TypedDict, total=False): - # 项目路径: proj_id/run_id 拼接而成 + index: int prefix: str - data: List[str] - more: List[Dict[str, Any]] + items: List[ApiMediaItemDataType] + + +class ApiMediaSeriesType(ApiMetricColumnRefType, total=False): + metrics: List[ApiMediaType] # --------------------------------------------------------------------------- -# Log — 日志 item 数据 +# Log — 日志数据 # --------------------------------------------------------------------------- class ApiLogType(TypedDict, total=False): epoch: int @@ -85,5 +92,10 @@ class ApiLogType(TypedDict, total=False): timestamp: str +class ApiLogSeriesType(ApiMetricColumnRefType, total=False): + logs: List[ApiLogType] + count: int + + # 统一数据类型定义用于类型提示 ApiMetricType = Union[ApiScalarType, ApiMediaType, ApiLogType] diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 6cc3aceda..49ebdbf34 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -107,12 +107,12 @@ def validate_filter(item: Dict[str, Any]) -> None: raise ValueError(f"filter value must be a list, got {type(item['value']).__name__}") -def validate_metric_type(item: str, key: Optional[str] = None): - """校验 metric_type 的合法性""" - if item not in _VALID_METRIC_TYPES: - raise ValueError(f"Invalid metric_type: {item!r}, expected one of {sorted(_VALID_METRIC_TYPES)}") - if not key and item != "LOG": - raise ValueError("key must NOT be None if metric_type != LOG") +def validate_metric_type(metric_type: str, key: Optional[str] = None) -> None: + """校验 metric_type 的合法性。非 LOG 类型必须提供非空 key。""" + if metric_type not in _VALID_METRIC_TYPES: + raise ValueError(f"Invalid metric_type: {metric_type!r}, expected one of {sorted(_VALID_METRIC_TYPES)}") + if metric_type != "LOG" and not key: + raise ValueError(f"key is required for metric_type {metric_type!r}, got key={key!r}") def validate_group(item: Dict[str, Any]) -> None: From cd0d262ade8cfa0a8b148827441d40dd18376c98 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Sat, 25 Apr 2026 19:25:12 +0800 Subject: [PATCH 34/52] feat: support logs metric --- swanlab/api/experiment.py | 22 ++++++++++++++- swanlab/api/metric.py | 48 +++++++++++++++++++++++++++++---- swanlab/api/typings/__init__.py | 16 ++++++++++- swanlab/api/typings/common.py | 5 +++- swanlab/api/utils.py | 10 ++++++- 5 files changed, 92 insertions(+), 9 deletions(-) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 0c56675d9..43ad99ffb 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -8,7 +8,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.common import ApiMetricTypeLiteral, PaginatedQuery +from swanlab.api.typings.common import ApiMetricLogLevelLiteral, ApiMetricTypeLiteral, PaginatedQuery from swanlab.api.typings.experiment import ( ApiExperimentLabelType, ApiExperimentProfileType, @@ -164,6 +164,26 @@ def metric( ) return metric.json() + def logs( + self, + offset: Optional[int] = 0, + level: ApiMetricLogLevelLiteral = "INFO", + ignore_timestamp: bool = False, + ) -> Dict[str, Any]: + from swanlab.api.metric import Metric + + logs = Metric( + ctx=self._ctx, + project_id=self.project_id, + run_id=self.run_id, + key="LOG", + log_offset=offset, + log_level=level, + metric_type="LOG", + ignore_timestamp=ignore_timestamp, + ) + return logs.json() + def columns( self, page: int = 1, diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index 2691b3253..991f72ed3 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -10,7 +10,7 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings import ApiColumnCsvExportType, ApiResponseType from swanlab.api.typings.metric import ApiLogSeriesType, ApiMediaSeriesType, ApiMediaType, ApiScalarSeriesType -from swanlab.api.utils import get_properties, validate_metric_type +from swanlab.api.utils import get_properties, validate_metric_log_level, validate_metric_type class Metric(BaseEntity): @@ -27,21 +27,28 @@ def __init__( project_id: str, run_id: str, key: Optional[str] = "", - sample: int = 1500, + sample: int = 1000, + log_offset: Optional[int] = 0, # 标记第几个分片,仅对 Log metric_type 有效 + log_level: str = "INFO", metric_type: str = "SCALAR", data: Optional[Dict[str, Any]] = None, ignore_timestamp: bool = False, ) -> None: super().__init__(ctx) validate_metric_type(metric_type, key) + if metric_type == "LOG": + validate_metric_log_level(log_level) self._project_id = project_id self._run_id = run_id self._key = key self._data: Optional[Dict[str, Any]] = data self._metric_type = metric_type self._ignore_timestamp = ignore_timestamp - # TODO: 采样值,仅在 scalar 时生效,待接入 + # TODO: 采样值, scalar 时生效,logs 时降级到 1000 self._sample = sample + # 偏移量,仅对 Log metric_type 有效, 默认为 0 + self._offset = log_offset + self._log_level = log_level # 类型 → 加载方法 的分发表,新增类型只需在此注册 _FETCH_DISPATCH = { @@ -77,6 +84,14 @@ def metric_type(self) -> str: def metrics(self) -> List[Any]: return self._ensure_data().get("metrics", []) + @property + def logs(self) -> List[Any]: + return self._ensure_data().get("logs", []) + + @property + def count(self) -> int: + return self._ensure_data().get("count", 0) + # ------------------------------------------------------------------ # 请求辅助函数 # ------------------------------------------------------------------ @@ -102,6 +117,15 @@ def _build_media_payload(self) -> Dict[str, Any]: "columns": [{"experimentId": self.run_id, "key": self.key}], } + def _build_log_params(self) -> Dict[str, Any]: + return { + "projectId": self.project_id, + "experimentId": self.run_id, + "size": 1000, # 硬编码为 1000 + "epoch": self._offset, + "level": self._log_level, + } + # ------------------------------------------------------------------ # 类型专属加载 # ------------------------------------------------------------------ @@ -131,7 +155,6 @@ def _fetch_media(self) -> ApiMediaSeriesType: raw_data = self._extract_first(raw_resp) if raw_data is None: return res - # print(raw_data) metrics: List[ApiMediaType] = [] prefix = f"{self.project_id}/{self.run_id}" for entry in raw_data.get("metrics", []): @@ -150,6 +173,14 @@ def _fetch_media(self) -> ApiMediaSeriesType: def _fetch_logs(self) -> ApiLogSeriesType: res = ApiLogSeriesType(projectId=self.project_id, experimentId=self.run_id, key="LOG") + params = self._build_log_params() + raw_resp = self._get("/house/metrics/log", params=params) + if not raw_resp.ok or not raw_resp.data: + return res + data = raw_resp.data + if isinstance(data, dict): + res["logs"] = data.get("logs", []) + res["count"] = data.get("count", 0) return res # ------------------------------------------------------------------ @@ -188,8 +219,15 @@ def json(self) -> Dict[str, Any]: if val: result[field] = val + if self._metric_type == "LOG": + result.pop("metrics", None) + else: + result.pop("logs", None) + result.pop("count", None) + if self._ignore_timestamp: - for item in cast(List[Dict[str, Any]], result.get("metrics", [])): + timestamp_items = result.get("metrics", []) or result.get("logs", []) + for item in cast(List[Dict[str, Any]], timestamp_items): if isinstance(item, dict): item.pop("timestamp", None) diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index 134c19fd2..e1ba1bb28 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -9,6 +9,9 @@ from .common import ( ApiIdentityLiteral, ApiLicensePlanLiteral, + ApiMetricLogLevelLiteral, + ApiMetricTypeLiteral, + ApiMetricXAxisLiteral, ApiPaginationType, ApiResponseType, ApiRoleLiteral, @@ -18,7 +21,11 @@ ApiWorkspaceLiteral, ) from .experiment import ApiExperimentLabelType, ApiExperimentType -from .metric import * +from .metric import ( + ApiLogSeriesType, + ApiMediaSeriesType, + ApiScalarSeriesType, +) from .project import ApiProjectCountType, ApiProjectLabelType, ApiProjectType from .selfhosted import ApiApiKeyType, ApiSelfHostedInfoType from .user import ApiUserProfileType, ApiUserType @@ -33,6 +40,9 @@ "ApiRoleLiteral", "ApiIdentityLiteral", "ApiLicensePlanLiteral", + "ApiMetricLogLevelLiteral", + "ApiMetricTypeLiteral", + "ApiMetricXAxisLiteral", # General TypedDicts "ApiPaginationType", "ApiResponseType", @@ -56,4 +66,8 @@ "ApiColumnErrorType", "ApiColumnType", "ApiColumnCsvExportType", + # Metric + "ApiLogSeriesType", + "ApiMediaSeriesType", + "ApiScalarSeriesType", ] diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index d778cf42f..a95eecc02 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -64,7 +64,10 @@ ApiLicensePlanLiteral = Literal["free", "commercial"] # 指标类型(log 不属于 column-backed metrics,使用独立查询方法) -ApiMetricTypeLiteral = Literal["SCALAR", "MEDIA", "LOG"] +ApiMetricTypeLiteral = Literal["SCALAR", "MEDIA"] + +# 指标日志级别 +ApiMetricLogLevelLiteral = Literal["DEBUG", "INFO", "WARN", "ERROR"] # X 轴类型 ApiMetricXAxisLiteral = Literal["step", "time", "relative_time"] diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 49ebdbf34..ae199ada4 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -13,6 +13,7 @@ ApiColumnScalarTypeLiteral, ApiFilterOpLiteral, ApiFilterStableKeyLiteral, + ApiMetricLogLevelLiteral, ApiMetricTypeLiteral, ApiSidebarLiteral, ApiSortOrderLiteral, @@ -77,6 +78,7 @@ def resovle_run_path(path: str) -> Tuple[str, str]: # 指标相关校验常量 _VALID_METRIC_TYPES = frozenset(get_args(ApiMetricTypeLiteral)) +_VALID_METRIC_LOG_LEVELS = frozenset(get_args(ApiMetricLogLevelLiteral)) def _check_required(item: Dict[str, Any], keys: Set[str]) -> None: @@ -109,12 +111,18 @@ def validate_filter(item: Dict[str, Any]) -> None: def validate_metric_type(metric_type: str, key: Optional[str] = None) -> None: """校验 metric_type 的合法性。非 LOG 类型必须提供非空 key。""" - if metric_type not in _VALID_METRIC_TYPES: + if metric_type not in _VALID_METRIC_TYPES and metric_type != "LOG": raise ValueError(f"Invalid metric_type: {metric_type!r}, expected one of {sorted(_VALID_METRIC_TYPES)}") if metric_type != "LOG" and not key: raise ValueError(f"key is required for metric_type {metric_type!r}, got key={key!r}") +def validate_metric_log_level(level: str) -> None: + """校验 metric log level 的合法性。""" + if level not in _VALID_METRIC_LOG_LEVELS: + raise ValueError(f"Invalid metric log level: {level!r}, expected one of {sorted(_VALID_METRIC_LOG_LEVELS)}") + + def validate_group(item: Dict[str, Any]) -> None: """校验单个 group item 的合法性。""" _check_required(item, {"key", "type"}) From 311e91041ca00fbe966db5a05c3fe603451cbf0d Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Sat, 25 Apr 2026 19:59:15 +0800 Subject: [PATCH 35/52] feat: support metrics method --- swanlab/api/experiment.py | 19 +++++ swanlab/api/metric.py | 143 +++++++++++++++++++++++++++++++++++--- 2 files changed, 153 insertions(+), 9 deletions(-) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 43ad99ffb..dca577abf 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -164,6 +164,25 @@ def metric( ) return metric.json() + def metrics( + self, + keys: List[str], + metric_type: ApiMetricTypeLiteral = "SCALAR", + sample: int = 1500, + ignore_timestamp: bool = False, + ) -> Dict[str, Any]: + from swanlab.api.metric import Metrics + + return Metrics( + ctx=self._ctx, + project_id=self.project_id, + run_id=self.run_id, + keys=keys, + sample=sample, + metric_type=metric_type, + ignore_timestamp=ignore_timestamp, + ).json() + def logs( self, offset: Optional[int] = 0, diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index 991f72ed3..1721091b0 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -5,10 +5,11 @@ @description: Metric 实体类 — 指标序列的查询与操作 """ -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, Iterator, List, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings import ApiColumnCsvExportType, ApiResponseType +from swanlab.api.typings.common import ApiMetricTypeLiteral from swanlab.api.typings.metric import ApiLogSeriesType, ApiMediaSeriesType, ApiMediaType, ApiScalarSeriesType from swanlab.api.utils import get_properties, validate_metric_log_level, validate_metric_type @@ -103,18 +104,20 @@ def _extract_first(resp: ApiResponseType) -> Optional[Dict[str, Any]]: return resp.data[0] return None - def _build_scalar_payload(self) -> Dict[str, Any]: + @staticmethod + def _build_scalar_payload(project_id: str, run_id: str, keys: List[str]) -> Dict[str, Any]: return { - "projectId": self.project_id, + "projectId": project_id, "xType": "step", "range": [0, 0], - "columns": [{"experimentId": self.run_id, "key": self.key}], + "columns": [{"experimentId": run_id, "key": key} for key in keys], } - def _build_media_payload(self) -> Dict[str, Any]: + @staticmethod + def _build_media_payload(project_id: str, run_id: str, keys: List[str]) -> Dict[str, Any]: return { - "projectId": self.project_id, - "columns": [{"experimentId": self.run_id, "key": self.key}], + "projectId": project_id, + "columns": [{"experimentId": run_id, "key": key} for key in keys], } def _build_log_params(self) -> Dict[str, Any]: @@ -132,7 +135,7 @@ def _build_log_params(self) -> Dict[str, Any]: def _fetch_scalar(self) -> ApiScalarSeriesType: res = ApiScalarSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) - payload = self._build_scalar_payload() + payload = self._build_scalar_payload(self.project_id, self.run_id, [self.key]) # 1. 获取折线数据 raw_data = self._extract_first(self._post("/house/metrics/scalar", data=payload)) @@ -150,7 +153,7 @@ def _fetch_scalar(self) -> ApiScalarSeriesType: def _fetch_media(self) -> ApiMediaSeriesType: res = ApiMediaSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) - payload = self._build_media_payload() + payload = self._build_media_payload(self.project_id, self.run_id, [self.key]) raw_resp = self._post("/house/metrics/f_media", data=payload) raw_data = self._extract_first(raw_resp) if raw_data is None: @@ -232,3 +235,125 @@ def json(self) -> Dict[str, Any]: item.pop("timestamp", None) return result + + +class Metrics(BaseEntity): + """ + 批量指标数据的迭代器。 + + 一次 metrics 查询只支持一种 metric_type(SCALAR 或 MEDIA),不支持 LOG。 + 通过 payload 的 columns 数组一次性传递多个 key,减少网络请求。 + + 用法:: + + for m in experiment.metrics(keys=["loss", "acc"], metric_type="SCALAR"): + print(m.key, m.metrics) + """ + + def __init__( + self, + ctx: ApiClientContext, + *, + project_id: str, + run_id: str, + keys: List[str], + metric_type: ApiMetricTypeLiteral, + sample: int = 1500, + ignore_timestamp: bool = False, + ) -> None: + super().__init__(ctx) + if metric_type == "LOG": + raise ValueError("Metrics does not support LOG metric_type, use Experiment.logs() instead") + if not keys: + raise ValueError("keys must be a non-empty list") + self._project_id = project_id + self._run_id = run_id + self._keys = keys + self._metric_type = metric_type + self._sample = sample + self._ignore_timestamp = ignore_timestamp + self._page_info: Dict[str, Any] = { + "keys": keys, + "metricType": metric_type, + "list": [], + } + + def __iter__(self) -> Iterator[Metric]: + if self._metric_type == "SCALAR": + yield from self._fetch_scalars() + else: + yield from self._fetch_medias() + + def _build_metric(self, key: str, data: Dict[str, Any]) -> Metric: + return Metric( + ctx=self._ctx, + project_id=self._project_id, + run_id=self._run_id, + key=key, + metric_type=self._metric_type, + sample=self._sample, + ignore_timestamp=self._ignore_timestamp, + data=data, + ) + + def _fetch_scalars(self) -> Iterator[Metric]: + payload = Metric._build_scalar_payload(self._project_id, self._run_id, self._keys) + + # 1. 获取折线数据 + scalar_resp = self._post("/house/metrics/scalar", data=payload) + scalar_list: List[Dict[str, Any]] = ( + scalar_resp.data if scalar_resp.ok and isinstance(scalar_resp.data, list) else [] + ) + + # 2. 获取统计值 + value_resp = self._post("/house/metrics/scalar/value", data=payload) + value_list: List[Dict[str, Any]] = value_resp.ok and isinstance(value_resp.data, list) and value_resp.data or [] + + for i, key in enumerate(self._keys): + data: Dict[str, Any] = { + "projectId": self._project_id, + "experimentId": self._run_id, + "key": key, + "metrics": [], + } + if i < len(scalar_list): + data["metrics"] = scalar_list[i].get("metrics", []) + if i < len(value_list): + for field in ("min", "max", "avg", "median", "latest"): + val = value_list[i].get(field) + if val is not None: + data[field] = val + yield self._build_metric(key, data) + + def _fetch_medias(self) -> Iterator[Metric]: + payload = Metric._build_media_payload(self._project_id, self._run_id, self._keys) + raw_resp = self._post("/house/metrics/f_media", data=payload) + raw_list: List[Dict[str, Any]] = raw_resp.ok and isinstance(raw_resp.data, list) and raw_resp.data or [] + + for i, key in enumerate(self._keys): + data: Dict[str, Any] = { + "projectId": self._project_id, + "experimentId": self._run_id, + "key": key, + "metrics": [], + } + if i < len(raw_list): + raw_data = raw_list[i] + metrics: List[ApiMediaType] = [] + prefix = f"{self._project_id}/{self._run_id}" + for entry in raw_data.get("metrics", []): + paths = entry.get("data", []) + mores = entry.get("more", []) + items = [] + for j, path in enumerate(paths): + item: Dict[str, Any] = {"path": path} + if j < len(mores) and isinstance(mores[j], dict): + item.update(mores[j]) + items.append(item) + metrics.append({"index": entry.get("index", 0), "prefix": prefix, "items": items}) + data["metrics"] = metrics + yield self._build_metric(key, data) + + def json(self) -> Dict[str, Any]: + self._page_info["list"] = [m.json() for m in self] + return self._page_info From c4502adcbef3fa5bb2ff7ffdc17ff4964da2cb67 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Sun, 26 Apr 2026 00:10:58 +0800 Subject: [PATCH 36/52] feat: compose all media urls --- swanlab/api/metric.py | 76 ++++++++++++++++++++++++++--------- swanlab/api/typings/metric.py | 3 +- 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index 1721091b0..c7ad3b419 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -10,8 +10,15 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings import ApiColumnCsvExportType, ApiResponseType from swanlab.api.typings.common import ApiMetricTypeLiteral -from swanlab.api.typings.metric import ApiLogSeriesType, ApiMediaSeriesType, ApiMediaType, ApiScalarSeriesType +from swanlab.api.typings.metric import ( + ApiLogSeriesType, + ApiMediaItemDataType, + ApiMediaSeriesType, + ApiMediaType, + ApiScalarSeriesType, +) from swanlab.api.utils import get_properties, validate_metric_log_level, validate_metric_type +from swanlab.sdk.internal.pkg import console class Metric(BaseEntity): @@ -151,6 +158,35 @@ def _fetch_scalar(self) -> ApiScalarSeriesType: res[field] = stat_data.get(field, {}) return res + @staticmethod + def _fetch_presigned_urls(entity: BaseEntity, prefix: str, paths: List[str]) -> Dict[str, str]: + """批量获取预签名下载链接,返回 path → url 映射。""" + if not paths: + return {} + resp = entity._post("/resources/presigned/get", data={"prefix": prefix, "paths": paths}) + if not resp.ok or not isinstance(resp.data, dict): + return {} + urls = resp.data.get("urls", []) + return dict(zip(paths, urls)) if urls else {} + + @staticmethod + def _build_media_items( + entry: Dict[str, Any], + url_map: Dict[str, str], + ) -> List[ApiMediaItemDataType]: + """将单个 metric entry 的 data/more 合并为 items,注入预签名 url。""" + paths = entry.get("data", []) + mores = entry.get("more", []) + items: List[ApiMediaItemDataType] = [] + for i, path in enumerate(paths): + item: ApiMediaItemDataType = {} + if path in url_map: + item["url"] = url_map[path] + if i < len(mores) and isinstance(mores[i], dict): + item.update(mores[i]) + items.append(item) + return items + def _fetch_media(self) -> ApiMediaSeriesType: res = ApiMediaSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) payload = self._build_media_payload(self.project_id, self.run_id, [self.key]) @@ -160,16 +196,15 @@ def _fetch_media(self) -> ApiMediaSeriesType: return res metrics: List[ApiMediaType] = [] prefix = f"{self.project_id}/{self.run_id}" + all_paths = [p for entry in raw_data.get("metrics", []) for p in entry.get("data", [])] + if all_paths: + console.info( + f"Media fetched: run_id[{self.run_id}], key[{self.key}] - {len(all_paths)} items, requesting presigned urls..." + ) + url_map = self._fetch_presigned_urls(self, prefix, all_paths) for entry in raw_data.get("metrics", []): - paths = entry.get("data", []) - mores = entry.get("more", []) - items = [] - for i, path in enumerate(paths): - item = {"path": path} - if i < len(mores) and isinstance(mores[i], dict): - item.update(mores[i]) - items.append(item) - metrics.append({"index": entry.get("index", 0), "prefix": prefix, "items": items}) + items = self._build_media_items(entry, url_map) + metrics.append({"index": entry.get("index", 0), "items": items}) res["metrics"] = metrics return res @@ -330,6 +365,15 @@ def _fetch_medias(self) -> Iterator[Metric]: raw_resp = self._post("/house/metrics/f_media", data=payload) raw_list: List[Dict[str, Any]] = raw_resp.ok and isinstance(raw_resp.data, list) and raw_resp.data or [] + # collect all paths across all keys for a single presigned URL batch call + prefix = f"{self._project_id}/{self._run_id}" + all_paths = [p for raw_data in raw_list for entry in raw_data.get("metrics", []) for p in entry.get("data", [])] + url_map = Metric._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} + if all_paths: + console.info( + f"Media fetched: run_id[{self._run_id}] - {len(all_paths)} items across {len(self._keys)} keys, requesting presigned urls..." + ) + for i, key in enumerate(self._keys): data: Dict[str, Any] = { "projectId": self._project_id, @@ -340,17 +384,9 @@ def _fetch_medias(self) -> Iterator[Metric]: if i < len(raw_list): raw_data = raw_list[i] metrics: List[ApiMediaType] = [] - prefix = f"{self._project_id}/{self._run_id}" for entry in raw_data.get("metrics", []): - paths = entry.get("data", []) - mores = entry.get("more", []) - items = [] - for j, path in enumerate(paths): - item: Dict[str, Any] = {"path": path} - if j < len(mores) and isinstance(mores[j], dict): - item.update(mores[j]) - items.append(item) - metrics.append({"index": entry.get("index", 0), "prefix": prefix, "items": items}) + items = Metric._build_media_items(entry, url_map) + metrics.append({"index": entry.get("index", 0), "items": items}) data["metrics"] = metrics yield self._build_metric(key, data) diff --git a/swanlab/api/typings/metric.py b/swanlab/api/typings/metric.py index 954de5e75..617e2e6f0 100644 --- a/swanlab/api/typings/metric.py +++ b/swanlab/api/typings/metric.py @@ -68,12 +68,11 @@ class ApiScalarSummaryItemType(TypedDict, total=False): # Media — 媒体数据 # --------------------------------------------------------------------------- class ApiMediaItemDataType(TypedDict, total=False): - path: str + url: str class ApiMediaType(TypedDict, total=False): index: int - prefix: str items: List[ApiMediaItemDataType] From a0bdfb7bea470228e549c1926e273f36765398fe Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Sun, 26 Apr 2026 09:38:41 +0800 Subject: [PATCH 37/52] feat: constraint media to single post --- swanlab/api/column.py | 7 ++- swanlab/api/experiment.py | 10 +++- swanlab/api/metric.py | 93 +++++++++++++++++++++++++---------- swanlab/api/typings/metric.py | 2 + 4 files changed, 83 insertions(+), 29 deletions(-) diff --git a/swanlab/api/column.py b/swanlab/api/column.py index 70854d5e2..a0935a959 100644 --- a/swanlab/api/column.py +++ b/swanlab/api/column.py @@ -112,7 +112,11 @@ def error(self) -> Optional[Dict[str, Any]]: return self._ensure_data().get("error", {}) def metric( - self, sample: int = 1500, metric_type: ApiMetricTypeLiteral = "SCALAR", ignore_timestamp: bool = False + self, + sample: int = 1500, + metric_type: ApiMetricTypeLiteral = "SCALAR", + ignore_timestamp: bool = False, + media_step: Optional[int] = None, ) -> Dict[str, Any]: from swanlab.api.metric import Metric @@ -125,6 +129,7 @@ def metric( sample=sample, metric_type=metric_type, ignore_timestamp=ignore_timestamp, + media_step=media_step, ) return metric.json() diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index dca577abf..dafce508e 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -149,7 +149,12 @@ def column(self, key: str, column_class: Optional[str] = "CUSTOM", column_type: ) def metric( - self, key: str, sample: int = 1500, metric_type: ApiMetricTypeLiteral = "SCALAR", ignore_timestamp: bool = False + self, + key: str, + sample: int = 1500, + metric_type: ApiMetricTypeLiteral = "SCALAR", + ignore_timestamp: bool = False, + media_step: Optional[int] = 0, ) -> Dict[str, Any]: from swanlab.api.metric import Metric @@ -161,6 +166,7 @@ def metric( sample=sample, metric_type=metric_type, ignore_timestamp=ignore_timestamp, + media_step=media_step, ) return metric.json() @@ -170,6 +176,7 @@ def metrics( metric_type: ApiMetricTypeLiteral = "SCALAR", sample: int = 1500, ignore_timestamp: bool = False, + media_step: Optional[int] = 0, ) -> Dict[str, Any]: from swanlab.api.metric import Metrics @@ -181,6 +188,7 @@ def metrics( sample=sample, metric_type=metric_type, ignore_timestamp=ignore_timestamp, + media_step=media_step, ).json() def logs( diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index c7ad3b419..c01a3d180 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -41,6 +41,7 @@ def __init__( metric_type: str = "SCALAR", data: Optional[Dict[str, Any]] = None, ignore_timestamp: bool = False, + media_step: Optional[int] = None, ) -> None: super().__init__(ctx) validate_metric_type(metric_type, key) @@ -57,6 +58,7 @@ def __init__( # 偏移量,仅对 Log metric_type 有效, 默认为 0 self._offset = log_offset self._log_level = log_level + self._media_step = media_step # 类型 → 加载方法 的分发表,新增类型只需在此注册 _FETCH_DISPATCH = { @@ -100,6 +102,15 @@ def logs(self) -> List[Any]: def count(self) -> int: return self._ensure_data().get("count", 0) + @property + def steps(self) -> List[int]: + return self._ensure_data().get("steps", []) + + # only available for media type + @property + def step(self) -> Optional[int]: + return self._ensure_data().get("step") + # ------------------------------------------------------------------ # 请求辅助函数 # ------------------------------------------------------------------ @@ -121,11 +132,16 @@ def _build_scalar_payload(project_id: str, run_id: str, keys: List[str]) -> Dict } @staticmethod - def _build_media_payload(project_id: str, run_id: str, keys: List[str]) -> Dict[str, Any]: - return { + def _build_media_payload( + project_id: str, run_id: str, keys: List[str], step: Optional[int] = None + ) -> Dict[str, Any]: + payload: Dict[str, Any] = { "projectId": project_id, "columns": [{"experimentId": run_id, "key": key} for key in keys], } + if step is not None: + payload["step"] = step + return payload def _build_log_params(self) -> Dict[str, Any]: return { @@ -189,24 +205,33 @@ def _build_media_items( def _fetch_media(self) -> ApiMediaSeriesType: res = ApiMediaSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) - payload = self._build_media_payload(self.project_id, self.run_id, [self.key]) - raw_resp = self._post("/house/metrics/f_media", data=payload) - raw_data = self._extract_first(raw_resp) - if raw_data is None: + payload = self._build_media_payload(self.project_id, self.run_id, [self.key], step=self._media_step) + raw_resp = self._post("/house/metrics/media", data=payload) + if not raw_resp.ok or not raw_resp.data: + return res + data = raw_resp.data + if not isinstance(data, dict): return res - metrics: List[ApiMediaType] = [] + + res["steps"] = data.get("steps", []) + step_val = data.get("step") + if step_val is not None: + res["step"] = step_val + + metrics_raw: List[Dict[str, Any]] = data.get("metrics", []) + metric_entry = next((m for m in metrics_raw if m.get("key") == self.key), None) + if metric_entry is None: + return res + prefix = f"{self.project_id}/{self.run_id}" - all_paths = [p for entry in raw_data.get("metrics", []) for p in entry.get("data", [])] + all_paths = metric_entry.get("data", []) + url_map = self._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} if all_paths: console.info( f"Media fetched: run_id[{self.run_id}], key[{self.key}] - {len(all_paths)} items, requesting presigned urls..." ) - url_map = self._fetch_presigned_urls(self, prefix, all_paths) - for entry in raw_data.get("metrics", []): - items = self._build_media_items(entry, url_map) - metrics.append({"index": entry.get("index", 0), "items": items}) - - res["metrics"] = metrics + items = self._build_media_items(metric_entry, url_map) + res["metrics"] = [{"index": data.get("step", 0), "items": items}] return res def _fetch_logs(self) -> ApiLogSeriesType: @@ -263,6 +288,10 @@ def json(self) -> Dict[str, Any]: result.pop("logs", None) result.pop("count", None) + if self._metric_type != "MEDIA": + result.pop("steps", None) + result.pop("step", None) + if self._ignore_timestamp: timestamp_items = result.get("metrics", []) or result.get("logs", []) for item in cast(List[Dict[str, Any]], timestamp_items): @@ -295,6 +324,7 @@ def __init__( metric_type: ApiMetricTypeLiteral, sample: int = 1500, ignore_timestamp: bool = False, + media_step: Optional[int] = None, ) -> None: super().__init__(ctx) if metric_type == "LOG": @@ -307,6 +337,7 @@ def __init__( self._metric_type = metric_type self._sample = sample self._ignore_timestamp = ignore_timestamp + self._media_step = media_step self._page_info: Dict[str, Any] = { "keys": keys, "metricType": metric_type, @@ -328,6 +359,7 @@ def _build_metric(self, key: str, data: Dict[str, Any]) -> Metric: metric_type=self._metric_type, sample=self._sample, ignore_timestamp=self._ignore_timestamp, + media_step=self._media_step, data=data, ) @@ -361,33 +393,40 @@ def _fetch_scalars(self) -> Iterator[Metric]: yield self._build_metric(key, data) def _fetch_medias(self) -> Iterator[Metric]: - payload = Metric._build_media_payload(self._project_id, self._run_id, self._keys) - raw_resp = self._post("/house/metrics/f_media", data=payload) - raw_list: List[Dict[str, Any]] = raw_resp.ok and isinstance(raw_resp.data, list) and raw_resp.data or [] + payload = Metric._build_media_payload(self._project_id, self._run_id, self._keys, step=self._media_step) + raw_resp = self._post("/house/metrics/media", data=payload) + if not raw_resp.ok or not raw_resp.data: + return + resp_data = raw_resp.data + if not isinstance(resp_data, dict): + return + + steps = resp_data.get("steps", []) + current_step = resp_data.get("step") + metrics_raw: List[Dict[str, Any]] = resp_data.get("metrics", []) - # collect all paths across all keys for a single presigned URL batch call prefix = f"{self._project_id}/{self._run_id}" - all_paths = [p for raw_data in raw_list for entry in raw_data.get("metrics", []) for p in entry.get("data", [])] + all_paths = [p for entry in metrics_raw for p in entry.get("data", [])] url_map = Metric._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} if all_paths: console.info( f"Media fetched: run_id[{self._run_id}] - {len(all_paths)} items across {len(self._keys)} keys, requesting presigned urls..." ) - for i, key in enumerate(self._keys): + key_to_entry: Dict[str, Dict[str, Any]] = {e.get("key", ""): e for e in metrics_raw} + for key in self._keys: data: Dict[str, Any] = { "projectId": self._project_id, "experimentId": self._run_id, "key": key, + "steps": steps, + "step": current_step, "metrics": [], } - if i < len(raw_list): - raw_data = raw_list[i] - metrics: List[ApiMediaType] = [] - for entry in raw_data.get("metrics", []): - items = Metric._build_media_items(entry, url_map) - metrics.append({"index": entry.get("index", 0), "items": items}) - data["metrics"] = metrics + entry = key_to_entry.get(key) + if entry: + items = Metric._build_media_items(entry, url_map) + data["metrics"] = [{"index": current_step or 0, "items": items}] yield self._build_metric(key, data) def json(self) -> Dict[str, Any]: diff --git a/swanlab/api/typings/metric.py b/swanlab/api/typings/metric.py index 617e2e6f0..ef0e99536 100644 --- a/swanlab/api/typings/metric.py +++ b/swanlab/api/typings/metric.py @@ -77,6 +77,8 @@ class ApiMediaType(TypedDict, total=False): class ApiMediaSeriesType(ApiMetricColumnRefType, total=False): + steps: List[int] + step: int metrics: List[ApiMediaType] From 24a0a1f60f4442327dd48fa7580e8e771ffb1295 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Sun, 26 Apr 2026 10:44:51 +0800 Subject: [PATCH 38/52] feat: split medias and scalars --- swanlab/api/experiment.py | 44 +++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index dafce508e..794d92556 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -8,7 +8,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.common import ApiMetricLogLevelLiteral, ApiMetricTypeLiteral, PaginatedQuery +from swanlab.api.typings.common import ApiMetricLogLevelLiteral, PaginatedQuery from swanlab.api.typings.experiment import ( ApiExperimentLabelType, ApiExperimentProfileType, @@ -152,9 +152,7 @@ def metric( self, key: str, sample: int = 1500, - metric_type: ApiMetricTypeLiteral = "SCALAR", ignore_timestamp: bool = False, - media_step: Optional[int] = 0, ) -> Dict[str, Any]: from swanlab.api.metric import Metric @@ -164,19 +162,15 @@ def metric( run_id=self.run_id, key=key, sample=sample, - metric_type=metric_type, ignore_timestamp=ignore_timestamp, - media_step=media_step, ) return metric.json() def metrics( self, keys: List[str], - metric_type: ApiMetricTypeLiteral = "SCALAR", sample: int = 1500, ignore_timestamp: bool = False, - media_step: Optional[int] = 0, ) -> Dict[str, Any]: from swanlab.api.metric import Metrics @@ -186,9 +180,41 @@ def metrics( run_id=self.run_id, keys=keys, sample=sample, - metric_type=metric_type, + metric_type="SCALAR", ignore_timestamp=ignore_timestamp, - media_step=media_step, + ).json() + + def media( + self, + key: str, + step: Optional[int] = 0, + ) -> Dict[str, Any]: + from swanlab.api.metric import Metric + + metric = Metric( + ctx=self._ctx, + project_id=self.project_id, + run_id=self.run_id, + key=key, + metric_type="MEDIA", + media_step=step, + ) + return metric.json() + + def medias( + self, + keys: List[str], + step: Optional[int] = 0, + ) -> Dict[str, Any]: + from swanlab.api.metric import Metrics + + return Metrics( + ctx=self._ctx, + project_id=self.project_id, + run_id=self.run_id, + keys=keys, + metric_type="MEDIA", + media_step=step, ).json() def logs( From 58e44af124c53d617575cf778d180be84968b29a Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Sun, 26 Apr 2026 10:58:41 +0800 Subject: [PATCH 39/52] chore: remove single key signature --- swanlab/api/experiment.py | 35 ----------------------------------- swanlab/api/metric.py | 5 ++++- 2 files changed, 4 insertions(+), 36 deletions(-) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 794d92556..924fde060 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -148,24 +148,6 @@ def column(self, key: str, column_class: Optional[str] = "CUSTOM", column_type: column_type=column_type, ) - def metric( - self, - key: str, - sample: int = 1500, - ignore_timestamp: bool = False, - ) -> Dict[str, Any]: - from swanlab.api.metric import Metric - - metric = Metric( - ctx=self._ctx, - project_id=self.project_id, - run_id=self.run_id, - key=key, - sample=sample, - ignore_timestamp=ignore_timestamp, - ) - return metric.json() - def metrics( self, keys: List[str], @@ -184,23 +166,6 @@ def metrics( ignore_timestamp=ignore_timestamp, ).json() - def media( - self, - key: str, - step: Optional[int] = 0, - ) -> Dict[str, Any]: - from swanlab.api.metric import Metric - - metric = Metric( - ctx=self._ctx, - project_id=self.project_id, - run_id=self.run_id, - key=key, - metric_type="MEDIA", - media_step=step, - ) - return metric.json() - def medias( self, keys: List[str], diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index c01a3d180..e03ddd153 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -335,7 +335,7 @@ def __init__( self._run_id = run_id self._keys = keys self._metric_type = metric_type - self._sample = sample + self._ignore_timestamp = ignore_timestamp self._media_step = media_step self._page_info: Dict[str, Any] = { @@ -343,6 +343,9 @@ def __init__( "metricType": metric_type, "list": [], } + if sample > 1500: + console.warning(f"Get sample = [{sample}], expected <= 1500, will be constrainted automatically..") + self._sample = sample def __iter__(self) -> Iterator[Metric]: if self._metric_type == "SCALAR": From 807bb21af0f2234c609615a771da4cf6ab5a8635 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 10:59:01 +0800 Subject: [PATCH 40/52] feat: support all params control --- swanlab/api/experiment.py | 4 +++ swanlab/api/metric.py | 60 ++++++++++++++++++++++++++++++++------- 2 files changed, 54 insertions(+), 10 deletions(-) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 924fde060..8a4a8a0bc 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -153,6 +153,7 @@ def metrics( keys: List[str], sample: int = 1500, ignore_timestamp: bool = False, + all: bool = False, ) -> Dict[str, Any]: from swanlab.api.metric import Metrics @@ -164,12 +165,14 @@ def metrics( sample=sample, metric_type="SCALAR", ignore_timestamp=ignore_timestamp, + all=all, ).json() def medias( self, keys: List[str], step: Optional[int] = 0, + all: bool = False, ) -> Dict[str, Any]: from swanlab.api.metric import Metrics @@ -180,6 +183,7 @@ def medias( keys=keys, metric_type="MEDIA", media_step=step, + all=all, ).json() def logs( diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index e03ddd153..2e1d162f2 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -5,7 +5,7 @@ @description: Metric 实体类 — 指标序列的查询与操作 """ -from typing import Any, Dict, Iterator, List, Optional, cast +from typing import Any, Dict, Iterator, List, Optional from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings import ApiColumnCsvExportType, ApiResponseType @@ -20,6 +20,8 @@ from swanlab.api.utils import get_properties, validate_metric_log_level, validate_metric_type from swanlab.sdk.internal.pkg import console +_SCALAR_STATISTIC_FIELDS = ("min", "max", "avg", "median", "latest") + class Metric(BaseEntity): """ @@ -170,7 +172,7 @@ def _fetch_scalar(self) -> ApiScalarSeriesType: stat_data = self._extract_first(self._post("/house/metrics/scalar/value", data=payload)) if stat_data is None: return res - for field in ("min", "max", "avg", "median", "latest"): + for field in _SCALAR_STATISTIC_FIELDS: res[field] = stat_data.get(field, {}) return res @@ -277,7 +279,10 @@ def json(self) -> Dict[str, Any]: data = self._ensure_data() if self._metric_type == "SCALAR": - for field in ("min", "max", "avg", "median", "latest"): + if "url" in data: + result.pop("metrics", None) + result["url"] = data["url"] + for field in _SCALAR_STATISTIC_FIELDS: val = data.get(field) if val: result[field] = val @@ -293,10 +298,11 @@ def json(self) -> Dict[str, Any]: result.pop("step", None) if self._ignore_timestamp: - timestamp_items = result.get("metrics", []) or result.get("logs", []) - for item in cast(List[Dict[str, Any]], timestamp_items): - if isinstance(item, dict): - item.pop("timestamp", None) + items = result.get("metrics", []) or result.get("logs", []) + if isinstance(items, list): + for item in items: + if isinstance(item, dict): + item.pop("timestamp", None) return result @@ -325,6 +331,7 @@ def __init__( sample: int = 1500, ignore_timestamp: bool = False, media_step: Optional[int] = None, + all: bool = False, ) -> None: super().__init__(ctx) if metric_type == "LOG": @@ -338,18 +345,23 @@ def __init__( self._ignore_timestamp = ignore_timestamp self._media_step = media_step + self._all = all self._page_info: Dict[str, Any] = { "keys": keys, "metricType": metric_type, "list": [], } + self._sample = sample if sample > 1500: console.warning(f"Get sample = [{sample}], expected <= 1500, will be constrainted automatically..") - self._sample = sample + self._sample = 1500 def __iter__(self) -> Iterator[Metric]: if self._metric_type == "SCALAR": - yield from self._fetch_scalars() + if self._all: + yield from self._fetch_scalars_all() + else: + yield from self._fetch_scalars() else: yield from self._fetch_medias() @@ -389,7 +401,35 @@ def _fetch_scalars(self) -> Iterator[Metric]: if i < len(scalar_list): data["metrics"] = scalar_list[i].get("metrics", []) if i < len(value_list): - for field in ("min", "max", "avg", "median", "latest"): + for field in _SCALAR_STATISTIC_FIELDS: + val = value_list[i].get(field) + if val is not None: + data[field] = val + yield self._build_metric(key, data) + + def _fetch_scalars_all(self) -> Iterator[Metric]: + urls: Dict[str, str] = {} + for key in self._keys: + resp = self._get(f"/experiment/{self._run_id}/column/csv", params={"key": key}) + if resp.ok and resp.data: + if isinstance(resp.data, list) and resp.data: + urls[key] = resp.data[0].get("url", "") + elif isinstance(resp.data, dict): + urls[key] = resp.data.get("url", "") + + payload = Metric._build_scalar_payload(self._project_id, self._run_id, self._keys) + value_resp = self._post("/house/metrics/scalar/value", data=payload) + value_list: List[Dict[str, Any]] = value_resp.ok and isinstance(value_resp.data, list) and value_resp.data or [] + + for i, key in enumerate(self._keys): + data: Dict[str, Any] = { + "projectId": self._project_id, + "experimentId": self._run_id, + "key": key, + "url": urls.get(key, ""), + } + if i < len(value_list): + for field in _SCALAR_STATISTIC_FIELDS: val = value_list[i].get(field) if val is not None: data[field] = val From 9e7ac63b358ea9dc3fa9901fa0b96fe2a993c4a6 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 11:19:32 +0800 Subject: [PATCH 41/52] refactor: simplify metrics code --- swanlab/api/experiment.py | 4 ++-- swanlab/api/metric.py | 39 +++++++++++++++++------------------ swanlab/api/typings/metric.py | 1 + 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 8a4a8a0bc..91b94aac3 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -152,7 +152,7 @@ def metrics( self, keys: List[str], sample: int = 1500, - ignore_timestamp: bool = False, + ignore_timestamp: bool = True, all: bool = False, ) -> Dict[str, Any]: from swanlab.api.metric import Metrics @@ -190,7 +190,7 @@ def logs( self, offset: Optional[int] = 0, level: ApiMetricLogLevelLiteral = "INFO", - ignore_timestamp: bool = False, + ignore_timestamp: bool = True, ) -> Dict[str, Any]: from swanlab.api.metric import Metric diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index 2e1d162f2..ab0b3d99c 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -5,6 +5,8 @@ @description: Metric 实体类 — 指标序列的查询与操作 """ +from __future__ import annotations + from typing import Any, Dict, Iterator, List, Optional from swanlab.api.base import ApiClientContext, BaseEntity @@ -14,13 +16,21 @@ ApiLogSeriesType, ApiMediaItemDataType, ApiMediaSeriesType, - ApiMediaType, ApiScalarSeriesType, ) from swanlab.api.utils import get_properties, validate_metric_log_level, validate_metric_type from swanlab.sdk.internal.pkg import console _SCALAR_STATISTIC_FIELDS = ("min", "max", "avg", "median", "latest") +_METRIC_SHARED_KEYS = frozenset({"project_id", "run_id", "metric_type"}) + + +def _extract_csv_url(data: Any) -> str: + if isinstance(data, list) and data: + return data[0].get("url", "") + if isinstance(data, dict): + return data.get("url", "") + return "" class Metric(BaseEntity): @@ -253,24 +263,14 @@ def _fetch_logs(self) -> ApiLogSeriesType: # ------------------------------------------------------------------ def export_csv(self) -> ApiResponseType: - """ - 导出列数据为 CSV。 - - :return: ApiResponseType,成功时 data 包含临时下载 URL - """ + """导出列数据为 CSV。""" if self.metric_type != "SCALAR": - err_msg = "export_csv() only support SCALAR metric_type" - return ApiResponseType(ok=False, errmsg=err_msg, data=None) + return ApiResponseType(ok=False, errmsg="export_csv() only support SCALAR metric_type", data=None) resp = self._get(f"/experiment/{self._run_id}/column/csv", params={"key": self.key}) if not resp.ok: return resp - - data = resp.data - if isinstance(data, list) and data: - url = data[0].get("url", "") - elif isinstance(data, dict): - url = data.get("url", "") - else: + url = _extract_csv_url(resp.data) + if not url: return ApiResponseType(ok=False, errmsg="Invalid response format", data=None) return ApiResponseType(ok=True, data=ApiColumnCsvExportType(url=url)) @@ -349,6 +349,8 @@ def __init__( self._page_info: Dict[str, Any] = { "keys": keys, "metricType": metric_type, + "projectId": project_id, + "experimentId": run_id, "list": [], } self._sample = sample @@ -412,10 +414,7 @@ def _fetch_scalars_all(self) -> Iterator[Metric]: for key in self._keys: resp = self._get(f"/experiment/{self._run_id}/column/csv", params={"key": key}) if resp.ok and resp.data: - if isinstance(resp.data, list) and resp.data: - urls[key] = resp.data[0].get("url", "") - elif isinstance(resp.data, dict): - urls[key] = resp.data.get("url", "") + urls[key] = _extract_csv_url(resp.data) payload = Metric._build_scalar_payload(self._project_id, self._run_id, self._keys) value_resp = self._post("/house/metrics/scalar/value", data=payload) @@ -473,5 +472,5 @@ def _fetch_medias(self) -> Iterator[Metric]: yield self._build_metric(key, data) def json(self) -> Dict[str, Any]: - self._page_info["list"] = [m.json() for m in self] + self._page_info["list"] = [{k: v for k, v in m.json().items() if k not in _METRIC_SHARED_KEYS} for m in self] return self._page_info diff --git a/swanlab/api/typings/metric.py b/swanlab/api/typings/metric.py index ef0e99536..b478c457d 100644 --- a/swanlab/api/typings/metric.py +++ b/swanlab/api/typings/metric.py @@ -44,6 +44,7 @@ class ApiScalarSeriesType(ApiMetricColumnRefType, total=False): """标量指标序列,包含折线数据和聚合值""" metrics: List[ApiScalarType] + url: str min: ApiScalarType max: ApiScalarType avg: ApiScalarType From 5e18f5e735c9f2a359bff150d8ea275c56e639b2 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 11:40:29 +0800 Subject: [PATCH 42/52] feat: support fetch all medias --- swanlab/api/metric.py | 70 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index ab0b3d99c..c6fe466c4 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -54,6 +54,7 @@ def __init__( data: Optional[Dict[str, Any]] = None, ignore_timestamp: bool = False, media_step: Optional[int] = None, + all: bool = False, ) -> None: super().__init__(ctx) validate_metric_type(metric_type, key) @@ -71,6 +72,7 @@ def __init__( self._offset = log_offset self._log_level = log_level self._media_step = media_step + self._all = all # 类型 → 加载方法 的分发表,新增类型只需在此注册 _FETCH_DISPATCH = { @@ -81,7 +83,10 @@ def __init__( def _ensure_data(self) -> Dict[str, Any]: if self._data is None: - method_name = self._FETCH_DISPATCH.get(self._metric_type, "_fetch_scalar") + if self._metric_type == "MEDIA" and self._all: + method_name = "_fetch_media_all" + else: + method_name = self._FETCH_DISPATCH.get(self._metric_type, "_fetch_scalar") self._data = getattr(self, method_name)() assert self._data is not None return self._data @@ -246,6 +251,27 @@ def _fetch_media(self) -> ApiMediaSeriesType: res["metrics"] = [{"index": data.get("step", 0), "items": items}] return res + def _fetch_media_all(self) -> ApiMediaSeriesType: + res = ApiMediaSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) + payload = self._build_media_payload(self.project_id, self.run_id, [self.key]) + raw_resp = self._post("/house/metrics/f_media", data=payload) + raw_data = self._extract_first(raw_resp) + if raw_data is None: + return res + + prefix = f"{self.project_id}/{self.run_id}" + all_paths = [p for entry in raw_data.get("metrics", []) for p in entry.get("data", [])] + url_map = self._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} + if all_paths: + console.info( + f"Media fetched (all): run_id[{self.run_id}], key[{self.key}] - {len(all_paths)} items, requesting presigned urls..." + ) + res["metrics"] = [ + {"index": entry.get("index", 0), "items": self._build_media_items(entry, url_map)} + for entry in raw_data.get("metrics", []) + ] + return res + def _fetch_logs(self) -> ApiLogSeriesType: res = ApiLogSeriesType(projectId=self.project_id, experimentId=self.run_id, key="LOG") params = self._build_log_params() @@ -293,7 +319,7 @@ def json(self) -> Dict[str, Any]: result.pop("logs", None) result.pop("count", None) - if self._metric_type != "MEDIA": + if self._metric_type != "MEDIA" or "steps" not in data: result.pop("steps", None) result.pop("step", None) @@ -365,7 +391,10 @@ def __iter__(self) -> Iterator[Metric]: else: yield from self._fetch_scalars() else: - yield from self._fetch_medias() + if self._all: + yield from self._fetch_medias_all() + else: + yield from self._fetch_medias() def _build_metric(self, key: str, data: Dict[str, Any]) -> Metric: return Metric( @@ -378,6 +407,7 @@ def _build_metric(self, key: str, data: Dict[str, Any]) -> Metric: ignore_timestamp=self._ignore_timestamp, media_step=self._media_step, data=data, + all=self._all, ) def _fetch_scalars(self) -> Iterator[Metric]: @@ -471,6 +501,40 @@ def _fetch_medias(self) -> Iterator[Metric]: data["metrics"] = [{"index": current_step or 0, "items": items}] yield self._build_metric(key, data) + def _fetch_medias_all(self) -> Iterator[Metric]: + payload = Metric._build_media_payload(self._project_id, self._run_id, self._keys) + raw_resp = self._post("/house/metrics/f_media", data=payload) + if not raw_resp.ok or not raw_resp.data: + return + raw_list = raw_resp.data + if not isinstance(raw_list, list): + return + + prefix = f"{self._project_id}/{self._run_id}" + all_paths = [p for entry in raw_list for m in entry.get("metrics", []) for p in m.get("data", [])] + url_map = Metric._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} + if all_paths: + console.info( + f"Media fetched (all): run_id[{self._run_id}] - {len(all_paths)} items across {len(self._keys)} keys, requesting presigned urls..." + ) + + key_to_entry: Dict[str, Dict[str, Any]] = {e.get("key", ""): e for e in raw_list} + for key in self._keys: + data: Dict[str, Any] = { + "projectId": self._project_id, + "experimentId": self._run_id, + "key": key, + "metrics": [], + } + entry = key_to_entry.get(key) + if entry: + metrics_list: List[Dict[str, Any]] = [] + for m in entry.get("metrics", []): + items = Metric._build_media_items(m, url_map) + metrics_list.append({"index": m.get("index", 0), "items": items}) + data["metrics"] = metrics_list + yield self._build_metric(key, data) + def json(self) -> Dict[str, Any]: self._page_info["list"] = [{k: v for k, v in m.json().items() if k not in _METRIC_SHARED_KEYS} for m in self] return self._page_info From 78918ecc11bf1d2c9e61a7ab551dcc9123ce931d Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 12:57:08 +0800 Subject: [PATCH 43/52] feat: add self_hosted interceptor method --- swanlab/api/__init__.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 04790fb87..a50df9b20 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -16,7 +16,9 @@ from .column import Column, Columns from .experiment import Experiment, Experiments from .project import Project, Projects +from .selfhosted import SelfHosted from .typings.common import PaginatedQuery +from .typings.selfhosted import ApiSelfHostedInfoType from .user import User from .workspace import Workspace, Workspaces @@ -76,10 +78,41 @@ def __init__( ctx = ApiClientContext(client=_client, web_host=web_host, api_host=api_host, username=username, name=name) super().__init__(ctx) + # 私有化信息 + self._self_hosted_info: Optional[ApiSelfHostedInfoType] = None + def json(self) -> dict: """Api 非数据实体,返回空字典。""" return {} + # ------------------------------------------------------------------ + # 私有化校验 + # ------------------------------------------------------------------ + + def _fetch_self_hosted_info(self) -> ApiSelfHostedInfoType: + """获取并缓存私有化实例信息。""" + if self._self_hosted_info is None: + resp = self._get("/self_hosted/info") + if not resp.ok or not resp.data: + raise ValueError("Failed to get self-hosted instance info.") + self._self_hosted_info = resp.data + assert self._self_hosted_info is not None + return self._self_hosted_info + + def _validate_self_hosted(self) -> None: + """校验私有化实例未过期。""" + info = self._fetch_self_hosted_info() + if info.get("expired", True): + raise ValueError("SwanLab self-hosted instance has expired.") + + def _validate_self_hosted_root(self) -> None: + """校验私有化实例未过期且当前用户拥有 root 权限。""" + info = self._fetch_self_hosted_info() + if info.get("expired", True): + raise ValueError("SwanLab self-hosted instance has expired.") + if not info.get("root", False): + raise ValueError("You don't have permission to perform this action. Please login as a root user.") + @staticmethod def _resolve_credentials( api_key: Optional[str], @@ -252,5 +285,11 @@ def column( """ return Column(self._ctx, path=path, key=key, column_class=column_class, column_type=column_type) + # ------- + # 私有化相关接口 + # -------- + def self_hosted(self) -> SelfHosted: + return SelfHosted(self._ctx) + __all__ = ["Api"] From 09c2418afda396b397e35ba30671413e5c4deeb4 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 15:49:55 +0800 Subject: [PATCH 44/52] feat: split self-hosted info --- swanlab/api/__init__.py | 32 -------------------------------- swanlab/api/selfhosted.py | 37 ++++++++++++++++++++++++++++++------- 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index a50df9b20..9da084111 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -18,7 +18,6 @@ from .project import Project, Projects from .selfhosted import SelfHosted from .typings.common import PaginatedQuery -from .typings.selfhosted import ApiSelfHostedInfoType from .user import User from .workspace import Workspace, Workspaces @@ -78,41 +77,10 @@ def __init__( ctx = ApiClientContext(client=_client, web_host=web_host, api_host=api_host, username=username, name=name) super().__init__(ctx) - # 私有化信息 - self._self_hosted_info: Optional[ApiSelfHostedInfoType] = None - def json(self) -> dict: """Api 非数据实体,返回空字典。""" return {} - # ------------------------------------------------------------------ - # 私有化校验 - # ------------------------------------------------------------------ - - def _fetch_self_hosted_info(self) -> ApiSelfHostedInfoType: - """获取并缓存私有化实例信息。""" - if self._self_hosted_info is None: - resp = self._get("/self_hosted/info") - if not resp.ok or not resp.data: - raise ValueError("Failed to get self-hosted instance info.") - self._self_hosted_info = resp.data - assert self._self_hosted_info is not None - return self._self_hosted_info - - def _validate_self_hosted(self) -> None: - """校验私有化实例未过期。""" - info = self._fetch_self_hosted_info() - if info.get("expired", True): - raise ValueError("SwanLab self-hosted instance has expired.") - - def _validate_self_hosted_root(self) -> None: - """校验私有化实例未过期且当前用户拥有 root 权限。""" - info = self._fetch_self_hosted_info() - if info.get("expired", True): - raise ValueError("SwanLab self-hosted instance has expired.") - if not info.get("root", False): - raise ValueError("You don't have permission to perform this action. Please login as a root user.") - @staticmethod def _resolve_credentials( api_key: Optional[str], diff --git a/swanlab/api/selfhosted.py b/swanlab/api/selfhosted.py index 075ce2e44..df745f745 100644 --- a/swanlab/api/selfhosted.py +++ b/swanlab/api/selfhosted.py @@ -5,10 +5,10 @@ @description: SelfHosted 实体类 — 私有化部署实例的查询与管理 """ -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Iterator, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.common import ApiResponseType +from swanlab.api.typings.common import ApiResponseType, PaginatedQuery from swanlab.api.typings.selfhosted import ApiLicensePlanLiteral, ApiSelfHostedInfoType from swanlab.api.utils import get_properties @@ -55,6 +55,25 @@ def plan(self) -> ApiLicensePlanLiteral: def seats(self) -> int: return self._ensure_data().get("seats", 0) + # ================================ + # 权限校验 + # ================================ + + @staticmethod + def validate_expire(info: ApiSelfHostedInfoType) -> None: + if info.get("expired", True): + raise ValueError("SwanLab self-hosted instance has expired.") + + @staticmethod + def validate_root(info: ApiSelfHostedInfoType) -> None: + SelfHosted.validate_expire(info) + if not info.get("root", False): + raise ValueError("You don't have permission to perform this action. Please login as a root user.") + + # ================================ + # 管理操作(root 限定) + # ================================ + def create_user(self, username: str, password: str) -> ApiResponseType: """ 添加用户(私有化管理员限定)。 @@ -62,18 +81,22 @@ def create_user(self, username: str, password: str) -> ApiResponseType: :param username: 待创建用户名 :param password: 待创建用户密码 """ + SelfHosted.validate_root(self._ensure_data()) data = {"users": [{"username": username, "password": password}]} return self._post("/self_hosted/users", data=data) - def get_users(self, page: int = 1, size: int = 20) -> ApiResponseType: + def get_users(self, page: int = 1, size: int = 20, all: bool = False) -> Iterator[dict]: """ 分页获取用户(管理员限定)。 - :param page: 页码 - :param size: 每页大小 + :param page: 起始页码,默认 1 + :param size: 每页大小,默认 20 + :param all: 是否获取全部数据,默认 False """ - params = {"page": page, "size": size} - return self._get("/self_hosted/users", params=params) + SelfHosted.validate_root(self._ensure_data()) + query = PaginatedQuery(page=page, size=size, all=all) + page_info: Dict[str, Any] = {"total": 0, "pages": 0} + yield from self._paginate("/self_hosted/users", query, page_info=page_info) def json(self) -> Dict[str, Any]: return get_properties(self) From c8c923d4f23f96f1d34ee25fb9d7f52e7767779c Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 18:30:22 +0800 Subject: [PATCH 45/52] feat: add create project --- swanlab/api/__init__.py | 1 - swanlab/api/project.py | 2 +- swanlab/api/typings/experiment.py | 3 ++- swanlab/api/typings/project.py | 9 +++++-- swanlab/api/utils.py | 15 +++++++++++ swanlab/api/workspace.py | 41 ++++++++++++++++++++++++++++--- 6 files changed, 63 insertions(+), 8 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index 9da084111..a7dd42bc5 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -165,7 +165,6 @@ def run(self, path: str) -> Experiment: :param path: 实验路径,格式为 'username/project/run_id' """ - return Experiment(self._ctx, path=path) def runs( diff --git a/swanlab/api/project.py b/swanlab/api/project.py index 2bcdb0114..e27dd92d7 100644 --- a/swanlab/api/project.py +++ b/swanlab/api/project.py @@ -75,7 +75,7 @@ def labels(self) -> List[ApiProjectLabelType]: @property def count(self) -> ApiProjectCountType: - return self._ensure_data().get("_count", {}) + return cast(ApiProjectCountType, self._ensure_data().get("_count", {})) def runs( self, diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index eb8f6a47e..a2a1c5fe4 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -67,8 +67,9 @@ class ApiSortItem(TypedDict): # --------------------------------------------------------------------------- # 实验实体 # --------------------------------------------------------------------------- -class ApiExperimentLabelType(TypedDict): +class ApiExperimentLabelType(TypedDict, total=False): name: str + colors: List[str] class ApiExperimentProfileType(TypedDict): diff --git a/swanlab/api/typings/project.py b/swanlab/api/typings/project.py index f6028961b..d1d30762f 100644 --- a/swanlab/api/typings/project.py +++ b/swanlab/api/typings/project.py @@ -11,8 +11,10 @@ from .workspace import ApiWorkspaceType -class ApiProjectLabelType(TypedDict): +class ApiProjectLabelType(TypedDict, total=False): name: str + colors: List[str] + cuid: str class ApiProjectCountType(TypedDict): @@ -22,7 +24,7 @@ class ApiProjectCountType(TypedDict): clones: int -class ApiProjectType(TypedDict): +class ApiProjectType(TypedDict, total=False): cuid: str name: str username: str @@ -32,3 +34,6 @@ class ApiProjectType(TypedDict): group: Dict[str, str] projectLabels: List[ApiProjectLabelType] _count: ApiProjectCountType + createdAt: str + updatedAt: str + role: str diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index ae199ada4..71e7ff6d7 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -5,6 +5,7 @@ @description: swanlab/api 实体层工具函数 """ +import re from typing import Any, Dict, List, Optional, Set, Tuple, Type, get_args, get_type_hints from swanlab.api.typings.common import ( @@ -71,6 +72,8 @@ def resovle_run_path(path: str) -> Tuple[str, str]: _VALID_ORDERS = frozenset(get_args(ApiSortOrderLiteral)) _STABLE_KEYS = frozenset(get_args(ApiFilterStableKeyLiteral)) +_PROJECT_NAME_RE = re.compile(r"^[0-9a-zA-Z\-_.+]+$") + # 列相关校验常量 _VALID_COLUMN_CLASSES = frozenset(get_args(ApiColumnClassLiteral)) _VALID_COLUMN_DATA_TYPES = frozenset(get_args(ApiColumnDataTypeLiteral)) @@ -172,3 +175,15 @@ def parse_column_data_type(column_type: str): return "SCALAR" # 新加入的类型默认指定为 media return "MEDIA" + + +# --------------------------------------------------------------------------- +# 创建项目 / 实验的参数校验 +# --------------------------------------------------------------------------- + + +def validate_project_name(name: str) -> None: + if not 1 <= len(name) <= 100: + raise ValueError("Project name must be between 1 and 100 characters.") + if not _PROJECT_NAME_RE.match(name): + raise ValueError("Project name can only contain 0-9, a-z, A-Z, -, _, ., +") diff --git a/swanlab/api/workspace.py b/swanlab/api/workspace.py index 2a3583ef7..76d3dd8a5 100644 --- a/swanlab/api/workspace.py +++ b/swanlab/api/workspace.py @@ -5,12 +5,17 @@ @description: Workspace 实体类 — 工作空间的查询 """ -from typing import Any, Dict, Iterator, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.common import PaginatedQuery +from swanlab.api.typings.common import ApiVisibilityLiteral, PaginatedQuery +from swanlab.api.typings.project import ApiProjectType from swanlab.api.typings.workspace import ApiWorkspaceLiteral, ApiWorkspaceProfileType, ApiWorkspaceType -from swanlab.api.utils import get_properties, strip_dict +from swanlab.api.utils import get_properties, strip_dict, validate_project_name +from swanlab.sdk.internal.pkg import safe + +if TYPE_CHECKING: + from swanlab.api.project import Project class Workspace(BaseEntity): @@ -78,6 +83,36 @@ def projects( detail=detail, ) + def create_project( + self, + name: str, + *, + visibility: ApiVisibilityLiteral = "PRIVATE", + description: Optional[str] = None, + ) -> Optional["Project"]: + """ + 在此工作空间下创建项目。 + + :param name: 项目名称 (1-100 字符,仅支持 0-9a-zA-Z-_.+) + :param visibility: 可见性,PUBLIC 或 PRIVATE,默认 PRIVATE + :param description: 项目描述 + """ + from swanlab.api.project import Project + + with safe.block(message=None): + validate_project_name(name) + + body: Dict[str, Any] = {"name": name, "visibility": visibility, "username": self.username} + if description: + body["description"] = description + resp = self._post("/project", data=body) + if not resp.ok: + return None + data = resp.data + path = data.get("path", "") + return Project(self._ctx, path=path, data=cast(ApiProjectType, data)) + return None + def json(self) -> Dict[str, Any]: return get_properties(self) From fe0fc65ffe8101d79400889b84fd4471b9a8a3c3 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 19:12:00 +0800 Subject: [PATCH 46/52] feat(ut): add param validation --- swanlab/api/__init__.py | 14 +- swanlab/api/experiment.py | 6 +- swanlab/api/metric.py | 5 +- swanlab/api/selfhosted.py | 4 + swanlab/api/typings/common.py | 2 +- swanlab/api/typings/experiment.py | 2 +- swanlab/api/typings/metric.py | 4 +- swanlab/api/typings/project.py | 1 - swanlab/api/typings/workspace.py | 2 +- swanlab/api/utils.py | 27 ++- swanlab/api/workspace.py | 2 + tests/unit/api/__init__.py | 0 tests/unit/api/conftest.py | 17 ++ tests/unit/api/test_api.py | 305 ++++++++++++++++++++++++++++++ tests/unit/api/test_utils.py | 193 +++++++++++++++++++ 15 files changed, 570 insertions(+), 14 deletions(-) create mode 100644 tests/unit/api/__init__.py create mode 100644 tests/unit/api/conftest.py create mode 100644 tests/unit/api/test_api.py create mode 100644 tests/unit/api/test_utils.py diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index a7dd42bc5..acc62ceb3 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -19,6 +19,7 @@ from .selfhosted import SelfHosted from .typings.common import PaginatedQuery from .user import User +from .utils import validate_api_path, validate_non_empty_string from .workspace import Workspace, Workspaces @@ -93,8 +94,9 @@ def _resolve_credentials( """ if api_key is None: api_key = global_settings.api_key - if api_key is None: + if not isinstance(api_key, str) or not api_key.strip(): raise AuthenticationError("No API key found. Please login with `swanlab login` or pass api_key parameter.") + api_key = api_key.strip() api_host: str = nrc.fmt(host) if host is not None else global_settings.api_host resolved_web_host: str = nrc.fmt(web_host) if web_host is not None else global_settings.web_host @@ -115,6 +117,7 @@ def workspace(self, username: Optional[str] = None) -> Workspace: """ if username is None: username = self._ctx.username + validate_api_path(username, segments=1, label="workspace") return Workspace(self._ctx, username=username) def workspaces(self, username: Optional[str] = None) -> Workspaces: @@ -125,6 +128,7 @@ def workspaces(self, username: Optional[str] = None) -> Workspaces: """ if username is None: username = self._ctx.username + validate_api_path(username, segments=1, label="workspace") return Workspaces(self._ctx, username=username) def project(self, path: str) -> Project: @@ -133,6 +137,7 @@ def project(self, path: str) -> Project: :param path: 项目路径,格式为 'username/project-name' """ + validate_api_path(path, segments=2, label="project") return Project(self._ctx, path=path) def projects( @@ -156,6 +161,7 @@ def projects( :param size: 每页数量,默认 20 :param all: 是否获取全部数据,默认 False """ + validate_api_path(path, segments=1, label="workspace") query = PaginatedQuery(page=page, size=size, search=search, sort=sort, all=all) return Projects(self._ctx, path=path, query=query, detail=detail) @@ -165,6 +171,7 @@ def run(self, path: str) -> Experiment: :param path: 实验路径,格式为 'username/project/run_id' """ + validate_api_path(path, segments=3, label="run") return Experiment(self._ctx, path=path) def runs( @@ -182,6 +189,7 @@ def runs( :param groups: 分组规则列表,每项为 {key, type} :param sorts: 排序规则列表,每项为 {key, type, order} """ + validate_api_path(path, segments=2, label="project") return Experiments(self._ctx, path=path, filters=filters, groups=groups, sorts=sorts, mode="post") def runs_get( @@ -199,6 +207,7 @@ def runs_get( :param size: 每页数量,默认 20 :param all: 是否获取全部数据,默认 False """ + validate_api_path(path, segments=2, label="project") query = PaginatedQuery(page=page, size=size, all=all) return Experiments(self._ctx, path=path, query=query, mode="get") @@ -226,6 +235,7 @@ def columns( :param column_type: 列的类型,如 FLOAT、STRING、IMAGE 等 :param all: 是否获取全部数据,默认 False """ + validate_api_path(path, segments=3, label="run") query = PaginatedQuery(page=page, size=size, search=search, all=all) return Columns( self._ctx, @@ -250,6 +260,8 @@ def column( :param column_class: 列的分类,CUSTOM 或 SYSTEM,默认 CUSTOM :param column_type: 列的类型,如 FLOAT、STRING、IMAGE 等,默认为 None """ + validate_api_path(path, segments=3, label="run") + validate_non_empty_string(key, label="column key") return Column(self._ctx, path=path, key=key, column_class=column_class, column_type=column_type) # ------- diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 91b94aac3..8f2e03d97 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -310,9 +310,9 @@ def _iter_filtered(self) -> Iterator[Experiment]: resp = self._post( f"/project/{self._proj_path}/runs/shows", data={ - "filters": validate_update_active(self._filters, validate_filter), - "groups": validate_update_active(self._groups, validate_group), - "sorts": validate_update_active(self._sorts, validate_sort), + "filters": validate_update_active(self._filters, validate_filter, label="filters"), + "groups": validate_update_active(self._groups, validate_group, label="groups"), + "sorts": validate_update_active(self._sorts, validate_sort, label="sorts"), }, ) if not resp.ok: diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index c6fe466c4..2dedfba56 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -360,10 +360,11 @@ def __init__( all: bool = False, ) -> None: super().__init__(ctx) + if not isinstance(keys, list) or not keys or any(not isinstance(key, str) or not key.strip() for key in keys): + raise ValueError("keys must be a non-empty list") + validate_metric_type(metric_type, keys[0]) if metric_type == "LOG": raise ValueError("Metrics does not support LOG metric_type, use Experiment.logs() instead") - if not keys: - raise ValueError("keys must be a non-empty list") self._project_id = project_id self._run_id = run_id self._keys = keys diff --git a/swanlab/api/selfhosted.py b/swanlab/api/selfhosted.py index df745f745..fde6b904b 100644 --- a/swanlab/api/selfhosted.py +++ b/swanlab/api/selfhosted.py @@ -82,6 +82,10 @@ def create_user(self, username: str, password: str) -> ApiResponseType: :param password: 待创建用户密码 """ SelfHosted.validate_root(self._ensure_data()) + if not isinstance(username, str) or not username.strip(): + raise ValueError("username must be a non-empty string") + if not isinstance(password, str) or not password.strip(): + raise ValueError("password must be a non-empty string") data = {"users": [{"username": username, "password": password}]} return self._post("/self_hosted/users", data=data) diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index a95eecc02..b79833a67 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -38,7 +38,7 @@ ApiColumnDataTypeLiteral = Literal[ "FLOAT", "BOOLEAN", - "STRING" + "STRING", # media 类型 "IMAGE", "AUDIO", diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index a2a1c5fe4..958f725c8 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -11,7 +11,7 @@ - SCALAR: 训练过程中记录的标量指标最新值(动态 key,如 train/loss) """ -from typing import Any, Dict, List, Literal, Optional, TypedDict +from typing import Any, Dict, List, TypedDict from .common import ( ApiExperimentTypeLiteral, diff --git a/swanlab/api/typings/metric.py b/swanlab/api/typings/metric.py index b478c457d..e4a6fb3b8 100644 --- a/swanlab/api/typings/metric.py +++ b/swanlab/api/typings/metric.py @@ -5,9 +5,7 @@ @description: 指标数据类型定义(用于 column 采样值) """ -from typing import Any, Dict, List, Literal, TypedDict, Union - -from .common import ApiMetricTypeLiteral, ApiMetricXAxisLiteral +from typing import Any, List, TypedDict, Union # --------------------------------------------------------------------------- # Common — 通用指标类型定义 diff --git a/swanlab/api/typings/project.py b/swanlab/api/typings/project.py index d1d30762f..1012299da 100644 --- a/swanlab/api/typings/project.py +++ b/swanlab/api/typings/project.py @@ -8,7 +8,6 @@ from typing import Dict, List, TypedDict from .common import ApiVisibilityLiteral -from .workspace import ApiWorkspaceType class ApiProjectLabelType(TypedDict, total=False): diff --git a/swanlab/api/typings/workspace.py b/swanlab/api/typings/workspace.py index 4d3edefcc..c05db67fc 100644 --- a/swanlab/api/typings/workspace.py +++ b/swanlab/api/typings/workspace.py @@ -5,7 +5,7 @@ @description: 公共查询 API 工作空间类型定义 """ -from typing import Dict, TypedDict +from typing import TypedDict from .common import ApiRoleLiteral, ApiWorkspaceLiteral diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index 71e7ff6d7..bf6bfe574 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -63,6 +63,21 @@ def resovle_run_path(path: str) -> Tuple[str, str]: ) +def validate_api_path(path: str, *, segments: int, label: str) -> None: + """校验公开 API 入口 path 参数的段数和空白字符。""" + if not isinstance(path, str): + raise ValueError(f"{label} path must be a string") + parts = path.split("/") + if path != path.strip() or len(parts) != segments or any(part != part.strip() or not part for part in parts): + raise ValueError(f"{label} path must contain {segments} non-empty segment(s), got {path!r}") + + +def validate_non_empty_string(value: str, *, label: str) -> None: + """校验公开 API 入口中的非空字符串参数。""" + if not isinstance(value, str) or not value.strip(): + raise ValueError(f"{label} must be a non-empty string") + + # --------------------------------------------------------------------------- # POST /runs/shows 参数校验常量(从 typings 中的 Literal 类型提取,避免重复定义) # --------------------------------------------------------------------------- @@ -85,6 +100,8 @@ def resovle_run_path(path: str) -> Tuple[str, str]: def _check_required(item: Dict[str, Any], keys: Set[str]) -> None: + if not isinstance(item, dict): + raise ValueError(f"Expected dict item, got {type(item).__name__}") missing = keys - item.keys() if missing: raise ValueError(f"Missing required fields: {sorted(missing)}, got {sorted(item.keys())}") @@ -116,7 +133,7 @@ def validate_metric_type(metric_type: str, key: Optional[str] = None) -> None: """校验 metric_type 的合法性。非 LOG 类型必须提供非空 key。""" if metric_type not in _VALID_METRIC_TYPES and metric_type != "LOG": raise ValueError(f"Invalid metric_type: {metric_type!r}, expected one of {sorted(_VALID_METRIC_TYPES)}") - if metric_type != "LOG" and not key: + if metric_type != "LOG" and (not isinstance(key, str) or not key.strip()): raise ValueError(f"key is required for metric_type {metric_type!r}, got key={key!r}") @@ -145,11 +162,19 @@ def validate_sort(item: Dict[str, Any]) -> None: def validate_update_active( items: Optional[List[Dict[str, Any]]], validator, + *, + label: str = "items", ) -> List[Dict[str, Any]]: """校验每个 item 并补充 active: True,返回可直接发送的列表。""" + if items is None: + return [] + if not isinstance(items, list): + raise ValueError(f"{label} must be a list") if not items: return [] for item in items: + if not isinstance(item, dict): + raise ValueError(f"{label} items must be dicts") validator(item) return [{**item, "active": True} for item in items] diff --git a/swanlab/api/workspace.py b/swanlab/api/workspace.py index 76d3dd8a5..cabca484b 100644 --- a/swanlab/api/workspace.py +++ b/swanlab/api/workspace.py @@ -101,6 +101,8 @@ def create_project( with safe.block(message=None): validate_project_name(name) + if visibility not in ("PUBLIC", "PRIVATE"): + raise ValueError("Invalid visibility, expected PUBLIC or PRIVATE.") body: Dict[str, Any] = {"name": name, "visibility": visibility, "username": self.username} if description: diff --git a/tests/unit/api/__init__.py b/tests/unit/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py new file mode 100644 index 000000000..ffb871ccc --- /dev/null +++ b/tests/unit/api/conftest.py @@ -0,0 +1,17 @@ +from unittest.mock import MagicMock + +import pytest + +from swanlab.api.base import ApiClientContext + + +@pytest.fixture +def mock_ctx(): + client = MagicMock() + return ApiClientContext( + client=client, + web_host="https://swanlab.cn", + api_host="https://api.swanlab.cn", + username="testuser", + name="Test User", + ) diff --git a/tests/unit/api/test_api.py b/tests/unit/api/test_api.py new file mode 100644 index 000000000..94f438e79 --- /dev/null +++ b/tests/unit/api/test_api.py @@ -0,0 +1,305 @@ +""" +@author: caddiesnew +@time: 2026/4/27 +@description: swanlab/api 实体类 4xx / 错误场景单测 +""" + +import importlib +from types import SimpleNamespace +from typing import Any, List, cast +from unittest.mock import MagicMock + +import pytest +import requests + +from swanlab.api import Api +from swanlab.api.base import ApiClientContext, BaseEntity +from swanlab.api.column import Columns +from swanlab.api.experiment import Experiment, Experiments +from swanlab.api.metric import Metric, Metrics +from swanlab.api.project import Project +from swanlab.api.selfhosted import SelfHosted +from swanlab.api.typings.common import PaginatedQuery +from swanlab.api.typings.selfhosted import ApiSelfHostedInfoType +from swanlab.api.workspace import Workspace +from swanlab.exceptions import AuthenticationError + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +MockResponse = MagicMock + + +def _api_response(data=None): + """构造 Client.get/post 返回值。""" + r = MagicMock() + r.data = data + r.raw = MagicMock() + return r + + +@pytest.fixture +def ctx(): + client = MagicMock() + return ApiClientContext( + client=client, + web_host="https://swanlab.cn", + api_host="https://api.swanlab.cn", + username="testuser", + name="Test User", + ) + + +@pytest.fixture +def ctx_404(ctx): + """Client 所有 HTTP 方法均抛出 HTTPError(模拟 4xx)。""" + err = requests.exceptions.HTTPError("404 Not Found") + ctx.client.get.side_effect = err + ctx.client.post.side_effect = err + ctx.client.put.side_effect = err + ctx.client.delete.side_effect = err + return ctx + + +@pytest.fixture +def api(ctx): + instance = Api.__new__(Api) + BaseEntity.__init__(instance, ctx) + return instance + + +# --------------------------------------------------------------------------- +# Api 入口 — 参数校验 +# --------------------------------------------------------------------------- +class TestApiEntryValidation: + def test_missing_api_key_raises(self, monkeypatch): + api_module = importlib.import_module("swanlab.api") + monkeypatch.setattr( + api_module, + "global_settings", + SimpleNamespace(api_key=None, api_host="https://api.swanlab.cn", web_host="https://swanlab.cn"), + ) + + with pytest.raises(AuthenticationError, match="No API key"): + Api._resolve_credentials(None, None, None) + + @pytest.mark.parametrize("api_key", ["", " "]) + def test_blank_api_key_raises(self, api_key): + with pytest.raises(AuthenticationError, match="No API key"): + Api._resolve_credentials(api_key, "https://api.swanlab.cn", "https://swanlab.cn") + + def test_blank_host_raises(self): + with pytest.raises(ValueError, match="Host cannot be empty"): + Api._resolve_credentials("test-key", " ", "https://swanlab.cn") + + def test_projects_invalid_page_raises(self, api): + with pytest.raises(ValueError, match="page must be >= 1"): + api.projects("testuser", page=0) + + @pytest.mark.parametrize( + ("method_name", "path"), + [ + ("project", "testuser"), + ("run", "testuser/project"), + ("runs", "testuser/project/run1"), + ("columns", "testuser/project"), + ], + ) + def test_factory_methods_reject_invalid_path_shapes(self, api, method_name, path): + with pytest.raises(ValueError, match="path"): + getattr(api, method_name)(path) + + def test_column_rejects_empty_key(self, api): + with pytest.raises(ValueError, match="key"): + api.column("testuser/project/run1", key="") + + @pytest.mark.parametrize("column_type", ["STRING", "IMAGE"]) + def test_columns_accept_documented_column_types(self, api, column_type): + columns = api.columns("testuser/project/run1", column_type=column_type) + assert isinstance(columns, Columns) + + +# --------------------------------------------------------------------------- +# SelfHosted — 权限拒绝 +# --------------------------------------------------------------------------- +def _sh_info(**overrides) -> ApiSelfHostedInfoType: + base = {"enabled": True, "expired": False, "root": True, "plan": "free", "seats": 10} + base.update(overrides) + return cast(ApiSelfHostedInfoType, base) + + +class TestSelfHostedPermission: + def test_create_user_expired(self, ctx): + sh = SelfHosted(ctx, data=_sh_info(expired=True)) + with pytest.raises(ValueError, match="expired"): + sh.create_user("newuser", "pass123") + + def test_create_user_not_root(self, ctx): + sh = SelfHosted(ctx, data=_sh_info(root=False)) + with pytest.raises(ValueError, match="root"): + sh.create_user("newuser", "pass123") + + def test_get_users_not_root(self, ctx): + sh = SelfHosted(ctx, data=_sh_info(root=False)) + with pytest.raises(ValueError, match="root"): + list(sh.get_users()) + + def test_create_user_4xx(self, ctx): + sh = SelfHosted(ctx, data=_sh_info()) + ctx.client.post.side_effect = requests.exceptions.HTTPError("400 Bad Request") + resp = sh.create_user("newuser", "pass123") + assert not resp.ok + + @pytest.mark.parametrize(("username", "password"), [("", "pass123"), ("newuser", "")]) + def test_create_user_rejects_blank_credentials(self, ctx, username, password): + sh = SelfHosted(ctx, data=_sh_info()) + with pytest.raises(ValueError, match="username|password"): + sh.create_user(username, password) + + def test_get_users_4xx_yields_nothing(self, ctx): + sh = SelfHosted(ctx, data=_sh_info()) + ctx.client.get.side_effect = requests.exceptions.HTTPError("500 Internal") + result = list(sh.get_users()) + assert result == [] + + +# --------------------------------------------------------------------------- +# Workspace — create_project 错误 +# --------------------------------------------------------------------------- +class TestWorkspaceCreateProject: + def test_invalid_name_returns_none(self, ctx): + ws = Workspace(ctx, username="testuser") + result = ws.create_project("bad name!") + assert result is None + + def test_empty_name_returns_none(self, ctx): + ws = Workspace(ctx, username="testuser") + result = ws.create_project("") + assert result is None + + def test_invalid_visibility_returns_none(self, ctx): + ws = Workspace(ctx, username="testuser") + result = ws.create_project("valid-name", visibility=cast(Any, "SECRET")) + assert result is None + ctx.client.post.assert_not_called() + + def test_api_error_returns_none(self, ctx): + ws = Workspace(ctx, username="testuser") + ctx.client.post.side_effect = requests.exceptions.HTTPError("500") + result = ws.create_project("valid-name") + assert result is None + + def test_success(self, ctx): + ws = Workspace(ctx, username="testuser") + ctx.client.post.return_value = _api_response( + {"path": "testuser/valid-name", "name": "valid-name", "cuid": "cuid123"} + ) + proj = ws.create_project("valid-name") + assert proj is not None + assert proj.name == "valid-name" + + +# --------------------------------------------------------------------------- +# Entity lazy-load 4xx — 确保不 crash,返回空默认值 +# --------------------------------------------------------------------------- +class TestEntityLazyLoad4xx: + def test_workspace_returns_empty(self, ctx_404): + ws = Workspace(ctx_404, username="testuser") + assert ws.name == "" + assert ws.username == "" + + def test_project_returns_empty(self, ctx_404): + proj = Project(ctx_404, path="user/proj") + assert proj.name == "" + assert proj.path == "" + + def test_experiment_returns_empty(self, ctx_404): + exp = Experiment(ctx_404, path="user/proj/run123") + assert exp.name == "" + assert exp.state == "" + + def test_selfhosted_returns_defaults(self, ctx_404): + sh = SelfHosted(ctx_404) + assert sh.enabled is False + assert sh.expired is False + + def test_project_delete_4xx_returns_false(self, ctx_404): + proj = Project(ctx_404, path="user/proj") + assert proj.delete() is False + + def test_experiment_delete_4xx_returns_false(self, ctx_404): + exp = Experiment(ctx_404, path="user/proj/run123") + assert exp.delete() is False + + +# --------------------------------------------------------------------------- +# Column / Columns — 校验 + 4xx +# --------------------------------------------------------------------------- +class TestColumnValidation: + def test_columns_invalid_type_raises(self, ctx): + with pytest.raises(ValueError, match="Invalid column_type"): + Columns(ctx, path="user/proj/run1", query=PaginatedQuery(), column_type="INVALID") + + def test_columns_invalid_class_raises(self, ctx): + with pytest.raises(ValueError, match="Invalid column_class"): + Columns(ctx, path="user/proj/run1", query=PaginatedQuery(), column_class="INVALID") + + +# --------------------------------------------------------------------------- +# Metric / Metrics — 校验 +# --------------------------------------------------------------------------- +class TestMetricValidation: + def test_metric_invalid_type_raises(self, ctx): + with pytest.raises(ValueError, match="Invalid metric_type"): + Metric(ctx, project_id="p1", run_id="r1", key="loss", metric_type="INVALID") + + def test_metric_invalid_log_level_raises(self, ctx): + with pytest.raises(ValueError, match="Invalid metric log level"): + Metric(ctx, project_id="p1", run_id="r1", key="LOG", metric_type="LOG", log_level="VERBOSE") + + def test_metric_scalar_no_key_raises(self, ctx): + with pytest.raises(ValueError, match="key is required"): + Metric(ctx, project_id="p1", run_id="r1", key="", metric_type="SCALAR") + + def test_metrics_empty_keys_raises(self, ctx): + with pytest.raises(ValueError, match="non-empty"): + Metrics(ctx, project_id="p1", run_id="r1", keys=[], metric_type="SCALAR") + + def test_metrics_invalid_type_raises(self, ctx): + with pytest.raises(ValueError, match="Invalid metric_type"): + Metrics(ctx, project_id="p1", run_id="r1", keys=["loss"], metric_type=cast(Any, "INVALID")) + + @pytest.mark.parametrize("keys", ["loss", [""], None]) + def test_metrics_invalid_keys_raises(self, ctx, keys): + with pytest.raises(ValueError, match="keys must be a non-empty list"): + Metrics(ctx, project_id="p1", run_id="r1", keys=cast(List[str], keys), metric_type="SCALAR") + + +# --------------------------------------------------------------------------- +# Experiments POST 过滤 — 校验 +# --------------------------------------------------------------------------- +class TestExperimentsFilterValidation: + def test_invalid_filter_raises_on_iter(self, ctx): + bad_filters = [{"key": "name"}] # missing type, op, value + exps = Experiments(ctx, path="user/proj", filters=bad_filters, mode="post") + with pytest.raises(ValueError, match="Missing required"): + list(exps) + + def test_non_list_filters_raise_on_iter(self, ctx): + bad_filters = {"key": "name", "type": "STABLE", "op": "EQ", "value": ["test"]} + exps = Experiments(ctx, path="user/proj", filters=cast(Any, bad_filters), mode="post") + with pytest.raises(ValueError, match="filters must be a list"): + list(exps) + + def test_invalid_group_raises_on_iter(self, ctx): + bad_groups = [{"key": "cluster", "type": "INVALID"}] + exps = Experiments(ctx, path="user/proj", groups=bad_groups, mode="post") + with pytest.raises(ValueError, match="Invalid type"): + list(exps) + + def test_invalid_sort_raises_on_iter(self, ctx): + bad_sorts = [{"key": "name", "type": "STABLE", "order": "RANDOM"}] + exps = Experiments(ctx, path="user/proj", sorts=bad_sorts, mode="post") + with pytest.raises(ValueError, match="Invalid sort order"): + list(exps) diff --git a/tests/unit/api/test_utils.py b/tests/unit/api/test_utils.py new file mode 100644 index 000000000..12378758f --- /dev/null +++ b/tests/unit/api/test_utils.py @@ -0,0 +1,193 @@ +""" +@author: caddiesnew +@time: 2026/4/27 +@description: swanlab/api 校验函数单测 +""" + +from typing import cast + +import pytest + +from swanlab.api.selfhosted import SelfHosted +from swanlab.api.typings.common import PaginatedQuery +from swanlab.api.typings.selfhosted import ApiSelfHostedInfoType +from swanlab.api.utils import ( + validate_column_params, + validate_filter, + validate_group, + validate_metric_log_level, + validate_metric_type, + validate_project_name, + validate_sort, +) + + +# --------------------------------------------------------------------------- +# validate_project_name +# --------------------------------------------------------------------------- +class TestValidateProjectName: + def test_valid(self): + validate_project_name("my-project_1.0+beta") + + @pytest.mark.parametrize("name", ["", "x" * 101]) + def test_length_invalid(self, name): + with pytest.raises(ValueError, match="1 and 100"): + validate_project_name(name) + + @pytest.mark.parametrize("name", ["hello world", "中文项目", "a/b", "a@b"]) + def test_invalid_chars(self, name): + with pytest.raises(ValueError, match="0-9"): + validate_project_name(name) + + +# --------------------------------------------------------------------------- +# validate_column_params +# --------------------------------------------------------------------------- +class TestValidateColumnParams: + def test_valid_type_and_class(self): + validate_column_params(column_type="FLOAT", column_class="CUSTOM") + + def test_invalid_type(self): + with pytest.raises(ValueError, match="Invalid column_type"): + validate_column_params(column_type="INVALID") + + def test_invalid_class(self): + with pytest.raises(ValueError, match="Invalid column_class"): + validate_column_params(column_class="INVALID") + + +# --------------------------------------------------------------------------- +# validate_metric_type / validate_metric_log_level +# --------------------------------------------------------------------------- +class TestValidateMetricType: + def test_valid_scalar(self): + validate_metric_type("SCALAR", key="loss") + + def test_log_no_key_ok(self): + validate_metric_type("LOG") + + def test_scalar_without_key_raises(self): + with pytest.raises(ValueError, match="key is required"): + validate_metric_type("SCALAR", key="") + + def test_invalid_type(self): + with pytest.raises(ValueError, match="Invalid metric_type"): + validate_metric_type("INVALID", key="x") + + +class TestValidateMetricLogLevel: + def test_valid(self): + validate_metric_log_level("INFO") + + def test_invalid(self): + with pytest.raises(ValueError, match="Invalid metric log level"): + validate_metric_log_level("VERBOSE") + + +# --------------------------------------------------------------------------- +# validate_filter / validate_group / validate_sort +# --------------------------------------------------------------------------- +class TestValidateFilter: + def test_valid(self): + validate_filter({"key": "name", "type": "STABLE", "op": "EQ", "value": ["test"]}) + + def test_missing_fields(self): + with pytest.raises(ValueError, match="Missing required"): + validate_filter({"key": "name"}) + + def test_invalid_type(self): + with pytest.raises(ValueError, match="Invalid type"): + validate_filter({"key": "name", "type": "INVALID", "op": "EQ", "value": ["x"]}) + + def test_invalid_op(self): + with pytest.raises(ValueError, match="Invalid filter op"): + validate_filter({"key": "name", "type": "STABLE", "op": "LIKE", "value": ["x"]}) + + def test_value_not_list(self): + with pytest.raises(ValueError, match="must be a list"): + validate_filter({"key": "name", "type": "STABLE", "op": "EQ", "value": "not_list"}) + + def test_invalid_stable_key(self): + with pytest.raises(ValueError, match="Invalid STABLE key"): + validate_filter({"key": "invalid_key", "type": "STABLE", "op": "EQ", "value": ["x"]}) + + +class TestValidateGroup: + def test_valid(self): + validate_group({"key": "cluster", "type": "STABLE"}) + + def test_missing_fields(self): + with pytest.raises(ValueError, match="Missing required"): + validate_group({"key": "name"}) + + +class TestValidateSort: + def test_valid(self): + validate_sort({"key": "name", "type": "STABLE", "order": "ASC"}) + + def test_invalid_order(self): + with pytest.raises(ValueError, match="Invalid sort order"): + validate_sort({"key": "name", "type": "STABLE", "order": "RANDOM"}) + + +# --------------------------------------------------------------------------- +# PaginatedQuery +# --------------------------------------------------------------------------- +class TestPaginatedQuery: + def test_valid_defaults(self): + q = PaginatedQuery() + assert q.page == 1 and q.size == 20 + + def test_page_less_than_1(self): + with pytest.raises(ValueError, match="page must be >= 1"): + PaginatedQuery(page=0) + + def test_invalid_size(self): + with pytest.raises(ValueError, match="size must be one of"): + PaginatedQuery(size=42) + + def test_to_params_filters_none(self): + q = PaginatedQuery() + params = q.to_params(search=None, sort=None) + assert "search" not in params + assert "sort" not in params + + def test_to_params_includes_extras(self): + q = PaginatedQuery() + params = q.to_params(detail=True, extra_key="val") + assert params["detail"] is True + assert params["extra_key"] == "val" + + +# --------------------------------------------------------------------------- +# SelfHosted validation +# --------------------------------------------------------------------------- +class TestSelfHostedValidation: + def _make_info(self, **overrides) -> ApiSelfHostedInfoType: + base: dict = { + "enabled": True, + "expired": False, + "root": True, + "plan": "free", + "seats": 10, + } + base.update(overrides) + return cast(ApiSelfHostedInfoType, base) + + def test_validate_expire_ok(self): + SelfHosted.validate_expire(self._make_info(expired=False)) + + def test_validate_expire_raises(self): + with pytest.raises(ValueError, match="expired"): + SelfHosted.validate_expire(self._make_info(expired=True)) + + def test_validate_root_ok(self): + SelfHosted.validate_root(self._make_info(expired=False, root=True)) + + def test_validate_root_not_root(self): + with pytest.raises(ValueError, match="root"): + SelfHosted.validate_root(self._make_info(expired=False, root=False)) + + def test_validate_root_expired(self): + with pytest.raises(ValueError, match="expired"): + SelfHosted.validate_root(self._make_info(expired=True, root=True)) From 8ba8b345fd53717cd90b7011d32d621bf11a15a5 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 19:59:24 +0800 Subject: [PATCH 47/52] refactor: resolve path --- swanlab/api/column.py | 107 +++++++++++++++++++++++------- swanlab/api/experiment.py | 62 +++++++++++------ swanlab/api/typings/experiment.py | 3 +- swanlab/api/utils.py | 12 ++-- tests/unit/api/test_api.py | 81 +++++++++++++++++++++- 5 files changed, 213 insertions(+), 52 deletions(-) diff --git a/swanlab/api/column.py b/swanlab/api/column.py index a0935a959..a65124ae5 100644 --- a/swanlab/api/column.py +++ b/swanlab/api/column.py @@ -5,12 +5,12 @@ @description: Column 实体类 — 实验列的查询与操作 """ -from typing import Any, Dict, Iterator, Optional, cast +from typing import Any, Callable, Dict, Iterator, Optional, cast from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.column import ApiColumnType from swanlab.api.typings.common import ApiMetricTypeLiteral, ApiResponseType, PaginatedQuery -from swanlab.api.utils import get_properties, parse_column_data_type, resovle_run_path, validate_column_params +from swanlab.api.utils import get_properties, parse_column_data_type, resolve_run_path, validate_column_params class Column(BaseEntity): @@ -21,6 +21,20 @@ class Column(BaseEntity): 注意:列不支持单个获取 API,只能通过列表接口获取。 """ + @staticmethod + def _resolve_cuid(entity: BaseEntity, path: str, fallback: str = "") -> str: + resp = entity._get(path) + data = resp.data if resp.ok and isinstance(resp.data, dict) else {} + return data.get("cuid", "") or fallback + + @staticmethod + def _resolve_run_cuid(entity: BaseEntity, project_path: str, run_slug: str) -> str: + return Column._resolve_cuid(entity, f"/project/{project_path}/runs/{run_slug}", fallback=run_slug) + + @staticmethod + def _resolve_project_cuid(entity: BaseEntity, project_path: str) -> str: + return Column._resolve_cuid(entity, f"/project/{project_path}") + def __init__( self, ctx: ApiClientContext, @@ -30,23 +44,29 @@ def __init__( column_class: Optional[str] = "CUSTOM", column_type: Optional[str] = None, data: Optional[ApiColumnType] = None, + project_id: Optional[str] = None, + run_id: Optional[str] = None, + project_id_getter: Optional[Callable[[], str]] = None, ) -> None: super().__init__(ctx) - self._proj_path, self._run_id = resovle_run_path(path=path) + self._proj_path, self._run_slug = resolve_run_path(path=path) self._key = key self._column_class = column_class self._column_type = column_type self._data = data - self._project_id = None + self._project_id = project_id or (data or {}).get("project_id", "") or None + self._run_id = run_id or (data or {}).get("run_id", "") or "" + self._project_id_getter = project_id_getter def _ensure_data(self) -> ApiColumnType: if self._data is None: validate_column_params(column_class=self._column_class, column_type=self._column_type) extra: Dict[str, Any] = {"search": self._key} + run_id = self._ensure_run_id() if self._column_class: extra["class"] = self._column_class resp = self._get( - f"/experiment/{self._run_id}/column", + f"/experiment/{run_id}/column", params={"page": 1, "size": 10, **extra}, ) if resp.data: @@ -55,30 +75,44 @@ def _ensure_data(self) -> ApiColumnType: self._data = cast(ApiColumnType, items[0]) if self._data is None: self._data = cast(ApiColumnType, {}) + self._data.setdefault("run_id", self._ensure_run_id()) + return self._data + + def _ensure_run_id(self) -> str: + if self._run_id: + return self._run_id + if self._data and (run_id := self._data.get("run_id", "")): + self._run_id = run_id + return run_id + self._run_id = Column._resolve_run_cuid(self, self._proj_path, self._run_slug) + if self._data is not None: self._data["run_id"] = self._run_id + return self._run_id + + def _ensure_project_id(self) -> str: + if self._project_id: + return self._project_id + if self._data and (project_id := self._data.get("project_id", "")): + self._project_id = project_id + return project_id + if self._project_id_getter is not None: + self._project_id = self._project_id_getter() + if self._data is not None: + self._data["project_id"] = self._project_id + return self._project_id if self._project_id is None: - resp = self._get(f"/project/{self._proj_path}") - proj_data = resp.data if resp.ok else {} - self._project_id = proj_data.get("cuid", "") - self._data["project_id"] = self._project_id - # 这里要确保是 cuid 而非 slug - run_resp = self._get(f"/project/{self._proj_path}/runs/{self._run_id}") - run_data = run_resp.data if run_resp.ok else {} - run_cuid = run_data.get("cuid", "") - self._run_id = run_cuid - return self._data + self._project_id = Column._resolve_project_cuid(self, self._proj_path) + if self._data is not None: + self._data["project_id"] = self._project_id + return self._project_id @property def project_id(self) -> str: - if self._project_id: - return self._project_id - return self._ensure_data().get("project_id", "") + return self._ensure_project_id() @property def run_id(self) -> str: - if self._project_id: - return self._run_id - return self._ensure_data().get("run_id", "") + return self._ensure_run_id() @property def key(self) -> str: @@ -180,10 +214,16 @@ def __init__( query: PaginatedQuery, column_class: Optional[str] = None, column_type: Optional[str] = None, + project_id: Optional[str] = None, + run_id: Optional[str] = None, + project_id_getter: Optional[Callable[[], str]] = None, ) -> None: super().__init__(ctx) self._run_path = path - self._proj_path, self._run_id = resovle_run_path(path=path) + self._proj_path, self._run_slug = resolve_run_path(path=path) + self._run_id = run_id or "" + self._project_id = project_id + self._project_id_getter = project_id_getter self._query = query # 校验 column_type 和 column_class 的合法性 validate_column_params(column_type=column_type, column_class=column_class) @@ -197,25 +237,44 @@ def __init__( "list": [], } + def _ensure_run_id(self) -> str: + if self._run_id: + return self._run_id + self._run_id = Column._resolve_run_cuid(self, self._proj_path, self._run_slug) + return self._run_id + + def _ensure_project_id(self) -> str: + if self._project_id: + return self._project_id + if self._project_id_getter is not None: + self._project_id = self._project_id_getter() + return self._project_id + self._project_id = Column._resolve_project_cuid(self, self._proj_path) + return self._project_id + def __iter__(self) -> Iterator[Column]: """迭代分页获取列。""" extra: Dict[str, Any] = {} + run_id = self._ensure_run_id() if self._column_type: extra["type"] = self._column_type if self._column_class: extra["class"] = self._column_class for item in self._paginate( - f"/experiment/{self._run_id}/column", + f"/experiment/{run_id}/column", self._query, page_info=self._page_info, extra=extra, ): + data = {**item, "run_id": run_id} yield Column( self._ctx, path=self._run_path, key=item.get("key", ""), - data=cast(ApiColumnType, item), + data=cast(ApiColumnType, data), + run_id=run_id, + project_id_getter=self._ensure_project_id, ) @property diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 8f2e03d97..70c5847fd 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -17,7 +17,7 @@ from swanlab.api.typings.user import ApiUserType from swanlab.api.utils import ( get_properties, - resovle_run_path, + resolve_run_path, validate_filter, validate_group, validate_sort, @@ -41,34 +41,49 @@ def __init__( data: Optional[ApiExperimentType] = None, ) -> None: super().__init__(ctx) - self._proj_path, self._cuid = resovle_run_path(path=path) - self._data = data - self._project_id = None + self._proj_path, self._run_slug = resolve_run_path(path=path) + self._cuid: str = (data or {}).get("cuid", "") or self._run_slug + self._data: Optional[ApiExperimentType] = data + self._project_id = "" + + def _refresh_cuid(self) -> None: + if self._data: + self._cuid = self._data.get("cuid", "") or self._cuid def _ensure_data(self) -> ApiExperimentType: if self._data is None: resp = self._get(f"/project/{self._proj_path}/runs/{self._cuid}") self._data = resp.data if resp.ok and resp.data else cast(ApiExperimentType, {}) - if not self._cuid and self._data: - self._cuid = self._data.get("cuid", "") - if self._project_id is None: + self._refresh_cuid() + assert self._data is not None + return self._data + + def _ensure_project_id(self) -> str: + if self._project_id: + return self._project_id + if self._data and (project_id := str(self._data.get("project_id") or "")): + self._project_id = project_id + return project_id + if not self._project_id: resp = self._get(f"/project/{self._proj_path}") proj_data = resp.data if resp.ok else {} - self._project_id = proj_data.get("cuid", "") - self._data["project_id"] = self._project_id - return self._data + self._project_id = str(proj_data.get("cuid") or "") + if self._data is not None: + self._data["project_id"] = self._project_id + return self._project_id @property def project_id(self) -> str: - if self._project_id: - return self._project_id - return self._ensure_data().get("project_id", "") + return self._ensure_project_id() @property def run_id(self) -> str: - if self._cuid: - return self._cuid - return self._ensure_data().get("cuid", "") + self._ensure_data() + return self._cuid + + def _run_url_ref(self) -> str: + data = self._ensure_data() + return data.get("slug") or self._run_slug or self.run_id @property def name(self) -> str: @@ -88,7 +103,7 @@ def state(self) -> str: @property def url(self) -> str: - return self._build_web_url(f"@{self._proj_path}/runs/{self.run_id}/chart") + return self._build_web_url(f"@{self._proj_path}/runs/{self._run_url_ref()}/chart") @property def show(self) -> bool: @@ -127,6 +142,7 @@ def profile(self) -> ApiExperimentProfileType: resp = self._get(f"/project/{self._proj_path}/runs/{self._cuid}") if resp.ok and resp.data: self._data = resp.data + self._refresh_cuid() data = self._data return cast(ApiExperimentProfileType, self._ensure_data().get("profile", {})) @@ -140,12 +156,15 @@ def column(self, key: str, column_class: Optional[str] = "CUSTOM", column_type: """ from swanlab.api.column import Column + run_id = self.run_id return Column( self._ctx, - path=f"{self._proj_path}/{self._cuid}", + path=f"{self._proj_path}/{run_id}", key=key, column_class=column_class, column_type=column_type, + run_id=run_id, + project_id_getter=lambda: self.project_id, ) def metrics( @@ -228,17 +247,20 @@ def columns( from swanlab.api.column import Columns query = PaginatedQuery(page=page, size=size, search=search, all=all) + run_id = self.run_id return Columns( self._ctx, - path=f"{self._proj_path}/{self._cuid}", + path=f"{self._proj_path}/{run_id}", query=query, column_type=column_type, column_class=column_class, + run_id=run_id, + project_id_getter=lambda: self.project_id, ) def delete(self) -> bool: """删除此实验。""" - resp = self._delete(f"/project/{self._proj_path}/runs/{self._cuid}") + resp = self._delete(f"/project/{self._proj_path}/runs/{self.run_id}") return resp.ok def json(self) -> Dict[str, Any]: diff --git a/swanlab/api/typings/experiment.py b/swanlab/api/typings/experiment.py index 958f725c8..a20eeec4e 100644 --- a/swanlab/api/typings/experiment.py +++ b/swanlab/api/typings/experiment.py @@ -11,7 +11,7 @@ - SCALAR: 训练过程中记录的标量指标最新值(动态 key,如 train/loss) """ -from typing import Any, Dict, List, TypedDict +from typing import Any, Dict, List, Optional, TypedDict from .common import ( ApiExperimentTypeLiteral, @@ -82,6 +82,7 @@ class ApiExperimentProfileType(TypedDict): class ApiExperimentType(TypedDict, total=False): project_id: str cuid: str + slug: Optional[str] name: str type: ApiExperimentTypeLiteral description: str diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index bf6bfe574..fda085961 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -49,17 +49,17 @@ def get_properties(obj: object, _visited: Optional[Set[int]] = None) -> Dict[str # 路径解析 -def resovle_run_path(path: str) -> Tuple[str, str]: - """ "path like: user/proj_name/run_id""" - proj_path, cuid = "", "" +def resolve_run_path(path: str) -> Tuple[str, str]: + """ "path like: user/proj_name/run_slug""" + proj_path, run_slug = "", "" parts = path.split("/") if len(parts) != 3: - return proj_path, cuid - cuid = parts[-1] + return proj_path, run_slug + run_slug = parts[-1] proj_path = path.rsplit("/", 1)[0] return ( proj_path, - cuid, + run_slug, ) diff --git a/tests/unit/api/test_api.py b/tests/unit/api/test_api.py index 94f438e79..01d57c39e 100644 --- a/tests/unit/api/test_api.py +++ b/tests/unit/api/test_api.py @@ -14,7 +14,7 @@ from swanlab.api import Api from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.column import Columns +from swanlab.api.column import Column, Columns from swanlab.api.experiment import Experiment, Experiments from swanlab.api.metric import Metric, Metrics from swanlab.api.project import Project @@ -233,6 +233,33 @@ def test_experiment_delete_4xx_returns_false(self, ctx_404): assert exp.delete() is False +# --------------------------------------------------------------------------- +# Experiment — run slug / cuid 解析 +# --------------------------------------------------------------------------- +class TestExperimentRunIdResolution: + def test_run_id_resolves_slug_to_cuid(self, ctx): + ctx.client.get.side_effect = [ + _api_response({"cuid": "run-cuid", "slug": "run-slug", "name": "test-run"}), + ] + + exp = Experiment(ctx, path="user/proj/run-slug") + + assert exp.run_id == "run-cuid" + assert [call.args[0] for call in ctx.client.get.call_args_list] == [ + "/project/user/proj/runs/run-slug", + ] + + def test_column_created_from_experiment_uses_run_cuid(self, ctx): + ctx.client.get.side_effect = [ + _api_response({"cuid": "run-cuid", "slug": "run-slug", "name": "test-run"}), + ] + exp = Experiment(ctx, path="user/proj/run-slug") + + column = exp.column("loss") + + assert column.run_id == "run-cuid" + + # --------------------------------------------------------------------------- # Column / Columns — 校验 + 4xx # --------------------------------------------------------------------------- @@ -245,6 +272,58 @@ def test_columns_invalid_class_raises(self, ctx): with pytest.raises(ValueError, match="Invalid column_class"): Columns(ctx, path="user/proj/run1", query=PaginatedQuery(), column_class="INVALID") + def test_iterated_columns_resolve_run_cuid_once_for_local_fields(self, ctx): + ctx.client.get.side_effect = [ + _api_response({"cuid": "run-cuid", "slug": "run-slug"}), + _api_response( + { + "list": [ + {"key": "loss", "name": "loss", "type": "FLOAT", "class": "CUSTOM"}, + {"key": "acc", "name": "acc", "type": "FLOAT", "class": "CUSTOM"}, + ], + "total": 2, + "pages": 1, + } + ), + ] + + columns = Columns(ctx, path="user/proj/run-slug", query=PaginatedQuery()) + + assert [column.name for column in columns] == ["loss", "acc"] + assert [call.args[0] for call in ctx.client.get.call_args_list] == [ + "/project/user/proj/runs/run-slug", + "/experiment/run-cuid/column", + ] + + def test_column_resolves_run_cuid_before_fetching_column_data(self, ctx): + ctx.client.get.side_effect = [ + _api_response({"cuid": "run-cuid", "slug": "run-slug"}), + _api_response( + { + "list": [{"key": "loss", "name": "loss", "type": "FLOAT", "class": "CUSTOM"}], + "total": 1, + "pages": 1, + } + ), + ] + + col = Column(ctx, path="user/proj/run-slug", key="loss") + + assert col.name == "loss" + assert col.run_id == "run-cuid" + assert [call.args[0] for call in ctx.client.get.call_args_list] == [ + "/project/user/proj/runs/run-slug", + "/experiment/run-cuid/column", + ] + + def test_column_project_id_fetches_project_lazily(self, ctx): + item = {"key": "loss", "name": "loss", "type": "FLOAT", "class": "CUSTOM"} + col = Column(ctx, path="user/proj/run1", key="loss", data=cast(Any, item)) + ctx.client.get.return_value = _api_response({"cuid": "project-cuid"}) + + assert col.project_id == "project-cuid" + assert [call.args[0] for call in ctx.client.get.call_args_list] == ["/project/user/proj"] + # --------------------------------------------------------------------------- # Metric / Metrics — 校验 From 869b8122254e4226187c7c3220429cf65fcdc094 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 20:08:43 +0800 Subject: [PATCH 48/52] chore: downgrade log level for media metrics --- swanlab/api/metric.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index 2dedfba56..c13d7113f 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -244,7 +244,7 @@ def _fetch_media(self) -> ApiMediaSeriesType: all_paths = metric_entry.get("data", []) url_map = self._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} if all_paths: - console.info( + console.debug( f"Media fetched: run_id[{self.run_id}], key[{self.key}] - {len(all_paths)} items, requesting presigned urls..." ) items = self._build_media_items(metric_entry, url_map) @@ -263,7 +263,7 @@ def _fetch_media_all(self) -> ApiMediaSeriesType: all_paths = [p for entry in raw_data.get("metrics", []) for p in entry.get("data", [])] url_map = self._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} if all_paths: - console.info( + console.debug( f"Media fetched (all): run_id[{self.run_id}], key[{self.key}] - {len(all_paths)} items, requesting presigned urls..." ) res["metrics"] = [ @@ -482,7 +482,7 @@ def _fetch_medias(self) -> Iterator[Metric]: all_paths = [p for entry in metrics_raw for p in entry.get("data", [])] url_map = Metric._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} if all_paths: - console.info( + console.debug( f"Media fetched: run_id[{self._run_id}] - {len(all_paths)} items across {len(self._keys)} keys, requesting presigned urls..." ) @@ -515,7 +515,7 @@ def _fetch_medias_all(self) -> Iterator[Metric]: all_paths = [p for entry in raw_list for m in entry.get("metrics", []) for p in m.get("data", [])] url_map = Metric._fetch_presigned_urls(self, prefix, all_paths) if all_paths else {} if all_paths: - console.info( + console.debug( f"Media fetched (all): run_id[{self._run_id}] - {len(all_paths)} items across {len(self._keys)} keys, requesting presigned urls..." ) From e4959da1bd3c271ee842218aca01c87b9340a26e Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 22:51:08 +0800 Subject: [PATCH 49/52] feat: optimize validation --- swanlab/api/__init__.py | 10 +++++----- swanlab/api/column.py | 18 ++++++++++++------ swanlab/api/experiment.py | 18 ++++++++++++++---- swanlab/api/metric.py | 11 +++++------ swanlab/api/selfhosted.py | 8 +++----- swanlab/api/typings/__init__.py | 6 ++++-- swanlab/api/typings/common.py | 5 ++++- swanlab/api/typings/workspace.py | 3 +-- swanlab/api/utils.py | 22 ++++++++++++++++++---- swanlab/api/workspace.py | 5 ++--- 10 files changed, 68 insertions(+), 38 deletions(-) diff --git a/swanlab/api/__init__.py b/swanlab/api/__init__.py index acc62ceb3..62a9061b3 100644 --- a/swanlab/api/__init__.py +++ b/swanlab/api/__init__.py @@ -17,7 +17,7 @@ from .experiment import Experiment, Experiments from .project import Project, Projects from .selfhosted import SelfHosted -from .typings.common import PaginatedQuery +from .typings.common import ApiColumnClassLiteral, ApiColumnDataTypeLiteral, PaginatedQuery from .user import User from .utils import validate_api_path, validate_non_empty_string from .workspace import Workspace, Workspaces @@ -220,8 +220,8 @@ def columns( page: int = 1, size: int = 20, search: Optional[str] = None, - column_class: str = "CUSTOM", - column_type: Optional[str] = None, + column_class: ApiColumnClassLiteral = "CUSTOM", + column_type: Optional[ApiColumnDataTypeLiteral] = None, all: bool = False, ) -> Columns: """ @@ -249,8 +249,8 @@ def column( self, path: str, key: str, - column_class: Optional[str] = "CUSTOM", - column_type: Optional[str] = None, + column_class: Optional[ApiColumnClassLiteral] = "CUSTOM", + column_type: Optional[ApiColumnDataTypeLiteral] = None, ) -> Column: """ 获取单个列(通过搜索 key 匹配)。 diff --git a/swanlab/api/column.py b/swanlab/api/column.py index a65124ae5..f8ff82c8d 100644 --- a/swanlab/api/column.py +++ b/swanlab/api/column.py @@ -9,7 +9,13 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.column import ApiColumnType -from swanlab.api.typings.common import ApiMetricTypeLiteral, ApiResponseType, PaginatedQuery +from swanlab.api.typings.common import ( + ApiColumnClassLiteral, + ApiColumnDataTypeLiteral, + ApiMetricColumnTypeLiteral, + ApiResponseType, + PaginatedQuery, +) from swanlab.api.utils import get_properties, parse_column_data_type, resolve_run_path, validate_column_params @@ -41,8 +47,8 @@ def __init__( *, path: str, key: str, - column_class: Optional[str] = "CUSTOM", - column_type: Optional[str] = None, + column_class: Optional[ApiColumnClassLiteral] = "CUSTOM", + column_type: Optional[ApiColumnDataTypeLiteral] = None, data: Optional[ApiColumnType] = None, project_id: Optional[str] = None, run_id: Optional[str] = None, @@ -148,7 +154,7 @@ def error(self) -> Optional[Dict[str, Any]]: def metric( self, sample: int = 1500, - metric_type: ApiMetricTypeLiteral = "SCALAR", + metric_type: ApiMetricColumnTypeLiteral = "SCALAR", ignore_timestamp: bool = False, media_step: Optional[int] = None, ) -> Dict[str, Any]: @@ -212,8 +218,8 @@ def __init__( *, path: str, query: PaginatedQuery, - column_class: Optional[str] = None, - column_type: Optional[str] = None, + column_class: Optional[ApiColumnClassLiteral] = None, + column_type: Optional[ApiColumnDataTypeLiteral] = None, project_id: Optional[str] = None, run_id: Optional[str] = None, project_id_getter: Optional[Callable[[], str]] = None, diff --git a/swanlab/api/experiment.py b/swanlab/api/experiment.py index 70c5847fd..ee8943987 100644 --- a/swanlab/api/experiment.py +++ b/swanlab/api/experiment.py @@ -8,7 +8,12 @@ from typing import Any, Dict, Iterator, List, Optional, Union, cast from swanlab.api.base import ApiClientContext, BaseEntity -from swanlab.api.typings.common import ApiMetricLogLevelLiteral, PaginatedQuery +from swanlab.api.typings.common import ( + ApiColumnClassLiteral, + ApiColumnDataTypeLiteral, + ApiMetricLogLevelLiteral, + PaginatedQuery, +) from swanlab.api.typings.experiment import ( ApiExperimentLabelType, ApiExperimentProfileType, @@ -146,7 +151,12 @@ def profile(self) -> ApiExperimentProfileType: data = self._data return cast(ApiExperimentProfileType, self._ensure_data().get("profile", {})) - def column(self, key: str, column_class: Optional[str] = "CUSTOM", column_type: Optional[str] = "FLOAT"): + def column( + self, + key: str, + column_class: Optional[ApiColumnClassLiteral] = "CUSTOM", + column_type: Optional[ApiColumnDataTypeLiteral] = "FLOAT", + ): """ 获取实验下指定 key 的单个列。 @@ -230,8 +240,8 @@ def columns( page: int = 1, size: int = 20, search: Optional[str] = None, - column_type: Optional[str] = None, - column_class: Optional[str] = None, + column_type: Optional[ApiColumnDataTypeLiteral] = None, + column_class: Optional[ApiColumnClassLiteral] = None, all: bool = False, ): """ diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index c13d7113f..f4cc066f0 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -11,14 +11,14 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings import ApiColumnCsvExportType, ApiResponseType -from swanlab.api.typings.common import ApiMetricTypeLiteral +from swanlab.api.typings.common import ApiMetricColumnTypeLiteral, ApiMetricLogLevelLiteral from swanlab.api.typings.metric import ( ApiLogSeriesType, ApiMediaItemDataType, ApiMediaSeriesType, ApiScalarSeriesType, ) -from swanlab.api.utils import get_properties, validate_metric_log_level, validate_metric_type +from swanlab.api.utils import get_properties, validate_metric_keys, validate_metric_log_level, validate_metric_type from swanlab.sdk.internal.pkg import console _SCALAR_STATISTIC_FIELDS = ("min", "max", "avg", "median", "latest") @@ -49,7 +49,7 @@ def __init__( key: Optional[str] = "", sample: int = 1000, log_offset: Optional[int] = 0, # 标记第几个分片,仅对 Log metric_type 有效 - log_level: str = "INFO", + log_level: ApiMetricLogLevelLiteral = "INFO", metric_type: str = "SCALAR", data: Optional[Dict[str, Any]] = None, ignore_timestamp: bool = False, @@ -353,15 +353,14 @@ def __init__( project_id: str, run_id: str, keys: List[str], - metric_type: ApiMetricTypeLiteral, + metric_type: ApiMetricColumnTypeLiteral, sample: int = 1500, ignore_timestamp: bool = False, media_step: Optional[int] = None, all: bool = False, ) -> None: super().__init__(ctx) - if not isinstance(keys, list) or not keys or any(not isinstance(key, str) or not key.strip() for key in keys): - raise ValueError("keys must be a non-empty list") + validate_metric_keys(keys) validate_metric_type(metric_type, keys[0]) if metric_type == "LOG": raise ValueError("Metrics does not support LOG metric_type, use Experiment.logs() instead") diff --git a/swanlab/api/selfhosted.py b/swanlab/api/selfhosted.py index fde6b904b..2280defa9 100644 --- a/swanlab/api/selfhosted.py +++ b/swanlab/api/selfhosted.py @@ -10,7 +10,7 @@ from swanlab.api.base import ApiClientContext, BaseEntity from swanlab.api.typings.common import ApiResponseType, PaginatedQuery from swanlab.api.typings.selfhosted import ApiLicensePlanLiteral, ApiSelfHostedInfoType -from swanlab.api.utils import get_properties +from swanlab.api.utils import get_properties, validate_non_empty_string class SelfHosted(BaseEntity): @@ -82,10 +82,8 @@ def create_user(self, username: str, password: str) -> ApiResponseType: :param password: 待创建用户密码 """ SelfHosted.validate_root(self._ensure_data()) - if not isinstance(username, str) or not username.strip(): - raise ValueError("username must be a non-empty string") - if not isinstance(password, str) or not password.strip(): - raise ValueError("password must be a non-empty string") + validate_non_empty_string(username, label="username") + validate_non_empty_string(password, label="password") data = {"users": [{"username": username, "password": password}]} return self._post("/self_hosted/users", data=data) diff --git a/swanlab/api/typings/__init__.py b/swanlab/api/typings/__init__.py index e1ba1bb28..5b60c1c9e 100644 --- a/swanlab/api/typings/__init__.py +++ b/swanlab/api/typings/__init__.py @@ -9,8 +9,9 @@ from .common import ( ApiIdentityLiteral, ApiLicensePlanLiteral, + ApiMetricAllTypeLiteral, + ApiMetricColumnTypeLiteral, ApiMetricLogLevelLiteral, - ApiMetricTypeLiteral, ApiMetricXAxisLiteral, ApiPaginationType, ApiResponseType, @@ -41,7 +42,8 @@ "ApiIdentityLiteral", "ApiLicensePlanLiteral", "ApiMetricLogLevelLiteral", - "ApiMetricTypeLiteral", + "ApiMetricAllTypeLiteral", + "ApiMetricColumnTypeLiteral", "ApiMetricXAxisLiteral", # General TypedDicts "ApiPaginationType", diff --git a/swanlab/api/typings/common.py b/swanlab/api/typings/common.py index b79833a67..0b1a07195 100644 --- a/swanlab/api/typings/common.py +++ b/swanlab/api/typings/common.py @@ -64,7 +64,10 @@ ApiLicensePlanLiteral = Literal["free", "commercial"] # 指标类型(log 不属于 column-backed metrics,使用独立查询方法) -ApiMetricTypeLiteral = Literal["SCALAR", "MEDIA"] +ApiMetricColumnTypeLiteral = Literal["SCALAR", "MEDIA"] + +# 指标扩展类型(包含 LOG,用于内部 Metric 调度) +ApiMetricAllTypeLiteral = Literal["SCALAR", "MEDIA", "LOG"] # 指标日志级别 ApiMetricLogLevelLiteral = Literal["DEBUG", "INFO", "WARN", "ERROR"] diff --git a/swanlab/api/typings/workspace.py b/swanlab/api/typings/workspace.py index c05db67fc..5ee279885 100644 --- a/swanlab/api/typings/workspace.py +++ b/swanlab/api/typings/workspace.py @@ -9,9 +9,8 @@ from .common import ApiRoleLiteral, ApiWorkspaceLiteral -# 工作空间即 Group 组织 - +# 工作空间即 Group 组织 class ApiWorkspaceProfileType(TypedDict): bio: str url: str diff --git a/swanlab/api/utils.py b/swanlab/api/utils.py index fda085961..27869cf57 100644 --- a/swanlab/api/utils.py +++ b/swanlab/api/utils.py @@ -14,10 +14,11 @@ ApiColumnScalarTypeLiteral, ApiFilterOpLiteral, ApiFilterStableKeyLiteral, + ApiMetricAllTypeLiteral, ApiMetricLogLevelLiteral, - ApiMetricTypeLiteral, ApiSidebarLiteral, ApiSortOrderLiteral, + ApiVisibilityLiteral, ) @@ -86,6 +87,7 @@ def validate_non_empty_string(value: str, *, label: str) -> None: _VALID_OPS = frozenset(get_args(ApiFilterOpLiteral)) _VALID_ORDERS = frozenset(get_args(ApiSortOrderLiteral)) _STABLE_KEYS = frozenset(get_args(ApiFilterStableKeyLiteral)) +_VALID_VISIBILITIES = frozenset(get_args(ApiVisibilityLiteral)) _PROJECT_NAME_RE = re.compile(r"^[0-9a-zA-Z\-_.+]+$") @@ -95,7 +97,7 @@ def validate_non_empty_string(value: str, *, label: str) -> None: _VALID_COLUMN_SCALAR_TYPES = frozenset(get_args(ApiColumnScalarTypeLiteral)) # 指标相关校验常量 -_VALID_METRIC_TYPES = frozenset(get_args(ApiMetricTypeLiteral)) +_VALID_METRIC_ALL_TYPES = frozenset(get_args(ApiMetricAllTypeLiteral)) _VALID_METRIC_LOG_LEVELS = frozenset(get_args(ApiMetricLogLevelLiteral)) @@ -131,8 +133,8 @@ def validate_filter(item: Dict[str, Any]) -> None: def validate_metric_type(metric_type: str, key: Optional[str] = None) -> None: """校验 metric_type 的合法性。非 LOG 类型必须提供非空 key。""" - if metric_type not in _VALID_METRIC_TYPES and metric_type != "LOG": - raise ValueError(f"Invalid metric_type: {metric_type!r}, expected one of {sorted(_VALID_METRIC_TYPES)}") + if metric_type not in _VALID_METRIC_ALL_TYPES: + raise ValueError(f"Invalid metric_type: {metric_type!r}, expected one of {sorted(_VALID_METRIC_ALL_TYPES)}") if metric_type != "LOG" and (not isinstance(key, str) or not key.strip()): raise ValueError(f"key is required for metric_type {metric_type!r}, got key={key!r}") @@ -212,3 +214,15 @@ def validate_project_name(name: str) -> None: raise ValueError("Project name must be between 1 and 100 characters.") if not _PROJECT_NAME_RE.match(name): raise ValueError("Project name can only contain 0-9, a-z, A-Z, -, _, ., +") + + +def validate_visibility(visibility: str) -> None: + """校验 visibility 的合法性。""" + if visibility not in _VALID_VISIBILITIES: + raise ValueError(f"Invalid visibility: {visibility!r}, expected one of {sorted(_VALID_VISIBILITIES)}") + + +def validate_metric_keys(keys: List[str]) -> None: + """校验 metric keys 列表的合法性。""" + if not isinstance(keys, list) or not keys or any(not isinstance(key, str) or not key.strip() for key in keys): + raise ValueError("keys must be a non-empty list of non-empty strings") diff --git a/swanlab/api/workspace.py b/swanlab/api/workspace.py index cabca484b..15238008a 100644 --- a/swanlab/api/workspace.py +++ b/swanlab/api/workspace.py @@ -11,7 +11,7 @@ from swanlab.api.typings.common import ApiVisibilityLiteral, PaginatedQuery from swanlab.api.typings.project import ApiProjectType from swanlab.api.typings.workspace import ApiWorkspaceLiteral, ApiWorkspaceProfileType, ApiWorkspaceType -from swanlab.api.utils import get_properties, strip_dict, validate_project_name +from swanlab.api.utils import get_properties, strip_dict, validate_project_name, validate_visibility from swanlab.sdk.internal.pkg import safe if TYPE_CHECKING: @@ -101,8 +101,7 @@ def create_project( with safe.block(message=None): validate_project_name(name) - if visibility not in ("PUBLIC", "PRIVATE"): - raise ValueError("Invalid visibility, expected PUBLIC or PRIVATE.") + validate_visibility(visibility) body: Dict[str, Any] = {"name": name, "visibility": visibility, "username": self.username} if description: From fd095c3e25ab5b8fee21ba555a8cf377ff8d6107 Mon Sep 17 00:00:00 2001 From: Nexisato <978452096@qq.com> Date: Mon, 27 Apr 2026 23:10:41 +0800 Subject: [PATCH 50/52] feat: add metric sample placeholder --- swanlab/api/metric.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/swanlab/api/metric.py b/swanlab/api/metric.py index f4cc066f0..8b2bff884 100644 --- a/swanlab/api/metric.py +++ b/swanlab/api/metric.py @@ -66,7 +66,7 @@ def __init__( self._data: Optional[Dict[str, Any]] = data self._metric_type = metric_type self._ignore_timestamp = ignore_timestamp - # TODO: 采样值, scalar 时生效,logs 时降级到 1000 + # 采样值, scalar 时生效 self._sample = sample # 偏移量,仅对 Log metric_type 有效, 默认为 0 self._offset = log_offset @@ -140,12 +140,13 @@ def _extract_first(resp: ApiResponseType) -> Optional[Dict[str, Any]]: return None @staticmethod - def _build_scalar_payload(project_id: str, run_id: str, keys: List[str]) -> Dict[str, Any]: + def _build_scalar_payload(project_id: str, run_id: str, keys: List[str], sample: int = 1500) -> Dict[str, Any]: return { "projectId": project_id, "xType": "step", "range": [0, 0], "columns": [{"experimentId": run_id, "key": key} for key in keys], + "num": sample if sample <= 1500 else 1500, } @staticmethod @@ -175,7 +176,7 @@ def _build_log_params(self) -> Dict[str, Any]: def _fetch_scalar(self) -> ApiScalarSeriesType: res = ApiScalarSeriesType(projectId=self.project_id, experimentId=self.run_id, key=self.key) - payload = self._build_scalar_payload(self.project_id, self.run_id, [self.key]) + payload = self._build_scalar_payload(self.project_id, self.run_id, [self.key], self._sample) # 1. 获取折线数据 raw_data = self._extract_first(self._post("/house/metrics/scalar", data=payload)) @@ -411,7 +412,7 @@ def _build_metric(self, key: str, data: Dict[str, Any]) -> Metric: ) def _fetch_scalars(self) -> Iterator[Metric]: - payload = Metric._build_scalar_payload(self._project_id, self._run_id, self._keys) + payload = Metric._build_scalar_payload(self._project_id, self._run_id, self._keys, self._sample) # 1. 获取折线数据 scalar_resp = self._post("/house/metrics/scalar", data=payload) @@ -446,7 +447,7 @@ def _fetch_scalars_all(self) -> Iterator[Metric]: if resp.ok and resp.data: urls[key] = _extract_csv_url(resp.data) - payload = Metric._build_scalar_payload(self._project_id, self._run_id, self._keys) + payload = Metric._build_scalar_payload(self._project_id, self._run_id, self._keys, self._sample) value_resp = self._post("/house/metrics/scalar/value", data=payload) value_list: List[Dict[str, Any]] = value_resp.ok and isinstance(value_resp.data, list) and value_resp.data or [] From 6b7b7c72ccf23f84308fd16304ac1000585e31b5 Mon Sep 17 00:00:00 2001 From: neyuki778 <2597605722@qq.com> Date: Tue, 28 Apr 2026 11:26:45 +0800 Subject: [PATCH 51/52] feat: add cli -h/-k interceptor --- swanlab/cli/api/experiment.py | 7 +++---- swanlab/cli/api/helper.py | 39 ++++++++++++++++++++++++++++++++++- swanlab/cli/api/project.py | 7 +++---- swanlab/cli/api/workspace.py | 7 +++---- 4 files changed, 47 insertions(+), 13 deletions(-) diff --git a/swanlab/cli/api/experiment.py b/swanlab/cli/api/experiment.py index bcea5e83b..c241b568e 100644 --- a/swanlab/cli/api/experiment.py +++ b/swanlab/cli/api/experiment.py @@ -1,8 +1,7 @@ import click import orjson -from swanlab.api import Api -from swanlab.cli.api.helper import format_output, save_output +from swanlab.cli.api.helper import format_output, save_output, with_custom_host @click.group("run") @@ -22,9 +21,9 @@ def experiment_cli(): default=None, help="Save output as JSON to current directory.", ) -def get_experiment(path: str, name): +@with_custom_host +def get_experiment(path: str, name, api): """Get Experiment(Run) info by path (username/project/run_id).""" - api = Api() resp = api.run(path).wrapper() format_output(resp) if resp.ok and name is not None: diff --git a/swanlab/cli/api/helper.py b/swanlab/cli/api/helper.py index 8b3b921fc..5acc176a6 100644 --- a/swanlab/cli/api/helper.py +++ b/swanlab/cli/api/helper.py @@ -1,11 +1,13 @@ import enum from datetime import datetime -from typing import Optional +from functools import wraps +from typing import Callable, Optional import click import nanoid import orjson +from swanlab.api import Api from swanlab.api.typings.common import ApiResponseType @@ -13,6 +15,41 @@ class _SaveFormatEnum(enum.Enum): JSON = "json" +def with_custom_host(func: Callable) -> Callable: + """ + Add common SwanLab API host/auth options to a CLI command. + + The wrapped command receives an `api` keyword argument. When no option is + provided, the default local login settings are used. + """ + + @click.option( + "--host", + "-h", + default=None, + type=str, + help="The host of the SwanLab server.", + ) + @click.option( + "--api-key", + "--api_key", + "-k", + "api_key", + default=None, + type=str, + help="The API key to use for authentication.", + ) + @wraps(func) + def wrapper(*args, host: Optional[str], api_key: Optional[str], **kwargs): + if host is None and api_key is None: + api = Api() + else: + api = Api(host=host, api_key=api_key) + return func(*args, api=api, **kwargs) + + return wrapper + + def format_output(resp: ApiResponseType, fmt: _SaveFormatEnum = _SaveFormatEnum.JSON) -> None: if fmt == _SaveFormatEnum.JSON: click.echo(orjson.dumps(resp.json(), option=orjson.OPT_INDENT_2).decode()) diff --git a/swanlab/cli/api/project.py b/swanlab/cli/api/project.py index 86771321c..05ffb8f60 100644 --- a/swanlab/cli/api/project.py +++ b/swanlab/cli/api/project.py @@ -1,8 +1,7 @@ import click import orjson -from swanlab.api import Api -from swanlab.cli.api.helper import format_output, save_output +from swanlab.cli.api.helper import format_output, save_output, with_custom_host @click.group("project") @@ -22,9 +21,9 @@ def project_cli(): default=None, help="Save output as JSON to current directory.", ) -def get_project(path: str, name): +@with_custom_host +def get_project(path: str, name, api): """Get project info by path (username/project).""" - api = Api() resp = api.project(path).wrapper() format_output(resp) if resp.ok and name is not None: diff --git a/swanlab/cli/api/workspace.py b/swanlab/cli/api/workspace.py index ab1723560..9119fdd13 100644 --- a/swanlab/cli/api/workspace.py +++ b/swanlab/cli/api/workspace.py @@ -1,8 +1,7 @@ import click import orjson -from swanlab.api import Api -from swanlab.cli.api.helper import format_output, save_output +from swanlab.cli.api.helper import format_output, save_output, with_custom_host @click.group("workspace") @@ -22,9 +21,9 @@ def workspace_cli(): default=None, help="Save output as JSON to current directory.", ) -def get_workspace(username: str, name): +@with_custom_host +def get_workspace(username: str, name, api): """Get Workspace info.""" - api = Api() resp = api.workspace(username).wrapper() format_output(resp) if resp.ok and name is not None: From 70437ef048ab801ca2aab6220fadd21055ae4766 Mon Sep 17 00:00:00 2001 From: neyuki778 <2597605722@qq.com> Date: Tue, 28 Apr 2026 11:38:08 +0800 Subject: [PATCH 52/52] Revert "feat: add cli -h/-k interceptor" This reverts commit 6b7b7c72ccf23f84308fd16304ac1000585e31b5. --- swanlab/cli/api/experiment.py | 7 ++++--- swanlab/cli/api/helper.py | 39 +---------------------------------- swanlab/cli/api/project.py | 7 ++++--- swanlab/cli/api/workspace.py | 7 ++++--- 4 files changed, 13 insertions(+), 47 deletions(-) diff --git a/swanlab/cli/api/experiment.py b/swanlab/cli/api/experiment.py index c241b568e..bcea5e83b 100644 --- a/swanlab/cli/api/experiment.py +++ b/swanlab/cli/api/experiment.py @@ -1,7 +1,8 @@ import click import orjson -from swanlab.cli.api.helper import format_output, save_output, with_custom_host +from swanlab.api import Api +from swanlab.cli.api.helper import format_output, save_output @click.group("run") @@ -21,9 +22,9 @@ def experiment_cli(): default=None, help="Save output as JSON to current directory.", ) -@with_custom_host -def get_experiment(path: str, name, api): +def get_experiment(path: str, name): """Get Experiment(Run) info by path (username/project/run_id).""" + api = Api() resp = api.run(path).wrapper() format_output(resp) if resp.ok and name is not None: diff --git a/swanlab/cli/api/helper.py b/swanlab/cli/api/helper.py index 5acc176a6..8b3b921fc 100644 --- a/swanlab/cli/api/helper.py +++ b/swanlab/cli/api/helper.py @@ -1,13 +1,11 @@ import enum from datetime import datetime -from functools import wraps -from typing import Callable, Optional +from typing import Optional import click import nanoid import orjson -from swanlab.api import Api from swanlab.api.typings.common import ApiResponseType @@ -15,41 +13,6 @@ class _SaveFormatEnum(enum.Enum): JSON = "json" -def with_custom_host(func: Callable) -> Callable: - """ - Add common SwanLab API host/auth options to a CLI command. - - The wrapped command receives an `api` keyword argument. When no option is - provided, the default local login settings are used. - """ - - @click.option( - "--host", - "-h", - default=None, - type=str, - help="The host of the SwanLab server.", - ) - @click.option( - "--api-key", - "--api_key", - "-k", - "api_key", - default=None, - type=str, - help="The API key to use for authentication.", - ) - @wraps(func) - def wrapper(*args, host: Optional[str], api_key: Optional[str], **kwargs): - if host is None and api_key is None: - api = Api() - else: - api = Api(host=host, api_key=api_key) - return func(*args, api=api, **kwargs) - - return wrapper - - def format_output(resp: ApiResponseType, fmt: _SaveFormatEnum = _SaveFormatEnum.JSON) -> None: if fmt == _SaveFormatEnum.JSON: click.echo(orjson.dumps(resp.json(), option=orjson.OPT_INDENT_2).decode()) diff --git a/swanlab/cli/api/project.py b/swanlab/cli/api/project.py index 05ffb8f60..86771321c 100644 --- a/swanlab/cli/api/project.py +++ b/swanlab/cli/api/project.py @@ -1,7 +1,8 @@ import click import orjson -from swanlab.cli.api.helper import format_output, save_output, with_custom_host +from swanlab.api import Api +from swanlab.cli.api.helper import format_output, save_output @click.group("project") @@ -21,9 +22,9 @@ def project_cli(): default=None, help="Save output as JSON to current directory.", ) -@with_custom_host -def get_project(path: str, name, api): +def get_project(path: str, name): """Get project info by path (username/project).""" + api = Api() resp = api.project(path).wrapper() format_output(resp) if resp.ok and name is not None: diff --git a/swanlab/cli/api/workspace.py b/swanlab/cli/api/workspace.py index 9119fdd13..ab1723560 100644 --- a/swanlab/cli/api/workspace.py +++ b/swanlab/cli/api/workspace.py @@ -1,7 +1,8 @@ import click import orjson -from swanlab.cli.api.helper import format_output, save_output, with_custom_host +from swanlab.api import Api +from swanlab.cli.api.helper import format_output, save_output @click.group("workspace") @@ -21,9 +22,9 @@ def workspace_cli(): default=None, help="Save output as JSON to current directory.", ) -@with_custom_host -def get_workspace(username: str, name, api): +def get_workspace(username: str, name): """Get Workspace info.""" + api = Api() resp = api.workspace(username).wrapper() format_output(resp) if resp.ok and name is not None: