diff --git a/music_assistant/providers/yandex_smarthome/__init__.py b/music_assistant/providers/yandex_smarthome/__init__.py new file mode 100644 index 0000000000..8e02b43438 --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/__init__.py @@ -0,0 +1,592 @@ +""" +Yandex Smart Home Plugin Provider for Music Assistant. + +Exposes Music Assistant players to Yandex Alice via the Yandex Smart Home API. +Allows voice control of MA players through Alice commands like +"Алиса, включи музыку на [имя плеера]". + +Architecture: + Alice voice command → Yandex Cloud → Smart Home API callback → this plugin → MA Player + +The plugin registers MA players as media_device in Yandex Smart Home, +mapping capabilities (on_off, volume, pause) to MA player controls. + +Reference: https://github.com/dext0r/yandex_smart_home +""" + +from __future__ import annotations + +import logging +import uuid +from typing import TYPE_CHECKING, cast + +import aiohttp +from music_assistant_models.config_entries import ConfigEntry, ConfigValueOption +from music_assistant_models.enums import ConfigEntryType, ProviderFeature + +from ._compat import SecretStr +from .cloud import get_cloud_otp, register_cloud_instance +from .constants import ( + CLOUD_OAUTH_AUTHORIZE_URL, + CLOUD_OAUTH_TOKEN_URL, + CLOUD_SKILL_CLIENT_ID_TEMPLATE, + CLOUD_SKILL_CLIENT_SECRET, + CLOUD_SKILL_WEBHOOK_TEMPLATE, + CONF_ACTION_GET_OTP, + CONF_ACTION_REGISTER, + CONF_CLOUD_CONNECTION_TOKEN, + CONF_CLOUD_INSTANCE_ID, + CONF_CLOUD_INSTANCE_PASSWORD, + CONF_CONNECTION_TYPE, + CONF_DIRECT_ACCESS_TOKEN, + CONF_DIRECT_CLIENT_SECRET, + CONF_EXPOSED_PLAYERS, + CONF_INSTANCE_NAME, + CONF_SKILL_ID, + CONF_SKILL_TOKEN, + CONNECTION_TYPE_CLOUD, + CONNECTION_TYPE_CLOUD_PLUS, + CONNECTION_TYPE_DIRECT, + DIRECT_API_BASE_PATH, + DIRECT_AUTH_BASE_PATH, + DIRECT_OAUTH_CLIENT_ID, + YANDEX_DIALOGS_DEVELOPER_URL, + YANDEX_OAUTH_URL, +) +from .plugin import YandexSmartHomePlugin + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ConfigValueType, ProviderConfig + from music_assistant_models.provider import ProviderManifest + + from music_assistant.mass import MusicAssistant + from music_assistant.models import ProviderInstanceType + +_LOGGER = logging.getLogger(__name__) + +SUPPORTED_FEATURES: set[ProviderFeature] = set() + + +def _build_status_label(otp_code: str | None, is_cloud_plus: bool, is_registered: bool) -> str: + """Build the status label text based on registration state.""" + if otp_code and is_cloud_plus: + return ( + "✅ Cloud instance registered! " + "Open Yandex app → Devices → Add device → Smart Home → " + "find your private skill → enter OTP code below → " + "then click Save to complete setup." + ) + if otp_code: + return ( + "✅ Cloud instance registered! " + "Open Yandex app → Devices → Add device → Smart Home → " + "find 'Yaha Cloud' skill → enter OTP code below → " + "then click Save to complete setup." + ) + if is_registered: + return ( + "✅ Cloud instance is configured. " + "Use 'Get OTP code' if you need to re-link with Yandex." + ) + return ( + "Register a cloud instance to connect with Yandex Alice. " + "This is free and uses the yaha-cloud.ru relay service (no public URL needed)." + ) + + +def _build_cloud_plus_label(is_cloud_plus: bool, is_registered: bool) -> str: + """Build the Cloud Plus instruction label.""" + if not is_cloud_plus: + return "" + if is_registered: + return ( + "Cloud Plus setup: " + "1) Open Yandex.Dialogs console (link below) → Smart Home → Create skill. " + "2) Fill 'Basic info': Backend URL = webhook URL below, Access = Private. " + "3) Save, then fill 'Account linking' section with values below. " + "4) Save & Publish. " + "5) Get OAuth token → enter skill_id and token → Save." + ) + return ( + "Cloud Plus mode requires a private skill in Yandex.Dialogs. " + "First register a cloud instance, then follow the setup instructions." + ) + + +async def setup( + mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig +) -> ProviderInstanceType: + """Initialize provider(instance) with given configuration.""" + return YandexSmartHomePlugin(mass, manifest, config, SUPPORTED_FEATURES) + + +async def _handle_config_actions( + mass: MusicAssistant, + action: str | None, + values: dict[str, ConfigValueType], + instance_id: str | None, + is_cloud_plus: bool, +) -> str | None: + """Execute register/OTP actions and return OTP code if obtained.""" + saved_config = None + if instance_id: + prov = mass.get_provider(instance_id) + if prov: + saved_config = prov.config + + if action == CONF_ACTION_REGISTER: + try: + platform = "yandex" if is_cloud_plus else None + async with aiohttp.ClientSession() as session: + data = await register_cloud_instance(session, platform=platform) + values[CONF_CLOUD_INSTANCE_ID] = data["id"] + values[CONF_CLOUD_INSTANCE_PASSWORD] = data["password"] + values[CONF_CLOUD_CONNECTION_TOKEN] = data["connection_token"] + _LOGGER.info("Auto-registered cloud instance: %s", data["id"]) + except Exception: + _LOGGER.exception("Failed to register cloud instance") + + otp_code: str | None = None + if action == CONF_ACTION_GET_OTP: + cloud_id = str(values.get(CONF_CLOUD_INSTANCE_ID, "")) + cloud_token = "" + if saved_config: + cloud_token = str(saved_config.get_value(CONF_CLOUD_CONNECTION_TOKEN) or "") + if not cloud_token: + cloud_token = str(values.get(CONF_CLOUD_CONNECTION_TOKEN, "")) + if cloud_id and cloud_token: + try: + async with aiohttp.ClientSession() as session: + otp_code = await get_cloud_otp(session, cloud_id, SecretStr(cloud_token)) + except Exception: + _LOGGER.exception("Failed to get OTP code") + + if action == CONF_ACTION_REGISTER and not otp_code: + cloud_id = str(values.get(CONF_CLOUD_INSTANCE_ID, "")) + cloud_token = str(values.get(CONF_CLOUD_CONNECTION_TOKEN, "")) + if cloud_id and cloud_token: + try: + async with aiohttp.ClientSession() as session: + otp_code = await get_cloud_otp(session, cloud_id, SecretStr(cloud_token)) + except Exception: + _LOGGER.exception("Failed to get OTP after registration") + + return otp_code + + +async def get_config_entries( + mass: MusicAssistant, + instance_id: str | None = None, + action: str | None = None, + values: dict[str, ConfigValueType] | None = None, +) -> tuple[ConfigEntry, ...]: + """Return Config entries to setup this provider.""" + if values is None: + values = {} + + connection_type = str(values.get(CONF_CONNECTION_TYPE, CONNECTION_TYPE_CLOUD)) + is_cloud_plus = connection_type == CONNECTION_TYPE_CLOUD_PLUS + is_direct = connection_type == CONNECTION_TYPE_DIRECT + + otp_code = await _handle_config_actions(mass, action, values, instance_id, is_cloud_plus) + + is_registered = bool(values.get(CONF_CLOUD_INSTANCE_ID)) and bool( + values.get(CONF_CLOUD_CONNECTION_TOKEN) + ) + cloud_instance_id = str(values.get(CONF_CLOUD_INSTANCE_ID, "")) + + label_text = _build_status_label(otp_code, is_cloud_plus, is_registered) + cloud_plus_label = _build_cloud_plus_label(is_cloud_plus, is_registered) + + # Compute copyable values for Cloud Plus mode + webhook_url = "" + client_id = "" + if is_cloud_plus and is_registered: + webhook_url = CLOUD_SKILL_WEBHOOK_TEMPLATE + client_id = CLOUD_SKILL_CLIENT_ID_TEMPLATE.format(instance_id=cloud_instance_id) + + # Compute direct mode endpoint URLs + direct_base_url = "" + direct_auth_url = "" + direct_token_url = "" + if is_direct: + try: + ma_base_url = mass.webserver.base_url.rstrip("/") + except Exception: + ma_base_url = "https://" + direct_base_url = f"{ma_base_url}{DIRECT_API_BASE_PATH}" + direct_auth_url = f"{ma_base_url}{DIRECT_AUTH_BASE_PATH}/authorize" + direct_token_url = f"{ma_base_url}{DIRECT_AUTH_BASE_PATH}/token" + + # Build player options for exposed players filter + player_options: list[ConfigValueOption] = [] + try: + for player in mass.players.all_players(): + state = player.state + player_options.append( + ConfigValueOption(title=state.name or state.player_id, value=state.player_id) + ) + except Exception: # noqa: S110 + pass + + return ( + # Instance name + ConfigEntry( + key=CONF_INSTANCE_NAME, + type=ConfigEntryType.STRING, + label="Instance Name", + description=( + "Name of this MA instance as it will appear in Yandex Smart Home. " + "Alice will use this name for voice commands, e.g. " + '"Алиса, включи музыку на [имя]".' + ), + required=False, + default_value="Music Assistant", + ), + # Connection type selector + ConfigEntry( + key=CONF_CONNECTION_TYPE, + type=ConfigEntryType.STRING, + label="Connection Type", + description=( + '"cloud" — public Yaha Cloud skill (simple setup). ' + '"cloud_plus" — private skill via cloud relay (for multi-platform setups). ' + '"direct" — Yandex calls your MA server directly (requires public HTTPS URL).' + ), + required=False, + default_value=CONNECTION_TYPE_CLOUD, + options=[ + ConfigValueOption(title="Cloud (public Yaha Cloud skill)", value="cloud"), + ConfigValueOption(title="Cloud Plus (private skill)", value="cloud_plus"), + ConfigValueOption(title="Direct (no relay, requires public URL)", value="direct"), + ], + advanced=True, + ), + # Status label (cloud modes only) + ConfigEntry( + key="label_status", + type=ConfigEntryType.LABEL, + label=label_text, + hidden=is_direct, + ), + # OTP code — copyable text field (shown only when OTP is available) + ConfigEntry( + key="otp_code", + type=ConfigEntryType.STRING, + label="OTP Code", + description="Copy this code and enter it in the Yandex app.", + required=False, + value=otp_code, + hidden=not otp_code or is_direct, + ), + # Register action (hidden after registration or in direct mode) + ConfigEntry( + key=CONF_ACTION_REGISTER, + type=ConfigEntryType.ACTION, + label="Register cloud instance", + description="Register a new instance on yaha-cloud.ru relay service.", + action=CONF_ACTION_REGISTER, + action_label="Register with cloud", + hidden=is_registered or is_direct, + ), + # Get OTP action (shown after registration, hidden in direct mode) + ConfigEntry( + key=CONF_ACTION_GET_OTP, + type=ConfigEntryType.ACTION, + label="Get OTP code", + description="Get a fresh one-time password to link with Yandex Smart Home app.", + action=CONF_ACTION_GET_OTP, + action_label="Get OTP code", + hidden=not is_registered or is_direct, + ), + # --- Direct connection section --- + ConfigEntry( + key="label_direct", + type=ConfigEntryType.LABEL, + label=( + "Direct connection setup: " + "1) Create a private skill in Yandex.Dialogs (Smart Home type). " + "2) Set Backend URL, Authorization URL, Token URL from values below. " + "3) Set Client ID and Client Secret from values below. " + "4) Publish skill, then link account in Yandex app. " + "5) Fill Skill ID and Skill Token below and Save." + ), + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_DIRECT, + category="Direct Connection Setup", + ), + # Yandex Dialogs developer console link (direct) + ConfigEntry( + key="direct_dialogs_url", + type=ConfigEntryType.STRING, + label="Yandex.Dialogs Console (create skill here)", + required=False, + default_value=YANDEX_DIALOGS_DEVELOPER_URL, + help_link=YANDEX_DIALOGS_DEVELOPER_URL, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_DIRECT, + category="Direct Connection Setup", + ), + # Backend URL (for Yandex.Dialogs skill config) + ConfigEntry( + key="direct_backend_url", + type=ConfigEntryType.STRING, + label="Backend URL (→ Basic info)", + description="Copy to your skill's Backend URL field in Yandex.Dialogs.", + required=False, + value=direct_base_url or None, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_DIRECT, + category="Copy to Yandex.Dialogs skill", + ), + # Authorization URL (direct) + ConfigEntry( + key="direct_auth_url", + type=ConfigEntryType.STRING, + label="Authorization URL (→ Account linking)", + description="Copy to 'Account linking' → 'Authorization URL' field.", + required=False, + value=direct_auth_url or None, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_DIRECT, + category="Copy to Yandex.Dialogs skill", + ), + # Token URL (direct) + ConfigEntry( + key="direct_token_url", + type=ConfigEntryType.STRING, + label="Token URL (→ Account linking, both fields)", + description=("Copy to both 'Token endpoint' and 'Refresh token URL' fields."), + required=False, + value=direct_token_url or None, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_DIRECT, + category="Copy to Yandex.Dialogs skill", + ), + # Client ID (direct — always the same) + ConfigEntry( + key="direct_client_id", + type=ConfigEntryType.STRING, + label="Client ID (→ Account linking)", + description="Copy to 'Account linking' → 'Client identifier' field.", + required=False, + default_value=DIRECT_OAUTH_CLIENT_ID, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_DIRECT, + category="Copy to Yandex.Dialogs skill", + ), + # Client Secret (direct — auto-generated per install) + ConfigEntry( + key=CONF_DIRECT_CLIENT_SECRET, + type=ConfigEntryType.SECURE_STRING, + label="Client Secret (→ Account linking)", + description=( + "Copy to 'Account linking' → 'Client secret' field. Auto-generated on first setup." + ), + required=False, + default_value=( + cast("str", values.get(CONF_DIRECT_CLIENT_SECRET)) + if values and values.get(CONF_DIRECT_CLIENT_SECRET) + else uuid.uuid4().hex + ), + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_DIRECT, + category="Copy to Yandex.Dialogs skill", + ), + # OAuth URL for getting skill token (direct) + ConfigEntry( + key="direct_oauth_url", + type=ConfigEntryType.STRING, + label="OAuth URL (open to get skill token)", + required=False, + default_value=YANDEX_OAUTH_URL, + help_link=YANDEX_OAUTH_URL, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_DIRECT, + category="Fill in from Yandex.Dialogs", + ), + # Skill ID (cloud_plus and direct) + ConfigEntry( + key=CONF_SKILL_ID, + type=ConfigEntryType.STRING, + label="Skill ID", + description=( + "UUID of your private Smart Home skill from Yandex.Dialogs. " + "Find it in the skill URL: /developer/skills/{skill_id}/" + ), + required=False, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value_not=CONNECTION_TYPE_CLOUD, + category="Fill in from Yandex.Dialogs", + ), + # Skill OAuth Token (cloud_plus and direct) + ConfigEntry( + key=CONF_SKILL_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Skill OAuth Token", + description="Paste the OAuth token obtained from the URL above.", + required=False, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value_not=CONNECTION_TYPE_CLOUD, + category="Fill in from Yandex.Dialogs", + ), + # --- Cloud Plus section (advanced) --- + # Cloud Plus instructions + ConfigEntry( + key="label_cloud_plus", + type=ConfigEntryType.LABEL, + label=cloud_plus_label, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_CLOUD_PLUS, + advanced=True, + category="Cloud Plus Setup", + ), + # Yandex Dialogs developer console link + ConfigEntry( + key="dialogs_url", + type=ConfigEntryType.STRING, + label="Yandex.Dialogs Console (create skill here)", + required=False, + default_value=YANDEX_DIALOGS_DEVELOPER_URL, + help_link=YANDEX_DIALOGS_DEVELOPER_URL, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_CLOUD_PLUS, + advanced=True, + category="Cloud Plus Setup", + ), + # --- Copy to Yandex.Dialogs --- + # Webhook URL + ConfigEntry( + key="webhook_url", + type=ConfigEntryType.STRING, + label="Backend URL (→ Basic info)", + description="Copy and paste into your private skill's Backend URL field.", + required=False, + value=webhook_url or None, + hidden=not webhook_url, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_CLOUD_PLUS, + advanced=True, + category="Copy to Yandex.Dialogs skill", + ), + # Client ID + ConfigEntry( + key="skill_client_id", + type=ConfigEntryType.STRING, + label="Client ID (→ Account linking)", + description="Copy to 'Account linking' → 'Client identifier' field.", + required=False, + value=client_id or None, + hidden=not client_id, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_CLOUD_PLUS, + advanced=True, + category="Copy to Yandex.Dialogs skill", + ), + # Client Secret + ConfigEntry( + key="skill_client_secret", + type=ConfigEntryType.STRING, + label="Client Secret (→ Account linking)", + description="Copy to 'Account linking' → 'Client secret' field.", + required=False, + default_value=CLOUD_SKILL_CLIENT_SECRET, + hidden=not is_registered, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_CLOUD_PLUS, + advanced=True, + category="Copy to Yandex.Dialogs skill", + ), + # Authorization URL + ConfigEntry( + key="skill_auth_url", + type=ConfigEntryType.STRING, + label="Authorization URL (→ Account linking)", + description="Copy to 'Account linking' → 'Authorization URL' field.", + required=False, + default_value=CLOUD_OAUTH_AUTHORIZE_URL, + hidden=not is_registered, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_CLOUD_PLUS, + advanced=True, + category="Copy to Yandex.Dialogs skill", + ), + # Token URL + ConfigEntry( + key="skill_token_url", + type=ConfigEntryType.STRING, + label="Token URL (→ Account linking, both fields)", + description=( + "Copy to both 'Token endpoint' and 'Refresh token URL' fields " + "in the 'Account linking' section." + ), + required=False, + default_value=CLOUD_OAUTH_TOKEN_URL, + hidden=not is_registered, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_CLOUD_PLUS, + advanced=True, + category="Copy to Yandex.Dialogs skill", + ), + # OAuth URL — link to get skill token (Cloud Plus) + ConfigEntry( + key="oauth_url", + type=ConfigEntryType.STRING, + label="OAuth URL (open to get token)", + required=False, + default_value=YANDEX_OAUTH_URL, + help_link=YANDEX_OAUTH_URL, + hidden=not is_registered, + depends_on=CONF_CONNECTION_TYPE, + depends_on_value=CONNECTION_TYPE_CLOUD_PLUS, + advanced=True, + category="Fill in from Yandex.Dialogs", + ), + # --- Player filter --- + ConfigEntry( + key=CONF_EXPOSED_PLAYERS, + type=ConfigEntryType.STRING, + label="Exposed Players", + description=( + "Select which MA players to expose to Yandex Smart Home. " + "Leave empty to expose all players." + ), + required=False, + multi_value=True, + default_value=[], + options=list(player_options) if player_options else [], + ), + # --- Auto-managed fields (hidden, populated by actions) --- + ConfigEntry( + key=CONF_CLOUD_INSTANCE_ID, + type=ConfigEntryType.STRING, + label="Cloud Instance ID", + hidden=True, + required=False, + value=cast("str", values.get(CONF_CLOUD_INSTANCE_ID)) if values else None, + ), + ConfigEntry( + key=CONF_CLOUD_INSTANCE_PASSWORD, + type=ConfigEntryType.SECURE_STRING, + label="Cloud Instance Password", + hidden=True, + required=False, + value=(cast("str", values.get(CONF_CLOUD_INSTANCE_PASSWORD)) if values else None), + ), + ConfigEntry( + key=CONF_CLOUD_CONNECTION_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Cloud Connection Token", + hidden=True, + required=False, + value=(cast("str", values.get(CONF_CLOUD_CONNECTION_TOKEN)) if values else None), + ), + ConfigEntry( + key=CONF_DIRECT_ACCESS_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Direct Access Token", + hidden=True, + required=False, + value=(cast("str", values.get(CONF_DIRECT_ACCESS_TOKEN)) if values else None), + ), + ) diff --git a/music_assistant/providers/yandex_smarthome/_compat.py b/music_assistant/providers/yandex_smarthome/_compat.py new file mode 100644 index 0000000000..6b747cbd01 --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/_compat.py @@ -0,0 +1,31 @@ +"""Compatibility shim for ya-passport-auth. + +Provides a single source of `SecretStr` for the provider. When `ya-passport-auth` +is installed (the normal runtime case — declared in manifest.json), we re-export +the real implementation. When it's missing (bare test envs, pre-install linting) +we expose a minimal drop-in so importing the provider package doesn't crash. + +Centralized here to avoid duplicating the fallback across modules. +""" + +from __future__ import annotations + +try: + from ya_passport_auth import SecretStr +except ImportError: + + class SecretStr: # type: ignore[no-redef] + """Minimal fallback when ya-passport-auth is not yet installed.""" + + def __init__(self, value: str) -> None: + """Initialize with a secret value.""" + if not value: + raise ValueError("SecretStr value must not be empty") + self._value = value + + def get_secret(self) -> str: + """Return the secret value.""" + return self._value + + +__all__ = ["SecretStr"] diff --git a/music_assistant/providers/yandex_smarthome/cloud.py b/music_assistant/providers/yandex_smarthome/cloud.py new file mode 100644 index 0000000000..35483504e4 --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/cloud.py @@ -0,0 +1,198 @@ +"""Cloud connection manager for Yandex Smart Home via yaha-cloud.ru relay. + +Manages a persistent WebSocket connection to the yaha-cloud.ru relay service. +Incoming Yandex Smart Home API requests are received over WS, processed by +the on_request callback, and the response is sent back over WS. + +Adapted from dext0r/yandex_smart_home cloud.py, stripped of HA dependencies. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +import aiohttp + +if TYPE_CHECKING: + from ._compat import SecretStr + +from .constants import ( + CLOUD_BASE_URL, + CLOUD_HEARTBEAT_INTERVAL, + CLOUD_RECONNECT_MAX, + CLOUD_RECONNECT_MIN, + CLOUD_REGISTER_URL, + CLOUD_WS_URL, +) +from .schema import CloudRequest + +_LOGGER = logging.getLogger(__name__) + + +class CloudManager: + """Manages WebSocket connection to yaha-cloud.ru for Smart Home API relay.""" + + def __init__( + self, + session: aiohttp.ClientSession, + connection_token: SecretStr, + on_request: Callable[[CloudRequest], Awaitable[dict[str, Any]]], + logger: logging.Logger | None = None, + ) -> None: + """Initialize cloud relay manager.""" + self._session = session + self._token = connection_token + self._on_request = on_request + self._logger = logger or _LOGGER + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._running = False + self._reconnect_delay = CLOUD_RECONNECT_MIN + + @property + def connected(self) -> bool: + """Return True if WebSocket is connected.""" + return self._ws is not None and not self._ws.closed + + async def connect(self) -> None: + """Start the WebSocket connection loop (runs until disconnect is called).""" + self._running = True + while self._running: + try: + await self._connect_once() + except asyncio.CancelledError: + break + except Exception: + if not self._running: + break # type: ignore[unreachable] + self._logger.exception( + "Cloud connection error, reconnecting in %ds", self._reconnect_delay + ) + if not self._running: + break # type: ignore[unreachable] + # Backoff before reconnect (both after errors and clean disconnects) + await asyncio.sleep(self._reconnect_delay) + self._reconnect_delay = min(self._reconnect_delay * 2, CLOUD_RECONNECT_MAX) + + async def _connect_once(self) -> None: + """Single WebSocket connection attempt + message loop.""" + headers = {"Authorization": f"Bearer {self._token.get_secret()}"} + async with self._session.ws_connect( + CLOUD_WS_URL, + headers=headers, + heartbeat=CLOUD_HEARTBEAT_INTERVAL, + ) as ws: + self._ws = ws + self._reconnect_delay = CLOUD_RECONNECT_MIN + self._logger.info("Connected to cloud relay at %s", CLOUD_WS_URL) + + async for msg in ws: + if not self._running: + break + + if msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(msg.data) + except json.JSONDecodeError: + self._logger.warning("Received invalid JSON from cloud relay: %r", msg.data) + continue + await self._handle_message(ws, data) + elif msg.type == aiohttp.WSMsgType.ERROR: + self._logger.error("WebSocket error: %s", ws.exception()) + break + elif msg.type in ( + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + aiohttp.WSMsgType.CLOSED, + ): + break + + self._ws = None + self._logger.info("Cloud relay connection closed") + + async def _handle_message( + self, ws: aiohttp.ClientWebSocketResponse, data: dict[str, Any] + ) -> None: + """Parse incoming WS message, call handler, and send response.""" + try: + # message may be a JSON string or already parsed dict + raw_message = data.get("message") + if isinstance(raw_message, str) and raw_message: + raw_message = json.loads(raw_message) + request = CloudRequest( + request_id=data["request_id"], + action=data["action"], + message=raw_message if isinstance(raw_message, dict) else None, + ) + self._logger.debug("Cloud request: action=%s", request.action) + response = await self._on_request(request) + await ws.send_json(response) + except Exception: + self._logger.exception("Error handling cloud message: %s", data) + # Send best-effort error response so the relay doesn't hang + request_id = data.get("request_id") if isinstance(data, dict) else None + if request_id and ws and not ws.closed: + try: + await ws.send_json( + {"request_id": request_id, "payload": {"error": "INTERNAL_ERROR"}} + ) + except Exception: + self._logger.debug("Failed to send error response for %s", request_id) + + async def disconnect(self) -> None: + """Stop the connection loop and close WebSocket.""" + self._running = False + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + self._logger.info("Cloud relay disconnected") + + +# --------------------------------------------------------------------------- +# Cloud instance registration helpers +# --------------------------------------------------------------------------- + + +async def register_cloud_instance( + session: aiohttp.ClientSession, + platform: str | None = None, +) -> dict[str, str]: + """Register a new cloud instance on yaha-cloud.ru. + + Returns dict with 'id', 'password', 'connection_token'. + No authentication is required — the relay auto-generates credentials. + + For Cloud Plus mode, pass platform="yandex" so the relay can validate + the client_id during OAuth account linking. + """ + kwargs: dict[str, Any] = {} + if platform: + kwargs["json"] = {"platform": platform} + async with session.post(CLOUD_REGISTER_URL, **kwargs) as resp: + resp.raise_for_status() + # yaha-cloud.ru may return text/plain content-type for JSON + data = await resp.json(content_type=None) + _LOGGER.info("Registered cloud instance: %s", data.get("id")) + return dict(data) + + +async def get_cloud_otp( + session: aiohttp.ClientSession, + instance_id: str, + token: SecretStr, +) -> str: + """Get a one-time password for linking the instance in the Yandex app. + + User enters this OTP in the Yandex Smart Home app to link their account. + The token parameter is the connection_token from registration. + """ + url = f"{CLOUD_BASE_URL}/api/home_assistant/v1/instance/{instance_id}/otp" + headers = {"Authorization": f"Bearer {token.get_secret()}"} + async with session.post(url, headers=headers) as resp: + resp.raise_for_status() + # yaha-cloud.ru may return text/plain content-type for JSON + data = await resp.json(content_type=None) + return str(data["code"]) diff --git a/music_assistant/providers/yandex_smarthome/constants.py b/music_assistant/providers/yandex_smarthome/constants.py new file mode 100644 index 0000000000..3f70005dab --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/constants.py @@ -0,0 +1,122 @@ +"""Constants for Yandex Smart Home provider.""" + +from __future__ import annotations + +# --------------------------------------------------------------------------- +# Config entry keys +# --------------------------------------------------------------------------- +CONF_INSTANCE_NAME = "instance_name" +CONF_CONNECTION_TYPE = "connection_type" +CONF_CLOUD_INSTANCE_ID = "cloud_instance_id" +CONF_CLOUD_INSTANCE_PASSWORD = "cloud_instance_password" +CONF_CLOUD_CONNECTION_TOKEN = "cloud_connection_token" +CONF_SKILL_ID = "skill_id" +CONF_SKILL_TOKEN = "skill_token" +CONF_EXPOSED_PLAYERS = "exposed_players" + +# --------------------------------------------------------------------------- +# Config actions +# --------------------------------------------------------------------------- +CONF_ACTION_REGISTER = "register_cloud" +CONF_ACTION_GET_OTP = "get_otp" + +# --------------------------------------------------------------------------- +# Connection types +# --------------------------------------------------------------------------- +CONNECTION_TYPE_CLOUD = "cloud" +CONNECTION_TYPE_CLOUD_PLUS = "cloud_plus" +CONNECTION_TYPE_DIRECT = "direct" + +# --------------------------------------------------------------------------- +# Cloud relay — yaha-cloud.ru (dext0r's relay service) +# --------------------------------------------------------------------------- +CLOUD_BASE_URL = "https://yaha-cloud.ru" +CLOUD_WS_URL = "wss://yaha-cloud.ru/api/home_assistant/v1/connect" +CLOUD_REGISTER_URL = f"{CLOUD_BASE_URL}/api/home_assistant/v1/instance/register" +CLOUD_CALLBACK_URL = f"{CLOUD_BASE_URL}/api/home_assistant/v2/callback" +CLOUD_OAUTH_AUTHORIZE_URL = f"{CLOUD_BASE_URL}/oauth/authorize" +CLOUD_OAUTH_TOKEN_URL = f"{CLOUD_BASE_URL}/oauth/token" + +# Platform identifier sent to the cloud relay +CLOUD_PLATFORM = "music_assistant" + +# Account linking template: client_id = "yandex_smart_home:{instance_id}" +CLOUD_SKILL_CLIENT_ID_TEMPLATE = "yandex_smart_home:{instance_id}" +# Account linking: fixed client_secret required by yaha-cloud.ru relay protocol. +# This is NOT a per-install secret — the relay expects exactly this value. +# Direct mode uses a per-install auto-generated secret instead (see CONF_DIRECT_CLIENT_SECRET). +CLOUD_SKILL_CLIENT_SECRET = "secret" + +# --------------------------------------------------------------------------- +# Cloud Plus / Direct mode — Yandex Dialogs API +# --------------------------------------------------------------------------- +YANDEX_DIALOGS_CALLBACK_BASE = "https://dialogs.yandex.net/api/v1/skills" +YANDEX_DIALOGS_DEVELOPER_URL = "https://dialogs.yandex.ru/developer/smart-home" +YANDEX_OAUTH_URL = "https://oauth.yandex.ru/authorize?response_type=token&client_id=c473ca268cd749d3a8371351a8f2bcbd" + +# Webhook URL template for yaha-cloud relay (private skill points here) +CLOUD_SKILL_WEBHOOK_TEMPLATE = "https://yaha-cloud.ru/api/yandex_smart_home" + +# --------------------------------------------------------------------------- +# Direct connection — HTTP endpoints on MA webserver +# --------------------------------------------------------------------------- +DIRECT_API_BASE_PATH = "/api/yandex_smarthome/v1.0" +DIRECT_AUTH_BASE_PATH = "/api/yandex_smarthome/auth" +DIRECT_HEALTH_RESPONSE = "Yandex Smart Home for Music Assistant" +CONF_DIRECT_ACCESS_TOKEN = "direct_access_token" +CONF_DIRECT_CLIENT_SECRET = "direct_client_secret" +DIRECT_OAUTH_CLIENT_ID = "https://social.yandex.net/" +OAUTH_CODE_EXPIRY = 300 # pending authorization codes expire after 5 minutes +MAX_PENDING_CODES = 20 # hard cap on concurrent pending authorization codes (DoS protection) + +# --------------------------------------------------------------------------- +# Timing (seconds) +# --------------------------------------------------------------------------- +STATE_REPORT_DELAY = 1.0 # debounce window for batched state reports +STATE_HEARTBEAT_INTERVAL = 3600 # report all states hourly +STATE_INITIAL_REPORT_DELAY = 15 # initial report after startup +CLOUD_RECONNECT_MIN = 2 # initial reconnect delay +CLOUD_RECONNECT_MAX = 180 # max reconnect delay (exponential backoff cap) +CLOUD_HEARTBEAT_INTERVAL = 45 # WebSocket heartbeat + +# --------------------------------------------------------------------------- +# Yandex Smart Home API — device & capability constants +# --------------------------------------------------------------------------- +YANDEX_DEVICE_TYPE_MEDIA = "devices.types.media_device" +YANDEX_DEVICE_TYPE_RECEIVER = "devices.types.media_device.receiver" + +CAPABILITY_ON_OFF = "devices.capabilities.on_off" +CAPABILITY_RANGE = "devices.capabilities.range" +CAPABILITY_TOGGLE = "devices.capabilities.toggle" + +INSTANCE_ON = "on" +INSTANCE_VOLUME = "volume" +INSTANCE_MUTE = "mute" +INSTANCE_PAUSE = "pause" +INSTANCE_CHANNEL = "channel" +INSTANCE_INPUT_SOURCE = "input_source" + +UNIT_PERCENT = "unit.percent" + +# Yandex mode values for input_source mapping (by index position) +YANDEX_MODE_VALUES = ( + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", +) + +# --------------------------------------------------------------------------- +# Yandex Smart Home API — response codes +# --------------------------------------------------------------------------- +RESPONSE_OK = "DONE" +ERROR_DEVICE_UNREACHABLE = "DEVICE_UNREACHABLE" +ERROR_INVALID_ACTION = "INVALID_ACTION" +ERROR_INTERNAL_ERROR = "INTERNAL_ERROR" +ERROR_DEVICE_NOT_FOUND = "DEVICE_NOT_FOUND" diff --git a/music_assistant/providers/yandex_smarthome/device.py b/music_assistant/providers/yandex_smarthome/device.py new file mode 100644 index 0000000000..74340f7183 --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/device.py @@ -0,0 +1,523 @@ +"""MA Player ↔ Yandex Smart Home device mapper. + +Maps Music Assistant Player state to Yandex Smart Home device descriptions, +capability states, and action execution. +""" + +from __future__ import annotations + +import logging +import re +from typing import TYPE_CHECKING, Any + +from music_assistant_models.enums import PlaybackState + +from .constants import ( + ERROR_DEVICE_UNREACHABLE, + ERROR_INTERNAL_ERROR, + ERROR_INVALID_ACTION, + INSTANCE_CHANNEL, + INSTANCE_INPUT_SOURCE, + INSTANCE_MUTE, + INSTANCE_ON, + INSTANCE_PAUSE, + INSTANCE_VOLUME, + UNIT_PERCENT, + YANDEX_DEVICE_TYPE_MEDIA, + YANDEX_MODE_VALUES, +) +from .schema import ( + ActionResult, + CapabilityAction, + CapabilityActionResult, + CapabilityActionResultState, + CapabilityDescription, + CapabilityInstanceState, + CapabilityParameters, + CapabilityState, + DeviceDescription, + DeviceState, + ModeValue, + RangeParameters, + YandexCapabilityType, + YandexDeviceInfo, +) + +if TYPE_CHECKING: + from music_assistant_models.player import Player, PlayerSource + +_LOGGER = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Source list helpers (for input_source capability) +# --------------------------------------------------------------------------- + + +def _is_group_player(player: Player) -> bool: + """Check if player is a group (has child members).""" + group_members = getattr(player, "group_members", None) + return bool(group_members) + + +def _has_feature(player: Player, feature_name: str) -> bool: + """Check if player supports a given PlayerFeature by name.""" + features = getattr(player, "supported_features", None) + if not features: + return False + return any( + str(f) == feature_name or getattr(f, "value", None) == feature_name for f in features + ) + + +def _supports_select_source(player: Player) -> bool: + """Check if player natively supports source selection.""" + return _has_feature(player, "select_source") + + +def _get_source_list(player: Player) -> list[PlayerSource]: + """Get the source list from a player, or empty list if not available.""" + if not _supports_select_source(player): + return [] + source_list = getattr(player, "source_list", None) + if source_list: + return list(source_list) + return [] + + +def _build_source_modes(source_list: list[PlayerSource]) -> list[ModeValue]: + """Build Yandex mode values from an MA source list (max 10).""" + return [ModeValue(value=YANDEX_MODE_VALUES[i]) for i in range(min(len(source_list), 10))] + + +def _source_to_mode(active_source: str | None, source_list: list[PlayerSource]) -> str | None: + """Map active MA source name/id to a Yandex mode value.""" + if not active_source or not source_list: + return None + for i, source in enumerate(source_list[:10]): + source_name = getattr(source, "name", str(source)) + source_id = getattr(source, "id", str(source)) + if active_source in (source_name, source_id): + return YANDEX_MODE_VALUES[i] + return None + + +def _mode_to_source(mode_value: str, source_list: list[PlayerSource]) -> str | None: + """Resolve a Yandex mode value to an MA source id.""" + try: + idx = list(YANDEX_MODE_VALUES).index(mode_value) + except ValueError: + return None + if idx >= len(source_list): + return None + return source_list[idx].id + + +# --------------------------------------------------------------------------- +# Device description & state +# --------------------------------------------------------------------------- + + +def _volume_range_params() -> CapabilityParameters: + """Build range parameters for volume capability.""" + return CapabilityParameters( + instance=INSTANCE_VOLUME, + range=RangeParameters(min=0, max=100, precision=1), + unit=UNIT_PERCENT, + ) + + +# Yandex Smart Home allows only Russian/English letters, digits, and spaces. +_RE_DISALLOWED = re.compile(r"[^a-zA-Zа-яА-ЯёЁ0-9 ]") # noqa: RUF001 +_RE_LETTER_DIGIT = re.compile(r"([a-zA-Zа-яА-ЯёЁ])(\d)") # noqa: RUF001 +_RE_DIGIT_LETTER = re.compile(r"(\d)([a-zA-Zа-яА-ЯёЁ])") # noqa: RUF001 +_RE_MULTI_SPACE = re.compile(r" {2,}") + + +def normalize_device_name(name: str) -> str: + """Normalize player name for Yandex Smart Home. + + Rules: only Russian/English letters, digits, and spaces; + mandatory space between letters and digits. + """ + result = _RE_DISALLOWED.sub(" ", name) + result = _RE_LETTER_DIGIT.sub(r"\1 \2", result) + result = _RE_DIGIT_LETTER.sub(r"\1 \2", result) + result = _RE_MULTI_SPACE.sub(" ", result).strip() + return result or name + + +def get_device_description(player: Player) -> DeviceDescription: + """Build a Yandex Smart Home device description from an MA player.""" + capabilities = [ + CapabilityDescription(type=YandexCapabilityType.ON_OFF), + CapabilityDescription( + type=YandexCapabilityType.RANGE, + parameters=_volume_range_params(), + ), + CapabilityDescription( + type=YandexCapabilityType.TOGGLE, + parameters=CapabilityParameters(instance=INSTANCE_PAUSE), + ), + CapabilityDescription( + type=YandexCapabilityType.RANGE, + parameters=CapabilityParameters( + instance=INSTANCE_CHANNEL, + range=RangeParameters(min=0, max=999, precision=1), + random_access=False, + ), + ), + ] + + # toggle(mute) — if player supports VOLUME_MUTE or is a group + if _has_feature(player, "volume_mute") or _is_group_player(player): + capabilities.append( + CapabilityDescription( + type=YandexCapabilityType.TOGGLE, + parameters=CapabilityParameters(instance=INSTANCE_MUTE), + ) + ) + + # mode(input_source) — only if player has sources + source_list = _get_source_list(player) + if source_list: + modes = _build_source_modes(source_list) + if modes: + capabilities.append( + CapabilityDescription( + type=YandexCapabilityType.MODE, + parameters=CapabilityParameters( + instance=INSTANCE_INPUT_SOURCE, + modes=modes, + ), + ) + ) + + model = "MA Player" + if hasattr(player, "device_info") and player.device_info: + model = getattr(player.device_info, "model", model) or model + + return DeviceDescription( + id=player.player_id, + name=normalize_device_name(player.name), + type=YANDEX_DEVICE_TYPE_MEDIA, + capabilities=capabilities, + device_info=YandexDeviceInfo(model=model), + ) + + +def get_device_state(player: Player) -> DeviceState: + """Read current MA player state and convert to Yandex capability states.""" + # on = player is powered on (or available if power state unknown) + powered = getattr(player, "powered", None) + is_on = powered if powered is not None else getattr(player, "available", True) + is_paused = player.playback_state != PlaybackState.PLAYING + is_group = _is_group_player(player) + + # For groups use group_volume/group_volume_muted which aggregate children + if is_group: + volume = getattr(player, "group_volume", None) + if volume is None: + volume = player.volume_level if player.volume_level is not None else 0 + else: + volume = player.volume_level if player.volume_level is not None else 0 + + capabilities = [ + CapabilityState( + type=YandexCapabilityType.ON_OFF, + state=CapabilityInstanceState(instance=INSTANCE_ON, value=is_on), + ), + CapabilityState( + type=YandexCapabilityType.RANGE, + state=CapabilityInstanceState(instance=INSTANCE_VOLUME, value=volume), + ), + CapabilityState( + type=YandexCapabilityType.TOGGLE, + state=CapabilityInstanceState(instance=INSTANCE_PAUSE, value=is_paused), + ), + CapabilityState( + type=YandexCapabilityType.RANGE, + state=CapabilityInstanceState(instance=INSTANCE_CHANNEL, value=0), + ), + ] + + # mute state — if player supports VOLUME_MUTE or is a group + if _has_feature(player, "volume_mute") or is_group: + if is_group: + muted = getattr(player, "group_volume_muted", None) + if muted is None: + muted = player.volume_muted if player.volume_muted is not None else False + else: + muted = player.volume_muted if player.volume_muted is not None else False + capabilities.append( + CapabilityState( + type=YandexCapabilityType.TOGGLE, + state=CapabilityInstanceState(instance=INSTANCE_MUTE, value=muted), + ) + ) + + # input_source state — only if player has sources + source_list = _get_source_list(player) + if source_list: + active = getattr(player, "active_source", None) + mode_value = _source_to_mode(active, source_list) + if mode_value: + capabilities.append( + CapabilityState( + type=YandexCapabilityType.MODE, + state=CapabilityInstanceState(instance=INSTANCE_INPUT_SOURCE, value=mode_value), + ) + ) + + return DeviceState(id=player.player_id, capabilities=capabilities) + + +async def _execute_input_source( + mass: Any, player_id: str, player: Player | None, instance: str, value: Any +) -> CapabilityActionResult | None: + """Handle input_source mode action. Returns error result or None on success.""" + if player is None: + return CapabilityActionResult( + type=YandexCapabilityType.MODE, + state=CapabilityActionResultState( + instance=instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_DEVICE_UNREACHABLE, + error_message=f"Player {player_id} not found", + ), + ), + ) + p_state = player.state if hasattr(player, "state") else player + source_list = _get_source_list(p_state) + source = _mode_to_source(str(value), source_list) + if source: + await mass.players.select_source(player_id, source) + return None + return CapabilityActionResult( + type=YandexCapabilityType.MODE, + state=CapabilityActionResultState( + instance=instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_INVALID_ACTION, + error_message=f"Unknown source mode: {value}", + ), + ), + ) + + +def _invalid_bool_result(cap_type: str, instance: str, value: Any) -> CapabilityActionResult: + """Build an INVALID_ACTION result for a capability that requires a boolean value.""" + return CapabilityActionResult( + type=cap_type, + state=CapabilityActionResultState( + instance=instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_INVALID_ACTION, + error_message=( + f"Expected boolean value for {cap_type}/{instance}, got {type(value).__name__}" + ), + ), + ), + ) + + +def _invalid_numeric_result(cap_type: str, instance: str, value: Any) -> CapabilityActionResult: + """Build an INVALID_ACTION result for a capability that requires a numeric value.""" + return CapabilityActionResult( + type=cap_type, + state=CapabilityActionResultState( + instance=instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_INVALID_ACTION, + error_message=( + f"Expected numeric value for {cap_type}/{instance}, got {type(value).__name__}" + ), + ), + ), + ) + + +async def execute_capability_action( # noqa: PLR0915 + mass: Any, + player_id: str, + action: CapabilityAction, + current_volume: int = 0, +) -> CapabilityActionResult: + """Execute a Yandex capability action by calling the corresponding MA player command. + + Returns a CapabilityActionResult with success or error status. + """ + instance = action.state.instance + value = action.state.value + player = mass.players.get_player(player_id) + + if player is None: + return CapabilityActionResult( + type=action.type, + state=CapabilityActionResultState( + instance=instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_DEVICE_UNREACHABLE, + error_message=f"Player {player_id} not found", + ), + ), + ) + + is_group = _is_group_player(player) + + # Bool-typed capabilities: reject non-bool payloads up-front so truthiness + # on strings like "false"/"0" can't trigger the wrong command. + requires_bool = action.type == YandexCapabilityType.ON_OFF or ( + action.type == YandexCapabilityType.TOGGLE and instance in (INSTANCE_MUTE, INSTANCE_PAUSE) + ) + if requires_bool and not isinstance(value, bool): + return _invalid_bool_result(action.type, instance, value) + + # Numeric RANGE capabilities: reject bool explicitly (bool is a subclass of + # int in Python, so float(True)/int(False) would silently change volume or + # skip tracks) and any other non-numeric type. + if action.type == YandexCapabilityType.RANGE and ( + isinstance(value, bool) or not isinstance(value, (int, float)) + ): + return _invalid_numeric_result(action.type, instance, value) + + try: + if action.type == YandexCapabilityType.ON_OFF: + if value: + # Power on if supported, then play + if _has_feature(player, "power"): + await mass.players.cmd_power(player_id, True) + await mass.players.cmd_play(player_id) + else: + await mass.players.cmd_stop(player_id) + if _has_feature(player, "power"): + await mass.players.cmd_power(player_id, False) + + elif action.type == YandexCapabilityType.RANGE and instance == INSTANCE_VOLUME: + if action.state.relative: + target = max(0, min(100, current_volume + int(float(value)))) + else: + target = max(0, min(100, int(float(value)))) + if is_group: + await mass.players.cmd_group_volume(player_id, target) + else: + await mass.players.cmd_volume_set(player_id, target) + value = target + + elif action.type == YandexCapabilityType.TOGGLE and instance == INSTANCE_MUTE: + if is_group: + await mass.players.cmd_group_volume_mute(player_id, value) + else: + await mass.players.cmd_volume_mute(player_id, value) + + elif action.type == YandexCapabilityType.TOGGLE and instance == INSTANCE_PAUSE: + if value: + await mass.players.cmd_pause(player_id) + else: + await mass.players.cmd_play(player_id) + + elif action.type == YandexCapabilityType.RANGE and instance == INSTANCE_CHANNEL: + if action.state.relative: + if int(float(value)) > 0: + await mass.players.cmd_next_track(player_id) + elif int(float(value)) < 0: + await mass.players.cmd_previous_track(player_id) + # Non-relative channel set is ignored (no concept of channel number in MA) + + elif action.type == YandexCapabilityType.MODE and instance == INSTANCE_INPUT_SOURCE: + result = await _execute_input_source(mass, player_id, player, instance, value) + if result: + return result + + else: + return CapabilityActionResult( + type=action.type, + state=CapabilityActionResultState( + instance=instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_INVALID_ACTION, + error_message=f"Unknown capability: {action.type}/{instance}", + ), + ), + ) + + except (ValueError, TypeError): + return CapabilityActionResult( + type=action.type, + state=CapabilityActionResultState( + instance=instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_INVALID_ACTION, + error_message=f"Invalid value for {action.type}/{instance}: {value}", + ), + ), + ) + except Exception: + _LOGGER.exception("Error executing action %s/%s on %s", action.type, instance, player_id) + return CapabilityActionResult( + type=action.type, + state=CapabilityActionResultState( + instance=instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_INTERNAL_ERROR, + ), + ), + ) + + return CapabilityActionResult( + type=action.type, + state=CapabilityActionResultState( + instance=instance, + value=value, + action_result=ActionResult(status="DONE"), + ), + ) + + +def is_player_exposable(player: Player, exposed_ids: set[str] | None = None) -> bool: + """Determine whether an MA player should be exposed to Yandex Smart Home.""" + if not player.available: + return False + if not player.enabled: + return False + # Don't expose players that are synced to another player (they are controlled via leader) + if player.synced_to: + return False + # If a filter is set, only expose selected players + return not (exposed_ids and player.player_id not in exposed_ids) + + +def make_error_device_state(device_id: str) -> DeviceState: + """Create an error DeviceState for an unreachable device.""" + return DeviceState( + id=device_id, + error_code=ERROR_DEVICE_UNREACHABLE, + error_message="Device is not available", + ) + + +def make_error_action_result( + _device_id: str, actions: list[CapabilityAction] +) -> list[CapabilityActionResult]: + """Create error action results for all capabilities of an unreachable device.""" + return [ + CapabilityActionResult( + type=a.type, + state=CapabilityActionResultState( + instance=a.state.instance, + action_result=ActionResult( + status="ERROR", + error_code=ERROR_DEVICE_UNREACHABLE, + ), + ), + ) + for a in actions + ] diff --git a/music_assistant/providers/yandex_smarthome/direct.py b/music_assistant/providers/yandex_smarthome/direct.py new file mode 100644 index 0000000000..1d0025ec3b --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/direct.py @@ -0,0 +1,402 @@ +"""HTTP handlers for Yandex Smart Home direct connection. + +Registers dynamic routes on the MA webserver to handle Yandex Smart Home API +requests directly (without the yaha-cloud.ru WebSocket relay). + +Routes: + Health: + HEAD/GET /api/yandex_smarthome/v1.0 — health check + GET /api/yandex_smarthome/v1.0/ping — health check + + API (Bearer auth required): + POST /api/yandex_smarthome/v1.0/user/devices — list devices + POST /api/yandex_smarthome/v1.0/user/devices/query — query states + POST /api/yandex_smarthome/v1.0/user/devices/action — execute actions + POST /api/yandex_smarthome/v1.0/user/unlink — user unlink + + OAuth (account linking): + GET /api/yandex_smarthome/auth/authorize — authorization page + POST /api/yandex_smarthome/auth/token — token exchange +""" + +from __future__ import annotations + +import html as html_module +import logging +import secrets +import time +import urllib.parse +import uuid +from collections.abc import Callable +from dataclasses import asdict +from typing import TYPE_CHECKING, Any + +from aiohttp import web + +from .constants import ( + DIRECT_API_BASE_PATH, + DIRECT_AUTH_BASE_PATH, + DIRECT_HEALTH_RESPONSE, + DIRECT_OAUTH_CLIENT_ID, + MAX_PENDING_CODES, + OAUTH_CODE_EXPIRY, +) +from .handlers import ( + build_response, + handle_device_list, + handle_devices_action, + handle_devices_query, + handle_user_unlink, + parse_action_payload, +) + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + +_LOGGER = logging.getLogger(__name__) + +# Minimal HTML for the OAuth authorize page +_AUTHORIZE_HTML = """ + + + + + Music Assistant — Yandex Smart Home + + + +
+

🎵 Music Assistant

+

Привязать аккаунт к Яндекс Умному Дому?

+ Привязать +
+ +""" + + +class DirectConnectionHandler: + """Handles Yandex Smart Home HTTP requests for direct connection mode.""" + + def __init__( + self, + mass: MusicAssistant, + user_id: str, + access_token: str, + client_secret: str, + exposed_ids: set[str] | None = None, + logger: logging.Logger | None = None, + on_token_created: Callable[[str], None] | None = None, + ) -> None: + """Initialize the handler. + + Args: + mass: MusicAssistant instance. + user_id: User identifier for Yandex API responses. + access_token: Current Bearer access token (may be empty on first run). + client_secret: OAuth client secret for account linking validation. + exposed_ids: Set of player IDs to expose, or None for all. + logger: Optional logger instance. + on_token_created: Callback invoked with new access token when generated + via OAuth flow (to persist in config). + """ + self._mass = mass + self._user_id = user_id + self._access_token = access_token + self._client_secret = client_secret + self._exposed_ids = exposed_ids + self._logger = logger or _LOGGER + self._on_token_created = on_token_created + self._unregister_callbacks: list[Callable[[], None]] = [] + # Pending OAuth authorization codes: {code: expiry_timestamp} + self._pending_codes: dict[str, float] = {} + + @property + def access_token(self) -> str: + """Return the current access token.""" + return self._access_token + + def register_routes(self) -> None: + """Register all HTTP routes on the MA webserver.""" + base = DIRECT_API_BASE_PATH + auth_base = DIRECT_AUTH_BASE_PATH + register = self._mass.webserver.register_dynamic_route + + # Health check endpoints (no auth) + routes: list[tuple[str, str, Any]] = [ + (f"{base}", "HEAD", self._handle_health), + (f"{base}", "GET", self._handle_health), + (f"{base}/ping", "GET", self._handle_health), + # API endpoints (auth required) + (f"{base}/user/devices", "POST", self._handle_devices), + (f"{base}/user/devices/query", "POST", self._handle_query), + (f"{base}/user/devices/action", "POST", self._handle_action), + (f"{base}/user/unlink", "POST", self._handle_unlink), + # Also support GET for devices (Yandex may use either) + (f"{base}/user/devices", "GET", self._handle_devices), + # OAuth account linking endpoints (no auth) + (f"{auth_base}/authorize", "GET", self._handle_oauth_authorize), + (f"{auth_base}/token", "POST", self._handle_oauth_token), + ] + + for path, method, handler in routes: + try: + unregister = register(path, handler, method) + except RuntimeError: + self._logger.error("Failed to register route %s %s; rolling back", method, path) + self.unregister_routes() + raise + self._unregister_callbacks.append(unregister) + + self._logger.info( + "Direct connection: registered %d routes on MA webserver", + len(self._unregister_callbacks), + ) + + def unregister_routes(self) -> None: + """Unregister all HTTP routes from the MA webserver.""" + for cb in self._unregister_callbacks: + try: + cb() + except Exception: + self._logger.debug("Error unregistering route", exc_info=True) + count = len(self._unregister_callbacks) + self._unregister_callbacks.clear() + self._logger.info("Direct connection: unregistered %d routes", count) + + # ------------------------------------------------------------------- + # Auth helpers + # ------------------------------------------------------------------- + + def _validate_auth(self, request: web.Request) -> bool: + """Validate Bearer token from Authorization header.""" + if not self._access_token: + return False + auth = request.headers.get("Authorization", "") + if auth.startswith("Bearer "): + return secrets.compare_digest(auth[7:], self._access_token) + return False + + def _unauthorized_response(self, request_id: str = "") -> web.Response: + """Return a 401 Unauthorized response in the Smart Home API envelope.""" + return web.json_response( + build_response(request_id, {"error": "unauthorized"}), + status=401, + headers={"WWW-Authenticate": "Bearer"}, + ) + + # ------------------------------------------------------------------- + # Health check + # ------------------------------------------------------------------- + + async def _handle_health(self, request: web.Request) -> web.Response: + """Handle health check (HEAD/GET /v1.0 and /v1.0/ping).""" + if request.method == "HEAD": + return web.Response(status=200) + return web.Response(text=DIRECT_HEALTH_RESPONSE, status=200) + + # ------------------------------------------------------------------- + # Smart Home API endpoints + # ------------------------------------------------------------------- + + async def _handle_devices(self, request: web.Request) -> web.Response: + """Handle /user/devices — list all exposed devices.""" + request_id = request.headers.get("X-Request-Id", "") + if not self._validate_auth(request): + return self._unauthorized_response(request_id) + + try: + device_list = await handle_device_list( + self._mass, self._user_id, exposed_ids=self._exposed_ids + ) + return web.json_response(build_response(request_id, asdict(device_list))) + except Exception: + self._logger.exception("Error handling /user/devices") + return web.json_response(build_response(request_id, {}), status=500) + + async def _handle_query(self, request: web.Request) -> web.Response: + """Handle /user/devices/query — return device states.""" + request_id = request.headers.get("X-Request-Id", "") + if not self._validate_auth(request): + return self._unauthorized_response(request_id) + + try: + body = await request.json() + except Exception: + return web.json_response(build_response(request_id, {}), status=400) + + try: + if not isinstance(body, dict): + body = {} + devices_raw = body.get("devices", []) + if not isinstance(devices_raw, list): + devices_raw = [] + device_ids = [ + device_id for d in devices_raw if isinstance(d, dict) and (device_id := d.get("id")) + ] + states = await handle_devices_query( + self._mass, device_ids, exposed_ids=self._exposed_ids + ) + return web.json_response(build_response(request_id, asdict(states))) + except Exception: + self._logger.exception("Error handling /user/devices/query") + return web.json_response(build_response(request_id, {}), status=500) + + async def _handle_action(self, request: web.Request) -> web.Response: + """Handle /user/devices/action — execute capability actions.""" + request_id = request.headers.get("X-Request-Id", "") + if not self._validate_auth(request): + return self._unauthorized_response(request_id) + + try: + body = await request.json() + except Exception: + return web.json_response(build_response(request_id, {}), status=400) + + try: + action_payload = parse_action_payload(body) + result = await handle_devices_action( + self._mass, action_payload, exposed_ids=self._exposed_ids + ) + return web.json_response(build_response(request_id, asdict(result))) + except Exception: + self._logger.exception("Error handling /user/devices/action") + return web.json_response(build_response(request_id, {}), status=500) + + async def _handle_unlink(self, request: web.Request) -> web.Response: + """Handle /user/unlink — user disconnected account.""" + request_id = request.headers.get("X-Request-Id", "") + if not self._validate_auth(request): + return self._unauthorized_response(request_id) + + try: + result = await handle_user_unlink() + return web.json_response(build_response(request_id, result)) + except Exception: + self._logger.exception("Error handling /user/unlink") + return web.json_response(build_response(request_id, {}), status=500) + + # ------------------------------------------------------------------- + # OAuth account linking + # ------------------------------------------------------------------- + + def _cleanup_expired_codes(self) -> None: + """Remove expired authorization codes.""" + now = time.time() + expired = [code for code, exp in self._pending_codes.items() if now > exp] + for code in expired: + del self._pending_codes[code] + + async def _handle_oauth_authorize(self, request: web.Request) -> web.Response: + """Handle GET /auth/authorize — show authorization page. + + Yandex opens this URL in the user's browser during account linking. + Parameters: client_id, redirect_uri, state, response_type=code + """ + client_id = request.query.get("client_id", "") + response_type = request.query.get("response_type", "") + redirect_uri = request.query.get("redirect_uri", "") + state = request.query.get("state", "") + + # Validate required OAuth parameters + if client_id != DIRECT_OAUTH_CLIENT_ID: + return web.Response(text="Invalid client_id", status=400) + if response_type != "code": + return web.Response(text="Invalid response_type", status=400) + if not redirect_uri: + return web.Response(text="Missing redirect_uri", status=400) + + # Validate redirect_uri is the expected Yandex HTTPS endpoint (prevent open redirect) + parsed = urllib.parse.urlparse(redirect_uri) + if parsed.scheme != "https" or parsed.hostname != "social.yandex.net": + return web.Response(text="Invalid redirect_uri", status=400) + + # Generate authorization code (with DoS cap on pending codes) + self._cleanup_expired_codes() + if len(self._pending_codes) >= MAX_PENDING_CODES: + return web.Response(text="Too many pending authorization requests", status=429) + code = uuid.uuid4().hex + self._pending_codes[code] = time.time() + OAUTH_CODE_EXPIRY + + # Build redirect URL with code and state + params = {"code": code} + if state: + params["state"] = state + separator = "&" if "?" in redirect_uri else "?" + redirect_url = f"{redirect_uri}{separator}{urllib.parse.urlencode(params)}" + + html = _AUTHORIZE_HTML.format(redirect_url=html_module.escape(redirect_url, quote=True)) + return web.Response(text=html, content_type="text/html", status=200) + + async def _handle_oauth_token(self, request: web.Request) -> web.Response: + """Handle POST /auth/token — exchange code for access token. + + Supports: + - grant_type=authorization_code: exchange code for new token + - grant_type=refresh_token: return same access token + """ + try: + data = await request.post() + except Exception: + return web.json_response({"error": "invalid_request"}, status=400) + + # Validate client credentials + client_id = str(data.get("client_id", "")) + client_secret = str(data.get("client_secret", "")) + if client_id != DIRECT_OAUTH_CLIENT_ID or not secrets.compare_digest( + client_secret, self._client_secret + ): + return web.json_response({"error": "invalid_client"}, status=401) + + grant_type = str(data.get("grant_type", "")) + + if grant_type == "authorization_code": + code = str(data.get("code", "")) + self._cleanup_expired_codes() + + if not code or code not in self._pending_codes: + return web.json_response({"error": "invalid_grant"}, status=400) + + # Consume the code + del self._pending_codes[code] + + # Generate or reuse access token + if not self._access_token: + self._access_token = uuid.uuid4().hex + if self._on_token_created: + self._on_token_created(self._access_token) + self._logger.info("Generated new access token for direct connection") + + return web.json_response( + { + "access_token": self._access_token, + "token_type": "bearer", + "refresh_token": self._access_token, + } + ) + + if grant_type == "refresh_token": + refresh_token = str(data.get("refresh_token", "")) + if refresh_token and secrets.compare_digest(refresh_token, self._access_token): + return web.json_response( + { + "access_token": self._access_token, + "token_type": "bearer", + "refresh_token": self._access_token, + } + ) + return web.json_response({"error": "invalid_grant"}, status=400) + + return web.json_response({"error": "unsupported_grant_type"}, status=400) diff --git a/music_assistant/providers/yandex_smarthome/handlers.py b/music_assistant/providers/yandex_smarthome/handlers.py new file mode 100644 index 0000000000..2607e188f2 --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/handlers.py @@ -0,0 +1,210 @@ +"""Yandex Smart Home API request handlers. + +Pure async functions handling the 4 Smart Home API actions: +- /user/devices — list all exposed MA players +- /user/devices/query — return states of requested devices +- /user/devices/action — execute capability actions +- /user/unlink — user disconnected their account +""" + +from __future__ import annotations + +import logging +from dataclasses import asdict +from typing import TYPE_CHECKING, Any + +from .device import ( + execute_capability_action, + get_device_description, + get_device_state, + is_player_exposable, + make_error_action_result, + make_error_device_state, +) +from .schema import ( + ActionRequestPayload, + ActionResultPayload, + CapabilityAction, + CapabilityActionState, + DeviceAction, + DeviceActionResult, + DeviceListPayload, + DeviceState, + DeviceStatesPayload, +) + +if TYPE_CHECKING: + from music_assistant.mass import MusicAssistant + +_LOGGER = logging.getLogger(__name__) + + +async def handle_device_list( + mass: MusicAssistant, + user_id: str, + exposed_ids: set[str] | None = None, +) -> DeviceListPayload: + """Handle /user/devices — return list of all MA players as Yandex devices.""" + devices = [] + for player in mass.players.all_players(): + state = player.state + if not is_player_exposable(state, exposed_ids=exposed_ids): + continue + devices.append(get_device_description(state)) + _LOGGER.debug("Device list: %d devices exposed", len(devices)) + return DeviceListPayload(user_id=user_id, devices=devices) + + +async def handle_devices_query( + mass: MusicAssistant, + device_ids: list[str], + exposed_ids: set[str] | None = None, +) -> DeviceStatesPayload: + """Handle /user/devices/query — return current states for requested devices.""" + states: list[DeviceState] = [] + for device_id in device_ids: + try: + player = mass.players.get_player(device_id) + except Exception: + player = None + + if player is None: + states.append(make_error_device_state(device_id)) + continue + + player_state = player.state if hasattr(player, "state") else player + if not is_player_exposable(player_state, exposed_ids=exposed_ids): # type: ignore[arg-type] + states.append(make_error_device_state(device_id)) + continue + + states.append(get_device_state(player_state)) # type: ignore[arg-type] + + return DeviceStatesPayload(devices=states) + + +async def handle_devices_action( + mass: MusicAssistant, + payload: ActionRequestPayload, + exposed_ids: set[str] | None = None, +) -> ActionResultPayload: + """Handle /user/devices/action — execute capability actions on devices.""" + results: list[DeviceActionResult] = [] + + for device_action in payload.devices: + try: + player = mass.players.get_player(device_action.id) + except Exception: + player = None + + if player is None: + results.append( + DeviceActionResult( + id=device_action.id, + capabilities=make_error_action_result( + device_action.id, device_action.capabilities + ), + ) + ) + continue + + player_state = player.state if hasattr(player, "state") else player + if not is_player_exposable(player_state, exposed_ids=exposed_ids): # type: ignore[arg-type] + results.append( + DeviceActionResult( + id=device_action.id, + capabilities=make_error_action_result( + device_action.id, device_action.capabilities + ), + ) + ) + continue + + current_volume = player_state.volume_level or 0 + if getattr(player_state, "group_members", None): + current_volume = getattr(player_state, "group_volume", None) or current_volume + cap_results = [] + for cap_action in device_action.capabilities: + result = await execute_capability_action( + mass, device_action.id, cap_action, current_volume + ) + cap_results.append(result) + + results.append(DeviceActionResult(id=device_action.id, capabilities=cap_results)) + + return ActionResultPayload(devices=results) + + +async def handle_user_unlink() -> dict[str, Any]: + """Handle /user/unlink — user disconnected their Yandex account.""" + _LOGGER.info("User unlinked Yandex Smart Home account") + return {} + + +def parse_action_payload(raw: dict[str, Any]) -> ActionRequestPayload: + """Parse a raw /user/devices/action message into ActionRequestPayload. + + Defensively handles malformed input: non-list devices, non-dict entries, + missing/non-dict state objects are all silently skipped. + """ + devices = [] + payload_obj = raw.get("payload", raw) + if not isinstance(payload_obj, dict): + payload_obj = {} + devices_raw = payload_obj.get("devices", []) + if not isinstance(devices_raw, list): + devices_raw = [] + for dev_raw in devices_raw: + if not isinstance(dev_raw, dict): + continue + dev_id = dev_raw.get("id") + if not dev_id: + continue + capabilities = [] + caps_raw = dev_raw.get("capabilities", []) + if not isinstance(caps_raw, list): + caps_raw = [] + for cap_raw in caps_raw: + if not isinstance(cap_raw, dict): + continue + cap_type = cap_raw.get("type") + if not cap_type: + continue + state_raw = cap_raw.get("state") + if not isinstance(state_raw, dict): + continue + instance_raw = state_raw.get("instance") + if not isinstance(instance_raw, str) or not instance_raw.strip(): + continue + relative_raw = state_raw.get("relative", False) + if not isinstance(relative_raw, bool): + continue + capabilities.append( + CapabilityAction( + type=cap_type, + state=CapabilityActionState( + instance=instance_raw.strip(), + value=state_raw.get("value"), + relative=relative_raw, + ), + ) + ) + devices.append(DeviceAction(id=dev_id, capabilities=capabilities)) + return ActionRequestPayload(devices=devices) + + +def _strip_none(obj: Any) -> Any: + """Recursively remove None values from dicts (Yandex API rejects null fields).""" + if isinstance(obj, dict): + return {k: _strip_none(v) for k, v in obj.items() if v is not None} + if isinstance(obj, list): + return [_strip_none(item) for item in obj] + return obj + + +def build_response(request_id: str, payload: Any) -> dict[str, Any]: + """Wrap a handler result in the Yandex Smart Home API response envelope.""" + if payload is None: + return {"request_id": request_id, "payload": {}} + if isinstance(payload, dict): + return {"request_id": request_id, "payload": _strip_none(payload)} + return {"request_id": request_id, "payload": _strip_none(asdict(payload))} diff --git a/music_assistant/providers/yandex_smarthome/icon.svg b/music_assistant/providers/yandex_smarthome/icon.svg new file mode 100644 index 0000000000..0a6f433c3c --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/icon.svg @@ -0,0 +1,3 @@ + + + diff --git a/music_assistant/providers/yandex_smarthome/manifest.json b/music_assistant/providers/yandex_smarthome/manifest.json new file mode 100644 index 0000000000..c821b99756 --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/manifest.json @@ -0,0 +1,15 @@ +{ + "type": "plugin", + "domain": "yandex_smarthome", + "name": "Yandex Smart Home", + "description": "Expose Music Assistant players to Yandex Alice via Yandex Smart Home API.", + "codeowners": ["@trudenboy"], + "credits": [ + "[dext0r/yandex_smart_home](https://github.com/dext0r/yandex_smart_home)" + ], + "requirements": ["ya-passport-auth==1.2.3"], + "documentation": "https://github.com/trudenboy/ma-provider-yandex-smarthome", + "stage": "beta", + "multi_instance": false, + "builtin": false +} diff --git a/music_assistant/providers/yandex_smarthome/notifier.py b/music_assistant/providers/yandex_smarthome/notifier.py new file mode 100644 index 0000000000..3a2a3aad6e --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/notifier.py @@ -0,0 +1,258 @@ +"""State notifier — reports MA player state changes to Yandex Smart Home. + +Watches MA player events and pushes state updates to Yandex via the +callback/state API endpoint (cloud or direct). Uses a 1-second debounce +window to batch rapid state changes into a single callback. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import logging +import time +from collections.abc import Callable +from dataclasses import asdict +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import aiohttp + +from music_assistant_models.enums import EventType + +from .constants import ( + STATE_HEARTBEAT_INTERVAL, + STATE_INITIAL_REPORT_DELAY, + STATE_REPORT_DELAY, +) +from .device import get_device_state, is_player_exposable +from .handlers import _strip_none +from .schema import CallbackPayload, CallbackRequest, DeviceState + +if TYPE_CHECKING: + from music_assistant_models.event import MassEvent + + from music_assistant.mass import MusicAssistant + +_LOGGER = logging.getLogger(__name__) + + +class StateNotifier: + """Watches MA player events and reports state changes to Yandex.""" + + def __init__( + self, + mass: MusicAssistant, + session: aiohttp.ClientSession, + user_id: str, + callback_url: str, + auth_header: dict[str, str], + logger: logging.Logger | None = None, + exposed_ids: set[str] | None = None, + ) -> None: + """Initialize state notifier.""" + self._mass = mass + self._session = session + self._user_id = user_id + self._callback_url = callback_url + self._auth_header = auth_header + self._logger = logger or _LOGGER + self._exposed_ids = exposed_ids + + self._dirty_player_ids: set[str] = set() + self._flush_handle: asyncio.TimerHandle | None = None + self._initial_report_handle: asyncio.TimerHandle | None = None + self._heartbeat_task: asyncio.Task[None] | None = None + self._unsub: Callable[[], None] | None = None + + async def start(self) -> None: + """Subscribe to player events and start background tasks.""" + self._unsub = self._mass.subscribe( + self._on_player_event, + event_filter=( + EventType.PLAYER_UPDATED, + EventType.PLAYER_ADDED, + EventType.PLAYER_REMOVED, + ), + ) + + # Schedule initial full state report after startup delay + self._initial_report_handle = self._mass.loop.call_later( + STATE_INITIAL_REPORT_DELAY, + lambda: self._mass.create_task(self._report_all_states()), + ) + + # Periodic heartbeat + self._heartbeat_task = self._mass.create_task( + self._heartbeat_loop(), task_id="yandex_smarthome_heartbeat" + ) + + self._logger.info("State notifier started (callback=%s)", self._callback_url) + + async def stop(self) -> None: + """Unsubscribe from events and cancel background tasks.""" + if self._unsub: + self._unsub() + self._unsub = None + if self._initial_report_handle: + self._initial_report_handle.cancel() + self._initial_report_handle = None + if self._flush_handle: + self._flush_handle.cancel() + self._flush_handle = None + if self._heartbeat_task is not None: + if not self._heartbeat_task.done(): + self._heartbeat_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._heartbeat_task + self._heartbeat_task = None + self._dirty_player_ids.clear() + self._logger.info("State notifier stopped") + + # ----------------------------------------------------------------------- + # Event handling + # ----------------------------------------------------------------------- + + def _on_player_event(self, event: MassEvent) -> None: + """Handle player state change — mark player as dirty for batched reporting.""" + if event.event in (EventType.PLAYER_ADDED, EventType.PLAYER_REMOVED): + self._schedule_discovery() + return + + # PLAYER_UPDATED — event.data is the Player (state) dataclass + player_state = event.data + if player_state is None: + return + + # If this player is synced to a group, propagate the event to the group + # (child volume/mute changes should update the group's state in Yandex) + synced_to = getattr(player_state, "synced_to", None) + if synced_to: + self._dirty_player_ids.add(synced_to) + self._schedule_flush() + return + + if not is_player_exposable(player_state, exposed_ids=self._exposed_ids): + return + + self._dirty_player_ids.add(player_state.player_id) + self._schedule_flush() + + def _schedule_flush(self) -> None: + """Schedule a batched state flush after the debounce window.""" + if self._flush_handle is not None: + return # already scheduled + self._flush_handle = self._mass.loop.call_later( + STATE_REPORT_DELAY, + lambda: self._mass.create_task(self._flush_pending()), + ) + + async def _flush_pending(self) -> None: + """Send all pending state changes to Yandex. + + Reads the fresh player state at flush time (not at event time) + so transient states during track transitions are not reported. + """ + self._flush_handle = None + if not self._dirty_player_ids: + return + dirty = self._dirty_player_ids + self._dirty_player_ids = set() + + devices: list[DeviceState] = [] + for player_id in dirty: + player = self._mass.players.get_player(player_id) + if player is None: + continue + state = player.state + if is_player_exposable(state, exposed_ids=self._exposed_ids): + devices.append(get_device_state(state)) + + if not devices: + return + try: + await self._send_state_callback(devices) + except Exception: + # Re-queue failed player IDs + self._dirty_player_ids |= dirty + self._schedule_flush() + raise + + # ----------------------------------------------------------------------- + # State reporting + # ----------------------------------------------------------------------- + + async def _send_state_callback(self, devices: list[DeviceState]) -> None: + """POST state callback to Yandex.""" + payload = CallbackRequest( + ts=time.time(), + payload=CallbackPayload(user_id=self._user_id, devices=devices), + ) + try: + async with self._session.post( + self._callback_url, + json=_strip_none(asdict(payload)), + headers=self._auth_header, + ) as resp: + if resp.status not in (200, 202): + body = await resp.text() + raise RuntimeError( + f"State callback failed with HTTP {resp.status}: {body[:200]}" + ) + self._logger.debug("State callback sent: %d device(s)", len(devices)) + except Exception: + self._logger.exception("State callback error") + raise + + async def _report_all_states(self) -> None: + """Report states for all currently exposed players.""" + devices: list[DeviceState] = [] + for player in self._mass.players.all_players(): + state = player.state + if is_player_exposable(state, exposed_ids=self._exposed_ids): + devices.append(get_device_state(state)) + if devices: + self._logger.info("Reporting all states: %d device(s)", len(devices)) + await self._send_state_callback(devices) + + async def _heartbeat_loop(self) -> None: + """Periodically report all states as a heartbeat.""" + while True: + try: + await asyncio.sleep(STATE_HEARTBEAT_INTERVAL) + await self._report_all_states() + except asyncio.CancelledError: + raise + except Exception: + self._logger.exception("Heartbeat state report failed, will retry next interval") + + # ----------------------------------------------------------------------- + # Discovery notification + # ----------------------------------------------------------------------- + + def _schedule_discovery(self) -> None: + """Notify Yandex that the device list has changed.""" + self._mass.create_task(self._send_discovery()) + + async def _send_discovery(self) -> None: + """POST discovery notification to Yandex.""" + discovery_url = self._callback_url.removesuffix("/state") + "/discovery" + payload = { + "ts": time.time(), + "payload": {"user_id": self._user_id}, + } + try: + async with self._session.post( + discovery_url, + json=payload, + headers=self._auth_header, + ) as resp: + if resp.status not in (200, 202): + body = await resp.text() + self._logger.warning( + "Discovery callback failed (HTTP %d): %s", resp.status, body[:200] + ) + else: + self._logger.debug("Discovery notification sent") + except Exception: + self._logger.exception("Discovery callback error") diff --git a/music_assistant/providers/yandex_smarthome/plugin.py b/music_assistant/providers/yandex_smarthome/plugin.py new file mode 100644 index 0000000000..335e3813e0 --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/plugin.py @@ -0,0 +1,314 @@ +""" +Yandex Smart Home Plugin Provider. + +Bridges Music Assistant players to the Yandex Smart Home ecosystem, +allowing Alice voice control of playback, volume, and transport. + +The plugin: +1. Listens for MA player events (added, removed, updated) +2. Exposes them as Yandex Smart Home media_device devices +3. Handles capability actions (on_off, volume, pause) from Alice +4. Reports state changes back to Yandex + +Connection modes: +- Cloud: WebSocket relay through yaha-cloud.ru (no public URL needed) +- Cloud Plus: Private skill via yaha-cloud.ru relay (custom Yandex.Dialogs skill) +- Direct: HTTP endpoints on MA webserver that Yandex calls directly (requires public URL) +""" + +from __future__ import annotations + +from dataclasses import asdict +from typing import Any + +from music_assistant.models.plugin import PluginProvider + +from ._compat import SecretStr +from .cloud import CloudManager +from .constants import ( + CLOUD_CALLBACK_URL, + CONF_CLOUD_CONNECTION_TOKEN, + CONF_CLOUD_INSTANCE_ID, + CONF_CLOUD_INSTANCE_PASSWORD, + CONF_CONNECTION_TYPE, + CONF_DIRECT_ACCESS_TOKEN, + CONF_DIRECT_CLIENT_SECRET, + CONF_EXPOSED_PLAYERS, + CONF_INSTANCE_NAME, + CONF_SKILL_ID, + CONF_SKILL_TOKEN, + CONNECTION_TYPE_CLOUD, + CONNECTION_TYPE_CLOUD_PLUS, + CONNECTION_TYPE_DIRECT, + YANDEX_DIALOGS_CALLBACK_BASE, +) +from .direct import DirectConnectionHandler +from .handlers import ( + build_response, + handle_device_list, + handle_devices_action, + handle_devices_query, + handle_user_unlink, + parse_action_payload, +) +from .notifier import StateNotifier +from .schema import CloudRequest + + +class YandexSmartHomePlugin(PluginProvider): + """Plugin provider that exposes MA players to Yandex Alice via Smart Home API. + + Follows the same pattern as the HASS plugin provider: subscribes to MA events, + maintains a mapping of MA players to Yandex Smart Home devices, and handles + capability actions from Alice by translating them to MA player commands. + """ + + _cloud_manager: CloudManager | None = None + _state_notifier: StateNotifier | None = None + _direct_handler: DirectConnectionHandler | None = None + _cloud_task: Any = None + _user_id: str = "" + + async def handle_async_init(self) -> None: + """Handle async initialization of the plugin.""" + self._connection_type = str( + self.config.get_value(CONF_CONNECTION_TYPE) or CONNECTION_TYPE_CLOUD + ) + self._instance_name = str(self.config.get_value(CONF_INSTANCE_NAME) or "Music Assistant") + cloud_token_raw = str(self.config.get_value(CONF_CLOUD_INSTANCE_PASSWORD) or "") + self._cloud_token: SecretStr | None = ( + SecretStr(cloud_token_raw) if cloud_token_raw else None + ) + conn_token_raw = str(self.config.get_value(CONF_CLOUD_CONNECTION_TOKEN) or "") + self._connection_token: SecretStr | None = ( + SecretStr(conn_token_raw) if conn_token_raw else None + ) + self._cloud_instance_id = str(self.config.get_value(CONF_CLOUD_INSTANCE_ID) or "") + self._skill_id = str(self.config.get_value(CONF_SKILL_ID) or "") + skill_token_raw = str(self.config.get_value(CONF_SKILL_TOKEN) or "") + self._skill_token: SecretStr | None = ( + SecretStr(skill_token_raw) if skill_token_raw else None + ) + self._direct_access_token = str(self.config.get_value(CONF_DIRECT_ACCESS_TOKEN) or "") + self._direct_client_secret = str(self.config.get_value(CONF_DIRECT_CLIENT_SECRET) or "") + + # Parse exposed players filter + exposed_raw = self.config.get_value(CONF_EXPOSED_PLAYERS) or [] + if isinstance(exposed_raw, str): + exposed_raw = [x.strip() for x in exposed_raw.split(",") if x.strip()] + elif isinstance(exposed_raw, list): + exposed_raw = [str(x) for x in exposed_raw if x] + else: + exposed_raw = [] + self._exposed_ids: set[str] | None = set(exposed_raw) if exposed_raw else None + + self.logger.info( + "Yandex Smart Home plugin init (mode=%s, name=%s)", + self._connection_type, + self._instance_name, + ) + + async def loaded_in_mass(self) -> None: + """Call after the provider has been loaded. + + Starts cloud WebSocket connection and state notifier. + """ + self.logger.info("Yandex Smart Home plugin loaded") + + if self._connection_type in (CONNECTION_TYPE_CLOUD, CONNECTION_TYPE_CLOUD_PLUS): + await self._start_cloud_mode() + elif self._connection_type == CONNECTION_TYPE_DIRECT: + await self._start_direct_mode() + else: + self.logger.error("Unknown connection type: %s", self._connection_type) + + async def _start_cloud_mode(self) -> None: + """Initialize and start cloud relay connection + state notifier.""" + if not self._connection_token or not self._connection_token.get_secret(): + self.logger.error( + "Cloud connection token not configured — " + "register an instance at yaha-cloud.ru and set the connection token" + ) + return + + # Validate Cloud Plus credentials before starting any tasks + if self._connection_type == CONNECTION_TYPE_CLOUD_PLUS: + if not self._skill_id or not self._skill_token or not self._skill_token.get_secret(): + self.logger.error("Cloud Plus mode requires skill_id and skill_token") + return + + # Validate cloud password (used for callback auth in basic cloud mode) + if self._connection_type == CONNECTION_TYPE_CLOUD and ( + not self._cloud_token or not self._cloud_token.get_secret() + ): + self.logger.error( + "Cloud instance password not configured — " + "set the password from yaha-cloud.ru instance settings" + ) + return + + # Determine user_id once — used in both API responses and state callbacks + self._user_id = self._cloud_instance_id or self._instance_name + + session = self.mass.http_session + + # Cloud WebSocket manager + self._cloud_manager = CloudManager( + session=session, + connection_token=self._connection_token, + on_request=self._handle_cloud_request, + logger=self.logger, + ) + self._cloud_task = self.mass.create_task( + self._cloud_manager.connect(), + task_id="yandex_smarthome_cloud", + ) + + # State notifier — different callback URL/auth for cloud_plus + if self._connection_type == CONNECTION_TYPE_CLOUD_PLUS: + assert self._skill_token is not None # validated above + callback_url = f"{YANDEX_DIALOGS_CALLBACK_BASE}/{self._skill_id}/callback/state" + auth_header = {"Authorization": f"OAuth {self._skill_token.get_secret()}"} + else: + assert self._cloud_token is not None # validated above + callback_url = f"{CLOUD_CALLBACK_URL}/state" + auth_header = {"Authorization": f"Bearer {self._cloud_token.get_secret()}"} + + self._state_notifier = StateNotifier( + mass=self.mass, + session=session, + user_id=self._user_id, + callback_url=callback_url, + auth_header=auth_header, + logger=self.logger, + exposed_ids=self._exposed_ids, + ) + await self._state_notifier.start() + + async def _start_direct_mode(self) -> None: + """Initialize direct connection mode — HTTP endpoints + state notifier.""" + if not self._skill_id or not self._skill_token or not self._skill_token.get_secret(): + self.logger.error( + "Direct mode requires skill_id and skill_token — " + "create a private skill in Yandex.Dialogs and configure the tokens" + ) + return + + if not self._direct_client_secret: + self.logger.error("Direct mode requires a client secret for OAuth account linking") + return + + self._user_id = self._instance_name + + def _on_token_created(token: str) -> None: + """Persist new access token generated during OAuth flow.""" + self._direct_access_token = token + self._update_config_value(CONF_DIRECT_ACCESS_TOKEN, token, encrypted=True) + + self._direct_handler = DirectConnectionHandler( + mass=self.mass, + user_id=self._user_id, + access_token=self._direct_access_token, + client_secret=self._direct_client_secret, + exposed_ids=self._exposed_ids, + logger=self.logger, + on_token_created=_on_token_created, + ) + self._direct_handler.register_routes() + + # State notifier — callback to Yandex Dialogs (same as Cloud Plus) + session = self.mass.http_session + callback_url = f"{YANDEX_DIALOGS_CALLBACK_BASE}/{self._skill_id}/callback/state" + auth_header = {"Authorization": f"OAuth {self._skill_token.get_secret()}"} + + self._state_notifier = StateNotifier( + mass=self.mass, + session=session, + user_id=self._user_id, + callback_url=callback_url, + auth_header=auth_header, + logger=self.logger, + exposed_ids=self._exposed_ids, + ) + await self._state_notifier.start() + + self.logger.info("Direct connection mode started") + + async def _handle_cloud_request(self, request: CloudRequest) -> dict[str, Any]: + """Route incoming cloud WS request to the appropriate handler.""" + action = request.action + request_id = request.request_id + message = request.message or {} + + # Normalize action path — relay may send with or without /v1.0 prefix + normalized = action.removeprefix("/v1.0") + + self.logger.debug( + "Cloud request: action=%s, request_id=%s", + action, + request_id, + ) + + try: + if normalized == "/user/devices": + device_list = await handle_device_list( + self.mass, + self._user_id, + exposed_ids=self._exposed_ids, + ) + return build_response(request_id, asdict(device_list)) + + if normalized == "/user/devices/query": + device_ids = [ + device_id + for d in message.get("devices", []) + if isinstance(d, dict) and (device_id := d.get("id")) + ] + states = await handle_devices_query( + self.mass, device_ids, exposed_ids=self._exposed_ids + ) + return build_response(request_id, asdict(states)) + + if normalized == "/user/devices/action": + action_payload = parse_action_payload(message) + action_result = await handle_devices_action( + self.mass, action_payload, exposed_ids=self._exposed_ids + ) + return build_response(request_id, asdict(action_result)) + + if normalized == "/user/unlink": + unlink_result = await handle_user_unlink() + return build_response(request_id, unlink_result) + + self.logger.warning("Unknown cloud action: %s", action) + return build_response(request_id, {}) + + except Exception: + self.logger.exception("Error handling cloud request: %s", action) + return build_response(request_id, {}) + + async def unload(self, is_removed: bool = False) -> None: + """Handle unload/close of the provider. + + Called when provider is deregistered (e.g. MA exiting or config reloading). + is_removed will be set to True when the provider is removed from the configuration. + """ + self.logger.info("Yandex Smart Home plugin unloading (removed=%s)", is_removed) + + if self._state_notifier: + await self._state_notifier.stop() + self._state_notifier = None + + if self._direct_handler: + self._direct_handler.unregister_routes() + self._direct_handler = None + + if self._cloud_manager: + await self._cloud_manager.disconnect() + self._cloud_manager = None + + if self._cloud_task: + cloud_task = self._cloud_task + self._cloud_task = None + if not cloud_task.done(): + cloud_task.cancel() diff --git a/music_assistant/providers/yandex_smarthome/schema.py b/music_assistant/providers/yandex_smarthome/schema.py new file mode 100644 index 0000000000..61a8dc446f --- /dev/null +++ b/music_assistant/providers/yandex_smarthome/schema.py @@ -0,0 +1,316 @@ +"""Dataclass models for the Yandex Smart Home API. + +Covers device descriptions, capability states, action requests/results, +callback payloads, and cloud WebSocket messages. + +Reference: https://yandex.ru/dev/dialogs/smart-home/doc/concepts/platform-protocol.html +Reference: https://github.com/dext0r/yandex_smart_home +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +try: + from enum import StrEnum +except ImportError: + # Python < 3.11 fallback (needed for local dev; upstream requires >=3.12) + class StrEnum(str, Enum): # type: ignore[no-redef] # noqa: UP042 + """Backport of StrEnum for Python < 3.11.""" + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class YandexDeviceType(StrEnum): + """Yandex Smart Home device types relevant to MA players.""" + + MEDIA_DEVICE = "devices.types.media_device" + MEDIA_DEVICE_RECEIVER = "devices.types.media_device.receiver" + + +class YandexCapabilityType(StrEnum): + """Yandex Smart Home capability types.""" + + ON_OFF = "devices.capabilities.on_off" + RANGE = "devices.capabilities.range" + TOGGLE = "devices.capabilities.toggle" + MODE = "devices.capabilities.mode" + + +class YandexRangeInstance(StrEnum): + """Range capability instances.""" + + VOLUME = "volume" + CHANNEL = "channel" + + +class YandexModeInstance(StrEnum): + """Mode capability instances.""" + + INPUT_SOURCE = "input_source" + + +class YandexToggleInstance(StrEnum): + """Toggle capability instances.""" + + MUTE = "mute" + PAUSE = "pause" + + +class YandexResponseCode(StrEnum): + """Yandex Smart Home API response/error codes.""" + + DONE = "DONE" + DEVICE_UNREACHABLE = "DEVICE_UNREACHABLE" + INVALID_ACTION = "INVALID_ACTION" + INTERNAL_ERROR = "INTERNAL_ERROR" + DEVICE_NOT_FOUND = "DEVICE_NOT_FOUND" + + +# --------------------------------------------------------------------------- +# Device description — returned by /user/devices +# --------------------------------------------------------------------------- + + +@dataclass +class RangeParameters: + """Range capability parameters (min/max/precision).""" + + min: float = 0 + max: float = 100 + precision: float = 1 + + +@dataclass +class ModeValue: + """A single mode value for mode capabilities.""" + + value: str + + +@dataclass +class CapabilityParameters: + """Parameters block inside a capability description.""" + + instance: str + range: RangeParameters | None = None + unit: str | None = None + random_access: bool | None = None + modes: list[ModeValue] | None = None + + +@dataclass +class CapabilityDescription: + """A single capability in a device description.""" + + type: str + retrievable: bool = True + reportable: bool = True + parameters: CapabilityParameters | None = None + + +@dataclass +class YandexDeviceInfo: + """Device info block.""" + + manufacturer: str = "Music Assistant" + model: str = "MA Player" + sw_version: str | None = None + + +@dataclass +class DeviceDescription: + """Full device description for /user/devices response.""" + + id: str + name: str + type: str + capabilities: list[CapabilityDescription] = field(default_factory=list) + device_info: YandexDeviceInfo | None = None + room: str | None = None + description: str | None = None + + +# --------------------------------------------------------------------------- +# Capability state — for /user/devices/query and state callbacks +# --------------------------------------------------------------------------- + + +@dataclass +class CapabilityInstanceState: + """State of a specific capability instance.""" + + instance: str + value: Any + + +@dataclass +class CapabilityState: + """A capability with its current state.""" + + type: str + state: CapabilityInstanceState + + +@dataclass +class DeviceState: + """State of a single device (for query or callback).""" + + id: str + capabilities: list[CapabilityState] = field(default_factory=list) + error_code: str | None = None + error_message: str | None = None + + +# --------------------------------------------------------------------------- +# Action request — from /user/devices/action +# --------------------------------------------------------------------------- + + +@dataclass +class CapabilityActionState: + """State portion of an action request capability.""" + + instance: str + value: Any + relative: bool = False + + +@dataclass +class CapabilityAction: + """A single capability action from Yandex.""" + + type: str + state: CapabilityActionState + + +@dataclass +class DeviceAction: + """Action request for a single device.""" + + id: str + capabilities: list[CapabilityAction] = field(default_factory=list) + + +@dataclass +class ActionRequestPayload: + """Payload of /user/devices/action request.""" + + devices: list[DeviceAction] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Action result +# --------------------------------------------------------------------------- + + +@dataclass +class ActionResult: + """Result of executing a single capability action.""" + + status: str = "DONE" + error_code: str | None = None + error_message: str | None = None + + +@dataclass +class CapabilityActionResultState: + """State with action result for a single capability in an action response. + + Per Yandex Smart Home API, action_result goes inside 'state' alongside instance. + """ + + instance: str + value: Any = None + action_result: ActionResult = field(default_factory=ActionResult) + + +@dataclass +class CapabilityActionResult: + """Result for a single capability in an action response.""" + + type: str + state: CapabilityActionResultState + + +@dataclass +class DeviceActionResult: + """Action results for a single device.""" + + id: str + capabilities: list[CapabilityActionResult] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Response payloads +# --------------------------------------------------------------------------- + + +@dataclass +class DeviceListPayload: + """Payload for /user/devices response.""" + + user_id: str + devices: list[DeviceDescription] = field(default_factory=list) + + +@dataclass +class DeviceStatesPayload: + """Payload for /user/devices/query response.""" + + devices: list[DeviceState] = field(default_factory=list) + + +@dataclass +class ActionResultPayload: + """Payload for /user/devices/action response.""" + + devices: list[DeviceActionResult] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Callback — state reporting to Yandex +# --------------------------------------------------------------------------- + + +@dataclass +class CallbackPayload: + """Payload for callback/state POST.""" + + user_id: str + devices: list[DeviceState] = field(default_factory=list) + + +@dataclass +class CallbackRequest: + """Full callback state request body.""" + + ts: float + payload: CallbackPayload + + +# --------------------------------------------------------------------------- +# Cloud WebSocket messages +# --------------------------------------------------------------------------- + + +@dataclass +class CloudRequest: + """Incoming message from yaha-cloud.ru WebSocket.""" + + request_id: str + action: str + message: dict[str, Any] | None = None + + +@dataclass +class CloudResponse: + """Outgoing response to yaha-cloud.ru WebSocket.""" + + request_id: str + payload: dict[str, Any] = field(default_factory=dict) diff --git a/requirements_all.txt b/requirements_all.txt index a452ebd986..8de605f398 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -87,6 +87,7 @@ unidecode==1.4.0 uv>=0.8.0 websocket-client==1.9.0 xmltodict==1.0.4 +ya-passport-auth==1.2.3 yandex-music==2.2.0 ytmusicapi==1.11.5 zeroconf==0.148.0 diff --git a/tests/providers/yandex_smarthome/__init__.py b/tests/providers/yandex_smarthome/__init__.py new file mode 100644 index 0000000000..90e43be1ad --- /dev/null +++ b/tests/providers/yandex_smarthome/__init__.py @@ -0,0 +1 @@ +"""Yandex Smart Home provider test suite.""" diff --git a/tests/providers/yandex_smarthome/test_basic.py b/tests/providers/yandex_smarthome/test_basic.py new file mode 100644 index 0000000000..5f4a224111 --- /dev/null +++ b/tests/providers/yandex_smarthome/test_basic.py @@ -0,0 +1,77 @@ +"""Basic tests for Yandex Smart Home plugin provider.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import music_assistant.providers.yandex_smarthome.constants as _cmod +from music_assistant.providers.yandex_smarthome.constants import ( + CAPABILITY_ON_OFF, + CAPABILITY_RANGE, + CAPABILITY_TOGGLE, + CLOUD_SKILL_WEBHOOK_TEMPLATE, + CONF_CLOUD_INSTANCE_PASSWORD, + CONF_INSTANCE_NAME, + CONF_SKILL_TOKEN, + CONNECTION_TYPE_CLOUD_PLUS, + INSTANCE_MUTE, + INSTANCE_PAUSE, + INSTANCE_VOLUME, + YANDEX_DEVICE_TYPE_RECEIVER, + YANDEX_DIALOGS_CALLBACK_BASE, + YANDEX_DIALOGS_DEVELOPER_URL, + YANDEX_OAUTH_URL, +) + +# Locate manifest.json relative to the provider module (works in both local and upstream layouts) +_MANIFEST_PATH = Path(_cmod.__file__).parent / "manifest.json" + + +def test_manifest_valid() -> None: + """Manifest should be valid JSON with required fields.""" + data = json.loads(_MANIFEST_PATH.read_text()) + + assert data["type"] == "plugin" + assert data["domain"] == "yandex_smarthome" + assert data["name"] == "Yandex Smart Home" + assert data["stage"] == "beta" + assert data["multi_instance"] is False + assert data["builtin"] is False + assert isinstance(data["requirements"], list) + assert "ya-passport-auth==1.2.3" in data["requirements"] + + +def test_manifest_has_codeowners() -> None: + """Manifest should declare codeowners.""" + data = json.loads(_MANIFEST_PATH.read_text()) + + assert "codeowners" in data + assert len(data["codeowners"]) > 0 + + +def test_constants_defined() -> None: + """Core constants should be importable and non-empty.""" + assert CONF_INSTANCE_NAME + assert CONF_CLOUD_INSTANCE_PASSWORD + assert YANDEX_DEVICE_TYPE_RECEIVER + + +def test_cloud_plus_constants() -> None: + """Cloud Plus constants should be importable and well-formed.""" + assert CONNECTION_TYPE_CLOUD_PLUS == "cloud_plus" + assert CONF_SKILL_TOKEN == "skill_token" + assert "dialogs.yandex.net" in YANDEX_DIALOGS_CALLBACK_BASE + assert "dialogs.yandex.ru" in YANDEX_DIALOGS_DEVELOPER_URL + assert "oauth.yandex.ru" in YANDEX_OAUTH_URL + assert "yaha-cloud.ru" in CLOUD_SKILL_WEBHOOK_TEMPLATE + + +def test_constants_capability_types() -> None: + """Yandex capability constants should be properly defined.""" + assert "on_off" in CAPABILITY_ON_OFF + assert "range" in CAPABILITY_RANGE + assert "toggle" in CAPABILITY_TOGGLE + assert INSTANCE_VOLUME == "volume" + assert INSTANCE_MUTE == "mute" + assert INSTANCE_PAUSE == "pause" diff --git a/tests/providers/yandex_smarthome/test_cloud.py b/tests/providers/yandex_smarthome/test_cloud.py new file mode 100644 index 0000000000..03507b53c2 --- /dev/null +++ b/tests/providers/yandex_smarthome/test_cloud.py @@ -0,0 +1,229 @@ +"""Tests for provider/cloud.py — CloudManager WebSocket and registration helpers.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import aiohttp +import pytest + +from music_assistant.providers.yandex_smarthome._compat import SecretStr +from music_assistant.providers.yandex_smarthome.cloud import ( + CloudManager, + get_cloud_otp, + register_cloud_instance, +) +from music_assistant.providers.yandex_smarthome.constants import CLOUD_RECONNECT_MIN +from music_assistant.providers.yandex_smarthome.schema import CloudRequest + +# --------------------------------------------------------------------------- +# CloudManager tests +# --------------------------------------------------------------------------- + + +class TestCloudManager: + """Tests for CloudManager WebSocket client.""" + + def _make_manager(self, on_request: AsyncMock | None = None) -> CloudManager: + """Create make manager helper.""" + session = MagicMock(spec=aiohttp.ClientSession) + if on_request is None: + on_request = AsyncMock(return_value={"request_id": "r1", "payload": {}}) + return CloudManager( + session=session, + connection_token=SecretStr("test-token"), + on_request=on_request, + ) + + def test_initial_state(self) -> None: + """Test initial state.""" + mgr = self._make_manager() + assert mgr.connected is False + assert mgr._running is False + + def test_connected_property(self) -> None: + """Test connected property.""" + mgr = self._make_manager() + assert mgr.connected is False + + # Simulate connected WS + ws = MagicMock() + ws.closed = False + mgr._ws = ws + assert mgr.connected is True + + # Simulate closed WS + ws.closed = True # type: ignore[unreachable] + assert mgr.connected is False + + @pytest.mark.asyncio + async def test_handle_message_calls_callback(self) -> None: + """Test handle message calls callback.""" + callback = AsyncMock(return_value={"request_id": "r1", "payload": {}}) + mgr = self._make_manager(on_request=callback) + + ws = AsyncMock() + data = {"request_id": "r1", "action": "/v1.0/user/devices", "message": {}} + await mgr._handle_message(ws, data) + + callback.assert_awaited_once() + args = callback.call_args[0][0] + assert isinstance(args, CloudRequest) + assert args.request_id == "r1" + assert args.action == "/v1.0/user/devices" + ws.send_json.assert_awaited_once() + + @pytest.mark.asyncio + async def test_handle_message_exception_logged(self) -> None: + """Test handle message exception sends error response.""" + callback = AsyncMock(side_effect=RuntimeError("boom")) + mgr = self._make_manager(on_request=callback) + + ws = AsyncMock() + ws.closed = False + # Should not raise + await mgr._handle_message(ws, {"request_id": "r1", "action": "test"}) + # Should send error response so relay doesn't hang + ws.send_json.assert_awaited_once_with( + {"request_id": "r1", "payload": {"error": "INTERNAL_ERROR"}} + ) + + @pytest.mark.asyncio + async def test_disconnect(self) -> None: + """Test disconnect.""" + mgr = self._make_manager() + mgr._running = True + ws = AsyncMock() + ws.closed = False + mgr._ws = ws + + await mgr.disconnect() + + assert mgr._running is False + ws.close.assert_awaited_once() + assert mgr._ws is None + + @pytest.mark.asyncio + async def test_disconnect_when_already_closed(self) -> None: + """Test disconnect when already closed.""" + mgr = self._make_manager() + mgr._running = True + ws = AsyncMock() + ws.closed = True + mgr._ws = ws + + await mgr.disconnect() + # Should not call close on already closed WS + ws.close.assert_not_awaited() + + @pytest.mark.asyncio + async def test_disconnect_when_no_ws(self) -> None: + """Test disconnect when no ws.""" + mgr = self._make_manager() + mgr._running = True + # No WS at all + await mgr.disconnect() + assert mgr._running is False + + def test_reconnect_delay_reset_logic(self) -> None: + """Verify that _reconnect_delay is set to min by default.""" + mgr = self._make_manager() + assert mgr._reconnect_delay == CLOUD_RECONNECT_MIN + + +# --------------------------------------------------------------------------- +# Registration helpers +# --------------------------------------------------------------------------- + + +class TestRegisterCloudInstance: + """Tests for register_cloud_instance helper.""" + + @pytest.mark.asyncio + async def test_register(self) -> None: + """Test register.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json = AsyncMock( + return_value={ + "id": "inst-123", + "password": "pwd-xyz", + "connection_token": "tok-abc", + } + ) + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + result = await register_cloud_instance(session) + assert result["id"] == "inst-123" + assert result["password"] == "pwd-xyz" + assert result["connection_token"] == "tok-abc" + + @pytest.mark.asyncio + async def test_register_no_platform_param(self) -> None: + """Test register no platform param.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json = AsyncMock( + return_value={ + "id": "inst-1", + "password": "p", + "connection_token": "t", + } + ) + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + await register_cloud_instance(session) + # Standard cloud mode: no json body (compatible with yaha-cloud.ru) + call_kwargs = session.post.call_args + assert call_kwargs.kwargs.get("json") is None + + +class TestGetCloudOtp: + """Tests for get_cloud_otp helper.""" + + @pytest.mark.asyncio + async def test_get_otp(self) -> None: + """Test get otp.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json = AsyncMock(return_value={"code": "123456"}) + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + otp = await get_cloud_otp(session, "inst-123", SecretStr("tok-abc")) + assert otp == "123456" + + @pytest.mark.asyncio + async def test_get_otp_uses_post(self) -> None: + """Test get otp uses post.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.raise_for_status = MagicMock() + mock_resp.json = AsyncMock(return_value={"code": "999"}) + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + await get_cloud_otp(session, "inst-1", SecretStr("tok-1")) + session.post.assert_called_once() + assert "inst-1" in str(session.post.call_args) diff --git a/tests/providers/yandex_smarthome/test_device.py b/tests/providers/yandex_smarthome/test_device.py new file mode 100644 index 0000000000..f373d8b28f --- /dev/null +++ b/tests/providers/yandex_smarthome/test_device.py @@ -0,0 +1,868 @@ +"""Tests for provider/device.py — MA Player ↔ Yandex device mapper.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from unittest.mock import AsyncMock + +import pytest + +# Use the PlaybackState from conftest's mock enums +from music_assistant_models.enums import PlaybackState + +from music_assistant.providers.yandex_smarthome.constants import ( + INSTANCE_CHANNEL, + INSTANCE_INPUT_SOURCE, + INSTANCE_MUTE, + INSTANCE_ON, + INSTANCE_PAUSE, + INSTANCE_VOLUME, + YANDEX_DEVICE_TYPE_MEDIA, +) +from music_assistant.providers.yandex_smarthome.device import ( + execute_capability_action, + get_device_description, + get_device_state, + is_player_exposable, + make_error_action_result, + make_error_device_state, + normalize_device_name, +) +from music_assistant.providers.yandex_smarthome.schema import ( + CapabilityAction, + CapabilityActionState, + YandexCapabilityType, +) + + +@dataclass +class MockDeviceInfo: + """Mock device info for testing.""" + + model: str = "Test Speaker" + + +@dataclass +class MockPlayerSource: + """Minimal mock of music_assistant_models.player.PlayerSource.""" + + id: str = "source_1" + name: str = "Source 1" + + +@dataclass +class MockPlayer: + """Minimal mock of music_assistant_models.player.Player.""" + + player_id: str = "test_player_1" + name: str = "Living Room Speaker" + available: bool = True + enabled: bool = True + powered: bool | None = True + playback_state: PlaybackState = PlaybackState.IDLE + volume_level: int | None = 50 + volume_muted: bool | None = False + synced_to: str | None = None + device_info: MockDeviceInfo | None = None + supported_features: set[str] = field(default_factory=set) + source_list: list[MockPlayerSource] = field(default_factory=list) + active_source: str | None = None + group_members: list[str] = field(default_factory=list) + group_volume: int | None = None + group_volume_muted: bool | None = None + + +class MockPlayers: + """Mock of mass.players controller.""" + + def __init__(self) -> None: + """Initialize mock players controller.""" + self.cmd_play = AsyncMock() + self.cmd_stop = AsyncMock() + self.cmd_pause = AsyncMock() + self.cmd_power = AsyncMock() + self.cmd_volume_set = AsyncMock() + self.cmd_volume_mute = AsyncMock() + self.cmd_group_volume = AsyncMock() + self.cmd_group_volume_mute = AsyncMock() + self.cmd_next_track = AsyncMock() + self.cmd_previous_track = AsyncMock() + self.select_source = AsyncMock() + self._players: dict[str, MockPlayer] = {} + + def get_player(self, player_id: str) -> MockPlayer | None: + """Return player by ID.""" + return self._players.get(player_id) + + +@dataclass +class MockMass: + """Mock MusicAssistant for testing.""" + + players: MockPlayers = field(default_factory=MockPlayers) + + +# --------------------------------------------------------------------------- +# Tests: get_device_description +# --------------------------------------------------------------------------- + + +class TestGetDeviceDescription: + """Tests for get_device_description.""" + + def test_basic_description(self) -> None: + """Test basic description without mute support.""" + player = MockPlayer() + desc = get_device_description(player) # type: ignore[arg-type] + assert desc.id == "test_player_1" + assert desc.name == "Living Room Speaker" + assert desc.type == YANDEX_DEVICE_TYPE_MEDIA + # 4 base capabilities: on_off, volume, pause, channel (no mute without feature) + assert len(desc.capabilities) == 4 + instances = {c.parameters.instance for c in desc.capabilities if c.parameters} + assert INSTANCE_MUTE not in instances + + def test_description_with_mute(self) -> None: + """Test description includes mute toggle when VOLUME_MUTE feature is set.""" + player = MockPlayer(supported_features={"volume_mute"}) + desc = get_device_description(player) # type: ignore[arg-type] + assert len(desc.capabilities) == 5 + instances = {c.parameters.instance for c in desc.capabilities if c.parameters} + assert INSTANCE_MUTE in instances + + def test_description_group_has_mute(self) -> None: + """Group players always get mute toggle even without VOLUME_MUTE feature.""" + player = MockPlayer(group_members=["child1", "child2"]) + desc = get_device_description(player) # type: ignore[arg-type] + assert len(desc.capabilities) == 5 + instances = {c.parameters.instance for c in desc.capabilities if c.parameters} + assert INSTANCE_MUTE in instances + + def test_capability_types(self) -> None: + """Test capability types.""" + player = MockPlayer() + desc = get_device_description(player) # type: ignore[arg-type] + types = [c.type for c in desc.capabilities] + assert YandexCapabilityType.ON_OFF in types + assert YandexCapabilityType.RANGE in types + assert YandexCapabilityType.TOGGLE in types + + def test_volume_range_params(self) -> None: + """Test volume range params.""" + player = MockPlayer() + desc = get_device_description(player) # type: ignore[arg-type] + range_cap = next(c for c in desc.capabilities if c.type == YandexCapabilityType.RANGE) + assert range_cap.parameters is not None + assert range_cap.parameters.instance == "volume" + assert range_cap.parameters.range is not None + assert range_cap.parameters.range.min == 0 + assert range_cap.parameters.range.max == 100 + + def test_device_info_model(self) -> None: + """Test device info model.""" + player = MockPlayer(device_info=MockDeviceInfo(model="KEF LS50")) + desc = get_device_description(player) # type: ignore[arg-type] + assert desc.device_info is not None + assert desc.device_info.model == "KEF LS50" + + def test_device_info_default(self) -> None: + """Test device info default.""" + player = MockPlayer() + desc = get_device_description(player) # type: ignore[arg-type] + assert desc.device_info is not None + assert desc.device_info.model == "MA Player" + + def test_name_normalized(self) -> None: + """Test that device name is normalized for Yandex.""" + player = MockPlayer(name="KEF-LS50 (Kitchen)") + desc = get_device_description(player) # type: ignore[arg-type] + assert desc.name == "KEF LS 50 Kitchen" + + +# --------------------------------------------------------------------------- +# Tests: normalize_device_name +# --------------------------------------------------------------------------- + + +class TestNormalizeDeviceName: + """Tests for normalize_device_name.""" + + def test_passthrough_clean_name(self) -> None: + """Clean names pass through unchanged.""" + assert normalize_device_name("Living Room Speaker") == "Living Room Speaker" + + def test_russian_name(self) -> None: + """Russian names pass through.""" + assert normalize_device_name("Колонка в гостиной") == "Колонка в гостиной" + + def test_strip_special_chars(self) -> None: + """Special characters replaced with spaces.""" + assert normalize_device_name("KEF-LS50 (Kitchen)") == "KEF LS 50 Kitchen" + + def test_space_between_letters_and_digits(self) -> None: + """Mandatory space between letters and digits.""" + assert normalize_device_name("Sonos5") == "Sonos 5" + assert normalize_device_name("3колонка") == "3 колонка" # noqa: RUF001 + + def test_collapse_multiple_spaces(self) -> None: + """Multiple spaces collapsed to one.""" + assert normalize_device_name("KEF LS50") == "KEF LS 50" + + def test_strip_edges(self) -> None: + """Leading/trailing spaces stripped.""" + assert normalize_device_name(" Speaker ") == "Speaker" + + def test_mixed_russian_english(self) -> None: + """Mixed Russian and English.""" + assert normalize_device_name("Колонка JBL5") == "Колонка JBL 5" + + def test_empty_fallback(self) -> None: + """If normalization produces empty string, return original.""" + assert normalize_device_name("---") == "---" + + def test_digits_only(self) -> None: + """Digit-only names preserved.""" + assert normalize_device_name("123") == "123" + + +class TestGetDeviceState: + """Tests for get_device_state.""" + + def test_idle_state(self) -> None: + """Test idle state without mute support.""" + player = MockPlayer(playback_state=PlaybackState.IDLE, volume_level=30, volume_muted=False) + state = get_device_state(player) # type: ignore[arg-type] + assert state.id == "test_player_1" + + by_instance = {c.state.instance: c.state.value for c in state.capabilities} + assert by_instance[INSTANCE_ON] is True # powered on = "on" regardless of playback + assert by_instance[INSTANCE_VOLUME] == 30 + assert INSTANCE_MUTE not in by_instance + assert by_instance[INSTANCE_PAUSE] is True + + def test_idle_state_with_mute(self) -> None: + """Test idle state with mute support.""" + player = MockPlayer( + playback_state=PlaybackState.IDLE, + volume_level=30, + volume_muted=False, + supported_features={"volume_mute"}, + ) + state = get_device_state(player) # type: ignore[arg-type] + + by_instance = {c.state.instance: c.state.value for c in state.capabilities} + assert by_instance[INSTANCE_MUTE] is False + + def test_playing_state(self) -> None: + """Test playing state.""" + player = MockPlayer(playback_state=PlaybackState.PLAYING, volume_level=75) + state = get_device_state(player) # type: ignore[arg-type] + + by_instance = {c.state.instance: c.state.value for c in state.capabilities} + assert by_instance[INSTANCE_ON] is True + assert by_instance[INSTANCE_VOLUME] == 75 + assert by_instance[INSTANCE_PAUSE] is False + + def test_paused_state(self) -> None: + """Test paused state.""" + player = MockPlayer(playback_state=PlaybackState.PAUSED, volume_level=50) + state = get_device_state(player) # type: ignore[arg-type] + + by_instance = {c.state.instance: c.state.value for c in state.capabilities} + assert by_instance[INSTANCE_ON] is True # paused is still "on" + assert by_instance[INSTANCE_PAUSE] is True + + def test_none_volume(self) -> None: + """Test none volume.""" + player = MockPlayer( + volume_level=None, volume_muted=None, supported_features={"volume_mute"} + ) + state = get_device_state(player) # type: ignore[arg-type] + + by_instance = {c.state.instance: c.state.value for c in state.capabilities} + assert by_instance[INSTANCE_VOLUME] == 0 + assert by_instance[INSTANCE_MUTE] is False + + def test_group_state_includes_mute(self) -> None: + """Group players should include mute state using group_volume_muted.""" + player = MockPlayer( + playback_state=PlaybackState.PLAYING, + volume_level=40, + volume_muted=False, + group_members=["c1", "c2"], + group_volume=75, + group_volume_muted=True, + ) + state = get_device_state(player) # type: ignore[arg-type] + + by_instance = {c.state.instance: c.state.value for c in state.capabilities} + assert by_instance[INSTANCE_VOLUME] == 75 + assert by_instance[INSTANCE_MUTE] is True + + +# --------------------------------------------------------------------------- +# Tests: execute_capability_action +# --------------------------------------------------------------------------- + + +class TestExecuteCapabilityAction: + """Tests for execute_capability_action.""" + + @pytest.mark.asyncio + async def test_on_off_true_plays(self) -> None: + """Test on off true plays.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value=True), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_play.assert_awaited_once_with("p1") + mass.players.cmd_power.assert_not_awaited() + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_on_off_false_stops(self) -> None: + """Test on off false stops.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value=False), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_stop.assert_awaited_once_with("p1") + mass.players.cmd_power.assert_not_awaited() + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_on_powers_on_when_supported(self) -> None: + """ON with power feature should power on then play.""" + mass = MockMass() + player = MockPlayer(player_id="p1", supported_features={"power"}) + mass.players._players["p1"] = player + action = CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value=True), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_power.assert_awaited_once_with("p1", True) + mass.players.cmd_play.assert_awaited_once_with("p1") + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_off_powers_off_when_supported(self) -> None: + """OFF with power feature should stop then power off.""" + mass = MockMass() + player = MockPlayer(player_id="p1", supported_features={"power"}) + mass.players._players["p1"] = player + action = CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value=False), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_stop.assert_awaited_once_with("p1") + mass.players.cmd_power.assert_awaited_once_with("p1", False) + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_volume_absolute(self) -> None: + """Test volume absolute.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="volume", value=65), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_volume_set.assert_awaited_once_with("p1", 65) + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_volume_relative_up(self) -> None: + """Test volume relative up.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="volume", value=10, relative=True), + ) + result = await execute_capability_action(mass, "p1", action, current_volume=50) + mass.players.cmd_volume_set.assert_awaited_once_with("p1", 60) + assert result.state.value == 60 + + @pytest.mark.asyncio + async def test_volume_relative_clamp_max(self) -> None: + """Test volume relative clamp max.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="volume", value=20, relative=True), + ) + result = await execute_capability_action(mass, "p1", action, current_volume=90) + mass.players.cmd_volume_set.assert_awaited_once_with("p1", 100) + assert result.state.value == 100 + + @pytest.mark.asyncio + async def test_volume_relative_clamp_min(self) -> None: + """Test volume relative clamp min.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="volume", value=-20, relative=True), + ) + result = await execute_capability_action(mass, "p1", action, current_volume=10) + mass.players.cmd_volume_set.assert_awaited_once_with("p1", 0) + assert result.state.value == 0 + + @pytest.mark.asyncio + async def test_mute_toggle(self) -> None: + """Test mute toggle.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.TOGGLE, + state=CapabilityActionState(instance="mute", value=True), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_volume_mute.assert_awaited_once_with("p1", True) + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_mute_toggle_group(self) -> None: + """Group mute should use cmd_group_volume_mute.""" + mass = MockMass() + group = MockPlayer(player_id="grp", group_members=["c1", "c2"]) + mass.players._players["grp"] = group + action = CapabilityAction( + type=YandexCapabilityType.TOGGLE, + state=CapabilityActionState(instance="mute", value=True), + ) + result = await execute_capability_action(mass, "grp", action) + assert result.state.action_result.status == "DONE" + mass.players.cmd_group_volume_mute.assert_awaited_once_with("grp", True) + + @pytest.mark.asyncio + async def test_volume_set_group(self) -> None: + """Group volume should use cmd_group_volume.""" + mass = MockMass() + group = MockPlayer(player_id="grp", group_members=["c1", "c2"]) + mass.players._players["grp"] = group + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="volume", value=70), + ) + result = await execute_capability_action(mass, "grp", action) + assert result.state.action_result.status == "DONE" + mass.players.cmd_group_volume.assert_awaited_once_with("grp", 70) + mass.players.cmd_volume_set.assert_not_awaited() + + @pytest.mark.asyncio + async def test_pause_true(self) -> None: + """Test pause true.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.TOGGLE, + state=CapabilityActionState(instance="pause", value=True), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_pause.assert_awaited_once_with("p1") + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_pause_false_plays(self) -> None: + """Test pause false plays.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.TOGGLE, + state=CapabilityActionState(instance="pause", value=False), + ) + await execute_capability_action(mass, "p1", action) + mass.players.cmd_play.assert_awaited_once_with("p1") + + @pytest.mark.asyncio + async def test_unknown_capability_returns_error(self) -> None: + """Test unknown capability returns error.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type="devices.capabilities.unknown", + state=CapabilityActionState(instance="foo", value=42), + ) + result = await execute_capability_action(mass, "p1", action) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "INVALID_ACTION" + + @pytest.mark.asyncio + async def test_command_exception_returns_error(self) -> None: + """Test command exception returns error.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + mass.players.cmd_play.side_effect = RuntimeError("Connection lost") + action = CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value=True), + ) + result = await execute_capability_action(mass, "p1", action) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "INTERNAL_ERROR" + + @pytest.mark.asyncio + async def test_on_off_non_bool_value_returns_invalid_action(self) -> None: + """ON_OFF with non-bool value (e.g. string 'false') must not power on.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value="false"), + ) + result = await execute_capability_action(mass, "p1", action) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "INVALID_ACTION" + mass.players.cmd_play.assert_not_awaited() + mass.players.cmd_stop.assert_not_awaited() + + @pytest.mark.asyncio + async def test_mute_non_bool_value_returns_invalid_action(self) -> None: + """Mute toggle with non-bool value must not call cmd_volume_mute.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.TOGGLE, + state=CapabilityActionState(instance="mute", value="false"), + ) + result = await execute_capability_action(mass, "p1", action) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "INVALID_ACTION" + mass.players.cmd_volume_mute.assert_not_awaited() + + @pytest.mark.asyncio + async def test_pause_non_bool_value_returns_invalid_action(self) -> None: + """Pause toggle with non-bool value must not call cmd_pause/cmd_play.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.TOGGLE, + state=CapabilityActionState(instance="pause", value="true"), + ) + result = await execute_capability_action(mass, "p1", action) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "INVALID_ACTION" + mass.players.cmd_pause.assert_not_awaited() + mass.players.cmd_play.assert_not_awaited() + + @pytest.mark.asyncio + async def test_volume_range_bool_value_returns_invalid_action(self) -> None: + """Volume RANGE with bool value must not silently set volume to 0/1.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1", volume_level=50) + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="volume", value=True), + ) + result = await execute_capability_action(mass, "p1", action, current_volume=50) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "INVALID_ACTION" + mass.players.cmd_volume_set.assert_not_awaited() + mass.players.cmd_group_volume.assert_not_awaited() + + @pytest.mark.asyncio + async def test_channel_range_bool_value_returns_invalid_action(self) -> None: + """Channel RANGE with bool value must not silently skip tracks.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="channel", value=True, relative=True), + ) + result = await execute_capability_action(mass, "p1", action) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "INVALID_ACTION" + mass.players.cmd_next_track.assert_not_awaited() + mass.players.cmd_previous_track.assert_not_awaited() + + @pytest.mark.asyncio + async def test_missing_player_returns_device_unreachable(self) -> None: + """Player not found should return DEVICE_UNREACHABLE.""" + mass = MockMass() + action = CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value=True), + ) + result = await execute_capability_action(mass, "nonexistent", action) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "DEVICE_UNREACHABLE" + + +# --------------------------------------------------------------------------- +# Tests: is_player_exposable +# --------------------------------------------------------------------------- + + +class TestIsPlayerExposable: + """Tests for is_player_exposable.""" + + def test_normal_player(self) -> None: + """Test normal player.""" + assert is_player_exposable(MockPlayer()) is True # type: ignore[arg-type] + + def test_unavailable(self) -> None: + """Test unavailable.""" + assert is_player_exposable(MockPlayer(available=False)) is False # type: ignore[arg-type] + + def test_disabled(self) -> None: + """Test disabled.""" + assert is_player_exposable(MockPlayer(enabled=False)) is False # type: ignore[arg-type] + + def test_synced_to_another(self) -> None: + """Test synced to another.""" + assert is_player_exposable(MockPlayer(synced_to="other_player")) is False # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Tests: error helpers +# --------------------------------------------------------------------------- + + +class TestErrorHelpers: + """Tests for error helper functions.""" + + def test_make_error_device_state(self) -> None: + """Test make error device state.""" + state = make_error_device_state("p1") + assert state.id == "p1" + assert state.error_code == "DEVICE_UNREACHABLE" + assert state.capabilities == [] + + def test_make_error_action_result(self) -> None: + """Test make error action result.""" + actions = [ + CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value=True), + ), + CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="volume", value=50), + ), + ] + results = make_error_action_result("p1", actions) + assert len(results) == 2 + assert all(r.state.action_result.status == "ERROR" for r in results) + assert all(r.state.action_result.error_code == "DEVICE_UNREACHABLE" for r in results) + + +# --------------------------------------------------------------------------- +# Tests: channel capability (next/previous track) +# --------------------------------------------------------------------------- + + +class TestChannelCapability: + """Tests for channel capability handling.""" + + def test_channel_in_description(self) -> None: + """Channel capability should always be present in device description.""" + player = MockPlayer() + desc = get_device_description(player) # type: ignore[arg-type] + channel_caps = [ + c + for c in desc.capabilities + if c.type == YandexCapabilityType.RANGE + and c.parameters + and c.parameters.instance == INSTANCE_CHANNEL + ] + assert len(channel_caps) == 1 + cap = channel_caps[0] + assert cap.parameters.random_access is False # type: ignore[union-attr] + assert cap.parameters.range is not None # type: ignore[union-attr] + assert cap.parameters.range.min == 0 # type: ignore[union-attr] + assert cap.parameters.range.max == 999 # type: ignore[union-attr] + + def test_channel_state_always_zero(self) -> None: + """Channel state should always report value 0.""" + player = MockPlayer(playback_state=PlaybackState.PLAYING) + state = get_device_state(player) # type: ignore[arg-type] + channel_states = [c for c in state.capabilities if c.state.instance == INSTANCE_CHANNEL] + assert len(channel_states) == 1 + assert channel_states[0].state.value == 0 + + @pytest.mark.asyncio + async def test_channel_relative_positive_next_track(self) -> None: + """Relative +1 channel → cmd_next_track.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="channel", value=1, relative=True), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_next_track.assert_awaited_once_with("p1") + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_channel_relative_negative_prev_track(self) -> None: + """Relative -1 channel → cmd_previous_track.""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="channel", value=-1, relative=True), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_previous_track.assert_awaited_once_with("p1") + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_channel_non_relative_ignored(self) -> None: + """Non-relative channel set is a no-op (returns DONE).""" + mass = MockMass() + mass.players._players["p1"] = MockPlayer(player_id="p1") + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="channel", value=5, relative=False), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.cmd_next_track.assert_not_awaited() + mass.players.cmd_previous_track.assert_not_awaited() + assert result.state.action_result.status == "DONE" + + +# --------------------------------------------------------------------------- +# Tests: input_source capability (mode/input_source) +# --------------------------------------------------------------------------- + + +class TestInputSourceCapability: + """Tests for input source capability handling.""" + + def test_no_source_list_no_mode_cap(self) -> None: + """Player without source_list should not have mode capability.""" + player = MockPlayer(source_list=[]) + desc = get_device_description(player) # type: ignore[arg-type] + mode_caps = [c for c in desc.capabilities if c.type == YandexCapabilityType.MODE] + assert len(mode_caps) == 0 + + def test_with_sources_has_mode_cap(self) -> None: + """Player with source_list should have mode(input_source) capability.""" + sources = [ + MockPlayerSource(id="hdmi1", name="HDMI 1"), + MockPlayerSource(id="optical", name="Optical"), + ] + player = MockPlayer(source_list=sources, supported_features={"select_source"}) + desc = get_device_description(player) # type: ignore[arg-type] + mode_caps = [c for c in desc.capabilities if c.type == YandexCapabilityType.MODE] + assert len(mode_caps) == 1 + cap = mode_caps[0] + assert cap.parameters.instance == INSTANCE_INPUT_SOURCE # type: ignore[union-attr] + assert cap.parameters.modes is not None # type: ignore[union-attr] + assert len(cap.parameters.modes) == 2 # type: ignore[union-attr] + assert cap.parameters.modes[0].value == "one" # type: ignore[union-attr] + assert cap.parameters.modes[1].value == "two" # type: ignore[union-attr] + + def test_max_10_sources(self) -> None: + """Only the first 10 sources should be mapped.""" + sources = [MockPlayerSource(id=f"s{i}", name=f"Source {i}") for i in range(15)] + player = MockPlayer(source_list=sources, supported_features={"select_source"}) + desc = get_device_description(player) # type: ignore[arg-type] + mode_caps = [c for c in desc.capabilities if c.type == YandexCapabilityType.MODE] + assert len(mode_caps[0].parameters.modes) == 10 # type: ignore[arg-type,union-attr] + + def test_state_with_active_source(self) -> None: + """State should report current source as mode value.""" + sources = [ + MockPlayerSource(id="hdmi1", name="HDMI 1"), + MockPlayerSource(id="optical", name="Optical"), + ] + player = MockPlayer( + source_list=sources, + active_source="Optical", + playback_state=PlaybackState.PLAYING, + supported_features={"select_source"}, + ) + state = get_device_state(player) # type: ignore[arg-type] + mode_states = [c for c in state.capabilities if c.state.instance == INSTANCE_INPUT_SOURCE] + assert len(mode_states) == 1 + assert mode_states[0].state.value == "two" # index 1 → "two" + + def test_state_no_active_source(self) -> None: + """No active source → no input_source state reported.""" + sources = [MockPlayerSource(id="hdmi1", name="HDMI 1")] + player = MockPlayer( + source_list=sources, active_source=None, supported_features={"select_source"} + ) + state = get_device_state(player) # type: ignore[arg-type] + mode_states = [c for c in state.capabilities if c.state.instance == INSTANCE_INPUT_SOURCE] + assert len(mode_states) == 0 + + @pytest.mark.asyncio + async def test_select_source_action(self) -> None: + """Mode action should call select_source with resolved source id.""" + sources = [ + MockPlayerSource(id="hdmi1", name="HDMI 1"), + MockPlayerSource(id="optical", name="Optical"), + ] + player = MockPlayer( + player_id="p1", source_list=sources, supported_features={"select_source"} + ) + mass = MockMass() + mass.players._players["p1"] = player + + action = CapabilityAction( + type=YandexCapabilityType.MODE, + state=CapabilityActionState(instance="input_source", value="two"), + ) + result = await execute_capability_action(mass, "p1", action) + mass.players.select_source.assert_awaited_once_with("p1", "optical") + assert result.state.action_result.status == "DONE" + + @pytest.mark.asyncio + async def test_unknown_source_mode_returns_error(self) -> None: + """Invalid mode value should return INVALID_ACTION error.""" + player = MockPlayer(player_id="p1", source_list=[]) + mass = MockMass() + mass.players._players["p1"] = player + + action = CapabilityAction( + type=YandexCapabilityType.MODE, + state=CapabilityActionState(instance="input_source", value="five"), + ) + result = await execute_capability_action(mass, "p1", action) + assert result.state.action_result.status == "ERROR" + assert result.state.action_result.error_code == "INVALID_ACTION" + + +# --------------------------------------------------------------------------- +# Tests: player filter (exposed_ids) +# --------------------------------------------------------------------------- + + +class TestPlayerFilter: + """Tests for player filtering with exposed_ids.""" + + def test_no_filter_exposes_all(self) -> None: + """Without exposed_ids, all valid players are exposed.""" + assert is_player_exposable(MockPlayer()) is True # type: ignore[arg-type] + + def test_filter_includes_player(self) -> None: + """Player in the filter set is exposed.""" + assert is_player_exposable(MockPlayer(player_id="p1"), exposed_ids={"p1", "p2"}) is True # type: ignore[arg-type] + + def test_filter_excludes_player(self) -> None: + """Player not in the filter set is NOT exposed.""" + assert is_player_exposable(MockPlayer(player_id="p3"), exposed_ids={"p1", "p2"}) is False # type: ignore[arg-type] + + def test_empty_filter_exposes_all(self) -> None: + """Empty set filter should expose all players (same as None).""" + assert is_player_exposable(MockPlayer(player_id="p1"), exposed_ids=set()) is True # type: ignore[arg-type] + + def test_filter_still_checks_available(self) -> None: + """Even in filter, unavailable players are not exposed.""" + assert ( + is_player_exposable(MockPlayer(player_id="p1", available=False), exposed_ids={"p1"}) # type: ignore[arg-type] + is False + ) diff --git a/tests/providers/yandex_smarthome/test_direct.py b/tests/providers/yandex_smarthome/test_direct.py new file mode 100644 index 0000000000..7323b68e91 --- /dev/null +++ b/tests/providers/yandex_smarthome/test_direct.py @@ -0,0 +1,829 @@ +"""Tests for the DirectConnectionHandler (direct connection mode).""" + +from __future__ import annotations + +import json +import time +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aiohttp.test_utils import make_mocked_request + +if TYPE_CHECKING: + from aiohttp import web + +from music_assistant.providers.yandex_smarthome.constants import ( + CONNECTION_TYPE_DIRECT, + DIRECT_API_BASE_PATH, + DIRECT_AUTH_BASE_PATH, + DIRECT_HEALTH_RESPONSE, + DIRECT_OAUTH_CLIENT_ID, + OAUTH_CODE_EXPIRY, +) +from music_assistant.providers.yandex_smarthome.direct import DirectConnectionHandler +from music_assistant.providers.yandex_smarthome.plugin import YandexSmartHomePlugin + +TEST_CLIENT_SECRET = "test-client-secret-abc123" + +# token_store lists shared between fixtures and tests +_handler_tokens: list[str] = [] +_handler_no_token_tokens: list[str] = [] + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_mass() -> MagicMock: + """Return a mock MusicAssistant with a webserver stub.""" + mass = MagicMock() + mass.webserver.base_url = "https://my-ma.example.com" + mass.webserver.register_dynamic_route = MagicMock(return_value=MagicMock()) + mass.players = [] + mass.http_session = MagicMock() + return mass + + +@pytest.fixture +def handler(mock_mass: MagicMock) -> DirectConnectionHandler: + """Return a DirectConnectionHandler with a known token.""" + _handler_tokens.clear() + + def on_token(t: str) -> None: + _handler_tokens.append(t) + + return DirectConnectionHandler( + mass=mock_mass, + user_id="test_user", + access_token="test-token-abc", + client_secret=TEST_CLIENT_SECRET, + exposed_ids=None, + on_token_created=on_token, + ) + + +@pytest.fixture +def handler_no_token(mock_mass: MagicMock) -> DirectConnectionHandler: + """Return a handler with no initial access token (first-time OAuth flow).""" + _handler_no_token_tokens.clear() + + def on_token(t: str) -> None: + _handler_no_token_tokens.append(t) + + return DirectConnectionHandler( + mass=mock_mass, + user_id="test_user", + access_token="", + client_secret=TEST_CLIENT_SECRET, + exposed_ids=None, + on_token_created=on_token, + ) + + +def _make_request( + method: str = "GET", + path: str = "/", + headers: dict[str, str] | None = None, + payload: dict[str, Any] | None = None, + query: dict[str, str] | None = None, + post_data: dict[str, str] | None = None, +) -> web.Request: + """Build a mock aiohttp Request.""" + req = make_mocked_request( + method, + path, + headers=headers or {}, + ) + if query: + object.__setattr__(req, "_rel_url", req._rel_url.with_query(query)) + if payload is not None: + object.__setattr__(req, "_payload_writer", None) + + async def _json(**_kwargs: Any) -> dict[str, Any]: + return payload + + object.__setattr__(req, "json", _json) + if post_data is not None: + + async def _post() -> dict[str, str]: + return post_data + + object.__setattr__(req, "post", _post) + return req + + +def _get_pending_code(h: DirectConnectionHandler) -> str: + """Return the first pending authorization code from a handler.""" + return next(iter(h._pending_codes.keys())) + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + + +def test_connection_type_direct() -> None: + """CONNECTION_TYPE_DIRECT should be 'direct'.""" + assert CONNECTION_TYPE_DIRECT == "direct" + + +def test_api_base_path() -> None: + """DIRECT_API_BASE_PATH should be under /api/yandex_smarthome.""" + assert DIRECT_API_BASE_PATH.startswith("/api/yandex_smarthome") + + +def test_auth_base_path() -> None: + """DIRECT_AUTH_BASE_PATH should be under /api/yandex_smarthome.""" + assert DIRECT_AUTH_BASE_PATH.startswith("/api/yandex_smarthome") + + +def test_health_response() -> None: + """DIRECT_HEALTH_RESPONSE should be non-empty.""" + assert len(DIRECT_HEALTH_RESPONSE) > 0 + + +def test_oauth_constants() -> None: + """OAuth constants should match Yandex Smart Home spec.""" + assert DIRECT_OAUTH_CLIENT_ID == "https://social.yandex.net/" + assert OAUTH_CODE_EXPIRY == 300 + + +# --------------------------------------------------------------------------- +# Route registration +# --------------------------------------------------------------------------- + + +def test_register_routes(handler: DirectConnectionHandler, mock_mass: MagicMock) -> None: + """register_routes should register all 10 HTTP routes.""" + handler.register_routes() + assert mock_mass.webserver.register_dynamic_route.call_count == 10 + + +def test_unregister_routes(handler: DirectConnectionHandler) -> None: + """unregister_routes should call all stored callbacks and clear.""" + handler.register_routes() + handler.unregister_routes() + assert len(handler._unregister_callbacks) == 0 + + +def test_register_routes_rolls_back_on_failure( + handler: DirectConnectionHandler, mock_mass: MagicMock +) -> None: + """register_routes must unregister partial routes and re-raise on RuntimeError.""" + unregister_cb = MagicMock() + call_count = {"n": 0} + + def register(_path: str, _handler_fn: Any, _method: str) -> Any: + call_count["n"] += 1 + if call_count["n"] == 3: + raise RuntimeError("already registered") + return unregister_cb + + mock_mass.webserver.register_dynamic_route.side_effect = register + + with pytest.raises(RuntimeError): + handler.register_routes() + + # 2 successful registrations must be rolled back via unregister_cb + assert unregister_cb.call_count == 2 + assert handler._unregister_callbacks == [] + + +# --------------------------------------------------------------------------- +# Auth validation +# --------------------------------------------------------------------------- + + +def test_auth_valid(handler: DirectConnectionHandler) -> None: + """Valid Bearer token should pass validation.""" + req = _make_request(headers={"Authorization": "Bearer test-token-abc"}) + assert handler._validate_auth(req) is True + + +def test_auth_invalid(handler: DirectConnectionHandler) -> None: + """Wrong Bearer token should fail validation.""" + req = _make_request(headers={"Authorization": "Bearer wrong-token"}) + assert handler._validate_auth(req) is False + + +def test_auth_missing(handler: DirectConnectionHandler) -> None: + """Missing Authorization header should fail validation.""" + req = _make_request() + assert handler._validate_auth(req) is False + + +def test_auth_non_bearer(handler: DirectConnectionHandler) -> None: + """Non-Bearer auth scheme should fail validation.""" + req = _make_request(headers={"Authorization": "Basic dXNlcjpwYXNz"}) + assert handler._validate_auth(req) is False + + +def test_auth_empty_token_rejects(handler_no_token: DirectConnectionHandler) -> None: + """Handler with no access token should reject any Bearer token.""" + req = _make_request(headers={"Authorization": "Bearer anything"}) + assert handler_no_token._validate_auth(req) is False + + +# --------------------------------------------------------------------------- +# Health check +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_health_get(handler: DirectConnectionHandler) -> None: + """GET health check should return 200 with health text.""" + req = _make_request(method="GET", path="/v1.0") + resp = await handler._handle_health(req) + assert resp.status == 200 + assert resp.text == DIRECT_HEALTH_RESPONSE + + +@pytest.mark.asyncio +async def test_health_head(handler: DirectConnectionHandler) -> None: + """HEAD health check should return 200.""" + req = _make_request(method="HEAD", path="/v1.0") + resp = await handler._handle_health(req) + assert resp.status == 200 + + +# --------------------------------------------------------------------------- +# API auth rejection +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_devices_unauthorized(handler: DirectConnectionHandler) -> None: + """POST /user/devices with bad token should return 401.""" + req = _make_request(method="POST", headers={"Authorization": "Bearer bad"}) + resp = await handler._handle_devices(req) + assert resp.status == 401 + + +@pytest.mark.asyncio +async def test_query_unauthorized(handler: DirectConnectionHandler) -> None: + """POST /user/devices/query with bad token should return 401.""" + req = _make_request(method="POST", headers={"Authorization": "Bearer bad"}) + resp = await handler._handle_query(req) + assert resp.status == 401 + + +@pytest.mark.asyncio +async def test_action_unauthorized(handler: DirectConnectionHandler) -> None: + """POST /user/devices/action with bad token should return 401.""" + req = _make_request(method="POST", headers={"Authorization": "Bearer bad"}) + resp = await handler._handle_action(req) + assert resp.status == 401 + + +@pytest.mark.asyncio +async def test_unlink_unauthorized(handler: DirectConnectionHandler) -> None: + """POST /user/unlink with bad token should return 401.""" + req = _make_request(method="POST", headers={"Authorization": "Bearer bad"}) + resp = await handler._handle_unlink(req) + assert resp.status == 401 + + +# --------------------------------------------------------------------------- +# API authorized calls +# --------------------------------------------------------------------------- + +_AUTH_HEADERS = {"Authorization": "Bearer test-token-abc"} + + +@pytest.mark.asyncio +async def test_devices_success(handler: DirectConnectionHandler) -> None: + """Authorized /user/devices should return 200 with device list.""" + req = _make_request( + method="POST", + headers={**_AUTH_HEADERS, "X-Request-Id": "req-1"}, + ) + mock_result = MagicMock() + resp_payload = {"request_id": "req-1", "payload": {"devices": []}} + mock_hdl = patch( + "music_assistant.providers.yandex_smarthome.direct.handle_device_list", + new_callable=AsyncMock, + return_value=mock_result, + ) + with ( + mock_hdl, + patch( + "music_assistant.providers.yandex_smarthome.direct.asdict", return_value={"devices": []} + ), + patch( + "music_assistant.providers.yandex_smarthome.direct.build_response", + return_value=resp_payload, + ), + ): + resp = await handler._handle_devices(req) + assert resp.status == 200 + body = json.loads(resp.body) # type: ignore[arg-type] + assert body["request_id"] == "req-1" + + +@pytest.mark.asyncio +async def test_query_success(handler: DirectConnectionHandler) -> None: + """Authorized /user/devices/query should return 200.""" + req = _make_request( + method="POST", + headers={**_AUTH_HEADERS, "X-Request-Id": "req-2"}, + payload={"devices": [{"id": "player1"}]}, + ) + mock_result = MagicMock() + resp_payload = {"request_id": "req-2", "payload": {"devices": []}} + mock_query = patch( + "music_assistant.providers.yandex_smarthome.direct.handle_devices_query", + new_callable=AsyncMock, + return_value=mock_result, + ) + with ( + mock_query, + patch( + "music_assistant.providers.yandex_smarthome.direct.asdict", return_value={"devices": []} + ), + patch( + "music_assistant.providers.yandex_smarthome.direct.build_response", + return_value=resp_payload, + ), + ): + resp = await handler._handle_query(req) + assert resp.status == 200 + + +@pytest.mark.asyncio +async def test_action_success(handler: DirectConnectionHandler) -> None: + """Authorized /user/devices/action should return 200.""" + req = _make_request( + method="POST", + headers={**_AUTH_HEADERS, "X-Request-Id": "req-3"}, + payload={"payload": {"devices": []}}, + ) + resp_payload = {"request_id": "req-3", "payload": {"devices": []}} + mock_action = patch( + "music_assistant.providers.yandex_smarthome.direct.handle_devices_action", + new_callable=AsyncMock, + return_value=MagicMock(), + ) + with ( + patch( + "music_assistant.providers.yandex_smarthome.direct.parse_action_payload", + return_value=MagicMock(), + ), + mock_action, + patch( + "music_assistant.providers.yandex_smarthome.direct.asdict", return_value={"devices": []} + ), + patch( + "music_assistant.providers.yandex_smarthome.direct.build_response", + return_value=resp_payload, + ), + ): + resp = await handler._handle_action(req) + assert resp.status == 200 + + +@pytest.mark.asyncio +async def test_unlink_success(handler: DirectConnectionHandler) -> None: + """Authorized /user/unlink should return 200.""" + req = _make_request( + method="POST", + headers={**_AUTH_HEADERS, "X-Request-Id": "req-4"}, + ) + resp_payload = {"request_id": "req-4"} + with ( + patch( + "music_assistant.providers.yandex_smarthome.direct.handle_user_unlink", + new_callable=AsyncMock, + return_value={}, + ), + patch( + "music_assistant.providers.yandex_smarthome.direct.build_response", + return_value=resp_payload, + ), + ): + resp = await handler._handle_unlink(req) + assert resp.status == 200 + + +# --------------------------------------------------------------------------- +# OAuth authorize +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_authorize_returns_html(handler: DirectConnectionHandler) -> None: + """GET /auth/authorize should return HTML with link button.""" + req = _make_request( + method="GET", + path="/auth/authorize", + query={ + "client_id": DIRECT_OAUTH_CLIENT_ID, + "redirect_uri": "https://social.yandex.net/broker/redirect", + "state": "abc123", + "response_type": "code", + }, + ) + resp = await handler._handle_oauth_authorize(req) + assert resp.status == 200 + assert resp.content_type == "text/html" + assert resp.text is not None + assert "Music Assistant" in resp.text + assert "abc123" in resp.text + + +@pytest.mark.asyncio +async def test_authorize_missing_redirect_uri(handler: DirectConnectionHandler) -> None: + """GET /auth/authorize without redirect_uri should return 400.""" + req = _make_request(method="GET", path="/auth/authorize", query={}) + resp = await handler._handle_oauth_authorize(req) + assert resp.status == 400 + + +@pytest.mark.asyncio +async def test_authorize_creates_pending_code(handler: DirectConnectionHandler) -> None: + """GET /auth/authorize should create a pending authorization code.""" + req = _make_request( + method="GET", + path="/auth/authorize", + query={ + "client_id": DIRECT_OAUTH_CLIENT_ID, + "redirect_uri": "https://social.yandex.net/broker/redirect", + "state": "s1", + "response_type": "code", + }, + ) + assert len(handler._pending_codes) == 0 + await handler._handle_oauth_authorize(req) + assert len(handler._pending_codes) == 1 + + +@pytest.mark.asyncio +async def test_authorize_invalid_client_id(handler: DirectConnectionHandler) -> None: + """GET /auth/authorize with wrong client_id should return 400.""" + req = _make_request( + method="GET", + path="/auth/authorize", + query={ + "client_id": "wrong-client-id", + "redirect_uri": "https://social.yandex.net/broker/redirect", + "state": "s1", + "response_type": "code", + }, + ) + resp = await handler._handle_oauth_authorize(req) + assert resp.status == 400 + + +@pytest.mark.asyncio +async def test_authorize_invalid_response_type(handler: DirectConnectionHandler) -> None: + """GET /auth/authorize with wrong response_type should return 400.""" + req = _make_request( + method="GET", + path="/auth/authorize", + query={ + "client_id": DIRECT_OAUTH_CLIENT_ID, + "redirect_uri": "https://social.yandex.net/broker/redirect", + "state": "s1", + "response_type": "token", + }, + ) + resp = await handler._handle_oauth_authorize(req) + assert resp.status == 400 + + +@pytest.mark.asyncio +async def test_authorize_invalid_redirect_uri_domain(handler: DirectConnectionHandler) -> None: + """GET /auth/authorize with non-Yandex redirect_uri should return 400.""" + req = _make_request( + method="GET", + path="/auth/authorize", + query={ + "client_id": DIRECT_OAUTH_CLIENT_ID, + "redirect_uri": "https://evil.example.com/steal", + "state": "s1", + "response_type": "code", + }, + ) + resp = await handler._handle_oauth_authorize(req) + assert resp.status == 400 + + +# --------------------------------------------------------------------------- +# OAuth token +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_token_exchange_valid_code(handler: DirectConnectionHandler) -> None: + """Token exchange with valid code should return access_token.""" + req_auth = _make_request( + method="GET", + path="/auth/authorize", + query={ + "client_id": DIRECT_OAUTH_CLIENT_ID, + "redirect_uri": "https://social.yandex.net/broker/redirect", + "state": "s1", + "response_type": "code", + }, + ) + await handler._handle_oauth_authorize(req_auth) + code = _get_pending_code(handler) + + req_token = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "authorization_code", + "code": code, + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": TEST_CLIENT_SECRET, + }, + ) + resp = await handler._handle_oauth_token(req_token) + assert resp.status == 200 + body = json.loads(resp.body) # type: ignore[arg-type] + assert body["access_token"] == "test-token-abc" + assert body["token_type"] == "bearer" + assert "refresh_token" in body + + +@pytest.mark.asyncio +async def test_token_exchange_generates_new_token( + handler_no_token: DirectConnectionHandler, +) -> None: + """When no token exists, OAuth should generate a new one.""" + req_auth = _make_request( + method="GET", + path="/auth/authorize", + query={ + "client_id": DIRECT_OAUTH_CLIENT_ID, + "redirect_uri": "https://social.yandex.net/broker/redirect", + "state": "s1", + "response_type": "code", + }, + ) + await handler_no_token._handle_oauth_authorize(req_auth) + code = _get_pending_code(handler_no_token) + + req_token = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "authorization_code", + "code": code, + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": TEST_CLIENT_SECRET, + }, + ) + resp = await handler_no_token._handle_oauth_token(req_token) + assert resp.status == 200 + body = json.loads(resp.body) # type: ignore[arg-type] + assert body["access_token"] + assert len(body["access_token"]) == 32 # uuid4().hex + assert len(_handler_no_token_tokens) == 1 + assert _handler_no_token_tokens[0] == body["access_token"] + + +@pytest.mark.asyncio +async def test_token_exchange_invalid_client_secret(handler: DirectConnectionHandler) -> None: + """Token exchange with wrong client_secret should return 401.""" + req = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "authorization_code", + "code": "any", + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": "wrong-secret", + }, + ) + resp = await handler._handle_oauth_token(req) + assert resp.status == 401 + body = json.loads(resp.body) # type: ignore[arg-type] + assert body["error"] == "invalid_client" + + +@pytest.mark.asyncio +async def test_token_exchange_invalid_client_id(handler: DirectConnectionHandler) -> None: + """Token exchange with wrong client_id should return 401.""" + req = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "authorization_code", + "code": "any", + "client_id": "wrong-client-id", + "client_secret": TEST_CLIENT_SECRET, + }, + ) + resp = await handler._handle_oauth_token(req) + assert resp.status == 401 + body = json.loads(resp.body) # type: ignore[arg-type] + assert body["error"] == "invalid_client" + + +@pytest.mark.asyncio +async def test_token_exchange_invalid_code(handler: DirectConnectionHandler) -> None: + """Token exchange with invalid code should return 400.""" + req = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "authorization_code", + "code": "nonexistent", + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": TEST_CLIENT_SECRET, + }, + ) + resp = await handler._handle_oauth_token(req) + assert resp.status == 400 + body = json.loads(resp.body) # type: ignore[arg-type] + assert body["error"] == "invalid_grant" + + +@pytest.mark.asyncio +async def test_token_exchange_expired_code(handler: DirectConnectionHandler) -> None: + """Expired authorization codes should be rejected.""" + handler._pending_codes["expired-code"] = time.time() - 10 + req = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "authorization_code", + "code": "expired-code", + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": TEST_CLIENT_SECRET, + }, + ) + resp = await handler._handle_oauth_token(req) + assert resp.status == 400 + + +@pytest.mark.asyncio +async def test_refresh_token_valid(handler: DirectConnectionHandler) -> None: + """Refresh token with correct token should return 200.""" + req = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "refresh_token", + "refresh_token": "test-token-abc", + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": TEST_CLIENT_SECRET, + }, + ) + resp = await handler._handle_oauth_token(req) + assert resp.status == 200 + body = json.loads(resp.body) # type: ignore[arg-type] + assert body["access_token"] == "test-token-abc" + + +@pytest.mark.asyncio +async def test_refresh_token_invalid(handler: DirectConnectionHandler) -> None: + """Refresh token with wrong token should return 400.""" + req = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "refresh_token", + "refresh_token": "wrong", + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": TEST_CLIENT_SECRET, + }, + ) + resp = await handler._handle_oauth_token(req) + assert resp.status == 400 + + +@pytest.mark.asyncio +async def test_unsupported_grant_type(handler: DirectConnectionHandler) -> None: + """Unsupported grant_type should return 400.""" + req = _make_request( + method="POST", + path="/auth/token", + post_data={ + "grant_type": "client_credentials", + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": TEST_CLIENT_SECRET, + }, + ) + resp = await handler._handle_oauth_token(req) + assert resp.status == 400 + body = json.loads(resp.body) # type: ignore[arg-type] + assert body["error"] == "unsupported_grant_type" + + +@pytest.mark.asyncio +async def test_code_consumed_after_use(handler: DirectConnectionHandler) -> None: + """Authorization codes should be single-use.""" + req_auth = _make_request( + method="GET", + path="/auth/authorize", + query={ + "client_id": DIRECT_OAUTH_CLIENT_ID, + "redirect_uri": "https://social.yandex.net/broker/redirect", + "state": "s1", + "response_type": "code", + }, + ) + await handler._handle_oauth_authorize(req_auth) + code = _get_pending_code(handler) + + token_post = { + "grant_type": "authorization_code", + "code": code, + "client_id": DIRECT_OAUTH_CLIENT_ID, + "client_secret": TEST_CLIENT_SECRET, + } + + # First exchange — success + req1 = _make_request(method="POST", path="/auth/token", post_data=token_post) + resp1 = await handler._handle_oauth_token(req1) + assert resp1.status == 200 + + # Second exchange — code consumed, should fail + req2 = _make_request(method="POST", path="/auth/token", post_data=token_post) + resp2 = await handler._handle_oauth_token(req2) + assert resp2.status == 400 + + +# --------------------------------------------------------------------------- +# Plugin integration +# --------------------------------------------------------------------------- + + +def _make_direct_config(**overrides: Any) -> MagicMock: + """Create a mock config for direct mode with sensible defaults.""" + defaults: dict[str, Any] = { + "instance_name": "TestMA", + "connection_type": CONNECTION_TYPE_DIRECT, + "skill_id": "test-skill-id", + "skill_token": "test-skill-token", + "direct_access_token": "existing-token", + "direct_client_secret": TEST_CLIENT_SECRET, + "exposed_players": None, + "cloud_instance_id": "", + "cloud_instance_password": "", + "cloud_connection_token": "", + "log_level": "GLOBAL", + } + defaults.update(overrides) + config = MagicMock() + config.get_value = MagicMock(side_effect=lambda key: defaults.get(key, "")) + return config + + +@pytest.mark.asyncio +async def test_start_direct_mode_registers_routes(mock_mass: MagicMock) -> None: + """_start_direct_mode should create handler and register routes.""" + config = _make_direct_config() + plugin = YandexSmartHomePlugin( + mass=mock_mass, + manifest=MagicMock(domain="yandex_smarthome"), + config=config, + supported_features=set(), + ) + await plugin.handle_async_init() + await plugin.loaded_in_mass() + + assert plugin._direct_handler is not None + assert mock_mass.webserver.register_dynamic_route.call_count == 10 + assert plugin._state_notifier is not None + + +@pytest.mark.asyncio +async def test_start_direct_mode_missing_skill_id(mock_mass: MagicMock) -> None: + """Direct mode without skill_id should log error and not start.""" + config = _make_direct_config(skill_id="") + plugin = YandexSmartHomePlugin( + mass=mock_mass, + manifest=MagicMock(domain="yandex_smarthome"), + config=config, + supported_features=set(), + ) + plugin.logger = MagicMock() + await plugin.handle_async_init() + await plugin.loaded_in_mass() + + assert plugin._direct_handler is None + plugin.logger.error.assert_called() + + +@pytest.mark.asyncio +async def test_unload_cleans_up_direct(mock_mass: MagicMock) -> None: + """unload() should unregister routes and stop notifier.""" + config = _make_direct_config() + plugin = YandexSmartHomePlugin( + mass=mock_mass, + manifest=MagicMock(domain="yandex_smarthome"), + config=config, + supported_features=set(), + ) + await plugin.handle_async_init() + await plugin.loaded_in_mass() + await plugin.unload() + + assert plugin._direct_handler is None + assert plugin._state_notifier is None diff --git a/tests/providers/yandex_smarthome/test_handlers.py b/tests/providers/yandex_smarthome/test_handlers.py new file mode 100644 index 0000000000..3f547d2c19 --- /dev/null +++ b/tests/providers/yandex_smarthome/test_handlers.py @@ -0,0 +1,401 @@ +"""Tests for provider/handlers.py — Yandex Smart Home API request handlers.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +# Use the PlaybackState from conftest's mock enums +from music_assistant_models.enums import PlaybackState + +from music_assistant.providers.yandex_smarthome.handlers import ( + build_response, + handle_device_list, + handle_devices_action, + handle_devices_query, + handle_user_unlink, + parse_action_payload, +) +from music_assistant.providers.yandex_smarthome.schema import DeviceListPayload + + +@dataclass +class MockPlayer: + """Mock player for handler tests.""" + + player_id: str = "p1" + name: str = "Speaker" + available: bool = True + enabled: bool = True + powered: bool | None = True + playback_state: Any = PlaybackState.PLAYING + volume_level: int | None = 50 + volume_muted: bool | None = False + synced_to: str | None = None + device_info: Any = None + supported_features: set[str] = field(default_factory=set) + source_list: list[str] = field(default_factory=list) + active_source: str | None = None + + @property + def state(self) -> MockPlayer: + """Return self as state (mirrors real Player.state).""" + return self + + +def _make_mass(players: list[MockPlayer]) -> MagicMock: + """Create a mock MusicAssistant with given players.""" + mass = MagicMock() + mass.players.__iter__ = MagicMock(return_value=iter(players)) + mass.players.all_players = MagicMock(return_value=players) + + player_map = {p.player_id: p for p in players} + mass.players.get_player = MagicMock(side_effect=player_map.get) + + mass.players.cmd_play = AsyncMock() + mass.players.cmd_stop = AsyncMock() + mass.players.cmd_pause = AsyncMock() + mass.players.cmd_power = AsyncMock() + mass.players.cmd_volume_set = AsyncMock() + mass.players.cmd_volume_mute = AsyncMock() + return mass + + +# --------------------------------------------------------------------------- +# Tests: handle_device_list +# --------------------------------------------------------------------------- + + +class TestHandleDeviceList: + """Tests for handle_device_list.""" + + @pytest.mark.asyncio + async def test_empty(self) -> None: + """Test empty.""" + mass = _make_mass([]) + result = await handle_device_list(mass, "user1") + assert result.user_id == "user1" + assert result.devices == [] + + @pytest.mark.asyncio + async def test_exposes_available_players(self) -> None: + """Test exposes available players.""" + players = [ + MockPlayer(player_id="p1", name="Speaker 1"), + MockPlayer(player_id="p2", name="Speaker 2"), + ] + mass = _make_mass(players) + result = await handle_device_list(mass, "user1") + assert len(result.devices) == 2 + ids = {d.id for d in result.devices} + assert ids == {"p1", "p2"} + + @pytest.mark.asyncio + async def test_filters_unavailable(self) -> None: + """Test filters unavailable.""" + players = [ + MockPlayer(player_id="p1", available=True), + MockPlayer(player_id="p2", available=False), + ] + mass = _make_mass(players) + result = await handle_device_list(mass, "user1") + assert len(result.devices) == 1 + assert result.devices[0].id == "p1" + + @pytest.mark.asyncio + async def test_filters_synced(self) -> None: + """Test filters synced.""" + players = [ + MockPlayer(player_id="leader"), + MockPlayer(player_id="follower", synced_to="leader"), + ] + mass = _make_mass(players) + result = await handle_device_list(mass, "user1") + assert len(result.devices) == 1 + assert result.devices[0].id == "leader" + + @pytest.mark.asyncio + async def test_filters_by_exposed_ids(self) -> None: + """Test filters by exposed ids.""" + players = [ + MockPlayer(player_id="p1", name="Speaker 1"), + MockPlayer(player_id="p2", name="Speaker 2"), + MockPlayer(player_id="p3", name="Speaker 3"), + ] + mass = _make_mass(players) + result = await handle_device_list(mass, "user1", exposed_ids={"p1", "p3"}) + assert len(result.devices) == 2 + ids = {d.id for d in result.devices} + assert ids == {"p1", "p3"} + + +# --------------------------------------------------------------------------- +# Tests: handle_devices_query +# --------------------------------------------------------------------------- + + +class TestHandleDevicesQuery: + """Tests for handle_devices_query.""" + + @pytest.mark.asyncio + async def test_returns_states(self) -> None: + """Test returns states.""" + mass = _make_mass([MockPlayer(player_id="p1", volume_level=75)]) + result = await handle_devices_query(mass, ["p1"]) + assert len(result.devices) == 1 + assert result.devices[0].id == "p1" + assert result.devices[0].error_code is None + + @pytest.mark.asyncio + async def test_unknown_device_returns_error(self) -> None: + """Test unknown device returns error.""" + mass = _make_mass([]) + result = await handle_devices_query(mass, ["missing"]) + assert len(result.devices) == 1 + assert result.devices[0].error_code == "DEVICE_UNREACHABLE" + + @pytest.mark.asyncio + async def test_unavailable_device_returns_error(self) -> None: + """Test unavailable device returns error.""" + mass = _make_mass([MockPlayer(player_id="p1", available=False)]) + result = await handle_devices_query(mass, ["p1"]) + assert result.devices[0].error_code == "DEVICE_UNREACHABLE" + + +# --------------------------------------------------------------------------- +# Tests: handle_devices_action +# --------------------------------------------------------------------------- + + +class TestHandleDevicesAction: + """Tests for handle_devices_action.""" + + @pytest.mark.asyncio + async def test_executes_on_off(self) -> None: + """Test executes on off.""" + mass = _make_mass([MockPlayer(player_id="p1")]) + payload = parse_action_payload( + { + "payload": { + "devices": [ + { + "id": "p1", + "capabilities": [ + { + "type": "devices.capabilities.on_off", + "state": {"instance": "on", "value": True}, + } + ], + } + ], + }, + } + ) + result = await handle_devices_action(mass, payload) + assert len(result.devices) == 1 + mass.players.cmd_play.assert_awaited_once_with("p1") + + @pytest.mark.asyncio + async def test_missing_device_returns_error(self) -> None: + """Test missing device returns error.""" + mass = _make_mass([]) + payload = parse_action_payload( + { + "payload": { + "devices": [ + { + "id": "missing", + "capabilities": [ + { + "type": "devices.capabilities.on_off", + "state": {"instance": "on", "value": True}, + } + ], + } + ], + }, + } + ) + result = await handle_devices_action(mass, payload) + assert ( + result.devices[0].capabilities[0].state.action_result.error_code == "DEVICE_UNREACHABLE" + ) + + +# --------------------------------------------------------------------------- +# Tests: handle_user_unlink +# --------------------------------------------------------------------------- + + +class TestHandleUserUnlink: + """Tests for handle_user_unlink.""" + + @pytest.mark.asyncio + async def test_returns_empty(self) -> None: + """Test returns empty.""" + result = await handle_user_unlink() + assert result == {} + + +# --------------------------------------------------------------------------- +# Tests: parse_action_payload +# --------------------------------------------------------------------------- + + +class TestParseActionPayload: + """Tests for parse_action_payload.""" + + def test_basic_payload(self) -> None: + """Test basic payload.""" + raw = { + "payload": { + "devices": [ + { + "id": "p1", + "capabilities": [ + { + "type": "devices.capabilities.on_off", + "state": {"instance": "on", "value": True}, + } + ], + } + ], + }, + } + payload = parse_action_payload(raw) + assert len(payload.devices) == 1 + assert payload.devices[0].id == "p1" + assert len(payload.devices[0].capabilities) == 1 + cap = payload.devices[0].capabilities[0] + assert cap.type == "devices.capabilities.on_off" + assert cap.state.instance == "on" + assert cap.state.value is True + assert cap.state.relative is False + + def test_relative_volume(self) -> None: + """Test relative volume.""" + raw = { + "payload": { + "devices": [ + { + "id": "p1", + "capabilities": [ + { + "type": "devices.capabilities.range", + "state": {"instance": "volume", "value": 10, "relative": True}, + } + ], + } + ], + }, + } + payload = parse_action_payload(raw) + cap = payload.devices[0].capabilities[0] + assert cap.state.relative is True + assert cap.state.value == 10 + + def test_multiple_devices(self) -> None: + """Test multiple devices.""" + raw = { + "payload": { + "devices": [ + {"id": "p1", "capabilities": []}, + {"id": "p2", "capabilities": []}, + ], + }, + } + payload = parse_action_payload(raw) + assert len(payload.devices) == 2 + + def test_unwrapped_payload(self) -> None: + """parse_action_payload should also handle messages without outer 'payload' key.""" + raw = { + "devices": [ + { + "id": "p1", + "capabilities": [ + { + "type": "devices.capabilities.toggle", + "state": {"instance": "pause", "value": True}, + } + ], + } + ], + } + payload = parse_action_payload(raw) + assert len(payload.devices) == 1 + + def test_skips_missing_or_empty_instance(self) -> None: + """Capability with missing/empty/non-string instance should be skipped.""" + raw = { + "payload": { + "devices": [ + { + "id": "p1", + "capabilities": [ + {"type": "devices.capabilities.on_off", "state": {"value": True}}, + { + "type": "devices.capabilities.on_off", + "state": {"instance": " ", "value": True}, + }, + { + "type": "devices.capabilities.on_off", + "state": {"instance": 42, "value": True}, + }, + ], + } + ], + }, + } + payload = parse_action_payload(raw) + assert len(payload.devices) == 1 + assert payload.devices[0].capabilities == [] + + def test_skips_non_bool_relative(self) -> None: + """Capability with non-bool `relative` field should be skipped.""" + raw = { + "payload": { + "devices": [ + { + "id": "p1", + "capabilities": [ + { + "type": "devices.capabilities.range", + "state": {"instance": "volume", "value": 10, "relative": "yes"}, + }, + ], + } + ], + }, + } + payload = parse_action_payload(raw) + assert payload.devices[0].capabilities == [] + + +# --------------------------------------------------------------------------- +# Tests: build_response +# --------------------------------------------------------------------------- + + +class TestBuildResponse: + """Tests for build_response.""" + + def test_dict_payload(self) -> None: + """Test dict payload.""" + resp = build_response("req-1", {"key": "val"}) + assert resp == {"request_id": "req-1", "payload": {"key": "val"}} + + def test_none_payload(self) -> None: + """Test none payload.""" + resp = build_response("req-1", None) + assert resp == {"request_id": "req-1", "payload": {}} + + def test_dataclass_payload(self) -> None: + """Test dataclass payload.""" + payload = DeviceListPayload(user_id="u1", devices=[]) + resp = build_response("req-1", payload) + assert resp["request_id"] == "req-1" + assert resp["payload"]["user_id"] == "u1" diff --git a/tests/providers/yandex_smarthome/test_notifier.py b/tests/providers/yandex_smarthome/test_notifier.py new file mode 100644 index 0000000000..f802035c08 --- /dev/null +++ b/tests/providers/yandex_smarthome/test_notifier.py @@ -0,0 +1,449 @@ +"""Tests for provider/notifier.py — StateNotifier batching and lifecycle.""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import aiohttp +import pytest + +# Use mock enums from conftest +from music_assistant_models.enums import EventType, PlaybackState + +from music_assistant.providers.yandex_smarthome.notifier import StateNotifier + + +@dataclass +class MockPlayer: + """Mock player for notifier tests.""" + + player_id: str = "p1" + name: str = "Speaker" + available: bool = True + enabled: bool = True + powered: bool | None = True + playback_state: Any = PlaybackState.PLAYING + volume_level: int | None = 50 + volume_muted: bool | None = False + synced_to: str | None = None + device_info: Any = None + supported_features: set[str] = field(default_factory=set) + source_list: list[str] = field(default_factory=list) + active_source: str | None = None + group_members: list[str] = field(default_factory=list) + + @property + def state(self) -> MockPlayer: + """Return self as state (mirrors real Player.state).""" + return self + + +@dataclass +class MockEvent: + """Mock event for notifier tests.""" + + event: str + data: Any = None + + +def _make_mass(players: list[MockPlayer] | None = None) -> MagicMock: + """Create a mock MusicAssistant.""" + mass = MagicMock() + mass.loop = MagicMock() + + if players is None: + players = [MockPlayer()] + mass.players.__iter__ = MagicMock(return_value=iter(players)) + mass.players.all_players = MagicMock(return_value=players) + + # subscribe returns an unsubscribe callable + mass.subscribe = MagicMock(return_value=MagicMock()) + + # create_task returns a mock Task that can be awaited + class _MockTask: + """Minimal awaitable mock task for testing.""" + + def __init__(self) -> None: + self._cancelled = False + self._done = False + + def done(self) -> bool: + return self._done or self._cancelled + + def cancel(self) -> bool: + self._cancelled = True + return True + + def __await__(self): # type: ignore[no-untyped-def] + if self._cancelled: + raise asyncio.CancelledError + yield + + mass.create_task = MagicMock(return_value=_MockTask()) + + return mass + + +def _make_notifier( + mass: MagicMock | None = None, + session: MagicMock | None = None, +) -> StateNotifier: + """Create a StateNotifier with mocks.""" + if mass is None: + mass = _make_mass() + if session is None: + session = MagicMock(spec=aiohttp.ClientSession) + + return StateNotifier( + mass=mass, + session=session, + user_id="test_user", + callback_url="https://example.com/callback/state", + auth_header={"Authorization": "Bearer test-token"}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestStateNotifierLifecycle: + """Tests for StateNotifier lifecycle methods.""" + + @pytest.mark.asyncio + async def test_start_subscribes(self) -> None: + """Test start subscribes.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + + await notifier.start() + + mass.subscribe.assert_called_once() + assert notifier._unsub is not None + assert notifier._heartbeat_task is not None + + await notifier.stop() + + @pytest.mark.asyncio + async def test_stop_unsubscribes(self) -> None: + """Test stop unsubscribes.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + + await notifier.start() + unsub = notifier._unsub + await notifier.stop() + + unsub.assert_called_once() # type: ignore[union-attr] + assert notifier._unsub is None + assert notifier._heartbeat_task is None + + @pytest.mark.asyncio + async def test_stop_without_start(self) -> None: + """Test stop without start.""" + notifier = _make_notifier() + # Should not raise + await notifier.stop() + + +class TestStateNotifierEvents: + """Tests for StateNotifier event handling.""" + + def test_on_player_updated_queues_state(self) -> None: + """Test on player updated marks player as dirty.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + + player = MockPlayer(player_id="p1", playback_state=PlaybackState.PLAYING) + event = MockEvent(event=EventType.PLAYER_UPDATED, data=player) + + notifier._on_player_event(event) # type: ignore[arg-type] + + assert "p1" in notifier._dirty_player_ids + + def test_on_player_updated_unavailable_ignored(self) -> None: + """Test on player updated unavailable ignored.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + + player = MockPlayer(player_id="p1", available=False) + event = MockEvent(event=EventType.PLAYER_UPDATED, data=player) + + notifier._on_player_event(event) # type: ignore[arg-type] + + assert "p1" not in notifier._dirty_player_ids + + def test_on_player_added_triggers_discovery(self) -> None: + """Test on player added triggers discovery.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + + event = MockEvent(event=EventType.PLAYER_ADDED, data=MockPlayer()) + notifier._on_player_event(event) # type: ignore[arg-type] + + # Discovery triggers create_task + mass.create_task.assert_called() + + def test_on_player_removed_triggers_discovery(self) -> None: + """Test on player removed triggers discovery.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + + event = MockEvent(event=EventType.PLAYER_REMOVED, data="p1") + notifier._on_player_event(event) # type: ignore[arg-type] + + mass.create_task.assert_called() + + def test_on_none_data_ignored(self) -> None: + """Test on none data ignored.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + + event = MockEvent(event=EventType.PLAYER_UPDATED, data=None) + notifier._on_player_event(event) # type: ignore[arg-type] + + assert len(notifier._dirty_player_ids) == 0 + + def test_on_player_filtered_by_exposed_ids(self) -> None: + """Player not in exposed_ids should be ignored.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + notifier._exposed_ids = {"p2", "p3"} + + player = MockPlayer(player_id="p1", playback_state=PlaybackState.PLAYING) + event = MockEvent(event=EventType.PLAYER_UPDATED, data=player) + notifier._on_player_event(event) # type: ignore[arg-type] + + assert "p1" not in notifier._dirty_player_ids + + def test_on_player_included_by_exposed_ids(self) -> None: + """Player in exposed_ids should be marked dirty.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + notifier._exposed_ids = {"p1", "p2"} + + player = MockPlayer(player_id="p1", playback_state=PlaybackState.PLAYING) + event = MockEvent(event=EventType.PLAYER_UPDATED, data=player) + notifier._on_player_event(event) # type: ignore[arg-type] + + assert "p1" in notifier._dirty_player_ids + + def test_child_event_propagates_to_group(self) -> None: + """When a synced child fires PLAYER_UPDATED, the parent group is marked dirty.""" + mass = _make_mass() + notifier = _make_notifier(mass=mass) + + child = MockPlayer(player_id="child1", synced_to="grp1") + event = MockEvent(event=EventType.PLAYER_UPDATED, data=child) + notifier._on_player_event(event) # type: ignore[arg-type] + + assert "grp1" in notifier._dirty_player_ids + assert "child1" not in notifier._dirty_player_ids + + +class TestStateNotifierFlush: + """Tests for StateNotifier flush mechanism.""" + + @pytest.mark.asyncio + async def test_flush_sends_callback(self) -> None: + """Test flush reads fresh state and sends callback.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + player = MockPlayer(player_id="p1", volume_level=75) + mass = _make_mass([player]) + mass.players.get_player = MagicMock(return_value=player) + notifier = _make_notifier(mass=mass, session=session) + + notifier._dirty_player_ids.add("p1") + + await notifier._flush_pending() + + session.post.assert_called_once() + assert len(notifier._dirty_player_ids) == 0 + + @pytest.mark.asyncio + async def test_flush_empty_noop(self) -> None: + """Test flush empty noop.""" + session = MagicMock(spec=aiohttp.ClientSession) + mass = _make_mass() + notifier = _make_notifier(mass=mass, session=session) + + await notifier._flush_pending() + + session.post.assert_not_called() + + @pytest.mark.asyncio + async def test_flush_reads_fresh_volume(self) -> None: + """Volume should be read at flush time, not at event time.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + # Event arrives with volume 0 (transient state) + player_event = MockPlayer(player_id="p1", volume_level=0) + # But by flush time, player has correct volume + player_live = MockPlayer(player_id="p1", volume_level=75) + + mass = _make_mass([player_live]) + mass.players.get_player = MagicMock(return_value=player_live) + notifier = _make_notifier(mass=mass, session=session) + + # Simulate event with transient volume=0 + event = MockEvent(event=EventType.PLAYER_UPDATED, data=player_event) + notifier._on_player_event(event) # type: ignore[arg-type] + + # Flush should use live player state (volume=75) + await notifier._flush_pending() + + call_kwargs = session.post.call_args + json_body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + devices = json_body["payload"]["devices"] + volume_cap = next( + c for c in devices[0]["capabilities"] if c["state"]["instance"] == "volume" + ) + assert volume_cap["state"]["value"] == 75 + + +class TestStateNotifierReportAll: + """Tests for StateNotifier report-all-states.""" + + @pytest.mark.asyncio + async def test_report_all_states(self) -> None: + """Test report all states.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + players = [MockPlayer(player_id="p1"), MockPlayer(player_id="p2")] + mass = _make_mass(players) + notifier = _make_notifier(mass=mass, session=session) + + await notifier._report_all_states() + + session.post.assert_called_once() + # Verify the payload contains both devices + call_kwargs = session.post.call_args + json_body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + assert len(json_body["payload"]["devices"]) == 2 + + @pytest.mark.asyncio + async def test_report_all_no_players(self) -> None: + """Test report all no players.""" + session = MagicMock(spec=aiohttp.ClientSession) + mass = _make_mass([]) + notifier = _make_notifier(mass=mass, session=session) + + await notifier._report_all_states() + + session.post.assert_not_called() + + +class TestStateNotifierCloudPlus: + """Tests specific to Cloud Plus (Yandex Dialogs API) behaviour.""" + + @pytest.mark.asyncio + async def test_accepts_http_202(self) -> None: + """Yandex Dialogs returns 202 on successful callback — should not warn.""" + mock_resp = AsyncMock() + mock_resp.status = 202 + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + mass = _make_mass() + notifier = StateNotifier( + mass=mass, + session=session, + user_id="cloud-instance-id", + callback_url="https://dialogs.yandex.net/api/v1/skills/test-uuid/callback/state", + auth_header={"Authorization": "OAuth test-oauth-token"}, + ) + + player = MockPlayer(player_id="p1") + mass.players.get_player = MagicMock(return_value=player) + notifier._dirty_player_ids.add("p1") + + await notifier._flush_pending() + + session.post.assert_called_once() + call_kwargs = session.post.call_args + assert "dialogs.yandex.net" in call_kwargs[0][0] + assert call_kwargs.kwargs["headers"]["Authorization"] == "OAuth test-oauth-token" + + @pytest.mark.asyncio + async def test_rejects_http_500(self) -> None: + """Non-success status codes should re-queue dirty IDs and raise.""" + mock_resp = AsyncMock() + mock_resp.status = 500 + mock_resp.text = AsyncMock(return_value="Internal Server Error") + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + player = MockPlayer(player_id="p1") + mass = _make_mass([player]) + mass.players.get_player = MagicMock(return_value=player) + notifier = _make_notifier(mass=mass, session=session) + + notifier._dirty_player_ids.add("p1") + + with pytest.raises(RuntimeError, match="State callback failed"): + await notifier._flush_pending() + + session.post.assert_called_once() + # Player IDs should be re-queued after failure + assert "p1" in notifier._dirty_player_ids + + @pytest.mark.asyncio + async def test_discovery_url_cloud_plus(self) -> None: + """Discovery URL should use replace('/state', '/discovery') for Dialogs API.""" + mock_resp = AsyncMock() + mock_resp.status = 202 + + session = MagicMock(spec=aiohttp.ClientSession) + ctx = MagicMock() + ctx.__aenter__ = AsyncMock(return_value=mock_resp) + ctx.__aexit__ = AsyncMock(return_value=False) + session.post.return_value = ctx + + mass = _make_mass() + notifier = StateNotifier( + mass=mass, + session=session, + user_id="cloud-instance-id", + callback_url="https://dialogs.yandex.net/api/v1/skills/test-uuid/callback/state", + auth_header={"Authorization": "OAuth test-token"}, + ) + + await notifier._send_discovery() + + session.post.assert_called_once() + url = session.post.call_args[0][0] + assert url == "https://dialogs.yandex.net/api/v1/skills/test-uuid/callback/discovery" diff --git a/tests/providers/yandex_smarthome/test_schema.py b/tests/providers/yandex_smarthome/test_schema.py new file mode 100644 index 0000000000..ce46d436b0 --- /dev/null +++ b/tests/providers/yandex_smarthome/test_schema.py @@ -0,0 +1,285 @@ +"""Tests for provider/schema.py — dataclass models and enums.""" + +from __future__ import annotations + +from dataclasses import asdict + +from music_assistant.providers.yandex_smarthome.schema import ( + ActionRequestPayload, + ActionResult, + CallbackPayload, + CallbackRequest, + CapabilityAction, + CapabilityActionResult, + CapabilityActionResultState, + CapabilityActionState, + CapabilityDescription, + CapabilityInstanceState, + CapabilityParameters, + CapabilityState, + CloudRequest, + CloudResponse, + DeviceAction, + DeviceDescription, + DeviceListPayload, + DeviceState, + DeviceStatesPayload, + RangeParameters, + YandexCapabilityType, + YandexDeviceInfo, + YandexDeviceType, + YandexRangeInstance, + YandexResponseCode, + YandexToggleInstance, +) + +# --------------------------------------------------------------------------- +# Enum values +# --------------------------------------------------------------------------- + + +class TestEnums: + """Test enum string values match Yandex Smart Home API.""" + + def test_device_types(self) -> None: + """Test device types.""" + assert YandexDeviceType.MEDIA_DEVICE.value == "devices.types.media_device" + assert YandexDeviceType.MEDIA_DEVICE_RECEIVER.value == "devices.types.media_device.receiver" + + def test_capability_types(self) -> None: + """Test capability types.""" + assert YandexCapabilityType.ON_OFF.value == "devices.capabilities.on_off" + assert YandexCapabilityType.RANGE.value == "devices.capabilities.range" + assert YandexCapabilityType.TOGGLE.value == "devices.capabilities.toggle" + + def test_range_instances(self) -> None: + """Test range instances.""" + assert YandexRangeInstance.VOLUME.value == "volume" + + def test_toggle_instances(self) -> None: + """Test toggle instances.""" + assert YandexToggleInstance.MUTE.value == "mute" + assert YandexToggleInstance.PAUSE.value == "pause" + + def test_response_codes(self) -> None: + """Test response codes.""" + assert YandexResponseCode.DONE == "DONE" + assert YandexResponseCode.DEVICE_UNREACHABLE == "DEVICE_UNREACHABLE" + assert YandexResponseCode.INVALID_ACTION == "INVALID_ACTION" + assert YandexResponseCode.INTERNAL_ERROR == "INTERNAL_ERROR" + assert YandexResponseCode.DEVICE_NOT_FOUND == "DEVICE_NOT_FOUND" + + +# --------------------------------------------------------------------------- +# Serialization roundtrips +# --------------------------------------------------------------------------- + + +class TestDeviceDescription: + """Test DeviceDescription serialization.""" + + def test_minimal(self) -> None: + """Test minimal.""" + desc = DeviceDescription( + id="p1", name="Living Room", type=YandexDeviceType.MEDIA_DEVICE_RECEIVER + ) + data = asdict(desc) + assert data["id"] == "p1" + assert data["name"] == "Living Room" + assert data["type"] == "devices.types.media_device.receiver" + assert data["capabilities"] == [] + + def test_with_capabilities(self) -> None: + """Test with capabilities.""" + desc = DeviceDescription( + id="p1", + name="Speaker", + type=YandexDeviceType.MEDIA_DEVICE_RECEIVER, + capabilities=[ + CapabilityDescription(type=YandexCapabilityType.ON_OFF), + CapabilityDescription( + type=YandexCapabilityType.RANGE, + parameters=CapabilityParameters( + instance="volume", + range=RangeParameters(min=0, max=100, precision=1), + unit="unit.percent", + ), + ), + ], + device_info=YandexDeviceInfo(manufacturer="Test", model="X1"), + ) + data = asdict(desc) + assert len(data["capabilities"]) == 2 + assert data["capabilities"][0]["type"] == "devices.capabilities.on_off" + vol_cap = data["capabilities"][1] + assert vol_cap["parameters"]["instance"] == "volume" + assert vol_cap["parameters"]["range"]["max"] == 100 + assert data["device_info"]["manufacturer"] == "Test" + + +class TestDeviceState: + """Test DeviceState serialization.""" + + def test_with_capabilities(self) -> None: + """Test with capabilities.""" + state = DeviceState( + id="p1", + capabilities=[ + CapabilityState( + type=YandexCapabilityType.ON_OFF, + state=CapabilityInstanceState(instance="on", value=True), + ), + ], + ) + data = asdict(state) + assert data["id"] == "p1" + assert data["capabilities"][0]["state"]["value"] is True + assert data["error_code"] is None + + def test_error_state(self) -> None: + """Test error state.""" + state = DeviceState(id="p1", error_code="DEVICE_UNREACHABLE", error_message="Offline") + data = asdict(state) + assert data["error_code"] == "DEVICE_UNREACHABLE" + assert data["capabilities"] == [] + + +class TestActionRequestPayload: + """Test action request parsing structures.""" + + def test_single_device_action(self) -> None: + """Test single device action.""" + payload = ActionRequestPayload( + devices=[ + DeviceAction( + id="p1", + capabilities=[ + CapabilityAction( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionState(instance="on", value=True), + ), + ], + ), + ], + ) + data = asdict(payload) + assert len(data["devices"]) == 1 + assert data["devices"][0]["id"] == "p1" + cap = data["devices"][0]["capabilities"][0] + assert cap["state"]["relative"] is False + + def test_relative_volume(self) -> None: + """Test relative volume.""" + action = CapabilityAction( + type=YandexCapabilityType.RANGE, + state=CapabilityActionState(instance="volume", value=10, relative=True), + ) + data = asdict(action) + assert data["state"]["relative"] is True + assert data["state"]["value"] == 10 + + +class TestActionResult: + """Test action result structures.""" + + def test_success(self) -> None: + """Test success.""" + result = ActionResult(status="DONE") + assert asdict(result) == {"status": "DONE", "error_code": None, "error_message": None} + + def test_error(self) -> None: + """Test error.""" + result = ActionResult(status="ERROR", error_code="INVALID_ACTION", error_message="Oops") + data = asdict(result) + assert data["status"] == "ERROR" + assert data["error_code"] == "INVALID_ACTION" + + +class TestCallbackRequest: + """Test callback state request.""" + + def test_serialization(self) -> None: + """Test serialization.""" + req = CallbackRequest( + ts=1234567890.0, + payload=CallbackPayload( + user_id="test_user", + devices=[DeviceState(id="p1")], + ), + ) + data = asdict(req) + assert data["ts"] == 1234567890.0 + assert data["payload"]["user_id"] == "test_user" + assert len(data["payload"]["devices"]) == 1 + + +class TestCloudMessages: + """Test cloud WebSocket message models.""" + + def test_cloud_request(self) -> None: + """Test cloud request.""" + req = CloudRequest( + request_id="abc-123", action="/v1.0/user/devices", message={"key": "val"} + ) + assert req.request_id == "abc-123" + assert req.action == "/v1.0/user/devices" + assert req.message == {"key": "val"} + + def test_cloud_request_no_message(self) -> None: + """Test cloud request no message.""" + req = CloudRequest(request_id="abc", action="/v1.0/user/unlink") + assert req.message is None + + def test_cloud_response(self) -> None: + """Test cloud response.""" + resp = CloudResponse(request_id="abc", payload={"user_id": "u1"}) + data = asdict(resp) + assert data["request_id"] == "abc" + assert data["payload"]["user_id"] == "u1" + + +class TestDeviceListPayload: + """Test response payload structures.""" + + def test_empty(self) -> None: + """Test empty.""" + payload = DeviceListPayload(user_id="u1") + data = asdict(payload) + assert data["user_id"] == "u1" + assert data["devices"] == [] + + def test_with_devices(self) -> None: + """Test with devices.""" + payload = DeviceListPayload( + user_id="u1", + devices=[DeviceDescription(id="p1", name="Test", type="devices.types.media_device")], + ) + assert len(payload.devices) == 1 + + +class TestDeviceStatesPayload: + """Test query response payload.""" + + def test_serialization(self) -> None: + """Test serialization.""" + payload = DeviceStatesPayload(devices=[DeviceState(id="p1")]) + data = asdict(payload) + assert len(data["devices"]) == 1 + assert data["devices"][0]["id"] == "p1" + + +class TestCapabilityActionResult: + """Test action result with default factory.""" + + def test_default_result(self) -> None: + """Test default result.""" + result = CapabilityActionResult( + type=YandexCapabilityType.ON_OFF, + state=CapabilityActionResultState( + instance="on", + value=True, + action_result=ActionResult(status="DONE"), + ), + ) + assert result.state.action_result.status == "DONE" + assert result.state.action_result.error_code is None