diff --git a/src/program/settings/models.py b/src/program/settings/models.py index e2f81c2a4..e778b0f06 100644 --- a/src/program/settings/models.py +++ b/src/program/settings/models.py @@ -1,21 +1,25 @@ """Riven settings models""" +import re from collections.abc import Callable +from enum import Enum from pathlib import Path -from typing import Any, Literal, Annotated +from typing import Annotated, Any, Literal from pydantic import ( BaseModel, + BeforeValidator, + ConfigDict, Field, field_validator, model_validator, - BeforeValidator, ) from pydantic.networks import PostgresDsn from RTN.models import SettingsModel from program.settings.migratable import MigratableBaseModel from program.utils import generate_api_key, get_version +from program.utils.safe_formatter import SafeFormatter deprecation_warning = ( "This has been deprecated and will be removed in a future version." @@ -56,21 +60,41 @@ def __setattr__(self, name: str, value: Any): class RealDebridModel(Observable): + model_config = ConfigDict(title="Real-Debrid") + enabled: bool = Field(default=False, description="Enable Real-Debrid") - api_key: str = Field(default="", description="Real-Debrid API key") + api_key: str = Field( + default="", + description="Real-Debrid API key", + json_schema_extra={"format": "password"}, + ) class DebridLinkModel(Observable): + model_config = ConfigDict(title="Debrid-Link") + enabled: bool = Field(default=False, description="Enable Debrid-Link") - api_key: str = Field(default="", description="Debrid-Link API key") + api_key: str = Field( + default="", + description="Debrid-Link API key", + json_schema_extra={"format": "password"}, + ) class AllDebridModel(Observable): + model_config = ConfigDict(title="AllDebrid") + enabled: bool = Field(default=False, description="Enable AllDebrid") - api_key: str = Field(default="", description="AllDebrid API key") + api_key: str = Field( + default="", + description="AllDebrid API key", + json_schema_extra={"format": "password"}, + ) class DownloadersModel(Observable): + model_config = ConfigDict(title="Downloaders") + video_extensions: list[str] = Field( default_factory=lambda: list[str](["mp4", "mkv", "avi"]), description="list of video file extensions to consider for downloads", @@ -114,6 +138,8 @@ class DownloadersModel(Observable): class LibraryProfileFilterRules(BaseModel): """Filter rules for library profile matching (metadata-only)""" + model_config = ConfigDict(title="Filter Rules") + content_types: list[str] | None = Field( default=None, description="Media types to include (movie, show). None/omit = all types", @@ -196,6 +222,8 @@ def migrate_exclude_genres(self): class LibraryProfile(BaseModel): """Library profile configuration for organizing media into different libraries""" + model_config = ConfigDict(title="Library Profile") + name: str = Field(description="Human-readable profile name") library_path: str = Field( description="VFS path prefix for this profile (e.g., '/kids', '/anime')" @@ -218,8 +246,6 @@ def validate_library_path(cls, v: str): "library_path cannot be '/default' (reserved for default path)" ) # Check for valid characters (alphanumeric, dash, underscore, slash) - import re - if not re.match(r"^/[a-zA-Z0-9_\-/]+$", v): raise ValueError( "library_path must contain only alphanumeric characters, dashes, underscores, and slashes" @@ -228,6 +254,8 @@ def validate_library_path(cls, v: str): class FilesystemModel(Observable): + model_config = ConfigDict(title="Filesystem") + mount_path: Path = Field( default=Path("/path/to/riven/mount"), description="Path where Riven will mount the virtual filesystem", @@ -337,9 +365,7 @@ class FilesystemModel(Observable): @field_validator("library_profiles") def validate_library_profiles(cls, v: dict[str, LibraryProfile]): """Validate library profile keys and paths""" - import re - - for key in v.keys(): + for key in v: # Profile keys must be lowercase alphanumeric with underscores if not re.match(r"^[a-z0-9_]+$", key): raise ValueError( @@ -389,9 +415,6 @@ def validate_library_profiles(cls, v: dict[str, LibraryProfile]): ) def validate_naming_template(cls, v: str) -> str: """Validate that naming template string is syntactically valid.""" - - from program.utils.safe_formatter import SafeFormatter - try: # Test template with comprehensive dummy data formatter = SafeFormatter() @@ -449,30 +472,50 @@ def check_update_interval(cls, v: int): class PlexLibraryModel(Observable): + model_config = ConfigDict(title="Plex") + enabled: bool = Field(default=False, description="Enable Plex library updates") - token: str = Field(default="", description="Plex authentication token") + token: str = Field( + default="", + description="Plex authentication token", + json_schema_extra={"format": "password"}, + ) url: EmptyOrUrl = Field( default="http://localhost:32400", description="Plex server URL" ) class JellyfinLibraryModel(Observable): + model_config = ConfigDict(title="Jellyfin") + enabled: bool = Field(default=False, description="Enable Jellyfin library updates") - api_key: str = Field(default="", description="Jellyfin API key") + api_key: str = Field( + default="", + description="Jellyfin API key", + json_schema_extra={"format": "password"}, + ) url: EmptyOrUrl = Field( default="http://localhost:8096", description="Jellyfin server URL" ) class EmbyLibraryModel(Observable): + model_config = ConfigDict(title="Emby") + enabled: bool = Field(default=False, description="Enable Emby library updates") - api_key: str = Field(default="", description="Emby API key") + api_key: str = Field( + default="", + description="Emby API key", + json_schema_extra={"format": "password"}, + ) url: EmptyOrUrl = Field( default="http://localhost:8096", description="Emby server URL" ) class UpdatersModel(Observable): + model_config = ConfigDict(title="Media Servers") + updater_interval: int = Field( default=120, ge=1, description="Interval in seconds between library updates" ) @@ -498,6 +541,8 @@ class UpdatersModel(Observable): class ListrrModel(Updatable): + model_config = ConfigDict(title="Listrr") + enabled: bool = Field(default=False, description="Enable Listrr integration") movie_lists: list[str] = Field( default_factory=list, description="Listrr movie list IDs" @@ -505,15 +550,25 @@ class ListrrModel(Updatable): show_lists: list[str] = Field( default_factory=list, description="Listrr TV show list IDs" ) - api_key: str = Field(default="", description="Listrr API key") + api_key: str = Field( + default="", + description="Listrr API key", + json_schema_extra={"format": "password"}, + ) update_interval: int = Field( default=86400, ge=1, description="Update interval in seconds (24 hours default)" ) class MdblistModel(Updatable): + model_config = ConfigDict(title="MDBList") + enabled: bool = Field(default=False, description="Enable MDBList integration") - api_key: str = Field(default="", description="MDBList API key") + api_key: str = Field( + default="", + description="MDBList API key", + json_schema_extra={"format": "password"}, + ) lists: list[int | str] = Field( default_factory=list[int | str], description="MDBList list IDs to monitor" ) @@ -523,11 +578,17 @@ class MdblistModel(Updatable): class OverseerrModel(Updatable): + model_config = ConfigDict(title="Overseerr") + enabled: bool = Field(default=False, description="Enable Overseerr integration") url: EmptyOrUrl = Field( default="http://localhost:5055", description="Overseerr URL" ) - api_key: str = Field(default="", description="Overseerr API key") + api_key: str = Field( + default="", + description="Overseerr API key", + json_schema_extra={"format": "password"}, + ) use_webhook: bool = Field( default=False, description="Use webhook instead of polling" ) @@ -537,6 +598,8 @@ class OverseerrModel(Updatable): class PlexWatchlistModel(Updatable): + model_config = ConfigDict(title="Plex Watchlist") + enabled: bool = Field( default=False, description="Enable Plex Watchlist integration" ) @@ -551,16 +614,32 @@ class PlexWatchlistModel(Updatable): class TraktOauthModel(BaseModel): oauth_client_id: str = Field(default="", description="Trakt OAuth client ID") oauth_client_secret: str = Field( - default="", description="Trakt OAuth client secret" + default="", + description="Trakt OAuth client secret", + json_schema_extra={"format": "password"}, ) oauth_redirect_uri: str = Field(default="", description="Trakt OAuth redirect URI") - access_token: str = Field(default="", description="Trakt OAuth access token") - refresh_token: str = Field(default="", description="Trakt OAuth refresh token") + access_token: str = Field( + default="", + description="Trakt OAuth access token", + json_schema_extra={"format": "password"}, + ) + refresh_token: str = Field( + default="", + description="Trakt OAuth refresh token", + json_schema_extra={"format": "password"}, + ) class TraktModel(Updatable): + model_config = ConfigDict(title="Trakt") + enabled: bool = Field(default=False, description="Enable Trakt integration") - api_key: str = Field(default="", description="Trakt API key") + api_key: str = Field( + default="", + description="Trakt API key", + json_schema_extra={"format": "password"}, + ) watchlist: list[str] = Field( default_factory=list[str], description="Trakt usernames for watchlist monitoring", @@ -608,6 +687,8 @@ class TraktModel(Updatable): class ContentModel(Observable): + model_config = ConfigDict(title="Content Sources") + overseerr: OverseerrModel = Field( default_factory=lambda: OverseerrModel(), description="Overseerr configuration" ) @@ -630,6 +711,8 @@ class ContentModel(Observable): class TorrentioConfig(Observable): + model_config = ConfigDict(title="Torrentio") + enabled: bool = Field(default=False, description="Enable Torrentio scraper") filter: str = Field( default="sort=qualitysize%7Cqualityfilter=480p,scr,cam", @@ -649,6 +732,8 @@ class TorrentioConfig(Observable): class CometConfig(Observable): + model_config = ConfigDict(title="Comet") + enabled: bool = Field(default=False, description="Enable Comet scraper") url: EmptyOrUrl = Field(default="http://localhost:8000", description="Comet URL") timeout: int = Field(default=30, ge=1, description="Request timeout in seconds") @@ -659,6 +744,8 @@ class CometConfig(Observable): class ZileanConfig(Observable): + model_config = ConfigDict(title="Zilean") + enabled: bool = Field(default=False, description="Enable Zilean scraper") url: EmptyOrUrl = Field(default="http://localhost:8181", description="Zilean URL") timeout: int = Field(default=30, ge=1, description="Request timeout in seconds") @@ -669,6 +756,8 @@ class ZileanConfig(Observable): class MediafusionConfig(Observable): + model_config = ConfigDict(title="MediaFusion") + enabled: bool = Field(default=False, description="Enable Mediafusion scraper") url: EmptyOrUrl = Field( default="http://localhost:8000", description="Mediafusion URL" @@ -687,8 +776,14 @@ class OrionoidConfigParametersDict(Observable): class OrionoidConfig(Observable): + model_config = ConfigDict(title="Orionoid") + enabled: bool = Field(default=False, description="Enable Orionoid scraper") - api_key: str = Field(default="", description="Orionoid API key") + api_key: str = Field( + default="", + description="Orionoid API key", + json_schema_extra={"format": "password"}, + ) cached_results_only: bool = Field( default=False, description="Only return cached/downloadable results" ) @@ -704,9 +799,15 @@ class OrionoidConfig(Observable): class JackettConfig(Observable): + model_config = ConfigDict(title="Jackett") + enabled: bool = Field(default=False, description="Enable Jackett scraper") url: EmptyOrUrl = Field(default="http://localhost:9117", description="Jackett URL") - api_key: str = Field(default="", description="Jackett API key") + api_key: str = Field( + default="", + description="Jackett API key", + json_schema_extra={"format": "password"}, + ) timeout: int = Field(default=30, ge=1, description="Request timeout in seconds") retries: int = Field( default=1, ge=0, description="Number of retries for failed requests" @@ -720,9 +821,15 @@ class JackettConfig(Observable): class ProwlarrConfig(Observable): + model_config = ConfigDict(title="Prowlarr") + enabled: bool = Field(default=False, description="Enable Prowlarr scraper") url: EmptyOrUrl = Field(default="http://localhost:9696", description="Prowlarr URL") - api_key: str = Field(default="", description="Prowlarr API key") + api_key: str = Field( + default="", + description="Prowlarr API key", + json_schema_extra={"format": "password"}, + ) timeout: int = Field(default=30, ge=1, description="Request timeout in seconds") retries: int = Field( default=1, ge=0, description="Number of retries for failed requests" @@ -739,6 +846,8 @@ class ProwlarrConfig(Observable): class RarbgConfig(Observable): + model_config = ConfigDict(title="RARBG") + enabled: bool = Field(default=False, description="Enable RARBG scraper") url: EmptyOrUrl = Field(default="https://therarbg.to", description="RARBG URL") timeout: int = Field(default=30, ge=1, description="Request timeout in seconds") @@ -749,6 +858,8 @@ class RarbgConfig(Observable): class AIOStreamsConfig(Observable): + model_config = ConfigDict(title="AIOStreams") + enabled: bool = Field(default=False, description="Enable AIOStreams scraper") url: EmptyOrUrl = Field( default="http://localhost:8000", description="AIOStreams instance URL" @@ -758,12 +869,20 @@ class AIOStreamsConfig(Observable): default=1, ge=0, description="Number of retries for failed requests" ) ratelimit: bool = Field(default=True, description="Enable rate limiting") - proxy_url: EmptyOrUrl = Field(default="", description="Proxy URL for AIOStreams requests") + proxy_url: EmptyOrUrl = Field( + default="", description="Proxy URL for AIOStreams requests" + ) uuid: str = Field(default="", description="User UUID for AIOStreams authentication") - password: str = Field(default="", description="User password for AIOStreams authentication") + password: str = Field( + default="", + description="User password for AIOStreams authentication", + json_schema_extra={"format": "password"}, + ) class ScraperModel(Observable): + model_config = ConfigDict(title="Scraping") + after_2: float = Field( default=2, description="Hours to wait after 2 failed scrapes" ) @@ -814,7 +933,8 @@ class ScraperModel(Observable): default_factory=lambda: RarbgConfig(), description="RARBG configuration" ) aiostreams: AIOStreamsConfig = Field( - default_factory=lambda: AIOStreamsConfig(), description="AIOStreams configuration" + default_factory=lambda: AIOStreamsConfig(), + description="AIOStreams configuration", ) @@ -828,6 +948,8 @@ class RTNSettingsModel(SettingsModel, Observable): ... class IndexerModel(Observable): + model_config = ConfigDict(title="Indexer") + schedule_offset_minutes: int = Field( default=30, ge=0, @@ -836,6 +958,8 @@ class IndexerModel(Observable): class DatabaseModel(Observable): + model_config = ConfigDict(title="Database") + host: PostgresDsn = Field( default_factory=lambda: PostgresDsn( "postgresql+psycopg2://postgres:postgres@localhost/riven" @@ -844,10 +968,26 @@ class DatabaseModel(Observable): ) +class ItemType(str, Enum): + """Media item types for notifications.""" + + movie = "movie" + show = "show" + season = "season" + episode = "episode" + + class NotificationsModel(Observable): + model_config = ConfigDict(title="Notifications") + enabled: bool = Field(default=False, description="Enable notifications") - on_item_type: list[str] = Field( - default_factory=lambda: ["movie", "show", "season", "episode"], + on_item_type: list[ItemType] = Field( + default_factory=lambda: [ + ItemType.movie, + ItemType.show, + ItemType.season, + ItemType.episode, + ], description="Item types to send notifications for", ) service_urls: list[str] = Field( @@ -857,10 +997,14 @@ class NotificationsModel(Observable): class SubtitleProviderConfig(Observable): + model_config = ConfigDict(title="Subtitle Provider") + enabled: bool = Field(default=False, description="Enable this subtitle provider") class SubtitleProvidersDict(Observable): + model_config = ConfigDict(title="Subtitle Providers") + opensubtitles: SubtitleProviderConfig = Field( default_factory=lambda: SubtitleProviderConfig(), description="OpenSubtitles provider configuration", @@ -868,6 +1012,8 @@ class SubtitleProvidersDict(Observable): class SubtitleConfig(Observable): + model_config = ConfigDict(title="Subtitles") + enabled: bool = Field(default=False, description="Enable subtitle downloading") languages: list[str] = Field( default_factory=lambda: ["eng"], @@ -880,6 +1026,8 @@ class SubtitleConfig(Observable): class PostProcessing(Observable): + model_config = ConfigDict(title="Post Processing") + subtitle: SubtitleConfig = Field( default_factory=lambda: SubtitleConfig(), description="Subtitle post-processing configuration", @@ -887,6 +1035,8 @@ class PostProcessing(Observable): class LoggingModel(Observable): + model_config = ConfigDict(title="Logging") + enabled: bool = Field(default=True, description="Enable file logging") clean_interval: int = Field( default=60 * 60, description="Log cleanup interval in seconds (1 hour default)" @@ -909,6 +1059,8 @@ def check_compression(cls, v: str | None): class StreamModel(Observable): + model_config = ConfigDict(title="Streaming") + chunk_size_mb: int = Field( default=1, ge=1, @@ -933,7 +1085,11 @@ class StreamModel(Observable): class AppModel(Observable): version: str = Field(default_factory=get_version, description="Application version") - api_key: str = Field(default="", description="API key for Riven API access") + api_key: str = Field( + default="", + description="API key for Riven API access", + json_schema_extra={"format": "password"}, + ) log_level: Literal["TRACE", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = ( Field(default="INFO", description="Logging level") ) @@ -999,9 +1155,9 @@ class AppModel(Observable): @field_validator("log_level", mode="before") def check_debug(cls, v: str | bool): - if v == True: + if v is True: return "DEBUG" - elif v == False: + elif v is False: return "INFO" return v.upper() diff --git a/src/routers/secure/settings.py b/src/routers/secure/settings.py index ff6db2083..e6ab335b4 100644 --- a/src/routers/secure/settings.py +++ b/src/routers/secure/settings.py @@ -2,7 +2,7 @@ from typing import Annotated, Any, cast from fastapi import APIRouter, Body, HTTPException, Path, Query -from pydantic import TypeAdapter, ValidationError +from pydantic import ValidationError, create_model from program.settings import settings_manager from program.settings.models import AppModel @@ -64,34 +64,15 @@ async def get_settings_schema_for_keys( detail=f"Invalid keys: {', '.join(invalid_keys)}. Valid keys are: {', '.join(sorted(valid_keys))}", ) - all_defs: dict[str, Any] = {} - properties: dict[str, Any] = {} - required: list[str] = [] - - for key in requested_keys: - field_info = model_fields[key] - adapter: TypeAdapter[Any] = TypeAdapter(field_info.annotation) - field_schema = adapter.json_schema(ref_template="#/$defs/{model}") - - if "$defs" in field_schema: - all_defs.update(field_schema.pop("$defs")) - - properties[key] = field_schema - - if field_info.is_required(): - required.append(key) - - filtered_schema: dict[str, Any] = { - "properties": properties, - "required": required, - "title": title, - "type": "object", + # Build fields to preserve all metadata + fields = { + key: (model_fields[key].annotation, model_fields[key]) for key in requested_keys } - if all_defs: - filtered_schema["$defs"] = all_defs + # Create dynamic model to preserve all metadata + filtered_model = create_model(title, **fields) # type: ignore[call-overload] - return filtered_schema + return cast(dict[str, Any], filtered_model.model_json_schema()) @router.get(