Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ FEATURE_SUMMARY.md
GEMINI.md
QWEN.md
.omx/
CODEBUDDY.md
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ dependencies = [
"protobuf>=3.19.0,!=4.21.0,!=5.28.0,<7; sys_platform != 'linux'",
"rich>=13.6.0,<14.0.0",
"pydantic-settings>=2.8.1",
"orjson<=3.11.5; python_version == '3.9'",
"orjson; python_version > '3.9'",
]

# 对应 requirements-media.txt 和 requirements-dashboard.txt [cite: 3, 5]
Expand Down
3 changes: 3 additions & 0 deletions swanlab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from swanlab.api import Api
from swanlab.sdk import (
Audio,
Callback,
Expand Down Expand Up @@ -28,6 +29,8 @@
__version__ = helper.get_swanlab_version()

__all__ = [
# api
"Api",
# cmd
"merge_settings",
"init",
Expand Down
229 changes: 229 additions & 0 deletions swanlab/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""
@author: caddiesnew
@file: __init__.py
@time: 2026/4/20
@description: SwanLab 公共查询 API 入口,面向用户的 OOP 查询接口
"""

from typing import Optional

from swanlab.exceptions import AuthenticationError
from swanlab.sdk.internal.pkg import nrc, scope
from swanlab.sdk.internal.pkg.client import Client
from swanlab.sdk.internal.settings import settings as global_settings

from .base import BaseEntity
from .experiment import Experiment, Experiments
from .project import Project, Projects
from .selfhosted import User, Users
from .typings.common import ApiResponseType
from .workspace import Workspace, Workspaces


class Api(BaseEntity):
"""
SwanLab 公共查询 API 入口。

通过独立的 Client 实例与 SwanLab 云端交互,不与 SDK 运行时单例共享。
继承 BaseEntity 以复用 _get/_post/_put/_delete/_paginate 等安全 HTTP 方法。

用法::

from swanlab import Api

api = Api() # 自动从 .netrc 读取凭证
api = Api(api_key="...", host="...") # 显式传入凭证

resp = api.project("username/project")
if resp.ok:
project = resp.data
print(project.name)
"""

def __init__(
self,
api_key: Optional[str] = None,
host: Optional[str] = None,
web_host: Optional[str] = None,
) -> None:
"""
初始化 Api 实例。

认证优先级:显式参数 > Settings(含 .netrc / 环境变量)

:param api_key: API 密钥,为 None 时从 Settings / .netrc / 环境变量读取
:param host: API 主机地址,为 None 时从 Settings 读取
:param web_host: Web 面板地址,为 None 时从 Settings 读取
"""
api_key, api_host, resolved_web_host = self._resolve_credentials(api_key, host, web_host)
client = Client(api_key=str(api_key), base_url=api_host)
super().__init__(client, resolved_web_host, api_host)
self._login_resp = scope.get_context("login_resp")
self._username: str = self._login_resp["userInfo"]["username"] if self._login_resp else ""

def to_dict(self) -> dict:
"""Api 非数据实体,返回空字典。"""
return {}

@staticmethod
def _resolve_credentials(
api_key: Optional[str],
host: Optional[str],
web_host: Optional[str],
) -> tuple[str, str, str]:
"""
按优先级解析凭证:显式参数 > Settings(含 .netrc / 环境变量)。
返回 (api_key, api_host, web_host)。
"""
if api_key is None:
api_key = global_settings.api_key
if api_key is None:
raise AuthenticationError("No API key found. Please login with `swanlab login` or pass api_key parameter.")

api_host: str = nrc.fmt(host) if host is not None else global_settings.api_host
resolved_web_host: str = nrc.fmt(web_host) if web_host is not None else global_settings.web_host

return api_key, api_host, resolved_web_host

# ------------------------------------------------------------------
# 实体查询方法 — 统一返回 ApiResponse
# ------------------------------------------------------------------

def workspace(self, username: Optional[str] = None) -> ApiResponseType:
"""
获取工作空间信息,默认为当前登录用户的工作空间。

:param username: 指定工作空间用户名,为 None 时使用当前登录用户
"""
if username is None:
username = self._username
resp = self._get(f"/group/{username}")
if resp.ok:
return ApiResponseType(
ok=True,
data=Workspace(
self._client,
self._web_host,
self._api_host,
username=username,
data=resp.data,
),
)
return resp

def workspaces(self, username: Optional[str] = None) -> ApiResponseType:
"""
获取工作空间列表迭代器。

:param username: 指定用户名,为 None 时使用当前登录用户
"""
if username is None:
username = self._username
return ApiResponseType(
ok=True,
data=Workspaces(self._client, self._web_host, self._api_host, username=username),
)

def project(self, path: str) -> ApiResponseType:
"""
获取项目信息。

:param path: 项目路径,格式为 'username/project-name'
"""
resp = self._get(f"/project/{path}")
if resp.ok:
return ApiResponseType(
ok=True,
data=Project(self._client, self._web_host, self._api_host, path=path, data=resp.data),
)
return resp

def projects(
self,
path: str,
sort: Optional[str] = None,
search: Optional[str] = None,
detail: Optional[bool] = True,
) -> ApiResponseType:
"""
获取工作空间下的项目列表迭代器。

:param path: 工作空间名称 'username'
:param sort: 排序方式
:param search: 搜索关键词
:param detail: 是否返回详细信息
"""
return ApiResponseType(
ok=True,
data=Projects(
self._client, self._web_host, self._api_host, path=path, sort=sort, search=search, detail=detail
),
)

def run(self, path: str) -> ApiResponseType:
"""
获取单个实验。

:param path: 实验路径,格式为 'username/project/run_id'
"""
parts = path.split("/")
if len(parts) != 3:
return ApiResponseType(
ok=False, errmsg=f"Invalid path '{path}'. Expected format: 'username/project/run_id'"
)
proj_path = path.rsplit("/", 1)[0]
expid = parts[2]
resp = self._get(f"/project/{proj_path}/runs/{expid}")
if resp.ok:
return ApiResponseType(
ok=True,
data=Experiment(
self._client,
self._web_host,
self._api_host,
path=proj_path,
data=resp.data,
),
)
return resp

def runs(self, path: str, filters: Optional[dict] = None) -> ApiResponseType:
"""
获取项目下的实验列表迭代器。

:param path: 项目路径,格式为 'username/project'
:param filters: 筛选条件
"""
return ApiResponseType(
ok=True,
data=Experiments(self._client, self._web_host, self._api_host, path=path, filters=filters),
)

def user(self, username: Optional[str] = None) -> ApiResponseType:
"""
获取用户信息,默认为当前登录用户。

:param username: 指定用户名
"""
return ApiResponseType(
ok=True,
data=User(
self._client,
self._web_host,
self._api_host,
username=username or self._username,
login_user=self._username,
),
)

def users(self) -> ApiResponseType:
"""
获取用户列表迭代器(私有化部署管理员限定)。
"""
return ApiResponseType(
ok=True,
data=Users(self._client, self._web_host, self._api_host, login_user=self._username),
)


__all__ = ["Api"]
90 changes: 90 additions & 0 deletions swanlab/api/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
@author: caddiesnew
@file: base.py
@time: 2026/4/20
@description: 所有实体类的公共基类
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional

from swanlab.sdk.internal.pkg import safe

from .typings.common import ApiResponseType

if TYPE_CHECKING:
from swanlab.sdk.internal.pkg.client import Client


class BaseEntity(ABC):
"""
swanlab/api 实体类公共基类。

统一持有 _client、_web_host 和 _api_host,提供 _get/_post/_put/_delete HTTP 快捷方法和 _paginate 分页迭代。
所有 HTTP 请求通过 _safe_request 包裹,保证任何异常都不会导致程序 crash,统一返回 ApiResponse。
子类只需实现 to_dict() 和业务逻辑。
"""

def __init__(self, client: "Client", web_host: str, api_host: str) -> None:
self._client: "Client" = client
self._web_host: str = web_host
self._api_host: str = api_host
self._errors: list[str] = []

@abstractmethod
def to_dict(self) -> Dict[str, Any]:
"""将实体序列化为 JSON 可序列化的字典。"""

def _safe_request(self, method: Callable, path: str, **kwargs) -> ApiResponseType:
"""安全请求包装:捕获所有异常,始终返回 ApiResponse 而不抛出。"""

def _on_error(e: BaseException) -> None:
_err_msg[0] = str(e)

_err_msg: list[Optional[str]] = [None]
with safe.block(message=f"API request failed: {path}", on_error=_on_error):
data = method(path, **kwargs).data
return ApiResponseType(ok=True, data=data)
result = ApiResponseType(ok=False, errmsg=_err_msg[0] or "request failed")
self._errors.append(result.errmsg)
return result

def _get(self, path: str, **kwargs) -> ApiResponseType:
return self._safe_request(self._client.get, path, **kwargs)

def _post(self, path: str, **kwargs) -> ApiResponseType:
return self._safe_request(self._client.post, path, **kwargs)

def _put(self, path: str, **kwargs) -> ApiResponseType:
return self._safe_request(self._client.put, path, **kwargs)

def _delete(self, path: str, **kwargs) -> ApiResponseType:
return self._safe_request(self._client.delete, path, **kwargs)

def _build_url(self, path: str) -> str:
return f"{self._api_host}/{path}"

def _paginate(self, path: str, *, page_size: int = 20, params: Optional[dict] = None) -> Iterator[dict]:
"""通用分页迭代器,自动处理 page/size 参数。"""
page = 1
while True:
p = {"page": page, "size": page_size}
if params:
p.update({k: v for k, v in params.items() if v is not None})
resp = self._get(path, params=p)
if not resp.ok:
return
body = resp.data
items = body.get("list", []) if isinstance(body, dict) else body
if not items:
break
yield from items
total_pages = body.get("pages", 1) if isinstance(body, dict) else 1
if page >= total_pages:
break
page += 1

def __repr__(self) -> str:
cls = self.__class__.__name__
ident = getattr(self, "_path", None) or getattr(self, "_username", None) or "?"
return f"{cls}('{ident}')"
Loading
Loading