diff --git a/music_assistant/providers/yandex_music/__init__.py b/music_assistant/providers/yandex_music/__init__.py index 7563d07e9b..105632fac0 100644 --- a/music_assistant/providers/yandex_music/__init__.py +++ b/music_assistant/providers/yandex_music/__init__.py @@ -2,26 +2,230 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, cast from music_assistant_models.config_entries import ConfigEntry, ConfigValueOption, ConfigValueType from music_assistant_models.enums import ConfigEntryType, ProviderFeature +from music_assistant_models.errors import InvalidDataError +from .auth import perform_device_auth, perform_qr_auth from .constants import ( + CONF_ACTION_AUTH_DEVICE, + CONF_ACTION_AUTH_QR, CONF_ACTION_CLEAR_AUTH, + CONF_ACTION_DELETE_WAVE_PRESET, + CONF_ACTION_SAVE_WAVE_PRESET, CONF_BASE_URL, CONF_LIKED_TRACKS_MAX_TRACKS, CONF_MY_WAVE_MAX_TRACKS, CONF_QUALITY, + CONF_REFRESH_TOKEN, + CONF_REMEMBER_SESSION, CONF_TOKEN, + CONF_WAVE_PRESET_DRAFT_DIVERSITY, + CONF_WAVE_PRESET_DRAFT_LANGUAGE, + CONF_WAVE_PRESET_DRAFT_MOOD, + CONF_WAVE_PRESET_DRAFT_NAME, + CONF_WAVE_PRESET_TO_DELETE, + CONF_WAVE_PRESETS_DATA, + CONF_X_TOKEN, DEFAULT_BASE_URL, QUALITY_BALANCED, QUALITY_EFFICIENT, QUALITY_HIGH, QUALITY_SUPERB, + WAVE_PRESET_DIVERSITY_VALUES, + WAVE_PRESET_LANGUAGE_VALUES, + WAVE_PRESET_MOOD_VALUES, ) +from .presets import parse_stored_presets as _parse_stored_presets from .provider import YandexMusicProvider + +def _save_wave_preset_action(values: dict[str, ConfigValueType]) -> None: + """Merge the current draft fields into the stored preset list. + + Overwrites an existing preset with the same name instead of creating a + duplicate. Clears draft fields after persisting so the UI returns to a + blank state. Raises ``InvalidDataError`` when the name is blank. + """ + name_raw = values.get(CONF_WAVE_PRESET_DRAFT_NAME) + name = name_raw.strip() if isinstance(name_raw, str) else "" + if not name: + raise InvalidDataError("Please fill the preset name before saving.") + presets = _parse_stored_presets(values.get(CONF_WAVE_PRESETS_DATA)) + presets = [p for p in presets if p["name"] != name] + new_preset: dict[str, str] = {"name": name} + for conf_key, api_key in ( + (CONF_WAVE_PRESET_DRAFT_DIVERSITY, "diversity"), + (CONF_WAVE_PRESET_DRAFT_MOOD, "moodEnergy"), + (CONF_WAVE_PRESET_DRAFT_LANGUAGE, "language"), + ): + val = values.get(conf_key) + if isinstance(val, str) and val: + new_preset[api_key] = val + presets.append(new_preset) + values[CONF_WAVE_PRESETS_DATA] = json.dumps(presets, ensure_ascii=False) + # Clear draft so the UI is ready for the next preset + values[CONF_WAVE_PRESET_DRAFT_NAME] = None + values[CONF_WAVE_PRESET_DRAFT_DIVERSITY] = "" + values[CONF_WAVE_PRESET_DRAFT_MOOD] = "" + values[CONF_WAVE_PRESET_DRAFT_LANGUAGE] = "" + + +def _delete_wave_preset_action(values: dict[str, ConfigValueType]) -> None: + """Remove the preset named by CONF_WAVE_PRESET_TO_DELETE from the store. + + Raises ``InvalidDataError`` when no name is selected. Idempotent — absent + names simply rewrite an unchanged list. + """ + target_raw = values.get(CONF_WAVE_PRESET_TO_DELETE) + target = target_raw.strip() if isinstance(target_raw, str) else "" + if not target: + raise InvalidDataError("Please select a preset to delete.") + presets = _parse_stored_presets(values.get(CONF_WAVE_PRESETS_DATA)) + presets = [p for p in presets if p["name"] != target] + values[CONF_WAVE_PRESETS_DATA] = json.dumps(presets, ensure_ascii=False) + values[CONF_WAVE_PRESET_TO_DELETE] = "" + + +def _wave_preset_config_entries(values: dict[str, ConfigValueType]) -> list[ConfigEntry]: + """Return the wave-preset builder UI (all advanced settings). + + Layout: + - Section label showing how many presets are saved. + - Four "draft" fields (name + three dropdowns) the user fills in. + - "Save preset" action → copies draft into the JSON store. + - "Delete preset" dropdown + action (hidden when no presets exist). + - Hidden STRING carrying the JSON store itself. + + Number of presets is unbounded; the user never edits JSON directly. + """ + empty_title = "— Default —" + diversity_options = [ + ConfigValueOption(title=empty_title if not v else v.title(), value=v) + for v in WAVE_PRESET_DIVERSITY_VALUES + ] + mood_options = [ + ConfigValueOption(title=empty_title if not v else v.title(), value=v) + for v in WAVE_PRESET_MOOD_VALUES + ] + language_options = [ + ConfigValueOption(title=empty_title if not v else v.replace("-", " ").title(), value=v) + for v in WAVE_PRESET_LANGUAGE_VALUES + ] + + presets = _parse_stored_presets(values.get(CONF_WAVE_PRESETS_DATA)) + has_presets = bool(presets) + delete_options = [ConfigValueOption(title=p["name"], value=p["name"]) for p in presets] + if not delete_options: + # Empty options can break some frontends; supply a no-op placeholder. + delete_options = [ConfigValueOption(title="(no presets saved)", value="")] + + def _str_value(key: str) -> str | None: + v = values.get(key) + return v if isinstance(v, str) else None + + return [ + ConfigEntry( + key="wave_preset_section_label", + type=ConfigEntryType.LABEL, + label=(f"My Wave presets ({len(presets)} saved)" if has_presets else "My Wave presets"), + advanced=True, + ), + ConfigEntry( + key=CONF_WAVE_PRESET_DRAFT_NAME, + type=ConfigEntryType.STRING, + label="New preset name", + description=( + "Give the preset a short name, pick up to three dropdowns " + "below and click Save. Saving the same name again overwrites." + ), + default_value=None, + required=False, + advanced=True, + value=_str_value(CONF_WAVE_PRESET_DRAFT_NAME), + ), + ConfigEntry( + key=CONF_WAVE_PRESET_DRAFT_DIVERSITY, + type=ConfigEntryType.STRING, + label="New preset: diversity", + description="How broadly the wave explores.", + options=diversity_options, + default_value="", + required=False, + advanced=True, + value=_str_value(CONF_WAVE_PRESET_DRAFT_DIVERSITY), + ), + ConfigEntry( + key=CONF_WAVE_PRESET_DRAFT_MOOD, + type=ConfigEntryType.STRING, + label="New preset: mood", + description="Energy and mood of the tracks.", + options=mood_options, + default_value="", + required=False, + advanced=True, + value=_str_value(CONF_WAVE_PRESET_DRAFT_MOOD), + ), + ConfigEntry( + key=CONF_WAVE_PRESET_DRAFT_LANGUAGE, + type=ConfigEntryType.STRING, + label="New preset: language", + description="Lyrics language filter.", + options=language_options, + default_value="", + required=False, + advanced=True, + value=_str_value(CONF_WAVE_PRESET_DRAFT_LANGUAGE), + ), + ConfigEntry( + key=CONF_ACTION_SAVE_WAVE_PRESET, + type=ConfigEntryType.ACTION, + label="Save preset", + description=( + "Adds the values above to Saved presets. The list is shown " + "under Radio then My Presets in Browse." + ), + action=CONF_ACTION_SAVE_WAVE_PRESET, + action_label="Save preset", + advanced=True, + ), + ConfigEntry( + key=CONF_WAVE_PRESET_TO_DELETE, + type=ConfigEntryType.STRING, + label="Select preset to delete", + options=delete_options, + default_value="", + required=False, + advanced=True, + hidden=not has_presets, + value=_str_value(CONF_WAVE_PRESET_TO_DELETE), + ), + ConfigEntry( + key=CONF_ACTION_DELETE_WAVE_PRESET, + type=ConfigEntryType.ACTION, + label="Delete selected preset", + description="Removes the selected preset from the Saved presets list.", + action=CONF_ACTION_DELETE_WAVE_PRESET, + action_label="Delete", + advanced=True, + hidden=not has_presets, + ), + ConfigEntry( + key=CONF_WAVE_PRESETS_DATA, + type=ConfigEntryType.STRING, + label="Saved presets (internal)", + default_value="", + required=False, + advanced=True, + hidden=True, + value=_str_value(CONF_WAVE_PRESETS_DATA) or "", + ), + ] + + if TYPE_CHECKING: from music_assistant_models.config_entries import ProviderConfig from music_assistant_models.provider import ProviderManifest @@ -34,14 +238,19 @@ ProviderFeature.LIBRARY_ALBUMS, ProviderFeature.LIBRARY_TRACKS, ProviderFeature.LIBRARY_PLAYLISTS, + ProviderFeature.LIBRARY_PODCASTS, + ProviderFeature.LIBRARY_AUDIOBOOKS, ProviderFeature.ARTIST_ALBUMS, ProviderFeature.ARTIST_TOPTRACKS, ProviderFeature.SEARCH, ProviderFeature.LIBRARY_ARTISTS_EDIT, ProviderFeature.LIBRARY_ALBUMS_EDIT, ProviderFeature.LIBRARY_TRACKS_EDIT, + ProviderFeature.LIBRARY_PODCASTS_EDIT, + ProviderFeature.LIBRARY_AUDIOBOOKS_EDIT, ProviderFeature.BROWSE, ProviderFeature.SIMILAR_TRACKS, + ProviderFeature.SIMILAR_ARTISTS, ProviderFeature.RECOMMENDATIONS, ProviderFeature.LYRICS, } @@ -55,7 +264,7 @@ async def setup( async def get_config_entries( - mass: MusicAssistant, # noqa: ARG001 + mass: MusicAssistant, instance_id: str | None = None, # noqa: ARG001 action: str | None = None, values: dict[str, ConfigValueType] | None = None, @@ -64,33 +273,142 @@ async def get_config_entries( if values is None: values = {} + # Handle QR auth action + if action == CONF_ACTION_AUTH_QR: + session_id = values.get("session_id") + if not session_id: + raise InvalidDataError("Missing session_id for QR authentication") + x_token, music_token = await perform_qr_auth(mass, str(session_id)) + values[CONF_TOKEN] = music_token + if values.get(CONF_REMEMBER_SESSION, True): + values[CONF_X_TOKEN] = x_token + else: + values[CONF_X_TOKEN] = None + # QR flow never yields a refresh_token — clear any stale one from a + # prior device-flow login so we don't leave dead credentials behind + values[CONF_REFRESH_TOKEN] = None + + # Handle Device Flow auth action (yields x_token + refresh_token, + # so we get silent auto-refresh on music-token AND x_token expiry) + if action == CONF_ACTION_AUTH_DEVICE: + session_id = values.get("session_id") + if not session_id: + raise InvalidDataError("Missing session_id for device authentication") + x_token, music_token, refresh_token = await perform_device_auth(mass, str(session_id)) + values[CONF_TOKEN] = music_token + if values.get(CONF_REMEMBER_SESSION, True): + values[CONF_X_TOKEN] = x_token + values[CONF_REFRESH_TOKEN] = refresh_token + else: + values[CONF_X_TOKEN] = None + values[CONF_REFRESH_TOKEN] = None + # Handle clear auth action if action == CONF_ACTION_CLEAR_AUTH: values[CONF_TOKEN] = None + values[CONF_X_TOKEN] = None + values[CONF_REFRESH_TOKEN] = None + + # Wave-preset save/delete actions mutate the hidden JSON store and clear + # the draft / selection fields so the UI re-renders in a clean state. + if action == CONF_ACTION_SAVE_WAVE_PRESET: + _save_wave_preset_action(values) + if action == CONF_ACTION_DELETE_WAVE_PRESET: + _delete_wave_preset_action(values) # Check if user is authenticated is_authenticated = bool(values.get(CONF_TOKEN)) + # Dynamic label text + if not is_authenticated: + label_text = ( + "Open a verification URL on any device and enter the short code, " + "or scan a QR code with the Yandex app on your phone.\n\n" + "Alternatively, you can enter a music token manually in the advanced settings." + ) + elif action in (CONF_ACTION_AUTH_QR, CONF_ACTION_AUTH_DEVICE): + label_text = "Authenticated to Yandex Music. Don't forget to save to complete setup." + else: + label_text = "Authenticated to Yandex Music." + return ( - # Authentication + # Status label ConfigEntry( - key=CONF_TOKEN, - type=ConfigEntryType.SECURE_STRING, - label="Yandex Music Token", - description="Enter your Yandex Music OAuth token. " - "See the documentation for how to obtain it.", - required=True, + key="label_text", + type=ConfigEntryType.LABEL, + label=label_text, + ), + # Device Flow authentication (primary) + ConfigEntry( + key=CONF_ACTION_AUTH_DEVICE, + type=ConfigEntryType.ACTION, + label="Login with device code", + description=("Open a verification URL on any device and enter the short code."), + action=CONF_ACTION_AUTH_DEVICE, + action_label="Login with device code", hidden=is_authenticated, - value=cast("str", values.get(CONF_TOKEN)) if values else None, ), + # QR authentication (alternative) + ConfigEntry( + key=CONF_ACTION_AUTH_QR, + type=ConfigEntryType.ACTION, + label="Login with QR code", + description="Opens a QR code page — scan it with the Yandex app on your phone.", + action=CONF_ACTION_AUTH_QR, + action_label="Login with QR code", + hidden=is_authenticated, + ), + # Remember session toggle + ConfigEntry( + key=CONF_REMEMBER_SESSION, + type=ConfigEntryType.BOOLEAN, + label="Remember session (auto-refresh token)", + description="When enabled, stores a long-lived session token to automatically " + "refresh your music token when it expires. When disabled, you must " + "re-authenticate manually when the token expires.", + default_value=True, + hidden=is_authenticated, + ), + # Clear auth ConfigEntry( key=CONF_ACTION_CLEAR_AUTH, type=ConfigEntryType.ACTION, label="Reset authentication", description="Clear the current authentication details.", action=CONF_ACTION_CLEAR_AUTH, + action_label="Reset authentication", hidden=not is_authenticated, ), + # Token storage (populated by QR action or manual entry) + ConfigEntry( + key=CONF_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Yandex Music Token (manual)", + description="Advanced: manually enter a music token. " + "See the documentation for how to obtain it.", + required=True, + hidden=is_authenticated, + advanced=True, + value=cast("str", values.get(CONF_TOKEN)) if values else None, + ), + # x_token (internal storage, always hidden) + ConfigEntry( + key=CONF_X_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Session token", + hidden=True, + required=False, + value=cast("str", values.get(CONF_X_TOKEN)) if values else None, + ), + # refresh_token (internal storage, always hidden — device flow only) + ConfigEntry( + key=CONF_REFRESH_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Refresh token", + hidden=True, + required=False, + value=cast("str", values.get(CONF_REFRESH_TOKEN)) if values else None, + ), # Quality ConfigEntry( key=CONF_QUALITY, @@ -117,6 +435,8 @@ async def get_config_entries( required=False, advanced=True, ), + # User-defined wave presets: builder + save/delete actions (dynamic list) + *_wave_preset_config_entries(values), # Liked Tracks maximum tracks (advanced) ConfigEntry( key=CONF_LIKED_TRACKS_MAX_TRACKS, diff --git a/music_assistant/providers/yandex_music/api_client.py b/music_assistant/providers/yandex_music/api_client.py index b55bf31115..6e6d94b126 100644 --- a/music_assistant/providers/yandex_music/api_client.py +++ b/music_assistant/providers/yandex_music/api_client.py @@ -29,6 +29,7 @@ from music_assistant.helpers.throttle_retry import BYPASS_THROTTLER, Throttler if TYPE_CHECKING: + from ya_passport_auth import SecretStr from yandex_music import DownloadInfo from yandex_music.feed.feed import Feed from yandex_music.landing.chart_info import ChartInfo @@ -37,7 +38,7 @@ from yandex_music.rotor.dashboard import Dashboard from yandex_music.rotor.station_result import StationResult -from .constants import DEFAULT_LIMIT, ROTOR_STATION_MY_WAVE +from .constants import DEFAULT_LIMIT # get-file-info with quality=lossless returns FLAC; default /tracks/.../download-info often does not # Prefer flac-mp4/aac-mp4 (Yandex API moved to these formats around 2025) @@ -51,10 +52,10 @@ class YandexMusicClient: """Wrapper around yandex-music-api ClientAsync.""" - def __init__(self, token: str, base_url: str | None = None) -> None: + def __init__(self, token: SecretStr, base_url: str | None = None) -> None: """Initialize the Yandex Music client. - :param token: Yandex Music OAuth token. + :param token: Yandex Music OAuth token (wrapped in SecretStr). :param base_url: Optional API base URL (defaults to Yandex Music API). """ self._token = token @@ -79,7 +80,9 @@ async def connect(self) -> bool: :raises LoginFailed: If the token is invalid. """ try: - self._client = await ClientAsync(self._token, base_url=self._base_url).init() + self._client = await ClientAsync( + self._token.get_secret(), base_url=self._base_url + ).init() if self._client.me is None or self._client.me.account is None: raise LoginFailed("Failed to get account info") self._user_id = self._client.me.account.uid @@ -228,17 +231,6 @@ async def get_rotor_station_tracks( ordered = [order_map[tid] for tid in track_ids if tid in order_map] return (ordered, result.batch_id if result else None) - async def get_my_wave_tracks( - self, queue: str | int | None = None - ) -> tuple[list[YandexTrack], str | None]: - """Get tracks from the My Wave radio station. - - :param queue: Optional track ID of the last track from the previous batch (API uses it for - pagination; do not pass batch_id). - :return: Tuple of (list of track objects, batch_id for feedback). - """ - return await self.get_rotor_station_tracks(ROTOR_STATION_MY_WAVE, queue=queue) - async def send_rotor_station_feedback( self, station_id: str, @@ -260,26 +252,66 @@ async def send_rotor_station_feedback( :param total_played_seconds: Seconds played (for trackFinished, skip). :return: True if the request succeeded. """ - payload: dict[str, Any] = { - "type": feedback_type, - "timestamp": datetime.now(UTC).isoformat().replace("+00:00", "Z"), - } - if feedback_type == "radioStarted": - payload["from"] = "YandexMusicDesktopAppWindows" - if track_id is not None: - payload["trackId"] = track_id - if total_played_seconds is not None: - payload["totalPlayedSeconds"] = total_played_seconds - if batch_id is not None: - payload["batchId"] = batch_id - - async def _post(c: ClientAsync) -> bool: - url = f"{c.base_url}/rotor/station/{station_id}/feedback" - await c._request.post(url, payload) - return True + timestamp = datetime.now(UTC).isoformat().replace("+00:00", "Z") + + async def _send(c: ClientAsync) -> bool: + if feedback_type == "radioStarted": + return bool( + await c.rotor_station_feedback_radio_started( + station_id, + from_="YandexMusicDesktopAppWindows", + batch_id=batch_id, + timestamp=timestamp, + ) + ) + if feedback_type == "trackStarted": + if track_id is None: + return False + return bool( + await c.rotor_station_feedback_track_started( + station_id, + track_id=track_id, + batch_id=batch_id, + timestamp=timestamp, + ) + ) + if feedback_type == "trackFinished": + if track_id is None: + return False + return bool( + await c.rotor_station_feedback_track_finished( + station_id, + track_id=track_id, + total_played_seconds=float(total_played_seconds or 0), + batch_id=batch_id, + timestamp=timestamp, + ) + ) + if feedback_type == "skip": + if track_id is None: + return False + return bool( + await c.rotor_station_feedback_skip( + station_id, + track_id=track_id, + total_played_seconds=float(total_played_seconds or 0), + batch_id=batch_id, + timestamp=timestamp, + ) + ) + return bool( + await c.rotor_station_feedback( + station_id, + type_=feedback_type, + timestamp=timestamp, + track_id=track_id, + total_played_seconds=total_played_seconds, + batch_id=batch_id, + ) + ) try: - result = await self._call_no_retry(_post) + result = await self._call_no_retry(_send) LOGGER.debug( "Rotor feedback %s track_id=%s total_played_seconds=%s", feedback_type, @@ -294,6 +326,257 @@ async def _post(c: ClientAsync) -> bool: LOGGER.warning("Rotor feedback %s failed: %s", feedback_type, err) return False + # Rotor session API (new session-based endpoints) + # + # Yandex's newer rotor API models a wave as a long-lived session: + # POST /rotor/session/new → {radioSessionId, sequence, batchId} + # POST /rotor/session/{sessionId}/tracks → {sequence, batchId} + # POST /rotor/session/{sessionId}/feedback → {result: "ok"} + # All feedback events carry the same sessionId, so we no longer need to + # thread per-batch batch_ids through call sites the way the stations-based + # API forced us to. + + async def _rotor_session_request( + self, path: str, body: dict[str, Any], *, with_retry: bool = True + ) -> dict[str, Any] | None: + """POST a JSON body to /rotor/session/{path} and return parsed result. + + Reuses the MarshalX ClientAsync internal request object so we inherit + its auth headers and parsing. `json=` is forwarded to `aiohttp.request` + by MarshalX's `**kwargs` passthrough. + + :param path: Path suffix after /rotor/session/ (e.g. "new", + "{session_id}/tracks", "{session_id}/feedback"). + :param body: JSON body to send. + :param with_retry: When True (default), uses the same reconnect-on- + transient-connection-error path as normal data fetches — + appropriate for ``new`` and ``tracks`` which sit on the + user-facing browse/play path. Set to False for ``feedback``, + where a dropped request should be silently lost rather than + hammered against a potentially rate-limiting server. + :return: Parsed result dict, or None on failure. + """ + + async def _do(c: ClientAsync) -> dict[str, Any] | None: + base = getattr(c, "base_url", "https://api.music.yandex.net") + url = f"{base}/rotor/session/{path}" + LOGGER.debug("Rotor session POST %s body_keys=%s", path, list(body.keys())) + try: + result = await c._request.post(url, json=body) + except NetworkError: + # Let the outer retry wrapper see transient drops. On the + # no-retry path the outer `except` below swallows it silently. + if with_retry: + raise + LOGGER.debug("Rotor session POST %s: network error (no retry)", path) + return None + except BadRequestError as err: + # 4xx is terminal — server rejected the body; retry would only + # reproduce the same failure. + LOGGER.warning("Rotor session POST %s failed: %s", path, err) + return None + if isinstance(result, dict): + LOGGER.debug("Rotor session POST %s → result keys=%s", path, list(result.keys())) + return result + LOGGER.debug("Rotor session POST %s → non-dict result: %r", path, result) + return None + + runner = self._call_with_retry if with_retry else self._call_no_retry + try: + return await runner(_do) + except UnauthorizedError as err: + # Expired/invalidated token. Surface as LoginFailed so MA prompts + # for re-auth instead of the raw yandex_music exception bubbling + # through browse / play and crashing the caller. + LOGGER.warning("Rotor session POST %s: token no longer valid", path) + raise LoginFailed("Invalid Yandex Music token") from err + except (NetworkError, ProviderUnavailableError) as err: + LOGGER.warning("Rotor session POST %s failed: %s", path, err) + return None + + async def rotor_session_new( + self, + station_id: str, + *, + settings: dict[str, str] | None = None, + queue: list[str] | None = None, + ) -> tuple[str | None, list[YandexTrack], str | None]: + """Create a new rotor session. + + Sends `includeWaveModel: true` so Yandex applies its wave ML model and + `interactive: true` so the session is treated as foreground user play. + + :param station_id: Station ID (e.g. "user:onyourwave" or "track:123"). + :param settings: Optional {diversity, moodEnergy, language} — each + becomes an additional seed like "settingDiversity:discover". + :param queue: Optional initial track IDs in the queue; usually empty. + :return: Tuple of (radio_session_id, list of tracks, batch_id). + Any element may be None/[] on failure. + """ + seeds: list[str] = [station_id] + if settings: + for key, seed_name in ( + ("diversity", "settingDiversity"), + ("moodEnergy", "settingMoodEnergy"), + ("language", "settingLanguage"), + ): + val = settings.get(key) + if val: + seeds.append(f"{seed_name}:{val}") + body: dict[str, Any] = { + "seeds": seeds, + "queue": queue or [], + "includeTracksInResponse": True, + "includeWaveModel": True, + "interactive": True, + } + result = await self._rotor_session_request("new", body) + if not result: + return (None, [], None) + session_id = result.get("radioSessionId") + batch_id = result.get("batchId") + tracks = await self._hydrate_session_tracks(result.get("sequence") or []) + return (session_id, tracks, batch_id) + + async def rotor_session_tracks( + self, session_id: str, *, current_track_id: str + ) -> tuple[list[YandexTrack], str | None]: + """Fetch the next batch of tracks for an active rotor session. + + :param session_id: radioSessionId from rotor_session_new(). + :param current_track_id: Track ID just consumed from the previous batch + (Yandex uses it to decide what to return next). + :return: Tuple of (list of tracks, new batch_id). + """ + body = {"queue": [str(current_track_id)]} + result = await self._rotor_session_request(f"{session_id}/tracks", body) + if not result: + return ([], None) + batch_id = result.get("batchId") + tracks = await self._hydrate_session_tracks(result.get("sequence") or []) + return (tracks, batch_id) + + async def rotor_session_feedback( + self, + session_id: str, + event_type: str, + *, + track_id: str | None = None, + total_played_seconds: int | None = None, + batch_id: str | None = None, + ) -> bool: + """Send a feedback event for an active rotor session. + + Supports the Yandex rotor event types: radioStarted, trackStarted, + trackFinished, skip, like, dislike. For radioStarted the track_id goes + into `event.from`; all other types use `event.trackId`. Only + trackFinished and skip carry `totalPlayedSeconds`. + + :param session_id: radioSessionId. + :param event_type: rotor event type string. + :param track_id: Yandex track ID the event refers to (required for + everything except radioStarted without a seed). + :param total_played_seconds: seconds of the track that were played + (only meaningful for trackFinished / skip). + :param batch_id: batchId from the most recent rotor_session_{new,tracks} + response; anchors the event to a specific batch. + :return: True if the POST succeeded. + """ + timestamp = datetime.now(UTC).isoformat().replace("+00:00", "Z") + event: dict[str, Any] = {"type": event_type, "timestamp": timestamp} + if event_type == "radioStarted": + if track_id is not None: + event["from"] = str(track_id) + elif track_id is not None: + event["trackId"] = str(track_id) + if event_type in ("trackFinished", "skip") and total_played_seconds is not None: + event["totalPlayedSeconds"] = int(total_played_seconds) + body: dict[str, Any] = {"event": event} + if batch_id: + body["batchId"] = batch_id + LOGGER.debug( + "Rotor session feedback: session=%s event=%s track=%s secs=%s batch=%s", + session_id, + event_type, + track_id, + total_played_seconds, + batch_id, + ) + result = await self._rotor_session_request(f"{session_id}/feedback", body, with_retry=False) + return result is not None + + async def _hydrate_session_tracks(self, sequence: list[dict[str, Any]]) -> list[YandexTrack]: + """Extract track IDs from a rotor session sequence and hydrate via get_tracks. + + The session endpoints return tracks inline when includeTracksInResponse + is true, but full track objects (with download info, covers, etc.) are + fetched separately so parsed Track objects have the same shape as in + the rest of the provider. + + :param sequence: List of sequence items from a rotor session response. + :return: List of full track objects in the same order as `sequence`. + """ + track_ids: list[str] = [] + for seq in sequence: + tr = seq.get("track") if isinstance(seq, dict) else None + tid = None + if isinstance(tr, dict): + tid = tr.get("id") or tr.get("track_id") + if tid is not None: + track_ids.append(str(tid)) + if not track_ids: + return [] + try: + full_tracks = await self.get_tracks(track_ids) + except ResourceTemporarilyUnavailable as err: + LOGGER.warning("Rotor session track hydration failed: %s", err) + return [] + order_map = {str(t.id): t for t in full_tracks if hasattr(t, "id") and t.id} + return [order_map[tid] for tid in track_ids if tid in order_map] + + async def play_audio( + self, + *, + track_id: str, + album_id: str, + play_id: str, + track_length_seconds: int, + total_played_seconds: int, + end_position_seconds: int, + from_: str = "music_assistant-audiobook", + ) -> bool: + """Report playback progress for an audiobook chapter or podcast episode. + + Yandex persists this server-side so progress is visible across its + other clients. Failures are swallowed — progress sync is advisory and + must never abort pause/stop handling — so auth failures, rate-limits + and network blips all log at debug and return False. + """ + try: + return bool( + await self._call_no_retry( + lambda c: c.play_audio( + track_id=track_id, + album_id=album_id, + from_=from_, + play_id=play_id, + track_length_seconds=track_length_seconds, + total_played_seconds=total_played_seconds, + end_position_seconds=end_position_seconds, + ) + ) + ) + except ( + BadRequestError, + NetworkError, + ProviderUnavailableError, + UnauthorizedError, + LoginFailed, + ResourceTemporarilyUnavailable, + ) as err: + LOGGER.debug("play_audio failed for %s: %s", track_id, err) + return False + # Library methods async def get_liked_tracks(self) -> list[TrackShort]: @@ -558,27 +841,22 @@ async def get_album_with_tracks(self, album_id: str) -> YandexAlbum | None: """Get an album with its tracks. Uses the same semantics as the web client: albums/{id}/with-tracks - with resumeStream, richTracks, withListeningFinished when the library - passes them through. + with resumeStream, richTracks, withListeningFinished. :param album_id: Album ID. :return: Album object with tracks or None if not found. """ - - async def _fetch(c: ClientAsync) -> YandexAlbum | None: - try: - return await c.albums_with_tracks( + try: + return await self._call_with_retry( + lambda c: c.albums_with_tracks( album_id, - resumeStream=True, - richTracks=True, - withListeningFinished=True, + params={ + "resumeStream": "true", + "richTracks": "true", + "withListeningFinished": "true", + }, ) - except TypeError: - # Older yandex-music may not accept these kwargs - return await c.albums_with_tracks(album_id) - - try: - return await self._call_with_retry(_fetch) + ) except (BadRequestError, NetworkError, ProviderUnavailableError) as err: LOGGER.error("Error fetching album with tracks %s: %s", album_id, err) return None @@ -616,6 +894,59 @@ async def get_artist_albums( LOGGER.error("Error fetching artist albums %s: %s", artist_id, err) return [] + async def get_pins(self) -> Any | None: + """Get the user's pinned items (artists/albums/playlists/waves). + + :return: PinsList object or None on error. + """ + try: + return await self._call_with_retry(lambda c: c.pins()) + except (BadRequestError, NetworkError, ProviderUnavailableError) as err: + LOGGER.error("Error fetching pins: %s", err) + return None + + async def get_music_history(self) -> Any | None: + """Get the user's listening history (grouped by day). + + :return: MusicHistory object or None on error. + """ + try: + return await self._call_with_retry(lambda c: c.music_history()) + except (BadRequestError, NetworkError, ProviderUnavailableError) as err: + LOGGER.error("Error fetching music history: %s", err) + return None + + async def get_artist_about(self, artist_id: str) -> Any | None: + """Get artist enrichment info: description, monthly listeners, links. + + :param artist_id: Artist ID. + :return: ArtistAbout object or None on error/missing. + """ + try: + return await self._call_with_retry(lambda c: c.artists_about(artist_id)) + except (BadRequestError, NetworkError, ProviderUnavailableError) as err: + LOGGER.error("Error fetching artist about %s: %s", artist_id, err) + return None + + async def get_similar_artists( + self, artist_id: str, limit: int = DEFAULT_LIMIT + ) -> list[YandexArtist]: + """Get artists similar to the given one. + + :param artist_id: Artist ID. + :param limit: Maximum number of artists. + :return: List of similar artist objects. + """ + try: + result = await self._call_with_retry(lambda c: c.artists_similar(artist_id)) + if result is None or not result.similar_artists: + return [] + similar: list[YandexArtist] = result.similar_artists + return similar[:limit] + except (BadRequestError, NetworkError, ProviderUnavailableError) as err: + LOGGER.error("Error fetching similar artists %s: %s", artist_id, err) + return [] + async def get_artist_tracks( self, artist_id: str, limit: int = DEFAULT_LIMIT ) -> list[YandexTrack]: @@ -678,18 +1009,31 @@ async def get_track_download_info( LOGGER.error("Error fetching download info for track %s: %s", track_id, err) return [] - async def get_track_file_info_lossless(self, track_id: str) -> dict[str, Any] | None: - """Request lossless stream via get-file-info (quality=lossless). + async def get_track_file_info( + self, + track_id: str, + quality: str = "lossless", + codecs: str = GET_FILE_INFO_CODECS, + transport: str = "raw", + ) -> dict[str, Any] | None: + """Request stream via get-file-info for any quality tier. + + The /get-file-info endpoint supports all quality tiers (lossless, nq, lq) + and returns the best available codec based on the codecs parameter order. - The /tracks/{id}/download-info endpoint often returns only MP3; get-file-info - with quality=lossless and codecs=flac,... returns FLAC when available. + With transport="raw", returns a direct unencrypted URL. + With transport="encraw", returns an AES-CTR encrypted URL with decryption key. - Uses manual sign calculation matching yandex-music-downloader-realflac. Uses _call_with_retry for automatic reconnection on transient failures. :param track_id: Track ID. - :return: Parsed downloadInfo dict (url, codec, urls, ...) or None on error. + :param quality: Quality tier ("lossless", "nq", "lq"). + :param codecs: Comma-separated codec preference list. + :param transport: Transport mode ("raw" or "encraw"). + :return: Parsed downloadInfo dict (url, codec, key?, ...) or None on error. """ + # Normalize codecs: strip whitespace from each token to prevent HMAC mismatches + codecs = ",".join(c.strip() for c in codecs.split(",") if c.strip()) def _build_signed_params(client: ClientAsync) -> tuple[str, dict[str, Any]]: """Build URL and signed params using current client and timestamp. @@ -701,16 +1045,13 @@ def _build_signed_params(client: ClientAsync) -> tuple[str, dict[str, Any]]: params = { "ts": timestamp, "trackId": track_id, - "quality": "lossless", - "codecs": GET_FILE_INFO_CODECS, - "transports": "encraw", + "quality": quality, + "codecs": codecs, + "transports": transport, } - # Build sign string explicitly matching Yandex API specification: - # concatenate ts + trackId + quality + codecs (commas stripped) + transports. - # Comma stripping matches yandex-music-downloader-realflac reference implementation - # (see get_file_info signing in that project). - codecs_for_sign = GET_FILE_INFO_CODECS.replace(",", "") - param_string = f"{timestamp}{track_id}lossless{codecs_for_sign}encraw" + # Build sign string: ts + trackId + quality + codecs (commas stripped) + transports. + codecs_for_sign = codecs.replace(",", "") + param_string = f"{timestamp}{track_id}{quality}{codecs_for_sign}{transport}" hmac_sign = hmac.new( DEFAULT_SIGN_KEY.encode(), param_string.encode(), @@ -718,7 +1059,6 @@ def _build_signed_params(client: ClientAsync) -> tuple[str, dict[str, Any]]: ) # SHA-256 (32 bytes) -> base64 = 44 chars with "=" padding. # Yandex API expects exactly 43 chars (one "=" removed). - # Matches yandex-music-downloader-realflac reference implementation. params["sign"] = base64.b64encode(hmac_sign.digest()).decode()[:-1] url = f"{client.base_url}/get-file-info" return url, params @@ -726,7 +1066,9 @@ def _build_signed_params(client: ClientAsync) -> tuple[str, dict[str, Any]]: def _parse_file_info_result(raw: dict[str, Any] | None) -> dict[str, Any] | None: if not raw or not isinstance(raw, dict): return None - download_info = raw.get("download_info") + # yandex-music v3 no longer normalises camelCase keys inside + # Response.result, so /get-file-info returns "downloadInfo" as-is. + download_info = raw.get("download_info") or raw.get("downloadInfo") if not download_info or not download_info.get("url"): return None @@ -752,30 +1094,38 @@ async def _do_request(c: ClientAsync) -> dict[str, Any] | None: parsed = _parse_file_info_result(result) if parsed: LOGGER.debug( - "get-file-info lossless for track %s: Success, codec=%s", + "get-file-info for track %s: Success, codec=%s, transport=%s", track_id, parsed.get("codec"), + transport, ) return parsed - except (BadRequestError, NetworkError) as err: + except ( + BadRequestError, + NetworkError, + ProviderUnavailableError, + ResourceTemporarilyUnavailable, + ) as err: LOGGER.debug( - "get-file-info lossless for track %s: %s %s", + "get-file-info for track %s: %s %s", track_id, type(err).__name__, getattr(err, "message", str(err)) or repr(err), ) except UnauthorizedError as err: LOGGER.debug( - "get-file-info lossless for track %s: UnauthorizedError %s", + "get-file-info for track %s: UnauthorizedError %s", track_id, getattr(err, "message", str(err)) or repr(err), ) + except asyncio.CancelledError: + raise except Exception as err: LOGGER.warning( - "get-file-info lossless for track %s: Unexpected error: %s", + "get-file-info for track %s: Unexpected %s: %s", track_id, + type(err).__name__, err, - exc_info=True, ) return None diff --git a/music_assistant/providers/yandex_music/auth.py b/music_assistant/providers/yandex_music/auth.py new file mode 100644 index 0000000000..920fc75319 --- /dev/null +++ b/music_assistant/providers/yandex_music/auth.py @@ -0,0 +1,362 @@ +"""Yandex Music authentication flows. + +Two user-facing login paths, both backed by ``ya-passport-auth``: + +* **QR flow** — :func:`perform_qr_auth` opens a QR popup via the MA frontend + and polls Passport until the user scans/confirms. Yields + ``(x_token, music_token)``. +* **Device Flow** — :func:`perform_device_auth` serves a short user code on + an MA-hosted intermediate page and polls Passport until confirmation. + Yields the full ``(x_token, music_token, refresh_token)`` triple thanks + to ``ya-passport-auth`` v1.3.0 reusing the same Passport Android + ``client_id`` as the QR flow. + +Token maintenance helpers (:func:`refresh_music_token`, +:func:`refresh_credentials_via_passport`, :func:`validate_x_token`) live +alongside the login flows. +""" + +from __future__ import annotations + +import asyncio +import html +import json +import logging +from typing import TYPE_CHECKING + +from aiohttp import web +from music_assistant_models.errors import LoginFailed, ResourceTemporarilyUnavailable +from ya_passport_auth import Credentials, PassportClient, SecretStr +from ya_passport_auth.exceptions import ( + DeviceCodeTimeoutError, + NetworkError, + QRTimeoutError, + RateLimitedError, + YaPassportError, +) + +from music_assistant.helpers.auth import AuthenticationHelper + +if TYPE_CHECKING: + from music_assistant import MusicAssistant + +_LOGGER = logging.getLogger(__name__) + +_DEVICE_CODE_PAGE_PATH = "/yandex_music/device_code" +# Seconds to keep the status endpoint alive after the flow finishes so the +# intermediate page has a chance to poll once more and close itself. +_POST_AUTH_GRACE_SECONDS = 3 + + +def _build_device_code_page( + user_code: str, + verification_url: str, + status_url: str, +) -> str: + """Render the HTML page shown to the user during Device Flow login. + + Yandex's verification page does not pre-fill the code from query params, + and the MA frontend opens auth URLs in a new tab, so the user would + otherwise have no signal that authorization succeeded. The page polls the + status endpoint and closes itself (or shows a success message) when the + backend signals completion. + """ + safe_code = html.escape(user_code) + safe_url = html.escape(verification_url, quote=True) + # json.dumps emits a JS string literal, but `` would still break + # out of the surrounding + + +""" + + +async def perform_device_auth(mass: MusicAssistant, session_id: str) -> tuple[str, str, str]: + """Perform Yandex OAuth Device Flow and return credential tokens. + + Asks Yandex for a device code, presents it to the user via an intermediate + HTML page served from MA's own webserver, then polls until the user + confirms or the code expires. + + Returns (x_token, music_token, refresh_token) as plain strings for MA + config storage. + """ + try: + async with PassportClient.create() as client: + session = await client.start_device_login() + + _LOGGER.info( + "Device flow started: open %s (expires in %ss)", + session.verification_url, + session.expires_in, + ) + _LOGGER.debug("Device flow user_code issued") + + page_path = f"{_DEVICE_CODE_PAGE_PATH}/{session_id}" + status_path = f"{page_path}/status" + status_url = f"{mass.webserver.base_url}{status_path}" + state = {"value": "pending"} + + page_html = _build_device_code_page( + session.user_code, session.verification_url, status_url + ) + + async def _serve_page(_request: web.Request) -> web.Response: + return web.Response( + text=page_html, + content_type="text/html", + charset="utf-8", + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + "Expires": "0", + }, + ) + + async def _serve_status(_request: web.Request) -> web.Response: + return web.json_response( + {"state": state["value"]}, + headers={"Cache-Control": "no-store"}, + ) + + mass.webserver.register_dynamic_route(page_path, _serve_page, "GET") + mass.webserver.register_dynamic_route(status_path, _serve_status, "GET") + try: + async with AuthenticationHelper(mass, session_id) as auth_helper: + auth_helper.send_url(f"{mass.webserver.base_url}{page_path}") + try: + creds = await client.poll_device_until_confirmed(session) + except asyncio.CancelledError: + # Don't mark cancellations as auth failures. + raise + except Exception: + state["value"] = "failed" + # Give the page one more poll to surface the failure + # message before we tear the status route down. + await asyncio.sleep(_POST_AUTH_GRACE_SECONDS) + raise + state["value"] = "done" + # Give the intermediate page one more poll to pick up "done" + # and close itself before we tear the status route down. + await asyncio.sleep(_POST_AUTH_GRACE_SECONDS) + finally: + mass.webserver.unregister_dynamic_route(page_path, "GET") + mass.webserver.unregister_dynamic_route(status_path, "GET") + + music_token = creds.music_token + if music_token is None: + raise LoginFailed("Device auth succeeded but no music token was returned") + refresh_token = creds.refresh_token + if refresh_token is None: + raise LoginFailed("Device auth succeeded but no refresh token was returned") + + _LOGGER.debug("Device flow complete, obtained full credential triple") + return ( + creds.x_token.get_secret(), + music_token.get_secret(), + refresh_token.get_secret(), + ) + + except DeviceCodeTimeoutError as err: + raise LoginFailed("Device authentication timed out. Please try again.") from err + except YaPassportError as err: + raise LoginFailed(f"Yandex device auth error: {err}") from err + + +async def perform_qr_auth(mass: MusicAssistant, session_id: str) -> tuple[str, str]: + """Perform full QR authentication flow. + + Opens a QR code popup via MA frontend, polls for scan confirmation, + then returns tokens as plain strings for MA config storage. + + Returns (x_token, music_token). + """ + try: + async with PassportClient.create() as client: + qr = await client.start_qr_login() + + async with AuthenticationHelper(mass, session_id) as auth_helper: + auth_helper.send_url(qr.qr_url) + creds = await client.poll_qr_until_confirmed(qr) + + x_token = creds.x_token.get_secret() + music_token = creds.music_token + if music_token is None: + raise LoginFailed("QR auth succeeded but no music token was returned") + + _LOGGER.debug("QR auth complete, obtained both tokens") + return x_token, music_token.get_secret() + + except QRTimeoutError as err: + raise LoginFailed("QR authentication timed out. Please try again.") from err + except YaPassportError as err: + raise LoginFailed(f"Yandex auth error: {err}") from err + + +async def refresh_music_token(x_token: SecretStr) -> SecretStr: + """Exchange an x_token for a fresh music-scoped OAuth token. + + Distinguishes transient Passport failures (network/rate limiting) from + credential-invalid errors: only the latter raise ``LoginFailed``, so + callers don't clear stored tokens on a Passport blip. + """ + try: + async with PassportClient.create() as client: + return await client.refresh_music_token(x_token) + except (NetworkError, RateLimitedError) as err: + raise ResourceTemporarilyUnavailable( + f"Yandex Passport temporarily unavailable: {err}" + ) from err + except YaPassportError as err: + raise LoginFailed(f"Failed to refresh music token: {err}") from err + + +async def refresh_credentials_via_passport( + x_token: SecretStr, refresh_token: SecretStr +) -> Credentials: + """Silently re-issue the full credential triple using a refresh token. + + Only available for accounts authenticated via the Device Flow (QR login + does not yield a ``refresh_token``). Rotates both ``x_token`` and + ``refresh_token`` server-side, so callers must persist the returned + Credentials. + """ + try: + async with PassportClient.create() as client: + return await client.refresh_credentials( + Credentials(x_token=x_token, refresh_token=refresh_token) + ) + except (NetworkError, RateLimitedError) as err: + raise ResourceTemporarilyUnavailable( + f"Yandex Passport temporarily unavailable: {err}" + ) from err + except YaPassportError as err: + raise LoginFailed(f"Failed to refresh credentials: {err}") from err + + +async def validate_x_token(x_token: SecretStr) -> bool: + """Return True if *x_token* is still accepted by Yandex Passport.""" + try: + async with PassportClient.create() as client: + return bool(await client.validate_x_token(x_token)) + except YaPassportError: + return False diff --git a/music_assistant/providers/yandex_music/constants.py b/music_assistant/providers/yandex_music/constants.py index e86ed1d59b..0d4a1fa7e1 100644 --- a/music_assistant/providers/yandex_music/constants.py +++ b/music_assistant/providers/yandex_music/constants.py @@ -11,8 +11,15 @@ # Actions CONF_ACTION_AUTH = "auth" +CONF_ACTION_AUTH_QR = "auth_qr" +CONF_ACTION_AUTH_DEVICE = "auth_device" CONF_ACTION_CLEAR_AUTH = "clear_auth" +# QR authentication config keys +CONF_X_TOKEN = "x_token" +CONF_REFRESH_TOKEN: Final[str] = "refresh_token" +CONF_REMEMBER_SESSION = "remember_session" + # Labels LABEL_TOKEN = "token_label" LABEL_AUTH_INSTRUCTIONS = "auth_instructions_label" @@ -28,6 +35,35 @@ QUALITY_HIGH = "high" # High quality, lossy (~320kbps MP3) QUALITY_SUPERB = "superb" # Highest quality, lossless (FLAC) +# Transport modes for get-file-info API +CONF_TRANSPORT = "transport" +TRANSPORT_RAW = "raw" # Direct unencrypted stream (default) +TRANSPORT_ENCRAW = "encraw" # AES-CTR encrypted stream + +# Custom codecs override (empty = use quality-based default) +CONF_CODECS = "codecs" + +# Quality → get-file-info parameter mapping +# Codecs order determines API priority (first codec = preferred by server) +QUALITY_FILE_INFO_PARAMS: Final[dict[str, dict[str, str]]] = { + QUALITY_SUPERB: { + "quality": "lossless", + "codecs": "flac-mp4,flac,aac-mp4,aac,he-aac,mp3,he-aac-mp4", + }, + QUALITY_HIGH: { + "quality": "lossless", + "codecs": "mp3", + }, + QUALITY_BALANCED: { + "quality": "nq", + "codecs": "aac-mp4,aac,mp3,he-aac,he-aac-mp4", + }, + QUALITY_EFFICIENT: { + "quality": "lq", + "codecs": "he-aac-mp4,he-aac,aac,mp3", + }, +} + # Configuration keys for My Wave behavior (kept) CONF_MY_WAVE_MAX_TRACKS: Final[str] = "my_wave_max_tracks" @@ -76,6 +112,87 @@ # Composite item_id for My Wave tracks: track_id + separator + station_id (for rotor feedback) RADIO_TRACK_ID_SEP: Final[str] = "@" +# Wave-mode suffix separator: station keys like "user:onyourwave#discover" identify +# a specific preset (diversity/moodEnergy/language) on top of the base My Wave station. +# Chosen because # is not part of any rotor station ID format. +WAVE_MODE_SEP: Final[str] = "#" + +# Known wave-mode presets: preset key (suffix after WAVE_MODE_SEP) → rotor session +# settings dict. Names match the LMS YandexMusic plugin and the Desktop client UI. +MY_WAVE_MODES_FOLDER_ID: Final[str] = "my_wave_modes" +MY_WAVE_PRESETS_FOLDER_ID: Final[str] = "my_wave_presets" + +# User-defined wave presets are now stored in a single hidden JSON config key. +# The UI shows a small "builder" (name + three dropdowns) + Save / Delete +# action buttons, so the user never has to edit JSON by hand but has no fixed +# upper bound on preset count either. + +# Hidden JSON store. Shape: [{"name": str, "diversity"?: str, +# "moodEnergy"?: str, "language"?: str}, ...] +CONF_WAVE_PRESETS_DATA: Final[str] = "wave_presets_data" + +# Visible "working preset" fields — filled in, then copied into the JSON list +# by the save action and cleared afterwards. +CONF_WAVE_PRESET_DRAFT_NAME: Final[str] = "wave_preset_draft_name" +CONF_WAVE_PRESET_DRAFT_DIVERSITY: Final[str] = "wave_preset_draft_diversity" +CONF_WAVE_PRESET_DRAFT_MOOD: Final[str] = "wave_preset_draft_mood" +CONF_WAVE_PRESET_DRAFT_LANGUAGE: Final[str] = "wave_preset_draft_language" + +# Dropdown of saved preset names for the delete flow. +CONF_WAVE_PRESET_TO_DELETE: Final[str] = "wave_preset_to_delete" + +# Action button ids. +CONF_ACTION_SAVE_WAVE_PRESET: Final[str] = "save_wave_preset" +CONF_ACTION_DELETE_WAVE_PRESET: Final[str] = "delete_wave_preset" + +# Allowed per-dimension values (plus "" to mean "use wave default"). +WAVE_PRESET_DIVERSITY_VALUES: Final[tuple[str, ...]] = ( + "", + "discover", + "favorite", + "popular", +) +WAVE_PRESET_MOOD_VALUES: Final[tuple[str, ...]] = ( + "", + "active", + "fun", + "calm", + "sad", +) +WAVE_PRESET_LANGUAGE_VALUES: Final[tuple[str, ...]] = ( + "", + "russian", + "not-russian", + "without-words", +) + +WAVE_MODE_PRESETS: Final[dict[str, dict[str, str]]] = { + "discover": {"diversity": "discover"}, + "favorite": {"diversity": "favorite"}, + "popular": {"diversity": "popular"}, + "calm": {"moodEnergy": "calm"}, + "active": {"moodEnergy": "active"}, + "fun": {"moodEnergy": "fun"}, + "sad": {"moodEnergy": "sad"}, + "russian": {"language": "russian"}, + "not_russian": {"language": "not-russian"}, + "without_words": {"language": "without-words"}, +} + +# Ordered list of preset keys for Browse display. +WAVE_MODE_ORDER: Final[tuple[str, ...]] = ( + "discover", + "favorite", + "popular", + "calm", + "active", + "fun", + "sad", + "russian", + "not_russian", + "without_words", +) + # Browse folder names by locale (item_id -> display name) BROWSE_NAMES_RU: Final[dict[str, str]] = { "my_wave": "Моя волна", @@ -83,6 +200,8 @@ "albums": "Мои альбомы", "tracks": "Мне нравится", "playlists": "Мои плейлисты", + "audiobooks": "Мои аудиокниги", + "podcasts": "Мои подкасты", "feed": "Для вас", "chart": "Чарт", "new_releases": "Новинки", @@ -139,6 +258,8 @@ # Top-level browse groups "for_you": "Для вас", "collection": "Коллекция", + "pinned": "Закреплённое", + "history": "История прослушиваний", # Waves / Radio (rotor station categories) "waves": "Радио", "radio": "Радио", @@ -148,6 +269,19 @@ "genre": "Жанры", "epoch": "Эпоха", "local": "Местное", + # Wave-mode folder + presets (P4) + "my_wave_modes": "Режимы волны", + "my_wave_presets": "Мои пресеты", + "wave_mode_discover": "Открытия", + "wave_mode_favorite": "Любимое", + "wave_mode_popular": "Популярное", + "wave_mode_calm": "Спокойнее", + "wave_mode_active": "Активнее", + "wave_mode_fun": "Весёлое", + "wave_mode_sad": "Грустное", + "wave_mode_russian": "Русское", + "wave_mode_not_russian": "Не русское", # noqa: RUF001 + "wave_mode_without_words": "Без слов", } BROWSE_NAMES_EN: Final[dict[str, str]] = { "my_wave": "My Wave", @@ -155,6 +289,8 @@ "albums": "My Albums", "tracks": "My Favorites", "playlists": "My Playlists", + "audiobooks": "My Audiobooks", + "podcasts": "My Podcasts", "feed": "Made for You", "chart": "Chart", "new_releases": "New Releases", @@ -211,6 +347,8 @@ # Top-level browse groups "for_you": "For You", "collection": "Collection", + "pinned": "Pinned", + "history": "Listening History", # Waves / Radio (rotor station categories) "waves": "Radio", "radio": "Radio", @@ -220,6 +358,19 @@ "genre": "Genres", "epoch": "Era", "local": "Local", + # Wave-mode folder + presets (P4) + "my_wave_modes": "Wave Modes", + "my_wave_presets": "My Presets", + "wave_mode_discover": "Discover", + "wave_mode_favorite": "Favorites", + "wave_mode_popular": "Popular", + "wave_mode_calm": "Calm", + "wave_mode_active": "Active", + "wave_mode_fun": "Fun", + "wave_mode_sad": "Sad", + "wave_mode_russian": "Russian", + "wave_mode_not_russian": "Non-Russian", + "wave_mode_without_words": "Without Words", } # Tag categories for Picks and Recommendations @@ -335,6 +486,8 @@ # Top-level browse group folders FOR_YOU_FOLDER_ID: Final[str] = "for_you" COLLECTION_FOLDER_ID: Final[str] = "collection" +PINNED_ITEMS_FOLDER_ID: Final[str] = "pinned" +LISTENING_HISTORY_FOLDER_ID: Final[str] = "history" # Preferred display order for wave categories (rotor station types) WAVE_CATEGORY_DISPLAY_ORDER: Final[list[str]] = [ diff --git a/music_assistant/providers/yandex_music/manifest.json b/music_assistant/providers/yandex_music/manifest.json index 8fb8a38bb6..a96a5dc410 100644 --- a/music_assistant/providers/yandex_music/manifest.json +++ b/music_assistant/providers/yandex_music/manifest.json @@ -7,6 +7,6 @@ "codeowners": ["@TrudenBoy"], "credits": ["[yandex-music-api](https://github.com/MarshalX/yandex-music-api)"], "documentation": "https://music-assistant.io/music-providers/yandex-music/", - "requirements": ["yandex-music==2.2.0"], + "requirements": ["yandex-music==3.0.0", "ya-passport-auth==1.3.0"], "multi_instance": true } diff --git a/music_assistant/providers/yandex_music/parsers.py b/music_assistant/providers/yandex_music/parsers.py index fab0b27c75..8453f284a2 100644 --- a/music_assistant/providers/yandex_music/parsers.py +++ b/music_assistant/providers/yandex_music/parsers.py @@ -4,19 +4,23 @@ from contextlib import suppress from datetime import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from music_assistant_models.enums import ( AlbumType, ContentType, ImageType, ) +from music_assistant_models.errors import InvalidDataError from music_assistant_models.media_items import ( Album, Artist, + Audiobook, AudioFormat, MediaItemImage, Playlist, + Podcast, + PodcastEpisode, ProviderMapping, Track, UniqueList, @@ -41,6 +45,32 @@ from .provider import YandexMusicProvider +AlbumKind = Literal["music", "podcast", "audiobook"] + + +def classify_album(album_obj: YandexAlbum) -> AlbumKind: + """Classify a Yandex album as music / podcast / audiobook. + + Checks both ``meta_type`` and ``type`` for the substrings "audiobook" / + "podcast". The more specific "audiobook" signal wins over "podcast" on any + field because Yandex tags audiobooks with ``meta_type="podcast"`` *and* + ``type="audiobook"`` — empirically observed in production libraries. + Values are not documented in the yandex_music SDK. + + :param album_obj: Yandex album object. + :return: One of "music", "podcast", "audiobook". + """ + fields = [ + (getattr(album_obj, "meta_type", None) or "").lower(), + (getattr(album_obj, "type", None) or "").lower(), + ] + if any("audiobook" in f for f in fields): + return "audiobook" + if any("podcast" in f for f in fields): + return "podcast" + return "music" + + def get_canonical_provider_name(provider: YandexMusicProvider) -> str: """Return the locale-aware canonical display name for the Yandex Music system account. @@ -68,13 +98,21 @@ def _get_image_url(cover_uri: str | None, size: str = IMAGE_SIZE_LARGE) -> str | return f"https://{cover_uri.replace('%%', size)}" -def parse_artist(provider: YandexMusicProvider, artist_obj: YandexArtist) -> Artist: +def parse_artist( + provider: YandexMusicProvider, + artist_obj: YandexArtist, + *, + about: object | None = None, +) -> Artist: """Parse Yandex artist object to MA Artist model. :param provider: The Yandex Music provider instance. :param artist_obj: Yandex artist object. + :param about: Optional ArtistAbout enrichment (description + listener stats). :return: Music Assistant Artist model. """ + if artist_obj.id is None: + raise InvalidDataError("Yandex artist missing id") artist_id = str(artist_obj.id) artist = Artist( item_id=artist_id, @@ -118,9 +156,45 @@ def parse_artist(provider: YandexMusicProvider, artist_obj: YandexArtist) -> Art ] ) + if about is not None: + description = getattr(about, "description", None) + if description: + artist.metadata.description = description + stats = getattr(about, "stats", None) + monthly = getattr(stats, "last_month_listeners", None) if stats else None + if monthly is not None: + artist.metadata.popularity = max(0, min(100, monthly // 10000)) + return artist +def _album_cover_images( + provider: YandexMusicProvider, album_obj: YandexAlbum +) -> UniqueList[MediaItemImage]: + """Build the UniqueList of images for an album-like object. + + Prefers the templated ``cover_uri`` and falls back to ``og_image`` — matches + the selection rules used for podcasts and audiobooks so all album-like + parsers stay in sync. + """ + images: UniqueList[MediaItemImage] = UniqueList() + image_url: str | None = None + if album_obj.cover_uri: + image_url = _get_image_url(album_obj.cover_uri) + elif album_obj.og_image: + image_url = _get_image_url(album_obj.og_image) + if image_url: + images.append( + MediaItemImage( + type=ImageType.THUMB, + path=image_url, + provider=provider.instance_id, + remotely_accessible=True, + ) + ) + return images + + def parse_album(provider: YandexMusicProvider, album_obj: YandexAlbum) -> Album: """Parse Yandex album object to MA Album model. @@ -128,6 +202,8 @@ def parse_album(provider: YandexMusicProvider, album_obj: YandexAlbum) -> Album: :param album_obj: Yandex album object. :return: Music Assistant Album model. """ + if album_obj.id is None: + raise InvalidDataError("Yandex album missing id") name, version = parse_title_and_version( album_obj.title or "Unknown Album", album_obj.version or None, @@ -184,33 +260,9 @@ def parse_album(provider: YandexMusicProvider, album_obj: YandexAlbum) -> Album: if album_obj.genre: album.metadata.genres = {album_obj.genre} - # Add cover image - if album_obj.cover_uri: - image_url = _get_image_url(album_obj.cover_uri) - if image_url: - album.metadata.images = UniqueList( - [ - MediaItemImage( - type=ImageType.THUMB, - path=image_url, - provider=provider.instance_id, - remotely_accessible=True, - ) - ] - ) - elif album_obj.og_image: - image_url = _get_image_url(album_obj.og_image) - if image_url: - album.metadata.images = UniqueList( - [ - MediaItemImage( - type=ImageType.THUMB, - path=image_url, - provider=provider.instance_id, - remotely_accessible=True, - ) - ] - ) + images = _album_cover_images(provider, album_obj) + if images: + album.metadata.images = images return album @@ -229,6 +281,8 @@ def parse_track( :param lyrics_synced: Whether lyrics are in synced LRC format. :return: Music Assistant Track model. """ + if track_obj.id is None: + raise InvalidDataError("Yandex track missing id") name, version = parse_title_and_version( track_obj.title or "Unknown Track", track_obj.version or None, @@ -306,13 +360,21 @@ def parse_track( def parse_playlist( - provider: YandexMusicProvider, playlist_obj: YandexPlaylist, owner_name: str | None = None + provider: YandexMusicProvider, + playlist_obj: YandexPlaylist, + owner_name: str | None = None, + *, + is_dynamic: bool = False, ) -> Playlist: """Parse Yandex playlist object to MA Playlist model. :param provider: The Yandex Music provider instance. :param playlist_obj: Yandex playlist object. :param owner_name: Optional owner name override. + :param is_dynamic: Mark the playlist as dynamic so Music Assistant does + not long-cache its content. Yandex regenerates "Playlist of the Day", + "DejaVu", "Premiere" etc. on a schedule, and those need a fresh read + on every browse so users actually see the updated selection. :return: Music Assistant Playlist model. """ # Playlist ID in Yandex is a combination of owner uid and playlist kind @@ -351,6 +413,7 @@ def parse_playlist( ) }, is_editable=is_editable, + is_dynamic=is_dynamic, ) # Metadata @@ -389,3 +452,228 @@ def parse_playlist( ) return playlist + + +def parse_podcast(provider: YandexMusicProvider, album_obj: YandexAlbum) -> Podcast: + """Parse Yandex album (meta_type=podcast) to MA Podcast model. + + :param provider: The Yandex Music provider instance. + :param album_obj: Yandex album object classified as a podcast. + :return: Music Assistant Podcast model. + """ + if album_obj.id is None: + raise InvalidDataError("Yandex podcast missing id") + name, _ = parse_title_and_version( + album_obj.title or "Unknown Podcast", + album_obj.version or None, + ) + podcast_id = str(album_obj.id) + available = album_obj.available or False + + # Publisher: prefer labels[0].name; fall back to first artist name + publisher: str | None = None + labels = getattr(album_obj, "labels", None) + if labels: + first = labels[0] + label_name = getattr(first, "name", None) if not isinstance(first, str) else first + if label_name: + publisher = label_name + if not publisher and album_obj.artists: + first_artist = album_obj.artists[0] + if first_artist.name: + publisher = first_artist.name + + podcast = Podcast( + item_id=podcast_id, + provider=provider.instance_id, + name=name, + provider_mappings={ + ProviderMapping( + item_id=podcast_id, + provider_domain=provider.domain, + provider_instance=provider.instance_id, + audio_format=AudioFormat(content_type=ContentType.UNKNOWN), + url=f"{WEB_BASE_URL}/album/{podcast_id}", + available=available, + ) + }, + publisher=publisher, + total_episodes=album_obj.track_count, + ) + + description = album_obj.description or album_obj.short_description + if description: + podcast.metadata.description = description + if album_obj.content_warning: + podcast.metadata.explicit = album_obj.content_warning == "explicit" + + images = _album_cover_images(provider, album_obj) + if images: + podcast.metadata.images = images + + if album_obj.genre: + podcast.metadata.genres = {album_obj.genre} + else: + podcast.metadata.genres = {"Spoken Word"} + + if album_obj.release_date: + with suppress(ValueError): + podcast.metadata.release_date = datetime.fromisoformat(album_obj.release_date) + + return podcast + + +def parse_podcast_episode( + provider: YandexMusicProvider, + track_obj: YandexTrack, + podcast: Podcast, + position: int = 0, +) -> PodcastEpisode: + """Parse Yandex track (episode of a podcast album) to MA PodcastEpisode. + + :param provider: The Yandex Music provider instance. + :param track_obj: Yandex track object. + :param podcast: Parent Podcast object. + :param position: 1-based episode index (0 if unknown). + :return: Music Assistant PodcastEpisode model. + """ + if track_obj.id is None: + raise InvalidDataError("Yandex podcast episode missing id") + episode_id = str(track_obj.id) + available = track_obj.available or False + duration = (track_obj.duration_ms or 0) // 1000 + + episode_name = track_obj.title or (f"Episode {position}" if position else "Unknown Episode") + episode = PodcastEpisode( + item_id=episode_id, + provider=provider.instance_id, + name=episode_name, + duration=duration, + podcast=podcast, + position=position, + provider_mappings={ + ProviderMapping( + item_id=episode_id, + provider_domain=provider.domain, + provider_instance=provider.instance_id, + audio_format=AudioFormat(content_type=ContentType.UNKNOWN), + url=f"{WEB_BASE_URL}/track/{episode_id}", + available=available, + ) + }, + ) + + if track_obj.short_description: + episode.metadata.description = track_obj.short_description + if track_obj.content_warning: + episode.metadata.explicit = track_obj.content_warning == "explicit" + + # Track cover → fall back to podcast cover + if track_obj.cover_uri: + image_url = _get_image_url(track_obj.cover_uri) + if image_url: + episode.metadata.images = UniqueList( + [ + MediaItemImage( + type=ImageType.THUMB, + path=image_url, + provider=provider.instance_id, + remotely_accessible=True, + ) + ] + ) + elif track_obj.og_image: + image_url = _get_image_url(track_obj.og_image) + if image_url: + episode.metadata.images = UniqueList( + [ + MediaItemImage( + type=ImageType.THUMB, + path=image_url, + provider=provider.instance_id, + remotely_accessible=True, + ) + ] + ) + if not episode.metadata.images and podcast.metadata.images: + episode.metadata.images = UniqueList(podcast.metadata.images) + + return episode + + +def parse_audiobook(provider: YandexMusicProvider, album_obj: YandexAlbum) -> Audiobook: + """Parse Yandex album (meta_type=audiobook) to MA Audiobook model. + + :param provider: The Yandex Music provider instance. + :param album_obj: Yandex album object classified as an audiobook. + :return: Music Assistant Audiobook model. Chapters and duration are filled + by the provider's get_audiobook() method after loading album tracks. + """ + if album_obj.id is None: + raise InvalidDataError("Yandex audiobook missing id") + name, _ = parse_title_and_version( + album_obj.title or "Unknown Audiobook", + album_obj.version or None, + ) + audiobook_id = str(album_obj.id) + available = album_obj.available or False + + # Publisher: prefer labels[0]; fall back to nothing (authors sit on artists) + publisher: str | None = None + labels = getattr(album_obj, "labels", None) + if labels: + first = labels[0] + label_name = getattr(first, "name", None) if not isinstance(first, str) else first + if label_name: + publisher = label_name + + authors: UniqueList[str] = UniqueList() + if album_obj.artists: + for artist in album_obj.artists: + if artist.name: + authors.append(artist.name) + + audiobook = Audiobook( + item_id=audiobook_id, + provider=provider.instance_id, + name=name, + provider_mappings={ + ProviderMapping( + item_id=audiobook_id, + provider_domain=provider.domain, + provider_instance=provider.instance_id, + audio_format=AudioFormat(content_type=ContentType.UNKNOWN), + url=f"{WEB_BASE_URL}/album/{audiobook_id}", + available=available, + ) + }, + publisher=publisher, + authors=authors, + narrators=UniqueList(), + duration=0, + ) + + description = album_obj.description or album_obj.short_description + if description: + audiobook.metadata.description = description + if album_obj.content_warning: + audiobook.metadata.explicit = album_obj.content_warning == "explicit" + + images = _album_cover_images(provider, album_obj) + if images: + audiobook.metadata.images = images + + if album_obj.genre: + audiobook.metadata.genres = {album_obj.genre} + else: + audiobook.metadata.genres = {"Spoken Word"} + + if album_obj.release_date: + with suppress(ValueError): + audiobook.metadata.release_date = datetime.fromisoformat(album_obj.release_date) + + listening_finished = getattr(album_obj, "listening_finished", None) + if listening_finished is not None: + audiobook.fully_played = bool(listening_finished) + + return audiobook diff --git a/music_assistant/providers/yandex_music/presets.py b/music_assistant/providers/yandex_music/presets.py new file mode 100644 index 0000000000..49e085208d --- /dev/null +++ b/music_assistant/providers/yandex_music/presets.py @@ -0,0 +1,51 @@ +"""Shared helpers for user-defined wave presets. + +Both the settings UI (in ``__init__.py``) and the Browse handler (in +``provider.py``) need to read the same JSON-encoded preset store. +Keeping the decoding + validation in one place avoids schema drift. +""" + +from __future__ import annotations + +import json + + +def parse_stored_presets(raw: object) -> list[dict[str, str]]: + """Decode the hidden JSON wave-presets store into a sanitised list. + + Only entries with a non-empty ``name`` string are kept. The optional + ``diversity`` / ``moodEnergy`` / ``language`` fields are carried through + as long as they are non-empty *after stripping whitespace* — Yandex + would reject rotor seeds like ``settingDiversity: `` (space) with a 4xx, + so we never pass such values through. Stripped form is stored so the + downstream seed builder gets the canonical value. Any other keys are + dropped. Malformed JSON, non-list roots or non-dict items yield an + empty list — the UI treats that as "no presets yet". + + :param raw: Value read from the ``wave_presets_data`` config entry. + :return: List of sanitised preset dicts in source order. + """ + if not isinstance(raw, str) or not raw.strip(): + return [] + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return [] + if not isinstance(parsed, list): + return [] + result: list[dict[str, str]] = [] + for item in parsed: + if not isinstance(item, dict): + continue + name = item.get("name") + if not isinstance(name, str) or not name.strip(): + continue + clean: dict[str, str] = {"name": name.strip()} + for key in ("diversity", "moodEnergy", "language"): + val = item.get(key) + if isinstance(val, str): + val = val.strip() + if val: + clean[key] = val + result.append(clean) + return result diff --git a/music_assistant/providers/yandex_music/provider.py b/music_assistant/providers/yandex_music/provider.py index 5f013008e5..3f65636585 100644 --- a/music_assistant/providers/yandex_music/provider.py +++ b/music_assistant/providers/yandex_music/provider.py @@ -5,12 +5,13 @@ import asyncio import logging import random +import uuid from collections.abc import AsyncGenerator, Sequence from datetime import UTC, datetime from io import BytesIO from typing import TYPE_CHECKING, Any -from music_assistant_models.enums import ImageType, MediaType, ProviderFeature +from music_assistant_models.enums import ImageType, MediaType, ProviderFeature, StreamType from music_assistant_models.errors import ( InvalidDataError, LoginFailed, @@ -21,23 +22,30 @@ from music_assistant_models.media_items import ( Album, Artist, + Audiobook, BrowseFolder, ItemMapping, + MediaItemChapter, MediaItemImage, MediaItemType, Playlist, + Podcast, + PodcastEpisode, ProviderMapping, RecommendationFolder, SearchResults, Track, UniqueList, ) +from music_assistant_models.streamdetails import StreamDetails from PIL import Image as PilImage +from ya_passport_auth import SecretStr from music_assistant.controllers.cache import use_cache from music_assistant.models.music_provider import MusicProvider from .api_client import YandexMusicClient +from .auth import refresh_credentials_via_passport, refresh_music_token from .constants import ( BROWSE_INITIAL_TRACKS, BROWSE_NAMES_EN, @@ -46,17 +54,27 @@ CONF_BASE_URL, CONF_LIKED_TRACKS_MAX_TRACKS, CONF_MY_WAVE_MAX_TRACKS, + CONF_QUALITY, + CONF_REFRESH_TOKEN, CONF_TOKEN, + CONF_WAVE_PRESETS_DATA, + CONF_X_TOKEN, DEFAULT_BASE_URL, DISCOVERY_INITIAL_TRACKS, FOR_YOU_FOLDER_ID, IMAGE_SIZE_MEDIUM, LIKED_TRACKS_PLAYLIST_ID, + LISTENING_HISTORY_FOLDER_ID, MY_WAVE_BATCH_SIZE, + MY_WAVE_MODES_FOLDER_ID, MY_WAVE_PLAYLIST_ID, + MY_WAVE_PRESETS_FOLDER_ID, MY_WAVES_FOLDER_ID, MY_WAVES_SET_FOLDER_ID, + PINNED_ITEMS_FOLDER_ID, PLAYLIST_ID_SPLITTER, + QUALITY_BALANCED, + QUALITY_SUPERB, RADIO_FOLDER_ID, RADIO_TRACK_ID_SEP, ROTOR_STATION_MY_WAVE, @@ -70,6 +88,9 @@ TAG_SLUG_CATEGORY, TRACK_BATCH_SIZE, WAVE_CATEGORY_DISPLAY_ORDER, + WAVE_MODE_ORDER, + WAVE_MODE_PRESETS, + WAVE_MODE_SEP, WAVES_FOLDER_ID, WAVES_LANDING_FOLDER_ID, ) @@ -77,16 +98,49 @@ _get_image_url as get_image_url, ) from .parsers import ( + classify_album, get_canonical_provider_name, parse_album, parse_artist, + parse_audiobook, parse_playlist, + parse_podcast, + parse_podcast_episode, parse_track, ) +from .presets import parse_stored_presets from .streaming import YandexMusicStreamingManager if TYPE_CHECKING: - from music_assistant_models.streamdetails import StreamDetails + from yandex_music import Album as YandexAlbum + from yandex_music import Track as YandexTrack + + +# MediaType sub-paths that MA's default MusicProvider.browse() understands. +# Used by the Collection dispatcher to delegate nested paths back to core. +_COLLECTION_SUB_FOLDERS: frozenset[str] = frozenset( + {"tracks", "artists", "albums", "playlists", "audiobooks", "podcasts"} +) + + +def _split_wave_mode(station_id: str) -> tuple[str, dict[str, str]]: + """Split a wave-mode station key into its base station ID and preset settings. + + Keys like ``user:onyourwave#discover`` encode a specific preset on top of + the base rotor station. The part before ``#`` is the station ID that goes + to Yandex; the part after is a key into WAVE_MODE_PRESETS. + + :param station_id: Station key, with or without a ``#preset`` suffix. + :return: Tuple of (base_station_id, settings_dict). The suffix, if + present, is always stripped — only the base station goes to + Yandex. ``settings_dict`` is the preset's settings when the suffix + matches a known WAVE_MODE_PRESETS key, or an empty dict otherwise + (unknown suffix → base station fired with no extra seeds). + """ + if WAVE_MODE_SEP not in station_id: + return (station_id, {}) + base, preset = station_id.split(WAVE_MODE_SEP, 1) + return (base, dict(WAVE_MODE_PRESETS.get(preset, {}))) def _parse_radio_item_id(item_id: str) -> tuple[str, str | None]: @@ -104,14 +158,39 @@ def _parse_radio_item_id(item_id: str) -> tuple[str, str | None]: return (item_id, None) +def _extract_chapter_map_from_album(album: YandexAlbum) -> tuple[list[str], list[int]]: + """Flatten an audiobook album's volumes into (chapter_track_ids, chapter_durations_ms). + + Shared by ``_get_audiobook_stream_details`` and ``_resolve_audiobook_chapter_map`` + so the two code paths can't drift (e.g. when we later filter bad tracks). + """ + chapter_ids: list[str] = [] + chapter_durations_ms: list[int] = [] + for disc in album.volumes or []: + for track_obj in disc: + chapter_ids.append(str(track_obj.id)) + chapter_durations_ms.append(int(track_obj.duration_ms or 0)) + return chapter_ids, chapter_durations_ms + + class _WaveState: - """Per-station mutable state for rotor wave playback.""" + """Per-station mutable state for rotor wave playback. + + Holds both the new session-based rotor identifiers (`session_id`) and the + legacy stations-based ones (`batch_id`). Call sites prefer `session_id` + when present; `batch_id` is still carried because feedback events anchor + to a specific batch within the session. + """ def __init__(self) -> None: + self.session_id: str | None = None self.batch_id: str | None = None self.last_track_id: str | None = None + self.playlist_next_cursor: str | None = None self.seen_track_ids: set[str] = set() self.radio_started_sent: bool = False + self.prefetched: list[Any] = [] + self.settings: dict[str, str] = {} self.lock: asyncio.Lock = asyncio.Lock() @@ -120,14 +199,17 @@ class YandexMusicProvider(MusicProvider): _client: YandexMusicClient | None = None _streaming: YandexMusicStreamingManager | None = None - _my_wave_batch_id: str | None = None - _my_wave_last_track_id: str | None = None # last track id for "Load more" (API queue param) - _my_wave_playlist_next_cursor: str | None = None # first_track_id for next playlist page - _my_wave_radio_started_sent: bool = False - _my_wave_seen_track_ids: set[str] # Track IDs seen in current My Wave session - _my_wave_lock: asyncio.Lock # Protects My Wave mutable state - _wave_states: dict[str, _WaveState] # Per-station state for tagged wave stations + _wave_states: dict[str, _WaveState] # Per-station state (incl. My Wave) _wave_bg_colors: dict[str, str] # image_url -> hex bg color for transparent covers + # Short-lived cache to dedupe the three library syncs (albums/podcasts/audiobooks) + # that all derive from the same liked-albums endpoint. + _liked_albums_cache: tuple[float, list[YandexAlbum]] | None = None + _liked_albums_lock: asyncio.Lock + # Per-audiobook cache of (chapter_track_ids, chapter_durations_ms) used to + # report playback progress per chapter via play_audio. + _audiobook_chapter_cache: dict[str, tuple[list[str], list[int]]] + # Stable play_id per audiobook session, cleared in on_streamed. + _audiobook_play_ids: dict[str, str] @property def client(self) -> YandexMusicClient: @@ -154,24 +236,127 @@ def _get_browse_names(self) -> dict[str, str]: use_russian = False return BROWSE_NAMES_RU if use_russian else BROWSE_NAMES_EN + async def _reauth_via_refresh_token( + self, x_token: str, refresh_token: str, base_url: str, original_err: Exception + ) -> None: + """Silently re-issue full credentials when x_token refresh fails. + + Device-flow accounts have a refresh_token that can mint a new + x_token + refresh_token + music_token without any user interaction. + Persists the rotated triple and connects the client. Any failure + here is terminal — clears all credentials and forces re-auth. + """ + try: + new_creds = await refresh_credentials_via_passport( + SecretStr(x_token), SecretStr(refresh_token) + ) + except ResourceTemporarilyUnavailable as err2: + # Transient Passport failure — keep creds, let MA retry later + self.logger.warning( + "Credential refresh temporarily unavailable: %s", type(err2).__name__ + ) + raise ProviderUnavailableError( + "Unable to refresh credentials right now. Please try again later." + ) from err2 + except LoginFailed as err2: + self.logger.warning("Session and refresh tokens are both expired") + self._update_config_value(CONF_TOKEN, None, encrypted=True) + self._update_config_value(CONF_X_TOKEN, None, encrypted=True) + self._update_config_value(CONF_REFRESH_TOKEN, None, encrypted=True) + raise LoginFailed("Session expired. Please re-authenticate.") from err2 + + new_music_token = new_creds.music_token + new_refresh_token = new_creds.refresh_token + if new_music_token is None or new_refresh_token is None: + self._update_config_value(CONF_TOKEN, None, encrypted=True) + self._update_config_value(CONF_X_TOKEN, None, encrypted=True) + self._update_config_value(CONF_REFRESH_TOKEN, None, encrypted=True) + raise LoginFailed( + "Credential refresh returned an incomplete response." + ) from original_err + + self._update_config_value(CONF_TOKEN, new_music_token.get_secret(), encrypted=True) + self._update_config_value(CONF_X_TOKEN, new_creds.x_token.get_secret(), encrypted=True) + self._update_config_value( + CONF_REFRESH_TOKEN, new_refresh_token.get_secret(), encrypted=True + ) + self._client = YandexMusicClient(new_music_token, base_url=base_url) + await self._client.connect() + self.logger.info("Re-issued credentials silently from refresh token") + async def handle_async_init(self) -> None: """Handle async initialization of the provider.""" token = self.config.get_value(CONF_TOKEN) - if not token: - raise LoginFailed("No Yandex Music token provided") - + x_token = self.config.get_value(CONF_X_TOKEN) + refresh_token = self.config.get_value(CONF_REFRESH_TOKEN) base_url = self.config.get_value(CONF_BASE_URL, DEFAULT_BASE_URL) - self._client = YandexMusicClient(str(token), base_url=str(base_url)) - await self._client.connect() + + if not token and not x_token: + raise LoginFailed("No Yandex Music token provided. Please authenticate.") + + # Try existing music token first (fast path) + if token: + try: + self._client = YandexMusicClient(SecretStr(str(token)), base_url=str(base_url)) + await self._client.connect() + except LoginFailed: + self.logger.warning("Music token is invalid or expired") + # Clear the dead token so restarts go straight to refresh + self._update_config_value(CONF_TOKEN, None, encrypted=True) + if x_token: + self.logger.info("Attempting to refresh from session token") + token = None + self._client = None + else: + raise + + # Refresh from x_token if music token absent or failed + if not token and x_token: + try: + new_music_token = await refresh_music_token(SecretStr(str(x_token))) + self._update_config_value(CONF_TOKEN, new_music_token.get_secret(), encrypted=True) + self._client = YandexMusicClient(new_music_token, base_url=str(base_url)) + await self._client.connect() + self.logger.info("Refreshed music token from session token") + except LoginFailed as err: + # x_token refresh failed. If a refresh_token is available + # (device-flow accounts), try silent re-issue of the full + # credential triple before giving up. + if refresh_token: + await self._reauth_via_refresh_token( + str(x_token), str(refresh_token), str(base_url), err + ) + else: + # Definitive auth failure — clear dead credentials + self.logger.warning("Session token is invalid or expired") + self._update_config_value(CONF_TOKEN, None, encrypted=True) + self._update_config_value(CONF_X_TOKEN, None, encrypted=True) + raise LoginFailed("Session token expired. Please re-authenticate.") from err + except asyncio.CancelledError: + raise + except Exception as err: + # Transient/network failure — keep credentials for retry + self.logger.warning( + "Session token refresh failed (network): %s", + type(err).__name__, + ) + raise ProviderUnavailableError( + "Unable to refresh music token right now. Please try again later." + ) from err + # Suppress yandex_music library DEBUG dumps (full API request/response JSON) logging.getLogger("yandex_music").setLevel(self.logger.level + 10) + # Propagate the MA instance log level to our per-module loggers + # (api_client, streaming, parsers, auth) so DEBUG hooks there actually + # print when MA is set to DEBUG for this provider. + logging.getLogger("music_assistant.providers.yandex_music").setLevel(self.logger.level) self._streaming = YandexMusicStreamingManager(self) - # Initialize My Wave duplicate tracking - self._my_wave_seen_track_ids = set() - self._my_wave_lock = asyncio.Lock() - # Initialize per-station wave state dict + # Per-station wave state (incl. My Wave under ROTOR_STATION_MY_WAVE). + # Entries are created lazily by _get_wave_state() on first access. self._wave_states = {} self._wave_bg_colors = {} + self._liked_albums_lock, self._liked_albums_cache = asyncio.Lock(), None + self._audiobook_chapter_cache, self._audiobook_play_ids = {}, {} self.logger.info("Successfully connected to Yandex Music") async def unload(self, is_removed: bool = False) -> None: @@ -183,6 +368,8 @@ async def unload(self, is_removed: bool = False) -> None: await self._client.disconnect() self._client = None self._streaming = None + self._audiobook_chapter_cache.clear() + self._audiobook_play_ids.clear() await super().unload(is_removed) def get_item_mapping(self, media_type: MediaType | str, key: str, name: str) -> ItemMapping: @@ -202,7 +389,9 @@ def get_item_mapping(self, media_type: MediaType | str, key: str, name: str) -> name=name, ) - async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: + async def browse( # noqa: PLR0911, PLR0915 + self, path: str + ) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: """Browse provider items with locale-based folder names and My Wave. Root level shows My Wave, artists, albums, liked tracks, playlists. Names @@ -219,15 +408,76 @@ async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | Brow sub_subpath = path_parts[1] if len(path_parts) > 1 else None if subpath == MY_WAVE_PLAYLIST_ID: - async with self._my_wave_lock: + async with self._get_wave_state(ROTOR_STATION_MY_WAVE).lock: return await self._browse_my_wave(path, sub_subpath) + # Wave modes — accept two equivalent URL forms so both browse + # navigation (slash form "my_wave_modes/", emitted by our + # listing) and MA's play-time reconstruction (underscore form + # "my_wave_modes_", built as "://") work. + mode_preset: str | None = None + if subpath == MY_WAVE_MODES_FOLDER_ID and sub_subpath is None: + return self._browse_my_wave_modes_list(path) + if subpath == MY_WAVE_MODES_FOLDER_ID and sub_subpath is not None: + mode_preset = sub_subpath if sub_subpath != "next" else None + if mode_preset is None: + return [] + load_more_modes = len(path_parts) > 2 and path_parts[2] == "next" + elif subpath and subpath.startswith(f"{MY_WAVE_MODES_FOLDER_ID}_"): + mode_preset = subpath[len(MY_WAVE_MODES_FOLDER_ID) + 1 :] + load_more_modes = sub_subpath == "next" + if mode_preset is not None: + if mode_preset not in WAVE_MODE_PRESETS: + return [] + station_key = f"{ROTOR_STATION_MY_WAVE}{WAVE_MODE_SEP}{mode_preset}" + async with self._get_wave_state(station_key).lock: + return await self._browse_my_wave_mode(path, station_key, load_more_modes) + + # User-saved wave presets — same dual-form handling. + preset_idx: int | None = None + load_more_presets = False + if subpath == MY_WAVE_PRESETS_FOLDER_ID and sub_subpath is None: + return self._browse_user_presets_list(path, self._get_user_wave_presets()) + if subpath == MY_WAVE_PRESETS_FOLDER_ID and sub_subpath is not None: + try: + preset_idx = int(sub_subpath) + except ValueError: + return [] + load_more_presets = len(path_parts) > 2 and path_parts[2] == "next" + elif subpath and subpath.startswith(f"{MY_WAVE_PRESETS_FOLDER_ID}_"): + try: + preset_idx = int(subpath[len(MY_WAVE_PRESETS_FOLDER_ID) + 1 :]) + except ValueError: + return [] + load_more_presets = sub_subpath == "next" + if preset_idx is not None: + user_presets = self._get_user_wave_presets() + if not 0 <= preset_idx < len(user_presets): + return [] + preset_data = user_presets[preset_idx] + station_key = f"{ROTOR_STATION_MY_WAVE}{WAVE_MODE_SEP}preset_{preset_idx}" + wave = self._get_wave_state(station_key) + # Stash user-chosen settings so _fetch_rotor_session_batch sends them + wave.settings = { + k: v + for k, v in preset_data.items() + if k in ("diversity", "moodEnergy", "language") and v + } + async with wave.lock: + return await self._browse_my_wave_mode(path, station_key, load_more_presets) + # For You folder (picks + mixes) if subpath == FOR_YOU_FOLDER_ID: return await self._browse_for_you(path, path_parts) - # Collection folder (library items) + # Collection folder (library items). Two shapes: + # ://collection → listing of library sub-folders + # ://collection/ → delegate to MA's library handler + # The nested form is what lets MA's "back" button return here (strip + # last /-segment) instead of dumping the user at the provider root. if subpath == COLLECTION_FOLDER_ID: + if sub_subpath in _COLLECTION_SUB_FOLDERS: + return await super().browse(f"{self.instance_id}://{sub_subpath}") return await self._browse_collection(path) # Handle picks/ path (mood, activity, era, genres) @@ -246,6 +496,14 @@ async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | Brow if subpath == MY_WAVES_SET_FOLDER_ID: return await self._browse_vibe_sets(path, path_parts) + # Pinned items folder + if subpath == PINNED_ITEMS_FOLDER_ID: + return await self._browse_pins() + + # Listening history folder + if subpath == LISTENING_HISTORY_FOLDER_ID: + return await self._browse_history() + # Handle waves_landing/ path (Featured Waves from /landing-blocks/waves) if subpath == WAVES_LANDING_FOLDER_ID: return await self._browse_waves_landing(path, path_parts) @@ -258,6 +516,8 @@ async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | Brow "albums", "tracks", "playlists", + "audiobooks", + "podcasts", LIKED_TRACKS_PLAYLIST_ID, WAVES_FOLDER_ID, RADIO_FOLDER_ID, @@ -266,6 +526,8 @@ async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | Brow WAVES_LANDING_FOLDER_ID, FOR_YOU_FOLDER_ID, COLLECTION_FOLDER_ID, + PINNED_ITEMS_FOLDER_ID, + LISTENING_HISTORY_FOLDER_ID, } if subpath and subpath not in _known_folders: # Handle direct wave station_id (e.g. "activity:workout") passed when @@ -297,6 +559,27 @@ async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | Brow is_playable=True, ) ) + # Wave modes folder (P4): discover / calm / active / language presets + folders.append( + BrowseFolder( + item_id=MY_WAVE_MODES_FOLDER_ID, + provider=self.instance_id, + path=f"{base}{MY_WAVE_MODES_FOLDER_ID}", + name=names.get(MY_WAVE_MODES_FOLDER_ID, "Wave Modes"), + is_playable=False, + ) + ) + # User-defined wave presets (P8) — shown only when any configured. + if self._get_user_wave_presets(): + folders.append( + BrowseFolder( + item_id=MY_WAVE_PRESETS_FOLDER_ID, + provider=self.instance_id, + path=f"{base}{MY_WAVE_PRESETS_FOLDER_ID}", + name=names.get(MY_WAVE_PRESETS_FOLDER_ID, "My Presets"), + is_playable=False, + ) + ) # For You folder — Picks + Mixes (Яндекс «Для вас») folders.append( BrowseFolder( @@ -347,6 +630,26 @@ async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | Brow is_playable=False, ) ) + # Pinned items — user-pinned artists/albums/playlists/waves + folders.append( + BrowseFolder( + item_id=PINNED_ITEMS_FOLDER_ID, + provider=self.instance_id, + path=f"{base}{PINNED_ITEMS_FOLDER_ID}", + name=names.get(PINNED_ITEMS_FOLDER_ID, "Pinned"), + is_playable=False, + ) + ) + # Listening history — recently played tracks/albums + folders.append( + BrowseFolder( + item_id=LISTENING_HISTORY_FOLDER_ID, + provider=self.instance_id, + path=f"{base}{LISTENING_HISTORY_FOLDER_ID}", + name=names.get(LISTENING_HISTORY_FOLDER_ID, "Listening History"), + is_playable=False, + ) + ) if len(folders) == 1: return await self.browse(folders[0].path) return folders @@ -354,12 +657,13 @@ async def browse(self, path: str) -> Sequence[MediaItemType | ItemMapping | Brow async def _browse_my_wave( self, path: str, sub_subpath: str | None ) -> list[Track | BrowseFolder]: - """Browse My Wave tracks (must be called under _my_wave_lock). + """Browse My Wave tracks (must be called under the My Wave state lock). :param path: Full browse path. :param sub_subpath: Sub-path part ('next' for load more, or track_id cursor). :return: List of Track and optional BrowseFolder for "Load more". """ + wave = self._get_wave_state(ROTOR_STATION_MY_WAVE) max_tracks_config = int( self.config.get_value(CONF_MY_WAVE_MAX_TRACKS) or 150 # type: ignore[arg-type] ) @@ -379,11 +683,11 @@ async def _browse_my_wave( # Reset seen tracks on fresh browse (not "load more") if sub_subpath != "next": - self._my_wave_seen_track_ids = set() + wave.seen_track_ids = set() queue: str | int | None = None if sub_subpath == "next": - queue = self._my_wave_last_track_id + queue = wave.last_track_id elif sub_subpath: queue = sub_subpath @@ -396,24 +700,25 @@ async def _browse_my_wave( if total_track_count >= effective_limit: break - yandex_tracks, batch_id = await self.client.get_my_wave_tracks(queue=queue) + # On a fresh browse (non-"next"), honour any sub_subpath cursor override + # by seeding wave.last_track_id so the helper picks it up. + if queue is not None: + wave.last_track_id = str(queue) + yandex_tracks, batch_id = await self._fetch_rotor_session_batch( + wave, ROTOR_STATION_MY_WAVE + ) if batch_id: - self._my_wave_batch_id = batch_id last_batch_id = batch_id - if not self._my_wave_radio_started_sent and yandex_tracks: - sent = await self.client.send_rotor_station_feedback( - ROTOR_STATION_MY_WAVE, - "radioStarted", - batch_id=batch_id, - ) + if not wave.radio_started_sent and yandex_tracks: + sent = await self._send_wave_feedback(wave, ROTOR_STATION_MY_WAVE, "radioStarted") if sent: - self._my_wave_radio_started_sent = True + wave.radio_started_sent = True first_track_id_this_batch = None for yt in yandex_tracks: if total_track_count >= effective_limit: break - track = self._parse_my_wave_track(yt, self._my_wave_seen_track_ids) + track = self._parse_my_wave_track(yt, wave.seen_track_ids) if track is None: continue all_tracks.append(track) @@ -424,7 +729,7 @@ async def _browse_my_wave( first_track_id_this_batch = track_id if first_track_id_this_batch is not None: - self._my_wave_last_track_id = first_track_id_this_batch + wave.last_track_id = first_track_id_this_batch if ( first_track_id_this_batch is None or not batch_id @@ -449,15 +754,171 @@ async def _browse_my_wave( ) return all_tracks - def _parse_my_wave_track(self, yt: Any, seen_ids: set[str]) -> Track | None: + def _get_user_wave_presets(self) -> list[dict[str, str]]: + """Decode user-defined wave presets from the hidden JSON config key. + + Thin wrapper around :func:`presets.parse_stored_presets` so browse + code and settings actions use the exact same parsing — avoids schema + drift when preset fields are added or renamed. + """ + return parse_stored_presets(self.config.get_value(CONF_WAVE_PRESETS_DATA)) + + def _browse_user_presets_list( + self, path: str, presets: list[dict[str, str]] + ) -> list[BrowseFolder]: + """Return one playable BrowseFolder per configured user preset. + + ``path`` is nested (``my_wave_presets/``) so MA's back-nav — + which strips the last ``/``-segment — returns the user to the + listing instead of the provider root. ``item_id`` uses the + underscore form (``my_wave_presets_``) because MA rebuilds a + playable folder's path from its item_id at play time. The browse + dispatcher accepts both forms. + + :param path: Current browse path. + :param presets: Sanitized presets from ``_get_user_wave_presets``. + :return: List of playable BrowseFolder entries. + """ + base = path if path.endswith("/") else f"{path}/" + folders: list[BrowseFolder] = [] + for idx, preset in enumerate(presets): + folders.append( + BrowseFolder( + item_id=f"{MY_WAVE_PRESETS_FOLDER_ID}_{idx}", + provider=self.instance_id, + path=f"{base}{idx}", + name=preset.get("name", f"Preset {idx + 1}"), + is_playable=True, + ) + ) + return folders + + def _browse_my_wave_modes_list(self, path: str) -> list[BrowseFolder]: + """Return the 11 wave-mode entries as playable browse folders. + + Same dual-form contract as user presets: nested ``path`` keeps + back-navigation intact, underscore ``item_id`` survives MA's + play-time reconstruction. + + :param path: Browse path the user navigated into. + :return: Ordered list of BrowseFolder entries, one per preset. + """ + names = self._get_browse_names() + base = path if path.endswith("/") else f"{path}/" + folders: list[BrowseFolder] = [] + for preset in WAVE_MODE_ORDER: + name_key = f"wave_mode_{preset}" + folders.append( + BrowseFolder( + item_id=f"{MY_WAVE_MODES_FOLDER_ID}_{preset}", + provider=self.instance_id, + path=f"{base}{preset}", + name=names.get(name_key, preset.replace("_", " ").title()), + is_playable=True, + ) + ) + return folders + + async def _browse_my_wave_mode( + self, path: str, station_key: str, load_more: bool + ) -> list[Track | BrowseFolder]: + """Fetch a batch of tracks for a specific wave-mode preset. + + Reuses the session-API machinery: tracks live in + ``_wave_states[station_key]`` where station_key is + ``user:onyourwave#{preset}``. Tracks carry composite item_ids that + route feedback back to this state. + + :param path: Full browse path to this preset. + :param station_key: Station key with a ``#preset`` suffix. + :param load_more: True when called for ``.../next`` pagination. + :return: Tracks + optional "Load more" folder. + """ + wave = self._get_wave_state(station_key) + max_tracks_config = int( + self.config.get_value(CONF_MY_WAVE_MAX_TRACKS) or 150 # type: ignore[arg-type] + ) + batch_size_config = MY_WAVE_BATCH_SIZE + effective_limit = min( + BROWSE_INITIAL_TRACKS if not load_more else max_tracks_config, + max_tracks_config, + ) + max_batches = batch_size_config if not load_more else 1 + + if not load_more: + wave.seen_track_ids = set() + + all_tracks: list[Track | BrowseFolder] = [] + last_batch_id: str | None = None + total_track_count = 0 + + for _ in range(max_batches): + if total_track_count >= effective_limit: + break + yandex_tracks, batch_id = await self._fetch_rotor_session_batch(wave, station_key) + if batch_id: + last_batch_id = batch_id + if not wave.radio_started_sent and yandex_tracks: + sent = await self._send_wave_feedback(wave, station_key, "radioStarted") + if sent: + wave.radio_started_sent = True + first_track_id_this_batch: str | None = None + for yt in yandex_tracks: + if total_track_count >= effective_limit: + break + track = self._parse_my_wave_track(yt, wave.seen_track_ids, station_key=station_key) + if track is None: + continue + all_tracks.append(track) + total_track_count += 1 + track_id = track.item_id.split(RADIO_TRACK_ID_SEP, 1)[0] + if first_track_id_this_batch is None: + first_track_id_this_batch = track_id + if first_track_id_this_batch is not None: + wave.last_track_id = first_track_id_this_batch + if ( + first_track_id_this_batch is None + or not batch_id + or not yandex_tracks + or total_track_count >= effective_limit + ): + break + + if last_batch_id and total_track_count < max_tracks_config: + names = self._get_browse_names() + next_name = "Ещё" if names == BROWSE_NAMES_RU else "Load more" + all_tracks.append( + BrowseFolder( + item_id="next", + provider=self.instance_id, + path=f"{path.rstrip('/')}/next", + name=next_name, + is_playable=False, + ) + ) + return all_tracks + + def _parse_my_wave_track( + self, + yt: Any, + seen_ids: set[str], + *, + station_key: str = ROTOR_STATION_MY_WAVE, + ) -> Track | None: """Parse a Yandex track into a My Wave Track with composite item_id. Extracts the track_id, checks for duplicates in the seen_ids set, - sets composite item_id (track_id@station_id), and updates provider_mappings. - Callers using shared state must hold _my_wave_lock. + sets composite item_id (track_id@station_key) and updates + provider_mappings. `station_key` is the key in `_wave_states` under + which the matching session lives; for preset modes it carries a + `#preset` suffix so `on_played`/`on_streamed` find the right session. + + Callers using shared state must hold the My Wave state lock. :param yt: Yandex track object from rotor station response. :param seen_ids: Set of already-seen track IDs to check and update. + :param station_key: Station key to embed in the composite item_id. + Defaults to the plain My Wave station. :return: Parsed Track with composite item_id, or None if duplicate/invalid. """ try: @@ -475,7 +936,7 @@ def _parse_my_wave_track(self, yt: Any, seen_ids: set[str]) -> Track | None: return None seen_ids.add(track_id) - t.item_id = f"{track_id}{RADIO_TRACK_ID_SEP}{ROTOR_STATION_MY_WAVE}" + t.item_id = f"{track_id}{RADIO_TRACK_ID_SEP}{station_key}" for pm in t.provider_mappings: if pm.provider_instance == self.instance_id: pm.item_id = t.item_id @@ -626,56 +1087,139 @@ async def _browse_collection( ) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: """Browse «Collection» folder — shows library sub-folders (tracks/artists/albums/playlists). + Child ``path`` is nested (``…/collection/tracks``) so MA's "back" + button lands on this listing instead of the provider root. The + dispatcher then strips the ``collection/`` prefix and hands off to + core's default library handler. + :param path: Full browse path. :return: List of library sub-folders. """ names = self._get_browse_names() - base_parts = path.split("//", 1) - root_base = (base_parts[0] + "//") if len(base_parts) > 1 else path.rstrip("/") + "/" + base = path if path.endswith("/") else f"{path}/" folders: list[BrowseFolder] = [] - if ProviderFeature.LIBRARY_TRACKS in self.supported_features: - folders.append( - BrowseFolder( - item_id="tracks", - provider=self.instance_id, - path=f"{root_base}tracks", - name=names["tracks"], - is_playable=True, - ) - ) - if ProviderFeature.LIBRARY_ARTISTS in self.supported_features: - folders.append( - BrowseFolder( - item_id="artists", - provider=self.instance_id, - path=f"{root_base}artists", - name=names["artists"], - is_playable=True, - ) - ) - if ProviderFeature.LIBRARY_ALBUMS in self.supported_features: - folders.append( - BrowseFolder( - item_id="albums", - provider=self.instance_id, - path=f"{root_base}albums", - name=names["albums"], - is_playable=True, - ) - ) - if ProviderFeature.LIBRARY_PLAYLISTS in self.supported_features: + feature_map: tuple[tuple[ProviderFeature, str, bool], ...] = ( + (ProviderFeature.LIBRARY_TRACKS, "tracks", True), + (ProviderFeature.LIBRARY_ARTISTS, "artists", True), + (ProviderFeature.LIBRARY_ALBUMS, "albums", True), + (ProviderFeature.LIBRARY_PLAYLISTS, "playlists", True), + (ProviderFeature.LIBRARY_PODCASTS, "podcasts", False), + (ProviderFeature.LIBRARY_AUDIOBOOKS, "audiobooks", False), + ) + for feature, sub_id, is_playable in feature_map: + if feature not in self.supported_features: + continue folders.append( BrowseFolder( - item_id="playlists", + item_id=sub_id, provider=self.instance_id, - path=f"{root_base}playlists", - name=names["playlists"], - is_playable=True, + path=f"{base}{sub_id}", + name=names[sub_id], + is_playable=is_playable, ) ) return folders + async def _browse_pins(self) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: + """Browse user's pinned items (artists/albums/playlists from Yandex Pins). + + Resolves each pin to its full media item via existing single-item lookups. + Wave pins are skipped — MA has no native concept for them. + + :return: List of resolved media items. + """ + pins_list = await self.client.get_pins() + pins = getattr(pins_list, "pins", None) if pins_list else None + if not pins: + return [] + + items: list[MediaItemType] = [] + for pin in pins: + pin_type = getattr(pin, "type", None) + data = getattr(pin, "data", None) + if data is None: + continue + try: + if pin_type == "artist_item" and getattr(data, "id", None) is not None: + items.append(await self.get_artist(str(data.id))) + elif pin_type == "album_item" and getattr(data, "id", None) is not None: + items.append(await self.get_album(str(data.id))) + elif pin_type == "playlist_item": + uid = getattr(data, "uid", None) + kind = getattr(data, "kind", None) + if uid is not None and kind is not None: + items.append(await self.get_playlist(f"{uid}:{kind}")) + except (MediaNotFoundError, InvalidDataError) as err: + self.logger.debug("Skipping pin %s: %s", pin_type, err) + return items + + async def _browse_history(self) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: + """Browse user's recent listening history (flattened across days). + + Collects ``track_id`` values from each history entry's ``item_id`` + sub-object (``full_model`` is not populated by the current API + response — MarshalX exposes the IDs separately), dedupes, and + batch-resolves them via ``get_tracks`` so the returned Track objects + carry full artist/album/cover metadata. + + Entries without a resolvable ``track_id`` (e.g. album-only context + rows) are skipped silently. Order is preserved — most recent first — + by collecting unique IDs in response order into ``ordered_ids``, + then rebuilding the final list by iterating ``ordered_ids`` and + looking up each batch-fetched track in an id→track map. + + :return: List of recently played Track items. + """ + history = await self.client.get_music_history() + tabs = getattr(history, "history_tabs", None) if history else None + if not tabs: + return [] + + seen_track_ids: set[str] = set() + ordered_ids: list[str] = [] + for tab in tabs: + for group in getattr(tab, "items", None) or []: + for hist_item in getattr(group, "tracks", None) or []: + if getattr(hist_item, "type", None) != "track": + continue + item_id_obj = getattr(getattr(hist_item, "data", None), "item_id", None) + track_key: str | None = None + if isinstance(item_id_obj, dict): + track_key = item_id_obj.get("track_id") or item_id_obj.get("id") + else: + track_key = getattr(item_id_obj, "track_id", None) or getattr( + item_id_obj, "id", None + ) + if not track_key: + continue + track_key = str(track_key) + if track_key in seen_track_ids: + continue + seen_track_ids.add(track_key) + ordered_ids.append(track_key) + + if not ordered_ids: + return [] + + try: + fetched = await self.client.get_tracks(ordered_ids) + except ResourceTemporarilyUnavailable as err: + self.logger.warning("Failed to hydrate history tracks: %s", err) + return [] + + by_id = {str(t.id): t for t in fetched if getattr(t, "id", None) is not None} + tracks: list[Track] = [] + for tid in ordered_ids: + yt = by_id.get(tid) + if yt is None: + continue + try: + tracks.append(parse_track(self, yt)) + except InvalidDataError as err: + self.logger.debug("Skipping history track %s: %s", tid, err) + return tracks + async def _browse_picks( self, path: str, path_parts: list[str] ) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: @@ -831,6 +1375,142 @@ def _get_wave_state(self, station_id: str) -> _WaveState: """ return self._wave_states.setdefault(station_id, _WaveState()) + async def _send_wave_feedback( + self, + wave: _WaveState, + station_id: str, + event_type: str, + *, + track_id: str | None = None, + total_played_seconds: int | None = None, + ) -> bool: + """Route rotor feedback to the session endpoint. + + Requires an active ``wave.session_id`` — rotor feedback is only + meaningful inside the session it originated from. The legacy + stations-based endpoint (``/rotor/station/{id}/feedback``) is no + longer reachable (returns 404 "not-found"), so when there's no + session we skip silently rather than spamming the log. + + This happens when the track's composite item_id was parsed in a + previous provider run (e.g. loaded from MA's library cache) and + the corresponding session_id is not in memory any more. History + reporting via ``play_audio`` still works in that case — only the + rotor recommendation signal is lost. + + :param wave: Station state carrying session_id + batch_id. + :param station_id: Rotor station ID (used only for logging here). + :param event_type: Rotor event type (radioStarted, trackStarted, …). + :param track_id: Yandex track ID the event refers to. + :param total_played_seconds: Seconds played (trackFinished / skip only). + :return: True if the feedback POST succeeded, False when skipped. + """ + if not wave.session_id: + self.logger.debug( + "Skipping rotor feedback %s for %s: no active session", + event_type, + station_id, + ) + return False + return await self.client.rotor_session_feedback( + wave.session_id, + event_type, + track_id=track_id, + total_played_seconds=total_played_seconds, + batch_id=wave.batch_id, + ) + + async def _prefetch_rotor_session(self, station_key: str) -> None: + """Fire-and-forget: fetch the next batch for an active wave session. + + Called from ``on_played`` while a wave track starts playing, so by the + time Music Assistant's DSTM asks for more via ``get_similar_tracks``, + we already have Yandex-curated wave tracks sitting in + ``wave.prefetched`` ready to serve (no extra round-trip). + + No-op when the station has no active session yet (prefetch cannot + safely create one — that requires holding the lock across the + network call and would stall readers), or when the buffer already + has items (avoids burning rate limit). + + Three-phase lock discipline so the network round-trip does not + block browse / drain paths that share the lock: + + 1. Acquire, verify session + empty buffer, snapshot + ``session_id`` and ``last_track_id``, release. + 2. Call ``client.rotor_session_tracks`` **directly** (no + ``_fetch_rotor_session_batch``) — that helper mutates shared + state (session creation, batch_id write) and would race with + other callers now that we hold no lock. The raw client call + only reads the arguments we pass in. + 3. Re-acquire, verify the session hasn't been recycled and the + buffer is still empty, then ``extend``. + + :param station_key: Station key whose state to top up. + """ + wave = self._wave_states.get(station_key) + if wave is None: + return + + async with wave.lock: + if wave.session_id is None or wave.prefetched: + return + session_id = wave.session_id + cursor = wave.last_track_id + + if not cursor: + return # No anchor for the next batch yet; try again later. + + tracks, _ = await self.client.rotor_session_tracks(session_id, current_track_id=str(cursor)) + if not tracks: + return + + async with wave.lock: + # Another task could have restarted the session or filled the + # buffer while we were awaiting the network call; bail in both + # cases to avoid stale extends. + if wave.session_id != session_id or wave.prefetched: + return + wave.prefetched.extend(tracks) + + async def _fetch_rotor_session_batch( + self, wave: _WaveState, station_id: str + ) -> tuple[list[YandexTrack], str | None]: + """Fetch the next rotor-session batch for any station. + + On first call (wave.session_id is None), starts a new rotor session + and records session_id + batch_id on the wave state. On subsequent + calls, paginates via rotor_session_tracks using wave.last_track_id. + + If station_id carries a wave-mode suffix (e.g. "user:onyourwave#discover"), + the suffix maps to a preset in WAVE_MODE_PRESETS and its settings are + merged with wave.settings (wave.settings wins on key conflict). The + base station ID (before "#") is what actually goes to Yandex. + + :param wave: The _WaveState for this station (persists across calls). + :param station_id: Rotor station key (may include a "#preset" suffix). + :return: Tuple of (list of yandex tracks, batch_id or None). + """ + # Session-creation path: no session yet, or we have a session but no + # cursor yet (`tracks` with an empty queue returns a hard-to-debug + # empty batch — starting a fresh session is the same latency but + # actually yields tracks). + if wave.session_id is None or not wave.last_track_id: + base_station, preset_settings = _split_wave_mode(station_id) + merged = {**preset_settings, **wave.settings} + session_id, tracks, batch_id = await self.client.rotor_session_new( + base_station, settings=merged or None + ) + if session_id: + wave.session_id = session_id + else: + tracks, batch_id = await self.client.rotor_session_tracks( + wave.session_id, current_track_id=str(wave.last_track_id) + ) + if batch_id: + wave.batch_id = batch_id + return (tracks, batch_id) + async def _browse_waves( self, path: str, path_parts: list[str] ) -> Sequence[MediaItemType | ItemMapping | BrowseFolder]: @@ -1039,23 +1719,20 @@ async def _browse_wave_station( ) self.logger.debug( - "Browse wave station: station_id=%s path=%s last_track_id=%s", + "Browse wave station: station_id=%s path=%s last_track_id=%s session=%s", station_id, path, state.last_track_id, + state.session_id, ) - yandex_tracks, batch_id = await self.client.get_rotor_station_tracks( - station_id, queue=state.last_track_id - ) - if batch_id: - state.batch_id = batch_id + # Tagged stations (genre:*, mood:*, activity:*, epoch:*) accept the + # same /rotor/session/* endpoint as user:onyourwave / track:{id}, + # verified against the live Yandex API. Reuse the session helper so + # batch_id + session_id stay anchored across browse/play/feedback. + yandex_tracks, _ = await self._fetch_rotor_session_batch(state, station_id) if not state.radio_started_sent and yandex_tracks: - sent = await self.client.send_rotor_station_feedback( - station_id, - "radioStarted", - batch_id=batch_id, - ) + sent = await self._send_wave_feedback(state, station_id, "radioStarted") if sent: state.radio_started_sent = True @@ -1117,11 +1794,15 @@ async def _browse_wave_station( def _extract_wave_item_cover(item: dict[str, Any]) -> tuple[str | None, str | None]: """Extract cover URI and background color from a wave/mix item. + Accepts both camelCase (``compactImageUrl`` — what /landing-blocks/ + actually returns) and snake_case (``compact_image_url`` — retained + for safety if MarshalX ever normalises the payload). + :param item: Wave or mix item dict from the API. :return: (cover_uri, bg_color) tuple where bg_color is a hex string or None. """ agent_uri = item.get("agent", {}).get("cover", {}).get("uri", "") - cover_uri = agent_uri or item.get("compact_image_url") + cover_uri = agent_uri or item.get("compactImageUrl") or item.get("compact_image_url") bg_color = item.get("colors", {}).get("average") return cover_uri, bg_color @@ -1216,7 +1897,9 @@ async def _browse_wave_categories( items = wave_category.get("items", []) result: list[BrowseFolder] = [] for item in items: - station_id = item.get("station_id", "") + # API returns camelCase (`stationId`); keep snake_case as a + # safety net if the payload is ever normalised upstream. + station_id = item.get("stationId") or item.get("station_id") or "" title = item.get("title", "") if not station_id or not title: continue @@ -1307,14 +1990,19 @@ async def search( result = SearchResults() # Determine search type based on requested media types - # Map MediaType to Yandex API search type + # Map MediaType to Yandex API search type. AUDIOBOOK has no dedicated + # Yandex type — it maps to "album" and is filtered by classify_album below. type_mapping = { MediaType.TRACK: "track", MediaType.ALBUM: "album", + MediaType.AUDIOBOOK: "album", MediaType.ARTIST: "artist", MediaType.PLAYLIST: "playlist", + MediaType.PODCAST: "podcast", } - requested_types = [type_mapping[mt] for mt in media_types if mt in type_mapping] + requested_types = list( + dict.fromkeys(type_mapping[mt] for mt in media_types if mt in type_mapping) + ) # Use specific type if only one requested, otherwise search all search_type = requested_types[0] if len(requested_types) == 1 else "all" @@ -1331,13 +2019,34 @@ async def search( except InvalidDataError as err: self.logger.debug("Error parsing track: %s", err) - # Parse albums - if MediaType.ALBUM in media_types and search_result.albums: - for album in search_result.albums.results[:limit]: + # Parse albums — audiobooks are split into the audiobooks bucket via + # classify_album. Yandex-returned podcast albums are handled separately + # through the dedicated `.podcasts` node below. ``limit`` is applied per + # bucket AFTER classification — slicing first would drop audiobooks when + # the first ``limit`` results happen to be music albums (or vice versa). + want_album = MediaType.ALBUM in media_types + want_audiobook = MediaType.AUDIOBOOK in media_types + if (want_album or want_audiobook) and search_result.albums: + album_count = 0 + audiobook_count = 0 + for album in search_result.albums.results: + album_full = not want_album or album_count >= limit + audiobook_full = not want_audiobook or audiobook_count >= limit + if album_full and audiobook_full: + break + kind = classify_album(album) try: - result.albums = [*result.albums, parse_album(self, album)] + if kind == "audiobook" and want_audiobook and not audiobook_full: + result.audiobooks = [ + *result.audiobooks, + parse_audiobook(self, album), + ] + audiobook_count += 1 + elif kind == "music" and want_album and not album_full: + result.albums = [*result.albums, parse_album(self, album)] + album_count += 1 except InvalidDataError as err: - self.logger.debug("Error parsing album: %s", err) + self.logger.debug("Error parsing %s album: %s", kind, err) # Parse artists if MediaType.ARTIST in media_types and search_result.artists: @@ -1355,22 +2064,34 @@ async def search( except InvalidDataError as err: self.logger.debug("Error parsing playlist: %s", err) + # Parse podcasts (Yandex returns them as albums under .podcasts) + podcasts_node = getattr(search_result, "podcasts", None) + if MediaType.PODCAST in media_types and podcasts_node: + for album in podcasts_node.results[:limit]: + try: + result.podcasts = [*result.podcasts, parse_podcast(self, album)] + except InvalidDataError as err: + self.logger.debug("Error parsing podcast: %s", err) + return result # Get single items @use_cache(3600 * 24 * 30) async def get_artist(self, prov_artist_id: str) -> Artist: - """Get artist details by ID. + """Get artist details by ID, enriched with description and listener stats. :param prov_artist_id: The provider artist ID. :return: Artist object. :raises MediaNotFoundError: If artist not found. """ - artist = await self.client.get_artist(prov_artist_id) + artist, about = await asyncio.gather( + self.client.get_artist(prov_artist_id), + self.client.get_artist_about(prov_artist_id), + ) if not artist: raise MediaNotFoundError(f"Artist {prov_artist_id} not found") - return parse_artist(self, artist) + return parse_artist(self, artist, about=about) @use_cache(3600 * 24 * 30) async def get_album(self, prov_album_id: str) -> Album: @@ -1385,6 +2106,87 @@ async def get_album(self, prov_album_id: str) -> Album: raise MediaNotFoundError(f"Album {prov_album_id} not found") return parse_album(self, album) + @use_cache(3600 * 24) + async def get_podcast(self, prov_podcast_id: str) -> Podcast: + """Get podcast details by ID (backed by a Yandex album). + + :param prov_podcast_id: The provider podcast (album) ID. + :return: Podcast object. + :raises MediaNotFoundError: If not found. + """ + album = await self.client.get_album(prov_podcast_id) + if not album: + raise MediaNotFoundError(f"Podcast {prov_podcast_id} not found") + return parse_podcast(self, album) + + async def get_podcast_episodes( + self, prov_podcast_id: str + ) -> AsyncGenerator[PodcastEpisode, None]: + """Iterate podcast episodes for a given podcast (album) ID.""" + album = await self.client.get_album_with_tracks(prov_podcast_id) + if not album: + raise MediaNotFoundError(f"Podcast {prov_podcast_id} not found") + podcast = parse_podcast(self, album) + position = 1 + for disc in album.volumes or []: + for track_obj in disc: + try: + yield parse_podcast_episode(self, track_obj, podcast, position=position) + except InvalidDataError as err: + self.logger.debug("Error parsing podcast episode: %s", err) + position += 1 + + async def get_podcast_episode(self, prov_episode_id: str) -> PodcastEpisode: + """Get a single podcast episode by ID. + + The parent Podcast is reconstructed from the track's parent album. If + the album isn't present on the track, the episode cannot be converted + into a valid MA model and InvalidDataError is raised. + """ + tracks = await self.client.get_tracks([prov_episode_id]) + if not tracks: + raise MediaNotFoundError(f"Podcast episode {prov_episode_id} not found") + track_obj = tracks[0] + if not track_obj.albums: + raise InvalidDataError( + f"Podcast episode {prov_episode_id} is missing parent podcast album data" + ) + podcast = parse_podcast(self, track_obj.albums[0]) + return parse_podcast_episode(self, track_obj, podcast, position=0) + + @use_cache(3600 * 24) + async def get_audiobook(self, prov_audiobook_id: str) -> Audiobook: + """Get audiobook details by ID, including chapters built from tracks. + + :param prov_audiobook_id: The provider audiobook (album) ID. + :return: Audiobook object. + :raises MediaNotFoundError: If not found. + """ + album = await self.client.get_album_with_tracks(prov_audiobook_id) + if not album: + raise MediaNotFoundError(f"Audiobook {prov_audiobook_id} not found") + audiobook = parse_audiobook(self, album) + + chapters: list[MediaItemChapter] = [] + start = 0.0 + pos = 1 + for disc in album.volumes or []: + for track_obj in disc: + dur_s = (track_obj.duration_ms or 0) / 1000.0 + chapters.append( + MediaItemChapter( + position=pos, + name=track_obj.title or f"Chapter {pos}", + start=start, + end=start + dur_s, + ) + ) + start += dur_s + pos += 1 + audiobook.metadata.chapters = chapters + audiobook.duration = int(start) + return audiobook + async def get_track(self, prov_track_id: str) -> Track: """Get track details by ID. @@ -1496,23 +2298,24 @@ async def _get_my_wave_playlist_tracks(self, page: int) -> list[Track]: :param page: Page number (0 = first batch, 1+ = next batches via queue cursor). :return: List of Track objects for this page. """ - async with self._my_wave_lock: + wave = self._get_wave_state(ROTOR_STATION_MY_WAVE) + async with wave.lock: max_tracks_config = int( self.config.get_value(CONF_MY_WAVE_MAX_TRACKS) or 150 # type: ignore[arg-type] ) # Reset seen tracks on first page if page == 0: - self._my_wave_seen_track_ids = set() + wave.seen_track_ids = set() queue: str | int | None = None if page > 0: - queue = self._my_wave_playlist_next_cursor + queue = wave.playlist_next_cursor if not queue: return [] # Check if we've already reached the limit - if len(self._my_wave_seen_track_ids) >= max_tracks_config: + if len(wave.seen_track_ids) >= max_tracks_config: return [] tracks: list[Track] = [] @@ -1520,30 +2323,30 @@ async def _get_my_wave_playlist_tracks(self, page: int) -> list[Track]: # Fetch MY_WAVE_BATCH_SIZE Rotor API batches per page to reduce API round-trips for _ in range(MY_WAVE_BATCH_SIZE): - if len(self._my_wave_seen_track_ids) >= max_tracks_config: + if len(wave.seen_track_ids) >= max_tracks_config: break - yandex_tracks, batch_id = await self.client.get_my_wave_tracks(queue=queue) - if batch_id: - self._my_wave_batch_id = batch_id - if not self._my_wave_radio_started_sent and yandex_tracks: - sent = await self.client.send_rotor_station_feedback( - ROTOR_STATION_MY_WAVE, - "radioStarted", - batch_id=batch_id, + if queue is not None: + wave.last_track_id = str(queue) + yandex_tracks, _ = await self._fetch_rotor_session_batch( + wave, ROTOR_STATION_MY_WAVE + ) + if not wave.radio_started_sent and yandex_tracks: + sent = await self._send_wave_feedback( + wave, ROTOR_STATION_MY_WAVE, "radioStarted" ) if sent: - self._my_wave_radio_started_sent = True + wave.radio_started_sent = True if not yandex_tracks: break first_track_id_this_batch = None for yt in yandex_tracks: - if len(self._my_wave_seen_track_ids) >= max_tracks_config: + if len(wave.seen_track_ids) >= max_tracks_config: break - track = self._parse_my_wave_track(yt, self._my_wave_seen_track_ids) + track = self._parse_my_wave_track(yt, wave.seen_track_ids) if track is None: continue @@ -1560,7 +2363,7 @@ async def _get_my_wave_playlist_tracks(self, page: int) -> list[Track]: break # Store cursor for next page call (None clears pagination so next call returns []) - self._my_wave_playlist_next_cursor = next_cursor + wave.playlist_next_cursor = next_cursor return tracks async def _get_liked_tracks_playlist_tracks(self, page: int) -> list[Track]: @@ -1642,27 +2445,98 @@ async def get_album_tracks(self, prov_album_id: str) -> list[Track]: self.logger.debug("Error parsing album track: %s", err) return tracks - @use_cache(3600 * 3) async def get_similar_tracks(self, prov_track_id: str, limit: int = 25) -> list[Track]: - """Get similar tracks using Yandex Rotor station for this track. + """Get similar tracks, preferring pre-fetched wave tracks when available. + + Split in two paths with different caching policies: - Uses rotor station track:{id} so MA radio mode gets Yandex recommendations. + - **Wave-drain path** (the seed carries a station suffix and + ``wave.prefetched`` is non-empty). Uncached by design: it mutates + state, a cache hit would replay the same drained tracks forever and + the prefetch buffer would never advance. + - **Fallback path** (plain track_id, no active wave, or empty buffer). + Creates a per-seed rotor session under ``track:{id}`` and is cached + for 3 hours — this is pure and safe to memoise. :param prov_track_id: Provider track ID (plain or track_id@station_id). :param limit: Maximum number of tracks to return. :return: List of similar Track objects. """ - track_id, _ = _parse_radio_item_id(prov_track_id) - station_id = f"track:{track_id}" - yandex_tracks, _ = await self.client.get_rotor_station_tracks(station_id, queue=None) - tracks = [] - for yt in yandex_tracks[:limit]: + track_id, station_key = _parse_radio_item_id(prov_track_id) + + if station_key: + drained = await self._drain_prefetched_wave_tracks(station_key, limit) + if drained: + return drained + + return await self._fetch_similar_tracks_for_seed(track_id, limit) + + async def _drain_prefetched_wave_tracks(self, station_key: str, limit: int) -> list[Track]: + """Pop up to ``limit`` prefetched tracks off the wave state. + + Runs under ``wave.lock`` so it doesn't race with + ``_prefetch_rotor_session`` which extends the same list under the + same lock. Returns an empty list when there's no active session or + nothing prefetched; callers then fall through to the cached fetch. + + This method is intentionally not cached — it mutates wave state. + """ + wave = self._wave_states.get(station_key) + if not (wave and wave.session_id and wave.prefetched): + return [] + async with wave.lock: + if not wave.prefetched: + return [] + drained_yt = wave.prefetched[:limit] + wave.prefetched = wave.prefetched[limit:] + tracks: list[Track] = [] + for yt in drained_yt: try: tracks.append(parse_track(self, yt)) except InvalidDataError as err: - self.logger.debug("Error parsing similar track: %s", err) + self.logger.debug("Error parsing prefetched wave track: %s", err) return tracks + @use_cache(3600 * 3) + async def _fetch_similar_tracks_for_seed(self, track_id: str, limit: int) -> list[Track]: + """Create a one-off rotor session for ``track:{id}`` and return up to ``limit`` tracks. + + Stateless by design: similar-tracks results don't participate in + playback feedback or prefetch, so there is no need to keep a + ``_WaveState`` entry around. Going through ``_fetch_rotor_session_batch`` + would create one per unique seed and grow ``_wave_states`` without + bound under normal DSTM usage; call ``rotor_session_new`` directly + instead. + + Pure function of ``track_id`` / ``limit``, hence safe to memoise + via ``@use_cache``. + """ + _, yandex_tracks, _ = await self.client.rotor_session_new(f"track:{track_id}") + similar_tracks: list[Track] = [] + for yt in yandex_tracks[:limit]: + try: + similar_tracks.append(parse_track(self, yt)) + except InvalidDataError as err: + self.logger.debug("Error parsing similar track: %s", err) + return similar_tracks + + @use_cache(3600 * 3) + async def get_similar_artists(self, prov_artist_id: str, limit: int = 25) -> list[Artist]: + """Get artists similar to the given one via Yandex artists/similar endpoint. + + :param prov_artist_id: Provider artist ID. + :param limit: Maximum number of artists to return. + :return: List of similar Artist objects. + """ + yandex_artists = await self.client.get_similar_artists(prov_artist_id, limit=limit) + artists: list[Artist] = [] + for ya in yandex_artists: + try: + artists.append(parse_artist(self, ya)) + except InvalidDataError as err: + self.logger.debug("Error parsing similar artist: %s", err) + return artists + async def recommendations(self) -> list[RecommendationFolder]: """Get recommendations with multiple discovery folders. @@ -1722,6 +2596,11 @@ async def recommendations(self) -> list[RecommendationFolder]: async def _get_my_wave_recommendations(self) -> RecommendationFolder | None: """Get My Wave recommendation folder with personalized tracks. + Shares the same `_WaveState(ROTOR_STATION_MY_WAVE)` with browse and + virtual-playlist flows, so session_id + batch_id established here + carry into `on_played`/`on_streamed` feedback even when the user + starts playback from this discovery card. + :return: RecommendationFolder with My Wave tracks, or None if empty. """ max_tracks_config = int( @@ -1729,35 +2608,46 @@ async def _get_my_wave_recommendations(self) -> RecommendationFolder | None: ) batch_size_config = MY_WAVE_BATCH_SIZE + wave = self._get_wave_state(ROTOR_STATION_MY_WAVE) + # Local dedup so the recommendations card stays independent from the + # browse/virtual-playlist dedup set (which may be larger and stale). + # Only session_id + batch_id + last_track_id are shared with `wave`. seen_track_ids: set[str] = set() items: list[Track] = [] - queue: str | int | None = None - for _ in range(batch_size_config): - if len(seen_track_ids) >= max_tracks_config: - break - - yandex_tracks, _ = await self.client.get_my_wave_tracks(queue=queue) - if not yandex_tracks: - break - - first_track_id_this_batch = None - for yt in yandex_tracks: + # Hold the wave lock across the whole fetch chain — we mutate shared + # session_id/batch_id/last_track_id via _fetch_rotor_session_batch, + # and other call sites (browse, virtual-playlist) guard the same + # state with this lock. Concurrent calls without the lock would + # interleave cursor updates and leave the session inconsistent. + async with wave.lock: + for _ in range(batch_size_config): if len(seen_track_ids) >= max_tracks_config: break - track = self._parse_my_wave_track(yt, seen_ids=seen_track_ids) - if track is None: - continue + yandex_tracks, _ = await self._fetch_rotor_session_batch( + wave, ROTOR_STATION_MY_WAVE + ) + if not yandex_tracks: + break - items.append(track) - track_id = track.item_id.split(RADIO_TRACK_ID_SEP, 1)[0] - if first_track_id_this_batch is None: - first_track_id_this_batch = track_id + first_track_id_this_batch: str | None = None + for yt in yandex_tracks: + if len(seen_track_ids) >= max_tracks_config: + break - queue = first_track_id_this_batch - if not queue: - break + track = self._parse_my_wave_track(yt, seen_ids=seen_track_ids) + if track is None: + continue + + items.append(track) + track_id = track.item_id.split(RADIO_TRACK_ID_SEP, 1)[0] + if first_track_id_this_batch is None: + first_track_id_this_batch = track_id + + if first_track_id_this_batch is None: + break + wave.last_track_id = first_track_id_this_batch if not items: return None @@ -1788,7 +2678,10 @@ async def _get_feed_recommendations(self) -> RecommendationFolder | None: for gen_playlist in feed.generated_playlists: if gen_playlist.data and gen_playlist.ready: try: - items.append(parse_playlist(self, gen_playlist.data)) + # Mark feed-generated playlists (Playlist of the Day, DejaVu, + # Premiere, Missed Likes) as dynamic — Yandex regenerates them + # on a schedule so MA must not long-cache the track list. + items.append(parse_playlist(self, gen_playlist.data, is_dynamic=True)) except InvalidDataError as err: self.logger.debug("Error parsing feed playlist: %s", err) if not items: @@ -2182,16 +3075,59 @@ async def get_library_artists(self) -> AsyncGenerator[Artist, None]: except InvalidDataError as err: self.logger.debug("Error parsing library artist: %s", err) + async def _get_liked_albums_cached(self, ttl: float = 30.0) -> list[YandexAlbum]: + """Return liked albums with a short in-process TTL cache + lock. + + Albums, podcasts and audiobooks are all derived from the same + ``users/{uid}/likes/albums`` endpoint, so a full library sync would + otherwise trigger three sequential (or concurrent) identical calls. + The lock serializes refreshes so only one request hits the API when + multiple library syncs start together. + """ + async with self._liked_albums_lock: + now = asyncio.get_running_loop().time() + if self._liked_albums_cache is not None: + cached_at, cached = self._liked_albums_cache + if now - cached_at < ttl: + return cached + albums = await self.client.get_liked_albums(batch_size=TRACK_BATCH_SIZE) + self._liked_albums_cache = (now, albums) + return albums + async def get_library_albums(self) -> AsyncGenerator[Album, None]: - """Retrieve library albums from Yandex Music.""" - batch_size = TRACK_BATCH_SIZE - albums = await self.client.get_liked_albums(batch_size=batch_size) - for album in albums: + """Retrieve library albums from Yandex Music. + + Excludes entries classified as podcasts or audiobooks so they don't + duplicate into the Albums library view. + """ + for album in await self._get_liked_albums_cached(): + if classify_album(album) != "music": + continue try: yield parse_album(self, album) except InvalidDataError as err: self.logger.debug("Error parsing library album: %s", err) + async def get_library_podcasts(self) -> AsyncGenerator[Podcast, None]: + """Retrieve library podcasts from Yandex Music (filtered liked albums).""" + for album in await self._get_liked_albums_cached(): + if classify_album(album) != "podcast": + continue + try: + yield parse_podcast(self, album) + except InvalidDataError as err: + self.logger.debug("Error parsing library podcast: %s", err) + + async def get_library_audiobooks(self) -> AsyncGenerator[Audiobook, None]: + """Retrieve library audiobooks from Yandex Music (filtered liked albums).""" + for album in await self._get_liked_albums_cached(): + if classify_album(album) != "audiobook": + continue + try: + yield parse_audiobook(self, album) + except InvalidDataError as err: + self.logger.debug("Error parsing library audiobook: %s", err) + async def get_library_tracks(self) -> AsyncGenerator[Track, None]: """Retrieve library tracks from Yandex Music.""" track_shorts = await self.client.get_liked_tracks() @@ -2243,17 +3179,27 @@ async def get_library_playlists(self) -> AsyncGenerator[Playlist, None]: async def library_add(self, item: MediaItemType) -> bool: """Add item to library. + For tracks carrying a wave station context in the item_id (e.g. when + the user adds a My Wave track to favourites during playback), also + fires a rotor ``like`` feedback on the active session so the wave + algorithm biases toward similar tracks immediately. + :param item: The media item to add. :return: True if successful. """ prov_item_id = self._get_provider_item_id(item) if not prov_item_id: return False - track_id, _ = _parse_radio_item_id(prov_item_id) + track_id, station_key = _parse_radio_item_id(prov_item_id) if item.media_type == MediaType.TRACK: - return await self.client.like_track(track_id) - if item.media_type == MediaType.ALBUM: + ok = await self.client.like_track(track_id) + if ok and station_key: + wave = self._wave_states.get(station_key) + if wave and wave.session_id: + await self._send_wave_feedback(wave, station_key, "like", track_id=track_id) + return ok + if item.media_type in (MediaType.ALBUM, MediaType.PODCAST, MediaType.AUDIOBOOK): return await self.client.like_album(prov_item_id) if item.media_type == MediaType.ARTIST: return await self.client.like_artist(prov_item_id) @@ -2269,7 +3215,7 @@ async def library_remove(self, prov_item_id: str, media_type: MediaType) -> bool track_id, _ = _parse_radio_item_id(prov_item_id) if media_type == MediaType.TRACK: return await self.client.unlike_track(track_id) - if media_type == MediaType.ALBUM: + if media_type in (MediaType.ALBUM, MediaType.PODCAST, MediaType.AUDIOBOOK): return await self.client.unlike_album(prov_item_id) if media_type == MediaType.ARTIST: return await self.client.unlike_artist(prov_item_id) @@ -2287,29 +3233,243 @@ def _get_provider_item_id(self, item: MediaItemType) -> str | None: async def get_stream_details( self, item_id: str, media_type: MediaType = MediaType.TRACK ) -> StreamDetails: - """Get stream details for a track. + """Get stream details for a track, podcast episode, or audiobook. + + A podcast episode is a track underneath the Yandex API, so it flows + through the same per-track streaming path. An audiobook is an album + with multiple tracks (chapters) — returned as a CUSTOM stream whose + generator concatenates each chapter's bytes in order. - :param item_id: The track ID (or track_id@station_id for My Wave). - :param media_type: The media type (should be TRACK). - :return: StreamDetails for the track. + :param item_id: The track / episode ID (or track_id@station_id for My Wave), + or the audiobook (album) ID when ``media_type`` is AUDIOBOOK. + :param media_type: The media type. + :return: StreamDetails for the item. """ + if media_type == MediaType.AUDIOBOOK: + return await self._get_audiobook_stream_details(item_id) return await self.streaming.get_stream_details(item_id) + async def _get_audiobook_stream_details(self, audiobook_id: str) -> StreamDetails: + """Build StreamDetails for an audiobook as a chapter-concatenated CUSTOM stream. + + Loads the album's tracks, uses the first chapter to establish the audio + format, and stores the per-chapter track-IDs + durations in ``data`` so + ``get_audio_stream`` can iterate them. ``can_seek=True`` so MA routes + ``seek_position`` into ``get_audio_stream``, where the provider translates + it into ``(start_chapter, in_chapter_offset)``. In-chapter precision + requires a byte-seekable chapter codec (raw MP3); otherwise the chapter + is restarted from its beginning. + """ + album = await self.client.get_album_with_tracks(audiobook_id) + if not album or not (album.volumes or []): + raise MediaNotFoundError(f"Audiobook {audiobook_id} has no chapters") + + chapter_ids, chapter_durations_ms = _extract_chapter_map_from_album(album) + if not chapter_ids: + raise MediaNotFoundError(f"Audiobook {audiobook_id} has no chapters") + + self._audiobook_chapter_cache[audiobook_id] = (chapter_ids, chapter_durations_ms) + + # Resolve first-chapter format so MA/ffmpeg know what it's decoding + first = await self.streaming.get_stream_details(chapter_ids[0]) + total_duration = sum(chapter_durations_ms) // 1000 + + return StreamDetails( + item_id=audiobook_id, + provider=self.instance_id, + media_type=MediaType.AUDIOBOOK, + audio_format=first.audio_format, + stream_type=StreamType.CUSTOM, + duration=total_duration, + data={ + "chapter_ids": chapter_ids, + "chapter_durations_ms": chapter_durations_ms, + }, + can_seek=True, + allow_seek=True, + ) + async def get_audio_stream( self, streamdetails: StreamDetails, seek_position: int = 0 ) -> AsyncGenerator[bytes, None]: """Return the audio stream for the provider item. - This method is called when StreamType.CUSTOM is used, enabling on-the-fly - decryption of encrypted FLAC streams without disk I/O. + For tracks and podcast episodes, streams via windowed Range requests + (raw or AES-CTR encrypted). For audiobooks, iterates chapters: each + chapter's bytes are streamed through the per-track path and concatenated. - :param streamdetails: Stream details containing encrypted URL and decryption key. - :param seek_position: Seek position in seconds (not supported for encrypted streams). - :return: Async generator yielding decrypted audio chunks. + :param streamdetails: Stream details with URL and optional decryption key. + :param seek_position: Seek position in seconds (handled by provider for raw transport). + :return: Async generator yielding audio chunks. """ + data = streamdetails.data if isinstance(streamdetails.data, dict) else None + if streamdetails.media_type == MediaType.AUDIOBOOK and data and "chapter_ids" in data: + async for chunk in self._stream_audiobook_chapters(data, seek_position): + yield chunk + return async for chunk in self.streaming.get_audio_stream(streamdetails, seek_position): yield chunk + def _resolve_audiobook_seek( + self, chapter_durations_ms: list[int], seek_position: int, n_chapters: int + ) -> tuple[int, int]: + """Map an audiobook ``seek_position`` (seconds) to (start_idx, chapter_seek).""" + if seek_position <= 0 or not chapter_durations_ms: + return 0, 0 + accumulated_ms = 0 + seek_ms = seek_position * 1000 + for idx, dur_ms in enumerate(chapter_durations_ms): + if accumulated_ms + dur_ms > seek_ms: + return idx, (seek_ms - accumulated_ms) // 1000 + accumulated_ms += dur_ms + # Seek past end — start at last chapter from 0 + return max(n_chapters - 1, 0), 0 + + async def _resolve_audiobook_chapter_map( + self, audiobook_id: str + ) -> tuple[list[str], list[int]]: + """Return (chapter_track_ids, chapter_durations_ms) for an audiobook. + + Served from an in-memory cache populated by ``_get_audiobook_stream_details``. + On a miss (e.g. ``on_played`` fires before streaming has started), falls back + to a fresh ``get_album_with_tracks`` call and refills the cache. + """ + cached = self._audiobook_chapter_cache.get(audiobook_id) + if cached is not None: + return cached + album = await self.client.get_album_with_tracks(audiobook_id) + if not album or not (album.volumes or []): + return [], [] + chapter_ids, chapter_durations_ms = _extract_chapter_map_from_album(album) + self._audiobook_chapter_cache[audiobook_id] = (chapter_ids, chapter_durations_ms) + return chapter_ids, chapter_durations_ms + + async def _stream_audiobook_chapters( + self, data: dict[str, Any], seek_position: int + ) -> AsyncGenerator[bytes, None]: + """Concatenate per-chapter streams of an audiobook. + + Translates ``seek_position`` into (start_chapter, in_chapter_offset) and + delegates each chapter to the per-track streaming path. In-chapter offset + is only applied when the chapter codec is byte-seekable (``can_seek``); + otherwise the chapter is restarted from its beginning. Tracks consecutive + chapter failures and raises ``MediaNotFoundError`` once the threshold is + exceeded, so playback never silently truncates. + """ + chapter_ids: list[str] = list(data.get("chapter_ids") or []) + chapter_durations_ms: list[int] = list(data.get("chapter_durations_ms") or []) + if not chapter_ids: + return + + start_idx, chapter_seek = self._resolve_audiobook_seek( + chapter_durations_ms, seek_position, len(chapter_ids) + ) + + max_consecutive_failures = 3 + consecutive_failures = 0 + has_yielded_audio = False + last_error: Exception | None = None + + for idx in range(start_idx, len(chapter_ids)): + chapter_id = chapter_ids[idx] + requested_offset = chapter_seek if idx == start_idx else 0 + chapter_details: StreamDetails | None = None + try: + chapter_details = await self.streaming.get_stream_details(chapter_id) + except asyncio.CancelledError: + raise + except Exception as err: + last_error = err + self.logger.warning( + "Audiobook chapter %d (%s) stream-details failed: %s", + idx + 1, + chapter_id, + err, + ) + + if chapter_details is None: + consecutive_failures += 1 + if consecutive_failures >= max_consecutive_failures: + raise MediaNotFoundError( + "Unable to stream audiobook: too many consecutive chapter failures" + ) from last_error + continue + + # Apply the in-chapter offset only when the chapter codec supports + # byte-offset seeking; otherwise restart the chapter from 0 to avoid + # decoding garbled bytes from mid-file of a container format. + offset = requested_offset if chapter_details.can_seek else 0 + chapter_had_audio = False + try: + async for chunk in self.streaming.get_audio_stream(chapter_details, offset): + chapter_had_audio = True + has_yielded_audio = True + yield chunk + except asyncio.CancelledError: + raise + except Exception as err: + last_error = err + self.logger.warning( + "Audiobook chapter %d (%s) stream failed mid-play: %s", + idx + 1, + chapter_id, + err, + ) + + if chapter_had_audio: + consecutive_failures = 0 + last_error = None + else: + consecutive_failures += 1 + if consecutive_failures >= max_consecutive_failures: + raise MediaNotFoundError( + "Unable to stream audiobook: too many consecutive chapter failures" + ) from last_error + + if not has_yielded_audio: + raise MediaNotFoundError( + "Unable to stream audiobook: no playable chapters found" + ) from last_error + + async def get_rotor_station_tracks( + self, station_id: str, queue: str | int | None = None + ) -> tuple[list[Any], str | None]: + """Fetch tracks from a rotor station using the session API. + + Public surface — pinned by the ynison plugin + (`YandexMusicProviderLike.get_rotor_station_tracks`). The + ``(tracks, batch_id)`` return contract is kept for that caller even + though batch_id is now a session-scoped identifier. + + Routes to ``_fetch_rotor_session_batch`` so the wave session state + (`session_id`, seen tracks, prefetch) is shared with our own Browse / + on_played / on_streamed flows. ``queue`` is the most recently played + track ID the external caller observed — we record it as the + pagination cursor before calling through. + + :param station_id: Rotor station ID (e.g. "user:onyourwave", + "genre:rock", "mood:calm", "track:1234"). + :param queue: Last-played track ID for pagination. Ignored on the + very first call (no session yet) but still recorded. + :return: Tuple of (list of yandex tracks, batch_id or None). + """ + wave = self._get_wave_state(station_id) + # Cursor update + batch fetch run under the station's lock, matching + # the discipline in browse / recommendations / prefetch. Without it, + # ynison replenish racing with a concurrent MA browse could interleave + # last_track_id writes and leave session_id / batch_id out of sync. + async with wave.lock: + if queue is not None: + wave.last_track_id = str(queue) + return await self._fetch_rotor_session_batch(wave, station_id) + + def get_quality(self) -> str: + """Return the configured audio quality tier (e.g. 'balanced', 'superb').""" + quality = str(self.config.get_value(CONF_QUALITY) or QUALITY_BALANCED).strip().lower() + if quality == "lossless": + quality = QUALITY_SUPERB + return quality + async def resolve_image(self, path: str) -> str | bytes: """Resolve wave cover image with background color fill for transparent PNGs. @@ -2364,147 +3524,165 @@ async def on_played( media_item: MediaItemType, is_playing: bool = False, ) -> None: - """Report playback for rotor feedback when the track is from My Wave. - - Sends trackStarted when the track is currently playing (is_playing=True). - trackFinished/skip are sent from on_streamed to use accurate seconds_streamed. - - Also auto-enables "Don't stop the music" for any queue playing a radio track - so that MA refills the queue via get_similar_tracks when < 5 tracks remain. + """Report periodic playback updates. + + - Audiobooks: persist chapter progress via play_audio so Yandex's + own clients resume at the right point. + - Wave tracks: send rotor ``trackStarted`` while actively playing and + kick off a background prefetch so DSTM refill serves wave-curated + tracks with no extra round-trip. DSTM itself is the user's toggle — + the provider does not flip it. + + Generic track history reporting is not attempted here — the only + known channel Yandex writes into ``/handlers/music-history`` is a + long-lived Ynison WebSocket session, which lives in the sibling + yandex_ynison plugin. Regular tracks played through MA are therefore + invisible to Listening History unless that plugin is also active. """ - # Radio feedback always enabled + if media_type == MediaType.AUDIOBOOK: + await self._report_audiobook_progress(prov_item_id, position) + return if media_type != MediaType.TRACK: return - track_id, station_id = _parse_radio_item_id(prov_item_id) + _, station_id = _parse_radio_item_id(prov_item_id) + if station_id and is_playing: + track_id, _ = _parse_radio_item_id(prov_item_id) + wave = self._wave_states.get(station_id) or self._get_wave_state(station_id) + await self._send_wave_feedback(wave, station_id, "trackStarted", track_id=track_id) + self.mass.create_task(self._prefetch_rotor_session(station_id)) + + async def on_streamed(self, streamdetails: StreamDetails) -> None: + """Report stream completion to Yandex. + + - Audiobooks: a final ``play_audio`` with the absolute stream + position so the last listening point is preserved across Yandex + clients. Cleans up session state even when ``data`` was stripped. + - Wave tracks (composite item_id carries a station suffix): a rotor + ``trackFinished`` or ``skip`` event with the actual seconds streamed + so Yandex can improve recommendations. + """ + data = streamdetails.data if isinstance(streamdetails.data, dict) else None + if streamdetails.media_type == MediaType.AUDIOBOOK: + await self._report_audiobook_final(streamdetails, data or {}) + return + if streamdetails.media_type != MediaType.TRACK: + return + track_id, station_id = _parse_radio_item_id(streamdetails.item_id) if not station_id: return - # Auto-enable "Don't stop the music" on every on_played call for radio tracks. - # Calling on every invocation (not just is_playing=True) ensures it fires even - # for short tracks that finish before the 30-second periodic callback. - self._ensure_dont_stop_the_music(prov_item_id) - if is_playing: - if station_id == ROTOR_STATION_MY_WAVE: - batch_id = self._my_wave_batch_id - else: - state = self._wave_states.get(station_id) - batch_id = state.batch_id if state else None - await self.client.send_rotor_station_feedback( - station_id, - "trackStarted", - track_id=track_id, - batch_id=batch_id, - ) - # Remove duplicate call that was under is_playing guard. - # _ensure_dont_stop_the_music is now called unconditionally above. - - def _ensure_dont_stop_the_music(self, prov_item_id: str) -> None: - """Enable 'Don't stop the music' on queues playing this specific radio item. - - Iterates all queues and enables the setting on queues whose current track - mapping matches this exact composite item_id (track_id@station_id) for this - provider instance. + seconds = int(streamdetails.seconds_streamed or 0) + duration = int(streamdetails.duration or 0) + feedback_type = "trackFinished" if duration and seconds >= max(0, duration - 10) else "skip" + wave = self._wave_states.get(station_id) or self._get_wave_state(station_id) + await self._send_wave_feedback( + wave, station_id, feedback_type, track_id=track_id, total_played_seconds=seconds + ) - Also sets queue.radio_source directly to the current track because - enqueued_media_items is empty for BrowseFolder-initiated playback, which - normally prevents MA's auto-fill from triggering. Setting radio_source - directly bypasses that gap so _fill_radio_tracks runs when < 5 tracks remain. + def _audiobook_progress_point( + self, + chapter_durations_ms: list[int], + n_chapters: int, + absolute_sec: int, + ) -> tuple[int, int, int]: + """Resolve an absolute book position into a play_audio-ready tuple. + + Returns ``(chapter_idx, track_length_seconds, offset_seconds)``, applying + two invariants Yandex cares about and that ``_resolve_audiobook_seek`` + alone doesn't guarantee: + + - At/beyond end-of-book, map to end of the last chapter (not start), + so Yandex's resume point doesn't rewind to the start of the final + chapter on natural completion. + - ``track_length_seconds`` is clamped to at least 1 and ``offset`` to + ``[0, track_length_seconds]`` — a chapter with ``duration_ms=None`` + (coerced to 0 by the chapter-map builder) would otherwise send + ``track_length_seconds=0`` and block progress from syncing. """ - for queue in self.mass.player_queues: - current = queue.current_item - if current is None or current.media_item is None: - continue - item = current.media_item - # Match by provider instance and exact composite item_id - for mapping in getattr(item, "provider_mappings", []): - if ( - mapping.provider_instance == self.instance_id - and mapping.item_id == prov_item_id - ): - # Set radio_source directly so MA's fill mechanism works even when - # the queue was started from a BrowseFolder (enqueued_media_items empty). - if not queue.radio_source and isinstance(item, Track): - queue.radio_source = [item] - if not queue.dont_stop_the_music_enabled: - try: - self.mass.player_queues.set_dont_stop_the_music( - queue.queue_id, dont_stop_the_music_enabled=True - ) - self.logger.info( - "Auto-enabled 'Don't stop the music' for queue %s (radio station)", - queue.display_name, - ) - except Exception as err: - self.logger.debug( - "Could not enable 'Don't stop the music' for queue %s: %s", - queue.display_name, - err, - ) - break + absolute_sec = max(0, absolute_sec) + total_duration_sec = sum(chapter_durations_ms) // 1000 + last_idx = max(n_chapters - 1, 0) + if absolute_sec >= total_duration_sec > 0: + idx = last_idx + track_length_sec = max(1, chapter_durations_ms[idx] // 1000) + offset = track_length_sec + else: + idx, offset_raw = self._resolve_audiobook_seek( + chapter_durations_ms, absolute_sec, n_chapters + ) + track_length_sec = max(1, chapter_durations_ms[idx] // 1000) + offset = max(0, min(int(offset_raw), track_length_sec)) + return idx, track_length_sec, offset - def _ensure_dont_stop_the_music_for_queue(self, queue_id: str | None) -> None: - """Enable 'Don't stop the music' for a specific queue by ID. + async def _report_audiobook_progress(self, audiobook_id: str, position_sec: int) -> None: + """Push current listening position of an audiobook to Yandex. - Faster variant of _ensure_dont_stop_the_music used from on_streamed where - queue_id is available directly, avoiding iteration over all queues. + Resolves the playing chapter + offset from the cached chapter map, then + calls play_audio so Yandex persists the position for cross-client resume. + + Best-effort: any non-cancellation failure while resolving the chapter + map (rate-limit, network blip, auth edge case bubbling out of + ``_call_with_retry``) must never break pause/stop, so it is swallowed + here in addition to the errors already absorbed inside + ``api_client.play_audio``. """ - if not queue_id: - return - queue = self.mass.player_queues.get(queue_id) - if queue is None: + try: + chapter_ids, chapter_durations_ms = await self._resolve_audiobook_chapter_map( + audiobook_id + ) + except asyncio.CancelledError: + raise + except Exception as err: + self.logger.debug( + "Skipping audiobook progress report for %s (chapter map resolution failed): %s", + audiobook_id, + err, + ) return - current = queue.current_item - if current is None or current.media_item is None: + if not chapter_ids: + self.logger.debug( + "Audiobook %s has no chapter map; skipping progress report", audiobook_id + ) return - item = current.media_item - for mapping in getattr(item, "provider_mappings", []): - if ( - mapping.provider_instance == self.instance_id - and RADIO_TRACK_ID_SEP in mapping.item_id - ): - if not queue.radio_source and isinstance(item, Track): - queue.radio_source = [item] - if not queue.dont_stop_the_music_enabled: - try: - self.mass.player_queues.set_dont_stop_the_music( - queue_id, dont_stop_the_music_enabled=True - ) - self.logger.info( - "Auto-enabled 'Don't stop the music' for queue %s (radio)", - queue.display_name, - ) - except Exception as err: - self.logger.debug( - "Could not enable 'Don't stop the music' for queue %s: %s", - queue.display_name, - err, - ) - break + idx, track_length_sec, offset = self._audiobook_progress_point( + chapter_durations_ms, len(chapter_ids), int(position_sec) + ) + play_id = self._audiobook_play_ids.setdefault(audiobook_id, uuid.uuid4().hex) + await self.client.play_audio( + track_id=chapter_ids[idx], + album_id=audiobook_id, + play_id=play_id, + track_length_seconds=track_length_sec, + total_played_seconds=offset, + end_position_seconds=offset, + ) - async def on_streamed(self, streamdetails: StreamDetails) -> None: - """Report stream completion for My Wave rotor feedback. + async def _report_audiobook_final( + self, streamdetails: StreamDetails, data: dict[str, Any] + ) -> None: + """Send a closing play_audio for an audiobook stream. - Sends trackFinished or skip with actual seconds_streamed so Yandex - can improve recommendations. + Uses the streamdetails' own ``chapter_ids`` / ``chapter_durations_ms`` + (populated when the StreamDetails was created) to stay consistent with + what was actually played, then clears the session play_id and drops + the chapter-map cache entry so long-running instances can't grow the + cache without bound as users play more audiobooks. """ - # Radio feedback always enabled - track_id, station_id = _parse_radio_item_id(streamdetails.item_id) - if not station_id: + audiobook_id = streamdetails.item_id + chapter_ids = data.get("chapter_ids") or [] + chapter_durations_ms = data.get("chapter_durations_ms") or [] + play_id = self._audiobook_play_ids.pop(audiobook_id, None) or uuid.uuid4().hex + self._audiobook_chapter_cache.pop(audiobook_id, None) + if not chapter_ids or not chapter_durations_ms: return - # Also ensure Don't stop the music is active — on_streamed fires even for - # very short tracks and we have queue_id here directly. - self._ensure_dont_stop_the_music_for_queue(streamdetails.queue_id) - seconds = int(streamdetails.seconds_streamed or 0) - duration = streamdetails.duration or 0 - feedback_type = "trackFinished" if duration and seconds >= max(0, duration - 10) else "skip" - if station_id == ROTOR_STATION_MY_WAVE: - batch_id = self._my_wave_batch_id - else: - state = self._wave_states.get(station_id) - batch_id = state.batch_id if state else None - await self.client.send_rotor_station_feedback( - station_id, - feedback_type, - track_id=track_id, - total_played_seconds=seconds, - batch_id=batch_id, + absolute_sec = int(streamdetails.seek_position + (streamdetails.seconds_streamed or 0)) + idx, track_length_sec, offset = self._audiobook_progress_point( + chapter_durations_ms, len(chapter_ids), absolute_sec + ) + await self.client.play_audio( + track_id=chapter_ids[idx], + album_id=audiobook_id, + play_id=play_id, + track_length_seconds=track_length_sec, + total_played_seconds=offset, + end_position_seconds=offset, ) diff --git a/music_assistant/providers/yandex_music/streaming.py b/music_assistant/providers/yandex_music/streaming.py index 28766cf189..2effd6ac78 100644 --- a/music_assistant/providers/yandex_music/streaming.py +++ b/music_assistant/providers/yandex_music/streaming.py @@ -4,7 +4,7 @@ import asyncio from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Final import aiohttp from aiohttp import ClientPayloadError, ServerDisconnectedError @@ -17,11 +17,16 @@ from music_assistant.helpers.throttle_retry import BYPASS_THROTTLER from .constants import ( + CONF_CODECS, CONF_QUALITY, + CONF_TRANSPORT, + QUALITY_BALANCED, QUALITY_EFFICIENT, + QUALITY_FILE_INFO_PARAMS, QUALITY_HIGH, QUALITY_SUPERB, RADIO_TRACK_ID_SEP, + TRANSPORT_RAW, ) if TYPE_CHECKING: @@ -30,18 +35,26 @@ from .provider import YandexMusicProvider -# Encrypted-stream tuning constants +# Windowed-stream tuning constants _CHUNK_SIZE = 16384 # smaller than default 65536 for faster first-byte after retry _STREAM_TIMEOUT = aiohttp.ClientTimeout(total=None, sock_read=30) -# Yandex CDN drops TCP every ~6-7 MB per connection (observed via live traffic capture). -# By capping each Range request to 4 MB we stay well below that limit, so CDN drops -# should never occur during normal windowed playback. -_RANGE_WINDOW = 4 * 1024 * 1024 # 4 MB — must be a multiple of AES block size (16) -# Flat short delays for any residual TCP drops (network glitches within a 4 MB window) +# Yandex CDN drops TCP connections for slow consumers (observed at ~45s for raw transport +# at real-time playback rate ~200 KB/s). By capping each Range request to 4 MB we download +# each window quickly, preventing CDN drops for both raw and encrypted transports. +_RANGE_WINDOW = 4 * 1024 * 1024 # 4 MB per Range request +# AES-CTR block size in bytes (used for block-aligned Range requests in encrypted transport) +_AES_BLOCK_SIZE = 16 +# Flat short delays for TCP drops (network glitches within a 4 MB window) _TCP_DROP_DELAYS = (0.5, 1.0, 2.0) # Exponential delays for true network stalls (read timeout) _STALL_DELAYS = (2.0, 4.0, 8.0) +# Normalize Yandex codec names to MA ContentType values +_CODEC_ALIASES: Final[dict[str, str]] = { + "he-aac": "aac", + "mpeg": "mp3", +} + class YandexMusicStreamingManager: """Manages Yandex Music streaming operations.""" @@ -65,6 +78,9 @@ def _track_id_from_item_id(self, item_id: str) -> str: async def get_stream_details(self, item_id: str) -> StreamDetails: """Get stream details for a track. + Uses the unified /get-file-info endpoint for all quality tiers. + Falls back to /tracks/{id}/download-info if get-file-info fails. + :param item_id: Track ID or composite track_id@station_id for My Wave. :return: StreamDetails for the track (item_id preserved for on_streamed). :raises MediaNotFoundError: If stream URL cannot be obtained. @@ -74,103 +90,129 @@ async def get_stream_details(self, item_id: str) -> StreamDetails: if not track: raise MediaNotFoundError(f"Track {item_id} not found") - quality = self.provider.config.get_value(CONF_QUALITY) - quality_str = str(quality) if quality is not None else None - preferred_normalized = (quality_str or "").strip().lower() - - # Check for superb (lossless) quality - want_lossless = preferred_normalized in (QUALITY_SUPERB, "superb") - - # Backward compatibility: also check old "lossless" value (exact match) - if preferred_normalized == "lossless": - want_lossless = True - - # When user wants lossless, try get-file-info first (FLAC; download-info often MP3 only) - if want_lossless: - self.logger.debug("Requesting lossless via get-file-info for track %s", track_id) - file_info = await self.client.get_track_file_info_lossless(track_id) - if file_info: - url = file_info.get("url") - codec = file_info.get("codec") or "" - needs_decryption = file_info.get("needs_decryption", False) - - if url and codec.lower() in ("flac", "flac-mp4"): - audio_format = self._build_audio_format(codec) - - # Handle encrypted URLs from encraw transport - if needs_decryption and "key" in file_info: - self.logger.info( - "Streaming encrypted %s for track %s - will decrypt on-the-fly", - codec, - track_id, - ) - # Return StreamType.CUSTOM for streaming decryption. - # can_seek=False: provider always streams from position 0; - # allow_seek=True: ffmpeg handles seek with -ss input flag. - return StreamDetails( - item_id=item_id, - provider=self.provider.instance_id, - audio_format=audio_format, - stream_type=StreamType.CUSTOM, - duration=track.duration, - data={ - "encrypted_url": url, - "decryption_key": file_info["key"], - "codec": codec, - }, - can_seek=False, - allow_seek=True, - ) - # Unencrypted URL, use directly - self.logger.debug( - "Unencrypted stream for track %s: codec=%s", - item_id, - codec, - ) - return StreamDetails( - item_id=item_id, - provider=self.provider.instance_id, - audio_format=audio_format, - stream_type=StreamType.HTTP, - duration=track.duration, - path=url, - can_seek=True, - allow_seek=True, - expiration=50, # get-file-info URLs expire; force MA to re-fetch - ) - - # Default: use /tracks/.../download-info and select best quality - download_infos = await self.client.get_track_download_info(track_id, get_direct_links=True) - if not download_infos: - raise MediaNotFoundError(f"No stream info available for track {item_id}") + quality = ( + str(self.provider.config.get_value(CONF_QUALITY) or QUALITY_BALANCED).strip().lower() + ) + transport = ( + str(self.provider.config.get_value(CONF_TRANSPORT) or TRANSPORT_RAW).strip().lower() + ) + + # Backward compatibility: old "lossless" config value + if quality == "lossless": + quality = QUALITY_SUPERB + + fi_params = QUALITY_FILE_INFO_PARAMS.get( + quality, QUALITY_FILE_INFO_PARAMS[QUALITY_BALANCED] + ) + + # Allow advanced users to override codecs + codecs_override = str(self.provider.config.get_value(CONF_CODECS) or "").strip() + codecs = codecs_override or fi_params["codecs"] - codecs_available = [ - (getattr(i, "codec", None), getattr(i, "bitrate_in_kbps", None)) for i in download_infos - ] self.logger.debug( - "Stream quality for track %s: config quality=%s, available codecs=%s", + "Requesting stream for track %s: quality=%s, transport=%s, codecs=%s", + track_id, + quality, + transport, + codecs, + ) + + file_info = await self.client.get_track_file_info( track_id, - quality_str, - codecs_available, + quality=fi_params["quality"], + codecs=codecs, + transport=transport, + ) + + if file_info and file_info.get("url"): + url = file_info["url"] + codec = file_info.get("codec") or "" + needs_decryption = file_info.get("needs_decryption", False) + + # Gather audio params: API response first, then probe container + bit_rate = file_info.get("bitrate") or file_info.get("bitrate_in_kbps") or 0 + sample_rate = file_info.get("sample_rate") or 0 + bit_depth = file_info.get("bit_depth") or 0 + + if (not sample_rate or not bit_depth) and not needs_decryption: + # Probe raw stream headers for real sample_rate/bit_depth + probed_sr, probed_bd = await self._probe_stream_params(url, codec) + sample_rate = sample_rate or probed_sr + bit_depth = bit_depth or probed_bd + + self.logger.debug( + "Audio params for track %s: codec=%s, bit_rate=%s, sample_rate=%s, bit_depth=%s", + track_id, + codec, + bit_rate, + sample_rate, + bit_depth, + ) + + audio_format = self._build_audio_format( + codec, + bit_rate=bit_rate, + sample_rate=sample_rate, + bit_depth=bit_depth, + ) + + # Always use StreamType.CUSTOM with windowed Range requests to prevent CDN drops. + # can_seek=True only for codecs where bitrate * time yields a decodable byte + # offset — i.e. raw MP3. MP4-container codecs (aac-mp4, flac-mp4) need the + # ftyp/moov init atoms at the file start, so byte-offset seeks land in mdat + # with no codec config and produce undecodable data. Raw FLAC frames aren't + # fixed-size either, so byte-rate math doesn't land on a frame boundary. + # allow_seek=True lets ffmpeg handle time-based seeking via -ss in those cases. + byte_seekable = codec.lower() in ("mp3", "mpeg") + can_seek = not needs_decryption and bit_rate > 0 and byte_seekable + data: dict[str, Any] = { + "url": url, + "codec": codec, + "transport": transport, + "bit_rate": bit_rate, + # Stored for URL refresh on 4xx: + "fi_quality": fi_params["quality"], + "fi_codecs": codecs, + } + if needs_decryption and "key" in file_info: + data["decryption_key"] = file_info["key"] + + return StreamDetails( + item_id=item_id, + provider=self.provider.instance_id, + audio_format=audio_format, + stream_type=StreamType.CUSTOM, + duration=track.duration, + data=data, + can_seek=can_seek, + allow_seek=not needs_decryption, + ) + + # Fallback: /tracks/{id}/download-info (defensive, should rarely trigger) + self.logger.warning( + "get-file-info failed for track %s, falling back to download-info", track_id ) - selected_info = self._select_best_quality(download_infos, quality_str) + download_infos = await self.client.get_track_download_info(track_id, get_direct_links=True) + if not download_infos: + raise MediaNotFoundError(f"No stream info available for track {item_id}") + selected_info = self._select_best_quality(download_infos, quality) if not selected_info or not selected_info.direct_link: raise MediaNotFoundError(f"No stream URL available for track {item_id}") self.logger.debug( - "Stream selected for track %s: codec=%s, bitrate=%s", + "Fallback stream for track %s: codec=%s, bitrate=%s", track_id, getattr(selected_info, "codec", None), getattr(selected_info, "bitrate_in_kbps", None), ) - bitrate = selected_info.bitrate_in_kbps or 0 - return StreamDetails( item_id=item_id, provider=self.provider.instance_id, - audio_format=self._build_audio_format(selected_info.codec, bit_rate=bitrate), + audio_format=self._build_audio_format( + selected_info.codec, bit_rate=selected_info.bitrate_in_kbps or 0 + ), stream_type=StreamType.HTTP, duration=track.duration, path=selected_info.direct_link, @@ -184,6 +226,8 @@ def _select_best_quality( ) -> DownloadInfo | None: """Select the best quality download info based on user preference. + Used as fallback when get-file-info is unavailable. + :param download_infos: List of DownloadInfo objects. :param preferred_quality: User's quality preference (efficient/high/balanced/superb). :return: Best matching DownloadInfo or None. @@ -200,10 +244,10 @@ def _select_best_quality( reverse=True, ) - # Superb: Prefer FLAC (backward compatibility with "lossless") - if preferred_normalized == QUALITY_SUPERB or "lossless" in preferred_normalized: - # Note: flac-mp4 typically comes from get-file-info API, not download-info, - # but we check here for forward compatibility in case the API changes. + # Superb: Prefer FLAC. The legacy "lossless" alias still maps to Superb, + # but we use an exact-match set so a stray value like "lossless_foo" + # doesn't sneak in. + if preferred_normalized in {QUALITY_SUPERB, "lossless"}: for codec in ("flac-mp4", "flac"): for info in sorted_infos: if info.codec and info.codec.lower() == codec: @@ -215,12 +259,10 @@ def _select_best_quality( # Efficient: Prefer lowest bitrate AAC/MP3 if preferred_normalized == QUALITY_EFFICIENT: - # Sort ascending for lowest bitrate sorted_infos_asc = sorted( download_infos, key=lambda x: x.bitrate_in_kbps or 999, ) - # Prefer AAC for efficiency, then MP3 (include MP4 container variants) for codec in ("aac-mp4", "aac", "he-aac-mp4", "he-aac", "mp3"): for info in sorted_infos_asc: if info.codec and info.codec.lower() == codec: @@ -229,7 +271,6 @@ def _select_best_quality( # High: Prefer high bitrate MP3 (~320kbps) if preferred_normalized == QUALITY_HIGH: - # Look for MP3 with bitrate >= 256kbps high_quality_mp3 = [ info for info in sorted_infos @@ -239,122 +280,233 @@ def _select_best_quality( and info.bitrate_in_kbps >= 256 ] if high_quality_mp3: - return high_quality_mp3[0] # Already sorted by bitrate descending + return high_quality_mp3[0] - # Fallback: any MP3 available (highest bitrate) for info in sorted_infos: if info.codec and info.codec.lower() == "mp3": return info - # If no MP3, use highest available (excluding FLAC) for info in sorted_infos: if info.codec and info.codec.lower() not in ("flac", "flac-mp4"): return info - # Last resort: highest available return sorted_infos[0] - # Balanced (default): Prefer ~192kbps AAC, or medium quality MP3 - # Look for bitrate around 192kbps (within range 128-256) + # Balanced (default): Prefer ~192kbps AAC balanced_infos = [ info for info in sorted_infos if info.bitrate_in_kbps and 128 <= info.bitrate_in_kbps <= 256 ] if balanced_infos: - # Prefer AAC over MP3 at similar bitrate (include MP4 container variants) for codec in ("aac-mp4", "aac", "he-aac-mp4", "he-aac", "mp3"): for info in balanced_infos: if info.codec and info.codec.lower() == codec: return info return balanced_infos[0] - # Fallback to highest available if no balanced option return sorted_infos[0] if sorted_infos else None def _get_content_type(self, codec: str | None) -> tuple[ContentType, ContentType]: - """Determine container and codec type from Yandex API codec string. + """Determine content_type and codec_type from Yandex API codec string. - Yandex API returns codec strings like "flac-mp4" (FLAC in MP4 container), - "aac-mp4" (AAC in MP4 container), or plain "flac", "mp3", "aac". + Parses the codec string automatically: + - Simple codecs ("flac", "mp3", "aac") → (ContentType., UNKNOWN) + - Compound "codec-container" ("flac-mp4", "aac-mp4") → + (ContentType., ContentType.) - :param codec: Codec string from Yandex API. - :return: Tuple of (content_type/container, codec_type). + content_type always reflects the audio codec (not the container), + so MA's is_lossless() correctly identifies lossless streams and + ffmpeg gets the right decoder name via codec_type. + + :param codec: Codec string from Yandex API (e.g. "flac-mp4", "mp3"). + :return: Tuple of (content_type, codec_type). """ if not codec: return ContentType.UNKNOWN, ContentType.UNKNOWN codec_lower = codec.lower() - # MP4 container variants: codec is inside an MP4 container - if codec_lower == "flac-mp4": - return ContentType.MP4, ContentType.FLAC - if codec_lower in ("aac-mp4", "he-aac-mp4"): - return ContentType.MP4, ContentType.AAC - - # Plain single-codec formats: codec is implied by content_type, no separate codec_type - if codec_lower == "flac": - return ContentType.FLAC, ContentType.UNKNOWN - if codec_lower in ("mp3", "mpeg"): - return ContentType.MP3, ContentType.UNKNOWN - if codec_lower in ("aac", "he-aac"): - return ContentType.AAC, ContentType.UNKNOWN + # Strip container suffix: "flac-mp4" → "flac", "he-aac-mp4" → "he-aac" + has_container = codec_lower.endswith("-mp4") + audio_part = codec_lower[:-4] if has_container else codec_lower - return ContentType.UNKNOWN, ContentType.UNKNOWN + # Normalize aliases (he-aac → aac, mpeg → mp3) + audio_part = _CODEC_ALIASES.get(audio_part, audio_part) - def _get_audio_params(self, codec: str | None) -> tuple[int, int]: - """Return (sample_rate, bit_depth) defaults based on codec string. + try: + content_type = ContentType(audio_part) + except ValueError: + self.logger.debug("Unknown codec from Yandex API: %s", codec) + return ContentType.UNKNOWN, ContentType.UNKNOWN - The Yandex get-file-info API does not return sample rate or bit depth, - so we use codec-based defaults. These values help the core select the - correct PCM output format and avoid unnecessary resampling. + # For compound formats, set codec_type so ffmpeg knows the decoder + codec_type = content_type if has_container else ContentType.UNKNOWN + return content_type, codec_type - :param codec: Codec string from Yandex API (e.g. "flac-mp4", "flac", "mp3"). - :return: Tuple of (sample_rate, bit_depth). - """ - if codec and codec.lower() == "flac-mp4": - return 48000, 24 - # CD-quality defaults for all other codecs - return 44100, 16 + def _build_audio_format( + self, + codec: str | None, + *, + bit_rate: int = 0, + sample_rate: int = 0, + bit_depth: int = 0, + ) -> AudioFormat: + """Build AudioFormat from codec string and optional stream metadata. - def _build_audio_format(self, codec: str | None, bit_rate: int = 0) -> AudioFormat: - """Build AudioFormat with content type and codec-based audio params. + Values of 0 mean "unknown — let MA/ffmpeg detect from the actual stream". + Pass real values from the API response when available. - :param codec: Codec string from Yandex API (e.g. "flac-mp4", "flac", "mp3"). - :param bit_rate: Bitrate in kbps (0 for variable/unknown). + :param codec: Codec string from Yandex API. + :param bit_rate: Bitrate in kbps (0 = unknown). + :param sample_rate: Sample rate in Hz (0 = unknown, detect from stream). + :param bit_depth: Bit depth (0 = unknown, detect from stream). :return: Configured AudioFormat instance. """ content_type, codec_type = self._get_content_type(codec) - sample_rate, bit_depth = self._get_audio_params(codec) - return AudioFormat( - content_type=content_type, - codec_type=codec_type, - bit_rate=bit_rate, - sample_rate=sample_rate, - bit_depth=bit_depth, - ) + kwargs: dict[str, Any] = { + "content_type": content_type, + "codec_type": codec_type, + } + # Only pass non-zero values; AudioFormat defaults to 44100/16 which + # MA/ffmpeg rely on. Passing 0 would override those defaults. + if bit_rate: + kwargs["bit_rate"] = bit_rate + if sample_rate: + kwargs["sample_rate"] = sample_rate + if bit_depth: + kwargs["bit_depth"] = bit_depth + return AudioFormat(**kwargs) + + @staticmethod + def _parse_flac_streaminfo(header: bytes) -> tuple[int, int]: + """Extract sample_rate and bit_depth from FLAC STREAMINFO block. + + FLAC format: 4-byte magic "fLaC", then metadata blocks. + STREAMINFO is always the first block (type 0), 34 bytes payload. + Bytes 10-17 of STREAMINFO contain sample_rate (20 bits), + channels (3 bits), bit_depth (5 bits), total samples (36 bits). + + :param header: First 42+ bytes of the FLAC stream. + :return: (sample_rate, bit_depth) or (0, 0) on parse failure. + """ + if len(header) < 42 or header[:4] != b"fLaC": + return 0, 0 + # STREAMINFO payload: 4 magic + 4 block header = 8 byte offset, 34 bytes long + # Bytes 10-13 of payload: sample_rate(20) | channels(3) | bps(5) | total(4 high) + payload = header[8:] # skip "fLaC" + block header + if len(payload) < 34: + return 0, 0 + val = int.from_bytes(payload[10:14], "big") + sample_rate = (val >> 12) & 0xFFFFF + bit_depth = ((val >> 4) & 0x1F) + 1 + return sample_rate, bit_depth - async def _refresh_encrypted_url( + @staticmethod + def _parse_mp4_audio_params(header: bytes) -> tuple[int, int]: + """Extract sample_rate and bit_depth from MP4/fMP4 container. + + Scans for the 'dfLa' (FLAC-in-MP4) box, or falls back to parsing + the AudioSampleEntry in an 'mp4a' box to read sample size and sample rate. + + :param header: First 8-32 KB of the MP4 stream. + :return: (sample_rate, bit_depth) or (0, 0) if not found. + """ + # Quick scan for dfLa box (FLAC-in-MP4: contains FLAC STREAMINFO) + dfl_pos = header.find(b"dfLa") + if dfl_pos >= 4: + # dfLa box layout after "dfLa" type: + # 4 bytes version/flags + # 4 bytes STREAMINFO block header (type byte + 3-byte length) + # 34 bytes STREAMINFO payload + payload_start = dfl_pos + 4 + 4 + 4 # after type + version/flags + block header + payload = header[payload_start:] + if len(payload) >= 34: + val = int.from_bytes(payload[10:14], "big") + sample_rate = (val >> 12) & 0xFFFFF + bit_depth = ((val >> 4) & 0x1F) + 1 + if 8000 <= sample_rate <= 384000 and 1 <= bit_depth <= 32: + return sample_rate, bit_depth + + # Scan for mp4a AudioSampleEntry (AAC/generic audio in MP4) + mp4a_pos = header.find(b"mp4a") + if mp4a_pos >= 4: + # AudioSampleEntry: 4-byte size, "mp4a", 6 reserved, 2 data_ref, + # 2 version, 2 revision, 4 vendor, 2 channels, 2 sample_size, + # 2 compression_id, 2 packet_size, 4 sample_rate (16.16 fixed-point) + entry_start = mp4a_pos + 4 # after "mp4a" + entry = header[entry_start:] + if len(entry) >= 28: + sample_size = int.from_bytes(entry[18:20], "big") + sr_fixed = int.from_bytes(entry[24:28], "big") + sample_rate = sr_fixed >> 16 + bit_depth = max(0, sample_size) + if 8000 <= sample_rate <= 384000: + return sample_rate, bit_depth + + return 0, 0 + + async def _probe_stream_params(self, url: str, codec: str) -> tuple[int, int]: + """Probe audio params by reading the first bytes of the stream. + + Makes a small Range request to read container/stream headers, + then parses FLAC STREAMINFO or MP4 box structure. + + :param url: Stream URL. + :param codec: Codec string from API (e.g. "flac-mp4", "flac"). + :return: (sample_rate, bit_depth) or (0, 0) if probing fails. + """ + codec_lower = (codec or "").lower() + # Determine how many bytes to read and which parser to use + if codec_lower == "flac": + probe_size = 64 + parser = self._parse_flac_streaminfo + elif "-mp4" in codec_lower: + probe_size = 32768 # MP4 moov/stsd can be further in + parser = self._parse_mp4_audio_params + else: + return 0, 0 # lossy without container — let MA detect + + try: + headers = {"Range": f"bytes=0-{probe_size - 1}"} + async with self.mass.http_session.get( + url, + headers=headers, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + if resp.status not in (200, 206): + return 0, 0 + header_bytes = await resp.content.read(probe_size) + self.logger.debug("Probe read %d bytes for codec=%s", len(header_bytes), codec) + result = parser(header_bytes) + self.logger.debug("Probe result: sample_rate=%d, bit_depth=%d", *result) + return result + except asyncio.CancelledError: + raise + except Exception: + self.logger.debug("Stream probe failed for codec=%s", codec) + return 0, 0 + + async def _refresh_stream_url( self, - track_item_id: str, - current_url: str, - current_key_hex: str, + streamdetails: StreamDetails, http_status: int, bytes_yielded: int, attempt: int, max_retries: int, - ) -> tuple[str, str] | None: - """Re-fetch an expired encrypted stream URL. + ) -> bool: + """Re-fetch an expired stream URL (works for both raw and encraw). - Called when the CDN responds with 4xx (URL expired or access revoked). + Updates streamdetails.data in-place with new URL (and key for encraw). - :return: (new_url, new_key_hex) on success, or None if retries exhausted. + :return: True on success, False if retries exhausted. """ if attempt >= max_retries: - return None - raw_track_id = self._track_id_from_item_id(track_item_id) + return False + data = streamdetails.data + track_id = self._track_id_from_item_id(streamdetails.item_id) self.logger.warning( - "Encrypted stream URL expired (HTTP %d) at %d bytes (attempt %d/%d) — re-fetching", + "Stream URL expired (HTTP %d) at %d bytes (attempt %d/%d) — re-fetching", http_status, bytes_yielded, attempt + 1, @@ -362,12 +514,20 @@ async def _refresh_encrypted_url( ) token = BYPASS_THROTTLER.set(True) try: - file_info = await self.client.get_track_file_info_lossless(raw_track_id) + file_info = await self.client.get_track_file_info( + track_id, + quality=data["fi_quality"], + codecs=data["fi_codecs"], + transport=data.get("transport", TRANSPORT_RAW), + ) finally: BYPASS_THROTTLER.reset(token) if file_info and file_info.get("url"): - return file_info["url"], file_info.get("key", current_key_hex) - return None + data["url"] = file_info["url"] + if "decryption_key" in data and file_info.get("key"): + data["decryption_key"] = file_info["key"] + return True + return False async def _decrypt_response_stream( self, @@ -441,14 +601,14 @@ def _handle_stream_error( attempt += 1 if attempt <= max_retries: self.logger.warning( - "Encrypted stream %s at %d bytes (attempt %d/%d) — retrying", + "Stream %s at %d bytes (attempt %d/%d) — retrying", label, bytes_yielded, attempt, max_retries, ) return attempt, delay - raise MediaNotFoundError(f"Encrypted stream {label} after retries were exhausted") from err + raise MediaNotFoundError(f"Stream {label} after retries were exhausted") from err @staticmethod def _is_content_range_eof(headers: Any, window_end: int) -> bool: @@ -469,110 +629,217 @@ def _is_content_range_eof(headers: Any, window_end: int) -> bool: except ValueError: return False - async def get_audio_stream( - self, streamdetails: StreamDetails, seek_position: int = 0 + async def _iter_raw_response( + self, + response: Any, + bytes_delivered: int, + block_start: int, ) -> AsyncGenerator[bytes, None]: - """Return the audio stream for the provider item with on-the-fly decryption. + """Yield raw (unencrypted) chunks from one HTTP response. + + If the server ignored the Range header (200 instead of 206), skips the + already-delivered prefix transparently. - Downloads and decrypts the encrypted stream in windowed Range requests of - _RANGE_WINDOW bytes each. Yandex CDN drops TCP every ~6-7 MB per connection; - keeping each request at 4 MB prevents that limit from being reached. + :param response: aiohttp ClientResponse (open context manager). + :param bytes_delivered: Total bytes already sent to the caller. + :param block_start: Requested Range start offset. + :return: Async generator yielding raw audio bytes. + """ + range_ignored = response.status == 200 and block_start > 0 + skip_bytes = bytes_delivered if range_ignored else 0 + async for raw_chunk in response.content.iter_chunked(_CHUNK_SIZE): + if skip_bytes > 0: + if len(raw_chunk) <= skip_bytes: + skip_bytes -= len(raw_chunk) + continue + raw_chunk = raw_chunk[skip_bytes:] # noqa: PLW2901 + skip_bytes = 0 + if raw_chunk: + yield raw_chunk + + async def _handle_expired_url( + self, + streamdetails: StreamDetails, + response_status: int, + bytes_yielded: int, + attempt: int, + max_retries: int, + ) -> bytes | None: + """Handle URL expiry (401/403/410) by refreshing and returning updated key. - On connection drop (ClientPayloadError, ServerDisconnectedError), the current - window is retried with a flat short backoff (0.5s/1.0s/2.0s). - On read stall (asyncio.TimeoutError), the current window is retried with - exponential backoff (2s/4s/8s). - On URL expiry (HTTP 4xx), re-fetches the URL and resumes from bytes_yielded. - Up to max_retries retries per window; the retry counter resets on each - successful window so long tracks get the same protection as short ones. + :return: Updated AES key bytes (or empty bytes for raw), None if exhausted. + :raises MediaNotFoundError: When refresh fails after retries exhausted. + """ + if not await self._refresh_stream_url( + streamdetails, + response_status, + bytes_yielded, + attempt, + max_retries, + ): + raise MediaNotFoundError( + f"Stream URL expired (HTTP {response_status}) after retries exhausted" + ) + data = streamdetails.data + if "decryption_key" in data: + try: + return bytes.fromhex(data["decryption_key"]) + except ValueError as err: + raise MediaNotFoundError(f"Invalid decryption key: {err}") from err + return b"" - If the server ignores a Range header (returns 200 instead of 206), the decryptor - is reset to position 0 so decryption stays consistent with the restarted byte stream. + @staticmethod + def _validate_encryption_key(data: dict[str, Any]) -> tuple[bool, bytes | None]: + """Validate and extract encryption parameters from stream data. - :param streamdetails: Stream details containing encrypted URL and key. - :param seek_position: Always 0 (seeking delegated to ffmpeg via allow_seek=True). - :return: Async generator yielding decrypted audio bytes. + :return: (is_encrypted, key_bytes) tuple. + :raises MediaNotFoundError: If AES key length is invalid. """ - encrypted_url: str = streamdetails.data["encrypted_url"] - track_item_id: str = streamdetails.item_id - key_hex: str = streamdetails.data["decryption_key"] - key_bytes = bytes.fromhex(key_hex) + if "decryption_key" not in data: + return False, None + try: + key_bytes = bytes.fromhex(data["decryption_key"]) + except ValueError as err: + raise MediaNotFoundError(f"Invalid decryption key: {err}") from err if len(key_bytes) not in (16, 24, 32): raise MediaNotFoundError(f"Unsupported AES key length: {len(key_bytes)} bytes") + return True, key_bytes + + def _calculate_seek_offset( + self, data: dict[str, Any], seek_position: int, is_encrypted: bool + ) -> int: + """Calculate initial byte offset for raw transport seeking. + + :param data: Stream data dict (must contain 'bit_rate' in kbps). + :param seek_position: Seek offset in seconds. + :param is_encrypted: Whether the stream uses AES encryption. + :return: Byte offset to start streaming from (0 if not applicable). + """ + if seek_position <= 0 or is_encrypted: + return 0 + bit_rate = data.get("bit_rate") or 0 + if not bit_rate: + return 0 + byte_offset = seek_position * bit_rate * 1000 // 8 + self.logger.debug( + "Seeking to %ds: byte offset %d (bitrate %d kbps)", + seek_position, + byte_offset, + bit_rate, + ) + return byte_offset + + async def get_audio_stream( + self, streamdetails: StreamDetails, seek_position: int = 0 + ) -> AsyncGenerator[bytes, None]: + """Return the audio stream via windowed Range requests. + + Handles both raw (direct) and encraw (AES-CTR encrypted) transports. + Downloads in windowed Range requests of _RANGE_WINDOW bytes each to prevent + Yandex CDN from dropping slow-consumer TCP connections. + + On connection drop: flat short backoff (0.5s/1.0s/2.0s). + On read stall: exponential backoff (2s/4s/8s). + On URL expiry (HTTP 4xx): re-fetches URL and resumes from bytes_yielded. + Retry counter resets after each successful window. + + :param streamdetails: Stream details with URL (and optional decryption key). + :param seek_position: Seek offset in seconds for raw transport (0 = from start). + :return: Async generator yielding audio bytes. + """ + data = streamdetails.data + is_encrypted, key_bytes = self._validate_encryption_key(data) + initial_byte_offset = self._calculate_seek_offset(data, seek_position, is_encrypted) - block_size = 16 # AES-CTR block size in bytes max_retries = 6 - bytes_yielded = 0 # total decrypted bytes delivered to caller - attempt = 0 # retry counter; resets to 0 after each successful window + bytes_yielded = initial_byte_offset + attempt = 0 retry_delay: float = 0.0 while True: if attempt > 0: await asyncio.sleep(retry_delay) - block_start = (bytes_yielded // block_size) * block_size + block_start = ( + (bytes_yielded // _AES_BLOCK_SIZE) * _AES_BLOCK_SIZE + if is_encrypted + else bytes_yielded + ) window_end = block_start + _RANGE_WINDOW - 1 - headers = {"Range": f"bytes={block_start}-{window_end}"} try: async with self.mass.http_session.get( - encrypted_url, headers=headers, timeout=_STREAM_TIMEOUT + data["url"], + headers={"Range": f"bytes={block_start}-{window_end}"}, + timeout=_STREAM_TIMEOUT, ) as response: if response.status in (401, 403, 410): - # URL expired — re-fetch via helper and retry - refreshed = await self._refresh_encrypted_url( - track_item_id, - encrypted_url, - key_hex, + new_key = await self._handle_expired_url( + streamdetails, response.status, bytes_yielded, attempt, max_retries, ) - if refreshed is None: - raise MediaNotFoundError( - f"Encrypted stream URL expired (HTTP {response.status}) " - "after retries exhausted" - ) - encrypted_url, key_hex = refreshed - key_bytes = bytes.fromhex(key_hex) + if is_encrypted: + key_bytes = new_key + attempt += 1 retry_delay = 0.0 - attempt += 1 # consume one retry slot, same as TCP-drop path continue try: response.raise_for_status() except Exception as err: - raise MediaNotFoundError( - f"Failed to fetch encrypted stream: {err}" - ) from err + raise MediaNotFoundError(f"Failed to fetch stream: {err}") from err bytes_before = bytes_yielded - async for chunk in self._decrypt_response_stream( - response, key_bytes, block_size, bytes_yielded - ): - bytes_yielded += len(chunk) - yield chunk - - # window complete — check if EOF - window_got = bytes_yielded - bytes_before - if response.status == 200 or window_got < _RANGE_WINDOW: - return # full file received or last partial window - # Exact-boundary guard: if file size is an exact multiple of - # _RANGE_WINDOW the size check above won't catch EOF. - # Use Content-Range to confirm no bytes remain. + if is_encrypted: + if key_bytes is None: + raise MediaNotFoundError("Missing decryption key") + block_skip = bytes_before - block_start + async for chunk in self._decrypt_response_stream( + response, + key_bytes, + _AES_BLOCK_SIZE, + bytes_yielded, + ): + bytes_yielded += len(chunk) + yield chunk + else: + range_ignored = response.status == 200 and block_start > 0 + block_skip = bytes_before if range_ignored else 0 + async for chunk in self._iter_raw_response( + response, + bytes_before, + block_start, + ): + bytes_yielded += len(chunk) + yield chunk + + received = (bytes_yielded - bytes_before) + block_skip + if response.status == 200 or received < _RANGE_WINDOW: + return if self._is_content_range_eof(response.headers, window_end): return - # more data expected: advance to next window attempt = 0 retry_delay = 0.0 except asyncio.CancelledError: - raise # propagate cancellation immediately, do not retry + raise except (ClientPayloadError, ServerDisconnectedError) as err: attempt, retry_delay = self._handle_stream_error( - err, attempt, max_retries, bytes_yielded, _TCP_DROP_DELAYS, "dropped" + err, + attempt, + max_retries, + bytes_yielded, + _TCP_DROP_DELAYS, + "dropped", ) except TimeoutError as err: attempt, retry_delay = self._handle_stream_error( - err, attempt, max_retries, bytes_yielded, _STALL_DELAYS, "stalled" + err, + attempt, + max_retries, + bytes_yielded, + _STALL_DELAYS, + "stalled", ) diff --git a/tests/providers/kion_music/test_integration.py b/tests/providers/kion_music/test_integration.py deleted file mode 100644 index e3e969b726..0000000000 --- a/tests/providers/kion_music/test_integration.py +++ /dev/null @@ -1,354 +0,0 @@ -"""Integration tests for the KION Music provider with in-process Music Assistant.""" - -from __future__ import annotations - -import json -import pathlib -from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any, cast -from unittest import mock - -import pytest -from music_assistant_models.enums import ContentType, MediaType, StreamType -from yandex_music import Album as YandexAlbum -from yandex_music import Artist as YandexArtist -from yandex_music import Playlist as YandexPlaylist -from yandex_music import Track as YandexTrack - -from music_assistant.mass import MusicAssistant -from music_assistant.models.music_provider import MusicProvider -from tests.common import wait_for_sync_completion - -if TYPE_CHECKING: - from music_assistant_models.config_entries import ProviderConfig - -FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures" -_DE_JSON_CLIENT = type("ClientStub", (), {"report_unknown_fields": False})() - - -def _load_json(path: pathlib.Path) -> dict[str, Any]: - """Load JSON fixture.""" - with open(path) as f: - return cast("dict[str, Any]", json.load(f)) - - -def _load_kion_objects() -> tuple[Any, Any, Any, Any]: - """Load Artist, Album, Track, Playlist from fixtures for mock client.""" - artist = YandexArtist.de_json( - _load_json(FIXTURES_DIR / "artists" / "minimal.json"), _DE_JSON_CLIENT - ) - album = YandexAlbum.de_json( - _load_json(FIXTURES_DIR / "albums" / "minimal.json"), _DE_JSON_CLIENT - ) - track = YandexTrack.de_json( - _load_json(FIXTURES_DIR / "tracks" / "minimal.json"), _DE_JSON_CLIENT - ) - playlist = YandexPlaylist.de_json( - _load_json(FIXTURES_DIR / "playlists" / "minimal.json"), _DE_JSON_CLIENT - ) - return artist, album, track, playlist - - -def _make_search_result(track: Any, album: Any, artist: Any, playlist: Any) -> Any: - """Build a Search-like object with .tracks.results, .albums.results, etc.""" - return type( - "Search", - (), - { - "tracks": type("TracksResult", (), {"results": [track]})(), - "albums": type("AlbumsResult", (), {"results": [album]})(), - "artists": type("ArtistsResult", (), {"results": [artist]})(), - "playlists": type("PlaylistsResult", (), {"results": [playlist]})(), - }, - )() - - -def _make_download_info( - codec: str = "mp3", - direct_link: str = "https://example.com/kion_track.mp3", - bitrate_in_kbps: int = 320, -) -> Any: - """Build DownloadInfo-like object for streaming.""" - return type( - "DownloadInfo", - (), - { - "direct_link": direct_link, - "codec": codec, - "bitrate_in_kbps": bitrate_in_kbps, - }, - )() - - -@pytest.fixture -async def kion_music_provider( - mass: MusicAssistant, -) -> AsyncGenerator[ProviderConfig, None]: - """Configure KION Music provider with mocked API client and add to mass.""" - artist, album, track, playlist = _load_kion_objects() - search_result = _make_search_result(track, album, artist, playlist) - download_info = _make_download_info() - - # Album with volumes for get_album_tracks - album_with_volumes = type( - "AlbumWithVolumes", - (), - { - "id": album.id, - "title": album.title, - "volumes": [[track]], - "artists": album.artists if hasattr(album, "artists") else [], - "year": getattr(album, "year", None), - "release_date": getattr(album, "release_date", None), - "genre": getattr(album, "genre", None), - "cover_uri": getattr(album, "cover_uri", None), - "og_image": getattr(album, "og_image", None), - "type": getattr(album, "type", "album"), - "available": getattr(album, "available", True), - }, - )() - - with mock.patch( - "music_assistant.providers.kion_music.provider.KionMusicClient" - ) as mock_client_class: - mock_client = mock.AsyncMock() - mock_client_class.return_value = mock_client - - mock_client.connect = mock.AsyncMock(return_value=True) - mock_client.user_id = 12345 - - mock_client.get_liked_tracks = mock.AsyncMock(return_value=[]) - mock_client.get_liked_albums = mock.AsyncMock(return_value=[]) - mock_client.get_liked_artists = mock.AsyncMock(return_value=[]) - mock_client.get_user_playlists = mock.AsyncMock(return_value=[playlist]) - - mock_client.search = mock.AsyncMock(return_value=search_result) - mock_client.get_track = mock.AsyncMock(return_value=track) - mock_client.get_tracks = mock.AsyncMock(return_value=[track]) - mock_client.get_album = mock.AsyncMock(return_value=album) - mock_client.get_album_with_tracks = mock.AsyncMock(return_value=album_with_volumes) - mock_client.get_artist = mock.AsyncMock(return_value=artist) - mock_client.get_artist_albums = mock.AsyncMock(return_value=[album]) - mock_client.get_artist_tracks = mock.AsyncMock(return_value=[track]) - mock_client.get_playlist = mock.AsyncMock(return_value=playlist) - mock_client.get_track_download_info = mock.AsyncMock(return_value=[download_info]) - - async with wait_for_sync_completion(mass): - config = await mass.config.save_provider_config( - "kion_music", - {"token": "mock_kion_token", "quality": "high"}, - ) - await mass.music.start_sync() - - yield config - - -@pytest.fixture -async def kion_music_provider_lossless( - mass: MusicAssistant, -) -> AsyncGenerator[ProviderConfig, None]: - """Configure KION Music with quality=lossless and mock returning MP3 + FLAC.""" - artist, album, track, playlist = _load_kion_objects() - search_result = _make_search_result(track, album, artist, playlist) - mp3_info = _make_download_info( - codec="mp3", - direct_link="https://example.com/kion_track.mp3", - bitrate_in_kbps=320, - ) - flac_info = _make_download_info( - codec="flac", - direct_link="https://example.com/kion_track.flac", - bitrate_in_kbps=0, - ) - download_infos = [mp3_info, flac_info] - - album_with_volumes = type( - "AlbumWithVolumes", - (), - { - "id": album.id, - "title": album.title, - "volumes": [[track]], - "artists": album.artists if hasattr(album, "artists") else [], - "year": getattr(album, "year", None), - "release_date": getattr(album, "release_date", None), - "genre": getattr(album, "genre", None), - "cover_uri": getattr(album, "cover_uri", None), - "og_image": getattr(album, "og_image", None), - "type": getattr(album, "type", "album"), - "available": getattr(album, "available", True), - }, - )() - - with mock.patch( - "music_assistant.providers.kion_music.provider.KionMusicClient" - ) as mock_client_class: - mock_client = mock.AsyncMock() - mock_client_class.return_value = mock_client - - mock_client.connect = mock.AsyncMock(return_value=True) - mock_client.user_id = 12345 - - mock_client.get_liked_tracks = mock.AsyncMock(return_value=[]) - mock_client.get_liked_albums = mock.AsyncMock(return_value=[]) - mock_client.get_liked_artists = mock.AsyncMock(return_value=[]) - mock_client.get_user_playlists = mock.AsyncMock(return_value=[playlist]) - - mock_client.search = mock.AsyncMock(return_value=search_result) - mock_client.get_track = mock.AsyncMock(return_value=track) - mock_client.get_tracks = mock.AsyncMock(return_value=[track]) - mock_client.get_album = mock.AsyncMock(return_value=album) - mock_client.get_album_with_tracks = mock.AsyncMock(return_value=album_with_volumes) - mock_client.get_artist = mock.AsyncMock(return_value=artist) - mock_client.get_artist_albums = mock.AsyncMock(return_value=[album]) - mock_client.get_artist_tracks = mock.AsyncMock(return_value=[track]) - mock_client.get_playlist = mock.AsyncMock(return_value=playlist) - mock_client.get_track_file_info_lossless = mock.AsyncMock(return_value=None) - mock_client.get_track_download_info = mock.AsyncMock(return_value=download_infos) - - async with wait_for_sync_completion(mass): - config = await mass.config.save_provider_config( - "kion_music", - {"token": "mock_kion_token", "quality": "lossless"}, - ) - await mass.music.start_sync() - - yield config - - -def _get_kion_provider(mass: MusicAssistant) -> MusicProvider | None: - """Get KION Music provider instance from mass.""" - for provider in mass.music.providers: - if provider.domain == "kion_music": - return provider - return None - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_registration_and_sync(mass: MusicAssistant) -> None: - """Test that provider is registered and sync completes.""" - prov = _get_kion_provider(mass) - assert prov is not None - assert prov.domain == "kion_music" - assert prov.instance_id - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_search(mass: MusicAssistant) -> None: - """Test search returns results from kion_music.""" - results = await mass.music.search("test query", [MediaType.TRACK], limit=5) - kion_tracks = [t for t in results.tracks if t.provider and "kion_music" in t.provider] - assert len(kion_tracks) >= 0 - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_get_artist(mass: MusicAssistant) -> None: - """Test getting artist by id.""" - prov = _get_kion_provider(mass) - assert prov is not None - artist = await prov.get_artist("100") - assert artist is not None - assert artist.name - assert artist.provider == prov.instance_id - assert artist.item_id == "100" - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_get_album(mass: MusicAssistant) -> None: - """Test getting album by id.""" - prov = _get_kion_provider(mass) - assert prov is not None - album = await prov.get_album("300") - assert album is not None - assert album.name - assert album.provider == prov.instance_id - assert album.item_id == "300" - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_get_track(mass: MusicAssistant) -> None: - """Test getting track by id.""" - prov = _get_kion_provider(mass) - assert prov is not None - track = await prov.get_track("400") - assert track is not None - assert track.name - assert track.provider == prov.instance_id - assert track.item_id == "400" - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_get_album_tracks(mass: MusicAssistant) -> None: - """Test getting album tracks.""" - prov = _get_kion_provider(mass) - assert prov is not None - tracks = await prov.get_album_tracks("300") - assert isinstance(tracks, list) - assert len(tracks) >= 0 - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_get_playlist_tracks(mass: MusicAssistant) -> None: - """Test getting playlist tracks.""" - prov = _get_kion_provider(mass) - assert prov is not None - tracks = await prov.get_playlist_tracks("12345:3", page=0) - assert isinstance(tracks, list) - assert len(tracks) >= 0 - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_get_playlist_tracks_page_gt_zero_returns_empty(mass: MusicAssistant) -> None: - """Test that page > 0 returns empty list (no server-side pagination).""" - prov = _get_kion_provider(mass) - assert prov is not None - tracks = await prov.get_playlist_tracks("12345:3", page=1) - assert tracks == [] - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_get_stream_details(mass: MusicAssistant) -> None: - """Test stream details retrieval.""" - prov = _get_kion_provider(mass) - assert prov is not None - stream_details = await prov.get_stream_details("400", MediaType.TRACK) - assert stream_details is not None - assert stream_details.stream_type == StreamType.HTTP - assert stream_details.path == "https://example.com/kion_track.mp3" - - -@pytest.mark.usefixtures("kion_music_provider_lossless") -async def test_get_stream_details_returns_flac_when_lossless_selected( - mass: MusicAssistant, -) -> None: - """When quality=lossless and API returns MP3+FLAC, stream details use FLAC.""" - prov = _get_kion_provider(mass) - assert prov is not None - stream_details = await prov.get_stream_details("400", MediaType.TRACK) - assert stream_details is not None - assert stream_details.audio_format.content_type == ContentType.FLAC - assert stream_details.path == "https://example.com/kion_track.flac" - - -@pytest.mark.usefixtures("kion_music_provider") -async def test_library_items(mass: MusicAssistant) -> None: - """Test library artists, albums, tracks, playlists.""" - prov = _get_kion_provider(mass) - assert prov is not None - instance_id = prov.instance_id - - artists = await mass.music.artists.library_items() - kion_artists = [a for a in artists if a.provider == instance_id] - assert len(kion_artists) >= 0 - - albums = await mass.music.albums.library_items() - kion_albums = [a for a in albums if a.provider == instance_id] - assert len(kion_albums) >= 0 - - tracks = await mass.music.tracks.library_items() - kion_tracks = [t for t in tracks if t.provider == instance_id] - assert len(kion_tracks) >= 0 - - playlists = await mass.music.playlists.library_items() - kion_playlists = [p for p in playlists if p.provider == instance_id] - assert len(kion_playlists) >= 0 diff --git a/tests/providers/yandex_music/conftest.py b/tests/providers/yandex_music/conftest.py index cada4cc804..c0ead409a8 100644 --- a/tests/providers/yandex_music/conftest.py +++ b/tests/providers/yandex_music/conftest.py @@ -32,6 +32,18 @@ def get_item_mapping(self, media_type: MediaType | str, key: str, name: str) -> ) +class ConfigStub: + """Minimal config stub for provider tests.""" + + def __init__(self, values: dict[str, object] | None = None) -> None: + """Initialize with optional config values.""" + self._values = values or {} + + def get_value(self, key: str, default: object = None) -> object: + """Return config value or default.""" + return self._values.get(key, default) + + class StreamingProviderStub: """Minimal provider stub for streaming tests (no Mock). @@ -46,6 +58,7 @@ def __init__(self) -> None: """Initialize stub with minimal client.""" self.client = type("ClientStub", (), {"user_id": 12345})() self.mass = type("MassStub", (), {})() + self.config = ConfigStub() self._warning_count = 0 def _count_warning(self, *args: object, **kwargs: object) -> None: @@ -93,6 +106,7 @@ def __init__(self) -> None: """Initialize stub with tracking logger.""" self.client = type("ClientStub", (), {"user_id": 12345})() self.mass = type("MassStub", (), {})() + self.config = ConfigStub() self.logger = TrackingLogger() diff --git a/tests/providers/yandex_music/fixtures/audiobooks/basic.json b/tests/providers/yandex_music/fixtures/audiobooks/basic.json new file mode 100644 index 0000000000..984b1fba9e --- /dev/null +++ b/tests/providers/yandex_music/fixtures/audiobooks/basic.json @@ -0,0 +1,25 @@ +{ + "id": 800, + "title": "Sample Audiobook", + "available": true, + "artists": [ + { + "id": 81, + "name": "Book Author" + }, + { + "id": 82, + "name": "Co-Author" + } + ], + "labels": [ + { + "id": 2, + "name": "Audio Publisher" + } + ], + "type": "audiobook", + "meta_type": "podcast", + "description": "A sample audiobook description.", + "cover_uri": "avatars.yandex.net/get-music-content/ab/cover/%%" +} diff --git a/tests/providers/yandex_music/fixtures/podcast_episodes/basic.json b/tests/providers/yandex_music/fixtures/podcast_episodes/basic.json new file mode 100644 index 0000000000..b3eeb6b564 --- /dev/null +++ b/tests/providers/yandex_music/fixtures/podcast_episodes/basic.json @@ -0,0 +1,24 @@ +{ + "id": 900, + "title": "Episode One", + "available": true, + "duration_ms": 1800000, + "artists": [ + { + "id": 71, + "name": "Podcast Author" + } + ], + "albums": [ + { + "id": 700, + "title": "Sample Podcast", + "type": "podcast", + "meta_type": "podcast", + "cover_uri": "avatars.yandex.net/get-music-content/pod/cover/%%" + } + ], + "short_description": "Episode summary goes here.", + "content_warning": "explicit", + "cover_uri": "avatars.yandex.net/get-music-content/pod/ep1/%%" +} diff --git a/tests/providers/yandex_music/fixtures/podcasts/basic.json b/tests/providers/yandex_music/fixtures/podcasts/basic.json new file mode 100644 index 0000000000..500208831e --- /dev/null +++ b/tests/providers/yandex_music/fixtures/podcasts/basic.json @@ -0,0 +1,25 @@ +{ + "id": 700, + "title": "Sample Podcast", + "available": true, + "artists": [ + { + "id": 71, + "name": "Podcast Author" + } + ], + "labels": [ + { + "id": 1, + "name": "Podcast Studio" + } + ], + "type": "podcast", + "meta_type": "podcast", + "track_count": 42, + "description": "A sample podcast description.", + "short_description": "Short desc", + "content_warning": "explicit", + "cover_uri": "avatars.yandex.net/get-music-content/pod/cover/%%", + "release_date": "2024-03-15T00:00:00+00:00" +} diff --git a/tests/providers/yandex_music/test_api_client.py b/tests/providers/yandex_music/test_api_client.py index 9e8dcb29b2..92379cbbd5 100644 --- a/tests/providers/yandex_music/test_api_client.py +++ b/tests/providers/yandex_music/test_api_client.py @@ -6,11 +6,14 @@ import hashlib import hmac import re +from collections.abc import Mapping +from typing import Any from unittest import mock import pytest -from music_assistant_models.errors import ResourceTemporarilyUnavailable -from yandex_music.exceptions import NetworkError +from music_assistant_models.errors import LoginFailed, ResourceTemporarilyUnavailable +from ya_passport_auth import SecretStr +from yandex_music.exceptions import NetworkError, UnauthorizedError from yandex_music.rotor.dashboard import Dashboard from yandex_music.rotor.station_result import StationResult from yandex_music.utils.sign_request import DEFAULT_SIGN_KEY @@ -29,7 +32,7 @@ def _make_client() -> tuple[YandexMusicClient, mock.AsyncMock]: :return: Tuple of (YandexMusicClient, mock_underlying_client). """ - client = YandexMusicClient(token="fake_token") + client = YandexMusicClient(token=SecretStr("fake_token")) mock_underlying = mock.AsyncMock() client._client = mock_underlying client._user_id = 12345 @@ -124,69 +127,442 @@ async def test_get_tracks_retry_on_network_error_both_fail() -> None: assert underlying.tracks.await_count == 2 -# -- get_my_wave_tracks -------------------------------------------------------- +async def test_send_rotor_station_feedback_track_started() -> None: + """send_rotor_station_feedback delegates trackStarted to public helper.""" + client, underlying = _make_client() + underlying.rotor_station_feedback_track_started = mock.AsyncMock(return_value=True) + + result = await client.send_rotor_station_feedback( + "user:onyourwave", + "trackStarted", + track_id="12345", + batch_id="batch_xyz", + ) + + assert result is True + underlying.rotor_station_feedback_track_started.assert_awaited_once() + args, kwargs = underlying.rotor_station_feedback_track_started.await_args + assert args[0] == "user:onyourwave" + assert kwargs["track_id"] == "12345" + assert kwargs["batch_id"] == "batch_xyz" + assert "timestamp" in kwargs + + +async def test_send_rotor_station_feedback_radio_started() -> None: + """send_rotor_station_feedback delegates radioStarted to public helper with from_.""" + client, underlying = _make_client() + underlying.rotor_station_feedback_radio_started = mock.AsyncMock(return_value=True) + + result = await client.send_rotor_station_feedback( + "user:onyourwave", + "radioStarted", + batch_id="batch_xyz", + ) + + assert result is True + underlying.rotor_station_feedback_radio_started.assert_awaited_once() + _, kwargs = underlying.rotor_station_feedback_radio_started.await_args + assert kwargs["from_"] == "YandexMusicDesktopAppWindows" + assert kwargs["batch_id"] == "batch_xyz" -async def test_get_my_wave_tracks_returns_tracks_and_batch_id() -> None: - """get_my_wave_tracks calls rotor_station_tracks and returns ordered tracks and batch_id.""" +async def test_send_rotor_station_feedback_track_finished() -> None: + """send_rotor_station_feedback delegates trackFinished with total_played_seconds.""" client, underlying = _make_client() + underlying.rotor_station_feedback_track_finished = mock.AsyncMock(return_value=True) + + result = await client.send_rotor_station_feedback( + "user:onyourwave", + "trackFinished", + track_id="12345", + total_played_seconds=42, + batch_id="batch_xyz", + ) + + assert result is True + underlying.rotor_station_feedback_track_finished.assert_awaited_once() + _, kwargs = underlying.rotor_station_feedback_track_finished.await_args + assert kwargs["track_id"] == "12345" + assert kwargs["total_played_seconds"] == 42.0 + assert kwargs["batch_id"] == "batch_xyz" + + +async def test_send_rotor_station_feedback_skip() -> None: + """send_rotor_station_feedback delegates skip to public helper.""" + client, underlying = _make_client() + underlying.rotor_station_feedback_skip = mock.AsyncMock(return_value=True) + + result = await client.send_rotor_station_feedback( + "user:onyourwave", + "skip", + track_id="12345", + total_played_seconds=10, + ) + + assert result is True + underlying.rotor_station_feedback_skip.assert_awaited_once() + _, kwargs = underlying.rotor_station_feedback_skip.await_args + assert kwargs["track_id"] == "12345" + assert kwargs["total_played_seconds"] == 10.0 + + +# -- rotor session API (/rotor/session/*) -------------------------------------- - seq_track = type("TrackShort", (), {"id": 100, "track_id": 100})() - sequence_item = type("SequenceItem", (), {"track": seq_track})() - result_obj = type( - "StationTracksResult", - (), - {"sequence": [sequence_item], "batch_id": "batch_abc"}, - )() - underlying.rotor_station_tracks = mock.AsyncMock(return_value=result_obj) - full_track = type("Track", (), {"id": 100, "title": "My Wave Track"})() - underlying.tracks = mock.AsyncMock(return_value=[full_track]) +def _patch_rotor_session_request(client: YandexMusicClient, response: object) -> mock.AsyncMock: + """Install a mocked _rotor_session_request on the client and return the mock.""" + req_mock = mock.AsyncMock(return_value=response) + client._rotor_session_request = req_mock # type: ignore[method-assign] + return req_mock - tracks, batch_id = await client.get_my_wave_tracks() - underlying.rotor_station_tracks.assert_awaited_once() - assert batch_id == "batch_abc" +def _patch_get_tracks(client: YandexMusicClient, tracks: list[object]) -> mock.AsyncMock: + """Install a mocked get_tracks on the client and return the mock.""" + tracks_mock = mock.AsyncMock(return_value=tracks) + client.get_tracks = tracks_mock # type: ignore[method-assign] + return tracks_mock + + +def _call_args(m: mock.AsyncMock) -> tuple[tuple[Any, ...], Mapping[str, Any]]: + """Return (args, kwargs) from the most recent await on ``m``. + + Raises AssertionError when the mock was never awaited — intentionally + surfacing missed setup rather than letting mypy's `None is not iterable` + propagate into destructuring sites. + """ + call = m.await_args + assert call is not None, "mock was not awaited" + return call.args, call.kwargs + + +async def test_rotor_session_new_posts_expected_body_and_returns_session() -> None: + """rotor_session_new POSTs to /rotor/session/new with wave-model flags and parses result.""" + client, underlying = _make_client() + del underlying # unused; session API bypasses MarshalX client + response = { + "radioSessionId": "sess_abc", + "batchId": "batch_1", + "sequence": [{"track": {"id": 100, "title": "T"}, "liked": False}], + } + req_mock = _patch_rotor_session_request(client, response) + _patch_get_tracks(client, [type("T", (), {"id": 100})()]) + + session_id, tracks, batch_id = await client.rotor_session_new("user:onyourwave") + + req_mock.assert_awaited_once() + args, _ = _call_args(req_mock) + path, body = args[0], args[1] + assert path == "new" + assert body["seeds"] == ["user:onyourwave"] + assert body["queue"] == [] + assert body["includeTracksInResponse"] is True + assert body["includeWaveModel"] is True + assert body["interactive"] is True + assert session_id == "sess_abc" + assert batch_id == "batch_1" assert len(tracks) == 1 assert tracks[0].id == 100 -async def test_get_my_wave_tracks_empty_sequence_returns_empty() -> None: - """When rotor returns no sequence, get_my_wave_tracks returns ([], batch_id or None).""" +async def test_rotor_session_new_appends_settings_as_seeds() -> None: + """rotor_session_new appends settingDiversity / settingMoodEnergy / settingLanguage seeds.""" client, underlying = _make_client() + del underlying + req_mock = _patch_rotor_session_request( + client, {"radioSessionId": "s1", "batchId": "b1", "sequence": []} + ) + _patch_get_tracks(client, []) + + await client.rotor_session_new( + "user:onyourwave", + settings={"diversity": "discover", "moodEnergy": "calm", "language": "russian"}, + ) + + args, _ = _call_args(req_mock) + body = args[1] + assert body["seeds"] == [ + "user:onyourwave", + "settingDiversity:discover", + "settingMoodEnergy:calm", + "settingLanguage:russian", + ] + - result_obj = type("StationTracksResult", (), {"sequence": [], "batch_id": None})() - underlying.rotor_station_tracks = mock.AsyncMock(return_value=result_obj) +async def test_rotor_session_new_returns_empty_on_missing_session_id() -> None: + """If the response lacks radioSessionId the call returns (None, [], None) without raising.""" + client, underlying = _make_client() + del underlying + _patch_rotor_session_request(client, None) - tracks, batch_id = await client.get_my_wave_tracks() + session_id, tracks, batch_id = await client.rotor_session_new("user:onyourwave") + assert session_id is None assert tracks == [] assert batch_id is None - underlying.tracks.assert_not_awaited() -async def test_send_rotor_station_feedback_posts() -> None: - """send_rotor_station_feedback POSTs to rotor feedback endpoint.""" +async def test_rotor_session_tracks_posts_current_track_queue() -> None: + """rotor_session_tracks POSTs {queue: [current_track_id]} and returns tracks + batch_id.""" client, underlying = _make_client() + del underlying + response = { + "batchId": "batch_2", + "sequence": [{"track": {"id": 200}}, {"track": {"id": 201}}], + } + req_mock = _patch_rotor_session_request(client, response) + _patch_get_tracks(client, [type("T", (), {"id": 200})(), type("T", (), {"id": 201})()]) + + tracks, batch_id = await client.rotor_session_tracks("sess_abc", current_track_id="100") + + args, _ = _call_args(req_mock) + path, body = args[0], args[1] + assert path == "sess_abc/tracks" + assert body == {"queue": ["100"]} + assert batch_id == "batch_2" + assert [t.id for t in tracks] == [200, 201] + + +async def test_rotor_session_feedback_radio_started_sends_from_field() -> None: + """RadioStarted event uses event.from=track_id (not trackId).""" + client, underlying = _make_client() + del underlying + req_mock = _patch_rotor_session_request(client, {"result": "ok"}) - underlying._request = mock.AsyncMock() - underlying.base_url = "https://api.music.yandex.net" - - result = await client.send_rotor_station_feedback( - "user:onyourwave", - "trackStarted", - track_id="12345", - batch_id="batch_xyz", + result = await client.rotor_session_feedback( + "sess_abc", "radioStarted", track_id="100", batch_id="batch_1" ) assert result is True - underlying._request.post.assert_awaited_once() - call_args = underlying._request.post.await_args - assert "rotor/station/user:onyourwave/feedback" in call_args[0][0] - body = call_args[0][1] - assert body["type"] == "trackStarted" - assert body["trackId"] == "12345" - assert body["batchId"] == "batch_xyz" + args, _ = _call_args(req_mock) + path, body = args[0], args[1] + assert path == "sess_abc/feedback" + assert body["batchId"] == "batch_1" + event = body["event"] + assert event["type"] == "radioStarted" + assert event["from"] == "100" + assert "trackId" not in event + assert "timestamp" in event + assert re.match(r"^\d{4}-\d{2}-\d{2}T", event["timestamp"]) + + +async def test_rotor_session_feedback_track_started_sends_track_id() -> None: + """TrackStarted event uses event.trackId (not from).""" + client, underlying = _make_client() + del underlying + req_mock = _patch_rotor_session_request(client, {"result": "ok"}) + + await client.rotor_session_feedback( + "sess_abc", "trackStarted", track_id="100", batch_id="batch_1" + ) + + args, _ = _call_args(req_mock) + body = args[1] + event = body["event"] + assert event["type"] == "trackStarted" + assert event["trackId"] == "100" + assert "from" not in event + assert "totalPlayedSeconds" not in event + + +async def test_rotor_session_feedback_track_finished_includes_seconds() -> None: + """TrackFinished event includes totalPlayedSeconds.""" + client, underlying = _make_client() + del underlying + req_mock = _patch_rotor_session_request(client, {"result": "ok"}) + + await client.rotor_session_feedback( + "sess_abc", + "trackFinished", + track_id="100", + total_played_seconds=42, + batch_id="batch_1", + ) + + args, _ = _call_args(req_mock) + body = args[1] + event = body["event"] + assert event["type"] == "trackFinished" + assert event["trackId"] == "100" + assert event["totalPlayedSeconds"] == 42 + + +async def test_rotor_session_feedback_skip_includes_seconds() -> None: + """Skip event includes totalPlayedSeconds and trackId.""" + client, underlying = _make_client() + del underlying + req_mock = _patch_rotor_session_request(client, {"result": "ok"}) + + await client.rotor_session_feedback( + "sess_abc", "skip", track_id="100", total_played_seconds=10, batch_id="batch_1" + ) + + args, _ = _call_args(req_mock) + body = args[1] + event = body["event"] + assert event["type"] == "skip" + assert event["trackId"] == "100" + assert event["totalPlayedSeconds"] == 10 + + +async def test_rotor_session_feedback_like_uses_trackid_without_seconds() -> None: + """like/dislike events use trackId but do NOT include totalPlayedSeconds.""" + client, underlying = _make_client() + del underlying + req_mock = _patch_rotor_session_request(client, {"result": "ok"}) + + await client.rotor_session_feedback("sess_abc", "like", track_id="100", batch_id="batch_1") + + args, _ = _call_args(req_mock) + body = args[1] + event = body["event"] + assert event["type"] == "like" + assert event["trackId"] == "100" + assert "totalPlayedSeconds" not in event + + +async def test_rotor_session_request_maps_unauthorized_to_login_failed() -> None: + """Expired/invalid token during /rotor/session/* surfaces as LoginFailed. + + Without this mapping the raw ``UnauthorizedError`` from the MarshalX + client would bubble up through browse / play paths and crash the + provider instead of triggering MA's re-auth prompt. + """ + client, underlying = _make_client() + # _do is awaited via _call_with_retry → _ensure_connected → returns our + # AsyncMock underlying client. The underlying client's ._request.post is + # what actually raises. + underlying._request = mock.MagicMock() + underlying._request.post = mock.AsyncMock(side_effect=UnauthorizedError("stale token")) + + with pytest.raises(LoginFailed): + await client._rotor_session_request("new", {"seeds": ["user:onyourwave"]}) + + +# -- get_similar_artists ------------------------------------------------------ + + +async def test_get_similar_artists_returns_list() -> None: + """get_similar_artists returns the similar_artists list from artists_similar().""" + client, underlying = _make_client() + similar = [type("Artist", (), {"id": i, "name": f"A{i}"})() for i in (1, 2, 3)] + result_obj = type("ArtistSimilar", (), {"similar_artists": similar})() + underlying.artists_similar = mock.AsyncMock(return_value=result_obj) + + result = await client.get_similar_artists("100") + + underlying.artists_similar.assert_awaited_once_with("100") + assert result == similar + + +async def test_get_similar_artists_respects_limit() -> None: + """get_similar_artists truncates results to the requested limit.""" + client, underlying = _make_client() + similar = [type("Artist", (), {"id": i})() for i in range(10)] + result_obj = type("ArtistSimilar", (), {"similar_artists": similar})() + underlying.artists_similar = mock.AsyncMock(return_value=result_obj) + + result = await client.get_similar_artists("100", limit=3) + + assert len(result) == 3 + assert [a.id for a in result] == [0, 1, 2] + + +async def test_get_similar_artists_handles_none_response() -> None: + """get_similar_artists returns [] when underlying call returns None.""" + client, underlying = _make_client() + underlying.artists_similar = mock.AsyncMock(return_value=None) + + result = await client.get_similar_artists("100") + + assert result == [] + + +async def test_get_similar_artists_handles_empty_field() -> None: + """get_similar_artists returns [] when similar_artists is empty/None.""" + client, underlying = _make_client() + result_obj = type("ArtistSimilar", (), {"similar_artists": None})() + underlying.artists_similar = mock.AsyncMock(return_value=result_obj) + + result = await client.get_similar_artists("100") + + assert result == [] + + +async def test_get_similar_artists_returns_empty_on_network_error() -> None: + """get_similar_artists returns [] when underlying raises NetworkError.""" + client, underlying = _make_client() + underlying.artists_similar = mock.AsyncMock( + side_effect=[NetworkError("timeout"), NetworkError("again")] + ) + + result = await client.get_similar_artists("100") + + assert result == [] + + +# -- get_pins / get_music_history / get_artist_about ------------------------- + + +async def test_get_pins_returns_list_object() -> None: + """get_pins forwards the underlying pins() result.""" + client, underlying = _make_client() + pins_obj = type("PinsList", (), {"pins": [type("Pin", (), {"type": "album_item"})()]})() + underlying.pins = mock.AsyncMock(return_value=pins_obj) + + result = await client.get_pins() + + underlying.pins.assert_awaited_once_with() + assert result is pins_obj + + +async def test_get_pins_returns_none_on_network_error() -> None: + """get_pins returns None when retries are exhausted.""" + client, underlying = _make_client() + underlying.pins = mock.AsyncMock(side_effect=NetworkError("boom")) + + result = await client.get_pins() + + assert result is None + + +async def test_get_music_history_returns_object() -> None: + """get_music_history forwards the underlying music_history() result.""" + client, underlying = _make_client() + history = type("MusicHistory", (), {"history_tabs": []})() + underlying.music_history = mock.AsyncMock(return_value=history) + + result = await client.get_music_history() + + underlying.music_history.assert_awaited_once_with() + assert result is history + + +async def test_get_music_history_returns_none_on_network_error() -> None: + """get_music_history returns None on persistent NetworkError.""" + client, underlying = _make_client() + underlying.music_history = mock.AsyncMock(side_effect=NetworkError("boom")) + + assert await client.get_music_history() is None + + +async def test_get_artist_about_returns_object() -> None: + """get_artist_about forwards the underlying artists_about() result.""" + client, underlying = _make_client() + about = type("ArtistAbout", (), {"description": "x", "stats": None})() + underlying.artists_about = mock.AsyncMock(return_value=about) + + result = await client.get_artist_about("42") + + underlying.artists_about.assert_awaited_once_with("42") + assert result is about + + +async def test_get_artist_about_returns_none_on_network_error() -> None: + """get_artist_about returns None on persistent NetworkError.""" + client, underlying = _make_client() + underlying.artists_about = mock.AsyncMock(side_effect=NetworkError("boom")) + + assert await client.get_artist_about("42") is None # -- LRC regex tests --------------------------------------------------------- @@ -361,6 +737,41 @@ async def test_get_dashboard_stations_returns_personalized_stations() -> None: underlying.rotor_stations_dashboard.assert_called_once() +# -- get_track_file_info: response key normalization ------------------------- + + +async def test_get_track_file_info_parses_camelcase_download_info() -> None: + """get_track_file_info parses the v3-style camelCase ``downloadInfo`` key. + + The yandex-music v3 client no longer recursively normalises camelCase keys + inside ``Response.result``. The raw JSON for /get-file-info comes back as + ``{"downloadInfo": {...}}`` — the provider must accept both shapes. + """ + client, underlying = _make_client() + + raw_response = { + "downloadInfo": { + "trackId": "132401416", + "quality": "lossless", + "codec": "flac-mp4", + "bitrate": 0, + "transport": "raw", + "url": "https://example.com/flac-mp4.bin", + "realId": "132401416", + } + } + underlying._request = mock.MagicMock() + underlying._request.get = mock.AsyncMock(return_value=raw_response) + underlying.base_url = "https://api.music.yandex.net" + + result = await client.get_track_file_info("132401416") + + assert result is not None + assert result["url"] == "https://example.com/flac-mp4.bin" + assert result["codec"] == "flac-mp4" + assert result["needs_decryption"] is False + + async def test_get_dashboard_stations_empty_on_error() -> None: """get_dashboard_stations() returns empty list on network error.""" client, underlying = _make_client() diff --git a/tests/providers/yandex_music/test_audiobook_progress.py b/tests/providers/yandex_music/test_audiobook_progress.py new file mode 100644 index 0000000000..eb009eede3 --- /dev/null +++ b/tests/providers/yandex_music/test_audiobook_progress.py @@ -0,0 +1,311 @@ +"""Tests for audiobook progress sync via play_audio in on_played/on_streamed.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, Mock + +import pytest +from music_assistant_models.enums import MediaType +from music_assistant_models.errors import ResourceTemporarilyUnavailable +from music_assistant_models.streamdetails import StreamDetails + +from music_assistant.providers.yandex_music.provider import YandexMusicProvider + + +@pytest.fixture +def provider_mock() -> Mock: + """Return a provider mock wired for audiobook progress reporting.""" + provider = Mock(spec=YandexMusicProvider) + provider.domain = "yandex_music" + provider.instance_id = "yandex_music_instance" + provider.logger = Mock() + provider.client = AsyncMock() + provider.client.play_audio = AsyncMock(return_value=True) + provider._audiobook_chapter_cache = {} + provider._audiobook_play_ids = {} + # real method so we don't have to replicate seek math in tests + provider._resolve_audiobook_seek = YandexMusicProvider._resolve_audiobook_seek.__get__( + provider, YandexMusicProvider + ) + provider._audiobook_progress_point = YandexMusicProvider._audiobook_progress_point.__get__( + provider, YandexMusicProvider + ) + provider._resolve_audiobook_chapter_map = AsyncMock() + return provider + + +@pytest.mark.asyncio +async def test_on_played_audiobook_reports_progress(provider_mock: Mock) -> None: + """on_played(AUDIOBOOK) resolves chapter from cache and calls play_audio.""" + chapter_ids = ["c1", "c2", "c3"] + chapter_durations_ms = [60_000, 120_000, 90_000] + provider_mock._resolve_audiobook_chapter_map.return_value = ( + chapter_ids, + chapter_durations_ms, + ) + + # position = 60 (end of c1) + 30 → middle of c2 + await YandexMusicProvider._report_audiobook_progress(provider_mock, "abook-42", 90) + + provider_mock.client.play_audio.assert_awaited_once() + kwargs = provider_mock.client.play_audio.await_args.kwargs + assert kwargs["track_id"] == "c2" + assert kwargs["album_id"] == "abook-42" + assert kwargs["track_length_seconds"] == 120 + assert kwargs["total_played_seconds"] == 30 + assert kwargs["end_position_seconds"] == 30 + # play_id created and persisted for the session + assert provider_mock._audiobook_play_ids["abook-42"] == kwargs["play_id"] + + +@pytest.mark.asyncio +async def test_on_played_audiobook_skips_when_no_chapter_map(provider_mock: Mock) -> None: + """No chapter map → skip play_audio (log at debug).""" + provider_mock._resolve_audiobook_chapter_map.return_value = ([], []) + + await YandexMusicProvider._report_audiobook_progress(provider_mock, "abook-42", 90) + + provider_mock.client.play_audio.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_on_played_audiobook_reuses_play_id(provider_mock: Mock) -> None: + """Successive on_played calls for the same audiobook share one play_id.""" + provider_mock._resolve_audiobook_chapter_map.return_value = (["c1"], [60_000]) + + await YandexMusicProvider._report_audiobook_progress(provider_mock, "abook-1", 10) + await YandexMusicProvider._report_audiobook_progress(provider_mock, "abook-1", 20) + + calls = provider_mock.client.play_audio.await_args_list + assert calls[0].kwargs["play_id"] == calls[1].kwargs["play_id"] + + +@pytest.mark.asyncio +async def test_on_streamed_audiobook_sends_final_position(provider_mock: Mock) -> None: + """on_streamed uses seek_position + seconds_streamed as absolute position.""" + sd = StreamDetails( + provider="yandex_music", + item_id="abook-7", + audio_format=Mock(), + media_type=MediaType.AUDIOBOOK, + data={ + "chapter_ids": ["c1", "c2", "c3"], + "chapter_durations_ms": [60_000, 120_000, 600_000], + }, + ) + sd.seek_position = 300 # 5 minutes in — inside c3 (starts at 180s, 10min long) + sd.seconds_streamed = 15.0 # stopped at 315s total + # pre-stash a play_id so we can assert it pops + provider_mock._audiobook_play_ids["abook-7"] = "stable-id" + + await YandexMusicProvider._report_audiobook_final(provider_mock, sd, sd.data) + + provider_mock.client.play_audio.assert_awaited_once() + kwargs = provider_mock.client.play_audio.await_args.kwargs + assert kwargs["track_id"] == "c3" + assert kwargs["album_id"] == "abook-7" + # absolute 315s → chapter 3 (starts at 180s) offset 135s + assert kwargs["total_played_seconds"] == 135 + assert kwargs["end_position_seconds"] == 135 + assert kwargs["play_id"] == "stable-id" + # session play_id cleared after final report + assert "abook-7" not in provider_mock._audiobook_play_ids + + +@pytest.mark.asyncio +async def test_on_streamed_audiobook_ignores_empty_data(provider_mock: Mock) -> None: + """Missing chapter_ids in data → no play_audio call.""" + sd = StreamDetails( + provider="yandex_music", + item_id="abook-9", + audio_format=Mock(), + media_type=MediaType.AUDIOBOOK, + data={}, + ) + + await YandexMusicProvider._report_audiobook_final(provider_mock, sd, sd.data) + + provider_mock.client.play_audio.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_on_played_audiobook_swallows_upstream_unavailable( + provider_mock: Mock, +) -> None: + """ResourceTemporarilyUnavailable from chapter_map resolution must not propagate.""" + provider_mock._resolve_audiobook_chapter_map.side_effect = ResourceTemporarilyUnavailable( + "rate limited" + ) + + # Must not raise; progress report is advisory. + await YandexMusicProvider._report_audiobook_progress(provider_mock, "abook-42", 30) + + provider_mock.client.play_audio.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_on_played_audiobook_swallows_unexpected_exception( + provider_mock: Mock, +) -> None: + """Any non-cancellation exception while resolving the chapter map is swallowed.""" + + class SomeAuthError(Exception): + pass + + provider_mock._resolve_audiobook_chapter_map.side_effect = SomeAuthError("token expired") + + await YandexMusicProvider._report_audiobook_progress(provider_mock, "abook-42", 30) + + provider_mock.client.play_audio.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_on_played_audiobook_propagates_cancellation( + provider_mock: Mock, +) -> None: + """asyncio.CancelledError must propagate — never suppressed.""" + provider_mock._resolve_audiobook_chapter_map.side_effect = asyncio.CancelledError() + + with pytest.raises(asyncio.CancelledError): + await YandexMusicProvider._report_audiobook_progress(provider_mock, "abook-42", 30) + + +@pytest.mark.asyncio +async def test_on_streamed_audiobook_evicts_cache_entry(provider_mock: Mock) -> None: + """After on_streamed, the chapter-map cache entry for this book is dropped.""" + sd = StreamDetails( + provider="yandex_music", + item_id="abook-7", + audio_format=Mock(), + media_type=MediaType.AUDIOBOOK, + data={"chapter_ids": ["c1"], "chapter_durations_ms": [60_000]}, + ) + provider_mock._audiobook_chapter_cache["abook-7"] = (["c1"], [60_000]) + provider_mock._audiobook_chapter_cache["abook-OTHER"] = (["x"], [1000]) + + await YandexMusicProvider._report_audiobook_final(provider_mock, sd, sd.data) + + assert "abook-7" not in provider_mock._audiobook_chapter_cache + # Other audiobooks in cache are untouched + assert "abook-OTHER" in provider_mock._audiobook_chapter_cache + + +@pytest.mark.asyncio +async def test_on_streamed_audiobook_reports_end_of_last_chapter_at_eof( + provider_mock: Mock, +) -> None: + """Natural EOF must report end_position_seconds = last chapter's length, not 0.""" + sd = StreamDetails( + provider="yandex_music", + item_id="abook-eof", + audio_format=Mock(), + media_type=MediaType.AUDIOBOOK, + data={ + "chapter_ids": ["c1", "c2"], + "chapter_durations_ms": [60_000, 120_000], + }, + ) + # Total duration = 180s. Reached exactly the end. + sd.seek_position = 0 + sd.seconds_streamed = 180.0 + + await YandexMusicProvider._report_audiobook_final(provider_mock, sd, sd.data) + + provider_mock.client.play_audio.assert_awaited_once() + kwargs = provider_mock.client.play_audio.await_args.kwargs + assert kwargs["track_id"] == "c2" + assert kwargs["track_length_seconds"] == 120 + assert kwargs["end_position_seconds"] == 120 + assert kwargs["total_played_seconds"] == 120 + + +@pytest.mark.asyncio +async def test_on_streamed_audiobook_clamps_zero_duration_chapter( + provider_mock: Mock, +) -> None: + """Missing duration_ms (coerced to 0) must never send track_length_seconds=0.""" + sd = StreamDetails( + provider="yandex_music", + item_id="abook-bad", + audio_format=Mock(), + media_type=MediaType.AUDIOBOOK, + data={"chapter_ids": ["c1"], "chapter_durations_ms": [0]}, + ) + sd.seek_position = 0 + sd.seconds_streamed = 0.0 + + await YandexMusicProvider._report_audiobook_final(provider_mock, sd, sd.data) + + kwargs = provider_mock.client.play_audio.await_args.kwargs + assert kwargs["track_length_seconds"] >= 1 + assert 0 <= kwargs["end_position_seconds"] <= kwargs["track_length_seconds"] + + +@pytest.mark.asyncio +async def test_on_streamed_audiobook_without_data_still_cleans_up( + provider_mock: Mock, +) -> None: + """StreamDetails.data missing or stripped → caches still evicted, no play_audio call.""" + sd = StreamDetails( + provider="yandex_music", + item_id="abook-x", + audio_format=Mock(), + media_type=MediaType.AUDIOBOOK, + ) + provider_mock._audiobook_chapter_cache["abook-x"] = (["c"], [1000]) + provider_mock._audiobook_play_ids["abook-x"] = "sess-id" + + await YandexMusicProvider._report_audiobook_final(provider_mock, sd, {}) + + assert "abook-x" not in provider_mock._audiobook_chapter_cache + assert "abook-x" not in provider_mock._audiobook_play_ids + provider_mock.client.play_audio.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_on_streamed_audiobook_branch_taken_even_without_chapter_data( + provider_mock: Mock, +) -> None: + """AUDIOBOOK stream without chapter_ids still routes to audiobook cleanup. + + Previously the gate required ``"chapter_ids" in data`` and fell through + to the radio path when data was missing, leaving caches stale. + """ + provider_mock._report_audiobook_final = AsyncMock() + provider_mock.client.send_rotor_station_feedback = AsyncMock() + sd = StreamDetails( + provider="yandex_music", + item_id="abook-y", + audio_format=Mock(), + media_type=MediaType.AUDIOBOOK, + # data=None (not a dict) — previously would fall through to radio + ) + + await YandexMusicProvider.on_streamed(provider_mock, sd) + + provider_mock._report_audiobook_final.assert_awaited_once() + provider_mock.client.send_rotor_station_feedback.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_on_played_routes_audiobook_and_skips_radio_branch( + provider_mock: Mock, +) -> None: + """on_played with MediaType.AUDIOBOOK early-returns before radio feedback.""" + provider_mock._resolve_audiobook_chapter_map.return_value = (["c1"], [60_000]) + provider_mock._report_audiobook_progress = AsyncMock() + provider_mock.client.send_rotor_station_feedback = AsyncMock() + + await YandexMusicProvider.on_played( + provider_mock, + MediaType.AUDIOBOOK, + "abook-1", + fully_played=False, + position=30, + media_item=Mock(), + is_playing=True, + ) + + provider_mock._report_audiobook_progress.assert_awaited_once_with("abook-1", 30) + provider_mock.client.send_rotor_station_feedback.assert_not_awaited() diff --git a/tests/providers/yandex_music/test_auth.py b/tests/providers/yandex_music/test_auth.py new file mode 100644 index 0000000000..e55daf122a --- /dev/null +++ b/tests/providers/yandex_music/test_auth.py @@ -0,0 +1,765 @@ +"""Unit tests for auth.py (ya-passport-auth QR + Device Flow).""" + +from __future__ import annotations + +import asyncio +import json +from collections.abc import Awaitable, Callable, Generator +from typing import TYPE_CHECKING +from unittest import mock + +import pytest +from music_assistant_models.errors import LoginFailed, ResourceTemporarilyUnavailable +from ya_passport_auth import Credentials, DeviceCodeSession, QrSession, SecretStr +from ya_passport_auth.exceptions import ( + DeviceCodeTimeoutError, + InvalidCredentialsError, + QRTimeoutError, + RateLimitedError, +) +from ya_passport_auth.exceptions import ( + NetworkError as PassportNetworkError, +) + +from music_assistant.providers.yandex_music.auth import ( + perform_device_auth, + perform_qr_auth, + refresh_credentials_via_passport, + refresh_music_token, + validate_x_token, +) + +if TYPE_CHECKING: + from aiohttp import web + + +@pytest.fixture(autouse=True) +def skip_grace_sleep() -> Generator[mock.AsyncMock, None, None]: + """Bypass the post-auth grace ``asyncio.sleep`` so tests run instantly.""" + with mock.patch( + "music_assistant.providers.yandex_music.auth.asyncio.sleep", + new=mock.AsyncMock(), + ) as patched: + yield patched + + +# -- helpers ------------------------------------------------------------------- + + +def _make_device_session( + user_code: str = "ABCD-1234", + verification_url: str = "https://oauth.yandex.ru/device", + interval: int = 1, + expires_in: int = 600, +) -> DeviceCodeSession: + """Build a DeviceCodeSession for testing.""" + return DeviceCodeSession( + device_code=SecretStr("dev-code-xyz"), + user_code=user_code, + verification_url=verification_url, + expires_in=expires_in, + interval=interval, + ) + + +def _make_credentials( + x_token: str = "test_x_token", # noqa: S107 + music_token: str | None = "test_music_token", # noqa: S107 + refresh_token: str | None = "test_refresh_token", # noqa: S107 +) -> Credentials: + """Build a Credentials dataclass for testing.""" + return Credentials( + x_token=SecretStr(x_token), + music_token=SecretStr(music_token) if music_token else None, + refresh_token=SecretStr(refresh_token) if refresh_token else None, + ) + + +def _make_qr_session() -> QrSession: + """Build a QrSession for testing.""" + return QrSession( + track_id="track123", + csrf_token="csrf_abc", + qr_url="https://passport.yandex.ru/auth/magic/code/?track_id=track123", + ) + + +# -- perform_device_auth ------------------------------------------------------- + + +async def test_perform_device_auth_returns_three_tokens() -> None: + """Device flow returns (x_token, music_token, refresh_token).""" + session = _make_device_session() + creds = _make_credentials() + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_mass.webserver.base_url = "http://ma.local:8095" + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + x_token, music_token, refresh_token = await perform_device_auth(mock_mass, "session_1") + + assert x_token == "test_x_token" + assert music_token == "test_music_token" + assert refresh_token == "test_refresh_token" + mock_client.start_device_login.assert_awaited_once() + mock_client.poll_device_until_confirmed.assert_awaited_once_with(session) + + +async def test_perform_device_auth_serves_intermediate_page_and_cleans_up() -> None: + """A temporary HTML page + status endpoint are registered and unregistered after.""" + session = _make_device_session( + user_code="WXYZ-9999", + verification_url="https://oauth.yandex.ru/device", + ) + creds = _make_credentials() + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_mass.webserver.base_url = "http://ma.local:8095" + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + await perform_device_auth(mock_mass, "session_1") + + expected_path = "/yandex_music/device_code/session_1" + expected_status_path = f"{expected_path}/status" + + registered_paths = [ + (c.args[0], c.args[2]) for c in mock_mass.webserver.register_dynamic_route.call_args_list + ] + assert (expected_path, "GET") in registered_paths + assert (expected_status_path, "GET") in registered_paths + + unregistered_paths = [ + c.args for c in mock_mass.webserver.unregister_dynamic_route.call_args_list + ] + assert (expected_path, "GET") in unregistered_paths + assert (expected_status_path, "GET") in unregistered_paths + + mock_auth_helper.__aenter__.return_value.send_url.assert_called_once_with( + f"http://ma.local:8095{expected_path}" + ) + + +async def test_perform_device_auth_status_endpoint_reports_done_after_success() -> None: + """The status endpoint reports state=done after the device flow completes. + + Without this the popup window (opened via target=_blank) has no signal to + close itself after the user confirms the code. + """ + session = _make_device_session() + creds = _make_credentials() + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_mass.webserver.base_url = "http://ma.local:8095" + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + await perform_device_auth(mock_mass, "session_xyz") + + status_call = next( + c + for c in mock_mass.webserver.register_dynamic_route.call_args_list + if c.args[0].endswith("/status") + ) + status_handler = status_call.args[1] + response = await status_handler(mock.MagicMock()) + assert isinstance(response.body, bytes) + payload = json.loads(response.body) + assert payload["state"] == "done" + + +async def test_perform_device_auth_status_reports_failed_on_error( + skip_grace_sleep: mock.AsyncMock, +) -> None: + """When poll fails, status endpoint reports failed and grace sleep still fires. + + Otherwise the page would race with route teardown and only ever see 404s + instead of the 'failed' message. + """ + session = _make_device_session() + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.side_effect = DeviceCodeTimeoutError("expired") + + mock_mass = mock.MagicMock() + mock_mass.webserver.base_url = "http://ma.local:8095" + mock_auth_helper = mock.AsyncMock() + + status_handlers: list[Callable[[web.Request], Awaitable[web.Response]]] = [] + + def _capture( + path: str, + handler: Callable[[web.Request], Awaitable[web.Response]], + _method: str, + ) -> None: + if path.endswith("/status"): + status_handlers.append(handler) + + mock_mass.webserver.register_dynamic_route.side_effect = _capture + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="timed out"): + await perform_device_auth(mock_mass, "session_fail") + + assert status_handlers, "status handler should have been registered" + response = await status_handlers[0](mock.MagicMock()) + assert isinstance(response.body, bytes) + payload = json.loads(response.body) + assert payload["state"] == "failed" + # Grace sleep must fire on failure so the page can observe "failed" before teardown. + skip_grace_sleep.assert_awaited() + + +async def test_perform_device_auth_does_not_mark_cancellation_as_failure( + skip_grace_sleep: mock.AsyncMock, +) -> None: + """CancelledError must propagate without marking state as 'failed' or sleeping.""" + session = _make_device_session() + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.side_effect = asyncio.CancelledError() + + mock_mass = mock.MagicMock() + mock_mass.webserver.base_url = "http://ma.local:8095" + mock_auth_helper = mock.AsyncMock() + + status_handlers: list[Callable[[web.Request], Awaitable[web.Response]]] = [] + + def _capture( + path: str, + handler: Callable[[web.Request], Awaitable[web.Response]], + _method: str, + ) -> None: + if path.endswith("/status"): + status_handlers.append(handler) + + mock_mass.webserver.register_dynamic_route.side_effect = _capture + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(asyncio.CancelledError): + await perform_device_auth(mock_mass, "session_cancel") + + assert status_handlers, "status handler should have been registered" + response = await status_handlers[0](mock.MagicMock()) + assert isinstance(response.body, bytes) + payload = json.loads(response.body) + assert payload["state"] == "pending" + skip_grace_sleep.assert_not_awaited() + + +async def test_perform_device_auth_route_handler_renders_code_and_url() -> None: + """The registered route handler returns HTML containing the code + verification URL.""" + session = _make_device_session( + user_code="ABCD-1234", + verification_url="https://oauth.yandex.ru/device", + ) + creds = _make_credentials() + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_mass.webserver.base_url = "http://ma.local:8095" + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + await perform_device_auth(mock_mass, "session_1") + + page_call = next( + c + for c in mock_mass.webserver.register_dynamic_route.call_args_list + if not c.args[0].endswith("/status") + ) + handler = page_call.args[1] + response = await handler(mock.MagicMock()) + body = response.text + assert body is not None + assert "ABCD-1234" in body + assert "https://oauth.yandex.ru/device" in body + assert response.content_type == "text/html" + + +async def test_perform_device_auth_timeout_raises_login_failed() -> None: + """DeviceCodeTimeoutError from library is mapped to LoginFailed and the route is freed.""" + session = _make_device_session() + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.side_effect = DeviceCodeTimeoutError("expired") + + mock_mass = mock.MagicMock() + mock_mass.webserver.base_url = "http://ma.local:8095" + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="timed out"): + await perform_device_auth(mock_mass, "session_1") + + unregistered_paths = [ + c.args for c in mock_mass.webserver.unregister_dynamic_route.call_args_list + ] + assert ("/yandex_music/device_code/session_1", "GET") in unregistered_paths + assert ("/yandex_music/device_code/session_1/status", "GET") in unregistered_paths + + +async def test_perform_device_auth_ya_passport_error_raises_login_failed() -> None: + """Generic YaPassportError from library is mapped to LoginFailed.""" + mock_client = mock.AsyncMock() + mock_client.start_device_login.side_effect = PassportNetworkError("offline") + + mock_mass = mock.MagicMock() + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="device auth error"): + await perform_device_auth(mock_mass, "session_1") + + +async def test_perform_device_auth_no_music_token_raises_login_failed() -> None: + """Credentials without music_token raises LoginFailed.""" + session = _make_device_session() + creds = _make_credentials(music_token=None) + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="no music token"): + await perform_device_auth(mock_mass, "session_1") + + +async def test_perform_device_auth_no_refresh_token_raises_login_failed() -> None: + """Credentials without refresh_token raises LoginFailed.""" + session = _make_device_session() + creds = _make_credentials(refresh_token=None) + mock_client = mock.AsyncMock() + mock_client.start_device_login.return_value = session + mock_client.poll_device_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="no refresh token"): + await perform_device_auth(mock_mass, "session_1") + + +# -- perform_qr_auth ---------------------------------------------------------- + + +async def test_perform_qr_auth_success() -> None: + """QR flow returns (x_token, music_token) as plain strings.""" + qr = _make_qr_session() + creds = _make_credentials() + mock_client = mock.AsyncMock() + mock_client.start_qr_login.return_value = qr + mock_client.poll_qr_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + x_token, music_token = await perform_qr_auth(mock_mass, "session_1") + + assert x_token == "test_x_token" + assert music_token == "test_music_token" + mock_client.start_qr_login.assert_awaited_once() + mock_client.poll_qr_until_confirmed.assert_awaited_once_with(qr) + + +async def test_perform_qr_auth_sends_qr_url() -> None: + """QR URL is sent to the AuthenticationHelper.""" + qr = _make_qr_session() + creds = _make_credentials() + mock_client = mock.AsyncMock() + mock_client.start_qr_login.return_value = qr + mock_client.poll_qr_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + await perform_qr_auth(mock_mass, "session_1") + + mock_auth_helper.__aenter__.return_value.send_url.assert_called_once_with(qr.qr_url) + + +async def test_perform_qr_auth_timeout_raises_login_failed() -> None: + """QRTimeoutError from library is mapped to LoginFailed.""" + qr = _make_qr_session() + mock_client = mock.AsyncMock() + mock_client.start_qr_login.return_value = qr + mock_client.poll_qr_until_confirmed.side_effect = QRTimeoutError("timed out") + + mock_mass = mock.MagicMock() + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="timed out"): + await perform_qr_auth(mock_mass, "session_1") + + +async def test_perform_qr_auth_passport_error_raises_login_failed() -> None: + """Generic YaPassportError is mapped to LoginFailed.""" + mock_client = mock.AsyncMock() + mock_client.start_qr_login.side_effect = PassportNetworkError("connection lost") + + mock_mass = mock.MagicMock() + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="Yandex auth error"): + await perform_qr_auth(mock_mass, "session_1") + + +async def test_perform_qr_auth_no_music_token_raises() -> None: + """Credentials without music_token raises LoginFailed.""" + qr = _make_qr_session() + creds = _make_credentials(music_token=None) + mock_client = mock.AsyncMock() + mock_client.start_qr_login.return_value = qr + mock_client.poll_qr_until_confirmed.return_value = creds + + mock_mass = mock.MagicMock() + mock_auth_helper = mock.AsyncMock() + + with ( + mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create, + mock.patch( + "music_assistant.providers.yandex_music.auth.AuthenticationHelper", + return_value=mock_auth_helper, + ), + ): + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="no music token"): + await perform_qr_auth(mock_mass, "session_1") + + +# -- refresh_music_token ------------------------------------------------------- + + +async def test_refresh_music_token_success() -> None: + """Successful refresh returns a SecretStr.""" + mock_client = mock.AsyncMock() + mock_client.refresh_music_token.return_value = SecretStr("new_music_token") + + with mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create: + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + result = await refresh_music_token(SecretStr("my_x_token")) + + assert result.get_secret() == "new_music_token" + mock_client.refresh_music_token.assert_awaited_once() + + +async def test_refresh_music_token_auth_error_raises_login_failed() -> None: + """Auth failure during refresh is mapped to LoginFailed.""" + mock_client = mock.AsyncMock() + mock_client.refresh_music_token.side_effect = InvalidCredentialsError("bad token") + + with mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create: + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="Failed to refresh"): + await refresh_music_token(SecretStr("bad_x_token")) + + +@pytest.mark.parametrize( + "exc", + [PassportNetworkError("offline"), RateLimitedError("429")], + ids=["network", "rate_limited"], +) +async def test_refresh_music_token_transient_error_raises_temporarily_unavailable( + exc: Exception, +) -> None: + """Transient Passport failures don't masquerade as LoginFailed.""" + mock_client = mock.AsyncMock() + mock_client.refresh_music_token.side_effect = exc + + with mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create: + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(ResourceTemporarilyUnavailable, match="temporarily unavailable"): + await refresh_music_token(SecretStr("my_x_token")) + + +# -- validate_x_token ---------------------------------------------------------- + + +async def test_validate_x_token_valid() -> None: + """Valid x_token returns True.""" + mock_client = mock.AsyncMock() + mock_client.validate_x_token.return_value = True + + with mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create: + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + result = await validate_x_token(SecretStr("good_token")) + + assert result is True + + +async def test_validate_x_token_error_returns_false() -> None: + """Any YaPassportError returns False (graceful degradation).""" + mock_client = mock.AsyncMock() + mock_client.validate_x_token.side_effect = RateLimitedError("429") + + with mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create: + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + result = await validate_x_token(SecretStr("some_token")) + + assert result is False + + +# -- refresh_credentials_via_passport ------------------------------------------ + + +async def test_refresh_credentials_via_passport_success() -> None: + """Successful refresh returns full Credentials triple.""" + new_creds = _make_credentials( + x_token="new_x", + music_token="new_music", + refresh_token="new_refresh", + ) + mock_client = mock.AsyncMock() + mock_client.refresh_credentials.return_value = new_creds + + with mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create: + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + result = await refresh_credentials_via_passport( + SecretStr("old_x"), SecretStr("old_refresh") + ) + + assert result.x_token.get_secret() == "new_x" + assert result.music_token is not None + assert result.music_token.get_secret() == "new_music" + assert result.refresh_token is not None + assert result.refresh_token.get_secret() == "new_refresh" + mock_client.refresh_credentials.assert_awaited_once() + + +async def test_refresh_credentials_via_passport_error_raises_login_failed() -> None: + """Auth failure during credential refresh is mapped to LoginFailed.""" + mock_client = mock.AsyncMock() + mock_client.refresh_credentials.side_effect = InvalidCredentialsError("dead") + + with mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create: + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(LoginFailed, match="Failed to refresh credentials"): + await refresh_credentials_via_passport(SecretStr("bad_x"), SecretStr("bad_refresh")) + + +@pytest.mark.parametrize( + "exc", + [PassportNetworkError("offline"), RateLimitedError("429")], + ids=["network", "rate_limited"], +) +async def test_refresh_credentials_via_passport_transient_error_raises_temporarily_unavailable( + exc: Exception, +) -> None: + """Transient Passport failures don't masquerade as LoginFailed.""" + mock_client = mock.AsyncMock() + mock_client.refresh_credentials.side_effect = exc + + with mock.patch( + "music_assistant.providers.yandex_music.auth.PassportClient.create", + ) as mock_create: + mock_create.return_value.__aenter__ = mock.AsyncMock(return_value=mock_client) + mock_create.return_value.__aexit__ = mock.AsyncMock(return_value=False) + + with pytest.raises(ResourceTemporarilyUnavailable, match="temporarily unavailable"): + await refresh_credentials_via_passport(SecretStr("x"), SecretStr("refresh")) diff --git a/tests/providers/yandex_music/test_browse_collection.py b/tests/providers/yandex_music/test_browse_collection.py new file mode 100644 index 0000000000..f05f769bb4 --- /dev/null +++ b/tests/providers/yandex_music/test_browse_collection.py @@ -0,0 +1,108 @@ +"""Tests that Collection folder renders audiobooks/podcasts sub-folders.""" + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest +from music_assistant_models.enums import ProviderFeature +from music_assistant_models.media_items import BrowseFolder + +from music_assistant.providers.yandex_music.constants import BROWSE_NAMES_EN, BROWSE_NAMES_RU +from music_assistant.providers.yandex_music.provider import YandexMusicProvider + + +def _make_provider_mock(features: set[ProviderFeature], *, locale: str = "en_US") -> Mock: + provider = Mock(spec=YandexMusicProvider) + provider.instance_id = "yandex_music_instance" + provider.domain = "yandex_music" + provider.supported_features = features + provider.mass = Mock() + provider.mass.metadata = Mock() + provider.mass.metadata.locale = locale + # real method so locale mapping runs + provider._get_browse_names = YandexMusicProvider._get_browse_names.__get__( + provider, YandexMusicProvider + ) + provider.logger = Mock() + return provider + + +@pytest.mark.asyncio +async def test_collection_shows_audiobooks_folder_when_feature_enabled() -> None: + """LIBRARY_AUDIOBOOKS enabled → BrowseFolder for audiobooks is returned.""" + features = { + ProviderFeature.LIBRARY_TRACKS, + ProviderFeature.LIBRARY_ALBUMS, + ProviderFeature.LIBRARY_AUDIOBOOKS, + } + provider = _make_provider_mock(features) + + folders = await YandexMusicProvider._browse_collection( + provider, "yandex_music_instance://collection" + ) + + item_ids = [f.item_id for f in folders if isinstance(f, BrowseFolder)] + assert "audiobooks" in item_ids + audiobook_folder = next( + f for f in folders if isinstance(f, BrowseFolder) and f.item_id == "audiobooks" + ) + assert audiobook_folder.is_playable is False + assert audiobook_folder.path.endswith("audiobooks") + assert audiobook_folder.name == BROWSE_NAMES_EN["audiobooks"] + + +@pytest.mark.asyncio +async def test_collection_shows_podcasts_folder_when_feature_enabled() -> None: + """LIBRARY_PODCASTS enabled → BrowseFolder for podcasts is returned.""" + features = { + ProviderFeature.LIBRARY_TRACKS, + ProviderFeature.LIBRARY_PODCASTS, + } + provider = _make_provider_mock(features) + + folders = await YandexMusicProvider._browse_collection( + provider, "yandex_music_instance://collection" + ) + + item_ids = [f.item_id for f in folders if isinstance(f, BrowseFolder)] + assert "podcasts" in item_ids + + +@pytest.mark.asyncio +async def test_collection_hides_audiobooks_folder_when_feature_disabled() -> None: + """Disabling LIBRARY_AUDIOBOOKS removes the folder from Collection.""" + features = { + ProviderFeature.LIBRARY_TRACKS, + ProviderFeature.LIBRARY_ALBUMS, + } + provider = _make_provider_mock(features) + + folders = await YandexMusicProvider._browse_collection( + provider, "yandex_music_instance://collection" + ) + + item_ids = [f.item_id for f in folders if isinstance(f, BrowseFolder)] + assert "audiobooks" not in item_ids + assert "podcasts" not in item_ids + + +@pytest.mark.asyncio +async def test_collection_audiobooks_folder_russian_locale() -> None: + """Russian locale uses Russian folder names.""" + features = { + ProviderFeature.LIBRARY_AUDIOBOOKS, + ProviderFeature.LIBRARY_PODCASTS, + } + provider = _make_provider_mock(features, locale="ru_RU") + + folders = await YandexMusicProvider._browse_collection( + provider, "yandex_music_instance://collection" + ) + + audiobook = next( + f for f in folders if isinstance(f, BrowseFolder) and f.item_id == "audiobooks" + ) + podcast = next(f for f in folders if isinstance(f, BrowseFolder) and f.item_id == "podcasts") + assert audiobook.name == BROWSE_NAMES_RU["audiobooks"] + assert podcast.name == BROWSE_NAMES_RU["podcasts"] diff --git a/tests/providers/yandex_music/test_browse_pins_history.py b/tests/providers/yandex_music/test_browse_pins_history.py new file mode 100644 index 0000000000..381e35a092 --- /dev/null +++ b/tests/providers/yandex_music/test_browse_pins_history.py @@ -0,0 +1,217 @@ +"""Tests for the Pins and Listening History browse handlers.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from music_assistant_models.errors import InvalidDataError, MediaNotFoundError +from music_assistant_models.media_items import Album, Artist, Playlist, Track + +from music_assistant.providers.yandex_music.provider import YandexMusicProvider + + +@pytest.fixture +def provider_mock() -> Mock: + """Return a mock Yandex Music provider with cache + client stubs.""" + provider = Mock(spec=YandexMusicProvider) + provider.domain = "yandex_music" + provider.instance_id = "yandex_music_instance" + provider.logger = Mock() + provider.client = AsyncMock() + provider.client.user_id = 12345 + provider.mass = Mock() + provider.mass.cache = AsyncMock() + provider.mass.cache.get = AsyncMock(return_value=None) + provider.mass.cache.set = AsyncMock() + return provider + + +@pytest.mark.asyncio +async def test_browse_pins_returns_empty_when_no_pins(provider_mock: Mock) -> None: + """_browse_pins returns [] when client returns None.""" + provider_mock.client.get_pins = AsyncMock(return_value=None) + + result = await YandexMusicProvider._browse_pins(provider_mock) + + assert result == [] + + +@pytest.mark.asyncio +async def test_browse_pins_returns_empty_when_pins_field_missing( + provider_mock: Mock, +) -> None: + """_browse_pins returns [] when PinsList.pins is None.""" + provider_mock.client.get_pins = AsyncMock(return_value=type("PinsList", (), {"pins": None})()) + + result = await YandexMusicProvider._browse_pins(provider_mock) + + assert result == [] + + +@pytest.mark.asyncio +async def test_browse_pins_resolves_artist_album_playlist(provider_mock: Mock) -> None: + """_browse_pins routes each pin type to the corresponding lookup.""" + artist_pin = type( + "Pin", + (), + {"type": "artist_item", "data": type("D", (), {"id": 11})()}, + )() + album_pin = type( + "Pin", + (), + {"type": "album_item", "data": type("D", (), {"id": 22})()}, + )() + playlist_pin = type( + "Pin", + (), + {"type": "playlist_item", "data": type("D", (), {"uid": 33, "kind": 44})()}, + )() + pins = type("PinsList", (), {"pins": [artist_pin, album_pin, playlist_pin]})() + provider_mock.client.get_pins = AsyncMock(return_value=pins) + + artist = Mock(spec=Artist) + album = Mock(spec=Album) + playlist = Mock(spec=Playlist) + provider_mock.get_artist = AsyncMock(return_value=artist) + provider_mock.get_album = AsyncMock(return_value=album) + provider_mock.get_playlist = AsyncMock(return_value=playlist) + + result = await YandexMusicProvider._browse_pins(provider_mock) + + provider_mock.get_artist.assert_awaited_once_with("11") + provider_mock.get_album.assert_awaited_once_with("22") + provider_mock.get_playlist.assert_awaited_once_with("33:44") + assert result == [artist, album, playlist] + + +@pytest.mark.asyncio +async def test_browse_pins_skips_wave_and_missing_data(provider_mock: Mock) -> None: + """_browse_pins ignores wave pins and pins with missing data.""" + wave_pin = type( + "Pin", + (), + {"type": "wave_item", "data": type("D", (), {})()}, + )() + bad_pin = type("Pin", (), {"type": "album_item", "data": None})() + pins = type("PinsList", (), {"pins": [wave_pin, bad_pin]})() + provider_mock.client.get_pins = AsyncMock(return_value=pins) + provider_mock.get_album = AsyncMock() + + result = await YandexMusicProvider._browse_pins(provider_mock) + + assert result == [] + provider_mock.get_album.assert_not_called() + + +@pytest.mark.asyncio +async def test_browse_pins_skips_lookup_errors(provider_mock: Mock) -> None: + """_browse_pins survives MediaNotFoundError during single-item lookups.""" + album_pin = type( + "Pin", + (), + {"type": "album_item", "data": type("D", (), {"id": 22})()}, + )() + pins = type("PinsList", (), {"pins": [album_pin]})() + provider_mock.client.get_pins = AsyncMock(return_value=pins) + provider_mock.get_album = AsyncMock(side_effect=MediaNotFoundError("gone")) + + result = await YandexMusicProvider._browse_pins(provider_mock) + + assert result == [] + + +@pytest.mark.asyncio +async def test_browse_history_returns_empty_when_no_history( + provider_mock: Mock, +) -> None: + """_browse_history returns [] when client returns None.""" + provider_mock.client.get_music_history = AsyncMock(return_value=None) + + result = await YandexMusicProvider._browse_history(provider_mock) + + assert result == [] + + +def _hist_item(track_id: int) -> object: + """Build a history entry the way MarshalX actually returns it. + + `data.item_id` is a dict containing track_id, album_id, etc.; `full_model` + is not populated by the live API. Callers batch-resolve via get_tracks. + """ + data = type("D", (), {"item_id": {"track_id": str(track_id)}, "full_model": None})() + return type("HistItem", (), {"type": "track", "data": data})() + + +@pytest.mark.asyncio +async def test_browse_history_flattens_and_deduplicates(provider_mock: Mock) -> None: + """_browse_history flattens days→groups→tracks, de-dupes by track id, preserves order.""" + group1 = type("Group", (), {"tracks": [_hist_item(1), _hist_item(2)]})() + group2 = type("Group", (), {"tracks": [_hist_item(2), _hist_item(3)]})() # dup id=2 + tab1 = type("Tab", (), {"items": [group1]})() + tab2 = type("Tab", (), {"items": [group2]})() + history = type("MusicHistory", (), {"history_tabs": [tab1, tab2]})() + provider_mock.client.get_music_history = AsyncMock(return_value=history) + + # Batch-hydrate returns the yandex tracks in their own order; the provider + # re-orders them to match the de-duplicated id list. + yt1 = type("Yt", (), {"id": 1})() + yt2 = type("Yt", (), {"id": 2})() + yt3 = type("Yt", (), {"id": 3})() + provider_mock.client.get_tracks = AsyncMock(return_value=[yt3, yt1, yt2]) + + parsed = [Mock(spec=Track, name="p1"), Mock(spec=Track, name="p2"), Mock(spec=Track, name="p3")] + with patch( + "music_assistant.providers.yandex_music.provider.parse_track", + side_effect=parsed, + ): + result = await YandexMusicProvider._browse_history(provider_mock) + + provider_mock.client.get_tracks.assert_awaited_once_with(["1", "2", "3"]) + assert result == parsed + + +@pytest.mark.asyncio +async def test_browse_history_skips_non_track_items(provider_mock: Mock) -> None: + """_browse_history ignores items with type != 'track'.""" + album_item = type( + "HistItem", + (), + { + "type": "album", + "data": type("D", (), {"item_id": {"track_id": "99"}})(), + }, + )() + group = type("Group", (), {"tracks": [album_item]})() + tab = type("Tab", (), {"items": [group]})() + history = type("MusicHistory", (), {"history_tabs": [tab]})() + provider_mock.client.get_music_history = AsyncMock(return_value=history) + provider_mock.client.get_tracks = AsyncMock() + + with patch("music_assistant.providers.yandex_music.provider.parse_track") as parse_track: + result = await YandexMusicProvider._browse_history(provider_mock) + parse_track.assert_not_called() + + assert result == [] + # No IDs collected → no hydration round-trip at all + provider_mock.client.get_tracks.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_browse_history_skips_invalid_track(provider_mock: Mock) -> None: + """_browse_history drops tracks where parse_track raises InvalidDataError.""" + group = type("Group", (), {"tracks": [_hist_item(1)]})() + tab = type("Tab", (), {"items": [group]})() + history = type("MusicHistory", (), {"history_tabs": [tab]})() + provider_mock.client.get_music_history = AsyncMock(return_value=history) + + yt1 = type("Yt", (), {"id": 1})() + provider_mock.client.get_tracks = AsyncMock(return_value=[yt1]) + + with patch( + "music_assistant.providers.yandex_music.provider.parse_track", + side_effect=InvalidDataError("nope"), + ): + result = await YandexMusicProvider._browse_history(provider_mock) + + assert result == [] diff --git a/tests/providers/yandex_music/test_integration.py b/tests/providers/yandex_music/test_integration.py deleted file mode 100644 index a984f25eda..0000000000 --- a/tests/providers/yandex_music/test_integration.py +++ /dev/null @@ -1,446 +0,0 @@ -"""Integration tests for the Yandex Music provider with in-process Music Assistant.""" - -from __future__ import annotations - -import json -import pathlib -from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any, cast -from unittest import mock - -import pytest -from music_assistant_models.enums import ContentType, MediaType, StreamType -from music_assistant_models.errors import ResourceTemporarilyUnavailable -from yandex_music import Album as YandexAlbum -from yandex_music import Artist as YandexArtist -from yandex_music import Playlist as YandexPlaylist -from yandex_music import Track as YandexTrack - -from music_assistant.mass import MusicAssistant -from music_assistant.models.music_provider import MusicProvider -from music_assistant.providers.yandex_music.constants import BROWSE_NAMES_EN, BROWSE_NAMES_RU -from tests.common import wait_for_sync_completion - -if TYPE_CHECKING: - from music_assistant_models.config_entries import ProviderConfig - -FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures" -_DE_JSON_CLIENT = type("ClientStub", (), {"report_unknown_fields": False})() - - -def _load_json(path: pathlib.Path) -> dict[str, Any]: - """Load JSON fixture.""" - with open(path) as f: - return cast("dict[str, Any]", json.load(f)) - - -def _load_yandex_objects() -> tuple[Any, Any, Any, Any]: - """Load Yandex Artist, Album, Track, Playlist from fixtures for mock client.""" - artist = YandexArtist.de_json( - _load_json(FIXTURES_DIR / "artists" / "minimal.json"), _DE_JSON_CLIENT - ) - album = YandexAlbum.de_json( - _load_json(FIXTURES_DIR / "albums" / "minimal.json"), _DE_JSON_CLIENT - ) - track = YandexTrack.de_json( - _load_json(FIXTURES_DIR / "tracks" / "minimal.json"), _DE_JSON_CLIENT - ) - playlist = YandexPlaylist.de_json( - _load_json(FIXTURES_DIR / "playlists" / "minimal.json"), _DE_JSON_CLIENT - ) - return artist, album, track, playlist - - -def _make_search_result(track: Any, album: Any, artist: Any, playlist: Any) -> Any: - """Build a Search-like object with .tracks.results, .albums.results, etc.""" - return type( - "Search", - (), - { - "tracks": type("TracksResult", (), {"results": [track]})(), - "albums": type("AlbumsResult", (), {"results": [album]})(), - "artists": type("ArtistsResult", (), {"results": [artist]})(), - "playlists": type("PlaylistsResult", (), {"results": [playlist]})(), - }, - )() - - -def _make_download_info( - codec: str = "mp3", - direct_link: str = "https://example.com/yandex_track.mp3", - bitrate_in_kbps: int = 320, -) -> Any: - """Build DownloadInfo-like object for streaming.""" - return type( - "DownloadInfo", - (), - { - "direct_link": direct_link, - "codec": codec, - "bitrate_in_kbps": bitrate_in_kbps, - }, - )() - - -@pytest.fixture -async def yandex_music_provider( - mass: MusicAssistant, -) -> AsyncGenerator[ProviderConfig, None]: - """Configure Yandex Music provider with mocked API client and add to mass.""" - artist, album, track, playlist = _load_yandex_objects() - search_result = _make_search_result(track, album, artist, playlist) - download_info = _make_download_info() - - # Album with volumes for get_album_tracks - album_with_volumes = type( - "AlbumWithVolumes", - (), - { - "id": album.id, - "title": album.title, - "volumes": [[track]], - "artists": album.artists if hasattr(album, "artists") else [], - "year": getattr(album, "year", None), - "release_date": getattr(album, "release_date", None), - "genre": getattr(album, "genre", None), - "cover_uri": getattr(album, "cover_uri", None), - "og_image": getattr(album, "og_image", None), - "type": getattr(album, "type", "album"), - "available": getattr(album, "available", True), - }, - )() - - with mock.patch( - "music_assistant.providers.yandex_music.provider.YandexMusicClient" - ) as mock_client_class: - mock_client = mock.AsyncMock() - mock_client_class.return_value = mock_client - - mock_client.connect = mock.AsyncMock(return_value=True) - mock_client.user_id = 12345 - - mock_client.get_liked_tracks = mock.AsyncMock(return_value=[]) - mock_client.get_liked_albums = mock.AsyncMock(return_value=[]) - mock_client.get_liked_artists = mock.AsyncMock(return_value=[]) - mock_client.get_user_playlists = mock.AsyncMock(return_value=[playlist]) - - mock_client.search = mock.AsyncMock(return_value=search_result) - mock_client.get_track = mock.AsyncMock(return_value=track) - mock_client.get_tracks = mock.AsyncMock(return_value=[track]) - mock_client.get_album = mock.AsyncMock(return_value=album) - mock_client.get_album_with_tracks = mock.AsyncMock(return_value=album_with_volumes) - mock_client.get_artist = mock.AsyncMock(return_value=artist) - mock_client.get_artist_albums = mock.AsyncMock(return_value=[album]) - mock_client.get_artist_tracks = mock.AsyncMock(return_value=[track]) - mock_client.get_playlist = mock.AsyncMock(return_value=playlist) - mock_client.get_track_download_info = mock.AsyncMock(return_value=[download_info]) - mock_client.get_track_lyrics = mock.AsyncMock(return_value=(None, False)) - mock_client.get_track_lyrics_from_track = mock.AsyncMock(return_value=(None, False)) - - async with wait_for_sync_completion(mass): - config = await mass.config.save_provider_config( - "yandex_music", - {"token": "mock_yandex_token", "quality": "high"}, - ) - await mass.music.start_sync() - - yield config - - -@pytest.fixture -async def yandex_music_provider_lossless( - mass: MusicAssistant, -) -> AsyncGenerator[ProviderConfig, None]: - """Configure Yandex Music with quality=lossless and mock returning MP3 + FLAC.""" - artist, album, track, playlist = _load_yandex_objects() - search_result = _make_search_result(track, album, artist, playlist) - mp3_info = _make_download_info( - codec="mp3", - direct_link="https://example.com/yandex_track.mp3", - bitrate_in_kbps=320, - ) - flac_info = _make_download_info( - codec="flac", - direct_link="https://example.com/yandex_track.flac", - bitrate_in_kbps=0, - ) - download_infos = [mp3_info, flac_info] - - album_with_volumes = type( - "AlbumWithVolumes", - (), - { - "id": album.id, - "title": album.title, - "volumes": [[track]], - "artists": album.artists if hasattr(album, "artists") else [], - "year": getattr(album, "year", None), - "release_date": getattr(album, "release_date", None), - "genre": getattr(album, "genre", None), - "cover_uri": getattr(album, "cover_uri", None), - "og_image": getattr(album, "og_image", None), - "type": getattr(album, "type", "album"), - "available": getattr(album, "available", True), - }, - )() - - with mock.patch( - "music_assistant.providers.yandex_music.provider.YandexMusicClient" - ) as mock_client_class: - mock_client = mock.AsyncMock() - mock_client_class.return_value = mock_client - - mock_client.connect = mock.AsyncMock(return_value=True) - mock_client.user_id = 12345 - - mock_client.get_liked_tracks = mock.AsyncMock(return_value=[]) - mock_client.get_liked_albums = mock.AsyncMock(return_value=[]) - mock_client.get_liked_artists = mock.AsyncMock(return_value=[]) - mock_client.get_user_playlists = mock.AsyncMock(return_value=[playlist]) - - mock_client.search = mock.AsyncMock(return_value=search_result) - mock_client.get_track = mock.AsyncMock(return_value=track) - mock_client.get_tracks = mock.AsyncMock(return_value=[track]) - mock_client.get_album = mock.AsyncMock(return_value=album) - mock_client.get_album_with_tracks = mock.AsyncMock(return_value=album_with_volumes) - mock_client.get_artist = mock.AsyncMock(return_value=artist) - mock_client.get_artist_albums = mock.AsyncMock(return_value=[album]) - mock_client.get_artist_tracks = mock.AsyncMock(return_value=[track]) - mock_client.get_playlist = mock.AsyncMock(return_value=playlist) - # get-file-info lossless is tried first; mock returns None so we use download_info path - mock_client.get_track_file_info_lossless = mock.AsyncMock(return_value=None) - mock_client.get_track_download_info = mock.AsyncMock(return_value=download_infos) - mock_client.get_track_lyrics = mock.AsyncMock(return_value=(None, False)) - mock_client.get_track_lyrics_from_track = mock.AsyncMock(return_value=(None, False)) - - async with wait_for_sync_completion(mass): - config = await mass.config.save_provider_config( - "yandex_music", - {"token": "mock_yandex_token", "quality": "lossless"}, - ) - await mass.music.start_sync() - - yield config - - -def _get_yandex_provider(mass: MusicAssistant) -> MusicProvider | None: - """Get Yandex Music provider instance from mass.""" - for provider in mass.music.providers: - if provider.domain == "yandex_music": - return provider - return None - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_registration_and_sync(mass: MusicAssistant) -> None: - """Test that provider is registered and sync completes.""" - prov = _get_yandex_provider(mass) - assert prov is not None - assert prov.domain == "yandex_music" - assert prov.instance_id - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_search(mass: MusicAssistant) -> None: - """Test search returns results from yandex_music.""" - results = await mass.music.search("test query", [MediaType.TRACK], limit=5) - yandex_tracks = [t for t in results.tracks if t.provider and "yandex_music" in t.provider] - assert len(yandex_tracks) >= 0 - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_artist(mass: MusicAssistant) -> None: - """Test getting artist by id.""" - prov = _get_yandex_provider(mass) - assert prov is not None - artist = await prov.get_artist("100") - assert artist is not None - assert artist.name - assert artist.provider == prov.instance_id - assert artist.item_id == "100" - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_album(mass: MusicAssistant) -> None: - """Test getting album by id.""" - prov = _get_yandex_provider(mass) - assert prov is not None - album = await prov.get_album("300") - assert album is not None - assert album.name - assert album.provider == prov.instance_id - assert album.item_id == "300" - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_track(mass: MusicAssistant) -> None: - """Test getting track by id.""" - prov = _get_yandex_provider(mass) - assert prov is not None - track = await prov.get_track("400") - assert track is not None - assert track.name - assert track.provider == prov.instance_id - assert track.item_id == "400" - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_album_tracks(mass: MusicAssistant) -> None: - """Test getting album tracks.""" - prov = _get_yandex_provider(mass) - assert prov is not None - tracks = await prov.get_album_tracks("300") - assert isinstance(tracks, list) - assert len(tracks) >= 0 - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_playlist_tracks(mass: MusicAssistant) -> None: - """Test getting playlist tracks.""" - prov = _get_yandex_provider(mass) - assert prov is not None - tracks = await prov.get_playlist_tracks("12345:3", page=0) - assert isinstance(tracks, list) - assert len(tracks) >= 0 - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_stream_details(mass: MusicAssistant) -> None: - """Test stream details retrieval.""" - prov = _get_yandex_provider(mass) - assert prov is not None - stream_details = await prov.get_stream_details("400", MediaType.TRACK) - assert stream_details is not None - assert stream_details.stream_type == StreamType.HTTP - assert stream_details.path == "https://example.com/yandex_track.mp3" - - -@pytest.mark.usefixtures("yandex_music_provider_lossless") -async def test_get_stream_details_returns_flac_when_lossless_selected( - mass: MusicAssistant, -) -> None: - """When quality=lossless and API returns MP3+FLAC, stream details use FLAC.""" - prov = _get_yandex_provider(mass) - assert prov is not None - stream_details = await prov.get_stream_details("400", MediaType.TRACK) - assert stream_details is not None - assert stream_details.audio_format.content_type == ContentType.FLAC - assert stream_details.path == "https://example.com/yandex_track.flac" - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_library_items(mass: MusicAssistant) -> None: - """Test library artists, albums, tracks, playlists.""" - prov = _get_yandex_provider(mass) - assert prov is not None - instance_id = prov.instance_id - - artists = await mass.music.artists.library_items() - yandex_artists = [a for a in artists if a.provider == instance_id] - assert len(yandex_artists) >= 0 - - albums = await mass.music.albums.library_items() - yandex_albums = [a for a in albums if a.provider == instance_id] - assert len(yandex_albums) >= 0 - - tracks = await mass.music.tracks.library_items() - yandex_tracks = [t for t in tracks if t.provider == instance_id] - assert len(yandex_tracks) >= 0 - - playlists = await mass.music.playlists.library_items() - yandex_playlists = [p for p in playlists if p.provider == instance_id] - assert len(yandex_playlists) >= 0 - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_browse(mass: MusicAssistant) -> None: - """Test browse root and subpaths.""" - prov = _get_yandex_provider(mass) - assert prov is not None - base_path = f"{prov.instance_id}://" - root_items = await prov.browse(path=base_path) - assert root_items is not None - assert isinstance(root_items, (list, tuple)) - all_names = set(BROWSE_NAMES_RU.values()) | set(BROWSE_NAMES_EN.values()) - if root_items: - first_name = getattr(root_items[0], "name", None) - assert first_name in all_names, ( - f"First folder name {first_name!r} should be from locale mapping" - ) - - artists_path = f"{prov.instance_id}://artists" - artists_items = await prov.browse(path=artists_path) - assert artists_items is not None - assert isinstance(artists_items, (list, tuple)) - - -# -- Playlist edge-case tests -------------------------------------------------- - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_playlist_tracks_page_gt_zero_returns_empty(mass: MusicAssistant) -> None: - """Page > 0 returns empty list (Yandex returns all tracks in one call).""" - prov = _get_yandex_provider(mass) - assert prov is not None - # Use a different playlist ID to avoid cache collision with test_get_playlist_tracks - result = await prov.get_playlist_tracks("12345:99", page=1) - assert result == [] - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_playlist_tracks_fetch_tracks_async_fallback(mass: MusicAssistant) -> None: - """When playlist.tracks is None but track_count > 0, fetch_tracks_async is used.""" - prov = _get_yandex_provider(mass) - assert prov is not None - - _, _, track, _ = _load_yandex_objects() - - # Build a playlist object with tracks=None and track_count=5 - track_short = type("TrackShort", (), {"track_id": 400, "id": 400})() - playlist_no_tracks = type( - "Playlist", - (), - { - "owner": type("Owner", (), {"uid": 12345})(), - "kind": 77, - "title": "Fallback Playlist", - "tracks": None, - "track_count": 5, - "fetch_tracks_async": mock.AsyncMock(return_value=[track_short]), - }, - )() - - prov.client.get_playlist = mock.AsyncMock(return_value=playlist_no_tracks) # type: ignore[attr-defined] - prov.client.get_tracks = mock.AsyncMock(return_value=[track]) # type: ignore[attr-defined] - - result = await prov.get_playlist_tracks("12345:77", page=0) - assert isinstance(result, list) - assert len(result) >= 1 - playlist_no_tracks.fetch_tracks_async.assert_awaited_once() - - -@pytest.mark.usefixtures("yandex_music_provider") -async def test_get_playlist_tracks_empty_batch_raises(mass: MusicAssistant) -> None: - """Empty batch result from get_tracks raises ResourceTemporarilyUnavailable.""" - prov = _get_yandex_provider(mass) - assert prov is not None - - # Build a playlist with tracks that have track_ids - track_short = type("TrackShort", (), {"track_id": 400, "id": 400})() - playlist_with_tracks = type( - "Playlist", - (), - { - "owner": type("Owner", (), {"uid": 12345})(), - "kind": 88, - "title": "Batch Fail Playlist", - "tracks": [track_short], - "track_count": 1, - }, - )() - - prov.client.get_playlist = mock.AsyncMock(return_value=playlist_with_tracks) # type: ignore[attr-defined] - prov.client.get_tracks = mock.AsyncMock(return_value=[]) # type: ignore[attr-defined] - - with pytest.raises(ResourceTemporarilyUnavailable): - await prov.get_playlist_tracks("12345:88", page=0) diff --git a/tests/providers/yandex_music/test_my_wave.py b/tests/providers/yandex_music/test_my_wave.py index 57af434de0..3fd0b4fbd7 100644 --- a/tests/providers/yandex_music/test_my_wave.py +++ b/tests/providers/yandex_music/test_my_wave.py @@ -2,11 +2,33 @@ from __future__ import annotations +import json +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from music_assistant_models.enums import MediaType +from music_assistant_models.errors import InvalidDataError +from music_assistant_models.media_items import ProviderMapping +from music_assistant_models.media_items import Track as MATrack + +from music_assistant.providers.yandex_music import ( + _delete_wave_preset_action, + _save_wave_preset_action, +) from music_assistant.providers.yandex_music.constants import ( RADIO_TRACK_ID_SEP, ROTOR_STATION_MY_WAVE, ) -from music_assistant.providers.yandex_music.provider import _parse_radio_item_id +from music_assistant.providers.yandex_music.parsers import parse_playlist +from music_assistant.providers.yandex_music.provider import ( + YandexMusicProvider, + _parse_radio_item_id, + _WaveState, +) + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ConfigValueType def test_parse_radio_item_id_plain_track_id() -> None: @@ -22,3 +44,672 @@ def test_parse_radio_item_id_composite() -> None: ROTOR_STATION_MY_WAVE, ) assert _parse_radio_item_id("99@user:custom") == ("99", "user:custom") + + +def test_wave_state_has_session_fields() -> None: + """_WaveState exposes session_id, playlist_next_cursor, prefetched, settings.""" + state = _WaveState() + # Session-based rotor API identifiers + assert state.session_id is None + # Legacy stations-based identifier retained during migration + assert state.batch_id is None + # Pagination cursor for virtual playlist pages + assert state.playlist_next_cursor is None + # Prefetch buffer for future-batch tracks + assert state.prefetched == [] + # Persistent station settings (diversity/moodEnergy/language) + assert state.settings == {} + # Once-per-session flag + assert state.radio_started_sent is False + + +def test_wave_state_is_per_instance_isolated() -> None: + """Each _WaveState has its own mutable containers (no shared class state).""" + a, b = _WaveState(), _WaveState() + a.seen_track_ids.add("1") + a.prefetched.append("x") + a.settings["diversity"] = "discover" + assert b.seen_track_ids == set() + assert b.prefetched == [] + assert b.settings == {} + + +# -- _fetch_rotor_session_batch (session-API helper) -------------------------- + + +@pytest.mark.asyncio +async def test_fetch_rotor_session_batch_starts_session_on_first_call() -> None: + """First call creates a rotor session and records session_id + batch_id on wave state.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_new = AsyncMock( + return_value=("sess_1", ["track1", "track2"], "batch_a") + ) + provider.client.rotor_session_tracks = AsyncMock() + wave = _WaveState() + + tracks, batch_id = await YandexMusicProvider._fetch_rotor_session_batch( + provider, wave, ROTOR_STATION_MY_WAVE + ) + + provider.client.rotor_session_new.assert_awaited_once_with(ROTOR_STATION_MY_WAVE, settings=None) + provider.client.rotor_session_tracks.assert_not_awaited() + assert wave.session_id == "sess_1" + assert wave.batch_id == "batch_a" + assert tracks == ["track1", "track2"] + assert batch_id == "batch_a" + + +@pytest.mark.asyncio +async def test_fetch_rotor_session_batch_passes_wave_settings_to_session_new() -> None: + """Session creation forwards wave.settings (diversity/moodEnergy/language) as seeds.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_new = AsyncMock(return_value=("s", [], "b")) + wave = _WaveState() + wave.settings = {"diversity": "discover", "moodEnergy": "calm"} + + await YandexMusicProvider._fetch_rotor_session_batch(provider, wave, ROTOR_STATION_MY_WAVE) + + _, kwargs = provider.client.rotor_session_new.await_args + assert kwargs["settings"] == {"diversity": "discover", "moodEnergy": "calm"} + + +@pytest.mark.asyncio +async def test_fetch_rotor_session_batch_paginates_via_session_tracks_after_first_call() -> None: + """Once session_id is set, subsequent calls use rotor_session_tracks with last_track_id.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_new = AsyncMock() + provider.client.rotor_session_tracks = AsyncMock(return_value=(["t3"], "batch_b")) + wave = _WaveState() + wave.session_id = "sess_1" + wave.last_track_id = "42" + + tracks, _batch_id = await YandexMusicProvider._fetch_rotor_session_batch( + provider, wave, ROTOR_STATION_MY_WAVE + ) + + provider.client.rotor_session_new.assert_not_awaited() + provider.client.rotor_session_tracks.assert_awaited_once_with("sess_1", current_track_id="42") + assert wave.batch_id == "batch_b" + assert tracks == ["t3"] + + +@pytest.mark.asyncio +async def test_fetch_rotor_session_batch_returns_empty_when_session_new_fails() -> None: + """When session creation returns None session_id, wave is not mutated and result is empty.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_new = AsyncMock(return_value=(None, [], None)) + wave = _WaveState() + + tracks, batch_id = await YandexMusicProvider._fetch_rotor_session_batch( + provider, wave, ROTOR_STATION_MY_WAVE + ) + + assert wave.session_id is None + assert tracks == [] + assert batch_id is None + + +@pytest.mark.asyncio +async def test_fetch_rotor_session_batch_works_with_track_seed_station() -> None: + """get_similar_tracks uses station 'track:{id}' — same session machinery.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_new = AsyncMock(return_value=("s", ["t"], "b")) + wave = _WaveState() + + await YandexMusicProvider._fetch_rotor_session_batch(provider, wave, "track:9999") + + provider.client.rotor_session_new.assert_awaited_once_with("track:9999", settings=None) + assert wave.session_id == "s" + + +# -- ynison compatibility wrapper --------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_rotor_station_tracks_wrapper_delegates_to_session_batch() -> None: + """Ynison-facing wrapper routes through _fetch_rotor_session_batch. + + This keeps ynison on the session API (long-lived radioSessionId, shared + wave state, prefetch) without any code change on its side — the + ``(tracks, batch_id)`` shape stays the same. + """ + wave = _WaveState() + provider = Mock(spec=YandexMusicProvider) + provider._get_wave_state = Mock(return_value=wave) + provider._fetch_rotor_session_batch = AsyncMock(return_value=(["t1", "t2"], "batch_1")) + + tracks, batch_id = await YandexMusicProvider.get_rotor_station_tracks( + provider, "genre:rock", queue=None + ) + + provider._get_wave_state.assert_called_once_with("genre:rock") + provider._fetch_rotor_session_batch.assert_awaited_once_with(wave, "genre:rock") + assert tracks == ["t1", "t2"] + assert batch_id == "batch_1" + + +@pytest.mark.asyncio +async def test_get_rotor_station_tracks_wrapper_records_queue_as_cursor() -> None: + """Ynison's queue= arg becomes wave.last_track_id so the next call paginates.""" + wave = _WaveState() + provider = Mock(spec=YandexMusicProvider) + provider._get_wave_state = Mock(return_value=wave) + provider._fetch_rotor_session_batch = AsyncMock(return_value=([], None)) + + await YandexMusicProvider.get_rotor_station_tracks(provider, "mood:calm", queue="42") + + assert wave.last_track_id == "42" + provider._fetch_rotor_session_batch.assert_awaited_once_with(wave, "mood:calm") + + +# -- wave-mode preset routing ------------------------------------------------- + + +@pytest.mark.asyncio +async def test_fetch_rotor_session_batch_resolves_wave_mode_preset_settings() -> None: + """A station key like 'user:onyourwave#discover' translates to settingDiversity=discover.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_new = AsyncMock(return_value=("sess_1", [], "batch_a")) + wave = _WaveState() + + await YandexMusicProvider._fetch_rotor_session_batch( + provider, wave, f"{ROTOR_STATION_MY_WAVE}#discover" + ) + + provider.client.rotor_session_new.assert_awaited_once_with( + ROTOR_STATION_MY_WAVE, settings={"diversity": "discover"} + ) + assert wave.session_id == "sess_1" + + +@pytest.mark.asyncio +async def test_fetch_rotor_session_batch_preset_merges_with_explicit_wave_settings() -> None: + """Explicit wave.settings overrides preset settings on the same key.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_new = AsyncMock(return_value=("s", [], "b")) + wave = _WaveState() + wave.settings = {"diversity": "popular"} # overrides preset + + await YandexMusicProvider._fetch_rotor_session_batch( + provider, wave, f"{ROTOR_STATION_MY_WAVE}#discover" + ) + + _, kwargs = provider.client.rotor_session_new.await_args + # wave.settings wins over preset + assert kwargs["settings"] == {"diversity": "popular"} + + +@pytest.mark.asyncio +async def test_fetch_rotor_session_batch_unknown_preset_strips_suffix_no_settings() -> None: + """Unknown '#' suffix is stripped from the station key and no extra settings are sent.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_new = AsyncMock(return_value=(None, [], None)) + wave = _WaveState() + + await YandexMusicProvider._fetch_rotor_session_batch( + provider, wave, f"{ROTOR_STATION_MY_WAVE}#does_not_exist" + ) + + # Base station is used; unknown preset yields empty settings → settings=None. + provider.client.rotor_session_new.assert_awaited_once_with(ROTOR_STATION_MY_WAVE, settings=None) + + +# -- _parse_my_wave_track with explicit station_key -------------------------- + + +# -- prefetch next batch (P6) ------------------------------------------------- + + +@pytest.mark.asyncio +async def test_prefetch_rotor_session_fills_prefetched_when_idle() -> None: + """With an active session + cursor and no prefetched tracks, fills wave.prefetched.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_tracks = AsyncMock(return_value=(["t1", "t2"], "batch_b")) + wave = _WaveState() + wave.session_id = "sess_1" + wave.last_track_id = "42" + provider._wave_states = {ROTOR_STATION_MY_WAVE: wave} + + await YandexMusicProvider._prefetch_rotor_session(provider, ROTOR_STATION_MY_WAVE) + + provider.client.rotor_session_tracks.assert_awaited_once_with("sess_1", current_track_id="42") + assert wave.prefetched == ["t1", "t2"] + + +@pytest.mark.asyncio +async def test_prefetch_rotor_session_noop_without_session() -> None: + """Prefetch does nothing when the station has no active session_id.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_tracks = AsyncMock() + wave = _WaveState() + provider._wave_states = {ROTOR_STATION_MY_WAVE: wave} + + await YandexMusicProvider._prefetch_rotor_session(provider, ROTOR_STATION_MY_WAVE) + + provider.client.rotor_session_tracks.assert_not_awaited() + assert wave.prefetched == [] + + +@pytest.mark.asyncio +async def test_prefetch_rotor_session_noop_without_cursor() -> None: + """Prefetch bails when session exists but no last_track_id cursor yet.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_tracks = AsyncMock() + wave = _WaveState() + wave.session_id = "sess_1" # but last_track_id still None + provider._wave_states = {ROTOR_STATION_MY_WAVE: wave} + + await YandexMusicProvider._prefetch_rotor_session(provider, ROTOR_STATION_MY_WAVE) + + provider.client.rotor_session_tracks.assert_not_awaited() + assert wave.prefetched == [] + + +@pytest.mark.asyncio +async def test_prefetch_rotor_session_noop_when_already_prefilled() -> None: + """Prefetch skips work when wave.prefetched already has items (avoid rate burn).""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_tracks = AsyncMock() + wave = _WaveState() + wave.session_id = "sess_1" + wave.last_track_id = "42" + wave.prefetched = ["existing_track"] + provider._wave_states = {ROTOR_STATION_MY_WAVE: wave} + + await YandexMusicProvider._prefetch_rotor_session(provider, ROTOR_STATION_MY_WAVE) + + provider.client.rotor_session_tracks.assert_not_awaited() + + +# -- rotor feedback on library_add (P5) --------------------------------------- + + +@pytest.mark.asyncio +async def test_library_add_track_from_wave_also_sends_rotor_like() -> None: + """library_add for a track from a wave session sends both users.like and rotor.like.""" + provider = Mock(spec=YandexMusicProvider) + provider.instance_id = "yandex_music_instance" + provider.logger = Mock() + provider.client = AsyncMock() + provider.client.like_track = AsyncMock(return_value=True) + composite = f"12345{RADIO_TRACK_ID_SEP}{ROTOR_STATION_MY_WAVE}" + provider._get_provider_item_id = Mock(return_value=composite) + # Share a session so like is routed to rotor_session_feedback + wave = _WaveState() + wave.session_id = "sess_1" + wave.batch_id = "batch_a" + provider._wave_states = {ROTOR_STATION_MY_WAVE: wave} + provider._get_wave_state = Mock(return_value=wave) + provider._send_wave_feedback = AsyncMock(return_value=True) + + item = MATrack( + item_id=composite, + provider="yandex_music_instance", + name="Test", + provider_mappings={ + ProviderMapping( + item_id=composite, + provider_domain="yandex_music", + provider_instance="yandex_music_instance", + ) + }, + ) + item.media_type = MediaType.TRACK + + result = await YandexMusicProvider.library_add(provider, item) + + assert result is True + provider.client.like_track.assert_awaited_once_with("12345") + provider._send_wave_feedback.assert_awaited_once() + args, kwargs = provider._send_wave_feedback.await_args + assert args[0] is wave + assert args[1] == ROTOR_STATION_MY_WAVE + assert args[2] == "like" + assert kwargs == {"track_id": "12345"} + + +@pytest.mark.asyncio +async def test_library_add_track_without_station_skips_rotor_feedback() -> None: + """Plain track_id (no station suffix) does NOT trigger rotor feedback.""" + provider = Mock(spec=YandexMusicProvider) + provider.instance_id = "yandex_music_instance" + provider.logger = Mock() + provider.client = AsyncMock() + provider.client.like_track = AsyncMock(return_value=True) + provider._get_provider_item_id = Mock(return_value="12345") + provider._send_wave_feedback = AsyncMock() + + item = MATrack( + item_id="12345", + provider="yandex_music_instance", + name="Test", + provider_mappings={ + ProviderMapping( + item_id="12345", + provider_domain="yandex_music", + provider_instance="yandex_music_instance", + ) + }, + ) + item.media_type = MediaType.TRACK + + await YandexMusicProvider.library_add(provider, item) + + provider.client.like_track.assert_awaited_once_with("12345") + provider._send_wave_feedback.assert_not_awaited() + + +# -- user wave presets (P8) --------------------------------------------------- + + +def _preset_config(values: dict[str, str]) -> Mock: + """Build a config stub whose get_value looks up keys in the given dict. + + Non-listed keys return None, matching MA's ``ConfigValueType | None`` contract. + """ + config = Mock() + config.get_value = Mock(side_effect=values.get) + return config + + +def test_get_user_wave_presets_decodes_stored_json() -> None: + """A valid JSON list in CONF_WAVE_PRESETS_DATA yields the same presets out.""" + provider = Mock(spec=YandexMusicProvider) + provider.config = _preset_config( + { + "wave_presets_data": ( + '[{"name": "Morning", "diversity": "discover", "moodEnergy": "calm"},' + ' {"name": "Evening", "language": "russian"}]' + ), + } + ) + provider.logger = Mock() + + result = YandexMusicProvider._get_user_wave_presets(provider) + + assert result == [ + {"name": "Morning", "diversity": "discover", "moodEnergy": "calm"}, + {"name": "Evening", "language": "russian"}, + ] + + +def test_get_user_wave_presets_empty_store_returns_empty() -> None: + """No stored data / empty string / None → empty list.""" + provider = Mock(spec=YandexMusicProvider) + provider.config = _preset_config({"wave_presets_data": ""}) + provider.logger = Mock() + + assert YandexMusicProvider._get_user_wave_presets(provider) == [] + + +def test_get_user_wave_presets_invalid_json_returns_empty() -> None: + """Malformed JSON → empty list (silent; matches the settings-UI parser).""" + provider = Mock(spec=YandexMusicProvider) + provider.config = _preset_config({"wave_presets_data": "not-json {{{"}) + provider.logger = Mock() + + assert YandexMusicProvider._get_user_wave_presets(provider) == [] + + +def test_get_user_wave_presets_skips_items_without_name() -> None: + """Entries missing a name or with non-string values are silently skipped.""" + provider = Mock(spec=YandexMusicProvider) + provider.config = _preset_config( + { + "wave_presets_data": ( + '[{"diversity": "discover"}, {"name": ""}, ' + '{"name": "Good", "moodEnergy": "active"}]' + ), + } + ) + provider.logger = Mock() + + assert YandexMusicProvider._get_user_wave_presets(provider) == [ + {"name": "Good", "moodEnergy": "active"}, + ] + + +def test_get_user_wave_presets_drops_whitespace_only_values() -> None: + """Whitespace-only dropdown values (e.g. hand-edited JSON) are treated as empty. + + Yandex rejects ``settingDiversity:`` with a 4xx, so the parser must not + propagate such values. Valid values are also stripped to their canonical + form so the downstream rotor seed builder always gets the stored string + without surrounding whitespace. + """ + provider = Mock(spec=YandexMusicProvider) + provider.config = _preset_config( + { + "wave_presets_data": ( + '[{"name": "WS-only", "diversity": " ",' + ' "moodEnergy": "\\t", "language": ""},' + ' {"name": "Trim", "diversity": " discover "}]' + ), + } + ) + provider.logger = Mock() + + assert YandexMusicProvider._get_user_wave_presets(provider) == [ + {"name": "WS-only"}, + {"name": "Trim", "diversity": "discover"}, + ] + + +# -- save / delete preset actions -------------------------------------------- + + +def test_save_wave_preset_action_appends_and_clears_draft() -> None: + """Save action writes the draft into JSON storage and clears draft fields.""" + values: dict[str, ConfigValueType] = { + "wave_preset_draft_name": "Morning", + "wave_preset_draft_diversity": "discover", + "wave_preset_draft_mood": "calm", + "wave_preset_draft_language": "", # "default" dropdown → skipped + "wave_presets_data": "", + } + + _save_wave_preset_action(values) + + stored_raw = values["wave_presets_data"] + assert isinstance(stored_raw, str) + assert json.loads(stored_raw) == [ + {"name": "Morning", "diversity": "discover", "moodEnergy": "calm"}, + ] + assert values["wave_preset_draft_name"] is None + assert values["wave_preset_draft_diversity"] == "" + assert values["wave_preset_draft_mood"] == "" + assert values["wave_preset_draft_language"] == "" + + +def test_save_wave_preset_action_overwrites_same_name() -> None: + """Saving with an existing name replaces the prior entry — no duplicates.""" + values: dict[str, ConfigValueType] = { + "wave_preset_draft_name": "Morning", + "wave_preset_draft_diversity": "favorite", + "wave_preset_draft_mood": "", + "wave_preset_draft_language": "", + "wave_presets_data": ( + '[{"name": "Morning", "diversity": "discover"},' + ' {"name": "Evening", "language": "russian"}]' + ), + } + + _save_wave_preset_action(values) + + stored_raw = values["wave_presets_data"] + assert isinstance(stored_raw, str) + stored = json.loads(stored_raw) + assert {p["name"] for p in stored} == {"Morning", "Evening"} + morning = next(p for p in stored if p["name"] == "Morning") + assert morning == {"name": "Morning", "diversity": "favorite"} + + +def test_save_wave_preset_action_rejects_blank_name() -> None: + """Save without a preset name raises InvalidDataError and changes nothing.""" + values: dict[str, ConfigValueType] = { + "wave_preset_draft_name": " ", + "wave_presets_data": "", + } + + with pytest.raises(InvalidDataError): + _save_wave_preset_action(values) + assert values["wave_presets_data"] == "" + + +def test_delete_wave_preset_action_removes_by_name() -> None: + """Delete action drops the selected preset and clears the selector.""" + values: dict[str, ConfigValueType] = { + "wave_preset_to_delete": "Morning", + "wave_presets_data": ( + '[{"name": "Morning", "diversity": "discover"},' + ' {"name": "Evening", "language": "russian"}]' + ), + } + + _delete_wave_preset_action(values) + + stored_raw = values["wave_presets_data"] + assert isinstance(stored_raw, str) + assert json.loads(stored_raw) == [{"name": "Evening", "language": "russian"}] + assert values["wave_preset_to_delete"] == "" + + +def test_delete_wave_preset_action_requires_selection() -> None: + """No selection → InvalidDataError; storage untouched.""" + values: dict[str, ConfigValueType] = { + "wave_preset_to_delete": "", + "wave_presets_data": '[{"name": "Keep"}]', + } + + with pytest.raises(InvalidDataError): + _delete_wave_preset_action(values) + assert values["wave_presets_data"] == '[{"name": "Keep"}]' + + +def test_parse_playlist_is_dynamic_flag_propagates() -> None: + """parse_playlist honours is_dynamic=True so feed autoplaylists skip MA cache.""" + provider = Mock(spec=YandexMusicProvider) + provider.instance_id = "yandex_music_instance" + provider.domain = "yandex_music" + provider.client = Mock() + provider.client.user_id = 12345 + + playlist_obj = Mock() + playlist_obj.owner = Mock(uid=67890, name="Яндекс") + playlist_obj.kind = 42 + playlist_obj.title = "Плейлист дня" + playlist_obj.description = None + playlist_obj.cover = None + playlist_obj.track_count = 50 + playlist_obj.modified = None + playlist_obj.created = None + playlist_obj.tags = [] + + result_dynamic = parse_playlist(provider, playlist_obj, is_dynamic=True) + result_static = parse_playlist(provider, playlist_obj) + + assert result_dynamic.is_dynamic is True + assert result_static.is_dynamic is False + + +def test_parse_my_wave_track_uses_provided_station_key_for_item_id() -> None: + """_parse_my_wave_track stamps the supplied station_key on composite item_id.""" + # Build a minimal provider instance with the attributes _parse_my_wave_track + # reads; don't use Mock(spec=...) because we call the real method. + provider = Mock(spec=YandexMusicProvider) + provider.instance_id = "yandex_music_instance" + provider.logger = Mock() + + # Fake yandex track object + yt = type("YTrack", (), {"id": "12345", "track_id": "12345"})() + + # Return a minimal MA Track from parse_track; _parse_my_wave_track rewrites + # its item_id in-place to the composite form. + base_track = MATrack( + item_id="12345", + provider="yandex_music_instance", + name="Test", + provider_mappings={ + ProviderMapping( + item_id="12345", + provider_domain="yandex_music", + provider_instance="yandex_music_instance", + ) + }, + ) + with patch( + "music_assistant.providers.yandex_music.provider.parse_track", + return_value=base_track, + ): + station_key = f"{ROTOR_STATION_MY_WAVE}#discover" + seen: set[str] = set() + result = YandexMusicProvider._parse_my_wave_track( + provider, yt, seen, station_key=station_key + ) + + assert result is not None + assert result.item_id == f"12345{RADIO_TRACK_ID_SEP}{station_key}" + # And round-trip via _parse_radio_item_id + assert _parse_radio_item_id(result.item_id) == ("12345", station_key) + assert "12345" in seen + + +# -- _send_wave_feedback (session vs. stations API router) --------------------- + + +@pytest.mark.asyncio +async def test_send_wave_feedback_uses_session_api_when_session_id_present() -> None: + """When wave.session_id is set, feedback is routed to rotor_session_feedback.""" + provider = Mock(spec=YandexMusicProvider) + provider.client = AsyncMock() + provider.client.rotor_session_feedback = AsyncMock(return_value=True) + provider.client.send_rotor_station_feedback = AsyncMock() + wave = _WaveState() + wave.session_id = "sess_1" + wave.batch_id = "batch_a" + + result = await YandexMusicProvider._send_wave_feedback( + provider, wave, "user:onyourwave", "trackStarted", track_id="100" + ) + + assert result is True + provider.client.rotor_session_feedback.assert_awaited_once_with( + "sess_1", "trackStarted", track_id="100", total_played_seconds=None, batch_id="batch_a" + ) + provider.client.send_rotor_station_feedback.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_wave_feedback_skips_silently_without_session() -> None: + """Without ``wave.session_id`` the call is a silent no-op returning False. + + The legacy stations-based feedback endpoint is gone (returns 404), so we + can't usefully fall back there. Callers treat the False result as + "signal was dropped" — history reporting via play_audio still fires. + """ + provider = Mock(spec=YandexMusicProvider) + provider.logger = Mock() + provider.client = AsyncMock() + provider.client.rotor_session_feedback = AsyncMock() + wave = _WaveState() + wave.batch_id = "batch_a" # session_id still None + + result = await YandexMusicProvider._send_wave_feedback( + provider, wave, "genre:rock", "skip", track_id="9", total_played_seconds=7 + ) + + assert result is False + provider.client.rotor_session_feedback.assert_not_awaited() + provider.logger.debug.assert_called_once() diff --git a/tests/providers/yandex_music/test_parsers.py b/tests/providers/yandex_music/test_parsers.py index 18294ae056..6d6f162f53 100644 --- a/tests/providers/yandex_music/test_parsers.py +++ b/tests/providers/yandex_music/test_parsers.py @@ -13,9 +13,13 @@ from yandex_music import Track as YandexTrack from music_assistant.providers.yandex_music.parsers import ( + classify_album, parse_album, parse_artist, + parse_audiobook, parse_playlist, + parse_podcast, + parse_podcast_episode, parse_track, ) from music_assistant.providers.yandex_music.provider import YandexMusicProvider @@ -32,6 +36,8 @@ ALBUM_FIXTURES = list(FIXTURES_DIR.glob("albums/*.json")) TRACK_FIXTURES = list(FIXTURES_DIR.glob("tracks/*.json")) PLAYLIST_FIXTURES = list(FIXTURES_DIR.glob("playlists/*.json")) +PODCAST_FIXTURES = list(FIXTURES_DIR.glob("podcasts/*.json")) +AUDIOBOOK_FIXTURES = list(FIXTURES_DIR.glob("audiobooks/*.json")) def _load_json(path: pathlib.Path) -> dict[str, Any]: @@ -95,6 +101,56 @@ def test_parse_artist_with_cover(provider_stub: ProviderStub) -> None: assert "avatars.yandex.net" in (result.metadata.images[0].path or "") +def test_parse_artist_with_about(provider_stub: ProviderStub) -> None: + """parse_artist enriches description and popularity from ArtistAbout.""" + artist_obj = _artist_from_fixture(FIXTURES_DIR / "artists" / "with_cover.json") + assert artist_obj is not None + + about = type( + "ArtistAbout", + (), + { + "description": "Singer-songwriter from somewhere.", + "stats": type("Stats", (), {"last_month_listeners": 250_000})(), + }, + )() + + result = parse_artist(cast("YandexMusicProvider", provider_stub), artist_obj, about=about) + assert result.metadata.description == "Singer-songwriter from somewhere." + # 250000 // 10000 == 25 + assert result.metadata.popularity == 25 + + +def test_parse_artist_about_missing_fields(provider_stub: ProviderStub) -> None: + """parse_artist tolerates ArtistAbout with missing description/stats.""" + artist_obj = _artist_from_fixture(FIXTURES_DIR / "artists" / "with_cover.json") + assert artist_obj is not None + + about = type("ArtistAbout", (), {"description": None, "stats": None})() + + result = parse_artist(cast("YandexMusicProvider", provider_stub), artist_obj, about=about) + assert result.metadata.description is None + assert result.metadata.popularity is None + + +def test_parse_artist_about_clamps_popularity(provider_stub: ProviderStub) -> None: + """parse_artist caps very large monthly listeners at popularity 100.""" + artist_obj = _artist_from_fixture(FIXTURES_DIR / "artists" / "with_cover.json") + assert artist_obj is not None + + about = type( + "ArtistAbout", + (), + { + "description": "", + "stats": type("Stats", (), {"last_month_listeners": 50_000_000})(), + }, + )() + + result = parse_artist(cast("YandexMusicProvider", provider_stub), artist_obj, about=about) + assert result.metadata.popularity == 100 + + @pytest.mark.parametrize("example", ALBUM_FIXTURES, ids=lambda val: val.stem) def test_parse_album(example: pathlib.Path, provider_stub: ProviderStub) -> None: """Test we can parse albums from fixture JSON.""" @@ -245,3 +301,146 @@ def test_parse_playlist_snapshot( result = parse_playlist(cast("YandexMusicProvider", provider_stub), playlist_obj) parsed = _sort_for_snapshot(result.to_dict()) assert snapshot == parsed + + +# --- classify_album --- + + +@pytest.mark.parametrize( + ("meta_type", "type_", "expected"), + [ + ("podcast", None, "podcast"), + (None, "podcast", "podcast"), + ("Podcast", None, "podcast"), + ("podcast_episode", None, "podcast"), + ("audiobook", None, "audiobook"), + (None, "audiobook", "audiobook"), + ("AUDIOBOOK", None, "audiobook"), + # audiobook wins over podcast on any field — empirically observed: + # Yandex tags audiobooks as meta_type="podcast" + type="audiobook" + ("podcast", "audiobook", "audiobook"), + ("audiobook", "podcast", "audiobook"), + ("audiobook", "music", "audiobook"), + # plain music + (None, None, "music"), + ("music", "album", "music"), + ("", "", "music"), + ], +) +def test_classify_album( + meta_type: str | None, + type_: str | None, + expected: str, +) -> None: + """classify_album maps meta_type/type variants to music/podcast/audiobook.""" + album_obj = YandexAlbum.de_json( + {"id": 1, "title": "x", "meta_type": meta_type, "type": type_}, + DE_JSON_CLIENT, + ) + assert album_obj is not None + assert classify_album(album_obj) == expected + + +# --- Podcast / Audiobook / PodcastEpisode parsers --- + + +@pytest.mark.parametrize("example", PODCAST_FIXTURES, ids=lambda val: val.stem) +def test_parse_podcast(example: pathlib.Path, provider_stub: ProviderStub) -> None: + """parse_podcast extracts basic fields from a podcast-typed album fixture.""" + album_obj = _album_from_fixture(example) + assert album_obj is not None + result = parse_podcast(cast("YandexMusicProvider", provider_stub), album_obj) + assert result.item_id == str(album_obj.id) + assert result.name + assert result.provider == provider_stub.instance_id + mapping = next(iter(result.provider_mappings)) + assert f"music.yandex.ru/album/{album_obj.id}" in (mapping.url or "") + # publisher resolves from labels[0].name when present + if album_obj.labels: + first = album_obj.labels[0] + label_name = first if isinstance(first, str) else getattr(first, "name", None) + if label_name: + assert result.publisher == label_name + if album_obj.track_count is not None: + assert result.total_episodes == album_obj.track_count + + +@pytest.mark.parametrize("example", AUDIOBOOK_FIXTURES, ids=lambda val: val.stem) +def test_parse_audiobook(example: pathlib.Path, provider_stub: ProviderStub) -> None: + """parse_audiobook extracts authors from artists and publisher from labels.""" + album_obj = _album_from_fixture(example) + assert album_obj is not None + result = parse_audiobook(cast("YandexMusicProvider", provider_stub), album_obj) + assert result.item_id == str(album_obj.id) + assert result.name + assert result.duration == 0 # filled in later by get_audiobook() + # authors come from album artists + expected_authors = [a.name for a in (album_obj.artists or []) if a.name] + assert list(result.authors) == expected_authors + assert list(result.narrators) == [] + + +def test_parse_audiobook_fully_played_true(provider_stub: ProviderStub) -> None: + """parse_audiobook propagates album.listening_finished=True to fully_played.""" + album_obj = _album_from_fixture(FIXTURES_DIR / "audiobooks" / "basic.json") + assert album_obj is not None + album_obj.listening_finished = True + result = parse_audiobook(cast("YandexMusicProvider", provider_stub), album_obj) + assert result.fully_played is True + + +def test_parse_audiobook_fully_played_false(provider_stub: ProviderStub) -> None: + """parse_audiobook propagates album.listening_finished=False to fully_played.""" + album_obj = _album_from_fixture(FIXTURES_DIR / "audiobooks" / "basic.json") + assert album_obj is not None + album_obj.listening_finished = False + result = parse_audiobook(cast("YandexMusicProvider", provider_stub), album_obj) + assert result.fully_played is False + + +def test_parse_audiobook_fully_played_none(provider_stub: ProviderStub) -> None: + """parse_audiobook leaves fully_played=None when the flag is missing.""" + album_obj = _album_from_fixture(FIXTURES_DIR / "audiobooks" / "basic.json") + assert album_obj is not None + album_obj.listening_finished = None + result = parse_audiobook(cast("YandexMusicProvider", provider_stub), album_obj) + assert result.fully_played is None + + +def test_parse_podcast_episode(provider_stub: ProviderStub) -> None: + """parse_podcast_episode links episode to its parent podcast.""" + podcast_album = _album_from_fixture(FIXTURES_DIR / "podcasts" / "basic.json") + assert podcast_album is not None + podcast = parse_podcast(cast("YandexMusicProvider", provider_stub), podcast_album) + + track_obj = _track_from_fixture(FIXTURES_DIR / "podcast_episodes" / "basic.json") + assert track_obj is not None + episode = parse_podcast_episode( + cast("YandexMusicProvider", provider_stub), track_obj, podcast, position=1 + ) + assert episode.item_id == str(track_obj.id) + assert episode.name == track_obj.title + assert episode.position == 1 + assert episode.duration == (track_obj.duration_ms or 0) // 1000 + assert episode.podcast is podcast + mapping = next(iter(episode.provider_mappings)) + assert f"music.yandex.ru/track/{track_obj.id}" in (mapping.url or "") + + +def test_parse_podcast_episode_inherits_podcast_image(provider_stub: ProviderStub) -> None: + """Episode image falls back to parent podcast image when track has none.""" + podcast_album = _album_from_fixture(FIXTURES_DIR / "podcasts" / "basic.json") + assert podcast_album is not None + podcast = parse_podcast(cast("YandexMusicProvider", provider_stub), podcast_album) + # strip cover on the track so the fallback kicks in + track_obj = _track_from_fixture(FIXTURES_DIR / "podcast_episodes" / "basic.json") + assert track_obj is not None + track_obj.cover_uri = None + track_obj.og_image = None + episode = parse_podcast_episode( + cast("YandexMusicProvider", provider_stub), track_obj, podcast, position=1 + ) + assert episode.metadata.images is not None + assert episode.metadata.images == podcast.metadata.images + # Must be a separate list — mutating one shouldn't affect the other. + assert episode.metadata.images is not podcast.metadata.images diff --git a/tests/providers/yandex_music/test_recommendations.py b/tests/providers/yandex_music/test_recommendations.py index 1f09250aa7..e375f5864d 100644 --- a/tests/providers/yandex_music/test_recommendations.py +++ b/tests/providers/yandex_music/test_recommendations.py @@ -7,7 +7,13 @@ import pytest from music_assistant_models.errors import InvalidDataError -from music_assistant_models.media_items import Album, Playlist, RecommendationFolder, Track +from music_assistant_models.media_items import ( + Album, + Artist, + Playlist, + RecommendationFolder, + Track, +) from music_assistant.providers.yandex_music.constants import ( BROWSE_NAMES_EN, @@ -15,7 +21,7 @@ RADIO_TRACK_ID_SEP, ROTOR_STATION_MY_WAVE, ) -from music_assistant.providers.yandex_music.provider import YandexMusicProvider +from music_assistant.providers.yandex_music.provider import YandexMusicProvider, _WaveState @pytest.fixture @@ -48,18 +54,26 @@ def provider_mock() -> Mock: return provider +def _install_wave_state(provider_mock: Mock) -> _WaveState: + """Stub _get_wave_state to return a fresh in-memory _WaveState per provider_mock.""" + wave = _WaveState() + provider_mock._get_wave_state = Mock(return_value=wave) + return wave + + @pytest.mark.asyncio async def test_get_my_wave_recommendations_success(provider_mock: Mock) -> None: - """Test _get_my_wave_recommendations returns data when API provides tracks.""" - # Create mock track with required attributes + """Test _get_my_wave_recommendations returns data when session API provides tracks.""" + _install_wave_state(provider_mock) mock_track = Mock() mock_track.id = "12345" mock_track.track_id = "12345" - # Mock get_my_wave_tracks to return tracks - provider_mock.client.get_my_wave_tracks = AsyncMock(return_value=([mock_track], None)) + # Mock the session-API helper; return the same track every time — matches + # the old single-track-per-batch test intent where the fake rotor returns + # the same shape across repeated batch calls. + provider_mock._fetch_rotor_session_batch = AsyncMock(return_value=([mock_track], "batch_a")) - # Mock _parse_my_wave_track to return a Track object with composite item_id mock_parsed_track = Mock(spec=Track) mock_parsed_track.item_id = f"12345{RADIO_TRACK_ID_SEP}{ROTOR_STATION_MY_WAVE}" mock_parsed_track.name = "Test Track" @@ -79,8 +93,9 @@ async def test_get_my_wave_recommendations_success(provider_mock: Mock) -> None: @pytest.mark.asyncio async def test_get_my_wave_recommendations_empty(provider_mock: Mock) -> None: - """Test _get_my_wave_recommendations returns None when API returns no tracks.""" - provider_mock.client.get_my_wave_tracks = AsyncMock(return_value=([], None)) + """Test _get_my_wave_recommendations returns None when session API yields no tracks.""" + _install_wave_state(provider_mock) + provider_mock._fetch_rotor_session_batch = AsyncMock(return_value=([], None)) result = await YandexMusicProvider._get_my_wave_recommendations(provider_mock) @@ -89,8 +104,8 @@ async def test_get_my_wave_recommendations_empty(provider_mock: Mock) -> None: @pytest.mark.asyncio async def test_get_my_wave_recommendations_duplicate_filtering(provider_mock: Mock) -> None: - """Test _get_my_wave_recommendations filters duplicate tracks.""" - # Create mock tracks with same ID + """Test _get_my_wave_recommendations filters duplicate tracks across batches.""" + _install_wave_state(provider_mock) mock_track1 = Mock() mock_track1.id = "12345" mock_track1.track_id = "12345" @@ -99,11 +114,11 @@ async def test_get_my_wave_recommendations_duplicate_filtering(provider_mock: Mo mock_track2.id = "12345" # Same ID mock_track2.track_id = "12345" - # First call returns track1, second call returns track2 (duplicate) - provider_mock.client.get_my_wave_tracks = AsyncMock( + # First batch returns track1, second batch returns track2 (duplicate) + provider_mock._fetch_rotor_session_batch = AsyncMock( side_effect=[ - ([mock_track1], None), - ([mock_track2], None), + ([mock_track1], "batch_a"), + ([mock_track2], "batch_b"), ] ) @@ -124,12 +139,13 @@ async def test_get_my_wave_recommendations_duplicate_filtering(provider_mock: Mo @pytest.mark.asyncio async def test_get_my_wave_recommendations_invalid_data_error(provider_mock: Mock) -> None: - """Test _get_my_wave_recommendations handles InvalidDataError gracefully.""" + """Test _get_my_wave_recommendations handles parse failures gracefully.""" + _install_wave_state(provider_mock) mock_track = Mock() mock_track.id = "12345" mock_track.track_id = "12345" - provider_mock.client.get_my_wave_tracks = AsyncMock(return_value=([mock_track], None)) + provider_mock._fetch_rotor_session_batch = AsyncMock(return_value=([mock_track], "batch_a")) # _parse_my_wave_track returns None (simulates parse error handled internally) provider_mock._parse_my_wave_track = Mock(return_value=None) @@ -851,3 +867,46 @@ async def return_no_tag(_category: str) -> None: result = await YandexMusicProvider.recommendations(provider_mock) assert result == [] + + +@pytest.mark.asyncio +async def test_get_similar_artists_returns_parsed(provider_mock: Mock) -> None: + """get_similar_artists parses each artist from the underlying client.""" + yandex_artists = [Mock(), Mock(), Mock()] + provider_mock.client.get_similar_artists = AsyncMock(return_value=yandex_artists) + + parsed = [Mock(spec=Artist) for _ in yandex_artists] + with patch( + "music_assistant.providers.yandex_music.provider.parse_artist", + side_effect=parsed, + ): + result = await YandexMusicProvider.get_similar_artists(provider_mock, "42", limit=10) + + provider_mock.client.get_similar_artists.assert_awaited_once_with("42", limit=10) + assert result == parsed + + +@pytest.mark.asyncio +async def test_get_similar_artists_skips_invalid(provider_mock: Mock) -> None: + """get_similar_artists skips artists that fail to parse.""" + yandex_artists = [Mock(), Mock()] + provider_mock.client.get_similar_artists = AsyncMock(return_value=yandex_artists) + + parsed_ok = Mock(spec=Artist) + with patch( + "music_assistant.providers.yandex_music.provider.parse_artist", + side_effect=[InvalidDataError("missing id"), parsed_ok], + ): + result = await YandexMusicProvider.get_similar_artists(provider_mock, "99") + + assert result == [parsed_ok] + + +@pytest.mark.asyncio +async def test_get_similar_artists_empty(provider_mock: Mock) -> None: + """get_similar_artists returns [] when client returns no artists.""" + provider_mock.client.get_similar_artists = AsyncMock(return_value=[]) + + result = await YandexMusicProvider.get_similar_artists(provider_mock, "42") + + assert result == [] diff --git a/tests/providers/yandex_music/test_search_audiobooks.py b/tests/providers/yandex_music/test_search_audiobooks.py new file mode 100644 index 0000000000..5d817394b7 --- /dev/null +++ b/tests/providers/yandex_music/test_search_audiobooks.py @@ -0,0 +1,157 @@ +"""Tests for audiobook search routing.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, Mock + +import pytest +from music_assistant_models.enums import MediaType + +from music_assistant.providers.yandex_music.provider import YandexMusicProvider + + +def _fake_album(*, album_id: int, title: str, meta_type: str | None, type_: str | None) -> Mock: + """Minimal Yandex Album stand-in sufficient for classify_album + parse_audiobook.""" + album = Mock() + album.id = album_id + album.title = title + album.version = None + album.available = True + album.meta_type = meta_type + album.type = type_ + album.artists = [] + album.labels = [] + album.description = None + album.short_description = None + album.content_warning = None + album.genre = None + album.release_date = None + album.cover_uri = None + album.og_image = None + album.listening_finished = None + album.track_count = None + return album + + +def _fake_search_result(albums: list[Mock]) -> Mock: + """Build a search-result stub with an `albums.results` list and empty siblings.""" + result = Mock() + result.tracks = None + result.artists = None + result.playlists = None + result.podcasts = None + result.albums = Mock() + result.albums.results = albums + return result + + +@pytest.fixture +def provider_mock() -> Mock: + """Return a provider mock with a stubbed api_client.search.""" + provider = Mock(spec=YandexMusicProvider) + provider.domain = "yandex_music" + provider.instance_id = "yandex_music_instance" + provider.logger = Mock() + provider.client = AsyncMock() + # @use_cache decorator reads self.mass.cache — stub returning None (cache miss) + provider.mass = Mock() + provider.mass.cache = AsyncMock() + provider.mass.cache.get = AsyncMock(return_value=None) + provider.mass.cache.set = AsyncMock() + return provider + + +@pytest.mark.asyncio +async def test_search_audiobook_only_filters_albums(provider_mock: Mock) -> None: + """Requesting AUDIOBOOK only routes audiobook albums and drops music ones.""" + music = _fake_album(album_id=1, title="Plain Music", meta_type="music", type_="music") + book = _fake_album(album_id=2, title="Cool Book", meta_type="podcast", type_="audiobook") + provider_mock.client.search = AsyncMock(return_value=_fake_search_result([music, book])) + + result = await YandexMusicProvider.search( + provider_mock, "query", [MediaType.AUDIOBOOK], limit=5 + ) + + # Single Yandex API call with type_="album" + provider_mock.client.search.assert_awaited_once() + assert provider_mock.client.search.await_args.kwargs["search_type"] == "album" + + assert [a.item_id for a in result.audiobooks] == ["2"] + assert list(result.albums) == [] + + +@pytest.mark.asyncio +async def test_search_album_and_audiobook_split(provider_mock: Mock) -> None: + """Requesting both ALBUM and AUDIOBOOK splits the albums bucket cleanly.""" + music = _fake_album(album_id=10, title="Music", meta_type="music", type_="music") + book = _fake_album(album_id=20, title="Book", meta_type="podcast", type_="audiobook") + podcast = _fake_album(album_id=30, title="Podcast", meta_type="podcast", type_="podcast") + provider_mock.client.search = AsyncMock( + return_value=_fake_search_result([music, book, podcast]) + ) + + result = await YandexMusicProvider.search( + provider_mock, "q", [MediaType.ALBUM, MediaType.AUDIOBOOK], limit=5 + ) + + assert [a.item_id for a in result.albums] == ["10"] + assert [a.item_id for a in result.audiobooks] == ["20"] + + +@pytest.mark.asyncio +async def test_search_audiobook_not_dropped_by_limit_when_music_dominates( + provider_mock: Mock, +) -> None: + """Limit applied per bucket after classification, not before. + + Audiobooks tail-listed by Yandex must still appear when top ``limit`` + results are music albums. + """ + music_albums = [ + _fake_album(album_id=i, title=f"Music {i}", meta_type="music", type_="music") + for i in range(5) + ] + tail_audiobook = _fake_album( + album_id=99, title="Tail Book", meta_type="podcast", type_="audiobook" + ) + provider_mock.client.search = AsyncMock( + return_value=_fake_search_result([*music_albums, tail_audiobook]) + ) + + result = await YandexMusicProvider.search(provider_mock, "q", [MediaType.AUDIOBOOK], limit=3) + + # Even with only 3 results requested and 5 music albums ahead of it, + # the audiobook tail entry still lands in the audiobooks bucket. + assert [a.item_id for a in result.audiobooks] == ["99"] + + +@pytest.mark.asyncio +async def test_search_album_bucket_respects_limit_independently( + provider_mock: Mock, +) -> None: + """Albums bucket is capped at ``limit`` regardless of audiobook count.""" + albums = [ + _fake_album(album_id=i, title=f"M{i}", meta_type="music", type_="music") for i in range(10) + ] + provider_mock.client.search = AsyncMock(return_value=_fake_search_result(albums)) + + result = await YandexMusicProvider.search( + provider_mock, "q", [MediaType.ALBUM, MediaType.AUDIOBOOK], limit=3 + ) + + assert len(result.albums) == 3 + assert list(result.audiobooks) == [] + + +@pytest.mark.asyncio +async def test_search_albums_type_mapping_dedupe(provider_mock: Mock) -> None: + """ALBUM + AUDIOBOOK both map to Yandex 'album' — dedup keeps a single call type.""" + provider_mock.client.search = AsyncMock(return_value=_fake_search_result([])) + + await YandexMusicProvider.search( + provider_mock, "q", [MediaType.ALBUM, MediaType.AUDIOBOOK], limit=3 + ) + + provider_mock.client.search.assert_awaited_once() + # both map to "album"; with dedup there's a single requested_type → search_type='album' + assert provider_mock.client.search.await_args.kwargs["search_type"] == "album" diff --git a/tests/providers/yandex_music/test_streaming.py b/tests/providers/yandex_music/test_streaming.py index 1b72f869d9..814209bdf9 100644 --- a/tests/providers/yandex_music/test_streaming.py +++ b/tests/providers/yandex_music/test_streaming.py @@ -92,15 +92,19 @@ def test_select_best_quality_balanced_falls_back_to_highest( assert result.bitrate_in_kbps == 320 -def test_select_best_quality_label_lossless_flac_returns_flac( +def test_select_best_quality_legacy_lossless_alias_returns_flac( streaming_manager: YandexMusicStreamingManager, ) -> None: - """When preferred_quality is UI label 'Lossless (FLAC)', FLAC is selected.""" + """Legacy stored value 'lossless' (pre-Superb rename) still maps to FLAC. + + Current UI writes ``superb``; older configs may still hold the literal + ``lossless`` string. The selector must treat the two as synonyms. + """ mp3 = _make_download_info("mp3", 320, "https://example.com/track.mp3") flac = _make_download_info("flac", 0, "https://example.com/track.flac") download_infos = [mp3, flac] - result = streaming_manager._select_best_quality(download_infos, "Lossless (FLAC)") + result = streaming_manager._select_best_quality(download_infos, "lossless") assert result is not None assert result.codec == "flac" @@ -143,12 +147,12 @@ def test_select_best_quality_none_preferred_returns_highest_bitrate( assert result.bitrate_in_kbps == 320 -def test_get_content_type_flac_mp4_returns_mp4_container_with_flac_codec( +def test_get_content_type_flac_mp4_returns_flac_with_flac_codec( streaming_manager: YandexMusicStreamingManager, ) -> None: - """flac-mp4 codec from get-file-info is mapped to MP4 container with FLAC codec.""" - assert streaming_manager._get_content_type("flac-mp4") == (ContentType.MP4, ContentType.FLAC) - assert streaming_manager._get_content_type("FLAC-MP4") == (ContentType.MP4, ContentType.FLAC) + """flac-mp4 codec: content_type=FLAC (lossless), codec_type=FLAC (ffmpeg decoder).""" + assert streaming_manager._get_content_type("flac-mp4") == (ContentType.FLAC, ContentType.FLAC) + assert streaming_manager._get_content_type("FLAC-MP4") == (ContentType.FLAC, ContentType.FLAC) def test_get_content_type_flac_returns_flac_container_with_unknown_codec( @@ -168,11 +172,11 @@ def test_get_content_type_aac_variants_return_aac( assert streaming_manager._get_content_type("AAC") == (ContentType.AAC, ContentType.UNKNOWN) assert streaming_manager._get_content_type("he-aac") == (ContentType.AAC, ContentType.UNKNOWN) assert streaming_manager._get_content_type("HE-AAC") == (ContentType.AAC, ContentType.UNKNOWN) - # MP4 container variants - assert streaming_manager._get_content_type("aac-mp4") == (ContentType.MP4, ContentType.AAC) - assert streaming_manager._get_content_type("AAC-MP4") == (ContentType.MP4, ContentType.AAC) - assert streaming_manager._get_content_type("he-aac-mp4") == (ContentType.MP4, ContentType.AAC) - assert streaming_manager._get_content_type("HE-AAC-MP4") == (ContentType.MP4, ContentType.AAC) + # MP4 container variants — content_type=AAC (audio codec), codec_type=AAC (ffmpeg decoder) + assert streaming_manager._get_content_type("aac-mp4") == (ContentType.AAC, ContentType.AAC) + assert streaming_manager._get_content_type("AAC-MP4") == (ContentType.AAC, ContentType.AAC) + assert streaming_manager._get_content_type("he-aac-mp4") == (ContentType.AAC, ContentType.AAC) + assert streaming_manager._get_content_type("HE-AAC-MP4") == (ContentType.AAC, ContentType.AAC) # --- Efficient quality tests --- @@ -291,42 +295,87 @@ def test_select_best_quality_high_only_flac_returns_flac( assert result.codec == "flac" -# --- Audio params tests --- +# --- _build_audio_format tests --- -def test_get_audio_params_flac_mp4( +def test_build_audio_format_passes_api_params( streaming_manager: YandexMusicStreamingManager, ) -> None: - """flac-mp4 returns 48kHz/24bit.""" - assert streaming_manager._get_audio_params("flac-mp4") == (48000, 24) + """_build_audio_format forwards API-provided params to AudioFormat.""" + fmt = streaming_manager._build_audio_format( + "flac-mp4", + bit_rate=0, + sample_rate=48000, + bit_depth=24, + ) + assert fmt.content_type == ContentType.FLAC + assert fmt.sample_rate == 48000 + assert fmt.bit_depth == 24 -def test_get_audio_params_flac_mp4_case_insensitive( +def test_build_audio_format_keeps_defaults_when_zero( streaming_manager: YandexMusicStreamingManager, ) -> None: - """flac-mp4 matching is case-insensitive.""" - assert streaming_manager._get_audio_params("FLAC-MP4") == (48000, 24) + """Without explicit params, AudioFormat keeps its defaults (44100/16).""" + fmt = streaming_manager._build_audio_format("mp3") + assert fmt.content_type == ContentType.MP3 + assert fmt.sample_rate == 44100 + assert fmt.bit_depth == 16 -def test_get_audio_params_flac( - streaming_manager: YandexMusicStreamingManager, -) -> None: - """Plain FLAC returns CD-quality defaults.""" - assert streaming_manager._get_audio_params("flac") == (44100, 16) +# --- Container probe parser tests --- -def test_get_audio_params_mp3( +def test_parse_flac_streaminfo_valid( streaming_manager: YandexMusicStreamingManager, ) -> None: - """MP3 returns CD-quality defaults.""" - assert streaming_manager._get_audio_params("mp3") == (44100, 16) + """Parse real FLAC STREAMINFO: 48kHz, 24-bit.""" + # Build a minimal FLAC header: magic + block header + 34-byte STREAMINFO + # STREAMINFO bytes 10-13: sample_rate(20) | channels(3) | bps(5) | total(36 high bits) + # 48000 Hz = 0xBB80, 24-bit = 23 (stored as bps-1), stereo = 1 (channels-1) + # bits: 00001011101110000000 001 10111 0000... + # = 0x0BB80 << 12 | 0x1 << 9 | 23 << 4 | 0x0 = 0x0BB80BE0 ... but let's compute: + sr = 48000 + channels_minus1 = 1 # stereo + bps_minus1 = 23 # 24-bit + val = (sr << 12) | (channels_minus1 << 9) | (bps_minus1 << 4) + # Build 34-byte STREAMINFO payload + payload = bytearray(34) + payload[10:14] = val.to_bytes(4, "big") + # Full header: "fLaC" + block header (type=0, length=34) + payload + block_header = b"\x80" + (34).to_bytes(3, "big") # last-metadata-block flag + length + header = b"fLaC" + block_header + bytes(payload) + + result = streaming_manager._parse_flac_streaminfo(header) + assert result == (48000, 24) + + +def test_parse_flac_streaminfo_invalid( + streaming_manager: YandexMusicStreamingManager, +) -> None: + """Non-FLAC data returns (0, 0).""" + assert streaming_manager._parse_flac_streaminfo(b"not flac data") == (0, 0) + assert streaming_manager._parse_flac_streaminfo(b"") == (0, 0) -def test_get_audio_params_none( +def test_parse_mp4_dfla_box( streaming_manager: YandexMusicStreamingManager, ) -> None: - """None codec returns CD-quality defaults.""" - assert streaming_manager._get_audio_params(None) == (44100, 16) + """Parse dfLa box (FLAC-in-MP4) with STREAMINFO inside.""" + sr = 48000 + bps_minus1 = 23 + val = (sr << 12) | (1 << 9) | (bps_minus1 << 4) + streaminfo = bytearray(34) + streaminfo[10:14] = val.to_bytes(4, "big") + # dfLa box: size(4) + "dfLa" + version/flags(4) + block_header(4) + STREAMINFO(34) + block_header = b"\x80\x00\x00\x22" # type=0 (last), length=34 + box_size = (4 + 4 + 4 + 4 + 34).to_bytes(4, "big") + dfla_box = box_size + b"dfLa" + b"\x00\x00\x00\x00" + block_header + bytes(streaminfo) + # Wrap in some padding to simulate real MP4 structure + header = b"\x00" * 100 + dfla_box + b"\x00" * 100 + + result = streaming_manager._parse_mp4_audio_params(header) + assert result == (48000, 24) # --- get_audio_stream tests --- @@ -340,12 +389,15 @@ def _make_encrypted_stream_details( return StreamDetails( item_id="test_track_123", provider="yandex_music_instance", - audio_format=AudioFormat(content_type=ContentType.MP4), + audio_format=AudioFormat(content_type=ContentType.FLAC), stream_type=StreamType.CUSTOM, data={ - "encrypted_url": url, + "url": url, "decryption_key": key_hex, "codec": "flac-mp4", + "transport": "encraw", + "fi_quality": "lossless", + "fi_codecs": "flac-mp4,flac,aac-mp4,aac,he-aac,mp3,he-aac-mp4", }, ) @@ -450,7 +502,7 @@ async def test_get_audio_stream_http_error_raises_media_not_found( _MockResponse([], error=RuntimeError("403 Forbidden")) ) - with pytest.raises(MediaNotFoundError, match="Failed to fetch encrypted stream"): + with pytest.raises(MediaNotFoundError, match="Failed to fetch stream"): async for _ in streaming_manager.get_audio_stream(sd): pass @@ -538,9 +590,9 @@ def _get(_url: str, **_kwargs: object) -> _MockResponse: streaming_provider_stub.mass.http_session = unittest.mock.MagicMock() streaming_provider_stub.mass.http_session.get = _get - # Mock get_track_file_info_lossless to return a fresh URL + # Mock get_track_file_info to return a fresh URL streaming_provider_stub.client = unittest.mock.AsyncMock() - streaming_provider_stub.client.get_track_file_info_lossless = unittest.mock.AsyncMock( + streaming_provider_stub.client.get_track_file_info = unittest.mock.AsyncMock( return_value={"url": fresh_url, "codec": "flac-mp4", "key": key.hex()} ) streaming_manager.client = streaming_provider_stub.client @@ -553,8 +605,11 @@ def _get(_url: str, **_kwargs: object) -> _MockResponse: result += chunk assert result == plaintext - streaming_provider_stub.client.get_track_file_info_lossless.assert_called_once_with( - "test_track_123" + streaming_provider_stub.client.get_track_file_info.assert_called_once_with( + "test_track_123", + quality="lossless", + codecs="flac-mp4,flac,aac-mp4,aac,he-aac,mp3,he-aac-mp4", + transport="encraw", ) @@ -567,7 +622,7 @@ async def test_get_audio_stream_raises_after_all_retries_on_410( sd = _make_encrypted_stream_details(key.hex()) streaming_provider_stub.mass.http_session = _MockHttpSession(_MockResponse([], status=410)) streaming_provider_stub.client = unittest.mock.AsyncMock() - streaming_provider_stub.client.get_track_file_info_lossless = unittest.mock.AsyncMock( + streaming_provider_stub.client.get_track_file_info = unittest.mock.AsyncMock( return_value={"url": "https://cdn.example.com/still-expired.flac", "key": key.hex()} ) streaming_manager.client = streaming_provider_stub.client @@ -676,15 +731,13 @@ async def test_get_audio_stream_fails_immediately_when_url_refresh_returns_nothi streaming_manager: YandexMusicStreamingManager, streaming_provider_stub: StreamingProviderStub, ) -> None: - """If get_track_file_info_lossless returns no URL, stream fails without wasting retries.""" + """If get_track_file_info returns no URL, stream fails without wasting retries.""" key = b"\x88" * 32 sd = _make_encrypted_stream_details(key.hex()) streaming_provider_stub.mass.http_session = _MockHttpSession(_MockResponse([], status=410)) streaming_provider_stub.client = unittest.mock.AsyncMock() # Simulate API returning no usable URL (None result) - streaming_provider_stub.client.get_track_file_info_lossless = unittest.mock.AsyncMock( - return_value=None - ) + streaming_provider_stub.client.get_track_file_info = unittest.mock.AsyncMock(return_value=None) streaming_manager.client = streaming_provider_stub.client with ( @@ -695,7 +748,7 @@ async def test_get_audio_stream_fails_immediately_when_url_refresh_returns_nothi pass # Should have given up after attempt 0 (refresh returned None → no stale URL reuse) - assert streaming_provider_stub.client.get_track_file_info_lossless.call_count == 1 + assert streaming_provider_stub.client.get_track_file_info.call_count == 1 async def test_get_audio_stream_exact_window_boundary( @@ -738,3 +791,282 @@ async def test_get_audio_stream_exact_window_boundary( assert result == plaintext assert len(session.calls) == 1, "second window must not be requested when EOF is detected" + + +async def test_get_audio_stream_continues_after_non_block_boundary_drop( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """TCP drop at a non-AES-block boundary must not cause premature EOF on reconnect. + + Scenario (patched window = 32 bytes = 2 AES blocks, 50-byte file): + - Window 1 (bytes=0-31) drops at byte 17 (not on a 16-byte AES boundary). + - Reconnect re-requests from block_start=16; server returns full 32 bytes. + - Old bug: window_got = 31 < _RANGE_WINDOW = 32 → stream terminates at byte 48, + losing the final 2 bytes of the file. + - Fixed: received = window_got + block_skip = 31 + 1 = 32 = _RANGE_WINDOW + → stream continues to window 2, which delivers the remaining 2 bytes. + """ + small_window = 32 # 2 AES blocks + key = b"\xcc" * 32 + plaintext = b"X" * 50 # 50 bytes → two windows (32 + 2 remaining) + + nonce_16 = bytes(16) + encryptor = Cipher(algorithms.AES(key), modes.CTR(nonce_16)).encryptor() + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + + drop_at = 17 # non-block boundary (17 % 16 != 0) + + # Window 1: bytes=0-31, drops after delivering 17 bytes + resp1 = _MockResponse([ciphertext[:drop_at]], drop_payload_error=True) + # Reconnect: block_start=16, requests bytes=16-47, server returns full 32 bytes + resp2 = _MockResponse([ciphertext[16:48]], status=206) + # Window 2: bytes=48-79, only 2 bytes remain in the file + resp3 = _MockResponse([ciphertext[48:50]], status=206) + + session = _MultiCallHttpSession([resp1, resp2, resp3]) + streaming_provider_stub.mass.http_session = session + + result = b"" + with ( + unittest.mock.patch.object(_streaming_mod, "_RANGE_WINDOW", small_window), + unittest.mock.patch("asyncio.sleep"), + ): + async for chunk in streaming_manager.get_audio_stream( + _make_encrypted_stream_details(key.hex()) + ): + result += chunk + + assert result == plaintext, f"Expected {len(plaintext)} bytes, got {len(result)}" + assert len(session.calls) == 3 + assert session.calls[0]["headers"] == {"Range": "bytes=0-31"} + assert session.calls[1]["headers"] == {"Range": "bytes=16-47"} # AES-aligned reconnect + assert session.calls[2]["headers"] == {"Range": "bytes=48-79"} # second window + + +# --- Raw (unencrypted) windowed streaming tests --- + + +def _make_raw_stream_details( + url: str = "https://cdn.example.com/track.flac", + codec: str = "flac-mp4", + bit_rate: int = 0, +) -> StreamDetails: + """Build StreamDetails for raw (unencrypted) windowed stream tests.""" + return StreamDetails( + item_id="test_track_123", + provider="yandex_music_instance", + audio_format=AudioFormat(content_type=ContentType.FLAC), + stream_type=StreamType.CUSTOM, + data={ + "url": url, + "codec": codec, + "transport": "raw", + "bit_rate": bit_rate, + "fi_quality": "lossless", + "fi_codecs": "flac-mp4,flac,aac-mp4,aac,he-aac,mp3,he-aac-mp4", + }, + ) + + +async def test_get_audio_stream_raw_single_window( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """Raw stream smaller than _RANGE_WINDOW is fetched in one request.""" + plaintext = b"Hello raw FLAC data!" * 50 # 1000 bytes + sd = _make_raw_stream_details() + streaming_provider_stub.mass.http_session = _MockHttpSession(_MockResponse([plaintext])) + + result = b"" + async for chunk in streaming_manager.get_audio_stream(sd): + result += chunk + + assert result == plaintext + + +async def test_get_audio_stream_raw_multi_window( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """Raw stream larger than _RANGE_WINDOW uses multiple windowed requests.""" + small_window = 32 + plaintext = b"A" * 50 # 50 bytes → two windows (32 + 18) + + resp1 = _MockResponse([plaintext[:small_window]], status=206) + resp2 = _MockResponse([plaintext[small_window:]], status=206) + session = _MultiCallHttpSession([resp1, resp2]) + streaming_provider_stub.mass.http_session = session + + result = b"" + with unittest.mock.patch.object(_streaming_mod, "_RANGE_WINDOW", small_window): + async for chunk in streaming_manager.get_audio_stream(_make_raw_stream_details()): + result += chunk + + assert result == plaintext + assert len(session.calls) == 2 + assert session.calls[0]["headers"] == {"Range": "bytes=0-31"} + assert session.calls[1]["headers"] == {"Range": "bytes=32-63"} + + +async def test_get_audio_stream_raw_retry_on_drop( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """Raw stream reconnects with correct Range header after TCP drop.""" + plaintext = b"B" * 96 + drop_at = 48 + + first_resp = _MockResponse([plaintext[:drop_at]], drop_payload_error=True) + second_resp = _MockResponse([plaintext[drop_at:]], status=206) + session = _MultiCallHttpSession([first_resp, second_resp]) + streaming_provider_stub.mass.http_session = session + + result = b"" + with unittest.mock.patch("asyncio.sleep"): + async for chunk in streaming_manager.get_audio_stream(_make_raw_stream_details()): + result += chunk + + assert result == plaintext + assert len(session.calls) == 2 + # Raw uses exact byte offset (no AES block alignment) + assert session.calls[1]["headers"] == {"Range": f"bytes={drop_at}-{drop_at + 4194304 - 1}"} + + +async def test_get_audio_stream_raw_url_refresh_on_403( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """Raw stream refreshes URL on 403 and continues.""" + plaintext = b"C" * 64 + fresh_url = "https://cdn.example.com/refreshed-track.flac" + + expired_resp = _MockResponse([], status=403) + fresh_resp = _MockResponse([plaintext]) + + call_count = 0 + + def _get(_url: str, **_kwargs: object) -> _MockResponse: + nonlocal call_count + call_count += 1 + return expired_resp if call_count == 1 else fresh_resp + + streaming_provider_stub.mass.http_session = unittest.mock.MagicMock() + streaming_provider_stub.mass.http_session.get = _get + + streaming_provider_stub.client = unittest.mock.AsyncMock() + streaming_provider_stub.client.get_track_file_info = unittest.mock.AsyncMock( + return_value={"url": fresh_url, "codec": "flac-mp4"} + ) + streaming_manager.client = streaming_provider_stub.client + + result = b"" + with unittest.mock.patch("asyncio.sleep"): + async for chunk in streaming_manager.get_audio_stream(_make_raw_stream_details()): + result += chunk + + assert result == plaintext + streaming_provider_stub.client.get_track_file_info.assert_called_once_with( + "test_track_123", + quality="lossless", + codecs="flac-mp4,flac,aac-mp4,aac,he-aac,mp3,he-aac-mp4", + transport="raw", + ) + + +async def test_get_audio_stream_raw_resets_on_range_ignored( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """If server returns 200 instead of 206 after raw reconnect, skip already-delivered bytes.""" + small_window = 32 + plaintext = b"G" * 96 # 96 bytes + + drop_at = 48 + # First response drops after 48 bytes + first_resp = _MockResponse([plaintext[:drop_at]], drop_payload_error=True) + # Second response ignores Range and returns full file with 200 + second_resp = _MockResponse([plaintext], status=200) + session = _MultiCallHttpSession([first_resp, second_resp]) + streaming_provider_stub.mass.http_session = session + + result = b"" + with ( + unittest.mock.patch.object(_streaming_mod, "_RANGE_WINDOW", small_window), + unittest.mock.patch("asyncio.sleep"), + ): + async for chunk in streaming_manager.get_audio_stream(_make_raw_stream_details()): + result += chunk + + # Should get the full plaintext without duplication + assert result == plaintext + + +async def test_get_audio_stream_raw_seek_starts_from_byte_offset( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """Raw stream with seek_position starts Range requests from calculated byte offset.""" + # 320 kbps = 40000 bytes/sec; seek to 10s → offset 400000 + bit_rate = 320 + seek_seconds = 10 + expected_offset = int(seek_seconds * bit_rate * 1000 / 8) # 400000 + + plaintext = b"S" * 64 + resp = _MockResponse([plaintext], status=206) + session = _MultiCallHttpSession([resp]) + streaming_provider_stub.mass.http_session = session + + sd = _make_raw_stream_details(bit_rate=bit_rate) + result = b"" + async for chunk in streaming_manager.get_audio_stream(sd, seek_position=seek_seconds): + result += chunk + + assert result == plaintext + assert len(session.calls) == 1 + range_header = session.calls[0]["headers"]["Range"] + assert range_header.startswith(f"bytes={expected_offset}-") + + +async def test_get_audio_stream_raw_seek_zero_bitrate_starts_from_zero( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """When bit_rate is 0, seek_position is ignored and stream starts from byte 0.""" + plaintext = b"Z" * 64 + resp = _MockResponse([plaintext]) + session = _MultiCallHttpSession([resp]) + streaming_provider_stub.mass.http_session = session + + sd = _make_raw_stream_details(bit_rate=0) + result = b"" + async for chunk in streaming_manager.get_audio_stream(sd, seek_position=30): + result += chunk + + assert result == plaintext + assert session.calls[0]["headers"]["Range"].startswith("bytes=0-") + + +async def test_get_audio_stream_encrypted_ignores_seek_position( + streaming_manager: YandexMusicStreamingManager, + streaming_provider_stub: StreamingProviderStub, +) -> None: + """Encrypted stream always starts from byte 0 regardless of seek_position.""" + key = b"\x01" * 16 + key_hex = key.hex() + plaintext = b"E" * 64 + cipher = Cipher(algorithms.AES(key), modes.CTR(b"\x00" * 16)) + encryptor = cipher.encryptor() + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + + resp = _MockResponse([ciphertext]) + session = _MultiCallHttpSession([resp]) + streaming_provider_stub.mass.http_session = session + + sd = _make_encrypted_stream_details(key_hex) + result = b"" + async for chunk in streaming_manager.get_audio_stream(sd, seek_position=30): + result += chunk + + assert result == plaintext + assert session.calls[0]["headers"]["Range"].startswith("bytes=0-")