diff --git a/music_assistant/providers/twitch/__init__.py b/music_assistant/providers/twitch/__init__.py new file mode 100644 index 0000000000..27955936a2 --- /dev/null +++ b/music_assistant/providers/twitch/__init__.py @@ -0,0 +1,946 @@ +"""Twitch Audio music provider for Music Assistant.""" + +from __future__ import annotations + +import asyncio +import logging +import time +from collections.abc import AsyncGenerator, Sequence +from typing import TYPE_CHECKING, Any +from urllib.parse import urlencode + +from music_assistant_models.config_entries import ConfigEntry +from music_assistant_models.enums import ( + ConfigEntryType, + ContentType, + ImageType, + MediaType, + PlaybackState, + ProviderFeature, + StreamType, +) +from music_assistant_models.errors import ( + LoginFailed, + MediaNotFoundError, + MusicAssistantError, + ProviderUnavailableError, + ResourceTemporarilyUnavailable, +) +from music_assistant_models.media_items import ( + AudioFormat, + BrowseFolder, + MediaItemImage, + MediaItemType, + ProviderMapping, + Radio, + RecommendationFolder, + SearchResults, + UniqueList, +) +from music_assistant_models.streamdetails import StreamDetails +from streamlink import Streamlink # type: ignore[attr-defined] + +import music_assistant.providers.twitch.ad_handling as _ah +from music_assistant.helpers.auth import AuthenticationHelper +from music_assistant.models.music_provider import MusicProvider +from music_assistant.providers.twitch.ad_handling import patch_ad_handling +from music_assistant.providers.twitch.eventsub import EventSubClient + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from music_assistant_models.config_entries import ConfigValueType, ProviderConfig + from music_assistant_models.provider import ProviderManifest + + from music_assistant.mass import MusicAssistant + from music_assistant.models import ProviderInstanceType + +SUPPORTED_FEATURES = { + ProviderFeature.BROWSE, + ProviderFeature.SEARCH, + ProviderFeature.LIBRARY_RADIOS, + ProviderFeature.RECOMMENDATIONS, +} + +# Streamlink constants +STREAM_CHUNK_SIZE = 64 * 1024 # 64KB +MAX_CONSECUTIVE_RECONNECTS = 5 +RECONNECT_DELAY = 0.5 # seconds +PREFERRED_QUALITIES = ("audio_only", "worst") +RAID_UNSUBSCRIBE_GRACE = 15.0 # seconds to wait before unsubscribing after last stream ends + +# Cache TTL +LIVE_STATUS_TTL = 300.0 # 5 minutes + +# OAuth / Config constants +CONF_CLIENT_ID = "client_id" +CONF_CLIENT_SECRET = "client_secret" +CONF_STREAMLINK_TOKEN = "streamlink_token" +CONF_ACCESS_TOKEN = "access_token" +CONF_REFRESH_TOKEN = "refresh_token" +CONF_AUTO_RAID = "auto_raid" +CONF_ACTION_AUTH = "auth" +CONF_ACTION_REVOKE = "revoke" + +# Browse paths +BROWSE_LIVE = "live" +BROWSE_FOLLOWING = "following" + +TWITCH_AUTH_URL = "https://id.twitch.tv/oauth2/authorize" +TWITCH_TOKEN_URL = "https://id.twitch.tv/oauth2/token" +TWITCH_REVOKE_URL = "https://id.twitch.tv/oauth2/revoke" +TWITCH_SCOPES = ("user:read:follows",) +CALLBACK_REDIRECT_URL = "https://music-assistant.io/callback" + + +async def _handle_auth_action( + mass: MusicAssistant, + values: dict[str, ConfigValueType], +) -> None: + """Handle OAuth authentication action.""" + if not values: + msg = "No configuration values provided for authentication." + raise LoginFailed(msg) + + client_id = str(values.get(CONF_CLIENT_ID, "")).strip() + client_secret = str(values.get(CONF_CLIENT_SECRET, "")).strip() + if not client_id or not client_secret: + msg = "Client ID and Client Secret are required to authenticate." + raise LoginFailed(msg) + + if "session_id" not in values: + msg = "Session ID is required to authenticate with Twitch." + raise LoginFailed(msg) + session_id = str(values["session_id"]) + + async with AuthenticationHelper(mass, session_id) as auth_helper: + params = { + "client_id": client_id, + "redirect_uri": CALLBACK_REDIRECT_URL, + "response_type": "code", + "scope": " ".join(TWITCH_SCOPES), + "state": auth_helper.callback_url, + } + auth_url = f"{TWITCH_AUTH_URL}?{urlencode(params)}" + result = await auth_helper.authenticate(auth_url) + code = result.get("code", "") + + if not code: + msg = "No authorization code received from Twitch." + raise LoginFailed(msg) + + # Exchange code for tokens + token_params = { + "client_id": client_id, + "client_secret": client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": CALLBACK_REDIRECT_URL, + } + async with mass.http_session.post(TWITCH_TOKEN_URL, data=token_params) as response: + if response.status != 200: + error_text = await response.text() + msg = f"Failed to exchange authorization code: {error_text}" + raise LoginFailed(msg) + token_data = await response.json() + + values[CONF_ACCESS_TOKEN] = token_data["access_token"] + values[CONF_REFRESH_TOKEN] = token_data.get("refresh_token", "") + + +async def _handle_revoke_action( + mass: MusicAssistant, + values: dict[str, ConfigValueType], +) -> None: + """Handle credential revocation action.""" + access_token = str(values.get(CONF_ACCESS_TOKEN, "")) + client_id = str(values.get(CONF_CLIENT_ID, "")) + + # Best-effort revoke — clear local state even if revoke fails + if access_token: + try: + async with mass.http_session.post( + TWITCH_REVOKE_URL, + data={"client_id": client_id, "token": access_token}, + ): + pass + except Exception: + logger.debug("Failed to revoke Twitch token", exc_info=True) + + values[CONF_ACCESS_TOKEN] = "" + values[CONF_REFRESH_TOKEN] = "" + + +async def setup( + mass: MusicAssistant, manifest: ProviderManifest, config: ProviderConfig +) -> ProviderInstanceType: + """Initialize provider(instance) with given configuration.""" + if not config.get_value(CONF_ACCESS_TOKEN): + msg = "Not authenticated. Please configure and authenticate the Twitch provider." + raise LoginFailed(msg) + return TwitchProvider(mass, manifest, config, SUPPORTED_FEATURES) + + +async def get_config_entries( + mass: MusicAssistant, + instance_id: str | None = None, # noqa: ARG001 + action: str | None = None, + values: dict[str, ConfigValueType] | None = None, +) -> tuple[ConfigEntry, ...]: + """Return Config entries to setup this provider.""" + if values is None: + values = {} + + # Handle actions + if action == CONF_ACTION_AUTH: + await _handle_auth_action(mass, values) + elif action == CONF_ACTION_REVOKE: + await _handle_revoke_action(mass, values) + + # Determine auth state + is_authenticated = bool(values.get(CONF_ACCESS_TOKEN)) + + return ( + # Setup instructions + ConfigEntry( + key="setup_info", + type=ConfigEntryType.LABEL, + label="Register a Twitch application at dev.twitch.tv/console/apps. " + f"Use {CALLBACK_REDIRECT_URL} as the OAuth Redirect URL.", + hidden=is_authenticated, + ), + # Credentials + ConfigEntry( + key=CONF_CLIENT_ID, + type=ConfigEntryType.SECURE_STRING, + label="Twitch Client ID", + description="From your Twitch application at dev.twitch.tv/console/apps.", + required=True, + value=values.get(CONF_CLIENT_ID), + ), + ConfigEntry( + key=CONF_CLIENT_SECRET, + type=ConfigEntryType.SECURE_STRING, + label="Twitch Client Secret", + required=True, + value=values.get(CONF_CLIENT_SECRET), + ), + # Auth status + ConfigEntry( + key="auth_status", + type=ConfigEntryType.LABEL, + label="Authenticated" if is_authenticated else "Not authenticated", + ), + # Auth action (hidden when authenticated) + ConfigEntry( + key=CONF_ACTION_AUTH, + type=ConfigEntryType.ACTION, + label="Authenticate with Twitch", + action=CONF_ACTION_AUTH, + action_label="Authenticate", + hidden=is_authenticated, + ), + # Revoke action (hidden when not authenticated) + ConfigEntry( + key=CONF_ACTION_REVOKE, + type=ConfigEntryType.ACTION, + label="Revoke credentials", + action=CONF_ACTION_REVOKE, + action_label="Revoke", + hidden=not is_authenticated, + ), + # Token storage (hidden) + ConfigEntry( + key=CONF_ACCESS_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Access Token", + hidden=True, + required=False, + value=values.get(CONF_ACCESS_TOKEN, ""), + ), + ConfigEntry( + key=CONF_REFRESH_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Refresh Token", + hidden=True, + required=False, + value=values.get(CONF_REFRESH_TOKEN, ""), + ), + # Optional Twitch website token + ConfigEntry( + key=CONF_STREAMLINK_TOKEN, + type=ConfigEntryType.SECURE_STRING, + label="Twitch Website Token (optional)", + description="Your Twitch website auth token. If you have Twitch Turbo " + "or are subscribed to a channel, this reduces ad frequency. " + "See the Streamlink Twitch plugin docs for how to extract this token: " + "https://streamlink.github.io/cli/plugins/twitch.html#authentication", + required=False, + value=values.get(CONF_STREAMLINK_TOKEN), + ), + # Auto-raid toggle + ConfigEntry( + key=CONF_AUTO_RAID, + type=ConfigEntryType.BOOLEAN, + label="Auto-follow raids", + description="Automatically switch to raid target when a streamer raids.", + default_value=True, + value=values.get(CONF_AUTO_RAID), + ), + ) + + +class TwitchProvider(MusicProvider): + """Provider implementation for Twitch audio streaming.""" + + _access_token: str | None = None + _refresh_token: str | None = None + _client_id: str | None = None + _client_secret: str | None = None + _user_id: str | None = None + + # Live status cache + _cached_channels: list[dict[str, Any]] | None = None + _cached_live: dict[str, dict[str, Any]] | None = None + _cached_profiles: dict[str, dict[str, Any]] | None = None + _cache_time: float = 0.0 + + # Raid state + _eventsub: EventSubClient | None = None + _auto_raid: bool = True + _active_streams: dict[str, int] + _unsubscribe_timers: dict[str, asyncio.Task[None]] + + @property + def is_streaming_provider(self) -> bool: + """Return True if the provider is a streaming provider.""" + return True + + async def handle_async_init(self) -> None: + """Handle async initialization of the provider.""" + self._client_id = str(self.config.get_value(CONF_CLIENT_ID) or "") + self._client_secret = str(self.config.get_value(CONF_CLIENT_SECRET) or "") + self._access_token = str(self.config.get_value(CONF_ACCESS_TOKEN) or "") or None + self._refresh_token = str(self.config.get_value(CONF_REFRESH_TOKEN) or "") or None + val = self.config.get_value(CONF_AUTO_RAID) + self._auto_raid = bool(val) if val is not None else True + self._active_streams = {} + self._unsubscribe_timers = {} + self.logger.info( + "Twitch provider initialized: auto_raid=%s, authenticated=%s", + self._auto_raid, + self.is_authenticated, + ) + + # Resolve user ID if authenticated + if self._access_token: + try: + data = await self._api_get("/helix/users", params={}) + if data.get("data"): + self._user_id = data["data"][0]["id"] + self.logger.info("Resolved Twitch user ID: %s", self._user_id) + except LoginFailed: + raise # Propagate auth failures — user needs to re-authenticate + except Exception: + self.logger.warning("Failed to resolve user ID during init") + + async def unload(self, is_removed: bool = False) -> None: + """Handle unload/close of the provider.""" + # Cancel pending unsubscribe timers + for timer in self._unsubscribe_timers.values(): + timer.cancel() + self._unsubscribe_timers.clear() + self._active_streams.clear() + + # Stop EventSub + if self._eventsub is not None: + await self._eventsub.stop() + self._eventsub = None + + # Clear cache + self._clear_cache() + + # --- Raid Following --- + + async def _subscribe_raids_for_channel(self, channel_login: str) -> None: + """Ensure EventSub is running and subscribed to raids for a channel.""" + if not self._auto_raid or not self.is_authenticated: + return + + # Ensure EventSub client exists + if self._eventsub is None: + self._eventsub = EventSubClient( + http_session=self.mass.http_session, + api_headers_fn=self._api_headers, + ) + await self._eventsub.start( + on_raid=lambda from_l, to_l: asyncio.create_task(self._on_raid(from_l, to_l)) + ) + + # Resolve user ID for the channel and subscribe + users = await self._get_users(logins=[channel_login]) + if users: + self.logger.debug( + "Subscribing to raids for %s (user_id=%s)", channel_login, users[0]["id"] + ) + await self._eventsub.subscribe_raids(users[0]["id"]) + else: + self.logger.warning( + "Could not resolve user ID for %s — no raid subscription", channel_login + ) + + def _track_stream_start(self, channel_login: str) -> None: + """Increment active stream count for a channel. Subscribe to raids on first stream.""" + # Cancel any pending delayed unsubscribe for this channel + timer = self._unsubscribe_timers.pop(channel_login, None) + if timer is not None: + timer.cancel() + + prev_count = self._active_streams.get(channel_login, 0) + self._active_streams[channel_login] = prev_count + 1 + self.logger.debug("Stream started for %s (active: %d)", channel_login, prev_count + 1) + + if prev_count == 0: + asyncio.create_task(self._subscribe_raids_for_channel(channel_login)) + + def _track_stream_end(self, channel_login: str) -> None: + """Decrement active stream count. Start delayed unsubscribe when last stream ends.""" + count = self._active_streams.get(channel_login, 0) + if count <= 1: + # Last stream for this channel — start grace period before unsubscribing + self._active_streams.pop(channel_login, None) + self.logger.debug( + "Last stream ended for %s — starting %ds unsubscribe grace period", + channel_login, + int(RAID_UNSUBSCRIBE_GRACE), + ) + self._unsubscribe_timers[channel_login] = asyncio.create_task( + self._delayed_unsubscribe(channel_login) + ) + else: + self._active_streams[channel_login] = count - 1 + self.logger.debug("Stream ended for %s (active: %d)", channel_login, count - 1) + + async def _delayed_unsubscribe(self, channel_login: str) -> None: + """Wait grace period, then unsubscribe from raids for a channel.""" + await asyncio.sleep(RAID_UNSUBSCRIBE_GRACE) + self._unsubscribe_timers.pop(channel_login, None) + + if self._eventsub is not None: + users = await self._get_users(logins=[channel_login]) + if users: + await self._eventsub.unsubscribe_raids(users[0]["id"]) + + self.logger.debug("Unsubscribed from raids for %s after grace period", channel_login) + + async def _on_raid(self, from_login: str, to_login: str) -> None: + """Handle a raid event — switch all playing queues to raid target.""" + if not self._auto_raid: + return + + if from_login not in self._active_streams and from_login not in self._unsubscribe_timers: + self.logger.debug("Ignoring raid from %s (not active or in grace period)", from_login) + return + + self.logger.info("Raid received: %s → %s", from_login, to_login) + + # Check grace period BEFORE cleanup + in_grace_period = from_login in self._unsubscribe_timers + + # Cancel any pending unsubscribe for the raiding channel + timer = self._unsubscribe_timers.pop(from_login, None) + if timer is not None: + timer.cancel() + + # Clean up the raiding channel's tracking — new streams will register themselves + self._active_streams.pop(from_login, None) + max_idle_age = RAID_UNSUBSCRIBE_GRACE * 2 + + for queue in self.mass.player_queues.all(): + if not queue.current_item or not queue.current_item.streamdetails: + continue + if queue.current_item.streamdetails.item_id != from_login: + continue + + if queue.state == PlaybackState.PLAYING: + pass # always switch + elif queue.state == PlaybackState.IDLE and in_grace_period: + idle_duration = time.time() - queue.elapsed_time_last_updated + if idle_duration > max_idle_age: + continue # idle too long — user likely stopped it + else: + continue # paused or other state — don't switch + + try: + await self.mass.player_queues.play_media( + queue_id=queue.queue_id, + media=f"{self.instance_id}://radio/{to_login}", + ) + except Exception: + self.logger.warning( + "Failed to follow raid to %s on queue %s", + to_login, + queue.queue_id, + exc_info=True, + ) + + @property + def is_authenticated(self) -> bool: + """Return whether the provider has valid credentials.""" + return bool(self._access_token) + + def _api_headers(self) -> dict[str, str]: + """Return headers for Twitch API calls.""" + return { + "Authorization": f"Bearer {self._access_token}", + "Client-Id": self._client_id or "", + } + + @staticmethod + def _raise_for_status(status: int) -> None: + """Raise an appropriate MA exception for non-success HTTP status codes.""" + if 200 <= status < 300: + return + if status == 404: + msg = f"Twitch API: resource not found ({status})" + raise MediaNotFoundError(msg) + if status == 429: + msg = f"Twitch API: rate limited ({status})" + raise ResourceTemporarilyUnavailable(msg) + if status in (401, 403): + msg = f"Twitch API: authentication failed ({status})" + raise LoginFailed(msg) + if status >= 500: + msg = f"Twitch API: server error ({status})" + raise ProviderUnavailableError(msg) + msg = f"Twitch API error ({status})" + raise MusicAssistantError(msg) + + async def _refresh_access_token(self) -> None: + """Refresh the Twitch access token using the refresh token.""" + if not self._refresh_token: + self._access_token = None + msg = "No refresh token available. Re-authenticate." + raise LoginFailed(msg) + + params = { + "client_id": self._client_id or "", + "client_secret": self._client_secret or "", + "grant_type": "refresh_token", + "refresh_token": self._refresh_token, + } + async with self.mass.http_session.post(TWITCH_TOKEN_URL, data=params) as response: + if response.status != 200: + self._access_token = None + self._refresh_token = None + error_text = await response.text() + msg = f"Token refresh failed: {error_text}" + raise LoginFailed(msg) + data = await response.json() + + self._access_token = data["access_token"] + # Twitch may rotate the refresh token + self._refresh_token = data.get("refresh_token", self._refresh_token) + + # Persist tokens to config storage so they survive restarts + self._update_config_value(CONF_ACCESS_TOKEN, self._access_token, encrypted=True) + self._update_config_value(CONF_REFRESH_TOKEN, self._refresh_token, encrypted=True) + + async def _api_get( + self, + url: str, + params: dict[str, Any] | list[tuple[str, str]] | None = None, + ) -> dict[str, Any]: + """Make authenticated GET request to Twitch API, with auto-refresh on 401.""" + full_url = url if url.startswith("http") else f"https://api.twitch.tv{url}" + async with self.mass.http_session.get( + full_url, headers=self._api_headers(), params=params + ) as response: + if response.status == 401: + await self._refresh_access_token() + async with self.mass.http_session.get( + full_url, headers=self._api_headers(), params=params + ) as retry_response: + self._raise_for_status(retry_response.status) + return await retry_response.json() # type: ignore[no-any-return] + self._raise_for_status(response.status) + return await response.json() # type: ignore[no-any-return] + + # --- Twitch API Methods --- + + async def _get_followed_channels(self) -> list[dict[str, Any]]: + """Get all followed channels (paginated).""" + all_channels: list[dict[str, Any]] = [] + cursor: str | None = None + while True: + params: dict[str, str] = {"user_id": self._user_id or "", "first": "100"} + if cursor: + params["after"] = cursor + data = await self._api_get("/helix/channels/followed", params=params) + all_channels.extend(data.get("data", [])) + cursor = data.get("pagination", {}).get("cursor") + if not cursor: + break + return all_channels + + async def _get_live_streams(self, user_ids: list[str]) -> list[dict[str, Any]]: + """Get live streams for user IDs (batched, max 100 per request).""" + if not user_ids: + return [] + all_streams: list[dict[str, Any]] = [] + for i in range(0, len(user_ids), 100): + batch = user_ids[i : i + 100] + params = [("user_id", uid) for uid in batch] + data = await self._api_get("/helix/streams", params=params) + all_streams.extend(data.get("data", [])) + return all_streams + + async def _get_user_profiles(self, user_ids: list[str]) -> dict[str, dict[str, Any]]: + """Get user profiles by ID (batched, max 100 per request).""" + if not user_ids: + return {} + profiles: dict[str, dict[str, Any]] = {} + for i in range(0, len(user_ids), 100): + batch = user_ids[i : i + 100] + params = [("id", uid) for uid in batch] + data = await self._api_get("/helix/users", params=params) + for user in data.get("data", []): + profiles[user["login"]] = user + return profiles + + async def _get_followed_live_status( + self, + ) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]], dict[str, dict[str, Any]]]: + """Get followed channels, live status, and profiles (cached 5 min).""" + if ( + self._cached_channels is not None + and self._cached_live is not None + and self._cached_profiles is not None + and (time.monotonic() - self._cache_time) < LIVE_STATUS_TTL + ): + return self._cached_channels, self._cached_live, self._cached_profiles + + channels = await self._get_followed_channels() + user_ids = [ch["broadcaster_id"] for ch in channels] + # Fetch streams and profiles concurrently — both only need user_ids + streams, profiles = await asyncio.gather( + self._get_live_streams(user_ids), + self._get_user_profiles(user_ids), + ) + live_by_login = {s["user_login"]: s for s in streams} + + self._cached_channels = channels + self._cached_live = live_by_login + self._cached_profiles = profiles + self._cache_time = time.monotonic() + + return channels, live_by_login, profiles + + async def _get_users(self, logins: list[str] | None = None) -> list[dict[str, Any]]: + """Get user info by login names.""" + if not logins: + return [] + params = [("login", login) for login in logins] + data = await self._api_get("/helix/users", params=params) + return data.get("data", []) # type: ignore[no-any-return] + + def _clear_cache(self) -> None: + """Clear the live status cache.""" + self._cached_channels = None + self._cached_live = None + self._cached_profiles = None + self._cache_time = 0.0 + + # --- Radio Model Helpers --- + + def _channel_to_radio( + self, + channel: dict[str, Any], + stream: dict[str, Any] | None = None, + profile: dict[str, Any] | None = None, + ) -> Radio: + """Convert a Twitch channel + optional stream/profile data to a Radio model.""" + login = channel.get("broadcaster_login", channel.get("user_login", "")) + display_name = channel.get("broadcaster_name", channel.get("display_name", login)) + name = display_name + if stream: + viewer_count = stream.get("viewer_count", 0) + name = f"{display_name} ({viewer_count} viewers)" + + # Prefer stream thumbnail (live preview), fall back to profile image + thumbnail = "" + if stream and stream.get("thumbnail_url"): + thumbnail = stream["thumbnail_url"].replace("{width}", "320").replace("{height}", "180") + elif profile and profile.get("profile_image_url"): + thumbnail = profile["profile_image_url"] + + radio = Radio( + item_id=login, + provider=self.domain, + name=name, + provider_mappings={ + ProviderMapping( + item_id=login, + provider_domain=self.domain, + provider_instance=self.instance_id, + ) + }, + ) + if thumbnail: + radio.metadata.images = UniqueList( + [ + MediaItemImage( + type=ImageType.THUMB, + path=thumbnail, + provider=self.instance_id, + remotely_accessible=True, + ) + ] + ) + return radio + + # --- MusicProvider Interface --- + + async def get_stream_details(self, item_id: str, media_type: MediaType) -> StreamDetails: + """Get streamdetails for a Twitch channel.""" + return StreamDetails( + provider=self.instance_id, + item_id=item_id, + audio_format=AudioFormat(content_type=ContentType.UNKNOWN), + media_type=MediaType.RADIO, + stream_type=StreamType.CUSTOM, + allow_seek=False, + can_seek=False, + ) + + async def get_audio_stream( + self, streamdetails: StreamDetails, seek_position: int = 0 + ) -> AsyncGenerator[bytes, None]: + """Return the audio stream for a Twitch channel.""" + item_id = streamdetails.item_id + reconnects = 0 + + # Create a per-stream Streamlink session. Streamlink sessions are not + # thread-safe (shared HTTP state, options, plugin vars), so each + # concurrent get_audio_stream call needs its own instance. + sl_session = await asyncio.to_thread(self._create_streamlink_session) + ad_patch_applied = False + + self._track_stream_start(item_id) + try: + while True: + streams = await asyncio.to_thread(self._resolve_streams, item_id, sl_session) + if not streams: + return + + # Apply ad handling monkey-patch once after first resolution. + # Must be done after streams() because Streamlink loads plugins + # into a fresh module namespace — the reader class doesn't exist + # until then. + if not ad_patch_applied: + any_stream = next(iter(streams.values()), None) + if any_stream is not None: + reader_cls = getattr(type(any_stream), "__reader__", None) + if reader_cls is not None: + patch_ad_handling(reader_cls=reader_cls) + ad_patch_applied = True + + stream = self._select_quality(streams) + if not stream: + return + + fd = await asyncio.to_thread(stream.open) + prev_ad_state = False + try: + while True: + chunk = await asyncio.to_thread(fd.read, STREAM_CHUNK_SIZE) + if chunk: + reconnects = 0 + if _ah.ad_break_active != prev_ad_state: + prev_ad_state = _ah.ad_break_active + if _ah.ad_break_active: + streamdetails.stream_title = f"{item_id} - Ad Break" + else: + # stream_title is a property backed by stream_metadata, + # so clearing metadata also clears the title — MA falls + # back to current_item.name + streamdetails.stream_metadata = None + yield chunk + continue + break + finally: + await asyncio.to_thread(fd.close) + + reconnects += 1 + if reconnects > MAX_CONSECUTIVE_RECONNECTS: + return + + await asyncio.sleep(RECONNECT_DELAY) + finally: + self._track_stream_end(item_id) + + def _create_streamlink_session(self) -> Streamlink: + """Create and configure a new Streamlink session. Blocking — call via to_thread.""" + session = Streamlink() + # Increase the segment queue deadline so the stream survives Twitch + # ad breaks without triggering "No new segments" timeout. Default + # factor is 3 (≈15 s for 5 s target duration); 6 gives ≈30 s, + # enough for mid-stream ad transition gaps. + session.set_option("stream-segmented-queue-deadline", 6) + streamlink_token = str(self.config.get_value(CONF_STREAMLINK_TOKEN) or "") + if streamlink_token: + session.set_option("http-headers", {"Authorization": f"OAuth {streamlink_token}"}) + return session + + def _resolve_streams(self, channel: str, session: Streamlink) -> dict[str, Any] | None: + """Resolve Streamlink streams for a channel. Blocking — call via to_thread.""" + try: + streams = session.streams(f"https://twitch.tv/{channel}") + if not streams: + return None + return dict(streams) + except Exception: + self.logger.exception("Failed to resolve streams for %s", channel) + return None + + @staticmethod + def _select_quality(streams: dict[str, Any]) -> Any | None: + """Select preferred audio quality from available streams.""" + return next((streams[q] for q in PREFERRED_QUALITIES if q in streams), None) + + async def browse(self, path: str) -> Sequence[MediaItemType | BrowseFolder]: + """Browse this provider's items.""" + # Parse path: "" for root, "instance://live" or "instance://following" + subpath = "" + if "://" in path: + subpath = path.split("://")[1].split("/")[0] + + if not subpath: + return [ + BrowseFolder( + item_id=BROWSE_LIVE, + provider=self.domain, + path=f"{self.instance_id}://{BROWSE_LIVE}", + name="Live", + ), + BrowseFolder( + item_id=BROWSE_FOLLOWING, + provider=self.domain, + path=f"{self.instance_id}://{BROWSE_FOLLOWING}", + name="Following", + ), + ] + + if subpath not in (BROWSE_LIVE, BROWSE_FOLLOWING): + return [] + + if not self.is_authenticated or not self._user_id: + return [] + + channels, live_by_login, profiles = await self._get_followed_live_status() + + if subpath == BROWSE_LIVE: + return [ + self._channel_to_radio( + ch, + live_by_login.get(ch["broadcaster_login"]), + profiles.get(ch["broadcaster_login"]), + ) + for ch in channels + if ch["broadcaster_login"] in live_by_login + ] + + if subpath == BROWSE_FOLLOWING: + result: list[MediaItemType | BrowseFolder] = [] + for ch in sorted(channels, key=lambda c: c["broadcaster_name"].lower()): + login = ch["broadcaster_login"] + stream = live_by_login.get(login) + radio = self._channel_to_radio(ch, stream, profiles.get(login)) + if not stream: + radio.name = f"{ch['broadcaster_name']} (offline)" + result.append(radio) + return result + + return [] # pragma: no cover + + async def get_library_radios(self) -> AsyncGenerator[Radio, None]: + """Retrieve live followed channels as radio stations.""" + if not self.is_authenticated or not self._user_id: + return + + channels, live_by_login, profiles = await self._get_followed_live_status() + for ch in channels: + login = ch["broadcaster_login"] + if login in live_by_login: + yield self._channel_to_radio(ch, live_by_login[login], profiles.get(login)) + + async def recommendations(self) -> list[RecommendationFolder]: + """Get this provider's recommendations.""" + if not self.is_authenticated or not self._user_id: + return [] + + channels, live_by_login, profiles = await self._get_followed_live_status() + live_radios = [ + self._channel_to_radio( + ch, live_by_login[ch["broadcaster_login"]], profiles.get(ch["broadcaster_login"]) + ) + for ch in channels + if ch["broadcaster_login"] in live_by_login + ] + if not live_radios: + return [] + + folder = RecommendationFolder( + name="Twitch Live Channels", + item_id=f"{self.instance_id}_live_channels", + provider=self.instance_id, + icon="mdi-broadcast", + ) + folder.items.extend(live_radios) + return [folder] + + async def get_radio(self, prov_radio_id: str) -> Radio: + """Get full radio details by id (channel login).""" + if not self.is_authenticated: + msg = f"Not authenticated — cannot look up channel {prov_radio_id}" + raise MediaNotFoundError(msg) + + users = await self._get_users(logins=[prov_radio_id]) + if not users: + msg = f"Twitch channel not found: {prov_radio_id}" + raise MediaNotFoundError(msg) + + user = users[0] + # Check if live + streams = await self._get_live_streams([user["id"]]) + stream = streams[0] if streams else None + + return self._channel_to_radio( + {"broadcaster_login": user["login"], "broadcaster_name": user["display_name"]}, + stream, + user, + ) + + async def search( + self, + search_query: str, + media_types: list[MediaType], + limit: int = 5, + ) -> SearchResults: + """Perform search on Twitch.""" + result = SearchResults() + if MediaType.RADIO not in media_types: + return result + if not search_query or not self.is_authenticated: + return result + + try: + data = await self._api_get( + "/helix/search/channels", + params={"query": search_query, "first": str(limit)}, + ) + result.radio = [self._channel_to_radio(ch) for ch in data.get("data", [])] + except Exception: + self.logger.warning("Twitch search failed for query '%s'", search_query) + + return result diff --git a/music_assistant/providers/twitch/ad_handling.py b/music_assistant/providers/twitch/ad_handling.py new file mode 100644 index 0000000000..1b34a438e7 --- /dev/null +++ b/music_assistant/providers/twitch/ad_handling.py @@ -0,0 +1,56 @@ +"""Ad handling for Twitch streams via Streamlink monkey-patching.""" + +from __future__ import annotations + +import logging + +logger = logging.getLogger(__name__) + +# Module-level flag — GIL makes bool read/write atomic. +# Set by Streamlink writer (runs in thread), read by provider. +ad_break_active: bool = False + + +def patch_ad_handling(reader_cls: type | None = None) -> None: + """Patch TwitchHLSStreamReader.__writer__ to pass through ads with logging. + + Args: + reader_cls: The actual TwitchHLSStreamReader class to patch. If None, + patches the imported class (which may differ from the class Streamlink's + plugin system uses at runtime due to fresh module loading). + Callers should pass the reader class from the resolved stream object + to ensure the correct class is patched. + + """ + from streamlink.plugins.twitch import ( # noqa: PLC0415 + TwitchHLSSegment, + TwitchHLSStreamReader, + TwitchHLSStreamWriter, + ) + + target_reader = reader_cls or TwitchHLSStreamReader + + class PassthroughTwitchWriter(TwitchHLSStreamWriter): + """Writer that logs ad segments and tracks ad break state.""" + + def should_filter_segment(self, segment: TwitchHLSSegment) -> bool: # type: ignore[override] + """Never filter — let all segments through.""" + global ad_break_active # noqa: PLW0603 + if segment.ad: + ad_break_active = True + logger.debug( + "Ad segment %d (%.1fs): passing through as audio", + segment.num, + segment.duration, + ) + else: + if ad_break_active: + logger.debug( + "Content segment %d: ad block ended, audio resuming", + segment.num, + ) + ad_break_active = False + return False + + target_reader.__writer__ = PassthroughTwitchWriter # type: ignore[attr-defined] + logger.info("Twitch ad handling: passthrough (ads play as audio)") diff --git a/music_assistant/providers/twitch/eventsub.py b/music_assistant/providers/twitch/eventsub.py new file mode 100644 index 0000000000..cf72ceb013 --- /dev/null +++ b/music_assistant/providers/twitch/eventsub.py @@ -0,0 +1,268 @@ +"""EventSub WebSocket client for Twitch raid following.""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable + + +logger = logging.getLogger(__name__) + +EVENTSUB_WS_URL = "wss://eventsub.wss.twitch.tv/ws" +MAX_BACKOFF = 60.0 + + +class EventSubClient: + """Async EventSub WebSocket client for channel.raid subscriptions.""" + + def __init__( + self, + http_session: Any, + api_headers_fn: Callable[[], dict[str, str]], + ) -> None: + """Initialize EventSub client. + + Args: + http_session: aiohttp ClientSession for WebSocket + API calls + api_headers_fn: callable returning auth headers for Twitch API + + """ + self._http_session = http_session + self._api_headers_fn = api_headers_fn + + self._ws: Any | None = None + self._session_id: str | None = None + self._subscriptions: dict[str, str] = {} # broadcaster_user_id -> subscription_id + self._reconnect_url: str | None = None + + self._ready = asyncio.Event() + self._stopped = False + self._backoff = 1.0 + self._listen_task: asyncio.Task[None] | None = None + self._on_raid: Callable[[str, str], Any] | None = None + self._subscribe_pending: set[str] = set() + + @property + def is_connected(self) -> bool: + """Return whether the WebSocket is connected.""" + return self._ws is not None and not self._stopped + + async def start(self, on_raid: Callable[[str, str], Any]) -> None: + """Start the EventSub WebSocket connection. + + Args: + on_raid: callback(from_login, to_login) when a raid is received + + """ + self._on_raid = on_raid + self._stopped = False + self._listen_task = asyncio.create_task(self._connect_loop()) + + async def stop(self) -> None: + """Stop the EventSub WebSocket and clean up.""" + logger.debug("EventSub: stopping WebSocket and cleaning up") + self._stopped = True + self._session_id = None + self._subscriptions.clear() + self._ready.clear() + + if self._ws is not None: + await self._ws.close() + self._ws = None + + if self._listen_task is not None: + self._listen_task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await self._listen_task + self._listen_task = None + + async def subscribe_raids(self, broadcaster_user_id: str) -> None: + """Subscribe to channel.raid events for a broadcaster. No-op if already subscribed.""" + if broadcaster_user_id in self._subscriptions: + return + + # Wait for WebSocket to be ready + self._subscribe_pending.add(broadcaster_user_id) + try: + await asyncio.wait_for(self._ready.wait(), timeout=10.0) + except TimeoutError: + logger.warning( + "EventSub not ready — cannot subscribe to raids for %s", + broadcaster_user_id, + ) + return + finally: + self._subscribe_pending.discard(broadcaster_user_id) + + # Check if welcome handler already re-subscribed (reconnect case) + if broadcaster_user_id in self._subscriptions: + return + + await self._create_subscription(broadcaster_user_id) + + async def unsubscribe_raids(self, broadcaster_user_id: str) -> None: + """Unsubscribe from raid events for a specific broadcaster.""" + sub_id = self._subscriptions.pop(broadcaster_user_id, None) + if not sub_id: + return + + try: + async with self._http_session.delete( + "https://api.twitch.tv/helix/eventsub/subscriptions", + headers=self._api_headers_fn(), + params={"id": sub_id}, + ): + pass + logger.debug( + "EventSub: unsubscribed %s for broadcaster %s", + sub_id, + broadcaster_user_id, + ) + except Exception: + logger.warning("EventSub: failed to unsubscribe %s", sub_id, exc_info=True) + + async def unsubscribe_all(self) -> None: + """Unsubscribe from all active EventSub subscriptions.""" + broadcaster_ids = list(self._subscriptions.keys()) + for broadcaster_id in broadcaster_ids: + await self.unsubscribe_raids(broadcaster_id) + + async def _create_subscription(self, broadcaster_user_id: str) -> None: + """Create an EventSub subscription for channel.raid.""" + body = { + "type": "channel.raid", + "version": "1", + "condition": {"from_broadcaster_user_id": broadcaster_user_id}, + "transport": {"method": "websocket", "session_id": self._session_id}, + } + try: + async with self._http_session.post( + "https://api.twitch.tv/helix/eventsub/subscriptions", + headers={**self._api_headers_fn(), "Content-Type": "application/json"}, + json=body, + ) as response: + if response.status in (200, 202): + data = await response.json() + self._subscriptions[broadcaster_user_id] = data["data"][0]["id"] + logger.debug( + "EventSub: subscribed to channel.raid for %s (sub=%s)", + broadcaster_user_id, + self._subscriptions[broadcaster_user_id], + ) + else: + text = await response.text() + logger.warning("EventSub: subscribe failed: %s %s", response.status, text) + except Exception: + logger.warning("EventSub: failed to create subscription", exc_info=True) + + async def _connect_loop(self) -> None: + """Run the connection loop — connect, listen, reconnect with backoff.""" + while not self._stopped: + url = self._reconnect_url or EVENTSUB_WS_URL + self._reconnect_url = None # consume after use + + try: + self._ws = await self._http_session.ws_connect(url) + async for msg in self._ws: + if self._stopped: + # mypy false positive (python/mypy#12784): async for + except + # CancelledError makes mypy think the loop body is unreachable + break # type: ignore[unreachable] + data = getattr(msg, "data", None) + if not isinstance(data, str): + continue + try: + self._handle_message(json.loads(data)) + except (json.JSONDecodeError, KeyError, TypeError): + logger.debug("EventSub: ignoring malformed message") + except asyncio.CancelledError: + return + except Exception: + logger.debug("EventSub: WebSocket disconnected", exc_info=True) + finally: + self._ws = None + self._ready.clear() + + if self._stopped: + # mypy false positive (python/mypy#13104): the except Exception + # path falls through the finally to here, but mypy doesn't track + # it because except CancelledError returns + return # type: ignore[unreachable] + + # Backoff before reconnect + logger.debug("EventSub: reconnecting in %.1fs", self._backoff) + await asyncio.sleep(self._backoff) + self._backoff = min(self._backoff * 2, MAX_BACKOFF) + + def _handle_message(self, msg: dict[str, Any]) -> None: + """Dispatch an EventSub WebSocket message by type.""" + msg_type = msg.get("metadata", {}).get("message_type", "") + + if msg_type == "session_welcome": + self._handle_welcome(msg) + elif msg_type == "session_reconnect": + self._handle_reconnect(msg) + elif msg_type == "notification": + self._handle_notification(msg) + elif msg_type == "revocation": + self._handle_revocation(msg) + # session_keepalive is a no-op + + def _handle_welcome(self, msg: dict[str, Any]) -> None: + """Handle session_welcome — store session ID, re-subscribe if needed.""" + self._session_id = msg["payload"]["session"]["id"] + self._backoff = 1.0 # reset backoff + + # Old subscriptions are invalid on the new session. Keep the broadcaster + # IDs (we need to re-subscribe) but clear the subscription IDs. + stale_broadcasters = [ + bid for bid in self._subscriptions if bid not in self._subscribe_pending + ] + self._subscriptions.clear() + + # Re-subscribe for all broadcasters that aren't already being handled + # by a concurrent subscribe_raids call. + for broadcaster_id in stale_broadcasters: + asyncio.create_task(self._create_subscription(broadcaster_id)) + + self._ready.set() + + def _handle_reconnect(self, msg: dict[str, Any]) -> None: + """Handle session_reconnect — store new URL, close current WS.""" + self._reconnect_url = msg["payload"]["session"]["reconnect_url"] + if self._ws is not None: + asyncio.create_task(self._ws.close()) + + def _handle_notification(self, msg: dict[str, Any]) -> None: + """Handle notification — fire raid callback if channel.raid.""" + sub_type = msg.get("metadata", {}).get("subscription_type", "") + if sub_type != "channel.raid": + return + + event = msg["payload"]["event"] + from_login = event["from_broadcaster_user_login"] + to_login = event["to_broadcaster_user_login"] + + if self._on_raid: + self._on_raid(from_login, to_login) + + def _handle_revocation(self, msg: dict[str, Any]) -> None: + """Handle revocation — clear subscription, log warning.""" + sub = msg.get("payload", {}).get("subscription", {}) + logger.warning( + "EventSub: subscription revoked: type=%s status=%s", + sub.get("type"), + sub.get("status"), + ) + # Remove the revoked subscription by its ID + revoked_id = sub.get("id") + if revoked_id: + self._subscriptions = { + bid: sid for bid, sid in self._subscriptions.items() if sid != revoked_id + } diff --git a/music_assistant/providers/twitch/manifest.json b/music_assistant/providers/twitch/manifest.json new file mode 100644 index 0000000000..166cd3dfed --- /dev/null +++ b/music_assistant/providers/twitch/manifest.json @@ -0,0 +1,10 @@ +{ + "type": "music", + "domain": "twitch", + "stage": "beta", + "name": "Twitch Audio", + "description": "Audio-only streaming from Twitch channels with raid following.", + "codeowners": ["@Drizzt321"], + "requirements": ["streamlink>=8.0,<9"], + "icon": "twitch" +} diff --git a/requirements_all.txt b/requirements_all.txt index 579df9a3d8..fbd4b4eb14 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -73,6 +73,7 @@ soco==0.30.14 soundcloudpy==0.1.4 sounddevice==0.5.5 srptools>=1.0.0 +streamlink>=8.0,<9 sxm==0.2.8 unidecode==1.4.0 uv>=0.8.0 diff --git a/tests/providers/twitch/__init__.py b/tests/providers/twitch/__init__.py new file mode 100644 index 0000000000..5b574f5f6e --- /dev/null +++ b/tests/providers/twitch/__init__.py @@ -0,0 +1 @@ +"""Tests for the Twitch provider.""" diff --git a/tests/providers/twitch/conftest.py b/tests/providers/twitch/conftest.py new file mode 100644 index 0000000000..57ec0ca0cb --- /dev/null +++ b/tests/providers/twitch/conftest.py @@ -0,0 +1,191 @@ +"""Shared fixtures for Twitch provider tests.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, Mock + +import pytest + +from music_assistant.providers.twitch import SUPPORTED_FEATURES, TwitchProvider + +FIXTURES = Path(__file__).parent / "fixtures" + + +def load_fixture(name: str) -> dict[str, Any]: + """Load a JSON fixture file by name.""" + with (FIXTURES / name).open() as f: + return json.load(f) # type: ignore[no-any-return] + + +class MockResponse: + """Mock aiohttp response that works as an async context manager. + + Behavioral contract: + - .json() raises ValueError on non-2xx when json_data was not explicitly provided, + matching real aiohttp behavior where error responses often aren't valid JSON. + - .json() returns json_data when explicitly provided, even on error status codes, + since some error responses have JSON bodies. + - Accepts optional headers dict for testing header-dependent code paths. + """ + + _NO_JSON = object() # sentinel to distinguish "not provided" from None + + def __init__( + self, + status: int = 200, + json_data: dict[str, Any] | list[Any] | None = _NO_JSON, # type: ignore[assignment] + text_data: str = "", + headers: dict[str, str] | None = None, + ) -> None: + """Initialize mock response.""" + self.status = status + self._json_explicit = json_data is not MockResponse._NO_JSON + self._json_data = json_data if self._json_explicit else None + self._text_data = text_data + self.headers = headers or {} + + async def json(self) -> dict[str, Any] | list[Any] | None: + """Return JSON body. Raises ValueError on error status when json_data not provided.""" + if not self._json_explicit and self.status >= 400: + msg = f"Cannot parse JSON from error response (status={self.status})" + raise ValueError(msg) + return self._json_data + + async def text(self) -> str: + """Return text body.""" + return self._text_data + + async def __aenter__(self) -> MockResponse: + """Enter async context.""" + return self + + async def __aexit__(self, *args: object) -> None: + """Exit async context.""" + + +def make_mock_session_method( + responses: list[MockResponse] | MockResponse, +) -> Mock: + """Create a mock HTTP method that returns async context manager responses. + + Accepts a single MockResponse or a list for sequential calls. + """ + if isinstance(responses, list): + iterator = iter(responses) + expected_count = len(responses) + call_count = 0 + + def side_effect(*args: Any, **kwargs: Any) -> MockResponse: # noqa: ARG001 + nonlocal call_count + call_count += 1 + try: + return next(iterator) + except StopIteration: + msg = ( + f"MockResponse list exhausted: expected {expected_count} calls, " + f"got call #{call_count}. Add more MockResponse entries or " + f"assert the correct call count." + ) + raise RuntimeError(msg) from None + + mock = Mock(side_effect=side_effect) + else: + + def single(*args: Any, **kwargs: Any) -> MockResponse: # noqa: ARG001 + return responses + + mock = Mock(side_effect=single) + return mock + + +_BASE_CONFIG: dict[str, Any] = { + "client_id": "", + "client_secret": "", + "streamlink_token": "", + "auto_raid": True, + "log_level": "GLOBAL", + "access_token": "", + "refresh_token": "", +} + + +def config_side_effect(overrides: dict[str, Any] | None = None) -> Any: + """Return a side_effect callable for config.get_value with optional overrides.""" + values = {**_BASE_CONFIG, **(overrides or {})} + return lambda key, default=None: values.get(key, default) + + +@pytest.fixture +def mass_mock() -> Mock: + """Return a mock MusicAssistant instance.""" + mass = Mock() + mass.http_session = Mock() + mass.http_session.ws_connect = AsyncMock() + mass.player_queues = Mock() + mass.player_queues.play_media = AsyncMock() + mass.cache.get = AsyncMock(return_value=None) + mass.cache.set = AsyncMock() + mass.config.set_raw_provider_config_value = Mock() + # webserver for AuthenticationHelper + mass.webserver = Mock() + mass.webserver.base_url = "http://localhost:8095" + mass.webserver.register_dynamic_route = Mock() + mass.webserver.unregister_dynamic_route = Mock() + mass.signal_event = Mock() + return mass + + +@pytest.fixture +def manifest_mock() -> Mock: + """Return a mock provider manifest.""" + manifest = Mock() + manifest.domain = "twitch" + return manifest + + +@pytest.fixture +def config_mock() -> Mock: + """Return a mock provider config.""" + config = Mock() + config.name = "Twitch Test" + config.instance_id = "twitch_test" + config.enabled = True + config.get_value.side_effect = lambda key, default=None: { + "client_id": "", + "client_secret": "", + "streamlink_token": "", + "auto_raid": True, + "log_level": "GLOBAL", + }.get(key, default) + # Mock config.values as a defaultdict-like dict for _update_config_value + config.values = {} + + class _ValueHolder: + """Hold a config value for mock purposes.""" + + def __init__(self) -> None: + self.value: Any = None + + class _AutoValues(dict): # type: ignore[type-arg] + """Auto-create value holders on access.""" + + def __missing__(self, key: str) -> _ValueHolder: + holder = _ValueHolder() + self[key] = holder + return holder + + config.values = _AutoValues() + return config + + +@pytest.fixture +def provider(mass_mock: Mock, manifest_mock: Mock, config_mock: Mock) -> TwitchProvider: + """Return a TwitchProvider instance.""" + p = TwitchProvider(mass_mock, manifest_mock, config_mock, SUPPORTED_FEATURES) + # Initialize raid state that would normally be set by handle_async_init + p._active_streams = {} + p._unsubscribe_timers = {} + return p diff --git a/tests/providers/twitch/fixtures/eventsub_raid.json b/tests/providers/twitch/fixtures/eventsub_raid.json new file mode 100644 index 0000000000..57e98174b1 --- /dev/null +++ b/tests/providers/twitch/fixtures/eventsub_raid.json @@ -0,0 +1,28 @@ +{ + "metadata": { + "message_id": "raid-1", + "message_type": "notification", + "message_timestamp": "2023-06-01T12:30:00Z", + "subscription_type": "channel.raid" + }, + "payload": { + "subscription": { + "id": "sub_123", + "type": "channel.raid", + "version": "1", + "status": "enabled", + "condition": { + "from_broadcaster_user_id": "123" + } + }, + "event": { + "from_broadcaster_user_id": "123", + "from_broadcaster_user_login": "streamer_a", + "from_broadcaster_user_name": "Streamer A", + "to_broadcaster_user_id": "789", + "to_broadcaster_user_login": "streamer_c", + "to_broadcaster_user_name": "Streamer C", + "viewers": 1500 + } + } +} diff --git a/tests/providers/twitch/fixtures/eventsub_reconnect.json b/tests/providers/twitch/fixtures/eventsub_reconnect.json new file mode 100644 index 0000000000..ca01afa23c --- /dev/null +++ b/tests/providers/twitch/fixtures/eventsub_reconnect.json @@ -0,0 +1,14 @@ +{ + "metadata": { + "message_id": "reconnect-1", + "message_type": "session_reconnect", + "message_timestamp": "2023-06-01T12:15:00Z" + }, + "payload": { + "session": { + "id": "test_session_123", + "status": "reconnecting", + "reconnect_url": "wss://eventsub.wss.twitch.tv/ws?reconnect=true" + } + } +} diff --git a/tests/providers/twitch/fixtures/eventsub_revocation.json b/tests/providers/twitch/fixtures/eventsub_revocation.json new file mode 100644 index 0000000000..95470905c5 --- /dev/null +++ b/tests/providers/twitch/fixtures/eventsub_revocation.json @@ -0,0 +1,19 @@ +{ + "metadata": { + "message_id": "revoke-1", + "message_type": "revocation", + "message_timestamp": "2023-06-01T13:00:00Z", + "subscription_type": "channel.raid" + }, + "payload": { + "subscription": { + "id": "sub_123", + "type": "channel.raid", + "version": "1", + "status": "authorization_revoked", + "condition": { + "from_broadcaster_user_id": "123" + } + } + } +} diff --git a/tests/providers/twitch/fixtures/eventsub_welcome.json b/tests/providers/twitch/fixtures/eventsub_welcome.json new file mode 100644 index 0000000000..fe7eaf230c --- /dev/null +++ b/tests/providers/twitch/fixtures/eventsub_welcome.json @@ -0,0 +1,16 @@ +{ + "metadata": { + "message_id": "welcome-1", + "message_type": "session_welcome", + "message_timestamp": "2023-06-01T12:00:00Z" + }, + "payload": { + "session": { + "id": "test_session_123", + "status": "connected", + "connected_at": "2023-06-01T12:00:00Z", + "keepalive_timeout_seconds": 10, + "reconnect_url": null + } + } +} diff --git a/tests/providers/twitch/fixtures/followed_channels.json b/tests/providers/twitch/fixtures/followed_channels.json new file mode 100644 index 0000000000..fefcaed97e --- /dev/null +++ b/tests/providers/twitch/fixtures/followed_channels.json @@ -0,0 +1,32 @@ +{ + "page1": { + "data": [ + { + "broadcaster_id": "123", + "broadcaster_login": "streamer_a", + "broadcaster_name": "Streamer A", + "followed_at": "2023-01-01T00:00:00Z" + }, + { + "broadcaster_id": "456", + "broadcaster_login": "streamer_b", + "broadcaster_name": "Streamer B", + "followed_at": "2023-02-01T00:00:00Z" + } + ], + "pagination": { + "cursor": "abc123" + } + }, + "page2": { + "data": [ + { + "broadcaster_id": "789", + "broadcaster_login": "streamer_c", + "broadcaster_name": "Streamer C", + "followed_at": "2023-03-01T00:00:00Z" + } + ], + "pagination": {} + } +} diff --git a/tests/providers/twitch/fixtures/live_streams.json b/tests/providers/twitch/fixtures/live_streams.json new file mode 100644 index 0000000000..f7b3233a21 --- /dev/null +++ b/tests/providers/twitch/fixtures/live_streams.json @@ -0,0 +1,34 @@ +{ + "data": [ + { + "id": "stream1", + "user_id": "123", + "user_login": "streamer_a", + "user_name": "Streamer A", + "game_id": "12345", + "game_name": "Just Chatting", + "type": "live", + "title": "Hello World Stream", + "viewer_count": 1500, + "started_at": "2023-06-01T12:00:00Z", + "language": "en", + "thumbnail_url": "https://static-cdn.jtvnw.net/previews-ttv/live_user_streamer_a-{width}x{height}.jpg", + "is_mature": false + }, + { + "id": "stream2", + "user_id": "789", + "user_login": "streamer_c", + "user_name": "Streamer C", + "game_id": "67890", + "game_name": "Music", + "type": "live", + "title": "Chill Beats", + "viewer_count": 250, + "started_at": "2023-06-01T14:00:00Z", + "language": "en", + "thumbnail_url": "https://static-cdn.jtvnw.net/previews-ttv/live_user_streamer_c-{width}x{height}.jpg", + "is_mature": false + } + ] +} diff --git a/tests/providers/twitch/fixtures/search_results.json b/tests/providers/twitch/fixtures/search_results.json new file mode 100644 index 0000000000..95cd9fdb00 --- /dev/null +++ b/tests/providers/twitch/fixtures/search_results.json @@ -0,0 +1,28 @@ +{ + "data": [ + { + "broadcaster_language": "en", + "broadcaster_login": "streamer_a", + "display_name": "Streamer A", + "game_id": "12345", + "game_name": "Just Chatting", + "id": "123", + "is_live": true, + "thumbnail_url": "https://static-cdn.jtvnw.net/jtv_user_pictures/streamer_a-profile.png", + "title": "Hello World Stream", + "started_at": "2023-06-01T12:00:00Z" + }, + { + "broadcaster_language": "en", + "broadcaster_login": "streamer_d", + "display_name": "Streamer D", + "game_id": "67890", + "game_name": "Music", + "id": "999", + "is_live": false, + "thumbnail_url": "https://static-cdn.jtvnw.net/jtv_user_pictures/streamer_d-profile.png", + "title": "", + "started_at": "" + } + ] +} diff --git a/tests/providers/twitch/fixtures/token_exchange.json b/tests/providers/twitch/fixtures/token_exchange.json new file mode 100644 index 0000000000..860c8ec87e --- /dev/null +++ b/tests/providers/twitch/fixtures/token_exchange.json @@ -0,0 +1,7 @@ +{ + "access_token": "new_access_token_123", + "expires_in": 14400, + "refresh_token": "new_refresh_token_456", + "scope": ["user:read:follows"], + "token_type": "bearer" +} diff --git a/tests/providers/twitch/fixtures/token_refresh.json b/tests/providers/twitch/fixtures/token_refresh.json new file mode 100644 index 0000000000..99c54b0882 --- /dev/null +++ b/tests/providers/twitch/fixtures/token_refresh.json @@ -0,0 +1,7 @@ +{ + "access_token": "refreshed_access_token_789", + "expires_in": 14400, + "refresh_token": "rotated_refresh_token_012", + "scope": ["user:read:follows"], + "token_type": "bearer" +} diff --git a/tests/providers/twitch/fixtures/user_lookup.json b/tests/providers/twitch/fixtures/user_lookup.json new file mode 100644 index 0000000000..286b1ccf81 --- /dev/null +++ b/tests/providers/twitch/fixtures/user_lookup.json @@ -0,0 +1,15 @@ +{ + "data": [ + { + "id": "123", + "login": "streamer_a", + "display_name": "Streamer A", + "type": "", + "broadcaster_type": "partner", + "description": "A test streamer", + "profile_image_url": "https://static-cdn.jtvnw.net/jtv_user_pictures/streamer_a-profile.png", + "offline_image_url": "", + "created_at": "2020-01-01T00:00:00Z" + } + ] +} diff --git a/tests/providers/twitch/test_ad_handling.py b/tests/providers/twitch/test_ad_handling.py new file mode 100644 index 0000000000..06e93a0abd --- /dev/null +++ b/tests/providers/twitch/test_ad_handling.py @@ -0,0 +1,225 @@ +"""Test Twitch ad handling — passthrough with ad break tracking.""" + +# ruff: noqa: PLC0415 +# Imports must be inside test functions because we inject fake +# streamlink modules into sys.modules via the autouse fixture. + +from __future__ import annotations + +import logging +import sys +from types import ModuleType +from typing import Any +from unittest.mock import Mock + +import pytest + +# --- Mock Streamlink Classes --- +# Streamlink is not installed in the test environment (runtime dep), +# so we create mock base classes and inject them into sys.modules +# before importing ad_handling. + + +class FakeTwitchHLSSegment: + """Mock Streamlink TwitchHLSSegment.""" + + def __init__(self, *, ad: bool = False, num: int = 0, duration: float = 2.0) -> None: + """Initialize fake segment.""" + self.ad = ad + self.num = num + self.duration = duration + + +class FakeTwitchHLSStreamWriter: + """Mock Streamlink TwitchHLSStreamWriter base class. + + Mirrors real TwitchHLSStreamWriter behavior: + - should_filter_segment returns segment.ad (filters ads by default) + """ + + def __init__(self) -> None: + """Initialize with a mock reader/buffer.""" + self.reader = Mock() + self.reader.buffer = Mock() + + def write(self, segment: Any, result: Any, *data: Any) -> None: + """Store segment content for verification.""" + self.reader.buffer.write(result.content) + + def should_filter_segment(self, segment: Any) -> bool: + """Return segment.ad — matches real TwitchHLSStreamWriter.""" + return bool(segment.ad) + + +class FakeTwitchHLSStreamReader: + """Mock Streamlink TwitchHLSStreamReader.""" + + __writer__ = FakeTwitchHLSStreamWriter + + +@pytest.fixture(autouse=True) +def _mock_streamlink_modules() -> Any: + """Inject fake streamlink modules so ad_handling can import them.""" + twitch_module = ModuleType("streamlink.plugins.twitch") + twitch_module.TwitchHLSSegment = FakeTwitchHLSSegment # type: ignore[attr-defined] + twitch_module.TwitchHLSStreamWriter = FakeTwitchHLSStreamWriter # type: ignore[attr-defined] + twitch_module.TwitchHLSStreamReader = FakeTwitchHLSStreamReader # type: ignore[attr-defined] + + streamlink_module = ModuleType("streamlink") + plugins_module = ModuleType("streamlink.plugins") + + saved = {} + for key in ("streamlink", "streamlink.plugins", "streamlink.plugins.twitch"): + saved[key] = sys.modules.get(key) + sys.modules["streamlink"] = streamlink_module + sys.modules["streamlink.plugins"] = plugins_module + sys.modules["streamlink.plugins.twitch"] = twitch_module + + # Reset __writer__ before each test + FakeTwitchHLSStreamReader.__writer__ = FakeTwitchHLSStreamWriter + + yield + + # Restore + for key, val in saved.items(): + if val is None: + sys.modules.pop(key, None) + else: + sys.modules[key] = val + + # Also clear ad_handling module cache so it reimports cleanly + sys.modules.pop("music_assistant.providers.twitch.ad_handling", None) + + +# --- Monkey-Patch Application --- + + +def test_patch_targets_exist_in_streamlink() -> None: + """TwitchHLSSegment, TwitchHLSStreamReader, TwitchHLSStreamWriter are importable.""" + from streamlink.plugins.twitch import ( + TwitchHLSSegment, + TwitchHLSStreamReader, + TwitchHLSStreamWriter, + ) + + assert TwitchHLSSegment is FakeTwitchHLSSegment # type: ignore[comparison-overlap] + assert TwitchHLSStreamReader is FakeTwitchHLSStreamReader # type: ignore[comparison-overlap] + assert TwitchHLSStreamWriter is FakeTwitchHLSStreamWriter # type: ignore[comparison-overlap] + + +def test_passthrough_patch_applies_without_error() -> None: + """Passthrough monkey-patch applies without errors.""" + from music_assistant.providers.twitch.ad_handling import patch_ad_handling + + patch_ad_handling() + assert FakeTwitchHLSStreamReader.__writer__ is not FakeTwitchHLSStreamWriter + + +def test_patch_does_not_affect_non_ad_segments() -> None: + """Normal (non-ad) segments pass through unchanged.""" + from music_assistant.providers.twitch.ad_handling import patch_ad_handling + + patch_ad_handling() + + writer_cls = FakeTwitchHLSStreamReader.__writer__ + writer = object.__new__(writer_cls) + + segment = FakeTwitchHLSSegment(ad=False, num=1, duration=2.0) + assert writer.should_filter_segment(segment) is False + + +# --- Ad Break Flag Tracking --- + + +def test_ad_break_flag_set_on_ad() -> None: + """ad_break_active set to True when ad segment processed.""" + import music_assistant.providers.twitch.ad_handling as ah + + ah.patch_ad_handling() + ah.ad_break_active = False + + writer_cls = FakeTwitchHLSStreamReader.__writer__ + writer = object.__new__(writer_cls) + + segment = FakeTwitchHLSSegment(ad=True, num=1, duration=2.0) + writer.should_filter_segment(segment) + assert ah.ad_break_active is True + + +def test_ad_break_flag_cleared_on_content() -> None: + """ad_break_active set to False when non-ad segment follows.""" + import music_assistant.providers.twitch.ad_handling as ah + + ah.patch_ad_handling() + ah.ad_break_active = True + + writer_cls = FakeTwitchHLSStreamReader.__writer__ + writer = object.__new__(writer_cls) + + segment = FakeTwitchHLSSegment(ad=False, num=2, duration=2.0) + writer.should_filter_segment(segment) + assert ah.ad_break_active is False + + +# --- Passthrough Behavior --- + + +def test_ad_segment_logged(caplog: pytest.LogCaptureFixture) -> None: + """Ad segment is logged at debug level.""" + from music_assistant.providers.twitch.ad_handling import patch_ad_handling + + patch_ad_handling() + + writer_cls = FakeTwitchHLSStreamReader.__writer__ + writer = object.__new__(writer_cls) + + segment = FakeTwitchHLSSegment(ad=True, num=1, duration=2.0) + + with caplog.at_level(logging.DEBUG): + writer.should_filter_segment(segment) + + assert any("ad segment" in r.message.lower() for r in caplog.records) + + +def test_ad_segment_passes_through() -> None: + """should_filter_segment returns False for ad segments.""" + from music_assistant.providers.twitch.ad_handling import patch_ad_handling + + patch_ad_handling() + + writer_cls = FakeTwitchHLSStreamReader.__writer__ + writer = object.__new__(writer_cls) + + segment = FakeTwitchHLSSegment(ad=True, num=1, duration=2.0) + assert writer.should_filter_segment(segment) is False + + +def test_passthrough_non_ad_also_passes() -> None: + """Non-ad segments also pass through.""" + from music_assistant.providers.twitch.ad_handling import patch_ad_handling + + patch_ad_handling() + + writer_cls = FakeTwitchHLSStreamReader.__writer__ + writer = object.__new__(writer_cls) + + segment = FakeTwitchHLSSegment(ad=False, num=1, duration=2.0) + assert writer.should_filter_segment(segment) is False + + +def test_ad_end_logged(caplog: pytest.LogCaptureFixture) -> None: + """Transition from ad to content is logged.""" + import music_assistant.providers.twitch.ad_handling as ah + + ah.patch_ad_handling() + ah.ad_break_active = True + + writer_cls = FakeTwitchHLSStreamReader.__writer__ + writer = object.__new__(writer_cls) + + segment = FakeTwitchHLSSegment(ad=False, num=2, duration=2.0) + + with caplog.at_level(logging.DEBUG): + writer.should_filter_segment(segment) + + assert any("ad block ended" in r.message.lower() for r in caplog.records) diff --git a/tests/providers/twitch/test_auth.py b/tests/providers/twitch/test_auth.py new file mode 100644 index 0000000000..acad7d48dd --- /dev/null +++ b/tests/providers/twitch/test_auth.py @@ -0,0 +1,448 @@ +"""Test Twitch Provider OAuth & token management.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from music_assistant_models.enums import ConfigEntryType +from music_assistant_models.errors import LoginFailed + +from music_assistant.providers.twitch import ( + CONF_ACCESS_TOKEN, + CONF_ACTION_AUTH, + CONF_ACTION_REVOKE, + CONF_CLIENT_ID, + CONF_CLIENT_SECRET, + CONF_REFRESH_TOKEN, + CONF_STREAMLINK_TOKEN, + TWITCH_SCOPES, + TwitchProvider, + get_config_entries, +) +from tests.providers.twitch.conftest import ( + MockResponse, + load_fixture, + make_mock_session_method, +) + +# --- Config Entries --- + + +async def test_config_entries_returns_expected_fields(mass_mock: Mock) -> None: + """get_config_entries() returns entries for all expected fields.""" + entries = await get_config_entries(mass_mock) + keys = {e.key for e in entries} + assert CONF_CLIENT_ID in keys + assert CONF_CLIENT_SECRET in keys + assert CONF_STREAMLINK_TOKEN in keys + + +async def test_client_id_is_secure_string(mass_mock: Mock) -> None: + """client_id config entry type is SECURE_STRING.""" + entries = await get_config_entries(mass_mock) + entry = next(e for e in entries if e.key == CONF_CLIENT_ID) + assert entry.type == ConfigEntryType.SECURE_STRING + + +async def test_client_secret_is_secure_string(mass_mock: Mock) -> None: + """client_secret config entry type is SECURE_STRING.""" + entries = await get_config_entries(mass_mock) + entry = next(e for e in entries if e.key == CONF_CLIENT_SECRET) + assert entry.type == ConfigEntryType.SECURE_STRING + + +async def test_streamlink_token_is_optional_secure_string(mass_mock: Mock) -> None: + """streamlink_token is SECURE_STRING, not required.""" + entries = await get_config_entries(mass_mock) + entry = next(e for e in entries if e.key == CONF_STREAMLINK_TOKEN) + assert entry.type == ConfigEntryType.SECURE_STRING + assert entry.required is False + + +async def test_auth_action_present(mass_mock: Mock) -> None: + """An ACTION type config entry exists for triggering OAuth.""" + entries = await get_config_entries(mass_mock) + action_entries = [e for e in entries if e.type == ConfigEntryType.ACTION] + action_keys = {e.action for e in action_entries} + assert CONF_ACTION_AUTH in action_keys + + +async def test_auth_status_label_present(mass_mock: Mock) -> None: + """A LABEL type config entry exists showing auth status.""" + entries = await get_config_entries(mass_mock) + label_entries = [e for e in entries if e.type == ConfigEntryType.LABEL] + assert len(label_entries) >= 1 + + +async def test_not_authenticated_label(mass_mock: Mock) -> None: + """Before auth, label shows 'Not authenticated'.""" + entries = await get_config_entries(mass_mock) + label_entries = [e for e in entries if e.type == ConfigEntryType.LABEL] + label_text = " ".join(e.label for e in label_entries).lower() + assert "not authenticated" in label_text + + +async def test_authenticated_label(mass_mock: Mock) -> None: + """After auth, label shows 'Authenticated' (not 'Not authenticated').""" + values: dict[str, Any] = { + CONF_ACCESS_TOKEN: "test_access_token", + CONF_REFRESH_TOKEN: "test_refresh_token", + } + entries = await get_config_entries(mass_mock, values=values) + label_entries = [e for e in entries if e.type == ConfigEntryType.LABEL] + label_text = " ".join(e.label for e in label_entries).lower() + assert "authenticated" in label_text + assert "not authenticated" not in label_text + + +async def test_revoke_action_hidden_when_not_authenticated(mass_mock: Mock) -> None: + """Revoke action is hidden when not authenticated.""" + entries = await get_config_entries(mass_mock) + revoke_entries = [e for e in entries if e.action == CONF_ACTION_REVOKE] + assert all(e.hidden for e in revoke_entries) + + +async def test_revoke_action_visible_when_authenticated(mass_mock: Mock) -> None: + """Revoke action is visible when authenticated.""" + values: dict[str, Any] = { + CONF_ACCESS_TOKEN: "test_access_token", + CONF_REFRESH_TOKEN: "test_refresh_token", + } + entries = await get_config_entries(mass_mock, values=values) + revoke_entries = [e for e in entries if e.action == CONF_ACTION_REVOKE] + assert any(not e.hidden for e in revoke_entries) + + +# --- Config Validation — Bad/Missing Values --- + + +async def test_empty_client_id_provider_loads( + mass_mock: Mock, manifest_mock: Mock, config_mock: Mock +) -> None: + """Provider loads without crash when client_id is empty.""" + provider = TwitchProvider(mass_mock, manifest_mock, config_mock) + # Should not raise + assert provider is not None + + +async def test_empty_credentials_shows_not_authenticated(mass_mock: Mock) -> None: + """Config label shows 'Not authenticated' state, not error/crash.""" + values: dict[str, Any] = { + CONF_CLIENT_ID: "", + CONF_CLIENT_SECRET: "", + } + entries = await get_config_entries(mass_mock, values=values) + label_entries = [e for e in entries if e.type == ConfigEntryType.LABEL] + label_text = " ".join(e.label for e in label_entries).lower() + assert "not authenticated" in label_text + + +# --- OAuth Flow --- + + +async def test_auth_action_with_empty_client_id(mass_mock: Mock) -> None: + """Authenticate with no client_id raises clear error, not crash.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_CLIENT_ID: "", + CONF_CLIENT_SECRET: "secret", + } + with pytest.raises(LoginFailed, match=r"(?i)client"): + await get_config_entries(mass_mock, action=CONF_ACTION_AUTH, values=values) + + +async def test_auth_action_with_empty_client_secret(mass_mock: Mock) -> None: + """Authenticate with no client_secret raises clear error, not crash.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_CLIENT_ID: "client_id", + CONF_CLIENT_SECRET: "", + } + with pytest.raises(LoginFailed, match=r"(?i)client"): + await get_config_entries(mass_mock, action=CONF_ACTION_AUTH, values=values) + + +async def test_auth_callback_exchanges_code_for_tokens(mass_mock: Mock) -> None: + """Happy-path OAuth: code exchanged for tokens, provider becomes authenticated.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_CLIENT_ID: "test_client", + CONF_CLIENT_SECRET: "test_secret", + } + + mock_auth = AsyncMock() + mock_auth.__aenter__ = AsyncMock(return_value=mock_auth) + mock_auth.__aexit__ = AsyncMock(return_value=None) + mock_auth.callback_url = "http://localhost:8095/callback/test" + mock_auth.authenticate = AsyncMock(return_value={"code": "valid_code"}) + + # Token exchange succeeds — use fixture for response data + fixture = load_fixture("token_exchange.json") + mass_mock.http_session.post = make_mock_session_method( + MockResponse(status=200, json_data=fixture) + ) + + with patch("music_assistant.providers.twitch.AuthenticationHelper", return_value=mock_auth): + entries = await get_config_entries(mass_mock, action=CONF_ACTION_AUTH, values=values) + + # Tokens should be stored in values + assert values[CONF_ACCESS_TOKEN] == fixture["access_token"] + assert values[CONF_REFRESH_TOKEN] == fixture["refresh_token"] + + # Config entries should show authenticated state + label_entries = [e for e in entries if e.type == ConfigEntryType.LABEL] + label_text = " ".join(e.label for e in label_entries).lower() + assert "not authenticated" not in label_text + + +async def test_auth_scope_includes_user_read_follows() -> None: + """OAuth scope includes user:read:follows.""" + assert "user:read:follows" in TWITCH_SCOPES + + +# --- Token Refresh --- + + +async def test_401_triggers_refresh(provider: TwitchProvider) -> None: + """API call returning 401 triggers token refresh, then retries.""" + provider._access_token = "expired_token" + provider._refresh_token = "valid_refresh" + provider._client_id = "test_client_id" + provider._client_secret = "test_client_secret" + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=401), + MockResponse(status=200, json_data={"data": []}), + ] + ) + fixture = load_fixture("token_refresh.json") + provider.mass.http_session.post = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data=fixture) + ) + + result = await provider._api_get("/helix/streams") + assert result == {"data": []} + + +async def test_refresh_no_refresh_token_raises(provider: TwitchProvider) -> None: + """Refresh with no stored refresh token raises LoginFailed.""" + provider._access_token = "some_token" + provider._refresh_token = None + provider._client_id = "test_client" + provider._client_secret = "test_secret" + + with pytest.raises(LoginFailed, match=r"(?i)refresh"): + await provider._refresh_access_token() + + +async def test_refresh_saves_new_refresh_token(provider: TwitchProvider) -> None: + """When refresh response includes new refresh_token, it's saved (token rotation).""" + provider._access_token = "old_access" + provider._refresh_token = "old_refresh" + provider._client_id = "test_client_id" + provider._client_secret = "test_client_secret" + + fixture = load_fixture("token_refresh.json") + provider.mass.http_session.post = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data=fixture) + ) + + await provider._refresh_access_token() + assert provider._access_token == fixture["access_token"] + assert provider._refresh_token == fixture["refresh_token"] + + +async def test_refresh_preserves_old_refresh_token_if_not_rotated( + provider: TwitchProvider, +) -> None: + """When refresh response omits refresh_token, old one is preserved.""" + provider._access_token = "old_access" + provider._refresh_token = "old_refresh" + provider._client_id = "test_client_id" + provider._client_secret = "test_client_secret" + + # Use fixture but remove refresh_token to simulate non-rotation response + fixture = load_fixture("token_refresh.json") + fixture_no_rotate = {k: v for k, v in fixture.items() if k != "refresh_token"} + provider.mass.http_session.post = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data=fixture_no_rotate) + ) + + await provider._refresh_access_token() + assert provider._access_token == fixture["access_token"] + assert provider._refresh_token == "old_refresh" + + +async def test_refresh_failure_raises_login_failed(provider: TwitchProvider) -> None: + """On refresh failure, LoginFailed is raised.""" + provider._access_token = "old_access" + provider._refresh_token = "old_refresh" + provider._client_id = "test_client_id" + provider._client_secret = "test_client_secret" + + provider.mass.http_session.post = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=401, text_data="Invalid refresh token") + ) + + with pytest.raises(LoginFailed): + await provider._refresh_access_token() + + +async def test_refresh_failure_clears_both_tokens(provider: TwitchProvider) -> None: + """On refresh failure, both access and refresh tokens are cleared.""" + provider._access_token = "old_access" + provider._refresh_token = "old_refresh" + provider._client_id = "test_client_id" + provider._client_secret = "test_client_secret" + + provider.mass.http_session.post = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=401, text_data="Invalid refresh token") + ) + + cleared_access: str | None = "sentinel" + cleared_refresh: str | None = "sentinel" + try: + await provider._refresh_access_token() + except LoginFailed: + cleared_access = provider._access_token + cleared_refresh = provider._refresh_token + + assert cleared_access is None + assert cleared_refresh is None + + +# --- Token Exchange Errors --- + + +async def test_token_exchange_fails_invalid_code(mass_mock: Mock) -> None: + """Twitch rejects authorization code — LoginFailed raised.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_CLIENT_ID: "test_client", + CONF_CLIENT_SECRET: "test_secret", + } + + # Mock AuthenticationHelper to return a code + mock_auth = AsyncMock() + mock_auth.__aenter__ = AsyncMock(return_value=mock_auth) + mock_auth.__aexit__ = AsyncMock(return_value=None) + mock_auth.callback_url = "http://localhost:8095/callback/test" + mock_auth.authenticate = AsyncMock(return_value={"code": "bad_code"}) + + # Token exchange fails + mass_mock.http_session.post = make_mock_session_method( + MockResponse(status=400, text_data="Invalid authorization code") + ) + + with ( + patch("music_assistant.providers.twitch.AuthenticationHelper", return_value=mock_auth), + pytest.raises(LoginFailed), + ): + await get_config_entries(mass_mock, action=CONF_ACTION_AUTH, values=values) + + +async def test_token_exchange_fails_network_error(mass_mock: Mock) -> None: + """Network failure during token exchange — LoginFailed raised.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_CLIENT_ID: "test_client", + CONF_CLIENT_SECRET: "test_secret", + } + + mock_auth = AsyncMock() + mock_auth.__aenter__ = AsyncMock(return_value=mock_auth) + mock_auth.__aexit__ = AsyncMock(return_value=None) + mock_auth.callback_url = "http://localhost:8095/callback/test" + mock_auth.authenticate = AsyncMock(return_value={"code": "valid_code"}) + + def raise_error(*_args: Any, **_kwargs: Any) -> None: + msg = "connection refused" + raise ConnectionError(msg) + + mass_mock.http_session.post = Mock(side_effect=raise_error) + + with ( + patch("music_assistant.providers.twitch.AuthenticationHelper", return_value=mock_auth), + pytest.raises(ConnectionError), + ): + await get_config_entries(mass_mock, action=CONF_ACTION_AUTH, values=values) + + +# --- Logout / Revoke --- + + +async def test_revoke_noop_when_not_authenticated(mass_mock: Mock) -> None: + """Revoke with no tokens is a no-op — no API call made.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_ACCESS_TOKEN: "", + CONF_REFRESH_TOKEN: "", + CONF_CLIENT_ID: "test_client", + } + mass_mock.http_session.post = make_mock_session_method(MockResponse(status=200)) + + await get_config_entries(mass_mock, action=CONF_ACTION_REVOKE, values=values) + + # post should NOT have been called — no token to revoke + mass_mock.http_session.post.assert_not_called() + + +async def test_revoke_invalidates_live_status_cache(mass_mock: Mock) -> None: + """After revoke, tokens are cleared in the values dict.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_ACCESS_TOKEN: "test_token", + CONF_REFRESH_TOKEN: "test_refresh", + CONF_CLIENT_ID: "test_client", + } + mass_mock.http_session.post = make_mock_session_method(MockResponse(status=200)) + + await get_config_entries(mass_mock, action=CONF_ACTION_REVOKE, values=values) + + # Values dict should have tokens cleared + assert values[CONF_ACCESS_TOKEN] == "" + assert values[CONF_REFRESH_TOKEN] == "" + + +async def test_revoke_action_clears_tokens(mass_mock: Mock) -> None: + """Revoke action clears stored tokens.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_ACCESS_TOKEN: "test_token", + CONF_REFRESH_TOKEN: "test_refresh", + CONF_CLIENT_ID: "test_client", + } + mass_mock.http_session.post = make_mock_session_method(MockResponse(status=200)) + + entries = await get_config_entries(mass_mock, action=CONF_ACTION_REVOKE, values=values) + token_entries = [e for e in entries if e.key == CONF_ACCESS_TOKEN] + if token_entries: + assert token_entries[0].value in (None, "") + refresh_entries = [e for e in entries if e.key == CONF_REFRESH_TOKEN] + if refresh_entries: + assert refresh_entries[0].value in (None, "") + + +async def test_revoke_tolerates_network_error(mass_mock: Mock) -> None: + """Network error during revoke still clears local credentials.""" + values: dict[str, Any] = { + "session_id": "test_session", + CONF_ACCESS_TOKEN: "test_token", + CONF_REFRESH_TOKEN: "test_refresh", + CONF_CLIENT_ID: "test_client", + } + + def raise_error(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + msg = "network error" + raise ConnectionError(msg) + + mass_mock.http_session.post = Mock(side_effect=raise_error) + + # Should not raise — revoke is best-effort + entries = await get_config_entries(mass_mock, action=CONF_ACTION_REVOKE, values=values) + token_entries = [e for e in entries if e.key == CONF_ACCESS_TOKEN] + if token_entries: + assert token_entries[0].value in (None, "") diff --git a/tests/providers/twitch/test_browse.py b/tests/providers/twitch/test_browse.py new file mode 100644 index 0000000000..3e5634b112 --- /dev/null +++ b/tests/providers/twitch/test_browse.py @@ -0,0 +1,428 @@ +"""Test Twitch Provider browse, library radios, and search.""" + +from __future__ import annotations + +import pytest +from music_assistant_models.enums import MediaType +from music_assistant_models.errors import ProviderUnavailableError +from music_assistant_models.media_items import BrowseFolder, Radio + +from music_assistant.providers.twitch import TwitchProvider +from tests.providers.twitch.conftest import ( + MockResponse, + load_fixture, + make_mock_session_method, +) + + +def _users_response() -> MockResponse: + """Return a mock users API response.""" + return MockResponse(status=200, json_data=load_fixture("user_lookup.json")) + + +def _setup_authenticated_provider(provider: TwitchProvider) -> None: + """Configure provider with test credentials and cached data.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + provider._user_id = "99" + + +# --- Library Radios (Live Only) --- + + +async def test_library_radios_yields_radio_items(provider: TwitchProvider) -> None: + """get_library_radios() yields Radio objects.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + radios = [item async for item in provider.get_library_radios()] + assert len(radios) > 0 + assert all(isinstance(r, Radio) for r in radios) + + +async def test_library_radios_only_live(provider: TwitchProvider) -> None: + """Offline followed channels are not yielded.""" + _setup_authenticated_provider(provider) + + # Channel B (456) is followed but not in live_streams fixture + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + # Return all 3 channels (page1 + page2) + MockResponse(status=200, json_data=fixture_channels["page1"]), + MockResponse(status=200, json_data=fixture_channels["page2"]), + # Only streamer_a and streamer_c are live + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + radios = [item async for item in provider.get_library_radios()] + logins = [r.item_id for r in radios] + assert "streamer_a" in logins + assert "streamer_c" in logins + assert "streamer_b" not in logins # offline — not in library + + +async def test_library_radios_item_fields(provider: TwitchProvider) -> None: + """Each Radio has correct item_id (login), name (display_name), provider.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + radios = [item async for item in provider.get_library_radios()] + assert len(radios) > 0 + radio = radios[0] + assert radio.item_id # has an ID + assert radio.name # has a name + assert radio.provider == provider.domain + + +async def test_library_radios_empty_when_none_live(provider: TwitchProvider) -> None: + """Returns empty when no followed channels are live.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data={"data": []}), # nobody live + _users_response(), + ] + ) + + radios = [item async for item in provider.get_library_radios()] + assert radios == [] + + +async def test_library_radios_requires_auth(provider: TwitchProvider) -> None: + """Returns empty when not authenticated.""" + provider._access_token = None + provider._user_id = None + + radios = [item async for item in provider.get_library_radios()] + assert radios == [] + + +# --- Browse Structure --- + + +async def test_browse_root_returns_two_folders(provider: TwitchProvider) -> None: + """browse("") returns "Live" and "Following" BrowseFolder items.""" + _setup_authenticated_provider(provider) + items = await provider.browse("") + folder_names = [f.name for f in items if isinstance(f, BrowseFolder)] + assert "Live" in folder_names + assert "Following" in folder_names + + +async def test_browse_live_returns_only_live_channels(provider: TwitchProvider) -> None: + """browse("Live") returns only currently live channels.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page1"]), + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + items = await provider.browse(f"{provider.instance_id}://live") + assert len(items) > 0 + # Should only contain live channels + item_ids = [getattr(r, "item_id", None) for r in items] + assert "streamer_a" in item_ids + assert "streamer_b" not in item_ids # offline + + +async def test_browse_following_returns_all_channels( + provider: TwitchProvider, +) -> None: + """browse("Following") returns all followed channels.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page1"]), + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + items = await provider.browse(f"{provider.instance_id}://following") + item_ids = [getattr(r, "item_id", None) for r in items] + assert "streamer_a" in item_ids + assert "streamer_b" in item_ids # offline but still in Following + assert "streamer_c" in item_ids + + +async def test_browse_following_marks_offline(provider: TwitchProvider) -> None: + """Offline channels in Following browse have '(offline)' in name.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page1"]), + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + items = await provider.browse(f"{provider.instance_id}://following") + # Find streamer_b (offline) + offline_items = [r for r in items if getattr(r, "item_id", None) == "streamer_b"] + assert len(offline_items) == 1 + assert "(offline)" in offline_items[0].name.lower() + + +async def test_browse_invalid_path_returns_empty(provider: TwitchProvider) -> None: + """Unknown browse path returns empty list.""" + _setup_authenticated_provider(provider) + items = await provider.browse(f"{provider.instance_id}://nonexistent") + assert items == [] + + +# --- Search --- + + +async def test_search_returns_matching_channels(provider: TwitchProvider) -> None: + """search() returns channels matching query from Twitch search API.""" + _setup_authenticated_provider(provider) + + fixture = load_fixture("search_results.json") + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data=fixture) + ) + + results = await provider.search("streamer", [MediaType.RADIO]) + assert len(results.radio) > 0 + + +async def test_search_results_are_radio_type(provider: TwitchProvider) -> None: + """Search results contain Radio items.""" + _setup_authenticated_provider(provider) + + fixture = load_fixture("search_results.json") + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data=fixture) + ) + + results = await provider.search("streamer", [MediaType.RADIO]) + assert all(isinstance(r, Radio) for r in results.radio) + + +async def test_search_respects_limit(provider: TwitchProvider) -> None: + """Search passes limit param to API.""" + _setup_authenticated_provider(provider) + + fixture = load_fixture("search_results.json") + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data=fixture) + ) + + await provider.search("streamer", [MediaType.RADIO], limit=3) + + call_kwargs = provider.mass.http_session.get.call_args + params = call_kwargs.kwargs.get("params", {}) + assert params.get("first") == "3" + + +async def test_search_empty_query_returns_empty(provider: TwitchProvider) -> None: + """Empty search query returns empty results.""" + _setup_authenticated_provider(provider) + + results = await provider.search("", [MediaType.RADIO]) + assert len(results.radio) == 0 + + +async def test_search_api_error_returns_empty(provider: TwitchProvider) -> None: + """API failure during search returns empty, not exception.""" + _setup_authenticated_provider(provider) + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=500) + ) + + results = await provider.search("test", [MediaType.RADIO]) + assert len(results.radio) == 0 + + +async def test_search_unauthenticated_returns_empty(provider: TwitchProvider) -> None: + """Search without auth returns empty result, not crash.""" + provider._access_token = None + + results = await provider.search("test", [MediaType.RADIO]) + assert len(results.radio) == 0 + + +# --- Browse Metadata --- + + +async def test_library_radios_includes_thumbnail(provider: TwitchProvider) -> None: + """Radio metadata includes channel thumbnail image.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + radios = [item async for item in provider.get_library_radios()] + assert len(radios) > 0 + # At least one radio should have images in metadata + assert any(r.metadata.images for r in radios) + + +async def test_library_radios_includes_viewer_count(provider: TwitchProvider) -> None: + """Radio name includes viewer count for live channels.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + radios = [item async for item in provider.get_library_radios()] + assert len(radios) > 0 + # Live radios should have viewer count in name + assert any("viewers" in r.name.lower() for r in radios) + + +async def test_browse_live_includes_viewer_counts(provider: TwitchProvider) -> None: + """Live browse items include viewer count in name.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page1"]), + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + items = await provider.browse(f"{provider.instance_id}://live") + assert len(items) > 0 + assert any("viewers" in getattr(r, "name", "").lower() for r in items) + + +async def test_browse_following_sorts_alphabetically(provider: TwitchProvider) -> None: + """Following list sorted by display name.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page1"]), + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + items = await provider.browse(f"{provider.instance_id}://following") + names = [getattr(r, "name", "").lower() for r in items] + # Strip "(offline)" and "viewers" suffixes for comparison + base_names = [n.split(" (")[0] for n in names] + assert base_names == sorted(base_names) + + +# --- Browse Error Handling --- + + +async def test_browse_api_error_raises_provider_unavailable(provider: TwitchProvider) -> None: + """API failure during browse raises ProviderUnavailableError.""" + _setup_authenticated_provider(provider) + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=500) + ) + + with pytest.raises(ProviderUnavailableError): + await provider.browse(f"{provider.instance_id}://live") + + +async def test_browse_live_empty_when_none_live(provider: TwitchProvider) -> None: + """All followed channels offline — Live folder returns empty list.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data={"data": []}), # nobody live + _users_response(), + ] + ) + + items = await provider.browse(f"{provider.instance_id}://live") + assert items == [] + + +async def test_browse_following_empty_when_no_follows(provider: TwitchProvider) -> None: + """User follows nobody — Following folder returns empty list.""" + _setup_authenticated_provider(provider) + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data={"data": [], "pagination": {}}), + MockResponse(status=200, json_data={"data": []}), + MockResponse(status=200, json_data={"data": []}), + ] + ) + + items = await provider.browse(f"{provider.instance_id}://following") + assert items == [] diff --git a/tests/providers/twitch/test_eventsub.py b/tests/providers/twitch/test_eventsub.py new file mode 100644 index 0000000000..d3feb4d703 --- /dev/null +++ b/tests/providers/twitch/test_eventsub.py @@ -0,0 +1,457 @@ +"""Test EventSub WebSocket client.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from music_assistant.providers.twitch.eventsub import EVENTSUB_WS_URL, MAX_BACKOFF, EventSubClient +from tests.providers.twitch.conftest import MockResponse, load_fixture + + +@pytest.fixture +def http_session() -> Mock: + """Return a mock aiohttp session.""" + session = Mock() + session.ws_connect = AsyncMock() + session.post = Mock( + return_value=MockResponse(status=202, json_data={"data": [{"id": "sub_999"}]}) + ) + session.delete = Mock(return_value=MockResponse(status=204)) + return session + + +@pytest.fixture +def client(http_session: Mock) -> EventSubClient: + """Return an EventSubClient instance.""" + return EventSubClient( + http_session=http_session, + api_headers_fn=lambda: {"Authorization": "Bearer test", "Client-Id": "test_client"}, + ) + + +# --- Connection Lifecycle --- + + +def test_connect_to_default_url() -> None: + """Initial connection targets the standard EventSub URL.""" + assert EVENTSUB_WS_URL == "wss://eventsub.wss.twitch.tv/ws" + + +def test_welcome_stores_session_id(client: EventSubClient) -> None: + """session_welcome message stores session ID.""" + msg = load_fixture("eventsub_welcome.json") + client._handle_message(msg) + assert client._session_id == "test_session_123" + + +def test_welcome_signals_ready(client: EventSubClient) -> None: + """session_welcome sets the ready event.""" + assert not client._ready.is_set() + msg = load_fixture("eventsub_welcome.json") + client._handle_message(msg) + assert client._ready.is_set() + + +def test_welcome_resets_backoff(client: EventSubClient) -> None: + """Successful welcome resets backoff to 1.0s.""" + client._backoff = 32.0 + msg = load_fixture("eventsub_welcome.json") + client._handle_message(msg) + assert client._backoff == 1.0 + + +async def test_stop_prevents_reconnect(client: EventSubClient) -> None: + """After stop(), _stopped flag is set.""" + await client.stop() + assert client._stopped is True + + +async def test_stop_clears_session_state(client: EventSubClient) -> None: + """stop() clears session_id, subscriptions, ready event.""" + client._session_id = "test" + client._subscriptions = {"123": "sub_1"} + client._ready.set() + + await client.stop() + + assert client._session_id is None + assert len(client._subscriptions) == 0 # type: ignore[unreachable] + assert not client._ready.is_set() + + +async def test_disconnect_triggers_reconnect(client: EventSubClient) -> None: + """WebSocket disconnect increases backoff, indicating reconnect will happen.""" + initial_backoff = client._backoff + assert initial_backoff == 1.0 + + client._backoff = min(client._backoff * 2, MAX_BACKOFF) + assert client._backoff == 2.0 + + welcome = load_fixture("eventsub_welcome.json") + client._handle_message(welcome) + assert client._backoff == 1.0 + + assert client._stopped is False + await client.stop() + assert client._stopped is True + + +# --- Twitch-Requested Reconnect --- + + +def test_reconnect_message_stores_url(client: EventSubClient) -> None: + """session_reconnect stores the new URL.""" + msg = load_fixture("eventsub_reconnect.json") + client._handle_message(msg) + assert client._reconnect_url == "wss://eventsub.wss.twitch.tv/ws?reconnect=true" + + +def test_reconnect_url_consumed(client: EventSubClient) -> None: + """Reconnect URL is stored and available for next connect attempt.""" + msg = load_fixture("eventsub_reconnect.json") + client._handle_message(msg) + assert client._reconnect_url is not None + url = client._reconnect_url + assert url == "wss://eventsub.wss.twitch.tv/ws?reconnect=true" + + +# --- Re-subscription on Reconnect --- + + +async def test_welcome_clears_old_subscriptions(client: EventSubClient) -> None: + """Welcome clears old subscription IDs (they're invalid on new session).""" + client._subscriptions = {"123": "old_sub"} + msg = load_fixture("eventsub_welcome.json") + client._handle_message(msg) + # Old subs cleared synchronously — re-subscription tasks created but not yet run + assert len(client._subscriptions) == 0 + + +def test_welcome_does_not_resubscribe_if_no_active( + client: EventSubClient, +) -> None: + """If no active broadcasters, welcome does not create subscription tasks.""" + client._subscriptions = {} + msg = load_fixture("eventsub_welcome.json") + client._handle_message(msg) + assert client._ready.is_set() + + +# --- Subscription Management --- + + +async def test_post_includes_auth_headers(client: EventSubClient, http_session: Mock) -> None: + """POST to EventSub includes Authorization and Client-Id headers.""" + client._session_id = "test_session" + client._ready.set() + + await client.subscribe_raids("123") + + call_kwargs = http_session.post.call_args + headers = call_kwargs.kwargs.get("headers", {}) + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer test" + assert "Client-Id" in headers + assert headers["Client-Id"] == "test_client" + + +async def test_subscribe_creates_raid_subscription( + client: EventSubClient, http_session: Mock +) -> None: + """subscribe_raids() calls EventSub create API with correct params.""" + client._session_id = "test_session" + client._ready.set() + + await client.subscribe_raids("123") + + http_session.post.assert_called_once() + call_kwargs = http_session.post.call_args + assert "eventsub/subscriptions" in call_kwargs.args[0] + body = call_kwargs.kwargs["json"] + assert body["type"] == "channel.raid" + assert body["condition"]["from_broadcaster_user_id"] == "123" + + +async def test_subscribe_stores_subscription(client: EventSubClient) -> None: + """subscribe_raids() stores the subscription ID in _subscriptions.""" + client._session_id = "test_session" + client._ready.set() + + await client.subscribe_raids("123") + + assert client._subscriptions["123"] == "sub_999" + + +async def test_subscribe_noop_if_already_subscribed( + client: EventSubClient, http_session: Mock +) -> None: + """subscribe_raids() is a no-op for an already-subscribed broadcaster.""" + client._session_id = "test_session" + client._ready.set() + client._subscriptions = {"123": "existing_sub"} + + await client.subscribe_raids("123") + + http_session.post.assert_not_called() + + +async def test_subscribe_waits_for_ready(client: EventSubClient, http_session: Mock) -> None: + """Subscribe blocks until ready event is set.""" + client._session_id = "test_session" + + async def set_ready() -> None: + await asyncio.sleep(0.01) + client._ready.set() + + asyncio.create_task(set_ready()) + + await client.subscribe_raids("123") + http_session.post.assert_called_once() + + +async def test_unsubscribe_raids_calls_delete_api( + client: EventSubClient, http_session: Mock +) -> None: + """unsubscribe_raids() calls EventSub delete API for specific broadcaster.""" + client._subscriptions = {"123": "sub_123"} + + await client.unsubscribe_raids("123") + + http_session.delete.assert_called_once() + assert "123" not in client._subscriptions + + +async def test_unsubscribe_raids_noop_if_not_subscribed( + client: EventSubClient, http_session: Mock +) -> None: + """unsubscribe_raids() is a no-op for a non-subscribed broadcaster.""" + client._subscriptions = {} + + await client.unsubscribe_raids("123") + + http_session.delete.assert_not_called() + + +async def test_unsubscribe_all_clears_all(client: EventSubClient, http_session: Mock) -> None: + """unsubscribe_all() unsubscribes all broadcasters.""" + client._subscriptions = {"123": "sub_1", "456": "sub_2"} + + await client.unsubscribe_all() + + assert http_session.delete.call_count == 2 + assert len(client._subscriptions) == 0 + + +async def test_unsubscribe_all_noop_when_empty(client: EventSubClient, http_session: Mock) -> None: + """Unsubscribe with no active subscriptions doesn't call API.""" + client._subscriptions = {} + + await client.unsubscribe_all() + + http_session.delete.assert_not_called() + + +async def test_unsubscribe_tolerates_api_error(client: EventSubClient, http_session: Mock) -> None: + """API error during unsubscribe is logged, not raised.""" + client._subscriptions = {"123": "sub_123"} + + def raise_err(*_args: Any, **_kwargs: Any) -> None: + msg = "API error" + raise ConnectionError(msg) + + http_session.delete = Mock(side_effect=raise_err) + + # Should not raise + await client.unsubscribe_raids("123") + assert "123" not in client._subscriptions + + +# --- Subscription Revocation --- + + +def test_revocation_clears_subscription(client: EventSubClient) -> None: + """Revocation message removes the revoked subscription.""" + client._subscriptions = {"123": "sub_123"} + msg = load_fixture("eventsub_revocation.json") + # The fixture has subscription id — check what it is + msg["payload"]["subscription"]["id"] = "sub_123" + client._handle_message(msg) + assert "123" not in client._subscriptions + + +def test_revocation_logged(client: EventSubClient, caplog: pytest.LogCaptureFixture) -> None: + """Revocation is logged as warning.""" + client._subscriptions = {"123": "sub_123"} + msg = load_fixture("eventsub_revocation.json") + with caplog.at_level(logging.WARNING): + client._handle_message(msg) + assert any("revoked" in r.message.lower() for r in caplog.records) + + +# --- Backoff --- + + +def test_backoff_doubles_on_reconnect(client: EventSubClient) -> None: + """Consecutive reconnects double backoff: 1->2->4->8->16->32->60->60.""" + client._backoff = 1.0 + expected = [2.0, 4.0, 8.0, 16.0, 32.0, 60.0, 60.0] + for exp in expected: + client._backoff = min(client._backoff * 2, 60.0) + assert client._backoff == exp + + +def test_backoff_caps_at_60s(client: EventSubClient) -> None: + """Backoff never exceeds 60s.""" + client._backoff = 60.0 + client._backoff = min(client._backoff * 2, 60.0) + assert client._backoff == 60.0 + + +# --- Twitch-Requested Reconnect (extended) --- + + +async def test_reconnect_message_closes_current_ws(client: EventSubClient) -> None: + """session_reconnect triggers close on current WebSocket.""" + mock_ws = AsyncMock() + client._ws = mock_ws + + msg = load_fixture("eventsub_reconnect.json") + client._handle_message(msg) + + assert client._reconnect_url is not None + await asyncio.sleep(0.01) + mock_ws.close.assert_called_once() + + +def test_reconnect_uses_new_url(client: EventSubClient) -> None: + """After reconnect message, stored URL is used for next connection.""" + msg = load_fixture("eventsub_reconnect.json") + client._handle_message(msg) + + expected_url = "wss://eventsub.wss.twitch.tv/ws?reconnect=true" + assert client._reconnect_url == expected_url + + +# --- Re-subscription on Reconnect (extended) --- + + +async def test_welcome_resubscribes_active_broadcasters( + client: EventSubClient, http_session: Mock +) -> None: + """If subscriptions exist, welcome re-creates them on new session.""" + client._subscriptions = {"123": "old_sub", "456": "old_sub2"} + expected_posts = len(client._subscriptions) + msg = load_fixture("eventsub_welcome.json") + + all_posted = asyncio.Event() + post_count = 0 + original_post = http_session.post + + def counting_post(*args: object, **kwargs: object) -> object: + nonlocal post_count + result = original_post(*args, **kwargs) + post_count += 1 + if post_count >= expected_posts: + all_posted.set() + return result + + http_session.post = Mock(side_effect=counting_post) + + client._handle_message(msg) + + assert client._ready.is_set() + assert client._session_id == "test_session_123" + + await asyncio.wait_for(all_posted.wait(), timeout=1.0) + + assert http_session.post.call_count == expected_posts + + +async def test_ready_set_after_resubscribe(client: EventSubClient) -> None: + """Ready event fires after welcome (even with resubscriptions).""" + client._subscriptions = {"123": "old_sub"} + msg = load_fixture("eventsub_welcome.json") + client._handle_message(msg) + assert client._ready.is_set() + + +# --- Subscription Management (extended) --- + + +async def test_subscribe_timeout_when_not_ready(client: EventSubClient) -> None: + """If ready event not set within timeout, subscribe is a no-op.""" + client._session_id = "test_session" + + async def fast_timeout(coro: Any, timeout: float) -> None: # noqa: ARG001 + raise TimeoutError + + with patch( + "music_assistant.providers.twitch.eventsub.asyncio.wait_for", side_effect=fast_timeout + ): + await client.subscribe_raids("123") + + assert "123" not in client._subscriptions + + +async def test_subscribe_skips_if_welcome_already_subscribed( + client: EventSubClient, http_session: Mock +) -> None: + """If welcome handler already created sub while waiting, don't duplicate.""" + client._session_id = "test_session" + + original_wait_for = asyncio.wait_for + + async def wait_that_simulates_welcome(coro: Any, timeout: float) -> Any: + # Simulate the welcome handler creating a sub during the wait + client._subscriptions["123"] = "sub_from_welcome" + client._ready.set() + return await original_wait_for(coro, timeout=timeout) + + with patch( + "music_assistant.providers.twitch.eventsub.asyncio.wait_for", + side_effect=wait_that_simulates_welcome, + ): + await client.subscribe_raids("123") + + # No POST should have been made — the welcome handler's sub was detected + http_session.post.assert_not_called() + assert client._subscriptions["123"] == "sub_from_welcome" + + +# --- Raid Notification --- + + +def test_raid_event_fires_callback(client: EventSubClient) -> None: + """channel.raid notification calls the on_raid callback.""" + raids_received: list[tuple[str, str]] = [] + client._on_raid = lambda from_l, to_l: raids_received.append((from_l, to_l)) + + msg = load_fixture("eventsub_raid.json") + client._handle_message(msg) + + assert len(raids_received) == 1 + assert raids_received[0] == ("streamer_a", "streamer_c") + + +def test_non_raid_notification_ignored(client: EventSubClient) -> None: + """Other notification types don't fire callback.""" + raids_received: list[tuple[str, str]] = [] + client._on_raid = lambda from_l, to_l: raids_received.append((from_l, to_l)) + + msg = load_fixture("eventsub_raid.json") + msg["metadata"]["subscription_type"] = "stream.online" + client._handle_message(msg) + + assert len(raids_received) == 0 + + +def test_invalid_json_ignored(client: EventSubClient) -> None: + """Malformed message is ignored (no crash).""" + client._handle_message({}) + client._handle_message({"metadata": {}}) + client._handle_message({"metadata": {"message_type": "unknown_type"}}) diff --git a/tests/providers/twitch/test_provider.py b/tests/providers/twitch/test_provider.py new file mode 100644 index 0000000000..678472b23f --- /dev/null +++ b/tests/providers/twitch/test_provider.py @@ -0,0 +1,134 @@ +"""Test Twitch Provider lifecycle and config.""" + +from unittest.mock import AsyncMock, Mock, patch + +from music_assistant_models.enums import ConfigEntryType, ProviderFeature + +from music_assistant.providers.twitch import ( + CONF_AUTO_RAID, + SUPPORTED_FEATURES, + TwitchProvider, + get_config_entries, + setup, +) +from tests.providers.twitch.conftest import config_side_effect + +# --- Provider Loading --- + + +async def test_setup_returns_provider_instance( + mass_mock: Mock, manifest_mock: Mock, config_mock: Mock +) -> None: + """setup() returns a TwitchProvider instance when authenticated.""" + config_mock.get_value.side_effect = lambda key, default=None: { + "client_id": "test", + "client_secret": "test", + "streamlink_token": "", + "auto_raid": True, + "log_level": "GLOBAL", + "access_token": "test_token", + "refresh_token": "test_refresh", + }.get(key, default) + provider = await setup(mass_mock, manifest_mock, config_mock) + assert isinstance(provider, TwitchProvider) + + +async def test_provider_is_streaming_provider(provider: TwitchProvider) -> None: + """is_streaming_provider property returns True.""" + assert provider.is_streaming_provider is True + + +async def test_supported_features_declared() -> None: + """SUPPORTED_FEATURES includes BROWSE, SEARCH, LIBRARY_RADIOS.""" + assert ProviderFeature.BROWSE in SUPPORTED_FEATURES + assert ProviderFeature.SEARCH in SUPPORTED_FEATURES + assert ProviderFeature.LIBRARY_RADIOS in SUPPORTED_FEATURES + + +async def test_supported_features_no_edit() -> None: + """SUPPORTED_FEATURES does not include library edit (Twitch removed follow API).""" + assert ProviderFeature.LIBRARY_RADIOS_EDIT not in SUPPORTED_FEATURES + + +async def test_unload_cleans_up(provider: TwitchProvider) -> None: + """unload() succeeds and cleans up resources.""" + await provider.unload() + + +async def test_unload_with_no_active_resources(provider: TwitchProvider) -> None: + """unload() succeeds when nothing is running (fresh provider, no playback).""" + await provider.unload() + + +async def test_provider_domain(provider: TwitchProvider) -> None: + """Provider domain matches manifest.""" + assert provider.domain == "twitch" + + +async def test_provider_instance_id(provider: TwitchProvider) -> None: + """Provider instance_id comes from config.""" + assert provider.instance_id == "twitch_test" + + +# --- Provider Initialization --- + + +async def test_handle_async_init_resolves_user_id(provider: TwitchProvider) -> None: + """handle_async_init() with valid access_token calls /helix/users and stores _user_id.""" + provider.config.get_value.side_effect = config_side_effect( # type: ignore[attr-defined] + { + "client_id": "test_client", + "client_secret": "test_secret", + "access_token": "test_token", + "refresh_token": "test_refresh", + } + ) + + with ( + patch.object( + provider, + "_api_get", + new_callable=AsyncMock, + return_value={"data": [{"id": "12345", "login": "testuser"}]}, + ), + patch("music_assistant.providers.twitch.ad_handling.patch_ad_handling"), + ): + await provider.handle_async_init() + + assert provider._user_id == "12345" + + +async def test_handle_async_init_no_token_skips(provider: TwitchProvider) -> None: + """handle_async_init() with no access_token does not call API.""" + provider.config.get_value.side_effect = config_side_effect( # type: ignore[attr-defined] + { + "client_id": "test_client", + "client_secret": "test_secret", + } + ) + + with ( + patch.object(provider, "_api_get", new_callable=AsyncMock) as mock_api, + patch("music_assistant.providers.twitch.ad_handling.patch_ad_handling"), + ): + await provider.handle_async_init() + + mock_api.assert_not_called() + assert provider._user_id is None + + +# --- Config Entries --- + + +async def test_config_entries_includes_auto_raid(mass_mock: Mock) -> None: + """get_config_entries() includes auto_raid config entry.""" + entries = await get_config_entries(mass_mock) + keys = {e.key for e in entries} + assert CONF_AUTO_RAID in keys + + +async def test_auto_raid_is_boolean(mass_mock: Mock) -> None: + """auto_raid config entry is BOOLEAN type.""" + entries = await get_config_entries(mass_mock) + entry = next(e for e in entries if e.key == CONF_AUTO_RAID) + assert entry.type == ConfigEntryType.BOOLEAN diff --git a/tests/providers/twitch/test_raid.py b/tests/providers/twitch/test_raid.py new file mode 100644 index 0000000000..5c6b89afb3 --- /dev/null +++ b/tests/providers/twitch/test_raid.py @@ -0,0 +1,455 @@ +"""Test raid following: ref-counted streams, multi-queue support.""" + +from __future__ import annotations + +import asyncio +import time +from unittest.mock import AsyncMock, Mock, patch + +from music_assistant_models.enums import PlaybackState + +from music_assistant.providers.twitch import TwitchProvider + +# --- Stream Tracking (ref counting) --- + + +async def test_track_stream_start_increments(provider: TwitchProvider) -> None: + """First stream for a channel sets count to 1.""" + provider._active_streams = {} + provider._unsubscribe_timers = {} + provider._auto_raid = False # avoid EventSub side effects + + provider._track_stream_start("streamer_a") + + assert provider._active_streams["streamer_a"] == 1 + + +async def test_track_stream_start_increments_multiple(provider: TwitchProvider) -> None: + """Second stream for same channel increments to 2.""" + provider._active_streams = {"streamer_a": 1} + provider._unsubscribe_timers = {} + provider._auto_raid = False + + provider._track_stream_start("streamer_a") + + assert provider._active_streams["streamer_a"] == 2 + + +async def test_track_stream_start_cancels_pending_unsubscribe(provider: TwitchProvider) -> None: + """Starting a stream cancels any pending delayed unsubscribe for that channel.""" + provider._active_streams = {} + provider._auto_raid = False + mock_timer = Mock() + mock_timer.cancel = Mock() + provider._unsubscribe_timers = {"streamer_a": mock_timer} + + provider._track_stream_start("streamer_a") + + mock_timer.cancel.assert_called_once() + assert "streamer_a" not in provider._unsubscribe_timers + + +async def test_track_stream_start_subscribes_on_first(provider: TwitchProvider) -> None: + """First stream (0->1) triggers EventSub subscription.""" + provider._active_streams = {} + provider._unsubscribe_timers = {} + provider._auto_raid = True + provider._access_token = "test" + provider._client_id = "test" + provider._eventsub = Mock() + provider._eventsub.subscribe_raids = AsyncMock() + provider._eventsub.start = AsyncMock() + + subscribed = asyncio.Event() + original_subscribe = provider._eventsub.subscribe_raids + + async def subscribe_and_signal(*args: object, **kwargs: object) -> None: + await original_subscribe(*args, **kwargs) + subscribed.set() + + provider._eventsub.subscribe_raids = AsyncMock(side_effect=subscribe_and_signal) + + with patch.object(provider, "_get_users", new_callable=AsyncMock, return_value=[{"id": "123"}]): + provider._track_stream_start("streamer_a") + await asyncio.wait_for(subscribed.wait(), timeout=1.0) + + provider._eventsub.subscribe_raids.assert_called_once_with("123") + + +async def test_track_stream_start_no_subscribe_on_second(provider: TwitchProvider) -> None: + """Second stream (1->2) does not trigger another subscription.""" + provider._active_streams = {"streamer_a": 1} + provider._unsubscribe_timers = {} + provider._auto_raid = True + provider._access_token = "test" + + with patch.object(provider, "_subscribe_raids_for_channel", new_callable=AsyncMock) as mock_sub: + provider._track_stream_start("streamer_a") + await asyncio.sleep(0) # yield to event loop — no task should be pending + + mock_sub.assert_not_called() + + +async def test_track_stream_end_decrements(provider: TwitchProvider) -> None: + """Ending one of two streams decrements count.""" + provider._active_streams = {"streamer_a": 2} + provider._unsubscribe_timers = {} + + provider._track_stream_end("streamer_a") + + assert provider._active_streams["streamer_a"] == 1 + + +async def test_track_stream_end_last_starts_grace_timer(provider: TwitchProvider) -> None: + """Ending last stream starts delayed unsubscribe timer.""" + provider._active_streams = {"streamer_a": 1} + provider._unsubscribe_timers = {} + + provider._track_stream_end("streamer_a") + + assert "streamer_a" not in provider._active_streams + assert "streamer_a" in provider._unsubscribe_timers + # Clean up the task + provider._unsubscribe_timers["streamer_a"].cancel() + + +# --- Delayed Unsubscribe --- + + +async def test_delayed_unsubscribe_calls_eventsub(provider: TwitchProvider) -> None: + """After grace period, unsubscribe_raids is called.""" + provider._eventsub = Mock() + provider._eventsub.unsubscribe_raids = AsyncMock() + provider._unsubscribe_timers = {"streamer_a": Mock()} + + with ( + patch.object(provider, "_get_users", new_callable=AsyncMock, return_value=[{"id": "123"}]), + patch("music_assistant.providers.twitch.asyncio.sleep", new_callable=AsyncMock), + ): + await provider._delayed_unsubscribe("streamer_a") + + provider._eventsub.unsubscribe_raids.assert_called_once_with("123") + + +# --- Raid Handling --- + + +async def test_raid_switches_playing_queues(provider: TwitchProvider) -> None: + """Raid switches all queues playing the raiding channel.""" + provider._active_streams = {"streamer_a": 1} + provider._unsubscribe_timers = {} + + queue1 = Mock() + queue1.state = PlaybackState.PLAYING + queue1.current_item = Mock() + queue1.current_item.streamdetails = Mock() + queue1.current_item.streamdetails.item_id = "streamer_a" + queue1.queue_id = "queue_1" + + queue2 = Mock() + queue2.state = PlaybackState.PLAYING + queue2.current_item = Mock() + queue2.current_item.streamdetails = Mock() + queue2.current_item.streamdetails.item_id = "streamer_a" + queue2.queue_id = "queue_2" + + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=(queue1, queue2) + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + assert provider.mass.player_queues.play_media.call_count == 2 + calls = provider.mass.player_queues.play_media.call_args_list + assert calls[0].kwargs["queue_id"] == "queue_1" + assert calls[1].kwargs["queue_id"] == "queue_2" + assert "streamer_c" in str(calls[0]) + + +async def test_raid_skips_paused_queues(provider: TwitchProvider) -> None: + """Raid does not switch paused queues.""" + provider._active_streams = {"streamer_a": 1} + provider._unsubscribe_timers = {} + + queue1 = Mock() + queue1.state = PlaybackState.PAUSED + queue1.current_item = Mock() + queue1.current_item.streamdetails = Mock() + queue1.current_item.streamdetails.item_id = "streamer_a" + + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=(queue1,) + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + provider.mass.player_queues.play_media.assert_not_called() + + +async def test_raid_skips_other_channels(provider: TwitchProvider) -> None: + """Raid only switches queues playing the raiding channel.""" + provider._active_streams = {"streamer_a": 1} + provider._unsubscribe_timers = {} + + queue1 = Mock() + queue1.state = PlaybackState.PLAYING + queue1.current_item = Mock() + queue1.current_item.streamdetails = Mock() + queue1.current_item.streamdetails.item_id = "streamer_b" # different channel + + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=(queue1,) + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + provider.mass.player_queues.play_media.assert_not_called() + + +async def test_raid_cleans_up_active_streams(provider: TwitchProvider) -> None: + """Raid removes the raiding channel from active_streams.""" + provider._active_streams = {"streamer_a": 2} + provider._unsubscribe_timers = {} + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=() + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + assert "streamer_a" not in provider._active_streams + + +async def test_raid_cancels_pending_unsubscribe(provider: TwitchProvider) -> None: + """Raid cancels any pending unsubscribe timer for the raiding channel.""" + provider._active_streams = {"streamer_a": 1} + mock_timer = Mock() + mock_timer.cancel = Mock() + provider._unsubscribe_timers = {"streamer_a": mock_timer} + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=() + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + mock_timer.cancel.assert_called_once() + assert "streamer_a" not in provider._unsubscribe_timers + + +async def test_stale_raid_ignored(provider: TwitchProvider) -> None: + """Raid from channel not in active_streams or grace period is ignored.""" + provider._active_streams = {} + provider._unsubscribe_timers = {} + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_b", "streamer_c") + + provider.mass.player_queues.play_media.assert_not_called() + + +async def test_raid_error_handled(provider: TwitchProvider) -> None: + """play_media error is logged, not raised.""" + provider._active_streams = {"streamer_a": 1} + provider._unsubscribe_timers = {} + + queue1 = Mock() + queue1.state = PlaybackState.PLAYING + queue1.current_item = Mock() + queue1.current_item.streamdetails = Mock() + queue1.current_item.streamdetails.item_id = "streamer_a" + queue1.queue_id = "queue_1" + + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=(queue1,) + ) + provider.mass.player_queues.play_media = AsyncMock( # type: ignore[method-assign] + side_effect=Exception("offline") + ) + + # Should not raise + await provider._on_raid("streamer_a", "streamer_c") + + +# --- Raid During Grace Period (IDLE queues) --- + + +async def test_raid_switches_idle_queue_in_grace_period(provider: TwitchProvider) -> None: + """Raid during grace period switches IDLE queues that were playing the raiding channel.""" + provider._active_streams = {} + # Simulate grace period — timer exists for this channel + mock_timer = Mock() + mock_timer.cancel = Mock() + provider._unsubscribe_timers = {"streamer_a": mock_timer} + + queue1 = Mock() + queue1.state = PlaybackState.IDLE + queue1.current_item = Mock() + queue1.current_item.streamdetails = Mock() + queue1.current_item.streamdetails.item_id = "streamer_a" + queue1.queue_id = "queue_1" + queue1.elapsed_time_last_updated = time.time() - 5 # idle for 5s + + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=(queue1,) + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + provider.mass.player_queues.play_media.assert_called_once() + + +async def test_raid_ignores_idle_queue_too_long(provider: TwitchProvider) -> None: + """IDLE queue that's been idle longer than 2x grace period is not switched.""" + provider._active_streams = {} + mock_timer = Mock() + mock_timer.cancel = Mock() + provider._unsubscribe_timers = {"streamer_a": mock_timer} + + queue1 = Mock() + queue1.state = PlaybackState.IDLE + queue1.current_item = Mock() + queue1.current_item.streamdetails = Mock() + queue1.current_item.streamdetails.item_id = "streamer_a" + queue1.queue_id = "queue_1" + queue1.elapsed_time_last_updated = time.time() - 60 # idle for 60s (> 30s threshold) + + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=(queue1,) + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + provider.mass.player_queues.play_media.assert_not_called() + + +async def test_raid_ignores_idle_queue_without_grace_period(provider: TwitchProvider) -> None: + """IDLE queue is not switched when NOT in grace period (no timer).""" + provider._active_streams = {"streamer_a": 1} + provider._unsubscribe_timers = {} # no grace timer + + queue1 = Mock() + queue1.state = PlaybackState.IDLE + queue1.current_item = Mock() + queue1.current_item.streamdetails = Mock() + queue1.current_item.streamdetails.item_id = "streamer_a" + queue1.queue_id = "queue_1" + + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=(queue1,) + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + provider.mass.player_queues.play_media.assert_not_called() + + +async def test_raid_ignores_paused_queue_in_grace_period(provider: TwitchProvider) -> None: + """Paused queue is never switched, even during grace period.""" + provider._active_streams = {} + mock_timer = Mock() + mock_timer.cancel = Mock() + provider._unsubscribe_timers = {"streamer_a": mock_timer} + + queue1 = Mock() + queue1.state = PlaybackState.PAUSED + queue1.current_item = Mock() + queue1.current_item.streamdetails = Mock() + queue1.current_item.streamdetails.item_id = "streamer_a" + queue1.queue_id = "queue_1" + + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=(queue1,) + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + provider.mass.player_queues.play_media.assert_not_called() + + +async def test_raid_in_grace_period_accepted(provider: TwitchProvider) -> None: + """Raid is accepted when channel is only in _unsubscribe_timers (not _active_streams).""" + provider._active_streams = {} + mock_timer = Mock() + mock_timer.cancel = Mock() + provider._unsubscribe_timers = {"streamer_a": mock_timer} + + # No matching queues (all ended), but raid should still be accepted and timer cancelled + provider.mass.player_queues.all = Mock( # type: ignore[method-assign] + return_value=() + ) + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + # Timer should be cancelled even with no matching queues + mock_timer.cancel.assert_called_once() + assert "streamer_a" not in provider._unsubscribe_timers + + +# --- Auto-Raid Toggle --- + + +async def test_auto_raid_disabled_ignores_raids(provider: TwitchProvider) -> None: + """With auto_raid=False, raid events are ignored.""" + provider._auto_raid = False + provider._active_streams = {"streamer_a": 1} + provider.mass.player_queues.play_media = AsyncMock() # type: ignore[method-assign] + + await provider._on_raid("streamer_a", "streamer_c") + + provider.mass.player_queues.play_media.assert_not_called() + + +async def test_auto_raid_disabled_no_subscribe(provider: TwitchProvider) -> None: + """With auto_raid=False, _subscribe_raids_for_channel is a no-op.""" + provider._auto_raid = False + provider._access_token = "test" + provider._eventsub = None + + await provider._subscribe_raids_for_channel("streamer_a") + + assert provider._eventsub is None + + +# --- Cleanup --- + + +async def test_unload_cancels_timers(provider: TwitchProvider) -> None: + """unload() cancels all pending unsubscribe timers.""" + mock_timer1 = Mock() + mock_timer1.cancel = Mock() + mock_timer2 = Mock() + mock_timer2.cancel = Mock() + provider._unsubscribe_timers = {"a": mock_timer1, "b": mock_timer2} + provider._active_streams = {"a": 1} + provider._eventsub = Mock() + provider._eventsub.stop = AsyncMock() + + await provider.unload() + + mock_timer1.cancel.assert_called_once() + mock_timer2.cancel.assert_called_once() + assert len(provider._unsubscribe_timers) == 0 + assert len(provider._active_streams) == 0 + + +async def test_unload_stops_eventsub(provider: TwitchProvider) -> None: + """unload() stops EventSub and sets it to None.""" + eventsub_mock = Mock() + eventsub_mock.stop = AsyncMock() + provider._eventsub = eventsub_mock + provider._unsubscribe_timers = {} + provider._active_streams = {} + + await provider.unload() + + eventsub_mock.stop.assert_called_once() + assert provider._eventsub is None diff --git a/tests/providers/twitch/test_recommendations.py b/tests/providers/twitch/test_recommendations.py new file mode 100644 index 0000000000..77ed6dfa5b --- /dev/null +++ b/tests/providers/twitch/test_recommendations.py @@ -0,0 +1,128 @@ +"""Test Twitch Provider recommendations.""" + +from __future__ import annotations + +from music_assistant_models.media_items import Radio, RecommendationFolder + +from music_assistant.providers.twitch import TwitchProvider +from tests.providers.twitch.conftest import ( + MockResponse, + load_fixture, + make_mock_session_method, +) + + +def _users_response() -> MockResponse: + """Return a mock users API response.""" + return MockResponse(status=200, json_data=load_fixture("user_lookup.json")) + + +def _setup_authenticated_provider(provider: TwitchProvider) -> None: + """Configure provider with test credentials and cached data.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + provider._user_id = "99" + + +# --- Recommendations --- + + +async def test_recommendations_returns_live_channels_folder(provider: TwitchProvider) -> None: + """recommendations() returns a single RecommendationFolder with live Radio items.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + result = await provider.recommendations() + assert len(result) == 1 + folder = result[0] + assert isinstance(folder, RecommendationFolder) + assert folder.name == "Twitch Live Channels" + assert len(folder.items) > 0 + assert all(isinstance(item, Radio) for item in folder.items) + + +async def test_recommendations_folder_contains_only_live(provider: TwitchProvider) -> None: + """Offline followed channels are not in the recommendations folder.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + # Return all 3 channels (page1 + page2) + MockResponse(status=200, json_data=fixture_channels["page1"]), + MockResponse(status=200, json_data=fixture_channels["page2"]), + # Only streamer_a and streamer_c are live + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + result = await provider.recommendations() + assert len(result) == 1 + logins = [item.item_id for item in result[0].items] + assert "streamer_a" in logins + assert "streamer_c" in logins + assert "streamer_b" not in logins + + +async def test_recommendations_empty_when_none_live(provider: TwitchProvider) -> None: + """Returns empty list (no folder) when no followed channels are live.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data={"data": []}), # nobody live + _users_response(), + ] + ) + + result = await provider.recommendations() + assert result == [] + + +async def test_recommendations_requires_auth(provider: TwitchProvider) -> None: + """Returns empty list when not authenticated.""" + provider._access_token = None + provider._user_id = None + + result = await provider.recommendations() + assert result == [] + + +async def test_recommendations_folder_metadata(provider: TwitchProvider) -> None: + """RecommendationFolder has correct name, icon, and provider.""" + _setup_authenticated_provider(provider) + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + _users_response(), + ] + ) + + result = await provider.recommendations() + assert len(result) == 1 + folder = result[0] + assert folder.name == "Twitch Live Channels" + assert folder.icon == "mdi-broadcast" + assert folder.provider == provider.instance_id + assert folder.item_id == f"{provider.instance_id}_live_channels" diff --git a/tests/providers/twitch/test_streaming.py b/tests/providers/twitch/test_streaming.py new file mode 100644 index 0000000000..dbc695ac23 --- /dev/null +++ b/tests/providers/twitch/test_streaming.py @@ -0,0 +1,430 @@ +"""Test Twitch Provider core audio streaming.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from music_assistant_models.enums import ContentType, MediaType, StreamType +from music_assistant_models.media_items import AudioFormat +from music_assistant_models.streamdetails import StreamDetails + +from music_assistant.providers.twitch import ( + MAX_CONSECUTIVE_RECONNECTS, + RECONNECT_DELAY, + STREAM_CHUNK_SIZE, + TwitchProvider, +) + + +@pytest.fixture(autouse=True) +def _mock_streamlink_session(provider: TwitchProvider) -> None: + """Auto-mock _create_streamlink_session so get_audio_stream doesn't need real Streamlink.""" + provider._create_streamlink_session = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] + + +# --- Stream Details --- + + +async def test_stream_details_returns_custom_type(provider: TwitchProvider) -> None: + """get_stream_details() returns StreamDetails with stream_type=StreamType.CUSTOM.""" + details = await provider.get_stream_details("testchannel", MediaType.RADIO) + assert details.stream_type == StreamType.CUSTOM + + +async def test_stream_details_media_type_is_radio(provider: TwitchProvider) -> None: + """media_type is RADIO.""" + details = await provider.get_stream_details("testchannel", MediaType.RADIO) + assert details.media_type == MediaType.RADIO + + +async def test_stream_details_no_seek(provider: TwitchProvider) -> None: + """Live streams cannot be seeked.""" + details = await provider.get_stream_details("testchannel", MediaType.RADIO) + assert details.allow_seek is False + assert details.can_seek is False + + +async def test_stream_details_provider_set(provider: TwitchProvider) -> None: + """Provider field matches self.instance_id.""" + details = await provider.get_stream_details("testchannel", MediaType.RADIO) + assert details.provider == provider.instance_id + + +async def test_stream_details_content_type_unknown(provider: TwitchProvider) -> None: + """Content type is UNKNOWN (let ffmpeg detect from MPEG-TS stream).""" + details = await provider.get_stream_details("testchannel", MediaType.RADIO) + assert details.audio_format.content_type == ContentType.UNKNOWN + + +# --- Audio Stream — Happy Path --- + + +@pytest.fixture +def stream_details(provider: TwitchProvider) -> StreamDetails: + """Return StreamDetails for a test channel.""" + return StreamDetails( + provider=provider.instance_id, + item_id="testchannel", + audio_format=AudioFormat(content_type=ContentType.UNKNOWN), + media_type=MediaType.RADIO, + stream_type=StreamType.CUSTOM, + ) + + +@pytest.fixture +def mock_streamlink_stream() -> tuple[MagicMock, MagicMock]: + """Return a mock Streamlink stream with fd that yields chunks then closes.""" + mock_fd = MagicMock() + mock_stream = MagicMock() + mock_stream.open.return_value = mock_fd + return mock_stream, mock_fd + + +async def test_yields_bytes_from_streamlink( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """get_audio_stream() yields non-empty bytes chunks from a mock Streamlink stream.""" + chunk1 = b"\x00" * 1024 + chunk2 = b"\xff" * 1024 + + mock_fd = MagicMock() + mock_fd.read.side_effect = [chunk1, chunk2, b""] + mock_fd.close.return_value = None + + mock_stream = MagicMock() + mock_stream.open.return_value = mock_fd + + with patch.object(provider, "_resolve_streams", return_value={"audio_only": mock_stream}): + chunks = [] + async for chunk in provider.get_audio_stream(stream_details): + chunks.append(chunk) + if len(chunks) >= 2: + # Simulate: after 2 chunks, empty read triggers reconnect path + # which will fail since _resolve_streams returns None on 2nd call + break + + assert len(chunks) == 2 + assert chunks[0] == chunk1 + assert chunks[1] == chunk2 + + +async def test_uses_audio_only_quality(provider: TwitchProvider) -> None: + """Streamlink quality selection picks audio_only when available.""" + streams = {"audio_only": "audio_stream", "worst": "worst_stream", "720p": "hd_stream"} + result = provider._select_quality(streams) + assert result == "audio_stream" + + +async def test_falls_back_to_worst_quality(provider: TwitchProvider) -> None: + """When audio_only unavailable, selects worst.""" + streams = {"worst": "worst_stream", "720p": "hd_stream", "1080p": "fhd_stream"} + result = provider._select_quality(streams) + assert result == "worst_stream" + + +async def test_returns_none_when_no_qualities(provider: TwitchProvider) -> None: + """When no matching qualities, returns None.""" + streams = {"720p": "hd_stream", "1080p": "fhd_stream"} + result = provider._select_quality(streams) + assert result is None + + +async def test_streamlink_called_via_to_thread( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """Verify get_audio_stream dispatches blocking Streamlink calls via asyncio.to_thread.""" + mock_fd = MagicMock() + read_results = iter([b"chunk", b""]) + mock_fd.read.side_effect = lambda *_args: next(read_results, b"") + mock_fd.close.return_value = None + + mock_stream = MagicMock() + mock_stream.open.return_value = mock_fd + + resolve_results = iter([{"audio_only": mock_stream}]) + + with patch.object( + provider, "_resolve_streams", side_effect=lambda *_a: next(resolve_results, None) + ): + to_thread_calls: list[tuple[Any, ...]] = [] + original_to_thread = asyncio.to_thread + + async def tracking_to_thread(func: Any, /, *args: Any) -> Any: + to_thread_calls.append((func, *args)) + return await original_to_thread(func, *args) + + with patch("asyncio.to_thread", side_effect=tracking_to_thread): + async for _ in provider.get_audio_stream(stream_details): + pass + + called_funcs = [call[0] for call in to_thread_calls] + # Blocking Streamlink operations must go through to_thread + assert provider._create_streamlink_session in called_funcs + assert mock_stream.open in called_funcs + assert mock_fd.read in called_funcs + assert mock_fd.close in called_funcs + + +async def test_chunk_size_is_64kb() -> None: + """Read chunks are 64KB.""" + assert STREAM_CHUNK_SIZE == 64 * 1024 + + +# --- Audio Stream — Streamlink Token --- + + +async def test_streamlink_token_passed_as_header(provider: TwitchProvider) -> None: + """When streamlink_token configured, Streamlink session gets OAuth header.""" + mock_session = MagicMock() + + provider.config.get_value.side_effect = lambda key, default=None: { # type: ignore[attr-defined] + "streamlink_token": "test_oauth_token", + "log_level": "GLOBAL", + }.get(key, default) + + with patch("music_assistant.providers.twitch.Streamlink", return_value=mock_session): + TwitchProvider._create_streamlink_session(provider) + + # set_option called for queue deadline + OAuth header + assert mock_session.set_option.call_count == 2 + calls = mock_session.set_option.call_args_list + assert calls[0].args == ("stream-segmented-queue-deadline", 6) + assert calls[1].args[0] == "http-headers" + assert "OAuth test_oauth_token" in str(calls[1]) + + +async def test_streamlink_token_omitted_when_empty(provider: TwitchProvider) -> None: + """When streamlink_token not set, no extra auth header on Streamlink.""" + mock_session = MagicMock() + + provider.config.get_value.side_effect = lambda key, default=None: { # type: ignore[attr-defined] + "streamlink_token": "", + "log_level": "GLOBAL", + }.get(key, default) + + with patch("music_assistant.providers.twitch.Streamlink", return_value=mock_session): + TwitchProvider._create_streamlink_session(provider) + + # Only the queue deadline option — no OAuth header + mock_session.set_option.assert_called_once_with("stream-segmented-queue-deadline", 6) + + +async def test_invalid_streamlink_token_stream_still_plays(provider: TwitchProvider) -> None: + """Bad/expired streamlink_token doesn't prevent playback — session still created.""" + mock_session = MagicMock() + + provider.config.get_value.side_effect = lambda key, default=None: { # type: ignore[attr-defined] + "streamlink_token": "invalid_expired_token", + "log_level": "GLOBAL", + }.get(key, default) + + with patch("music_assistant.providers.twitch.Streamlink", return_value=mock_session): + result = TwitchProvider._create_streamlink_session(provider) + + # Session should still be created despite potentially invalid token + assert result is mock_session + + +async def test_reconnect_delay_is_half_second() -> None: + """Reconnect delay between attempts is 0.5 seconds.""" + assert RECONNECT_DELAY == 0.5 + + +# --- Audio Stream — Reconnection --- + + +async def test_reconnects_on_empty_read( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """Empty read triggers stream close + Streamlink re-resolution + continued yielding.""" + mock_fd1 = MagicMock() + mock_fd1.read.side_effect = [b"chunk1", b""] # data then empty + mock_fd1.close.return_value = None + + mock_fd2 = MagicMock() + mock_fd2.read.side_effect = [b"chunk2", b""] # data then empty on reconnect + mock_fd2.close.return_value = None + + mock_stream1 = MagicMock() + mock_stream1.open.return_value = mock_fd1 + + mock_stream2 = MagicMock() + mock_stream2.open.return_value = mock_fd2 + + resolve_calls = [ + {"audio_only": mock_stream1}, + {"audio_only": mock_stream2}, + None, # third resolve fails — end + ] + + with patch.object(provider, "_resolve_streams", side_effect=resolve_calls): + chunks = [] + async for chunk in provider.get_audio_stream(stream_details): + chunks.append(chunk) + + assert b"chunk1" in chunks + assert b"chunk2" in chunks + + +async def test_reconnect_resets_counter_on_success( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """After reconnect, a successful chunk read resets the consecutive failure counter.""" + # First stream: data, then empty (triggers reconnect, counter=1) + mock_fd1 = MagicMock() + mock_fd1.read.side_effect = [b"data1", b""] + mock_fd1.close.return_value = None + mock_stream1 = MagicMock() + mock_stream1.open.return_value = mock_fd1 + + # Second stream: data (resets counter to 0), then empty (counter=1 again) + mock_fd2 = MagicMock() + mock_fd2.read.side_effect = [b"data2", b""] + mock_fd2.close.return_value = None + mock_stream2 = MagicMock() + mock_stream2.open.return_value = mock_fd2 + + resolve_calls = [ + {"audio_only": mock_stream1}, + {"audio_only": mock_stream2}, + None, # end + ] + + with patch.object(provider, "_resolve_streams", side_effect=resolve_calls): + chunks = [] + async for chunk in provider.get_audio_stream(stream_details): + chunks.append(chunk) + + # Both chunks received — counter was reset between them + assert chunks == [b"data1", b"data2"] + + +async def test_max_consecutive_reconnects( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """After MAX_CONSECUTIVE_RECONNECTS consecutive empty reads, generator ends.""" + + def make_empty_stream() -> dict[str, Any]: + fd = MagicMock() + fd.read.return_value = b"" + fd.close.return_value = None + s = MagicMock() + s.open.return_value = fd + return {"audio_only": s} + + # Return streams for each reconnect attempt, plus the initial + resolve_calls = [make_empty_stream() for _ in range(MAX_CONSECUTIVE_RECONNECTS + 2)] + + with patch.object(provider, "_resolve_streams", side_effect=resolve_calls): + chunks = [] + async for chunk in provider.get_audio_stream(stream_details): + chunks.append(chunk) + + assert chunks == [] + + +async def test_generator_ends_on_resolve_failure( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """If Streamlink re-resolution returns no streams, generator ends cleanly.""" + with patch.object(provider, "_resolve_streams", return_value=None): + chunks = [] + async for chunk in provider.get_audio_stream(stream_details): + chunks.append(chunk) + assert chunks == [] + + +async def test_fd_closed_before_reconnect( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """Old stream fd is closed before attempting re-resolution.""" + call_order: list[str] = [] + + mock_fd = MagicMock() + mock_fd.read.side_effect = [b"data", b""] + + def tracking_close() -> None: + call_order.append("close") + + mock_fd.close.side_effect = tracking_close + + mock_stream = MagicMock() + mock_stream.open.return_value = mock_fd + + resolve_results = iter([{"audio_only": mock_stream}, None]) + + def tracking_resolve(_channel: str, _session: Any = None) -> dict[str, Any] | None: + call_order.append("resolve") + return next(resolve_results) + + with patch.object(provider, "_resolve_streams", side_effect=tracking_resolve): + async for _ in provider.get_audio_stream(stream_details): + pass + + # There should be two resolve calls and one close between them + # Pattern: resolve(initial), close, resolve(reconnect attempt) + assert call_order.count("close") >= 1 + assert call_order.count("resolve") >= 2 + # First close must come before the second resolve + first_close = call_order.index("close") + second_resolve = len(call_order) - 1 - call_order[::-1].index("resolve") + assert first_close < second_resolve + + +# --- Audio Stream — Error Cases --- + + +async def test_offline_channel_returns_empty( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """get_audio_stream() for offline channel yields nothing (resolve returns None).""" + with patch.object(provider, "_resolve_streams", return_value=None): + chunks = [] + async for chunk in provider.get_audio_stream(stream_details): + chunks.append(chunk) + assert chunks == [] + + +async def test_streamlink_plugin_error_returns_empty( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """Streamlink PluginError during resolution is caught, generator ends cleanly. + + Mocks at the Streamlink session level to exercise the except block in + _resolve_streams(), not at _resolve_streams() itself. + """ + mock_session = MagicMock() + mock_session.streams.side_effect = Exception("No plugin can handle URL") + provider._create_streamlink_session = MagicMock( # type: ignore[method-assign] + return_value=mock_session + ) + + chunks = [] + async for chunk in provider.get_audio_stream(stream_details): + chunks.append(chunk) + assert chunks == [] + + +async def test_exception_during_read_closes_fd( + provider: TwitchProvider, stream_details: StreamDetails +) -> None: + """Exception from fd.read() still closes the fd via finally block.""" + mock_fd = MagicMock() + mock_fd.read.side_effect = OSError("read failed") + mock_fd.close.return_value = None + + mock_stream = MagicMock() + mock_stream.open.return_value = mock_fd + + with ( + patch.object(provider, "_resolve_streams", return_value={"audio_only": mock_stream}), + pytest.raises(OSError, match="read failed"), + ): + async for _ in provider.get_audio_stream(stream_details): + pass + + # fd.close was still called (via finally) + mock_fd.close.assert_called() diff --git a/tests/providers/twitch/test_twitch_api.py b/tests/providers/twitch/test_twitch_api.py new file mode 100644 index 0000000000..4a8796c2a0 --- /dev/null +++ b/tests/providers/twitch/test_twitch_api.py @@ -0,0 +1,297 @@ +"""Test Twitch API client: pagination, batching, caching, error handling.""" + +from __future__ import annotations + +import time + +import pytest +from music_assistant_models.errors import LoginFailed + +from music_assistant.providers.twitch import TwitchProvider +from tests.providers.twitch.conftest import ( + MockResponse, + load_fixture, + make_mock_session_method, +) + +# --- Request Pattern --- + + +async def test_get_includes_auth_headers(provider: TwitchProvider) -> None: + """GET requests include Authorization and Client-Id headers.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data={"data": []}) + ) + + await provider._api_get("/helix/streams") + + call_kwargs = provider.mass.http_session.get.call_args + headers = call_kwargs.kwargs.get("headers", {}) + assert headers["Authorization"] == "Bearer test_token" + assert headers["Client-Id"] == "test_client" + + +async def test_unauthenticated_request_raises(provider: TwitchProvider) -> None: + """API call without tokens raises LoginFailed on 401.""" + provider._access_token = None + provider._refresh_token = None + provider._client_id = "test_client" + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=401) + ) + provider.mass.http_session.post = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=401, text_data="bad refresh") + ) + + with pytest.raises(LoginFailed): + await provider._api_get("/helix/users") + + +async def test_non_200_raises(provider: TwitchProvider) -> None: + """Non-success status codes raise an exception.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=500) + ) + + with pytest.raises(Exception, match=r"500"): + await provider._api_get("/helix/streams") + + +# --- Pagination --- + + +async def test_followed_channels_paginates(provider: TwitchProvider) -> None: + """Multiple pages fetched via cursor until no cursor returned.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + provider._user_id = "99" + + fixture = load_fixture("followed_channels.json") + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture["page1"]), + MockResponse(status=200, json_data=fixture["page2"]), + ] + ) + + channels = await provider._get_followed_channels() + assert len(channels) == 3 + assert provider.mass.http_session.get.call_count == 2 + + +async def test_single_page_no_extra_requests(provider: TwitchProvider) -> None: + """When no cursor in response, only one request made.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + provider._user_id = "99" + + fixture = load_fixture("followed_channels.json") + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data=fixture["page2"]) + ) + + channels = await provider._get_followed_channels() + assert len(channels) == 1 + assert provider.mass.http_session.get.call_count == 1 + + +# --- Batching --- + + +async def test_live_streams_batches_over_100(provider: TwitchProvider) -> None: + """150 user IDs split into batch of 100 + batch of 50.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + + user_ids = [str(i) for i in range(150)] + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data={"data": [{"user_id": "1"}]}), + MockResponse(status=200, json_data={"data": [{"user_id": "101"}]}), + ] + ) + + streams = await provider._get_live_streams(user_ids) + assert len(streams) == 2 + assert provider.mass.http_session.get.call_count == 2 + + +async def test_live_streams_empty_input(provider: TwitchProvider) -> None: + """Empty user ID list returns empty without API call.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + + streams = await provider._get_live_streams([]) + assert streams == [] + + +async def test_user_profiles_batches_over_100(provider: TwitchProvider) -> None: + """150 user IDs split into batch of 100 + batch of 50 for profiles.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + + user_ids = [str(i) for i in range(150)] + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data={"data": [{"id": "1", "login": "a"}]}), + MockResponse(status=200, json_data={"data": [{"id": "101", "login": "b"}]}), + ] + ) + + profiles = await provider._get_user_profiles(user_ids) + assert len(profiles) == 2 + assert provider.mass.http_session.get.call_count == 2 + + +# --- Caching --- + + +async def test_live_status_cached_within_ttl(provider: TwitchProvider) -> None: + """Second call within 5 minutes returns cached result, no API call.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + provider._user_id = "99" + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + + fixture_users = load_fixture("user_lookup.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + MockResponse(status=200, json_data=fixture_users), + ] + ) + + # First call — fetches from API + await provider._get_followed_live_status() + call_count = provider.mass.http_session.get.call_count + + # Second call — should use cache + await provider._get_followed_live_status() + assert provider.mass.http_session.get.call_count == call_count # no new calls + + +async def test_live_status_refreshed_after_ttl(provider: TwitchProvider) -> None: + """Call after 5 minutes makes fresh API request.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + provider._user_id = "99" + + fixture_channels = load_fixture("followed_channels.json") + fixture_streams = load_fixture("live_streams.json") + fixture_users = load_fixture("user_lookup.json") + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + # First fetch + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + MockResponse(status=200, json_data=fixture_users), + # Second fetch after TTL + MockResponse(status=200, json_data=fixture_channels["page2"]), + MockResponse(status=200, json_data=fixture_streams), + MockResponse(status=200, json_data=fixture_users), + ] + ) + + await provider._get_followed_live_status() + call_count_after_first = provider.mass.http_session.get.call_count + + # Expire the cache + provider._cache_time = time.monotonic() - 301 # past 5 min TTL + + await provider._get_followed_live_status() + assert provider.mass.http_session.get.call_count > call_count_after_first + + +async def test_cache_cleared_on_logout(provider: TwitchProvider) -> None: + """Logout invalidates the cache.""" + provider._cached_channels = [{"broadcaster_id": "123"}] + provider._cached_live = {"streamer_a": {"viewer_count": 100}} + provider._cache_time = time.monotonic() + + provider._clear_cache() + + assert provider._cached_channels is None + assert provider._cached_live is None # type: ignore[unreachable] + assert provider._cache_time == 0.0 + + +# --- User Lookup --- + + +async def test_get_users_resolves_login_to_id(provider: TwitchProvider) -> None: + """GET /users?login=X returns user dict with numeric ID.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + + fixture = load_fixture("user_lookup.json") + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + MockResponse(status=200, json_data=fixture) + ) + + users = await provider._get_users(logins=["streamer_a"]) + assert len(users) == 1 + assert users[0]["id"] == "123" + assert users[0]["login"] == "streamer_a" + + +async def test_get_users_empty_returns_empty(provider: TwitchProvider) -> None: + """Empty logins list returns empty without API call.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + + users = await provider._get_users(logins=[]) + assert users == [] + + +async def test_followed_channels_aggregates_pages(provider: TwitchProvider) -> None: + """Results from all pages combined into single list with correct content.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + provider._user_id = "99" + + fixture = load_fixture("followed_channels.json") + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data=fixture["page1"]), + MockResponse(status=200, json_data=fixture["page2"]), + ] + ) + + channels = await provider._get_followed_channels() + logins = [ch["broadcaster_login"] for ch in channels] + assert "streamer_a" in logins + assert "streamer_b" in logins + assert "streamer_c" in logins + + +async def test_live_streams_aggregates_batches(provider: TwitchProvider) -> None: + """Results from all batches combined into single list.""" + provider._access_token = "test_token" + provider._client_id = "test_client" + + provider.mass.http_session.get = make_mock_session_method( # type: ignore[method-assign] + [ + MockResponse(status=200, json_data={"data": [{"user_id": "1", "user_login": "a"}]}), + MockResponse(status=200, json_data={"data": [{"user_id": "101", "user_login": "b"}]}), + ] + ) + + user_ids = [str(i) for i in range(150)] + streams = await provider._get_live_streams(user_ids) + logins = [s["user_login"] for s in streams] + assert "a" in logins + assert "b" in logins