diff --git a/music_assistant/controllers/webserver/auth.py b/music_assistant/controllers/webserver/auth.py index 650737f9d0..e66c538b98 100644 --- a/music_assistant/controllers/webserver/auth.py +++ b/music_assistant/controllers/webserver/auth.py @@ -49,7 +49,7 @@ LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth") # Database schema version -DB_SCHEMA_VERSION = 5 +DB_SCHEMA_VERSION = 6 # Token expiration constants (in days) TOKEN_SHORT_LIVED_EXPIRATION = 30 # Short-lived tokens (auto-renewing on use) @@ -219,6 +219,7 @@ async def _create_database_tables(self) -> None: use_count INTEGER DEFAULT 0, last_used_at TEXT, device_name TEXT, + instance_id TEXT, FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE ) """ @@ -301,6 +302,12 @@ async def _migrate_database(self, from_version: int) -> None: ) await self.database.commit() + # Migration to version 6: Add instance_id column to join_codes + if from_version < 6: + with contextlib.suppress(OperationalError): + await self.database.execute("ALTER TABLE join_codes ADD COLUMN instance_id TEXT") + await self.database.commit() + async def _get_or_create_jwt_secret(self) -> str: """Get or create JWT secret key from database. @@ -872,15 +879,21 @@ async def update_provider_link( # Create new link await self.link_user_to_provider(user, provider_type, provider_user_id) - async def create_token(self, user: User, name: str, is_long_lived: bool = False) -> str: - """ - Create a new JWT access token for a user. + async def create_token( + self, + user: User, + name: str, + is_long_lived: bool = False, + extra_claims: dict[str, Any] | None = None, + ) -> str: + """Create a new JWT access token for a user. :param user: The user to create the token for. :param name: A name/description for the token (e.g., device name). :param is_long_lived: Whether this is a long-lived token (default: False). Short-lived tokens (False): Auto-renewing on use, expire after 30 days of inactivity. Long-lived tokens (True): No auto-renewal, expire after 10 years. + :param extra_claims: Optional extra claims to embed in the JWT. :return: JWT token string. """ # Generate unique token ID @@ -902,6 +915,7 @@ async def create_token(self, user: User, name: str, is_long_lived: bool = False) token_name=name, expires_at=expires_at, is_long_lived=is_long_lived, + extra_claims=extra_claims, ) # Store token hash in database for revocation checking @@ -1391,6 +1405,22 @@ async def delete_user(self, user_id: str) -> None: admin_user.username, ) + async def delete_user_internal(self, user_id: str) -> None: + """Delete a user account programmatically (no auth context required). + + Used internally by providers to clean up system-managed users (e.g., party guests). + + :param user_id: The user ID to delete. + """ + user_row = await self.database.get_row("users", {"user_id": user_id}) + if not user_row: + return + + await self.database.delete("users", {"user_id": user_id}) + await self.database.commit() + self.webserver.disconnect_websockets_for_user(user_id) + self.logger.info("Internally deleted user '%s'", user_row["username"]) + @api_command("auth/me") async def get_current_user_info(self) -> User: """Get current authenticated user information.""" @@ -1612,6 +1642,7 @@ async def generate_join_code( expires_in_hours: int = JOIN_CODE_DEFAULT_EXPIRY_HOURS, max_uses: int = 1, device_name: str = "Short Code Login", + instance_id: str | None = None, ) -> tuple[str, datetime]: """Generate a short join code for link/QR-based login. @@ -1623,6 +1654,7 @@ async def generate_join_code( :param expires_in_hours: Hours until code expires (default: 8). :param max_uses: Maximum number of uses (0 = unlimited). :param device_name: Device name for tokens created with this code. + :param instance_id: Optional provider instance ID to embed in the resulting JWT. :return: Tuple of (code, expires_at datetime). """ if expires_in_hours <= 0: @@ -1646,6 +1678,7 @@ async def generate_join_code( "max_uses": max_uses, "use_count": 0, "device_name": device_name, + "instance_id": instance_id, } try: await self.database.insert("join_codes", code_data) @@ -1681,7 +1714,7 @@ async def _exchange_join_code(self, code: str) -> str | None: WHERE code = :code AND expires_at > :now AND (max_uses = 0 OR use_count < max_uses) - RETURNING user_id, device_name + RETURNING user_id, device_name, instance_id """, {"now": now.isoformat(), "code": code.upper()}, ) @@ -1700,10 +1733,14 @@ async def _exchange_join_code(self, code: str) -> str | None: return None device_name = row["device_name"] or "Short Code Login" + extra_claims: dict[str, Any] | None = None + if row["instance_id"]: + extra_claims = {"party_instance": row["instance_id"]} token = await self.create_token( user, device_name, is_long_lived=False, + extra_claims=extra_claims, ) self.logger.info( @@ -1729,24 +1766,60 @@ async def revoke_join_codes(self, user: User) -> int: self.logger.info("Revoked %d join code(s) for user %s", count, user.username) return count - async def get_active_join_code(self, user: User) -> str | None: + async def revoke_join_codes_for_instance(self, instance_id: str) -> int: + """Revoke all join codes associated with a specific provider instance. + + :param instance_id: The provider instance ID. + :return: Number of codes revoked. + """ + cursor = await self.database.execute( + "DELETE FROM join_codes WHERE instance_id = :instance_id", + {"instance_id": instance_id}, + ) + await self.database.commit() + + count = int(cursor.rowcount) + if count > 0: + self.logger.info("Revoked %d join code(s) for instance %s", count, instance_id) + return count + + async def get_active_join_code(self, user: User, instance_id: str | None = None) -> str | None: """Get the most recently created, non-expired join code for a user. :param user: The user to look up codes for. + :param instance_id: Optional instance ID to filter by. :return: The join code string if found, None otherwise. """ now = utc() - cursor = await self.database.execute( - """ - SELECT code FROM join_codes - WHERE user_id = :user_id - AND expires_at > :now - AND (max_uses = 0 OR use_count < max_uses) - ORDER BY created_at DESC - LIMIT 1 - """, - {"user_id": user.user_id, "now": now.isoformat()}, - ) + if instance_id: + cursor = await self.database.execute( + """ + SELECT code FROM join_codes + WHERE user_id = :user_id + AND instance_id = :instance_id + AND expires_at > :now + AND (max_uses = 0 OR use_count < max_uses) + ORDER BY created_at DESC + LIMIT 1 + """, + { + "user_id": user.user_id, + "instance_id": instance_id, + "now": now.isoformat(), + }, + ) + else: + cursor = await self.database.execute( + """ + SELECT code FROM join_codes + WHERE user_id = :user_id + AND expires_at > :now + AND (max_uses = 0 OR use_count < max_uses) + ORDER BY created_at DESC + LIMIT 1 + """, + {"user_id": user.user_id, "now": now.isoformat()}, + ) row = await cursor.fetchone() return str(row["code"]) if row else None diff --git a/music_assistant/helpers/jwt_auth.py b/music_assistant/helpers/jwt_auth.py index c7d6bb03a2..fd31e60bc3 100644 --- a/music_assistant/helpers/jwt_auth.py +++ b/music_assistant/helpers/jwt_auth.py @@ -43,6 +43,7 @@ def encode_token( token_name: str, expires_at: datetime, is_long_lived: bool = False, + extra_claims: dict[str, Any] | None = None, ) -> str: """Encode a JWT token for a user. @@ -51,10 +52,11 @@ def encode_token( :param token_name: Human-readable token name. :param expires_at: Token expiration datetime. :param is_long_lived: Whether this is a long-lived token. + :param extra_claims: Optional extra claims to include in the JWT payload. :return: Encoded JWT token string. """ now = utc() - payload = { + payload: dict[str, Any] = { "sub": user.user_id, "jti": token_id, "iat": int(now.timestamp()), @@ -64,6 +66,8 @@ def encode_token( "token_name": token_name, "is_long_lived": is_long_lived, } + if extra_claims: + payload["extra_claims"] = extra_claims return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) diff --git a/music_assistant/providers/party/__init__.py b/music_assistant/providers/party/__init__.py index 48317dc5fb..3d90550647 100644 --- a/music_assistant/providers/party/__init__.py +++ b/music_assistant/providers/party/__init__.py @@ -24,7 +24,10 @@ from music_assistant_models.queue_item import QueueItem from music_assistant.constants import DEFAULT_PORT -from music_assistant.controllers.webserver.helpers.auth_middleware import get_current_user +from music_assistant.controllers.webserver.helpers.auth_middleware import ( + get_current_token, + get_current_user, +) from music_assistant.models.plugin import PluginProvider if TYPE_CHECKING: @@ -78,9 +81,8 @@ ("Yellow", "#FFEB3B"), ] -# Guest user configuration -PARTY_GUEST_USER = "party_guest" -PARTY_GUEST_DISPLAY_NAME = "Party Guest" +# Shared guest user for all party instances (JWT claims differentiate instances) +PARTY_GUEST_USERNAME = "party_guest" # Extra attribute keys for tracking guest items in the queue ATTR_PARTY_GUEST = "party_guest" @@ -93,6 +95,9 @@ class PartyConfig(DataClassDictMixin): """Configuration data returned to the party guest frontend.""" + # Instance identification + instance_id: str + name: str # Feature toggles enable_rate_limiting: bool enable_add_queue: bool @@ -116,6 +121,8 @@ class PartyConfig(DataClassDictMixin): boost_badge_color: str # Anti burn-in anti_burn_in: bool + # Multi-instance context + instance_count: int async def setup( @@ -127,7 +134,7 @@ async def setup( async def get_config_entries( mass: MusicAssistant, - instance_id: str | None = None, # noqa: ARG001 + instance_id: str | None = None, action: str | None = None, # noqa: ARG001 values: dict[str, ConfigValueType] | None = None, # noqa: ARG001 ) -> tuple[ConfigEntry, ...]: @@ -138,6 +145,25 @@ async def get_config_entries( :param action: Optional action key called from config entries UI. :param values: The (intermediate) raw values for config entries sent with the action. """ + # Filter out players already assigned to other party instances + used_players: set[str] = set() + for other in mass.get_provider_instances("party"): + if other.instance_id == instance_id: + continue + other_player = mass.config.get_raw_provider_config_value( + other.instance_id, CONF_PARTY_PLAYER + ) + if other_player: + used_players.add(str(other_player)) + + player_options: list[ConfigValueOption] = [] + for player in sorted( + mass.players.all_players(False, False), key=lambda p: p.display_name.lower() + ): + if player.player_id in used_players: + continue + player_options.append(ConfigValueOption(player.display_name, player.player_id)) + return ( ConfigEntry( key=CONF_ENABLE_GUEST_ACCESS, @@ -155,12 +181,7 @@ async def get_config_entries( required=True, label="Party Player", description="Select which player/queue guests will add songs to.", - options=[ - ConfigValueOption(player.display_name, player.player_id) - for player in sorted( - mass.players.all_players(False, False), key=lambda p: p.display_name.lower() - ) - ], + options=player_options, ), ConfigEntry( key=CONF_PARTY_DISPLAY_LYRICS, @@ -390,27 +411,35 @@ def __init__( self._unregister_handles: list[Callable[[], None]] = [] self._queue_lock = asyncio.Lock() + @property + def _guest_username(self) -> str: + """Return the shared party guest username.""" + return PARTY_GUEST_USERNAME + async def loaded_in_mass(self) -> None: """Call after the provider has been loaded.""" - # Register API commands and store unregister handles + iid = self.instance_id + # Register instance-namespaced API commands self._unregister_handles.append( - self.mass.register_api_command("party/url", self.get_party_url, required_role="user") + self.mass.register_api_command( + f"party/{iid}/url", self.get_party_url, required_role="user" + ) ) self._unregister_handles.append( - self.mass.register_api_command("party/player", self.get_party_player) + self.mass.register_api_command(f"party/{iid}/player", self.get_party_player) ) self._unregister_handles.append( - self.mass.register_api_command("party/config", self.get_party_config) + self.mass.register_api_command(f"party/{iid}/config", self.get_party_config) ) # Guest action commands - these are called by the guest frontend self._unregister_handles.append( - self.mass.register_api_command("party/add_to_queue", self.add_to_queue) + self.mass.register_api_command(f"party/{iid}/add_to_queue", self.add_to_queue) ) self._unregister_handles.append( - self.mass.register_api_command("party/boost_queue_item", self.boost_queue_item) + self.mass.register_api_command(f"party/{iid}/boost_queue_item", self.boost_queue_item) ) self._unregister_handles.append( - self.mass.register_api_command("party/skip", self.skip_current) + self.mass.register_api_command(f"party/{iid}/skip", self.skip_current) ) async def unload(self, is_removed: bool = False) -> None: @@ -435,29 +464,31 @@ async def unload(self, is_removed: bool = False) -> None: ) if is_removed or not guest_access_enabled: self.logger.debug("Revoking guest tokens...") - await self._revoke_guest_tokens() + await self._revoke_guest_tokens(is_removed=is_removed) await super().unload(is_removed) # ==================== Configuration API Commands ==================== async def _get_or_create_party_guest_user(self) -> User: - """Get or create the party guest user. + """Get or create the shared party guest user. + + All party instances share a single guest user account. + The JWT extra_claims differentiate which instance a guest belongs to. - :returns: The party guest User. + :returns: The shared party guest User. """ auth = self.mass.webserver.auth - user = await auth.get_user_by_username(PARTY_GUEST_USER) + user = await auth.get_user_by_username(self._guest_username) if user: return user - # Create the party guest user user = await auth.create_user( - username=PARTY_GUEST_USER, + username=self._guest_username, role=UserRole.GUEST, - display_name=PARTY_GUEST_DISPLAY_NAME, + display_name="Party Guest", ) - self.logger.info("Created party guest user account") + self.logger.info("Created shared party guest user account") return user async def _get_join_code(self) -> str: @@ -472,7 +503,7 @@ async def _get_join_code(self) -> str: guest_user = await self._get_or_create_party_guest_user() # Check for an existing active join code - existing_code = await auth.get_active_join_code(guest_user) + existing_code = await auth.get_active_join_code(guest_user, instance_id=self.instance_id) if existing_code: return existing_code @@ -482,6 +513,7 @@ async def _get_join_code(self) -> str: expires_in_hours=8, max_uses=0, device_name="Party Guest", + instance_id=self.instance_id, ) return code @@ -525,6 +557,8 @@ async def get_party_config(self) -> PartyConfig: :returns: PartyConfig with feature toggles, token limits, refill rates, and colors. """ return PartyConfig( + instance_id=self.instance_id, + name=self.name, enable_rate_limiting=cast("bool", self.config.get_value(CONF_ENABLE_RATE_LIMITING)), enable_add_queue=cast("bool", self.config.get_value(CONF_ENABLE_ADD_QUEUE)), add_queue_limit=cast("int", self.config.get_value(CONF_PARTY_ADD_QUEUE_LIMIT)), @@ -547,6 +581,7 @@ async def get_party_config(self) -> PartyConfig: request_badge_color=cast("str", self.config.get_value(CONF_REQUEST_BADGE_COLOR)), boost_badge_color=cast("str", self.config.get_value(CONF_BOOST_BADGE_COLOR)), anti_burn_in=cast("bool", self.config.get_value(CONF_ANTI_BURN_IN)), + instance_count=len(self.mass.get_provider_instances("party")), ) # ==================== Guest Action API Commands ==================== @@ -798,14 +833,27 @@ async def _add_to_priority_section( shuffle=False, ) - @staticmethod - def _validate_guest_access() -> None: - """Validate the current user is an authenticated party guest. + def _validate_guest_access(self) -> None: + """Validate the current user is an authenticated party guest for this instance. + + Checks that the user is the shared party guest and that the JWT + extra_claims.party_instance matches this provider's instance_id. - :raises InvalidDataError: If the user is not a party guest. + :raises InvalidDataError: If the user is not a party guest for this instance. """ user = get_current_user() - if not user or user.username != PARTY_GUEST_USER: + if not user or user.username != self._guest_username: + raise InvalidDataError("This endpoint is only available to party guests") + # Verify the JWT claim matches this specific instance + token = get_current_token() + if not token: + raise InvalidDataError("No authentication token found") + try: + payload = self.mass.webserver.auth.jwt_helper.decode_token(token) + except Exception as err: + raise InvalidDataError("Invalid authentication token") from err + extra_claims = payload.get("extra_claims", {}) + if extra_claims.get("party_instance") != self.instance_id: raise InvalidDataError("This endpoint is only available to party guests") @staticmethod @@ -872,31 +920,43 @@ async def skip_current(self) -> dict[str, Any]: # ==================== Helper Methods ==================== - async def _revoke_guest_tokens(self) -> None: - """Revoke all guest access tokens and codes for party. + def _is_last_party_instance(self) -> bool: + """Check if this is the last remaining party provider instance. + + :returns: True if no other party instances exist. + """ + other_instances = [ + p + for p in self.mass.get_provider_instances("party") + if p.instance_id != self.instance_id + ] + return len(other_instances) == 0 + + async def _revoke_guest_tokens(self, is_removed: bool = False) -> None: + """Revoke guest access tokens and codes for this party instance. - This is called when guest access is disabled or the plugin is removed. - We disconnect WebSocket connections to force the frontend to redirect to login, - revoke tokens so they can't reconnect, and invalidate pending join codes. + When this instance is being removed and it's the last one, the shared + guest user is deleted entirely. Otherwise, only join codes for this + instance are revoked. + + :param is_removed: Whether the provider is being permanently removed. """ auth = self.mass.webserver.auth - # Find the party guest user - guest_user = await auth.get_user_by_username(PARTY_GUEST_USER) + guest_user = await auth.get_user_by_username(self._guest_username) if not guest_user: self.logger.debug("No party guest user found, nothing to revoke") return - # Revoke pending join codes for the guest user - codes_revoked = await auth.revoke_join_codes(guest_user) + # Revoke join codes scoped to this instance + codes_revoked = await auth.revoke_join_codes_for_instance(self.instance_id) if codes_revoked > 0: - self.logger.info("Revoked %d pending join codes", codes_revoked) - - # Revoke all tokens and disconnect WebSocket connections for the guest user - revoked_count = await auth.revoke_tokens_for_user(guest_user) - if revoked_count > 0: - self.logger.info( - "Revoked %d guest access tokens for user '%s'", - revoked_count, - guest_user.username, - ) + self.logger.info("Revoked %d pending join codes for instance", codes_revoked) + + if is_removed and self._is_last_party_instance(): + # Last instance being removed — revoke all tokens and delete the guest user + revoked_count = await auth.revoke_tokens_for_user(guest_user) + if revoked_count > 0: + self.logger.info("Revoked %d guest access tokens", revoked_count) + await auth.delete_user_internal(guest_user.user_id) + self.logger.info("Deleted shared party guest user (last instance removed)") diff --git a/music_assistant/providers/party/manifest.json b/music_assistant/providers/party/manifest.json index 5606e7599a..edab38e31a 100644 --- a/music_assistant/providers/party/manifest.json +++ b/music_assistant/providers/party/manifest.json @@ -7,7 +7,7 @@ "codeowners": ["@apophisnow"], "requirements": [], "documentation": "https://music-assistant.io/plugins/party", - "multi_instance": false, + "multi_instance": true, "builtin": false, "icon": "party-popper" }