Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 90 additions & 17 deletions music_assistant/controllers/webserver/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
LOGGER = logging.getLogger(f"{MASS_LOGGER_NAME}.auth")

# Database schema version
DB_SCHEMA_VERSION = 5
DB_SCHEMA_VERSION = 6

# Token expiration constants (in days)
TOKEN_SHORT_LIVED_EXPIRATION = 30 # Short-lived tokens (auto-renewing on use)
Expand Down Expand Up @@ -219,6 +219,7 @@ async def _create_database_tables(self) -> None:
use_count INTEGER DEFAULT 0,
last_used_at TEXT,
device_name TEXT,
instance_id TEXT,
FOREIGN KEY (user_id) REFERENCES users(user_id) ON DELETE CASCADE
)
"""
Expand Down Expand Up @@ -301,6 +302,12 @@ async def _migrate_database(self, from_version: int) -> None:
)
await self.database.commit()

# Migration to version 6: Add instance_id column to join_codes
if from_version < 6:
with contextlib.suppress(OperationalError):
await self.database.execute("ALTER TABLE join_codes ADD COLUMN instance_id TEXT")
await self.database.commit()

async def _get_or_create_jwt_secret(self) -> str:
"""Get or create JWT secret key from database.

Expand Down Expand Up @@ -872,15 +879,21 @@ async def update_provider_link(
# Create new link
await self.link_user_to_provider(user, provider_type, provider_user_id)

async def create_token(self, user: User, name: str, is_long_lived: bool = False) -> str:
"""
Create a new JWT access token for a user.
async def create_token(
self,
user: User,
name: str,
is_long_lived: bool = False,
extra_claims: dict[str, Any] | None = None,
) -> str:
"""Create a new JWT access token for a user.

:param user: The user to create the token for.
:param name: A name/description for the token (e.g., device name).
:param is_long_lived: Whether this is a long-lived token (default: False).
Short-lived tokens (False): Auto-renewing on use, expire after 30 days of inactivity.
Long-lived tokens (True): No auto-renewal, expire after 10 years.
:param extra_claims: Optional extra claims to embed in the JWT.
:return: JWT token string.
"""
# Generate unique token ID
Expand All @@ -902,6 +915,7 @@ async def create_token(self, user: User, name: str, is_long_lived: bool = False)
token_name=name,
expires_at=expires_at,
is_long_lived=is_long_lived,
extra_claims=extra_claims,
)

# Store token hash in database for revocation checking
Expand Down Expand Up @@ -1391,6 +1405,22 @@ async def delete_user(self, user_id: str) -> None:
admin_user.username,
)

async def delete_user_internal(self, user_id: str) -> None:
"""Delete a user account programmatically (no auth context required).

Used internally by providers to clean up system-managed users (e.g., party guests).

:param user_id: The user ID to delete.
"""
user_row = await self.database.get_row("users", {"user_id": user_id})
if not user_row:
return

await self.database.delete("users", {"user_id": user_id})
await self.database.commit()
self.webserver.disconnect_websockets_for_user(user_id)
self.logger.info("Internally deleted user '%s'", user_row["username"])

@api_command("auth/me")
async def get_current_user_info(self) -> User:
"""Get current authenticated user information."""
Expand Down Expand Up @@ -1612,6 +1642,7 @@ async def generate_join_code(
expires_in_hours: int = JOIN_CODE_DEFAULT_EXPIRY_HOURS,
max_uses: int = 1,
device_name: str = "Short Code Login",
instance_id: str | None = None,
) -> tuple[str, datetime]:
"""Generate a short join code for link/QR-based login.

Expand All @@ -1623,6 +1654,7 @@ async def generate_join_code(
:param expires_in_hours: Hours until code expires (default: 8).
:param max_uses: Maximum number of uses (0 = unlimited).
:param device_name: Device name for tokens created with this code.
:param instance_id: Optional provider instance ID to embed in the resulting JWT.
:return: Tuple of (code, expires_at datetime).
"""
if expires_in_hours <= 0:
Expand All @@ -1646,6 +1678,7 @@ async def generate_join_code(
"max_uses": max_uses,
"use_count": 0,
"device_name": device_name,
"instance_id": instance_id,
}
try:
await self.database.insert("join_codes", code_data)
Expand Down Expand Up @@ -1681,7 +1714,7 @@ async def _exchange_join_code(self, code: str) -> str | None:
WHERE code = :code
AND expires_at > :now
AND (max_uses = 0 OR use_count < max_uses)
RETURNING user_id, device_name
RETURNING user_id, device_name, instance_id
""",
{"now": now.isoformat(), "code": code.upper()},
)
Expand All @@ -1700,10 +1733,14 @@ async def _exchange_join_code(self, code: str) -> str | None:
return None

device_name = row["device_name"] or "Short Code Login"
extra_claims: dict[str, Any] | None = None
if row["instance_id"]:
extra_claims = {"party_instance": row["instance_id"]}
token = await self.create_token(
user,
device_name,
is_long_lived=False,
extra_claims=extra_claims,
)

self.logger.info(
Expand All @@ -1729,24 +1766,60 @@ async def revoke_join_codes(self, user: User) -> int:
self.logger.info("Revoked %d join code(s) for user %s", count, user.username)
return count

async def get_active_join_code(self, user: User) -> str | None:
async def revoke_join_codes_for_instance(self, instance_id: str) -> int:
"""Revoke all join codes associated with a specific provider instance.

:param instance_id: The provider instance ID.
:return: Number of codes revoked.
"""
cursor = await self.database.execute(
"DELETE FROM join_codes WHERE instance_id = :instance_id",
{"instance_id": instance_id},
)
await self.database.commit()

count = int(cursor.rowcount)
if count > 0:
self.logger.info("Revoked %d join code(s) for instance %s", count, instance_id)
return count

async def get_active_join_code(self, user: User, instance_id: str | None = None) -> str | None:
"""Get the most recently created, non-expired join code for a user.

:param user: The user to look up codes for.
:param instance_id: Optional instance ID to filter by.
:return: The join code string if found, None otherwise.
"""
now = utc()
cursor = await self.database.execute(
"""
SELECT code FROM join_codes
WHERE user_id = :user_id
AND expires_at > :now
AND (max_uses = 0 OR use_count < max_uses)
ORDER BY created_at DESC
LIMIT 1
""",
{"user_id": user.user_id, "now": now.isoformat()},
)
if instance_id:
cursor = await self.database.execute(
"""
SELECT code FROM join_codes
WHERE user_id = :user_id
AND instance_id = :instance_id
AND expires_at > :now
AND (max_uses = 0 OR use_count < max_uses)
ORDER BY created_at DESC
LIMIT 1
""",
{
"user_id": user.user_id,
"instance_id": instance_id,
"now": now.isoformat(),
},
)
else:
cursor = await self.database.execute(
"""
SELECT code FROM join_codes
WHERE user_id = :user_id
AND expires_at > :now
AND (max_uses = 0 OR use_count < max_uses)
ORDER BY created_at DESC
LIMIT 1
""",
{"user_id": user.user_id, "now": now.isoformat()},
)
row = await cursor.fetchone()
return str(row["code"]) if row else None

Expand Down
6 changes: 5 additions & 1 deletion music_assistant/helpers/jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def encode_token(
token_name: str,
expires_at: datetime,
is_long_lived: bool = False,
extra_claims: dict[str, Any] | None = None,
) -> str:
"""Encode a JWT token for a user.

Expand All @@ -51,10 +52,11 @@ def encode_token(
:param token_name: Human-readable token name.
:param expires_at: Token expiration datetime.
:param is_long_lived: Whether this is a long-lived token.
:param extra_claims: Optional extra claims to include in the JWT payload.
:return: Encoded JWT token string.
"""
now = utc()
payload = {
payload: dict[str, Any] = {
"sub": user.user_id,
"jti": token_id,
"iat": int(now.timestamp()),
Expand All @@ -64,6 +66,8 @@ def encode_token(
"token_name": token_name,
"is_long_lived": is_long_lived,
}
if extra_claims:
payload["extra_claims"] = extra_claims

return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)

Expand Down
Loading
Loading