diff --git a/music_assistant/constants.py b/music_assistant/constants.py index 0a9d25918e..f3f8a86c54 100644 --- a/music_assistant/constants.py +++ b/music_assistant/constants.py @@ -140,11 +140,13 @@ CONF_PROTOCOL_KEY_SPLITTER: Final[str] = "||protocol||" CONF_PROTOCOL_CATEGORY_PREFIX: Final[str] = "protocol" CONF_DEFAULT_PROVIDERS_SETUP: Final[str] = "default_providers_setup" +CONF_BACKGROUND_SCAN_CONCURRENCY: Final[str] = "background_scan_concurrency" # config default values DEFAULT_HOST: Final[str] = "0.0.0.0" DEFAULT_PORT: Final[int] = 8095 +DEFAULT_BACKGROUND_SCAN_CONCURRENCY: Final[int] = 1 # common db tables diff --git a/music_assistant/controllers/streams/audio_analysis.py b/music_assistant/controllers/streams/audio_analysis.py index 4a16e0df91..5a43786ac3 100644 --- a/music_assistant/controllers/streams/audio_analysis.py +++ b/music_assistant/controllers/streams/audio_analysis.py @@ -4,17 +4,21 @@ import asyncio import contextlib +import dataclasses import os +import time from math import inf -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch from music_assistant_models.background_task import TaskSchedule -from music_assistant_models.enums import MediaType, ProviderType, StreamType +from music_assistant_models.enums import ContentType, MediaType, ProviderType, StreamType from music_assistant.constants import ( + CONF_BACKGROUND_SCAN_CONCURRENCY, DB_TABLE_AUDIO_ANALYSIS, DB_TABLE_PROVIDER_MAPPINGS, + DEFAULT_BACKGROUND_SCAN_CONCURRENCY, LOUDNESS_MEASUREMENT_MIN_LUFS, ) from music_assistant.helpers.datetime import local_clock_time_to_utc @@ -23,12 +27,12 @@ from music_assistant.models.audio_analysis_provider import AudioAnalysisProvider from music_assistant.models.music_provider import MusicProvider -CHUNK_PROCESS_TIMEOUT = 1.0 +CHUNK_PROCESS_TIMEOUT_SECONDS = 1.0 LOUDNESS_ANALYSIS_DOMAIN = "loudness_analysis" BACKGROUND_SCAN_TASK_ID = "audio_analysis_background_scan" -BACKGROUND_SCAN_BATCH_SIZE = 250 -BACKGROUND_SCAN_SLEEP_BETWEEN_ITEMS = 2.0 -# providers whose tracks can be analyzed from their local filesystem path +BACKGROUND_PER_TRACK_TIMEOUT_SECONDS = 300 +# Per-run wall-clock cap; in-flight tracks finish, new ones defer to the next run. +BACKGROUND_SCAN_RUN_BUDGET_SECONDS = 4 * 3600 FILESYSTEM_PROVIDER_DOMAINS: tuple[str, ...] = ( "filesystem_local", "filesystem_smb", @@ -47,10 +51,7 @@ class AudioAnalysisController: """Controller that distributes PCM chunks to all registered AudioAnalysisProviders.""" def __init__(self, streams: StreamsController) -> None: - """Initialize the AudioAnalysisController. - - :param streams: Parent StreamsController instance. - """ + """Initialize the AudioAnalysisController.""" self.streams = streams self.mass = streams.mass self.logger = self.mass.logger.getChild("audio_analysis") @@ -67,8 +68,21 @@ def setup(self) -> None: handler=self._run_background_scan, schedule=TaskSchedule.daily(hour=utc_hour, minute=utc_minute), metadata={"task_domain": "audio_analysis"}, + allow_retry=True, ) + async def close(self) -> None: + """Drain in-flight sessions and chunk workers on shutdown.""" + workers = list(self._workers.values()) + self._workers.clear() + for worker in workers: + if not worker.done(): + worker.cancel() + for session_key in list(self._active_sessions): + self._cancel_providers(session_key) + if workers: + await asyncio.gather(*workers, return_exceptions=True) + def _configure_thread_caps(self) -> None: """Cap PyTorch threading so Audio Analysis inference stays around a quarter of cpu_count.""" budget = self._aa_thread_budget() @@ -99,9 +113,6 @@ async def start_analysis( """ Start analysis session for a track across all providers. - Starts an analysis session for the given track on all available - Audio Analysis providers. - :param audio_buffer: The AudioBuffer to observe for PCM chunks. :param streamdetails: The stream details for the item being analyzed. """ @@ -131,7 +142,6 @@ async def start_analysis( queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=10) self._workers[session_key] = self.mass.create_task(self._chunk_worker(session_key, queue)) - # Build and register closures for callbacks on the audio buffer finalized = False async def _on_chunk(position_seconds: int, pcm_data: bytes, is_last_chunk: bool) -> None: # noqa: ARG001 @@ -144,7 +154,7 @@ async def _on_chunk(position_seconds: int, pcm_data: bytes, is_last_chunk: bool) self.mass.create_task(_finalize_session()) return try: - await asyncio.wait_for(queue.put(pcm_data), timeout=CHUNK_PROCESS_TIMEOUT) + await asyncio.wait_for(queue.put(pcm_data), timeout=CHUNK_PROCESS_TIMEOUT_SECONDS) except (TimeoutError, asyncio.QueueFull): return @@ -330,100 +340,176 @@ async def get_audio_analysis_version( return int(row["analysis_version"]) async def _run_background_scan(self) -> None: - """ - Run the nightly background scan across all audio analysis providers. - - Iterates each available provider, queries tracks from local-filesystem - music providers that do not yet have analysis for that provider, and - hands each one to the provider's `analyze_file` hook. Results are - persisted via `set_audio_analysis`. The batch aborts for a given - provider if its storage backend goes offline. - """ + """Run the scan as decode-once-fan-out streaming over candidate tracks.""" providers = self.providers if not providers: return - for provider in providers: - candidates = await self._find_tracks_missing_analysis( - provider.domain, BACKGROUND_SCAN_BATCH_SIZE - ) - if not candidates: - continue + domains = [p.domain for p in providers] + candidates = await self._find_candidates_missing_analysis(domains, limit=0) + if not candidates: + return + + scan_started = time.monotonic() + run_deadline = scan_started + BACKGROUND_SCAN_RUN_BUDGET_SECONDS + self.logger.info( + "Background analysis (streaming): %d track(s) pending across %d provider(s); " + "run budget %.1fh", + len(candidates), + len(providers), + BACKGROUND_SCAN_RUN_BUDGET_SECONDS / 3600, + ) + + concurrency = self._get_scan_concurrency() + semaphore = asyncio.Semaphore(concurrency) + provider_by_domain = {p.domain: p for p in providers} + + processed = 0 + deferred = 0 + + async def _run_one(candidate: dict[str, Any]) -> None: + nonlocal processed, deferred + async with semaphore: + if time.monotonic() >= run_deadline: + deferred += 1 + return + + item_id = candidate["item_id"] + provider_instance = candidate["provider_instance"] + missing = candidate["missing_domains"] - self.logger.info( - "Background %s analysis: %d track(s) pending", - provider.domain, - len(candidates), - ) - processed = 0 - for row in candidates: - if not provider.available: - # provider was disabled mid-run - break - item_id = str(row["item_id"]) - provider_instance = str(row["provider_instance"]) music_prov = self.mass.get_provider(provider_instance, provider_type=MusicProvider) if music_prov is None or not music_prov.available: - # storage may be offline right now (e.g. NAS asleep) — stop the - # batch rather than churning through failures for the remaining - # tracks self.logger.debug( - "Background %s analysis: provider %s unavailable, aborting batch", - provider.domain, - provider_instance, + "Skipping %s: music provider %s unavailable", item_id, provider_instance ) - break + return try: streamdetails = await music_prov.get_stream_details(item_id, MediaType.TRACK) except Exception as err: - self.logger.debug( - "Background %s analysis: skipping %s (stream details failed: %s)", - provider.domain, - item_id, - err, - ) - continue + self.logger.debug("Skipping %s: stream details failed: %s", item_id, err) + return if streamdetails.stream_type != StreamType.LOCAL_FILE: - continue + return if not isinstance(streamdetails.path, str) or not streamdetails.path: - continue + return - try: - result = await provider.analyze_file(streamdetails) - except Exception as err: - self.logger.warning( - "Background %s analysis failed for %s: %s", - provider.domain, - item_id, - err, - ) - result = None - - if result is not None: - await self.set_audio_analysis( - item_id=item_id, - provider_instance_id_or_domain=music_prov.instance_id, - aa_provider_domain=provider.domain, - analysis=result, - analysis_version=provider.analysis_version, - ) - processed += 1 + providers_for_track = [ + p + for p in (provider_by_domain.get(d) for d in missing) + if p is not None and p.available + ] + if not providers_for_track: + return + + await self._run_background_streaming_for_track(streamdetails, providers_for_track) + processed += 1 - await asyncio.sleep(BACKGROUND_SCAN_SLEEP_BETWEEN_ITEMS) + await asyncio.gather(*(_run_one(c) for c in candidates)) + elapsed = time.monotonic() - scan_started + if deferred: + self.logger.info( + "Background analysis: run-budget reached " + "(%d processed, %d deferred to next run, %.1fs elapsed)", + processed, + deferred, + elapsed, + ) + else: self.logger.info( - "Background %s analysis: analyzed %d/%d track(s)", - provider.domain, + "Background analysis: complete (%d candidates processed in %.1fs)", processed, - len(candidates), + elapsed, + ) + + async def _run_background_streaming_for_track( + self, + streamdetails: StreamDetails, + providers: list[AudioAnalysisProvider], + ) -> None: + """Run a single track through the streaming pipeline using ffmpeg as the source.""" + session_key = streamdetails.uri + if session_key in self._active_sessions: + self.logger.debug( + "Background streaming: session already active for %s, skipping", session_key + ) + return + + try: + await asyncio.wait_for( + self._run_background_streaming_inner(session_key, streamdetails, providers), + timeout=BACKGROUND_PER_TRACK_TIMEOUT_SECONDS, + ) + except asyncio.CancelledError: + # CancelledError inherits from BaseException — the broad except below + # does not catch it. Clean up the session, then re-raise. + self.logger.debug("Background analysis cancelled for %s", session_key) + self._cancel_providers(session_key) + raise + except TimeoutError: + self.logger.warning( + "Background analysis exceeded %ds budget for %s, skipping", + BACKGROUND_PER_TRACK_TIMEOUT_SECONDS, + session_key, + ) + self._cancel_providers(session_key) + self.mass.tasks.add_task_failure( + BACKGROUND_SCAN_TASK_ID, + f"Timed out after {BACKGROUND_PER_TRACK_TIMEOUT_SECONDS}s: {session_key}", + ) + except Exception as err: + self.logger.warning("Background analysis failed for %s: %s", session_key, err) + self._cancel_providers(session_key) + self.mass.tasks.add_task_failure( + BACKGROUND_SCAN_TASK_ID, + f"Failed: {session_key}: {err}", ) - async def _find_tracks_missing_analysis( - self, aa_provider_domain: str, limit: int - ) -> list[dict[str, object]]: - """Return up to N local-filesystem tracks without analysis for the given AA provider.""" + async def _run_background_streaming_inner( + self, + session_key: str, + streamdetails: StreamDetails, + providers: list[AudioAnalysisProvider], + ) -> None: + """Inner body of _run_background_streaming_for_track, wrapped by wait_for.""" + if not isinstance(streamdetails.path, str) or not streamdetails.path: + return + + # Override content_type so ffmpeg decodes rather than re-muxing the source codec. + pcm_format = dataclasses.replace( + streamdetails.audio_format, + content_type=ContentType.from_bit_depth(streamdetails.audio_format.bit_depth), + ) + + accepted = await self._start_analysis_on_providers( + session_key, streamdetails, pcm_format, providers + ) + if not accepted: + self.logger.debug("No providers accepted background analysis for %s", session_key) + return + self._active_sessions[session_key] = accepted + + audio_source = self.mass.streams.audio.get_media_stream(streamdetails, pcm_format) + async for chunk in audio_source: + if session_key not in self._active_sessions: + # all providers evicted — bail early + break + await self._distribute_chunk(session_key, chunk) + if session_key in self._active_sessions: + self._finalize_providers(session_key) + + async def _find_candidates_missing_analysis( + self, + aa_provider_domains: list[str], + limit: int, + ) -> list[dict[str, Any]]: + """Return rows {item_id, provider_instance, missing_domains} for tracks lacking analysis.""" + if not aa_provider_domains: + return [] + filesystem_domains = tuple( domain for domain in FILESYSTEM_PROVIDER_DOMAINS @@ -435,26 +521,47 @@ async def _find_tracks_missing_analysis( if not filesystem_domains: return [] - domains_sql = ", ".join(f"'{d}'" for d in filesystem_domains) - track_media_type = MediaType.TRACK.value - # audio_analysis.item_id holds the provider-native item id, - # so join against provider_mappings.provider_item_id (not pm.item_id, - # which is the integer library-row id) + # CROSS JOIN (track x possible domain), keep pairs with no analysis row, + # GROUP_CONCAT the missing domains per track. + fs_inline = ", ".join(f"'{d}'" for d in filesystem_domains) + aa_select_terms = " UNION ALL ".join( + f"SELECT :aa_{i} AS aa_provider_domain" for i in range(len(aa_provider_domains)) + ) + params: dict[str, Any] = { + "media_type": MediaType.TRACK.value, + **{f"aa_{i}": d for i, d in enumerate(aa_provider_domains)}, + } query = ( f"SELECT pm.provider_item_id AS item_id, " - f" pm.provider_instance AS provider_instance " + f" pm.provider_instance AS provider_instance, " + f" GROUP_CONCAT(possible.aa_provider_domain) AS missing_domains " f"FROM {DB_TABLE_PROVIDER_MAPPINGS} pm " - f"LEFT JOIN {DB_TABLE_AUDIO_ANALYSIS} aa " - f" ON aa.item_id = pm.provider_item_id " - f" AND aa.provider = pm.provider_instance " - f" AND aa.aa_provider_domain = '{aa_provider_domain}' " - f" AND aa.media_type = '{track_media_type}' " - f"WHERE pm.media_type = '{track_media_type}' " - f" AND pm.provider_domain IN ({domains_sql}) " - f" AND aa.id IS NULL" + f"CROSS JOIN ({aa_select_terms}) possible " + f"WHERE pm.media_type = :media_type " + f" AND pm.provider_domain IN ({fs_inline}) " + f" AND NOT EXISTS (" + f" SELECT 1 FROM {DB_TABLE_AUDIO_ANALYSIS} aa " + f" WHERE aa.item_id = pm.provider_item_id " + f" AND aa.provider = pm.provider_instance " + f" AND aa.aa_provider_domain = possible.aa_provider_domain " + f" AND aa.media_type = :media_type" + f" ) " + f"GROUP BY pm.provider_item_id, pm.provider_instance" ) - rows = await self.mass.music.database.get_rows_from_query(query, limit=limit) - return [dict(r) for r in rows] + rows = await self.mass.music.database.get_rows_from_query(query, params, limit=limit) + results: list[dict[str, Any]] = [] + for r in rows: + missing_raw = r["missing_domains"] + if not missing_raw: + continue + results.append( + { + "item_id": str(r["item_id"]), + "provider_instance": str(r["provider_instance"]), + "missing_domains": sorted(set(missing_raw.split(","))), + } + ) + return results async def _start_analysis_on_providers( self, @@ -499,62 +606,75 @@ def _cancel_providers(self, session_key: str) -> None: if provider and isinstance(provider, AudioAnalysisProvider) and provider.available: self.mass.create_task(provider.cancel(session_key)) + async def _distribute_chunk(self, session_key: str, pcm_data: bytes) -> None: + """Fan a single PCM chunk to every provider in the session.""" + provider_ids = self._active_sessions.get(session_key) + if not provider_ids: + return + + async def _process(prov_id: str) -> str | None: + try: + provider = self.mass.get_provider(prov_id) + if not ( + provider and isinstance(provider, AudioAnalysisProvider) and provider.available + ): + return None + await asyncio.wait_for( + provider.process_pcm_chunk(session_key, pcm_data), + timeout=CHUNK_PROCESS_TIMEOUT_SECONDS, + ) + except TimeoutError: + self.logger.warning( + "Provider %s timed out processing chunk for %s, removing from session", + prov_id, + session_key, + ) + return prov_id + except Exception as err: + self.logger.warning("Error processing PCM chunk on provider %s: %s", prov_id, err) + return prov_id + return None + + results = await asyncio.gather(*[_process(prov_id) for prov_id in provider_ids]) + evicted = {prov_id for prov_id in results if prov_id is not None} + if evicted: + for prov_id in evicted: + provider = self.mass.get_provider(prov_id) + if provider and isinstance(provider, AudioAnalysisProvider) and provider.available: + self.mass.create_task(provider.cancel(session_key)) + provider_ids -= evicted + if not provider_ids: + self._active_sessions.pop(session_key, None) + async def _chunk_worker(self, session_key: str, queue: asyncio.Queue[bytes | None]) -> None: - """Background worker that processes queued PCM chunks concurrently across providers.""" + """Background worker that processes queued PCM chunks via _distribute_chunk.""" while True: chunk = await queue.get() if chunk is None: break - - provider_ids = self._active_sessions.get(session_key) - if not provider_ids: + if session_key not in self._active_sessions: + break + await self._distribute_chunk(session_key, chunk) + if session_key not in self._active_sessions: + # all providers evicted by _distribute_chunk + self._workers.pop(session_key, None) break - - pcm_data = chunk # bind for closure (chunk is narrowed to bytes here) - - async def _process(prov_id: str, pcm_data: bytes = pcm_data) -> str | None: - try: - provider = self.mass.get_provider(prov_id) - if not ( - provider - and isinstance(provider, AudioAnalysisProvider) - and provider.available - ): - return None - await asyncio.wait_for( - provider.process_pcm_chunk(session_key, pcm_data), - timeout=CHUNK_PROCESS_TIMEOUT, - ) - except TimeoutError: - self.logger.warning( - "Provider %s timed out processing chunk for %s, removing from session", - prov_id, - session_key, - ) - return prov_id - except Exception as err: - self.logger.warning( - "Error processing PCM chunk on provider %s: %s", prov_id, err - ) - return None - - results = await asyncio.gather(*[_process(prov_id) for prov_id in provider_ids]) - timed_out = {prov_id for prov_id in results if prov_id is not None} - if timed_out: - for prov_id in timed_out: - provider = self.mass.get_provider(prov_id) - if ( - provider - and isinstance(provider, AudioAnalysisProvider) - and provider.available - ): - self.mass.create_task(provider.cancel(session_key)) - provider_ids -= timed_out - if not provider_ids: - self._active_sessions.pop(session_key, None) - self._workers.pop(session_key, None) - break def _aa_thread_budget(self) -> int: """Return the per-op PyTorch intra-op thread budget for inference (~25% of cpu_count).""" return max(1, (os.process_cpu_count() or os.cpu_count() or 4) // 4) + + def _get_scan_concurrency(self) -> int: + """Read background scan concurrency from config, clamped to [1, 8].""" + try: + value = int( + self.mass.config.get_raw_core_config_value( + "streams", + CONF_BACKGROUND_SCAN_CONCURRENCY, + DEFAULT_BACKGROUND_SCAN_CONCURRENCY, + ) + or DEFAULT_BACKGROUND_SCAN_CONCURRENCY + ) + except Exception: + value = DEFAULT_BACKGROUND_SCAN_CONCURRENCY + return max(1, min(value, 8)) diff --git a/music_assistant/controllers/streams/controller.py b/music_assistant/controllers/streams/controller.py index 40374c10b7..6ff8fed1b4 100644 --- a/music_assistant/controllers/streams/controller.py +++ b/music_assistant/controllers/streams/controller.py @@ -30,6 +30,7 @@ from music_assistant.constants import ( ANNOUNCE_ALERT_FILE, + CONF_BACKGROUND_SCAN_CONCURRENCY, CONF_BIND_IP, CONF_BIND_PORT, CONF_CROSSFADE_DURATION, @@ -43,6 +44,7 @@ CONF_VOLUME_NORMALIZATION_FIXED_GAIN_TRACKS, CONF_VOLUME_NORMALIZATION_RADIO, CONF_VOLUME_NORMALIZATION_TRACKS, + DEFAULT_BACKGROUND_SCAN_CONCURRENCY, DEFAULT_STREAM_HEADERS, DLNA_CONTENT_FEATURES, DLNA_CONTENT_FEATURES_REALTIME, @@ -249,9 +251,20 @@ async def get_config_entries( description="Log level for the Smart Fades mixer and analyzer.", options=CONF_ENTRY_LOG_LEVEL.options, default_value="GLOBAL", - category="generic", + category="audio_analysis", advanced=True, ), + ConfigEntry( + key=CONF_BACKGROUND_SCAN_CONCURRENCY, + type=ConfigEntryType.INTEGER, + range=(1, 8), + default_value=DEFAULT_BACKGROUND_SCAN_CONCURRENCY, + label="Background analysis concurrency", + description="Maximum number of tracks analyzed concurrently during the nightly " + "background scan. Default 1 (serial). Increase only if your hardware can handle " + "concurrent torch/ffmpeg work.", + category="audio_analysis", + ), ) async def setup(self, config: CoreConfig) -> None: @@ -313,6 +326,7 @@ async def setup(self, config: CoreConfig) -> None: async def close(self) -> None: """Cleanup on exit.""" + await self._audio_analysis.close() await self._server.close() async def resolve_stream_url(self, player_id: str, media: PlayerMedia) -> str: diff --git a/music_assistant/models/audio_analysis_provider.py b/music_assistant/models/audio_analysis_provider.py index c6f310269d..fc2335090c 100644 --- a/music_assistant/models/audio_analysis_provider.py +++ b/music_assistant/models/audio_analysis_provider.py @@ -28,15 +28,13 @@ class AnalysisSessionData: class AudioAnalysisProvider(Provider): - """Base representation of an Audio Analysis Provider. - - Audio Analysis Provider implementations should inherit from this base model. - These providers receive PCM audio chunks during streaming and produce analysis - results such as beat tracking, key detection, phrase boundaries, etc. + """ + Base representation of an Audio Analysis Provider. - The AudioAnalysisController creates session IDs and passes them to all methods. - Providers implement _start_analysis and _finalize as hooks — the base class - manages session lifecycle, version gating, and cleanup. + Receives PCM audio chunks during streaming and produces analysis results + such as beat tracking, key detection, or loudness. The same hooks drive + both live playback and background scans; providers do not need to know + which context they are running in. """ # Version of the analysis algorithm. Providers should increment this when @@ -61,10 +59,9 @@ async def start_analysis( streamdetails: StreamDetails, audio_format: AudioFormat, ) -> bool: - """Start analysis for a new session. + """ + Start analysis for a new session. - Checks whether analysis is needed (version gating), stores session data, - and calls _start_analysis for provider-specific initialization. Returns True if the provider accepted the session. :param session_id: Session ID created by the AudioAnalysisController. @@ -95,11 +92,10 @@ async def _start_analysis( streamdetails: StreamDetails, audio_format: AudioFormat, ) -> bool: - """Provider-specific initialization for a new analysis session. + """ + Provider-specific initialization for a new analysis session. - Called by start_analysis after version gating and session storage. Return False to reject the session (e.g. unsupported format). - Session data is available in self._sessions[session_id]. :param session_id: The analysis session ID. :param streamdetails: The stream details for the item being analyzed. @@ -112,53 +108,74 @@ async def process_pcm_chunk( session_id: str, pcm_chunk: bytes, ) -> None: - """Process a PCM audio chunk. + """ + Process a PCM audio chunk. - Called for each chunk of audio data during streaming. + Implementations MUST `await` all chunk-processing work; the controller + relies on this to backpressure the audio source. :param session_id: The analysis session ID. :param pcm_chunk: Raw PCM audio data. """ @abstractmethod - async def _finalize(self, session_id: str) -> None: - """Finalize analysis and store results. + async def _finalize(self, session_id: str) -> AudioAnalysisData | None: + """ + Compute and return the analysis for this session (or None to skip). - Called when the track has finished streaming. Providers are responsible - for storing their results via mass.streams.audio_analysis.set_audio_analysis(). + The base class persists the returned value via set_audio_analysis() and + then fires post_analysis(). Return None to skip both. :param session_id: The analysis session ID. """ async def finalize(self, session_id: str) -> None: - """Finalize analysis and clean up session state. - - Calls _finalize, then removes the session from _sessions. - The controller calls this method — providers override _finalize. - - :param session_id: The analysis session ID. - """ + """Finalize analysis, persist the result, fire post_analysis, then clean up.""" + analysis: AudioAnalysisData | None = None try: - await self._finalize(session_id) - finally: - self._sessions.pop(session_id, None) + analysis = await self._finalize(session_id) + except Exception as err: + self.logger.error("_finalize raised for session %s: %s", session_id, err, exc_info=err) + session = self._sessions.get(session_id) + if analysis is not None and session is not None: + try: + await self.mass.streams.audio_analysis.set_audio_analysis( + item_id=session.streamdetails.item_id, + provider_instance_id_or_domain=session.streamdetails.provider, + aa_provider_domain=self.domain, + analysis=analysis, + analysis_version=self.analysis_version, + media_type=session.streamdetails.media_type, + ) + except Exception as err: + self.logger.warning( + "set_audio_analysis raised for %s: %s", self.domain, err, exc_info=err + ) + else: + try: + await self.post_analysis(session.streamdetails, analysis) + except Exception as err: + self.logger.warning( + "post_analysis raised for %s: %s", self.domain, err, exc_info=err + ) + self._sessions.pop(session_id, None) - async def analyze_file( + async def post_analysis( self, streamdetails: StreamDetails, - ) -> AudioAnalysisData | None: + analysis: AudioAnalysisData, + ) -> None: """ - Run analysis directly on a local audio file. + Run side effects after analysis is finalized and persisted. - Used by the AudioAnalysisController's background scan. Providers that can - analyze a file without going through live PCM streaming (e.g. by handing - the path to FFmpeg/librosa/etc.) should override this. Default returns - None, meaning the provider does not support file-based analysis. + Default is a no-op. Implementations MUST self-gate on whether + `streamdetails.path` is a writable filesystem path, since this hook + fires for both live and background-scan analyses. - :param streamdetails: StreamDetails for the item being analyzed. - Contains the local file path and audio format. + :param streamdetails: The stream details for the analyzed item. + :param analysis: The analysis data that was persisted by `_finalize`. """ - return None + return async def cancel(self, session_id: str) -> None: """Cancel an in-progress analysis session.""" diff --git a/music_assistant/providers/_demo_audio_analysis_provider/__init__.py b/music_assistant/providers/_demo_audio_analysis_provider/__init__.py index ab48822783..1f8354f521 100644 --- a/music_assistant/providers/_demo_audio_analysis_provider/__init__.py +++ b/music_assistant/providers/_demo_audio_analysis_provider/__init__.py @@ -96,7 +96,7 @@ class DemoAudioAnalysisProvider(AudioAnalysisProvider): The base class uses this to skip re-analysis of already-analyzed tracks. - If you have other conditions that determine whether to skip an analysis, implement them in _start_analysis and return False to reject the session. - - Store results via self.mass.streams.audio_analysis.set_audio_analysis() in _finalize. + - Return AudioAnalysisData from _finalize; the base class persists it. """ # Increment this when your analysis algorithm changes significantly. @@ -162,26 +162,18 @@ async def process_pcm_chunk( ) async def _finalize(self, session_id: str) -> None: - """Finalize analysis and store results. + """Finalize analysis and return the result. Called when the track has finished buffering and all chunks have been - processed. This is where a real provider would compute final results - and store them via self.mass.streams.audio_analysis.set_audio_analysis(). + processed. A real provider would compute its final result and return it + as an AudioAnalysisData; the base class then persists it via + set_audio_analysis() and fires post_analysis(). Return None to skip both. - Example of storing results (not done in this demo):: + Example return (not done in this demo):: from music_assistant.models.audio_analysis import AudioAnalysisData - session = self._sessions[session_id] - analysis = AudioAnalysisData(bpm=120.0, duration=180.5) - await self.mass.streams.audio_analysis.set_audio_analysis( - item_id=session.streamdetails.item_id, - provider_instance_id_or_domain=session.streamdetails.provider, - aa_provider_domain=self.domain, - analysis=analysis, - analysis_version=self.analysis_version, - media_type=session.streamdetails.media_type, - ) + return AudioAnalysisData(bpm=120.0, duration=180.5) Note: The base class's finalize() method calls this, then cleans up the session from self._sessions automatically. Do not override finalize() diff --git a/music_assistant/providers/loudness_analysis/provider.py b/music_assistant/providers/loudness_analysis/provider.py index 20a060604a..8f68ffba3b 100644 --- a/music_assistant/providers/loudness_analysis/provider.py +++ b/music_assistant/providers/loudness_analysis/provider.py @@ -83,30 +83,6 @@ async def cancel(self, session_id: str) -> None: await data.ffmpeg.close() await super().cancel(session_id) - async def analyze_file(self, streamdetails: StreamDetails) -> AudioAnalysisData | None: - """Run ebur128 directly on a local audio file and return the measurement.""" - if not isinstance(streamdetails.path, str) or not streamdetails.path: - return None - metrics = await _run_ebur128_on_file(streamdetails.path, streamdetails.audio_format) - if metrics is None: - return None - loudness, loudness_range, true_peak = metrics - if loudness is None or loudness <= LOUDNESS_MEASUREMENT_MIN_LUFS: - return None - if self.config.get_value(CONF_WRITE_REPLAYGAIN_TAGS): - # ReplayGain 2.0: track_gain_db = -18 - loudness_lufs - track_gain_db = -18.0 - loudness - ok = await write_replaygain_track_gain(streamdetails.path, track_gain_db) - if ok: - self.logger.debug( - "Background loudness: wrote ReplayGain tag to %s", streamdetails.path - ) - return AudioAnalysisData( - loudness_integrated=round(loudness, 2), - loudness_range=round(loudness_range, 2) if loudness_range is not None else None, - true_peak=round(true_peak, 2) if true_peak is not None else None, - ) - async def _start_analysis( self, session_id: str, @@ -131,11 +107,11 @@ async def _start_analysis( self._data[session_id] = LoudnessSessionData(ffmpeg=ffmpeg) return True - async def _finalize(self, session_id: str) -> None: + async def _finalize(self, session_id: str) -> AudioAnalysisData | None: """Persist the final loudness measurement for the session.""" data = self._data.pop(session_id, None) if not data: - return + return None await self._send_eof(data) try: @@ -143,14 +119,14 @@ async def _finalize(self, session_id: str) -> None: except Exception as err: self.logger.debug("Loudness analysis ffmpeg failed: %s", err) await data.ffmpeg.close() - return + return None metrics = _parse_ebur128_metrics(data.ffmpeg.log_history) await data.ffmpeg.close() session = self._sessions.get(session_id) if session is None: - return + return None if data.chunks_received < MIN_DURATION_SECONDS: self.logger.debug( @@ -160,7 +136,7 @@ async def _finalize(self, session_id: str) -> None: data.chunks_received, MIN_DURATION_SECONDS, ) - return + return None loudness, loudness_range, true_peak = metrics if loudness is None: @@ -168,7 +144,7 @@ async def _finalize(self, session_id: str) -> None: "Could not determine loudness of %s from buffer analysis", session.streamdetails.uri, ) - return + return None if loudness <= LOUDNESS_MEASUREMENT_MIN_LUFS: # ebur128 reports ~-70 LUFS on near-silence / cancelled streams, @@ -180,21 +156,13 @@ async def _finalize(self, session_id: str) -> None: loudness, LOUDNESS_MEASUREMENT_MIN_LUFS, ) - return + return None analysis = AudioAnalysisData( loudness_integrated=round(loudness, 2), loudness_range=round(loudness_range, 2) if loudness_range is not None else None, true_peak=round(true_peak, 2) if true_peak is not None else None, ) - await self.mass.streams.audio_analysis.set_audio_analysis( - item_id=session.streamdetails.item_id, - provider_instance_id_or_domain=session.streamdetails.provider, - aa_provider_domain=self.domain, - analysis=analysis, - analysis_version=self.analysis_version, - media_type=session.streamdetails.media_type, - ) # update in-memory streamdetails so subsequent seeks use the measurement # instead of dynamic normalization session.streamdetails.loudness = round(loudness, 2) @@ -205,6 +173,29 @@ async def _finalize(self, session_id: str) -> None: loudness_range, true_peak, ) + return analysis + + async def post_analysis( + self, + streamdetails: StreamDetails, + analysis: AudioAnalysisData, + ) -> None: + """Write the ReplayGain track-gain tag back to the source file when configured.""" + if not isinstance(streamdetails.path, str) or not streamdetails.path: + return + if not self.config.get_value(CONF_WRITE_REPLAYGAIN_TAGS): + return + if analysis.loudness_integrated is None: + return + # ReplayGain 2.0: track_gain_db = -18 - loudness_lufs + track_gain_db = -18.0 - analysis.loudness_integrated + ok = await write_replaygain_track_gain(streamdetails.path, track_gain_db) + if ok: + self.logger.debug( + "Wrote ReplayGain tag to %s (gain=%.2f dB)", + streamdetails.path, + track_gain_db, + ) async def _send_eof(self, data: LoudnessSessionData) -> None: """Signal end-of-input to the session's ffmpeg process (idempotent).""" @@ -234,23 +225,3 @@ def _match_float(pattern: re.Pattern[str], text: str) -> float | None: return float(match.group(1)) except ValueError: return None - - -async def _run_ebur128_on_file( - file_path: str, audio_format: AudioFormat -) -> tuple[float | None, float | None, float | None] | None: - """Run ebur128 on a local audio file and return the (I, LRA, TP) tuple.""" - try: - async with FFMpeg( - audio_input=file_path, - input_format=audio_format, - output_format=audio_format, - audio_output="NULL", - filter_params=["ebur128=framelog=verbose"], - collect_log_history=True, - loglevel="info", - ) as ffmpeg: - await ffmpeg.wait() - return _parse_ebur128_metrics(ffmpeg.log_history) - except Exception: - return None diff --git a/music_assistant/providers/smart_fades/provider.py b/music_assistant/providers/smart_fades/provider.py index 5865cabd7b..3b2e275967 100644 --- a/music_assistant/providers/smart_fades/provider.py +++ b/music_assistant/providers/smart_fades/provider.py @@ -142,10 +142,7 @@ async def process_pcm_chunk( await self._process_block(data) async def cancel(self, session_id: str) -> None: - """Cancel a beat tracking session. - - :param session_id: The analysis session ID. - """ + """Cancel a beat tracking session.""" data = self._data.pop(session_id, None) if data: data.pcm_buffer.clear() @@ -189,11 +186,11 @@ async def _start_analysis( self.logger.debug("Started beat tracking session %s", session_id) return True - async def _finalize(self, session_id: str) -> None: + async def _finalize(self, session_id: str) -> AudioAnalysisData | None: """Finalize beat tracking and store results.""" data = self._data.pop(session_id, None) if not data: - return + return None # Flush remaining buffered PCM if data.pcm_samples: @@ -205,7 +202,7 @@ async def _finalize(self, session_id: str) -> None: data.beats_feature_blocks.append(final_feats) if not data.beats_feature_blocks: - return + return None feats = np.concatenate(data.beats_feature_blocks, axis=0) duration = data.total_pcm_samples / ANALYSIS_SAMPLE_RATE @@ -220,7 +217,7 @@ async def _finalize(self, session_id: str) -> None: beats, downbeats = await asyncio.to_thread(self._infer_beat_timings, feats) if len(beats) < 2: self.logger.debug("Not enough beats detected, skipping storage") - return + return None key, mode = await asyncio.to_thread(self._infer_musical_key, all_vqt) bpm = calculate_overall_bpm(beats) @@ -259,23 +256,15 @@ async def _finalize(self, session_id: str) -> None: mode=mode, ) - await self.mass.streams.audio_analysis.set_audio_analysis( - data.item_id, - data.provider, - self.domain, - analysis, - analysis_version=self.analysis_version, - media_type=MediaType.TRACK, - ) - self.logger.debug( - "Stored beat analysis for %s: BPM=%.1f, %d beats, %d downbeats, key=%s", + "Beat analysis for %s: BPM=%.1f, %d beats, %d downbeats, key=%s", data.item_id, bpm, len(beats), len(downbeats), f"{key} {mode}" if key else "unknown", ) + return analysis async def _process_block(self, data: SmartFadesData, *, last: bool = False) -> None: """Resample accumulated PCM buffer and extract features.""" @@ -312,12 +301,7 @@ async def _process_block(self, data: SmartFadesData, *, last: bool = False) -> N def _compute_energy_and_spectral_centroids( self, pcm_22k: np.ndarray, data: SmartFadesData ) -> None: - """Compute fine-resolution RMS energy and spectral centroid for a block. - - RMS is computed in 100ms windows (~2205 samples at 22050 Hz). - Spectral centroid is computed per hop frame (~43 frames/s). - Both are interpolated to the fixed 1800-bin output representation in _finalize. - """ + """Compute fine-resolution RMS energy and spectral centroid for a block.""" sr = ANALYSIS_SAMPLE_RATE # RMS energy in 100ms windows, including partial final window window_samples = sr // 10 # 2205 samples = 100ms diff --git a/tests/controllers/__init__.py b/tests/controllers/__init__.py new file mode 100644 index 0000000000..e161d0afaa --- /dev/null +++ b/tests/controllers/__init__.py @@ -0,0 +1 @@ +"""Tests for Music Assistant controllers.""" diff --git a/tests/controllers/streams/__init__.py b/tests/controllers/streams/__init__.py new file mode 100644 index 0000000000..6cc1fdbb53 --- /dev/null +++ b/tests/controllers/streams/__init__.py @@ -0,0 +1 @@ +"""Tests for Music Assistant stream controllers.""" diff --git a/tests/controllers/streams/test_audio_analysis.py b/tests/controllers/streams/test_audio_analysis.py new file mode 100644 index 0000000000..800b81137e --- /dev/null +++ b/tests/controllers/streams/test_audio_analysis.py @@ -0,0 +1,484 @@ +"""Tests for the AudioAnalysisController.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from music_assistant_models.enums import ContentType, StreamType +from music_assistant_models.media_items import AudioFormat + +import music_assistant.controllers.streams.audio_analysis as audio_analysis_mod +from music_assistant.controllers.streams.audio_analysis import AudioAnalysisController +from music_assistant.models.audio_analysis_provider import AudioAnalysisProvider + + +@pytest.mark.asyncio +async def test_distribute_chunk_calls_all_providers() -> None: + """_distribute_chunk must invoke process_pcm_chunk on every active provider.""" + controller = _make_controller() + session_key = "track://provider/abc" + controller._active_sessions[session_key] = {"prov-1", "prov-2"} + + p1 = _make_aa_provider("prov-1", available=True) + p2 = _make_aa_provider("prov-2", available=True) + provider_map = {"prov-1": p1, "prov-2": p2} + controller.mass.get_provider = MagicMock(side_effect=provider_map.get) # type: ignore[method-assign] + + await controller._distribute_chunk(session_key, b"\x00" * 1024) + + p1.process_pcm_chunk.assert_awaited_once_with(session_key, b"\x00" * 1024) + p2.process_pcm_chunk.assert_awaited_once_with(session_key, b"\x00" * 1024) + + +@pytest.mark.asyncio +async def test_distribute_chunk_evicts_provider_on_timeout() -> None: + """A provider whose process_pcm_chunk exceeds CHUNK_PROCESS_TIMEOUT_SECONDS is evicted.""" + controller = _make_controller() + session_key = "track://provider/abc" + controller._active_sessions[session_key] = {"slow", "fast"} + + async def _hang(*_args: object, **_kwargs: object) -> None: + await asyncio.sleep(10) + + slow = _make_aa_provider("slow", available=True, process_pcm_chunk=AsyncMock(side_effect=_hang)) + fast = _make_aa_provider("fast", available=True) + provider_map = {"slow": slow, "fast": fast} + controller.mass.get_provider = MagicMock(side_effect=provider_map.get) # type: ignore[method-assign] + + with patch.object(audio_analysis_mod, "CHUNK_PROCESS_TIMEOUT_SECONDS", 0.05): + await controller._distribute_chunk(session_key, b"\x00" * 1024) + + assert "slow" not in controller._active_sessions[session_key] + assert "fast" in controller._active_sessions[session_key] + + +@pytest.mark.asyncio +async def test_distribute_chunk_evicts_provider_on_exception() -> None: + """A provider that raises in process_pcm_chunk is evicted; others continue.""" + controller = _make_controller() + session_key = "track://provider/abc" + controller._active_sessions[session_key] = {"raises", "ok"} + + raises = _make_aa_provider( + "raises", + available=True, + process_pcm_chunk=AsyncMock(side_effect=RuntimeError("boom")), + ) + ok = _make_aa_provider("ok", available=True) + provider_map = {"raises": raises, "ok": ok} + controller.mass.get_provider = MagicMock(side_effect=provider_map.get) # type: ignore[method-assign] + + await controller._distribute_chunk(session_key, b"\x00" * 1024) + + assert "raises" not in controller._active_sessions[session_key] + assert "ok" in controller._active_sessions[session_key] + + +@pytest.mark.asyncio +async def test_get_scan_concurrency_returns_default_on_unset() -> None: + """When the config value is unset/None, fall back to DEFAULT_BACKGROUND_SCAN_CONCURRENCY.""" + controller = _make_controller() + controller.mass.config.get_raw_core_config_value = MagicMock(return_value=None) # type: ignore[method-assign] + assert controller._get_scan_concurrency() == 1 + + +@pytest.mark.asyncio +async def test_get_scan_concurrency_clamps_to_max() -> None: + """Values above 8 are clamped to 8.""" + controller = _make_controller() + controller.mass.config.get_raw_core_config_value = MagicMock(return_value=99) # type: ignore[method-assign] + assert controller._get_scan_concurrency() == 8 + + +@pytest.mark.asyncio +async def test_get_scan_concurrency_clamps_to_min() -> None: + """Values below 1 are clamped to 1.""" + controller = _make_controller() + controller.mass.config.get_raw_core_config_value = MagicMock(return_value=0) # type: ignore[method-assign] + assert controller._get_scan_concurrency() == 1 + + +def _make_stream_mock(chunks: list[bytes]) -> object: + """Return a get_media_stream mock that yields the given chunks.""" + + async def _stream( + _streamdetails: object, _pcm_format: object, **_kwargs: object + ) -> AsyncGenerator[bytes, None]: + for chunk in chunks: + yield chunk + + return _stream + + +@pytest.mark.asyncio +async def test_background_streaming_happy_path() -> None: + """PCM chunks reach providers; session is cleaned up on clean EOF.""" + controller = _make_controller() + streamdetails = _make_streamdetails(path="/music/test.flac") + p = _make_aa_provider("p1", available=True) + p.start_analysis = AsyncMock(return_value=True) + p.finalize = AsyncMock(return_value=None) + controller.mass.get_provider = MagicMock(return_value=p) # type: ignore[method-assign] + + fake_chunks = [b"\x00\x01" * 512 for _ in range(5)] + controller.mass.streams.audio.get_media_stream = _make_stream_mock(fake_chunks) # type: ignore[method-assign,assignment] + + await controller._run_background_streaming_for_track(streamdetails, [p]) + + assert p.start_analysis.await_count == 1 + assert p.process_pcm_chunk.await_count == len(fake_chunks) + # _finalize_providers pops the session key before dispatching — key must be gone + assert streamdetails.uri not in controller._active_sessions + + +@pytest.mark.asyncio +async def test_background_streaming_per_track_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + """Per-track timeout cancels providers and cleans up the session.""" + controller = _make_controller() + streamdetails = _make_streamdetails(path="/music/test.flac") + p = _make_aa_provider("p1", available=True) + p.start_analysis = AsyncMock(return_value=True) + + async def _hang_chunk(*_args: object, **_kwargs: object) -> None: + await asyncio.sleep(10) + + p.process_pcm_chunk = AsyncMock(side_effect=_hang_chunk) + controller.mass.get_provider = MagicMock(return_value=p) # type: ignore[method-assign] + + controller.mass.streams.audio.get_media_stream = _make_stream_mock([b"\x00" * 1024] * 50) # type: ignore[method-assign,assignment] + monkeypatch.setattr(audio_analysis_mod, "BACKGROUND_PER_TRACK_TIMEOUT_SECONDS", 0.2) + + await controller._run_background_streaming_for_track(streamdetails, [p]) + + assert streamdetails.uri not in controller._active_sessions + # Per-track timeout must be surfaced to the TasksController so the run ends + # as PARTIAL_SUCCESS with a retryable status. + controller.mass.tasks.add_task_failure.assert_called_once() # type: ignore[attr-defined] + failure_args = controller.mass.tasks.add_task_failure.call_args.args # type: ignore[attr-defined] + assert failure_args[0] == audio_analysis_mod.BACKGROUND_SCAN_TASK_ID + assert "Timed out" in failure_args[1] + assert streamdetails.uri in failure_args[1] + + +@pytest.mark.asyncio +async def test_background_streaming_ffmpeg_startup_failure() -> None: + """get_media_stream failure cancels providers cleanly without raising.""" + controller = _make_controller() + streamdetails = _make_streamdetails(path="/nonexistent.flac") + p = _make_aa_provider("p1", available=True) + p.start_analysis = AsyncMock(return_value=True) + controller.mass.get_provider = MagicMock(return_value=p) # type: ignore[method-assign] + + def _failing_stream(*_args: object, **_kwargs: object) -> AsyncGenerator[bytes, None]: + raise RuntimeError("ffmpeg startup failed") + + controller.mass.streams.audio.get_media_stream = _failing_stream # type: ignore[method-assign] + + # Should not raise + await controller._run_background_streaming_for_track(streamdetails, [p]) + assert streamdetails.uri not in controller._active_sessions + # Per-track exception must be surfaced to the TasksController. + controller.mass.tasks.add_task_failure.assert_called_once() # type: ignore[attr-defined] + failure_args = controller.mass.tasks.add_task_failure.call_args.args # type: ignore[attr-defined] + assert failure_args[0] == audio_analysis_mod.BACKGROUND_SCAN_TASK_ID + assert "Failed" in failure_args[1] + assert "ffmpeg startup failed" in failure_args[1] + + +def _make_streamdetails(*, path: str, item_id: str = "test-item") -> MagicMock: + sd = MagicMock() + sd.path = path + sd.uri = f"track://test/{path}" + sd.audio_format = AudioFormat( + content_type=ContentType.FLAC, + sample_rate=44100, + bit_depth=16, + channels=2, + ) + sd.item_id = item_id + sd.provider = "test-provider" + sd.media_type = MagicMock() + return sd + + +def _make_controller() -> AudioAnalysisController: + streams = MagicMock() + streams.mass = MagicMock() + streams.mass.logger.getChild.return_value = MagicMock() + return AudioAnalysisController(streams) + + +def _make_aa_provider( + instance_id: str, + *, + available: bool = True, + process_pcm_chunk: AsyncMock | None = None, +) -> MagicMock: + provider = MagicMock(spec=AudioAnalysisProvider) + provider.instance_id = instance_id + provider.available = available + provider.process_pcm_chunk = process_pcm_chunk or AsyncMock(return_value=None) + provider.cancel = AsyncMock(return_value=None) + return provider + + +@pytest.mark.asyncio +async def test_run_background_scan_uses_union_candidate_query( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The new scan loop drives _run_background_streaming_for_track per candidate.""" + controller = _make_controller() + p1 = _make_aa_provider("prov-1", available=True) + p1.domain = "p1" + p1.start_analysis = AsyncMock(return_value=True) + monkeypatch.setattr( + controller.__class__, + "providers", + property(lambda _self: [p1]), + ) + + candidates = [ + {"item_id": "track-1", "provider_instance": "filesystem_local", "missing_domains": ["p1"]}, + {"item_id": "track-2", "provider_instance": "filesystem_local", "missing_domains": ["p1"]}, + ] + monkeypatch.setattr( + controller, "_find_candidates_missing_analysis", AsyncMock(return_value=candidates) + ) + + streamdetails_list = [ + _make_streamdetails(path=f"/music/{c['item_id']}.flac", item_id=str(c["item_id"])) + for c in candidates + ] + for sd in streamdetails_list: + sd.stream_type = StreamType.LOCAL_FILE + + music_prov = MagicMock() + music_prov.available = True + music_prov.get_stream_details = AsyncMock(side_effect=streamdetails_list) + music_prov.instance_id = "filesystem_local" + controller.mass.get_provider = MagicMock(return_value=music_prov) # type: ignore[method-assign] + + streaming_calls: list[str] = [] + + async def _track_streaming(streamdetails: MagicMock, _providers: object) -> None: + streaming_calls.append(streamdetails.item_id) + + monkeypatch.setattr(controller, "_run_background_streaming_for_track", _track_streaming) + + await controller._run_background_scan() + + assert sorted(streaming_calls) == ["track-1", "track-2"] + + +@pytest.mark.asyncio +async def test_find_candidates_handles_sqlite_row_without_get( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + _find_candidates_missing_analysis must use __getitem__ not .get() on rows. + + sqlite3.Row supports only __getitem__, not .get(). This regression test + uses a row class that lacks .get() to ensure we never reintroduce the bug. + """ + controller = _make_controller() + p1 = _make_aa_provider("prov-1", available=True) + p1.domain = "loudness_analysis" + p1.available = True + monkeypatch.setattr( + controller.__class__, + "providers", + property(lambda _self: [p1]), + ) + + # Make the filesystem-providers gate succeed + fs_prov = MagicMock() + fs_prov.domain = "filesystem_local" + fs_prov.available = True + controller.mass.get_providers = MagicMock(return_value=[fs_prov]) # type: ignore[method-assign] + + class _RowNoGet: + """Mimics sqlite3.Row: __getitem__ only, no .get().""" + + def __init__(self, data: dict[str, object]) -> None: + self._d = data + + def __getitem__(self, key: str) -> object: + return self._d[key] + + # SQL filters out fully-covered tracks via NOT EXISTS + GROUP BY, so the + # rows we receive from the database only contain missing-domain pairs. + rows = [ + _RowNoGet( + { + "item_id": "track-1", + "provider_instance": "filesystem_local", + "missing_domains": "loudness_analysis", + } + ), + ] + controller.mass.music.database.get_rows_from_query = AsyncMock(return_value=rows) # type: ignore[method-assign] + + result = await controller._find_candidates_missing_analysis(["loudness_analysis"], 100) + + assert len(result) == 1 + assert result[0]["item_id"] == "track-1" + assert result[0]["missing_domains"] == ["loudness_analysis"] + + +@pytest.mark.asyncio +async def test_run_background_scan_concurrency_semaphore( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """At most CONF_BACKGROUND_SCAN_CONCURRENCY tracks run concurrently.""" + controller = _make_controller() + monkeypatch.setattr(controller, "_get_scan_concurrency", lambda: 2) + + p1 = _make_aa_provider("prov-1", available=True) + p1.domain = "p1" + p1.start_analysis = AsyncMock(return_value=True) + monkeypatch.setattr( + controller.__class__, + "providers", + property(lambda _self: [p1]), + ) + + candidates = [ + { + "item_id": f"track-{i}", + "provider_instance": "filesystem_local", + "missing_domains": ["p1"], + } + for i in range(5) + ] + monkeypatch.setattr( + controller, "_find_candidates_missing_analysis", AsyncMock(return_value=candidates) + ) + + streamdetails_list = [ + _make_streamdetails(path=f"/music/{c['item_id']}.flac") for c in candidates + ] + for sd in streamdetails_list: + sd.stream_type = StreamType.LOCAL_FILE + music_prov = MagicMock() + music_prov.available = True + music_prov.get_stream_details = AsyncMock(side_effect=streamdetails_list) + music_prov.instance_id = "filesystem_local" + controller.mass.get_provider = MagicMock(return_value=music_prov) # type: ignore[method-assign] + + in_flight = 0 + max_in_flight = 0 + + async def _track_streaming(_streamdetails: MagicMock, _providers: object) -> None: + nonlocal in_flight, max_in_flight + in_flight += 1 + max_in_flight = max(max_in_flight, in_flight) + await asyncio.sleep(0.05) + in_flight -= 1 + + monkeypatch.setattr(controller, "_run_background_streaming_for_track", _track_streaming) + + await controller._run_background_scan() + + assert max_in_flight == 2 + + +@pytest.mark.asyncio +async def test_background_streaming_cancellation_cleans_up( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """CancelledError mid-track must trigger _cancel_providers and re-raise.""" + controller = _make_controller() + streamdetails = _make_streamdetails(path="/music/test.flac") + p = _make_aa_provider("p1", available=True) + controller.mass.get_provider = MagicMock(return_value=p) # type: ignore[method-assign] + + session_key = streamdetails.uri + + async def _inner_cancelled(_session_key: str, _sd: object, _providers: object) -> None: + # Simulate the inner having registered the session before being cancelled. + controller._active_sessions[session_key] = {"p1"} + raise asyncio.CancelledError + + monkeypatch.setattr(controller, "_run_background_streaming_inner", _inner_cancelled) + + with pytest.raises(asyncio.CancelledError): + await controller._run_background_streaming_for_track(streamdetails, [p]) + + # Session must be popped and provider.cancel scheduled. + assert session_key not in controller._active_sessions + p.cancel.assert_called_once_with(session_key) + + +@pytest.mark.asyncio +async def test_run_background_scan_defers_past_run_budget( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Tracks past the run-budget deadline are deferred to the next run.""" + controller = _make_controller() + + p1 = _make_aa_provider("prov-1", available=True) + p1.domain = "p1" + monkeypatch.setattr( + controller.__class__, + "providers", + property(lambda _self: [p1]), + ) + + candidates = [ + { + "item_id": f"track-{i}", + "provider_instance": "filesystem_local", + "missing_domains": ["p1"], + } + for i in range(3) + ] + monkeypatch.setattr( + controller, "_find_candidates_missing_analysis", AsyncMock(return_value=candidates) + ) + + # Force budget to negative so every candidate is past deadline. + monkeypatch.setattr(audio_analysis_mod, "BACKGROUND_SCAN_RUN_BUDGET_SECONDS", -1) + + streaming_called = False + + async def _track_streaming(_sd: object, _providers: object) -> None: + nonlocal streaming_called + streaming_called = True + + monkeypatch.setattr(controller, "_run_background_streaming_for_track", _track_streaming) + + await controller._run_background_scan() + + assert not streaming_called + + +@pytest.mark.asyncio +async def test_close_drains_sessions_and_workers() -> None: + """close() cancels in-flight chunk workers and dispatches provider cancels.""" + controller = _make_controller() + + # Real asyncio task that swallows cancellation cleanly. + async def _busy_worker() -> None: + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + return + + worker_task = asyncio.create_task(_busy_worker()) + controller._workers["track://test/a"] = worker_task + + p = _make_aa_provider("p1", available=True) + controller.mass.get_provider = MagicMock(return_value=p) # type: ignore[method-assign] + controller._active_sessions["track://test/a"] = {"p1"} + + await controller.close() + + # Worker awaited to completion; both dicts drained. + assert worker_task.done() + assert controller._workers == {} + assert controller._active_sessions == {} + # Provider cancel scheduled with the session key. + p.cancel.assert_called_once_with("track://test/a") diff --git a/tests/core/test_audio_analysis_controller.py b/tests/core/test_audio_analysis_controller.py index 6744237a12..8721d434fb 100644 --- a/tests/core/test_audio_analysis_controller.py +++ b/tests/core/test_audio_analysis_controller.py @@ -320,14 +320,19 @@ async def test_finalized_guard_prevents_double_finalize( @pytest.mark.asyncio -async def test_provider_error_during_chunk_processing( +async def test_provider_error_during_chunk_processing_evicts_provider( controller: AudioAnalysisController, audio_buffer: AudioBuffer, mock_stream_details: MagicMock, mock_provider: MagicMock, mock_mass: MagicMock, ) -> None: - """Provider raising in process_pcm_chunk still processes remaining chunks and finalizes.""" + """Provider that raises in process_pcm_chunk is evicted from the session. + + The first chunk processes successfully. The second chunk's exception + triggers eviction. The third chunk is not delivered. The provider's + cancel hook is dispatched (replaces finalize for evicted providers). + """ call_count = 0 async def _flaky_process(_session_id: str, _chunk: bytes) -> None: @@ -340,8 +345,12 @@ async def _flaky_process(_session_id: str, _chunk: bytes) -> None: await controller.start_analysis(audio_buffer, mock_stream_details) await _send_chunks(audio_buffer, 3) await _await_tasks(mock_mass) - assert call_count == 3 - mock_provider.finalize.assert_called_once() + + # Provider was called twice: chunk 1 (success), chunk 2 (raised → evicted) + assert call_count == 2 + # Evicted provider does NOT get finalize, but does get cancel + mock_provider.finalize.assert_not_called() + mock_provider.cancel.assert_called_once() @pytest.mark.asyncio @@ -393,7 +402,7 @@ def _get_prov(pid: str) -> MagicMock: mock_mass.get_provider = MagicMock(side_effect=_get_prov) with unittest.mock.patch( - "music_assistant.controllers.streams.audio_analysis.CHUNK_PROCESS_TIMEOUT", 0.1 + "music_assistant.controllers.streams.audio_analysis.CHUNK_PROCESS_TIMEOUT_SECONDS", 0.1 ): await controller.start_analysis(audio_buffer, mock_stream_details) await _send_chunks(audio_buffer, 3) @@ -423,10 +432,11 @@ async def test_provider_rejects_analysis( @pytest.mark.asyncio async def test_finalize_cleans_up_provider_sessions() -> None: - """Verify provider._sessions is cleaned up after finalize, even if _finalize raises.""" + """Verify provider._sessions is cleaned up after finalize.""" provider = MagicMock(spec=AudioAnalysisProvider) + provider.logger = MagicMock() provider._sessions = {"test_session": MagicMock(spec=AnalysisSessionData)} - provider._finalize = AsyncMock() + provider._finalize = AsyncMock(return_value=None) await AudioAnalysisProvider.finalize(provider, "test_session") @@ -467,13 +477,19 @@ async def test_provider_start_analysis_uses_media_type_for_version_gating() -> N @pytest.mark.asyncio -async def test_finalize_cleans_up_sessions_on_error() -> None: - """Verify provider._sessions is cleaned up even when _finalize raises.""" +async def test_finalize_swallows_finalize_exception_and_cleans_up() -> None: + """Verify provider._sessions is cleaned up even when _finalize raises. + + The finalize wrapper catches _finalize exceptions and logs ERROR; it must + not propagate them to the controller, and must still pop the session. + """ provider = MagicMock(spec=AudioAnalysisProvider) + provider.logger = MagicMock() provider._sessions = {"test_session": MagicMock(spec=AnalysisSessionData)} provider._finalize = AsyncMock(side_effect=RuntimeError("analysis failed")) - with pytest.raises(RuntimeError, match="analysis failed"): - await AudioAnalysisProvider.finalize(provider, "test_session") + # MUST NOT raise — exception is swallowed and logged + await AudioAnalysisProvider.finalize(provider, "test_session") assert "test_session" not in provider._sessions + provider.logger.error.assert_called_once() diff --git a/tests/fixtures/audio/short_test.flac b/tests/fixtures/audio/short_test.flac new file mode 100644 index 0000000000..6fb0726a95 Binary files /dev/null and b/tests/fixtures/audio/short_test.flac differ diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000000..314980ed31 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests that exercise real external processes (ffmpeg).""" diff --git a/tests/integration/test_background_scan_streaming.py b/tests/integration/test_background_scan_streaming.py new file mode 100644 index 0000000000..b089dd141f --- /dev/null +++ b/tests/integration/test_background_scan_streaming.py @@ -0,0 +1,153 @@ +"""End-to-end integration test for the streaming background scan.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, MagicMock + +import pytest +from music_assistant_models.enums import ContentType, MediaType +from music_assistant_models.media_items import AudioFormat + +from music_assistant.constants import CONF_LOG_LEVEL +from music_assistant.controllers.streams.audio_analysis import AudioAnalysisController +from music_assistant.helpers.ffmpeg import FFMpeg +from music_assistant.models.audio_analysis import AudioAnalysisData +from music_assistant.providers.loudness_analysis.provider import ( + CONF_WRITE_REPLAYGAIN_TAGS, + LoudnessAnalysisProvider, +) + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from music_assistant_models.streamdetails import StreamDetails + +FIXTURE_AUDIO = Path(__file__).parent.parent / "fixtures" / "audio" / "short_test.flac" + + +async def _real_get_media_stream( + sd: StreamDetails, pcm_format: AudioFormat, **_kwargs: object +) -> AsyncGenerator[bytes, None]: + """Real-ffmpeg stand-in for mass.streams.audio.get_media_stream. + + Mirrors the wait-then-close pattern in audio.py:466-528 so close() doesn't + hit the SIGINT path on Windows when the process is still running. + """ + assert isinstance(sd.path, str) + proc = FFMpeg( + audio_input=sd.path, + input_format=sd.audio_format, + output_format=pcm_format, + collect_log_history=True, + ) + try: + await proc.start() + async for chunk in proc.iter_chunked(pcm_format.pcm_sample_size): + yield chunk + await proc.wait_with_timeout(5) + finally: + await proc.close() + + +@pytest.mark.skipif(not FIXTURE_AUDIO.exists(), reason="fixture FLAC missing") +async def test_streaming_background_scan_loudness_end_to_end() -> None: + """ + Drive a real LoudnessAnalysisProvider through _run_background_streaming_for_track. + + Verifies: + - An audio_analysis row is written (captured via mocked set_audio_analysis) + - The analysis contains a plausible loudness value + - The session is removed from _active_sessions + """ + captured_rows: list[tuple[str, str, AudioAnalysisData]] = [] + + async def _capture_set(**kwargs: object) -> None: + captured_rows.append( + ( + str(kwargs["item_id"]), + str(kwargs["aa_provider_domain"]), + kwargs["analysis"], # type: ignore[arg-type] + ) + ) + + mass = MagicMock() + mass.streams.audio_analysis.set_audio_analysis = AsyncMock(side_effect=_capture_set) + mass.streams.audio_analysis.get_audio_analysis_version = AsyncMock(return_value=None) + + manifest = MagicMock() + manifest.domain = "loudness_analysis" + config = MagicMock() + config.instance_id = "loudness_analysis_test" + config.values = {} + # write_replaygain_tags=False so post_analysis is a no-op; log_level must be a valid string + config.get_value = MagicMock( + side_effect=lambda key: { + CONF_LOG_LEVEL: "GLOBAL", + CONF_WRITE_REPLAYGAIN_TAGS: False, + }.get(key, "GLOBAL") + ) + + provider = LoudnessAnalysisProvider(mass, manifest, config, supported_features=set()) + # domain comes from manifest.domain, instance_id from config.instance_id (both already set) + provider.available = True + + # Wire get_provider so the controller can look up the provider by instance_id + mass.get_provider = MagicMock(return_value=provider) + + # create_task schedules real asyncio tasks so finalize can run + created_tasks: list[asyncio.Task[None]] = [] + + def _create_task(coro: object) -> asyncio.Task[None]: + task: asyncio.Task[None] = asyncio.create_task(coro) # type: ignore[arg-type] + created_tasks.append(task) + return task + + mass.create_task = MagicMock(side_effect=_create_task) + mass.logger.getChild = MagicMock(return_value=MagicMock()) + + mass.streams.audio.get_media_stream = _real_get_media_stream + + streams = MagicMock() + streams.mass = mass + controller = AudioAnalysisController(streams) + + streamdetails = MagicMock() + streamdetails.path = str(FIXTURE_AUDIO) + streamdetails.uri = f"track://test/{FIXTURE_AUDIO.name}" + streamdetails.audio_format = AudioFormat( + content_type=ContentType.FLAC, + sample_rate=44100, + bit_depth=16, + channels=1, + ) + streamdetails.item_id = "fixture-track" + streamdetails.provider = "filesystem_local" + streamdetails.media_type = MediaType.TRACK + # None is not VolumeNormalizationMode.DISABLED, so loudness analysis proceeds + streamdetails.volume_normalization_mode = None + + await controller._run_background_streaming_for_track(streamdetails, [provider]) + + # Drain any tasks spawned by _finalize_providers (finalize is dispatched via create_task) + for _ in range(5): + await asyncio.sleep(0.05) + if created_tasks: + await asyncio.gather(*created_tasks, return_exceptions=True) + + # Assertions + assert len(captured_rows) == 1, f"expected exactly one analysis row; got {len(captured_rows)}" + item_id, aa_domain, analysis = captured_rows[0] + assert item_id == "fixture-track" + assert aa_domain == "loudness_analysis" + assert analysis.loudness_integrated is not None, "loudness_integrated must be populated" + measured = analysis.loudness_integrated + assert -70.0 < measured < 0.0, ( + f"loudness {measured} LUFS is outside the plausible range (-70, 0)" + ) + + assert streamdetails.uri not in controller._active_sessions, ( + "session must be cleaned up from _active_sessions" + ) diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000000..c28fa33304 --- /dev/null +++ b/tests/models/__init__.py @@ -0,0 +1 @@ +"""Tests for Music Assistant model base classes.""" diff --git a/tests/models/test_audio_analysis_provider.py b/tests/models/test_audio_analysis_provider.py new file mode 100644 index 0000000000..ea83ee763e --- /dev/null +++ b/tests/models/test_audio_analysis_provider.py @@ -0,0 +1,121 @@ +"""Tests for the AudioAnalysisProvider base class lifecycle.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from music_assistant.models.audio_analysis import AudioAnalysisData +from music_assistant.models.audio_analysis_provider import AudioAnalysisProvider + +if TYPE_CHECKING: + from music_assistant_models.media_items import AudioFormat + from music_assistant_models.streamdetails import StreamDetails + + +class _StubProvider(AudioAnalysisProvider): + """Minimal concrete provider for base-class tests.""" + + async def _start_analysis( + self, session_id: str, streamdetails: StreamDetails, audio_format: AudioFormat + ) -> bool: + return True + + async def process_pcm_chunk(self, session_id: str, pcm_chunk: bytes) -> None: + return None + + async def _finalize(self, session_id: str) -> AudioAnalysisData | None: + return None + + +def _make_provider() -> _StubProvider: + mass = MagicMock() + mass.streams.audio_analysis.get_audio_analysis_version = AsyncMock(return_value=None) + mass.streams.audio_analysis.set_audio_analysis = AsyncMock() + manifest = MagicMock() + manifest.domain = "test_stub_provider" + config = MagicMock() + config.get_value = MagicMock(return_value="GLOBAL") + return _StubProvider(mass, manifest, config, supported_features=set()) + + +@pytest.mark.asyncio +async def test_post_analysis_default_is_noop() -> None: + """Default post_analysis must be a no-op that returns None.""" + provider = _make_provider() + streamdetails = MagicMock() + analysis = AudioAnalysisData() + await provider.post_analysis(streamdetails, analysis) + + +@pytest.mark.asyncio +async def test_finalize_calls_post_analysis_when_finalize_returns_analysis() -> None: + """When _finalize returns analysis, finalize must call post_analysis with it.""" + provider = _make_provider() + streamdetails = MagicMock() + audio_format = MagicMock() + analysis = AudioAnalysisData(loudness_integrated=-14.0) + + provider._finalize = AsyncMock(return_value=analysis) # type: ignore[method-assign] + provider.post_analysis = AsyncMock(return_value=None) # type: ignore[method-assign] + + await provider.start_analysis("session-1", streamdetails, audio_format) + await provider.finalize("session-1") + + provider.post_analysis.assert_awaited_once_with(streamdetails, analysis) + assert "session-1" not in provider._sessions + + +@pytest.mark.asyncio +async def test_finalize_skips_post_analysis_when_finalize_returns_none() -> None: + """When _finalize returns None, post_analysis must NOT be called.""" + provider = _make_provider() + streamdetails = MagicMock() + audio_format = MagicMock() + + provider._finalize = AsyncMock(return_value=None) # type: ignore[method-assign] + provider.post_analysis = AsyncMock(return_value=None) # type: ignore[method-assign] + + await provider.start_analysis("session-2", streamdetails, audio_format) + await provider.finalize("session-2") + + provider.post_analysis.assert_not_awaited() + assert "session-2" not in provider._sessions + + +@pytest.mark.asyncio +async def test_finalize_swallows_finalize_exception_and_skips_post_analysis() -> None: + """If _finalize raises, post_analysis must not be called and the exception must not propagate.""" + provider = _make_provider() + streamdetails = MagicMock() + audio_format = MagicMock() + + provider._finalize = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + provider.post_analysis = AsyncMock(return_value=None) # type: ignore[method-assign] + + await provider.start_analysis("session-3", streamdetails, audio_format) + await provider.finalize("session-3") + + provider.post_analysis.assert_not_awaited() + assert "session-3" not in provider._sessions + + +@pytest.mark.asyncio +async def test_finalize_swallows_post_analysis_exception() -> None: + """post_analysis raising must be caught; the analysis row stays valid.""" + provider = _make_provider() + streamdetails = MagicMock() + audio_format = MagicMock() + analysis = AudioAnalysisData() + + provider._finalize = AsyncMock(return_value=analysis) # type: ignore[method-assign] + provider.post_analysis = AsyncMock(side_effect=RuntimeError("tag write failed")) # type: ignore[method-assign] + + await provider.start_analysis("session-4", streamdetails, audio_format) + # Must not raise + await provider.finalize("session-4") + + provider.post_analysis.assert_awaited_once() + assert "session-4" not in provider._sessions diff --git a/tests/providers/loudness_analysis/__init__.py b/tests/providers/loudness_analysis/__init__.py new file mode 100644 index 0000000000..b777edd3d6 --- /dev/null +++ b/tests/providers/loudness_analysis/__init__.py @@ -0,0 +1 @@ +"""Tests for the loudness_analysis provider.""" diff --git a/tests/providers/loudness_analysis/test_provider.py b/tests/providers/loudness_analysis/test_provider.py new file mode 100644 index 0000000000..40e981027d --- /dev/null +++ b/tests/providers/loudness_analysis/test_provider.py @@ -0,0 +1,206 @@ +"""Tests for the LoudnessAnalysisProvider._finalize return-value contract.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from music_assistant_models.enums import MediaType + +from music_assistant.constants import CONF_LOG_LEVEL +from music_assistant.models.audio_analysis import AudioAnalysisData +from music_assistant.models.audio_analysis_provider import AnalysisSessionData +from music_assistant.providers.loudness_analysis.provider import ( + CONF_WRITE_REPLAYGAIN_TAGS, + MIN_DURATION_SECONDS, + LoudnessAnalysisProvider, + LoudnessSessionData, +) + + +def _make_provider() -> LoudnessAnalysisProvider: + """Construct a LoudnessAnalysisProvider with mocked MA infrastructure.""" + mass = MagicMock() + mass.streams.audio_analysis.get_audio_analysis_version = AsyncMock(return_value=None) + mass.streams.audio_analysis.set_audio_analysis = AsyncMock() + manifest = MagicMock() + manifest.domain = "loudness_analysis" + config = MagicMock() + config.instance_id = "loudness_analysis_test" + config.get_value = MagicMock(return_value="GLOBAL") + config.values = {} + return LoudnessAnalysisProvider(mass, manifest, config, set()) + + +def _make_session_data() -> tuple[LoudnessSessionData, MagicMock]: + """Return (LoudnessSessionData with mocked ffmpeg, streamdetails mock).""" + streamdetails = MagicMock() + streamdetails.item_id = "track-1" + streamdetails.provider = "test_provider" + streamdetails.uri = "test://track-1" + streamdetails.media_type = MediaType.TRACK + + ffmpeg = MagicMock() + ffmpeg.wait = AsyncMock() + ffmpeg.close = AsyncMock() + ffmpeg.write_eof = AsyncMock() + ffmpeg.log_history = [] + + session_data = LoudnessSessionData(ffmpeg=ffmpeg) + return session_data, streamdetails + + +@pytest.mark.asyncio +async def test_finalize_returns_analysis_on_success(monkeypatch: pytest.MonkeyPatch) -> None: + """_finalize must return AudioAnalysisData with the parsed metrics when analysis succeeds.""" + provider = _make_provider() + session_id = "test-session-success" + + session_data, streamdetails = _make_session_data() + session_data.chunks_received = MIN_DURATION_SECONDS + 1 + session_data.eof_sent = True # already sent, _send_eof will be a no-op + + provider._data[session_id] = session_data + provider._sessions[session_id] = AnalysisSessionData( + streamdetails=streamdetails, + audio_format=MagicMock(), + ) + + # Patch _parse_ebur128_metrics to return a valid result above the threshold + monkeypatch.setattr( + "music_assistant.providers.loudness_analysis.provider._parse_ebur128_metrics", + lambda _log: (-14.5, 7.2, -1.2), + ) + + result = await provider._finalize(session_id) + + assert isinstance(result, AudioAnalysisData) + assert result.loudness_integrated == -14.5 + + +@pytest.mark.asyncio +async def test_finalize_returns_none_when_insufficient_duration() -> None: + """_finalize must return None when chunks_received is below MIN_DURATION_SECONDS.""" + provider = _make_provider() + session_id = "test-session-short" + + session_data, streamdetails = _make_session_data() + session_data.chunks_received = MIN_DURATION_SECONDS - 1 + session_data.eof_sent = True + + provider._data[session_id] = session_data + provider._sessions[session_id] = AnalysisSessionData( + streamdetails=streamdetails, + audio_format=MagicMock(), + ) + + result = await provider._finalize(session_id) + + assert result is None + + +# --------------------------------------------------------------------------- +# post_analysis tests +# --------------------------------------------------------------------------- + + +def _make_loudness_provider(*, write_replaygain_tags: bool) -> LoudnessAnalysisProvider: + """Construct a LoudnessAnalysisProvider with a config gated on write_replaygain_tags.""" + mass = MagicMock() + manifest = MagicMock() + manifest.domain = "loudness_analysis" + config = MagicMock() + config.instance_id = "loudness_analysis_test" + config.values = {} + config.get_value = MagicMock( + side_effect=lambda key: { + CONF_LOG_LEVEL: "GLOBAL", + CONF_WRITE_REPLAYGAIN_TAGS: write_replaygain_tags, + }.get(key, "GLOBAL") + ) + return LoudnessAnalysisProvider(mass, manifest, config, supported_features=set()) + + +@pytest.mark.asyncio +async def test_post_analysis_writes_tag_when_path_writable_and_config_on( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """post_analysis writes ReplayGain tag when path is filesystem-writable AND config is on.""" + provider = _make_loudness_provider(write_replaygain_tags=True) + streamdetails = MagicMock() + streamdetails.path = "/music/test.flac" + analysis = AudioAnalysisData(loudness_integrated=-14.0) + + write_mock = AsyncMock(return_value=True) + monkeypatch.setattr( + "music_assistant.providers.loudness_analysis.provider.write_replaygain_track_gain", + write_mock, + ) + + await provider.post_analysis(streamdetails, analysis) + + # ReplayGain 2.0: track_gain_db = -18 - loudness_lufs = -18 - (-14) = -4 + write_mock.assert_awaited_once_with("/music/test.flac", -4.0) + + +@pytest.mark.asyncio +async def test_post_analysis_skips_when_path_not_writable( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """post_analysis is a no-op when streamdetails.path is None or non-string.""" + provider = _make_loudness_provider(write_replaygain_tags=True) + streamdetails = MagicMock() + streamdetails.path = None + analysis = AudioAnalysisData(loudness_integrated=-14.0) + + write_mock = AsyncMock(return_value=True) + monkeypatch.setattr( + "music_assistant.providers.loudness_analysis.provider.write_replaygain_track_gain", + write_mock, + ) + + await provider.post_analysis(streamdetails, analysis) + + write_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_post_analysis_skips_when_config_off( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """post_analysis is a no-op when write_replaygain_tags config is False.""" + provider = _make_loudness_provider(write_replaygain_tags=False) + streamdetails = MagicMock() + streamdetails.path = "/music/test.flac" + analysis = AudioAnalysisData(loudness_integrated=-14.0) + + write_mock = AsyncMock(return_value=True) + monkeypatch.setattr( + "music_assistant.providers.loudness_analysis.provider.write_replaygain_track_gain", + write_mock, + ) + + await provider.post_analysis(streamdetails, analysis) + + write_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_post_analysis_skips_when_loudness_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """post_analysis is a no-op when analysis.loudness_integrated is None.""" + provider = _make_loudness_provider(write_replaygain_tags=True) + streamdetails = MagicMock() + streamdetails.path = "/music/test.flac" + analysis = AudioAnalysisData(loudness_integrated=None) + + write_mock = AsyncMock(return_value=True) + monkeypatch.setattr( + "music_assistant.providers.loudness_analysis.provider.write_replaygain_track_gain", + write_mock, + ) + + await provider.post_analysis(streamdetails, analysis) + + write_mock.assert_not_awaited() diff --git a/tests/providers/smart_fades/test_provider.py b/tests/providers/smart_fades/test_provider.py index 9881766f22..df46da1cf0 100644 --- a/tests/providers/smart_fades/test_provider.py +++ b/tests/providers/smart_fades/test_provider.py @@ -3,12 +3,14 @@ from __future__ import annotations from pathlib import Path -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, Mock, patch +import numpy as np import pytest from music_assistant_models.enums import ContentType, MediaType from music_assistant_models.media_items import AudioFormat +from music_assistant.models.audio_analysis import AudioAnalysisData from music_assistant.providers.smart_fades.provider import SmartFadesProvider FIXTURE_DIR = Path(__file__).parent / "fixtures" @@ -147,7 +149,7 @@ async def test_beat_detection(provider: SmartFadesProvider, mass_mock: Mock) -> # Verify set_audio_analysis was called with correct data set_aa_mock = mass_mock.streams.audio_analysis.set_audio_analysis set_aa_mock.assert_awaited_once() - analysis = set_aa_mock.call_args[0][3] # 4th positional arg: analysis + analysis = set_aa_mock.call_args.kwargs["analysis"] beats = analysis.beats downbeats = analysis.downbeats @@ -205,7 +207,7 @@ async def test_extended_analysis_fields(provider: SmartFadesProvider, mass_mock: await provider.finalize(session_id) set_aa_mock = mass_mock.streams.audio_analysis.set_audio_analysis - analysis = set_aa_mock.call_args[0][3] + analysis = set_aa_mock.call_args.kwargs["analysis"] # Energy curve should be 1800 bins, normalized to [0, 1] assert analysis.rms_energy is not None @@ -240,3 +242,73 @@ async def test_extended_analysis_fields(provider: SmartFadesProvider, mass_mock: # BPM and beats should still be correct assert analysis.bpm is not None assert 115 < analysis.bpm < 125 + + +async def test_finalize_returns_audio_analysis_data(provider: SmartFadesProvider) -> None: + """Test that _finalize returns an AudioAnalysisData on success.""" + audio_format = AudioFormat( + content_type=ContentType.PCM_F32LE, + bit_depth=32, + sample_rate=44100, + channels=2, + ) + + stream_details = Mock() + stream_details.item_id = "test_finalize_return" + stream_details.provider = "test" + stream_details.queue_id = "test" + stream_details.uri = "test://finalize_return" + stream_details.media_type = MediaType.TRACK + + session_id = "test:test:test_finalize_return" + await provider.start_analysis(session_id, stream_details, audio_format) + + pcm_data = FIXTURE_PCM.read_bytes() + chunk_size = 44100 * 2 * 4 + offset = 0 + while offset < len(pcm_data): + chunk = pcm_data[offset : offset + chunk_size] + await provider.process_pcm_chunk(session_id, chunk) + offset += chunk_size + + result = await provider._finalize(session_id) + + assert isinstance(result, AudioAnalysisData) + + +async def test_finalize_returns_none_on_early_exit(provider: SmartFadesProvider) -> None: + """Test that _finalize returns None when not enough beats are detected.""" + audio_format = AudioFormat( + content_type=ContentType.PCM_F32LE, + bit_depth=32, + sample_rate=44100, + channels=2, + ) + + stream_details = Mock() + stream_details.item_id = "test_finalize_none" + stream_details.provider = "test" + stream_details.queue_id = "test" + stream_details.uri = "test://finalize_none" + stream_details.media_type = MediaType.TRACK + + session_id = "test:test:test_finalize_none" + await provider.start_analysis(session_id, stream_details, audio_format) + + pcm_data = FIXTURE_PCM.read_bytes() + chunk_size = 44100 * 2 * 4 + offset = 0 + while offset < len(pcm_data): + chunk = pcm_data[offset : offset + chunk_size] + await provider.process_pcm_chunk(session_id, chunk) + offset += chunk_size + + # Patch _infer_beat_timings to return fewer than 2 beats → triggers early exit + with patch.object( + provider, + "_infer_beat_timings", + return_value=(np.array([0.5]), np.array([])), + ): + result = await provider._finalize(session_id) + + assert result is None