diff --git a/.gitignore b/.gitignore index 69f03c827..ced743546 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,4 @@ examples/test_streaming/ # See docs/rust_data_daemon_development.md#packaging-the-wheel. neuracore/data_daemon/bin/ neuracore/data_daemon/_native_producer*.so +neuracore/core/streaming/p2p/_native_webrtc*.so diff --git a/neuracore-dictionary.txt b/neuracore-dictionary.txt index 3eed49d11..c44e60009 100644 --- a/neuracore-dictionary.txt +++ b/neuracore-dictionary.txt @@ -1,18 +1,65 @@ +Autoencoders +CLOEXEC +CNNMLPP +CREAT +Colormaps +DONTNEED +Deque +EADDRINUSE +ENOENT +EPERM +EPIPE +ESRCH +EWOULDBLOCK +Emika +FRACT +FRANKA +Irqfx +MJFC +Mish +Mujoco +NCHW +Nonblock +PJRT +POLLIN +PRNN +Pieb +QTBD +Qbcaa +Qtoqmzp +Qwen +RDWR +REMB +Robotiq +SAVPF +SCTP +SRTP +SSRC +Safetensors +TDVB +TTYNTK +UNITREE +URDF +Unet's +VIPERX +Vaswani +Vsqo +WRONLY +ZBWs absolutises absolutising acked acks adarms -ADARMS addfinalizer addinivalue agentview -agilex aiolimiter aiortc allclose allocvec -altclip +analyzeduration +annexb asarray ascontiguousarray assimp @@ -20,7 +67,6 @@ asyncio atfork attns autocast -Autoencoders autoregressively autoset autouse @@ -28,18 +74,25 @@ avgpool axvline backpressure bfloat +bframes bgra bibtex bigendian bigym +bindgen blit blowaway bodyless Brawner broadcastable +bufsize buildtool +burstiness byteswap +cabac calcsize +cand +cands caplog capsys castagnoli @@ -57,18 +110,13 @@ chonk chrono CLIPMLP clippy -CLOEXEC closedness cmap cmaps cmeel cnnmlp -CNNMLP -CNNMLPP colab -Colab colcon -Colormaps colorspace colwise conaffinity @@ -76,13 +124,11 @@ concatenator concats condim condvar +conftest connectionstatechange conq -Conq contype -Conv convnet -CREAT ctrllimited ctrlrange cuda @@ -90,13 +136,17 @@ cudnn daemonising daemonization dashmap +datachannel +dataclass dataconfig dataformats dataid +dcep ddim -DDIM ddpm DDPM +delenv +deblocking debouncer delenv demux @@ -106,42 +156,36 @@ demuxes denoise denoised denoising -Deque +depacketize +depacketizer +depacketizes descheduled deserialise +desync devnull diaginertia dinov dinov2 distclass distilbert +dlsr docstrings doctests -DONTNEED +drainable +dtls dtype -EADDRINUSE eigenpy elementwise -elems embs -Emika -ENOENT enoexec -ENOTDIR -EPERM -EPIPE -eprintln erfinv errno -ESRCH -EWOULDBLOCK excinfo execv expanduser extr extractall extrinsics -Extrinsics faceadr facecolor facenum @@ -156,6 +200,7 @@ filtergraph finetune finetuning finfo +fippo fixturenames fogend fogstart @@ -164,10 +209,7 @@ forcelimited forcerange fourcc fovy -FRACT framecode -FRANKA -freqs frictionloss fromarray frombuffer @@ -183,7 +225,6 @@ genpts geomadr geomid geomnum -getbuffer getpgid getpid gettid @@ -195,31 +236,24 @@ Groot hdlc hookwrapper hparams -hstack huggingface hyperparameters iceoryx -idat -iend -ihdr -iiwa +icetransport +idempotent +idempotently imageio -imagenet imgmsg imgs -imread imshow inertiafromgeom inertiagrouprange inproc inspectable +interarrival ipython iquat -Irqfx itemsize -jaco -jdata -jname jntadr jntid jntnum @@ -230,10 +264,10 @@ jtps keepalive keepdim kernelspec +keyint keypoints killpg -KINOVA -kwargs +kurento lavfi layernorm lecun @@ -242,51 +276,58 @@ LEROBOT levelname libavformat libc +libclang libcs libgl libglew libglib +libjuice +libneuracore libosmesa +libpython +libsrtp +libssl libx linalg listconfig llava -Llava logdir loglevel loglik logsigmoid logvar +loopback +macroblocks makereport maxpool +maxrate maxs -mcap -MCAP +mdns meanpooling -memfd -meshgrid metadatas metafunc metas +mids +midsession mimsave -miniz -Mish mjcf -MJFC mline mocap +monkeypatched moov movflags mpng mpsa mpsc -Mujoco +msid +mtap mujocoinclude multihead -Multihead multinode multirun muxer +nack +nals nans nbconvert nbformat @@ -294,18 +335,18 @@ nbody nbytes nccl ncdata -NCHW ncon +ncwebrtc ndarray neginf +netem +netns neuraco neuracore -newbyteorder nhead nheads njoints nokey -Nonblock noprint nostdin nprocs @@ -319,13 +360,21 @@ oneshot openarm openarm_description opencv +oneshot +onicecandidate +ontrack +openarm openpi -OPENPI +openssl optim osmesa -outc outut packb +packetization +packetize +packetized +packetizer +packetizes paligemma parentbody parentid @@ -334,53 +383,58 @@ pathlib pathsep pbar pbtxt +peerconnection perceptrons -pgoa +pframe pidfile -Pieb pinnochio -PJRT -plotly +pkts +playout +playsinline +pli pointcloud -POLLIN -pooler popleft posinf +pranswer precheck preds preexec +preopen +prereqs pretrained pretraining prio +probesize proprio proprios PSNR pthread pyav pycache +pyclass pydantic +pydict pyfunction pygments +pymethods pymodule +pyo3 pyquaternion pytest -Qbcaa +qdisc qpos qposadr -QTBD -Qtoqmzp quadprog qvel -Qwen randn -randperm rawvideo +rbsp rclpy -RCVTIMEO -RDEM -RDWR reannounce -rels +recvonly +remb +reneg +renegotiation renice renicing reqwest @@ -388,14 +442,14 @@ reraises restartability resumably RETRYABLE +resumably +rfind rgba rgbd rgbs rlds rlib -Robotiq rosdep -rotvec roundoff rowwise rposition @@ -403,16 +457,15 @@ rsplit rtype rustls rustup -Safetensors -scanline -scanlines +scenecut schematypens -SCTP sdecode +sdes sdist sdpa secho sencode +sendonly seqlen serde sess @@ -421,8 +474,6 @@ setsid setuptools shadowsize siglip -Siglip -SIGLIP silu softmax solimp @@ -430,6 +481,9 @@ solref splitn sqlx squaredcos +srtp +ssrc +stap startcode startcodes staticmethod @@ -446,7 +500,6 @@ temb tensorboard testsrc tfds -TFDS tfrecord thiserror thres @@ -458,67 +511,65 @@ tmpfs tobytes tolist torchdynamo -torchserve -torq -tqdm traj +trickle triu truecolour trunc tryfirst -TTYNTK typer UFACTORY ultrawide unet -Unet -Unet's unflushed uninit unistd -UNITREE unitreeh unitreeh1 unniced unnormalization +unpaced unpackb unparseable unregisterable upserts -URDF urdfdom usefixtures userspace +usrsctp utaustin varint -Vaswani vdecode vencode vertadr vertnum -VIPERX -viser vlln -Vsqo vsync waitpid wakelock -WAKELOCK +webrtc widowx worldbody writeback -WRONLY wxyz -XARM xdata +xfail +xfailed +xfails xlabel xmat xmls xmlstr +xpass xquat xyzrgb xyzw ylabel -yourdfpy -ZBWs +zerolatency znear +baselining +multislice +ppid +pytestmark +respawn +shortenable diff --git a/neuracore/core/streaming/data_stream.py b/neuracore/core/streaming/data_stream.py index 0520ddfff..439af3188 100644 --- a/neuracore/core/streaming/data_stream.py +++ b/neuracore/core/streaming/data_stream.py @@ -123,6 +123,7 @@ def _handle_ensure_producer_channel(self, context: DataRecordingContext) -> None self._producer_channel.start_recording_session( recording_id=context.recording_id ) + self._on_producer_channel_ready() def prepare_recording_stopped(self) -> tuple[ProducerChannel | None, int]: """Mark the producer channel as stopping and return it. diff --git a/neuracore/core/streaming/p2p/provider/client_provider_stream_manager.py b/neuracore/core/streaming/p2p/provider/client_provider_stream_manager.py index 472a6c97e..48165803a 100644 --- a/neuracore/core/streaming/p2p/provider/client_provider_stream_manager.py +++ b/neuracore/core/streaming/p2p/provider/client_provider_stream_manager.py @@ -9,6 +9,7 @@ import logging from uuid import uuid4 +import numpy as np from aiohttp import ClientSession from neuracore_types import ( DataType, @@ -27,14 +28,25 @@ ) from neuracore.core.streaming.p2p.enabled_manager import EnabledManager from neuracore.core.streaming.p2p.provider.json_source import JSONSource +from neuracore.core.streaming.p2p.provider.native_broadcast_provider import ( + NativeBroadcastProvider, +) +from neuracore.core.streaming.p2p.webrtc_selection import ( + load_native, + rust_webrtc_enabled, +) from neuracore.core.utils.background_coroutine_tracker import BackgroundCoroutineTracker from .global_live_data_enabled import get_provide_live_data_enabled_manager from .provider_connection import PeerToPeerProviderConnection -from .video_source import DepthVideoSource, VideoSource +from .video_source import STREAMING_FPS, DepthVideoSource, VideoSource logger = logging.getLogger(__name__) +# How often the native pump drains signaling events and re-submits frames for +# the shared encode (matches the streaming frame rate). +NATIVE_PUMP_INTERVAL_S = 1 / STREAMING_FPS + class ClientProviderStreamManager(BaseP2PStreamManager): """Manages WebRTC streaming connections for robot sensor data. @@ -88,6 +100,67 @@ def __init__( self.tracks: list[VideoSource] = [] self.track_metadata: dict[str, RobotStreamTrack] = {} + # The new Rust stack: one Broadcaster fanned out to every browser, the + # legacy aiortc per-connection path stays untouched when the flag is off. + self._native: NativeBroadcastProvider | None = None + self._native_enabled: dict[str, EnabledManager] = {} + self._native_sources: dict[str, VideoSource] = {} + if rust_webrtc_enabled(): + self._setup_native() + + def _setup_native(self) -> None: + """Stand up the native Broadcaster and start its drain/feed pump.""" + broadcaster = load_native().Broadcaster() + self._native = NativeBroadcastProvider( + broadcaster, self._send_native_signal, browser_facing=True + ) + self.background_tracker.submit_background_coroutine(self._native_pump_loop()) + + def _send_native_signal( + self, + connection_id: str, + remote_stream_id: str, + message_type: MessageType, + data: str, + ) -> None: + """Deliver a broadcaster signaling event over the web transport.""" + self.background_tracker.submit_background_coroutine( + self.client_session.post( + f"{API_URL}/org/{self.org_id}/signalling/message/submit", + headers=self.auth.get_headers(), + json=HandshakeMessage( + connection_id=connection_id, + from_id=self.local_stream_id, + to_id=remote_stream_id, + type=message_type, + data=data, + ).model_dump(mode="json"), + ) + ) + + @staticmethod + def _native_frame(source: VideoSource) -> np.ndarray: + """Latest frame as a C-contiguous (H, W, 3) uint8 array for submit_frame.""" + return np.ascontiguousarray(source.get_last_frame().to_ndarray(format="rgb24")) + + def _native_send_json(self, label: str, message: str) -> None: + """Fan one JSON state update to every browser, guarded against teardown.""" + if self._native is None: + return + self._native.send_json(label, message) + + async def _native_pump_loop(self) -> None: + """Drain broadcaster events and re-feed the shared encode at the FPS.""" + assert self._native is not None + while self.streaming.is_enabled(): + try: + self._native.pump_once() + for mid, source in list(self._native_sources.items()): + self._native.submit_frame(mid, self._native_frame(source)) + except Exception: + logger.exception("native webrtc pump error") + await asyncio.sleep(NATIVE_PUMP_INTERVAL_S) + @property def enabled_manager(self) -> EnabledManager: """Get the enabled manager for this streaming manager. @@ -113,7 +186,16 @@ def get_video_source( if sensor_key in self.video_tracks_cache: return self.video_tracks_cache[sensor_key] - mid = str(len(self.tracks)) + # The Producer owns the mid: this one value is registered in available_robots + # (submit_track) and handed to the native Broadcaster, which uses it verbatim + # as the SDP m-line mid. So the SSE manifest the browser keys on and the + # offer's a=mid always agree, and the browser's identityForMid succeeds for + # both the initial connection and a mid-session add. The "v" prefix keeps the + # mid out of the data channel's m-line namespace (the control/json channels + # take a=mid:0); a bare index would collide with that, and libdatachannel + # would silently drop the colliding video track. self.tracks only grows, so + # the index is stable and never reused per source. + mid = f"v{len(self.tracks)}" self.background_tracker.submit_background_coroutine( self.submit_track(mid, data_type, sensor_name) ) @@ -127,6 +209,13 @@ def get_video_source( self.video_tracks_cache[sensor_key] = video_source self.tracks.append(video_source) + if self._native is not None: + # One shared encode per source, fanned out to every browser; the + # pump loop feeds frames. No per-connection aiortc track. + self._native.add_video_track(mid) + self._native_sources[mid] = video_source + return video_source + for connection in self.connections.values(): if ( connection.connection_details.video_format @@ -163,6 +252,18 @@ def get_json_source( self.event_source_cache[sensor_key] = source + if self._native is not None: + # One reliable data channel per source, fanned to every browser; each + # state update is sent to all of them. The channel label is the source + # mid, exactly as the aiortc path names it, so the SSE manifest the + # browser keys by lines up. Future browsers get the channel at bootstrap. + self._native.add_data_channel(mid) + source.add_listener( + source.STATE_UPDATED_EVENT, + lambda message, label=mid: self._native_send_json(label, message), + ) + return source + for connection in self.connections.values(): connection.add_event_source(source) @@ -221,6 +322,12 @@ async def create_new_connection( connection_id: Unique identifier for this connection connection_details: The describes the type of connection to establish. """ + if ( + self._native is not None + and connection_details.video_format == VideoFormat.WEB_RTC_NEGOTIATED + ): + return self._create_native_connection(remote_stream_id, connection_id) + connection = PeerToPeerProviderConnection( connection_id=connection_id, local_stream_id=self.local_stream_id, @@ -252,12 +359,37 @@ def on_close() -> None: return connection.enabled_manager + def _create_native_connection( + self, remote_stream_id: str, connection_id: str + ) -> EnabledManager: + """Add one browser as an answer-only consumer of the shared broadcast. + + The producer is the sole offerer; the offer rides the native pump. + """ + assert self._native is not None + enabled = EnabledManager.derived_manger(self.streaming, loop=self.loop) + self._native_enabled[connection_id] = enabled + self._native.add_consumer(connection_id, remote_stream_id) + + @enabled.on(EnabledManager.DISABLED) + def on_close() -> None: + if self._native is not None: + self._native.remove_consumer(connection_id) + self._native_enabled.pop(connection_id, None) + + return enabled + async def remove_connection(self, connection_id: str) -> None: """Remove a peer-to-peer connection. Args: connection_id: ID of the connection to end. """ + native_enabled = self._native_enabled.get(connection_id) + if native_enabled is not None: + native_enabled.disable() + return + connection = self.connections.pop(connection_id, None) if connection is None: return @@ -270,6 +402,15 @@ async def on_message(self, message: HandshakeMessage) -> None: Args: message: The message to handle. """ + if self._native is not None and message.connection_id in self._native_enabled: + if message.type == MessageType.ICE_CANDIDATE: + self._native.on_ice_candidate(message.connection_id, message.data) + elif message.type == MessageType.SDP_ANSWER: + self._native.on_answer(message.connection_id, message.data) + else: + logger.warning(f"Unsupported message type: {message.type}") + return + connection = self.connections.get(message.connection_id, None) if not connection: raise ValueError(f"Connection not found for id: {message.connection_id}") @@ -286,6 +427,11 @@ def _on_close(self) -> None: for connection in self.connections.values(): connection.close() + if self._native is not None: + self._native.close() + self._native_enabled.clear() + self._native_sources.clear() + self.connections.clear() self.tracks.clear() self.video_tracks_cache.clear() diff --git a/neuracore/core/streaming/p2p/provider/native_broadcast_provider.py b/neuracore/core/streaming/p2p/provider/native_broadcast_provider.py new file mode 100644 index 000000000..db968147c --- /dev/null +++ b/neuracore/core/streaming/p2p/provider/native_broadcast_provider.py @@ -0,0 +1,310 @@ +"""Native (Rust) broadcast provider for the web streaming path. + +This is the producer-side wiring of the new ``neuracore_webrtc`` stack for +browser consumers, gated by ``NCD_RUST_WEBRTC`` alongside the legacy aiortc +:class:`PeerToPeerProviderConnection`. The producer is the **sole offerer** and +every browser is an **answer-only** consumer, so there is no glare. + +One :class:`NativeBroadcastProvider` owns a single native ``Broadcaster`` (one +shared encode per source fanned out to N browsers). It maps the broadcaster's +drained, per-consumer signaling events onto the web signaling transport +(``send_handshake_message``) and feeds the browser's answer / candidates back in +via ``set_remote_answer(consumer_id, …)`` / ``add_remote_candidate(consumer_id, …)``. + +Lifecycle: a browser connecting is ``add_consumer``; disconnecting is +``remove_consumer``; a PR7 reconnect-needed ``on_error{where:"connection"}`` is a +remove + re-add for that consumer (the binding cannot ICE-restart — libjuice is +single-shot, upstream #130). + +The Chrome ``a=ssrc … cname`` munge is turned ON for these browser-facing +sessions (``NCD_WEBRTC_CHROME_SDP``); it is left off only for the +libdatachannel-to-libdatachannel loopback tests that assert byte-identical SDP. + +The module-level helpers (:func:`outbound_signal`, :func:`inbound_candidate`) +are pure so they can be unit-tested peer-free with a fake producer. +""" + +from __future__ import annotations + +import json +import logging +import os +from collections.abc import Callable +from dataclasses import dataclass +from typing import Protocol + +from neuracore_types import MessageType + +logger = logging.getLogger(__name__) + + +class BroadcasterProducer(Protocol): + """The subset of the native ``Broadcaster`` API this adapter drives.""" + + def add_consumer(self, consumer_id: str) -> None: + """Stand up an answer-only consumer peer connection.""" + + def remove_consumer(self, consumer_id: str) -> None: + """Tear down one consumer peer connection.""" + + def set_remote_answer(self, consumer_id: str, sdp: str) -> None: + """Apply a consumer's SDP answer.""" + + def add_remote_candidate( + self, consumer_id: str, candidate: str, mid: str | None + ) -> None: + """Apply a consumer's trickled ICE candidate.""" + + def add_video_track(self, track_id: str) -> None: + """Add a shared video source visible to all consumers.""" + + def remove_video_track(self, track_id: str) -> None: + """Remove a shared video source.""" + + def submit_frame(self, track_id: str, frame: object) -> None: + """Submit one frame for the shared encode.""" + + def add_data_channel(self, label: str, kind: str) -> None: + """Open a reliable data channel with ``label`` on every consumer.""" + + def send_json(self, label: str, payload: str) -> None: + """Send a JSON payload over ``label`` to every consumer that has it.""" + + def drain_events(self) -> list[dict]: + """Drain pending per-consumer events.""" + + def close(self) -> None: + """Tear down the broadcaster.""" + + +# Env flag that turns on the producer's Chrome-only SDP munge (bare ``a=ssrc`` -> +# ``a=ssrc … cname``). Browser sessions need it; the byte-identical loopback +# tests deliberately leave it unset. +CHROME_SDP_ENV = "NCD_WEBRTC_CHROME_SDP" + + +@dataclass(frozen=True) +class OutboundSignal: + """A signaling message the producer must deliver to one browser consumer.""" + + consumer_id: str + message_type: MessageType + data: str + + +def outbound_signal(event: dict) -> OutboundSignal | None: + """Map one drained broadcaster event to a web signaling message. + + The producer is the sole offerer, so it emits SDP **offers** and ICE + candidates; the browser replies with an answer (handled inbound). Events + without a ``consumer_id`` (a shared-encode error) or that are not signaling + map to ``None``. + + Args: + event: a single dict from ``Broadcaster.drain_events()``. + + Returns: + The message to send to that consumer, or ``None`` if not deliverable. + """ + consumer_id = event.get("consumer_id") + if consumer_id is None: + return None + kind = event.get("kind") + if kind == "on_local_description" and event.get("sdp_type") == "offer": + return OutboundSignal(consumer_id, MessageType.SDP_OFFER, event["sdp"]) + if kind == "on_local_candidate": + payload = json.dumps({ + "candidate": event["candidate"], + "sdpMid": event.get("mid"), + "sdpMLineIndex": event.get("mid"), + }) + return OutboundSignal(consumer_id, MessageType.ICE_CANDIDATE, payload) + return None + + +def inbound_candidate(data: str) -> tuple[str, str | None]: + """Parse a browser ICE candidate (``RTCIceCandidate.toJSON()``) for intake. + + Args: + data: JSON string the browser sent as the handshake payload. + + Returns: + ``(candidate, mid)`` for ``add_remote_candidate(consumer_id, …)``. + """ + content = json.loads(data) + return content["candidate"], content.get("sdpMid") + + +def needs_reconnect(event: dict) -> str | None: + """Return the consumer id that needs a reconnect, or ``None``. + + PR7 surfaces a dropped connection as ``on_error{where:"connection"}`` with + the originating ``consumer_id``; the producer recovers it by remove + re-add. + """ + if event.get("kind") == "on_error" and event.get("where") == "connection": + return event.get("consumer_id") + return None + + +@dataclass +class _Consumer: + """Bookkeeping for one browser consumer of the broadcast.""" + + connection_id: str + remote_stream_id: str + + +# (connection_id, remote_stream_id, message_type, data) -> None +SendMessage = Callable[[str, str, MessageType, str], None] + + +class NativeBroadcastProvider: + """Owns the native ``Broadcaster`` and bridges it to the web transport.""" + + def __init__( + self, + producer: BroadcasterProducer, + send_message: SendMessage, + *, + browser_facing: bool = True, + ) -> None: + """Initialize the provider. + + Args: + producer: a native ``Broadcaster`` instance (injected so tests can + pass a fake, peer-free producer). + send_message: delivers an outbound signaling message to a browser. + browser_facing: when True (the web path) the Chrome ``a=ssrc cname`` + munge is enabled for the process. + """ + self.producer = producer + self.send_message = send_message + self._consumers: dict[str, _Consumer] = {} + self._video_tracks: set[str] = set() + self._data_channels: set[str] = set() + if browser_facing: + self._enable_chrome_sdp() + + @staticmethod + def _enable_chrome_sdp() -> None: + """Turn the Chrome SDP munge on for browser sessions (idempotent).""" + os.environ.setdefault(CHROME_SDP_ENV, "1") + + # --- consumer lifecycle -------------------------------------------------- + + def add_consumer(self, connection_id: str, remote_stream_id: str) -> None: + """A browser connected: stand up an answer-only consumer for it.""" + if connection_id in self._consumers: + return + self._consumers[connection_id] = _Consumer(connection_id, remote_stream_id) + self.producer.add_consumer(connection_id) + + def remove_consumer(self, connection_id: str) -> None: + """A browser disconnected: tear down only its consumer.""" + if self._consumers.pop(connection_id, None) is None: + return + self.producer.remove_consumer(connection_id) + + def reconnect_consumer(self, connection_id: str) -> None: + """PR7 recovery: remove + re-add one consumer (no ICE restart).""" + consumer = self._consumers.get(connection_id) + if consumer is None: + return + self.producer.remove_consumer(connection_id) + self.producer.add_consumer(connection_id) + + # --- media / data sources ------------------------------------------------ + + def add_video_track(self, track_id: str) -> None: + """Register a video source visible to every (current + future) browser.""" + if track_id in self._video_tracks: + return + self._video_tracks.add(track_id) + self.producer.add_video_track(track_id) + + def remove_video_track(self, track_id: str) -> None: + """Drop a video source from the broadcast.""" + if track_id not in self._video_tracks: + return + self._video_tracks.discard(track_id) + self.producer.remove_video_track(track_id) + + def submit_frame(self, track_id: str, frame: object) -> None: + """Submit one frame for the shared encode (fanned out to all browsers).""" + self.producer.submit_frame(track_id, frame) + + def add_data_channel(self, label: str, kind: str = "reliable") -> None: + """Open a reliable data channel ``label`` on every (current + future) browser. + + Mirrors :meth:`add_video_track`: the channel is opened on each consumer by + the broadcaster (over the existing SCTP association for a live consumer, at + bootstrap for a future one), so json/joints reach every browser. + """ + if label in self._data_channels: + return + self._data_channels.add(label) + self.producer.add_data_channel(label, kind) + + def send_json(self, label: str, payload: str) -> None: + """Fan a serialised JSON payload to every browser's ``label`` channel.""" + self.producer.send_json(label, payload) + + # --- inbound signaling --------------------------------------------------- + + def on_answer(self, connection_id: str, sdp: str) -> None: + """Feed a browser's SDP answer back into its consumer.""" + if connection_id not in self._consumers: + return + self.producer.set_remote_answer(connection_id, sdp) + + def on_ice_candidate(self, connection_id: str, data: str) -> None: + """Feed a browser's trickled ICE candidate back into its consumer.""" + if connection_id not in self._consumers: + return + candidate, mid = inbound_candidate(data) + self.producer.add_remote_candidate(connection_id, candidate, mid) + + # --- outbound signaling pump --------------------------------------------- + + def pump_once(self) -> None: + """Drain the broadcaster and dispatch every pending event. + + Offers/candidates go to the browser via the web transport; a + reconnect-needed error triggers a remove + re-add for that consumer. + """ + for event in self.producer.drain_events(): + reconnect_id = needs_reconnect(event) + if reconnect_id is not None: + logger.warning( + "webrtc consumer %s needs reconnect: %s", + reconnect_id, + event.get("detail"), + ) + self.reconnect_consumer(reconnect_id) + continue + if event.get("kind") == "on_error": + logger.warning( + "webrtc error where=%s consumer=%s detail=%s", + event.get("where"), + event.get("consumer_id"), + event.get("detail"), + ) + signal = outbound_signal(event) + if signal is None: + continue + consumer = self._consumers.get(signal.consumer_id) + if consumer is None: + continue + self.send_message( + consumer.connection_id, + consumer.remote_stream_id, + signal.message_type, + signal.data, + ) + + def close(self) -> None: + """Tear down every consumer and the shared encode.""" + self._consumers.clear() + self._video_tracks.clear() + self._data_channels.clear() + self.producer.close() diff --git a/neuracore/core/streaming/p2p/recording_bridge.py b/neuracore/core/streaming/p2p/recording_bridge.py new file mode 100644 index 000000000..f4a96f906 --- /dev/null +++ b/neuracore/core/streaming/p2p/recording_bridge.py @@ -0,0 +1,113 @@ +"""Bridge the data daemon's recording entry points onto the WebRTC data plane. + +Both native send peers expose the same single channel send path, +``add_data_channel(label, kind)`` + ``send_json(label, payload)``: the 1:1 +``Producer`` (one consumer) and the ``Broadcaster`` (one shared producer fanned +out to N answer-only browsers). PR1's integration suite drives that path +directly; the live recording pipeline reaches it through the same ``log_json`` / +``log_joints`` entry points it already uses for disk recording (see +``neuracore/data_daemon/.../recording_context.py``). This adapter mirrors those +two signatures and forwards each call to ``send_json`` over a reliable-ordered +data channel, so the disk path and the streaming path converge on one send path. + +It is deliberately duck-typed against that send path (anything exposing +``add_data_channel(label, kind)`` and ``send_json(label, payload)``), so it is +**parameterised over the ``Producer`` or the ``Broadcaster`` rather than +duplicated**: over a ``Broadcaster`` each ``send_json`` fans to every browser's +channel for that label. It pulls in no heavy dependencies, so it can be wired in +at the cutover without disturbing the existing aiortc provider. One reliable +channel is opened lazily per stream; the reserved ``"control"`` label is never +used for application data. +""" + +from __future__ import annotations + +import json +from typing import Protocol + + +class _SendPath(Protocol): + """The slice of the native ``Producer`` / ``Broadcaster`` this bridge needs.""" + + def add_data_channel(self, label: str, kind: str) -> None: ... + + def send_json(self, label: str, payload: str) -> None: ... + + +class WebrtcRecordingBridge: + """Forward ``log_json`` / ``log_joints`` recording calls to ``send_json``. + + Each distinct stream maps to one reliable-ordered data channel, opened on + first use (which drives the send peer's negotiation just like any other + ``add_data_channel``). The wire payload is a small JSON envelope carrying the + capture timestamp and the sample, so a consumer can route by data type. The + same bridge serves the 1:1 ``Producer`` and the fan-out ``Broadcaster`` + interchangeably (see the module docstring). + """ + + #: Reserved by the transport for the manifest; never an application stream. + CONTROL_LABEL = "control" + + def __init__(self, producer: _SendPath) -> None: + """Bridge recording calls onto ``producer``'s ``send_json`` path. + + Args: + producer: any send peer exposing ``add_data_channel`` / + ``send_json`` — a native ``Producer`` or a ``Broadcaster``. + """ + self._producer = producer + self._open: set[str] = set() + + def _channel(self, label: str) -> str: + """Open ``label`` as a reliable channel on first use; return it.""" + if label == self.CONTROL_LABEL: + raise ValueError("'control' is reserved for the manifest transport") + if label not in self._open: + self._producer.add_data_channel(label, "reliable") + self._open.add(label) + return label + + def log_joints( + self, + data_type: str, + timestamp: float, + items: list[tuple[str, float]], + ) -> None: + """Stream a batch of ``(joint_name, value)`` samples for ``data_type``. + + Mirrors ``RecordingContext.log_joints``; routes to ``send_json`` over the + ``data_type`` channel. + """ + if not items: + return + label = self._channel(data_type) + envelope = { + "type": "joints", + "data_type": data_type, + "timestamp": timestamp, + "values": {name: value for name, value in items}, + } + self._producer.send_json(label, json.dumps(envelope)) + + def log_json( + self, + data_type: str, + name: str, + payload: bytes, + timestamp: float, + ) -> None: + """Stream one already-serialised JSON sample for the ``name`` stream. + + Mirrors ``RecordingContext.log_json``; routes to ``send_json`` over the + ``data_type/name`` channel. ``payload`` is the serialised sample as + produced by the recording path and is forwarded verbatim as text. + """ + label = self._channel(f"{data_type}/{name}") + envelope = { + "type": "json", + "data_type": data_type, + "name": name, + "timestamp": timestamp, + "payload": payload.decode("utf-8"), + } + self._producer.send_json(label, json.dumps(envelope)) diff --git a/neuracore/core/streaming/p2p/webrtc_selection.py b/neuracore/core/streaming/p2p/webrtc_selection.py new file mode 100644 index 000000000..2f02f5ddb --- /dev/null +++ b/neuracore/core/streaming/p2p/webrtc_selection.py @@ -0,0 +1,57 @@ +"""Feature-flag selection and native loader for the Rust WebRTC stack. + +Mirrors [rust_selection.py](neuracore/data_daemon/rust_selection.py): one place +for the ``NCD_RUST_WEBRTC`` environment-variable check that gates the new Rust +streaming core alongside the existing aiortc path, and for importing the +compiled PyO3 module that exposes the ``Producer``/``Consumer`` entry points. + +The new stack lives in the ``neuracore_webrtc`` Rust crate, built into the +package tree as ``neuracore.core.streaming.p2p._native_webrtc`` by +[build_wheel_artefacts.sh](rust/scripts/build_wheel_artefacts.sh). When the flag +is off (the default), nothing here is imported and the aiortc connections in +[provider/](neuracore/core/streaming/p2p/provider/) and +[consumer/](neuracore/core/streaming/p2p/consumer/) are used unchanged. + +Kept dependency-free so it can be imported without pulling in the streaming +runtime or aiortc. +""" + +from __future__ import annotations + +import os +from importlib import import_module +from types import ModuleType + +_TRUTHY_VALUES = frozenset({"1", "true", "yes", "y"}) + +_NATIVE_MODULE: ModuleType | None = None + +_NATIVE_IMPORT_HINT = ( + "neuracore.core.streaming.p2p._native_webrtc is not available. Build the " + "Rust neuracore_webrtc crate with rust/scripts/build_wheel_artefacts.sh " + "(which places the extension in the package tree), or unset NCD_RUST_WEBRTC " + "to use the legacy aiortc streaming path." +) + + +def rust_webrtc_enabled() -> bool: + """Return True when ``NCD_RUST_WEBRTC`` selects the Rust WebRTC stack.""" + return os.environ.get("NCD_RUST_WEBRTC", "").strip().lower() in _TRUTHY_VALUES + + +def load_native() -> ModuleType: + """Lazily import and cache the PyO3 WebRTC module for the process. + + Raises: + RuntimeError: if the compiled extension is not importable, with a hint + on how to build it or how to fall back to aiortc. + """ + global _NATIVE_MODULE + if _NATIVE_MODULE is None: + try: + _NATIVE_MODULE = import_module( + "neuracore.core.streaming.p2p._native_webrtc" + ) + except ImportError as error: + raise RuntimeError(_NATIVE_IMPORT_HINT) from error + return _NATIVE_MODULE diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 00229f264..f440c5a1b 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -151,6 +151,26 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bindgen" +version = "0.71.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", +] + [[package]] name = "bindgen" version = "0.72.1" @@ -301,6 +321,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" +[[package]] +name = "cmake" +version = "0.1.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0f78a02292a74a88ac736019ab962ece0bc380e3f977bf72e376c5d78ff0678" +dependencies = [ + "cc", +] + [[package]] name = "cobs" version = "0.3.0" @@ -337,6 +366,30 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpp_build" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f6fed3200ba0708c2adca5f6ed5ae202edd824bd4cbac7935a85edac9bcddce" +dependencies = [ + "cc", + "cpp_common", + "proc-macro2", + "regex", + "syn", + "unicode-xid", +] + +[[package]] +name = "cpp_common" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7280a73ef92e18d27d2ec3005b57fe0043b51d1b506be86b0bf66f588f9857b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -478,6 +531,33 @@ dependencies = [ "tracing", ] +[[package]] +name = "datachannel" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15faaf8ab2b10994dcc623bf0d1c243f210ad31112943211dd9b43df4977f6c2" +dependencies = [ + "datachannel-sys", + "derive_more", + "log", + "parking_lot", + "serde", + "webrtc-sdp", +] + +[[package]] +name = "datachannel-sys" +version = "0.23.0+0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2adb5cdc7cbd3ec035819d1c0cd3715727a0b282ce7624242e7150602e10dd27" +dependencies = [ + "bindgen 0.71.1", + "cmake", + "cpp_build", + "once_cell", + "openssl-src", +] + [[package]] name = "deadpool" version = "0.12.3" @@ -507,6 +587,28 @@ dependencies = [ "zeroize", ] +[[package]] +name = "derive_more" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "syn", + "unicode-xid", +] + [[package]] name = "digest" version = "0.10.7" @@ -1308,7 +1410,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1db27a7ee0e2d5a3b873b16bb1edaa52d71ff0cf5f9c8f9d93d328b46d777472" dependencies = [ - "bindgen", + "bindgen 0.72.1", "cc", "iceoryx2-pal-posix", ] @@ -1319,7 +1421,7 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "047637116257281d1102490689e2d0073d94b8d5f4cda20b00b65aed35c09e6e" dependencies = [ - "bindgen", + "bindgen 0.72.1", "cc", "iceoryx2-pal-concurrency-sync", "iceoryx2-pal-configuration", @@ -1649,6 +1751,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "neuracore_webrtc" +version = "0.1.0" +dependencies = [ + "datachannel", + "datachannel-sys", + "once_cell", + "pyo3", + "serde_json", + "tokio", +] + [[package]] name = "nix" version = "0.29.0" @@ -1748,6 +1862,15 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "openssl-src" +version = "300.6.1+3.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46eb8fb9fb3b61ce1c0f8a026c4c1a0714d3a9e138e7fbde78753ce2babc3846" +dependencies = [ + "cc", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -3351,6 +3474,16 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "webrtc-sdp" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a87d58624aae43577604ea137de9dcaf92793eccc4d816efad482001c2e055ca" +dependencies = [ + "log", + "url", +] + [[package]] name = "whoami" version = "1.6.1" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 0e20c8537..af0155ba4 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,5 +1,10 @@ [workspace] -members = ["data_daemon", "data_daemon_shared", "data_daemon_producer"] +members = [ + "data_daemon", + "data_daemon_shared", + "data_daemon_producer", + "neuracore_webrtc", +] resolver = "2" [workspace.package] diff --git a/rust/data_daemon/src/cloud/notifier.rs b/rust/data_daemon/src/cloud/notifier.rs new file mode 100644 index 000000000..53291ea95 --- /dev/null +++ b/rust/data_daemon/src/cloud/notifier.rs @@ -0,0 +1,327 @@ +//! Shared skeleton for the backend recording-lifecycle notifiers. +//! +//! The start / stop / cancel notifiers each POST a different backend endpoint, +//! but their machinery is identical: subscribe to the event bus, sweep any +//! recordings whose notification is pending from a previous (offline) session, +//! then POST whenever the relevant lifecycle event fires — retrying via a +//! startup sweep and on broadcast lag. This module owns that machinery once; a +//! notifier supplies only the three things that actually differ via +//! [`RecordingNotifier`]: which event(s) trigger it, which "pending" query +//! drives its recovery sweep, and the per-recording POST itself. +//! +//! Events are processed sequentially: each POST is awaited inline before the +//! next event is read, so a slow or retrying POST delays later events on the +//! same notifier and can push the broadcast channel into `Lagged`. That is +//! handled by re-running the recovery sweep (the POSTs are idempotent), which +//! is the recovery mechanism rather than a failure. +//! +//! Each `recording_*_notifier` module defines a small unit struct implementing +//! the trait plus a thin `spawn_recording_*_notifier` wrapper, so the call +//! sites (and their tests) are unchanged. + +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::broadcast; +use tokio::task::JoinHandle; + +use crate::api::ApiClient; +use crate::cloud::OrgIdRx; +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::{ + DaemonEvent, EventBus, RecordingRow, SqliteStateStore, StateStore, StateStoreError, +}; + +/// Handle returned by every recording notifier's `spawn_*` wrapper. +pub struct NotifierHandle { + join: JoinHandle<()>, + label: &'static str, +} + +impl NotifierHandle { + /// Wait for the notifier task to exit. + pub async fn join(self) { + if let Err(error) = self.join.await { + tracing::warn!( + ?error, + notifier = self.label, + "recording notifier join failed" + ); + } + } +} + +/// Shared dependencies handed to a notifier's `notify`. +pub struct NotifierCtx { + /// State store (already `Arc`-wrapped for the spawned task). + pub store: Arc, + /// Backend HTTP client. + pub client: Arc, + /// Event bus — the start notifier publishes `RecordingCloudIdAssigned` on it. + pub bus: EventBus, + /// Live current-org receiver, read at POST time. + pub org_rx: OrgIdRx, +} + +/// One backend recording-lifecycle notifier (start / stop / cancel). +/// +/// Everything common — the spawn loop, the offline-recovery sweep, the +/// shutdown/lag handling — lives in [`spawn_notifier`]; an implementor supplies +/// only what differs. +#[async_trait] +pub trait RecordingNotifier: Send + Sync + 'static { + /// Short label used in this notifier's log lines. + fn label(&self) -> &'static str; + + /// The recording index to notify for `event`, or `None` to ignore it. + fn triggered_by(&self, event: &DaemonEvent) -> Option; + + /// Recordings whose notification is still pending — the offline-recovery + /// sweep set, run on startup and after a broadcast lag. + async fn pending( + &self, + store: &Arc, + ) -> Result, StateStoreError>; + + /// Fire the backend POST for one recording. Idempotent and self-logging: + /// the shared loop never inspects the result. + async fn notify(&self, ctx: &NotifierCtx, recording_index: i64); +} + +/// Spawn a notifier task driven by `notifier` on the current Tokio runtime. +/// +/// Sweeps pending notifications first (so recordings that finished while the +/// daemon was offline recover), then serves live bus events until shutdown. +pub fn spawn_notifier( + notifier: N, + store: SqliteStateStore, + bus: EventBus, + client: Arc, + org_rx: OrgIdRx, + mut shutdown_rx: broadcast::Receiver, +) -> NotifierHandle { + let label = notifier.label(); + let mut subscriber = bus.subscribe(); + let ctx = NotifierCtx { + store: Arc::new(store), + client, + bus, + org_rx, + }; + let join = tokio::spawn(async move { + // Recover pending notifications before serving live events. Run inside a + // `select!` against shutdown so a long sweep cannot hold up exit. + tokio::select! { + biased; + signal = shutdown_rx.recv() => { + tracing::debug!(?signal, notifier = label, "recording notifier shutting down before sweep"); + return; + } + _ = sweep(¬ifier, &ctx) => {} + } + loop { + tokio::select! { + biased; + signal = shutdown_rx.recv() => { + tracing::debug!(?signal, notifier = label, "recording notifier shutting down"); + break; + } + event = subscriber.recv() => { + match event { + Ok(event) => { + if let Some(recording_index) = notifier.triggered_by(&event) { + notifier.notify(&ctx, recording_index).await; + } + } + Err(broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!( + skipped, + notifier = label, + "recording notifier missed bus events; re-sweeping pending notifications", + ); + sweep(¬ifier, &ctx).await; + } + Err(broadcast::error::RecvError::Closed) => { + tracing::debug!(notifier = label, "event bus closed; recording notifier exiting"); + break; + } + } + } + } + } + }); + NotifierHandle { join, label } +} + +/// Notify every recording the notifier reports as pending. +async fn sweep(notifier: &N, ctx: &NotifierCtx) { + let pending = match notifier.pending(&ctx.store).await { + Ok(rows) => rows, + Err(error) => { + tracing::warn!(%error, notifier = notifier.label(), "failed to query recordings pending notify"); + return; + } + }; + if pending.is_empty() { + return; + } + tracing::info!( + count = pending.len(), + notifier = notifier.label(), + "sweeping recordings with pending backend notify", + ); + for row in pending { + notifier.notify(ctx, row.recording_index).await; + } +} + +/// Which `/recording/*` endpoint a lifecycle notify targets. The stop and +/// cancel notifiers run the *same* guard chain (row fetch → already-notified +/// guard → cloud-id guard → org guard → `stop_timestamp_ns` guard → POST → +/// 404-as-success → mark-notified); only these per-kind bits differ. +#[derive(Clone, Copy)] +pub enum LifecycleKind { + Stop, + Cancel, +} + +impl LifecycleKind { + /// Word used in this notifier's log lines ("stop" / "cancel"). + fn action(self) -> &'static str { + match self { + LifecycleKind::Stop => "stop", + LifecycleKind::Cancel => "cancel", + } + } + + /// Whether this recording's notification has already been persisted. + fn already_notified(self, row: &RecordingRow) -> bool { + match self { + LifecycleKind::Stop => row.backend_stop_notified_at.is_some(), + LifecycleKind::Cancel => row.backend_cancel_notified_at.is_some(), + } + } +} + +/// Run the shared stop/cancel backend-notify flow for one recording. +/// +/// Idempotent and self-logging (the spawn loop never inspects the result): a +/// 404 is treated as success (the start notifier's prior-pending resolution +/// already closed the recording server-side), and a persist failure after a +/// successful POST is left for the next sweep since the POST is idempotent. +pub async fn notify_recording_lifecycle( + kind: LifecycleKind, + store: &Arc, + client: &Arc, + org_rx: &OrgIdRx, + recording_index: i64, +) { + let action = kind.action(); + let row = match store.get_recording(recording_index).await { + Ok(Some(row)) => row, + Ok(None) => { + tracing::warn!( + recording_index, + "recording row missing on {action}; skipping backend notify" + ); + return; + } + Err(error) => { + tracing::warn!(%error, recording_index, "failed to look up recording for {action} notify"); + return; + } + }; + + if kind.already_notified(&row) { + // Another path (sweep or earlier event) already notified. + return; + } + // Stop is also triggered by `RecordingCloudIdAssigned`, which can fire for a + // still-running recording; hold the POST until it has actually stopped. + // (A cancel only ever reaches here once `cancelled_at` is stamped.) + if matches!(kind, LifecycleKind::Stop) && row.stopped_at.is_none() { + return; + } + let Some(recording_id) = row.recording_id else { + // No cloud id → nothing exists server-side to act on. The sweep + // re-fires once the start notifier mints the id. + tracing::debug!( + recording_index, + "recording has no cloud id at {action} time; deferring backend notify" + ); + return; + }; + let Some(org_id) = org_rx.borrow().clone() else { + tracing::warn!( + recording_index, + recording_id, + "no current org_id configured at {action} time; skipping backend notify" + ); + return; + }; + let Some(stop_timestamp_ns) = row.stop_timestamp_ns else { + tracing::warn!( + recording_index, + recording_id, + "recording has no stop_timestamp_ns at {action} time; skipping backend notify" + ); + return; + }; + // The producer captured this as the recording window's real upper bound; + // the backend requires it (seconds) and derives the reported duration from + // it, so a late notify still reports correctly. + let end_time = stop_timestamp_ns as f64 / 1_000_000_000.0; + + let post_result = match kind { + LifecycleKind::Stop => { + client + .recording_stop(&org_id, &recording_id, end_time) + .await + } + LifecycleKind::Cancel => { + client + .recording_cancel(&org_id, &recording_id, end_time) + .await + } + }; + + let mark_result = match &post_result { + Ok(()) => mark_notified(kind, store, recording_index).await, + // 404 means the backend no longer has this recording open — the + // start-notifier's `resolve_prior_pending` already closed it. That is + // the post-condition we wanted, so record it rather than re-sweeping. + Err(error) if error.is_not_found() => mark_notified(kind, store, recording_index).await, + Err(error) => { + tracing::warn!(%error, recording_index, recording_id, "failed to notify backend of recording {action}"); + return; + } + }; + if let Err(error) = mark_result { + tracing::warn!( + %error, + recording_index, + recording_id, + "POST succeeded but persisting backend_{action}_notified_at failed; \ + the next sweep will re-post (the backend POST is idempotent)", + ); + } else { + tracing::info!( + recording_index, + recording_id, + "backend notified of recording {action}" + ); + } +} + +/// Persist the "notified" timestamp for the given lifecycle kind. +async fn mark_notified( + kind: LifecycleKind, + store: &Arc, + recording_index: i64, +) -> Result<(), StateStoreError> { + match kind { + LifecycleKind::Stop => store.mark_recording_stop_notified(recording_index).await, + LifecycleKind::Cancel => store.mark_recording_cancel_notified(recording_index).await, + } + .map(|_| ()) +} diff --git a/rust/data_daemon/src/cloud/org_watcher.rs b/rust/data_daemon/src/cloud/org_watcher.rs new file mode 100644 index 000000000..1f2cfdf9b --- /dev/null +++ b/rust/data_daemon/src/cloud/org_watcher.rs @@ -0,0 +1,190 @@ +//! Live `org_id` resolution. +//! +//! The organisation that owns a recording is no longer frozen onto the +//! recording row at creation time. Instead this module watches the +//! SDK-managed `~/.neuracore/config.json` and publishes the *current* +//! `current_org_id` into a [`watch::channel`] that every cloud coordinator +//! reads at the moment it issues a backend POST. A recording opened before +//! the org was selected therefore picks the org up as soon as it lands in +//! config — no daemon restart, and no per-recording backfill. + +use std::path::PathBuf; +use std::time::Duration; + +use tokio::sync::{broadcast, watch}; +use tokio::task::JoinHandle; +use tokio::time::{interval, MissedTickBehavior}; + +use crate::cloud::{read_org_id_from_config, read_org_id_from_config_async}; +use crate::lifecycle::signals::ShutdownSignal; + +/// Shared read handle for the current `org_id`. Cheap to clone; read the +/// current value with `org_rx.borrow().clone()`. +pub type OrgIdRx = tokio::sync::watch::Receiver>; + +/// How often the watcher re-reads the config file. Config writes are rare and +/// user-driven, and the file is tiny, so a coarse poll that re-parses each +/// tick is plenty — cheaper to reason about than mtime gating, which would +/// miss two writes landing within the same mtime granularity (e.g. `login` +/// immediately followed by `set_organization`). The per-tick read is async +/// (`tokio::fs`), so it never blocks a runtime worker. +const POLL_INTERVAL: Duration = Duration::from_secs(1); + +/// Handle for the config-file watcher task. +pub struct OrgWatcherHandle { + join: JoinHandle<()>, +} + +impl OrgWatcherHandle { + /// Wait for the watcher task to exit. + pub async fn join(self) { + if let Err(error) = self.join.await { + tracing::warn!(?error, "org watcher join failed"); + } + } +} + +/// Spawn the config-file watcher. +/// +/// Returns a [`OrgIdRx`] seeded with the org resolved at spawn time and the +/// task handle. `fallback` is the daemon-profile override (`NCD_CURRENT_ORG_ID` +/// / YAML profile) used whenever the config file has no `current_org_id`, +/// matching the launch-time resolution order. +pub fn spawn_org_watcher( + config_path: PathBuf, + fallback: Option, + mut shutdown_rx: broadcast::Receiver, +) -> (OrgIdRx, OrgWatcherHandle) { + // One-shot blocking seed is fine — it runs once before the task spawns. + let initial = read_org_id_from_config(&config_path).or_else(|| fallback.clone()); + let (org_tx, org_rx) = watch::channel(initial); + + let join = tokio::spawn(async move { + let mut ticker = interval(POLL_INTERVAL); + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + biased; + signal = shutdown_rx.recv() => { + tracing::debug!(?signal, "org watcher shutting down"); + break; + } + _ = ticker.tick() => { + let current = read_org_id_from_config_async(&config_path) + .await + .or_else(|| fallback.clone()); + org_tx.send_if_modified(|existing| { + if *existing == current { + false + } else { + tracing::info!( + org_id = ?current, + "config change picked up; updating current org_id" + ); + *existing = current; + true + } + }); + } + } + } + }); + + (org_rx, OrgWatcherHandle { join }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::TempDir; + use tokio::time::timeout; + + fn write_config(path: &std::path::Path, org_id: Option<&str>) { + let body = match org_id { + Some(org) => format!(r#"{{"current_org_id": "{org}"}}"#), + None => "{}".to_string(), + }; + let mut file = std::fs::File::create(path).expect("write config"); + file.write_all(body.as_bytes()).expect("write body"); + } + + #[tokio::test] + async fn seeds_initial_value_from_config() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("config.json"); + write_config(&path, Some("org-initial")); + + let (shutdown_tx, _) = broadcast::channel::(8); + let (org_rx, handle) = spawn_org_watcher(path, None, shutdown_tx.subscribe()); + assert_eq!(org_rx.borrow().as_deref(), Some("org-initial")); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn falls_back_when_config_has_no_org() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("config.json"); + write_config(&path, None); + + let (shutdown_tx, _) = broadcast::channel::(8); + let (org_rx, handle) = spawn_org_watcher( + path, + Some("profile-org".to_string()), + shutdown_tx.subscribe(), + ); + assert_eq!(org_rx.borrow().as_deref(), Some("profile-org")); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn corrupt_config_falls_back_without_crashing() { + // M15: a present-but-corrupt config must not crash the watcher or wipe + // the fallback org — it logs and is treated as "no org in config". + let dir = TempDir::new().unwrap(); + let path = dir.path().join("config.json"); + std::fs::write(&path, b"{ this is not valid json ").unwrap(); + + let (shutdown_tx, _) = broadcast::channel::(8); + let (org_rx, handle) = spawn_org_watcher( + path, + Some("profile-org".to_string()), + shutdown_tx.subscribe(), + ); + assert_eq!(org_rx.borrow().as_deref(), Some("profile-org")); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn picks_up_org_written_after_launch() { + // The recording-blocking case: the daemon comes up before any org is + // selected, then the SDK writes one. The watcher must publish it + // without a restart. + let dir = TempDir::new().unwrap(); + let path = dir.path().join("config.json"); + write_config(&path, None); + + let (shutdown_tx, _) = broadcast::channel::(8); + let (mut org_rx, handle) = spawn_org_watcher(path.clone(), None, shutdown_tx.subscribe()); + assert_eq!(org_rx.borrow().as_deref(), None, "starts org-less"); + + // Select an org after launch. + write_config(&path, Some("org-late")); + + timeout(Duration::from_secs(5), org_rx.changed()) + .await + .expect("watcher must observe the config change within 5s") + .expect("sender alive"); + assert_eq!(org_rx.borrow().as_deref(), Some("org-late")); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } +} diff --git a/rust/data_daemon/src/cloud/progress.rs b/rust/data_daemon/src/cloud/progress.rs new file mode 100644 index 000000000..aff928689 --- /dev/null +++ b/rust/data_daemon/src/cloud/progress.rs @@ -0,0 +1,543 @@ +//! Periodic progress reporter. +//! +//! Every [`FAST_PROGRESS_TICK`] the reporter sweeps the recordings still +//! pending a report ([`StateStore::recordings_pending_progress`] — a +//! server-side filter, so fully-settled recordings drop out of the scan) and, +//! for every stopped recording whose traces have all finished *writing* (and +//! whose `progress_reported` is still `Pending`), +//! POSTs `/org/{org}/recording/{rec}/traces-metadata` with the per-trace +//! `total_bytes` snapshot. This establishes the recording's upload +//! denominators on the backend up front — before uploads finish — so the +//! live per-trace `uploaded_bytes` stream renders as a partial-upload +//! percentage rather than a single jump to 100%. On success the recording +//! row flips to `progress_reported = 'reported'`. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::broadcast; +use tokio::task::JoinHandle; +use tokio::time::{interval, MissedTickBehavior}; + +use crate::api::ApiClient; +use crate::cloud::OrgIdRx; +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::{ + ProgressReportStatus, RecordingRow, SqliteStateStore, StateStore, TraceRecord, TraceWriteStatus, +}; + +/// Interval between progress-report sweeps. Kept short so a newly-uploaded +/// recording doesn't sit before reporting; the sweep is cheap because +/// [`StateStore::recordings_pending_progress`] filters settled recordings out +/// server-side and every actual flush is still guarded by the +/// upload-complete check. +const FAST_PROGRESS_TICK: Duration = Duration::from_secs(2); + +/// Handle returned by [`spawn_progress_reporter`]. +pub struct ProgressReporterHandle { + join: JoinHandle<()>, +} + +impl ProgressReporterHandle { + /// Wait for the reporter task to exit. + pub async fn join(self) { + if let Err(error) = self.join.await { + tracing::warn!(?error, "progress reporter join failed"); + } + } +} + +/// Spawn the progress reporter task on the current Tokio runtime. +pub fn spawn_progress_reporter( + store: SqliteStateStore, + client: Arc, + org_rx: OrgIdRx, + mut shutdown_rx: broadcast::Receiver, +) -> ProgressReporterHandle { + let store = Arc::new(store); + let join = tokio::spawn(async move { + let mut ticker = interval(FAST_PROGRESS_TICK); + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + biased; + signal = shutdown_rx.recv() => { + tracing::debug!(?signal, "progress reporter shutting down"); + break; + } + _ = ticker.tick() => { + sweep_once(&store, &client, &org_rx).await; + } + } + } + }); + ProgressReporterHandle { join } +} + +async fn sweep_once(store: &Arc, client: &Arc, org_rx: &OrgIdRx) { + // Server-side filter to stopped, non-cancelled, cloud-id-assigned + // recordings that still have reporting work outstanding, so fully-settled + // recordings drop out of the sweep instead of being re-scanned (and their + // traces re-fetched) on every tick. The cancelled/stopped/cloud-id guards + // below are kept as belt-and-braces against a row racing the query. + let recordings = match store.recordings_pending_progress().await { + Ok(rows) => rows, + Err(error) => { + tracing::warn!(%error, "progress reporter could not query pending recordings"); + return; + } + }; + for recording in recordings { + if recording.stopped_at.is_none() || recording.cancelled_at.is_some() { + continue; + } + let Some(org_id) = org_rx.borrow().clone() else { + continue; + }; + // Every cloud URL needs the backend `recording_id`. A None here means + // the start notifier hasn't populated the cloud id yet — skip until it + // has (e.g. a recording made while the daemon was offline). + let Some(recording_id) = recording.recording_id.clone() else { + tracing::warn!( + recording_index = recording.recording_index, + "progress reporter skipping recording with no cloud recording_id yet" + ); + continue; + }; + let traces = match store + .list_traces_for_recording(recording.recording_index) + .await + { + Ok(rows) => rows, + Err(error) => { + tracing::warn!(%error, recording_index = recording.recording_index, "progress reporter could not list traces"); + continue; + } + }; + if traces.is_empty() { + continue; + } + report_expected_trace_count(store, client, &recording, &org_id, &recording_id, &traces) + .await; + if matches!(recording.progress_reported, ProgressReportStatus::Reported) { + continue; + } + report_progress(store, client, &recording, &org_id, &recording_id, &traces).await; + } +} + +/// Tell the backend how many traces this recording will have. Until this PUT +/// lands, the backend keeps the recording hidden from its parent dataset +/// regardless of how many trace blobs are already uploaded. Idempotent: +/// short-circuits once `expected_trace_count_reported` is non-zero. +async fn report_expected_trace_count( + store: &Arc, + client: &Arc, + recording: &RecordingRow, + org_id: &str, + recording_id: &str, + traces: &[TraceRecord], +) { + if recording.expected_trace_count_reported > 0 { + return; + } + // Wait until every trace has reached a terminal write state. Reporting + // the count too early would race the per-trace actors and risk telling + // the backend a number that excludes traces still being flushed. + if !traces.iter().all(write_status_is_terminal) { + return; + } + let count = i64::try_from(traces.len()).unwrap_or(i64::MAX); + + // Persist locally first so a transient PUT failure does not lose the + // count, and so a re-claim by the next tick sees the same value. + if let Err(error) = store + .set_expected_trace_count(recording.recording_index, count) + .await + { + tracing::warn!( + %error, + recording_index = recording.recording_index, + "failed to persist expected trace count" + ); + return; + } + + match client + .put_expected_trace_count(org_id, recording_id, count) + .await + { + Ok(()) => { + if let Err(error) = store + .mark_expected_trace_count_reported(recording.recording_index, count) + .await + { + tracing::warn!( + %error, + recording_index = recording.recording_index, + "failed to mark expected trace count as reported" + ); + return; + } + tracing::info!( + recording_index = recording.recording_index, + recording_id, + count, + "expected trace count reported" + ); + } + Err(error) => { + tracing::warn!( + %error, + recording_index = recording.recording_index, + "expected trace count PUT failed" + ); + } + } +} + +async fn report_progress( + store: &Arc, + client: &Arc, + recording: &RecordingRow, + org_id: &str, + recording_id: &str, + traces: &[TraceRecord], +) { + // Send the snapshot of per-trace sizes (`total_bytes`) as soon as every + // trace has finished *writing* — not once it has finished *uploading*. + // This establishes the recording's denominators on the backend early, so + // the live per-trace `uploaded_bytes` stream (sent via the batch-update + // endpoint) can render a partial-upload percentage. Gating on upload + // completion instead would withhold the denominators until the whole + // recording is already uploaded, collapsing progress to a single 0→100% + // jump. Failed writes are terminal too, so one bad trace can't pin the + // recording in `progress_reported = pending` forever. + if !traces.iter().all(write_status_is_terminal) { + return; + } + let trace_map: HashMap = traces + .iter() + .map(|trace| (trace.trace_id.clone(), trace.total_bytes)) + .collect(); + // Move into a Reporting state so a slow request can't be re-issued + // by the next tick. + match store + .set_progress_report_status( + recording.recording_index, + ProgressReportStatus::Pending, + ProgressReportStatus::Reporting, + ) + .await + { + Ok(Some(row)) if matches!(row.progress_reported, ProgressReportStatus::Reporting) => {} + _ => return, + } + + match client + .report_progress(org_id, recording_id, &trace_map) + .await + { + Ok(()) => { + let _ = store + .set_progress_report_status( + recording.recording_index, + ProgressReportStatus::Reporting, + ProgressReportStatus::Reported, + ) + .await; + tracing::info!( + recording_index = recording.recording_index, + recording_id, + "progress report sent" + ); + } + Err(error) => { + tracing::warn!(%error, recording_index = recording.recording_index, "progress report failed"); + let _ = store + .set_progress_report_status( + recording.recording_index, + ProgressReportStatus::Reporting, + ProgressReportStatus::Pending, + ) + .await; + } + } +} + +fn write_status_is_terminal(trace: &TraceRecord) -> bool { + matches!( + trace.write_status, + TraceWriteStatus::Written | TraceWriteStatus::Failed + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::api::auth::StaticAuthProvider; + use crate::api::client::ApiClientOptions; + use crate::state::store::{NewRecording, TraceUpdate}; + use crate::state::{TraceUploadStatus, TraceWriteStatus}; + use tempfile::TempDir; + use wiremock::matchers::{body_json, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + async fn open_store() -> (SqliteStateStore, TempDir) { + let dir = TempDir::new().unwrap(); + let store = SqliteStateStore::open(&dir.path().join("state.db")) + .await + .unwrap(); + (store, dir) + } + + /// Create a recording stamped with `org-1` and the given cloud + /// `recording_id` so the wiremock URL expectations resolve. Returns the + /// local `recording_index`. + async fn seed_recording(store: &SqliteStateStore, cloud_recording_id: &str) -> i64 { + let recording = store + .create_recording(NewRecording::default()) + .await + .unwrap(); + store + .mark_recording_start_notified(recording.recording_index, cloud_recording_id) + .await + .unwrap(); + recording.recording_index + } + + /// A live-org receiver fixed at `org`. The sender is leaked so the channel + /// stays open for the test's duration. + fn org_rx(org: Option<&str>) -> OrgIdRx { + let (org_tx, org_rx) = tokio::sync::watch::channel(org.map(str::to_string)); + Box::leak(Box::new(org_tx)); + org_rx + } + + fn client(server: &MockServer) -> Arc { + let auth = Arc::new(StaticAuthProvider::new("test")); + let mut options = ApiClientOptions::new(server.uri()); + options.max_backoff = Duration::from_millis(10); + Arc::new(ApiClient::new(options, auth).unwrap()) + } + + #[tokio::test] + async fn sweep_reports_count_and_progress_once_writes_settle() { + let server = MockServer::start().await; + Mock::given(method("PUT")) + .and(path("/org/org-1/recording/rec-1/expected-trace-count")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + // The progress snapshot must carry each trace's `total_bytes` (the + // upload denominator), not its `uploaded_bytes` — and it must fire as + // soon as writes settle, before uploads finish, so the backend can + // render a live percentage from the streamed byte counts. + Mock::given(method("POST")) + .and(path("/org/org-1/recording/rec-1/traces-metadata")) + .and(body_json(serde_json::json!({ + "traces": { "t-1": 100, "t-2": 200 } + }))) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let recording_index = seed_recording(&store, "rec-1").await; + // Two traces finished writing (with known sizes) but neither has + // uploaded yet — both the expected-count PUT and the progress POST + // must fire on write completion, not upload completion. + for (trace_id, total_bytes) in [("t-1", 100), ("t-2", 200)] { + store + .create_trace(recording_index, trace_id, Some("JOINT_POSITIONS"), None) + .await + .unwrap(); + store + .update_trace( + trace_id, + TraceUpdate { + write_status: Some(TraceWriteStatus::Written), + total_bytes: Some(total_bytes), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + } + store + .mark_recording_stopped(recording_index, 0) + .await + .unwrap(); + + let api = client(&server); + sweep_once(&Arc::new(store.clone()), &api, &org_rx(Some("org-1"))).await; + + let recording = store.get_recording(recording_index).await.unwrap().unwrap(); + assert_eq!(recording.expected_trace_count, Some(2)); + assert_eq!(recording.expected_trace_count_reported, 2); + // Progress reports once writes settle — uploads need not be done. + assert!(matches!( + recording.progress_reported, + ProgressReportStatus::Reported + )); + } + + #[tokio::test] + async fn sweep_skips_expected_count_while_writes_in_flight() { + let server = MockServer::start().await; + // No mock for the PUT — if the sweep fires it would 404 and we'd + // catch a state-change side effect via the assertion below. + let (store, _dir) = open_store().await; + let recording_index = seed_recording(&store, "rec-1").await; + store + .create_trace(recording_index, "t-1", Some("JOINT_POSITIONS"), None) + .await + .unwrap(); + store + .update_trace( + "t-1", + TraceUpdate { + write_status: Some(TraceWriteStatus::Writing), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + store + .mark_recording_stopped(recording_index, 0) + .await + .unwrap(); + + let api = client(&server); + sweep_once(&Arc::new(store.clone()), &api, &org_rx(Some("org-1"))).await; + + let recording = store.get_recording(recording_index).await.unwrap().unwrap(); + assert_eq!(recording.expected_trace_count, None); + assert_eq!(recording.expected_trace_count_reported, 0); + } + + #[tokio::test] + async fn sweep_reports_when_one_trace_failed_and_rest_uploaded() { + // Mixed terminal state: one trace Uploaded, one trace Failed. + // The progress reporter should still POST and flip the + // recording's status — a single failure must not deadlock the + // whole recording. + let server = MockServer::start().await; + Mock::given(method("PUT")) + .and(path("/org/org-1/recording/rec-1/expected-trace-count")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/rec-1/traces-metadata")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let recording_index = seed_recording(&store, "rec-1").await; + store + .create_trace(recording_index, "ok", Some("JOINT_POSITIONS"), None) + .await + .unwrap(); + store + .update_trace( + "ok", + TraceUpdate { + write_status: Some(TraceWriteStatus::Written), + upload_status: Some(TraceUploadStatus::Uploaded), + bytes_uploaded: Some(7), + total_bytes: Some(7), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + store + .create_trace(recording_index, "bad", Some("JOINT_POSITIONS"), None) + .await + .unwrap(); + store + .update_trace( + "bad", + TraceUpdate { + write_status: Some(TraceWriteStatus::Written), + upload_status: Some(TraceUploadStatus::Failed), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + store + .mark_recording_stopped(recording_index, 0) + .await + .unwrap(); + + let api = client(&server); + sweep_once(&Arc::new(store.clone()), &api, &org_rx(Some("org-1"))).await; + + let recording = store.get_recording(recording_index).await.unwrap().unwrap(); + assert!( + matches!(recording.progress_reported, ProgressReportStatus::Reported), + "progress should be reported even when one trace failed; \ + got {:?}", + recording.progress_reported + ); + } + + #[tokio::test] + async fn sweep_marks_recording_reported_after_post() { + let server = MockServer::start().await; + Mock::given(method("PUT")) + .and(path("/org/org-1/recording/rec-1/expected-trace-count")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/rec-1/traces-metadata")) + .respond_with(ResponseTemplate::new(200)) + .expect(1) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let recording_index = seed_recording(&store, "rec-1").await; + store + .create_trace(recording_index, "trace-1", Some("JOINT_POSITIONS"), None) + .await + .unwrap(); + store + .update_trace( + "trace-1", + TraceUpdate { + write_status: Some(TraceWriteStatus::Written), + upload_status: Some(TraceUploadStatus::Uploaded), + bytes_uploaded: Some(42), + total_bytes: Some(42), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + store + .mark_recording_stopped(recording_index, 0) + .await + .unwrap(); + + let api = client(&server); + sweep_once(&Arc::new(store.clone()), &api, &org_rx(Some("org-1"))).await; + + let recording = store.get_recording(recording_index).await.unwrap().unwrap(); + assert!(matches!( + recording.progress_reported, + ProgressReportStatus::Reported + )); + } +} diff --git a/rust/data_daemon/src/cloud/recording_cancel_notifier.rs b/rust/data_daemon/src/cloud/recording_cancel_notifier.rs new file mode 100644 index 000000000..43b9c644e --- /dev/null +++ b/rust/data_daemon/src/cloud/recording_cancel_notifier.rs @@ -0,0 +1,347 @@ +//! Backend recording-cancel notifier. +//! +//! Subscribes to [`DaemonEvent::RecordingCancelled`] and POSTs +//! `/org/{org}/recording/cancel` (JSON body `{recording_id, end_time}`) to the +//! backend. The Python +//! SDK used to make this call inline from `nc.cancel_recording`, but that +//! required the SDK to know the cloud `recording_id` — which the thin-shipper +//! model removes. The notifier picks up the responsibility: once the local +//! cancel is stamped and the cloud id is known, it fires the POST in the +//! background with the daemon's standard retry policy. +//! +//! Recordings cancelled before `/recording/start` was ever notified (i.e. +//! `recording_id IS NULL`) have no cloud representation, so there is nothing +//! to cancel server-side; the notifier silently skips them. + +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::broadcast; + +use crate::api::ApiClient; +use crate::cloud::notifier::{ + notify_recording_lifecycle, spawn_notifier, LifecycleKind, NotifierCtx, NotifierHandle, + RecordingNotifier, +}; +use crate::cloud::OrgIdRx; +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::{ + DaemonEvent, EventBus, RecordingRow, SqliteStateStore, StateStore, StateStoreError, +}; + +/// Notifier that POSTs `/recording/cancel` once a recording is cancelled and +/// its cloud id is known. Recordings cancelled before `/recording/start` ever +/// landed have no cloud representation, so `notify_backend` skips them. +struct CancelNotifier; + +#[async_trait] +impl RecordingNotifier for CancelNotifier { + fn label(&self) -> &'static str { + "recording-cancel" + } + + fn triggered_by(&self, event: &DaemonEvent) -> Option { + match event { + DaemonEvent::RecordingCancelled { recording_index } => Some(*recording_index), + _ => None, + } + } + + async fn pending( + &self, + store: &Arc, + ) -> Result, StateStoreError> { + store.recordings_pending_cancel_notify().await + } + + async fn notify(&self, ctx: &NotifierCtx, recording_index: i64) { + notify_recording_lifecycle( + LifecycleKind::Cancel, + &ctx.store, + &ctx.client, + &ctx.org_rx, + recording_index, + ) + .await; + } +} + +/// Spawn the recording-cancel notifier on the current Tokio runtime. +pub fn spawn_recording_cancel_notifier( + store: SqliteStateStore, + bus: EventBus, + client: Arc, + org_rx: OrgIdRx, + shutdown_rx: broadcast::Receiver, +) -> NotifierHandle { + spawn_notifier(CancelNotifier, store, bus, client, org_rx, shutdown_rx) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::time::Duration; + + use tempfile::TempDir; + use tokio::sync::broadcast; + use tokio::time::{sleep, timeout}; + use wiremock::matchers::{body_partial_json, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::api::auth::StaticAuthProvider; + use crate::api::{ApiClient, ApiClientOptions}; + use crate::state::{DaemonEvent, EventBus, NewRecording, SqliteStateStore, StateStore}; + + async fn open_store() -> (SqliteStateStore, TempDir) { + let dir = TempDir::new().expect("tempdir"); + let store = SqliteStateStore::open(&dir.path().join("state.db")) + .await + .expect("open store"); + (store, dir) + } + + fn options(base_url: String) -> ApiClientOptions { + ApiClientOptions { + base_url, + timeout: Duration::from_secs(5), + max_retries: 1, + max_backoff: Duration::from_secs(1), + } + } + + async fn seed_cancelled_recording_with_cloud_id( + store: &SqliteStateStore, + cloud_id: &str, + ) -> i64 { + let row = store + .create_recording(NewRecording { + robot_id: Some("robot-1"), + robot_instance: Some(0), + start_timestamp_ns: 0, + ..NewRecording::default() + }) + .await + .expect("create_recording"); + let index = row.recording_index; + store + .mark_recording_start_notified(index, cloud_id) + .await + .expect("mark start notified"); + store + .cancel_recording(index, 5_000_000_000) + .await + .expect("cancel"); + index + } + + /// A live-org receiver fixed at `org`. The sender is leaked so the channel + /// stays open for the test's duration. + fn org_rx(org: Option<&str>) -> OrgIdRx { + let (org_tx, org_rx) = tokio::sync::watch::channel(org.map(str::to_string)); + Box::leak(Box::new(org_tx)); + org_rx + } + + #[tokio::test] + async fn posts_backend_cancel_on_recording_cancelled_event() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/cancel")) + .and(body_partial_json( + serde_json::json!({ "recording_id": "rec-cancel-1" }), + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!("ok"))) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + seed_cancelled_recording_with_cloud_id(&store, "rec-cancel-1").await; + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_cancel_notifier( + store.clone(), + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + bus.publish(DaemonEvent::RecordingCancelled { recording_index: 1 }); + + timeout(Duration::from_secs(3), async { + loop { + let received = server.received_requests().await.unwrap_or_default(); + if !received.is_empty() { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("expected one POST within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn startup_sweep_recovers_recordings_cancelled_while_offline() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/cancel")) + .and(body_partial_json( + serde_json::json!({ "recording_id": "rec-offline-cancel" }), + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!("ok"))) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let index = seed_cancelled_recording_with_cloud_id(&store, "rec-offline-cancel").await; + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_cancel_notifier( + store.clone(), + bus, + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + timeout(Duration::from_secs(3), async { + loop { + let received = server.received_requests().await.unwrap_or_default(); + if !received.is_empty() { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("sweep must POST within 3s"); + + timeout(Duration::from_secs(3), async { + loop { + let row = store + .get_recording(index) + .await + .expect("get") + .expect("exists"); + if row.backend_cancel_notified_at.is_some() { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("backend_cancel_notified_at must be stamped within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn treats_backend_404_as_already_cancelled() { + // The start-notifier's `resolve_prior_pending` may have closed this + // recording on the backend first (cancel-then-start with no gap), so a + // 404 here is the desired post-condition, not a failure: the row must + // still be marked notified so the sweep stops re-posting. + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/cancel")) + .respond_with( + ResponseTemplate::new(404) + .set_body_json(serde_json::json!({ "detail": "Recording not found." })), + ) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let index = seed_cancelled_recording_with_cloud_id(&store, "rec-already-gone").await; + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_cancel_notifier( + store.clone(), + bus, + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + timeout(Duration::from_secs(3), async { + loop { + let row = store + .get_recording(index) + .await + .expect("get") + .expect("exists"); + if row.backend_cancel_notified_at.is_some() { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("a 404 must still stamp backend_cancel_notified_at within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn skips_notify_when_recording_has_no_cloud_id() { + let server = MockServer::start().await; + let (store, _dir) = open_store().await; + + // A recording that was cancelled before /start was ever notified. + let row = store + .create_recording(NewRecording { + robot_id: Some("robot-1"), + robot_instance: Some(0), + start_timestamp_ns: 0, + ..NewRecording::default() + }) + .await + .unwrap(); + store + .cancel_recording(row.recording_index, 5_000_000_000) + .await + .unwrap(); + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_cancel_notifier( + store, + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + bus.publish(DaemonEvent::RecordingCancelled { + recording_index: row.recording_index, + }); + + sleep(Duration::from_millis(150)).await; + let received = server.received_requests().await.unwrap_or_default(); + assert!( + received.is_empty(), + "no backend POST expected when recording has no cloud id" + ); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } +} diff --git a/rust/data_daemon/src/cloud/recording_reaper.rs b/rust/data_daemon/src/cloud/recording_reaper.rs new file mode 100644 index 000000000..916e347e9 --- /dev/null +++ b/rust/data_daemon/src/cloud/recording_reaper.rs @@ -0,0 +1,142 @@ +//! Periodic recording reaper. +//! +//! Reclaims recordings whose local copy is redundant — the daemon owns no other +//! cleanup for a recording that reaches a settled terminal state, so without +//! this task both their files and DB rows leak forever. Two shapes qualify: +//! +//! * **Stopped + fully uploaded** — every declared trace uploaded and the +//! backend fully notified (stop POSTed, expected-trace-count + per-trace +//! progress reported). The cloud holds everything. +//! * **Cancelled** — the data was discarded; once the backend cancel has been +//! notified (`backend_cancel_notified_at`) nothing local needs keeping. +//! +//! For both, the reaper deletes the on-disk recording directory and then the +//! `recordings` / `traces` rows, keeping local disk and the state DB bounded +//! over a long-running daemon's lifetime. It is the single owner of +//! cancelled-recording file removal — the cancel path no longer unlinks files. +//! +//! The uploaded gate reads the authoritative per-trace `upload_status` rows; a +//! recording with a permanently `failed` trace never satisfies it, so data that +//! did not upload is intentionally retained. The startup sweep still handles +//! partial (mid-write) recordings separately. + +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::broadcast; +use tokio::task::JoinHandle; +use tokio::time::{interval, MissedTickBehavior}; + +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::{RecordingRow, SqliteStateStore, StateStore}; +use crate::storage::paths::recording_dir; + +/// Interval between reclaim sweeps. Reclamation is never latency-sensitive — +/// it only frees space already fully replicated to the cloud — so a relaxed +/// cadence keeps the scan off the hot path. +pub const RECLAIM_INTERVAL: Duration = Duration::from_secs(60); + +/// Handle returned by [`spawn_recording_reaper`]. +pub struct RecordingReaperHandle { + join: JoinHandle<()>, +} + +impl RecordingReaperHandle { + /// Wait for the reaper task to exit. + pub async fn join(self) { + if let Err(error) = self.join.await { + tracing::warn!(?error, "recording reaper join failed"); + } + } +} + +/// Spawn the recording reaper task on the current Tokio runtime. +pub fn spawn_recording_reaper( + store: SqliteStateStore, + recordings_root: Arc, + mut shutdown_rx: broadcast::Receiver, +) -> RecordingReaperHandle { + let store = Arc::new(store); + let join = tokio::spawn(async move { + let mut ticker = interval(RECLAIM_INTERVAL); + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + biased; + signal = shutdown_rx.recv() => { + tracing::debug!(?signal, "recording reaper shutting down"); + break; + } + _ = ticker.tick() => { + sweep_once(&store, &recordings_root).await; + } + } + } + }); + RecordingReaperHandle { join } +} + +async fn sweep_once(store: &Arc, recordings_root: &Arc) { + // Server-side filter returns *only* durably-settled, reclaimable + // recordings (cancel-notified, or stopped + fully uploaded with the + // expected trace count met). This walks neither every recording nor the + // traces of a recording wedged on a permanently-failed upload — both of + // which the old `list_recordings` + per-row trace fetch re-scanned every + // sweep, forever. + let recordings = match store.recordings_pending_reclaim().await { + Ok(rows) => rows, + Err(error) => { + tracing::warn!(%error, "recording reaper could not list reclaimable recordings"); + return; + } + }; + for recording in recordings { + reclaim(store, recordings_root, &recording).await; + } +} + +/// Remove the recording's on-disk directory, then its DB rows. Files are +/// deleted first: if the unlink fails the rows are left in place so the next +/// sweep retries rather than orphaning files with no row pointing at them. +async fn reclaim( + store: &Arc, + recordings_root: &Arc, + recording: &RecordingRow, +) { + let dir = recording_dir(recordings_root, recording.recording_index); + // `tokio::fs` so a large directory tree doesn't block a runtime worker + // (the sweep runs on the async reaper task). + match tokio::fs::remove_dir_all(&dir).await { + Ok(()) => {} + // Already gone (e.g. reclaimed on a prior sweep that crashed before the + // row delete committed) — fall through and finish removing the rows. + Err(error) if error.kind() == std::io::ErrorKind::NotFound => {} + Err(error) => { + tracing::warn!( + %error, + recording_index = recording.recording_index, + path = %dir.display(), + "recording reaper could not remove recording directory; retrying next sweep" + ); + return; + } + } + + match store + .delete_recording_cascade(recording.recording_index) + .await + { + Ok(traces_deleted) => tracing::info!( + recording_index = recording.recording_index, + traces_deleted, + "reclaimed fully-uploaded recording" + ), + Err(error) => tracing::warn!( + %error, + recording_index = recording.recording_index, + "recording reaper removed files but could not delete rows" + ), + } +} diff --git a/rust/data_daemon/src/cloud/recording_start_notifier.rs b/rust/data_daemon/src/cloud/recording_start_notifier.rs new file mode 100644 index 000000000..56e91b193 --- /dev/null +++ b/rust/data_daemon/src/cloud/recording_start_notifier.rs @@ -0,0 +1,550 @@ +//! Backend recording-start notifier. +//! +//! Subscribes to [`DaemonEvent::RecordingStarted`] and POSTs +//! `/org/{org}/recording/start` to the backend, persisting the cloud +//! `recording_id` the backend mints in response. The Python SDK used to make +//! this call inline from `nc.start_recording`, but the staging POST has a fat +//! upper tail. Doing it here means the SDK call returns as soon as the +//! producer publishes the `StartRecording` envelope, and the cloud-id mint +//! rides the daemon's standard retry policy in the background. +//! +//! The shared loop/sweep/lag semantics live in +//! [`notifier`](super::notifier); see there for how events are processed. What +//! is start-specific: the cloud `recording_id` is minted and persisted here, +//! and any prior pending recording is closed before the next start. Every +//! downstream coordinator (registration, progress, upload) waits for this id, +//! so an offline recording simply stays pending until the daemon is online and +//! `/recording/start` lands. + +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::broadcast; + +use crate::api::ApiClient; +use crate::cloud::notifier::{spawn_notifier, NotifierCtx, NotifierHandle, RecordingNotifier}; +use crate::cloud::OrgIdRx; +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::{ + DaemonEvent, EventBus, RecordingRow, SqliteStateStore, StateStore, StateStoreError, +}; + +/// Notifier that POSTs `/recording/start` and persists the cloud `recording_id` +/// the backend mints. The cloud id is always minted here — every downstream +/// coordinator waits on it — so an offline recording stays pending until the +/// daemon is online and the start POST lands. Before opening the new recording +/// it closes any earlier still-pending recording for the same source (see +/// [`resolve_prior_pending`]). +struct StartNotifier; + +#[async_trait] +impl RecordingNotifier for StartNotifier { + fn label(&self) -> &'static str { + "recording-start" + } + + fn triggered_by(&self, event: &DaemonEvent) -> Option { + match event { + DaemonEvent::RecordingStarted { recording_index } => Some(*recording_index), + _ => None, + } + } + + async fn pending( + &self, + store: &Arc, + ) -> Result, StateStoreError> { + store.recordings_pending_start_notify().await + } + + async fn notify(&self, ctx: &NotifierCtx, recording_index: i64) { + notify_backend( + &ctx.store, + &ctx.client, + &ctx.bus, + &ctx.org_rx, + recording_index, + ) + .await; + } +} + +/// Spawn the recording-start notifier on the current Tokio runtime. +pub fn spawn_recording_start_notifier( + store: SqliteStateStore, + bus: EventBus, + client: Arc, + org_rx: OrgIdRx, + shutdown_rx: broadcast::Receiver, +) -> NotifierHandle { + spawn_notifier(StartNotifier, store, bus, client, org_rx, shutdown_rx) +} + +async fn notify_backend( + store: &Arc, + client: &Arc, + bus: &EventBus, + org_rx: &OrgIdRx, + recording_index: i64, +) { + let row = match store.get_recording(recording_index).await { + Ok(Some(row)) => row, + Ok(None) => { + tracing::warn!( + recording_index, + "recording row missing on start; skipping backend notify", + ); + return; + } + Err(error) => { + tracing::warn!( + %error, + recording_index, + "failed to look up recording for start notify", + ); + return; + } + }; + if row.recording_id.is_some() || row.backend_start_notified_at.is_some() { + // Already notified — another path handled it. + return; + } + + let Some(org_id) = org_rx.borrow().clone() else { + // No current org configured yet (not logged in / org not selected). + // Without it we can't address the POST; the next sweep retries once + // the config watcher picks up a current org. + tracing::warn!( + recording_index, + "no current org_id configured at start time; skipping backend notify", + ); + return; + }; + let Some(robot_id) = row.robot_id else { + tracing::warn!( + recording_index, + "recording has no robot_id at start time; skipping backend notify", + ); + return; + }; + let Some(dataset_id) = row.dataset_id else { + tracing::warn!( + recording_index, + "recording has no dataset_id at start time; skipping backend notify", + ); + return; + }; + let instance = row.robot_instance.unwrap_or(0); + let Some(start_timestamp_ns) = row.start_timestamp_ns else { + tracing::warn!( + recording_index, + "recording has no start_timestamp_ns at start time; skipping backend notify", + ); + return; + }; + // The producer captured this as the recording window's real lower bound; + // the backend requires it (seconds) and derives the reported duration from + // it, so a late notify (e.g. after reconnecting) still reports correctly. + let start_time = start_timestamp_ns as f64 / 1_000_000_000.0; + + // Before opening this recording server-side, close any earlier recording for + // the same source that finished locally (cancel/stop) but whose backend + // notification has not landed yet. The backend dedupes pending recordings + // per robot instance — it returns the existing pending recording instead of + // minting a new one — so a still-pending prior recording would otherwise + // hand its cloud id to this one, collapsing both into one backend recording + // (e.g. cancel-then-start with no gap). The start notifier processes + // `RecordingStarted` events in order, so the prior recording's cloud id is + // already on its row by the time we reach here. + resolve_prior_pending(store, client, &org_id, &robot_id, instance, recording_index).await; + + match client + .recording_start(&org_id, &robot_id, instance, &dataset_id, start_time) + .await + { + Ok(recording_id) => { + if let Err(error) = store + .mark_recording_start_notified(recording_index, &recording_id) + .await + { + tracing::warn!( + %error, + recording_index, + recording_id, + "POST succeeded but persisting the cloud recording_id failed; \ + the next sweep will re-post (the start notify is idempotent)", + ); + } else { + tracing::info!( + recording_index, + recording_id, + "backend notified of recording start", + ); + // The cloud id is now available. Wake any coordinator that was + // waiting on it — notably the stop notifier, for a recording + // that was stopped while offline before its start was notified. + bus.publish(DaemonEvent::RecordingCloudIdAssigned { recording_index }); + } + } + Err(error) => { + // The producer-side iceoryx2 publish has already succeeded by + // the time we get here; logging is the only available recourse + // until the next sweep retries. + tracing::warn!( + %error, + recording_index, + "failed to notify backend of recording start", + ); + } + } +} + +/// Close, on the backend, any earlier recording for `(robot_id, instance)` that +/// finished locally (cancelled or stopped) but is still pending server-side, so +/// the backend does not hand its cloud id to the next `/recording/start` for +/// this instance. See +/// [`StateStore::recordings_pending_backend_resolution_for_source`]. +async fn resolve_prior_pending( + store: &Arc, + client: &Arc, + org_id: &str, + robot_id: &str, + instance: i64, + before_index: i64, +) { + let prior = match store + .recordings_pending_backend_resolution_for_source(robot_id, instance, before_index) + .await + { + Ok(rows) => rows, + Err(error) => { + tracing::warn!( + %error, + before_index, + "failed to query prior pending recordings for source; next start may reuse a cloud id", + ); + return; + } + }; + for row in prior { + let index = row.recording_index; + let is_cancelled = row.cancelled_at.is_some(); + // Cancel and stop both report the recording's captured stop time as + // `end_time` (a cancel is a stop that discards data). Compute it before + // `recording_id` is moved out of `row`. + let end_time = row.stop_timestamp_ns.map(|ns| ns as f64 / 1_000_000_000.0); + // Defensive against the query contract: the pending-resolution query + // only returns cloud-id-assigned, stopped/cancelled rows (which always + // carry a stop timestamp), so these guards should never skip in + // practice — they just keep the extraction total. + let Some(recording_id) = row.recording_id else { + continue; + }; + let Some(end_time) = end_time else { + continue; + }; + if is_cancelled { + match client + .recording_cancel(org_id, &recording_id, end_time) + .await + { + Ok(()) => { + let _ = store.mark_recording_cancel_notified(index).await; + tracing::info!( + recording_index = index, + recording_id, + next_recording_index = before_index, + "cancelled prior pending recording on the backend before opening the next", + ); + } + Err(error) if error.is_not_found() => { + // Already closed — the cancel-notifier sweep won the race. + // The prior recording is not pending on the backend, so the + // next start cannot reuse its id; mark it notified so the + // sweep stops re-posting too. + let _ = store.mark_recording_cancel_notified(index).await; + tracing::debug!( + recording_index = index, + recording_id, + next_recording_index = before_index, + "prior pending recording already cancelled on backend (404)", + ); + } + Err(error) => { + tracing::warn!( + %error, + recording_index = index, + recording_id, + "failed to cancel prior pending recording before next start; \ + the next start may reuse its cloud id", + ); + } + } + } else { + match client.recording_stop(org_id, &recording_id, end_time).await { + Ok(()) => { + let _ = store.mark_recording_stop_notified(index).await; + tracing::info!( + recording_index = index, + recording_id, + next_recording_index = before_index, + "stopped prior pending recording on the backend before opening the next", + ); + } + Err(error) if error.is_not_found() => { + // Already closed — the stop-notifier sweep won the race. Mark + // it notified so the sweep stops re-posting too. + let _ = store.mark_recording_stop_notified(index).await; + tracing::debug!( + recording_index = index, + recording_id, + next_recording_index = before_index, + "prior pending recording already stopped on backend (404)", + ); + } + Err(error) => { + tracing::warn!( + %error, + recording_index = index, + recording_id, + "failed to stop prior pending recording before next start", + ); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::time::Duration; + + use tempfile::TempDir; + use tokio::sync::broadcast; + use tokio::time::{sleep, timeout}; + use wiremock::matchers::{body_partial_json, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::api::auth::StaticAuthProvider; + use crate::api::{ApiClient, ApiClientOptions}; + use crate::lifecycle::signals::ShutdownSignal; + use crate::state::{DaemonEvent, EventBus, NewRecording, SqliteStateStore, StateStore}; + + async fn open_store() -> (SqliteStateStore, TempDir) { + let dir = TempDir::new().expect("tempdir"); + let store = SqliteStateStore::open(&dir.path().join("state.db")) + .await + .expect("open store"); + (store, dir) + } + + fn options(base_url: String) -> ApiClientOptions { + ApiClientOptions { + base_url, + timeout: Duration::from_secs(5), + max_retries: 1, + max_backoff: Duration::from_secs(1), + } + } + + /// Insert a fresh recording (no cloud id yet) and return its local index. + async fn seed_recording(store: &SqliteStateStore) -> i64 { + store + .create_recording(NewRecording { + robot_id: Some("robot-1"), + robot_instance: Some(7), + dataset_id: Some("ds-1"), + start_timestamp_ns: 1_700_000_000_000_000_000, + }) + .await + .expect("create recording") + .recording_index + } + + /// A live-org receiver fixed at `org`. The sender is leaked so the channel + /// stays open for the test's duration. + fn org_rx(org: Option<&str>) -> OrgIdRx { + let (org_tx, org_rx) = tokio::sync::watch::channel(org.map(str::to_string)); + Box::leak(Box::new(org_tx)); + org_rx + } + + fn start_ok_mock(recording_id: &'static str) -> wiremock::Mock { + Mock::given(method("POST")) + .and(path("/org/org-1/recording/start")) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({ "id": recording_id })), + ) + } + + #[tokio::test] + async fn posts_backend_start_on_recording_started_event() { + let server = MockServer::start().await; + start_ok_mock("cloud-rec-1").mount(&server).await; + + let (store, _dir) = open_store().await; + let index = seed_recording(&store).await; + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_start_notifier( + store.clone(), + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + bus.publish(DaemonEvent::RecordingStarted { + recording_index: index, + }); + + // The cloud id lands on the row once the POST round-trips. + timeout(Duration::from_secs(3), async { + loop { + let row = store + .get_recording(index) + .await + .expect("get") + .expect("exists"); + if row.recording_id.is_some() { + assert_eq!(row.recording_id.as_deref(), Some("cloud-rec-1")); + assert!(row.backend_start_notified_at.is_some()); + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("cloud recording_id must be persisted within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn cancels_prior_pending_recording_before_opening_the_next() { + // Cancel-then-start (no gap) for one source: the prior recording was + // cancelled before its cloud id was notified, so it is still pending on + // the backend. Opening the next recording must cancel it FIRST, so the + // backend mints a fresh id instead of handing back the cancelled one + // (which would collapse both recordings into one cloud recording). + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/cancel")) + .and(body_partial_json( + serde_json::json!({ "recording_id": "cloud-cancelled-A" }), + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!("ok"))) + .mount(&server) + .await; + start_ok_mock("cloud-fresh-B").mount(&server).await; + + let (store, _dir) = open_store().await; + // Prior recording A (same source): start-notified, then cancelled, with + // its backend cancel still pending. + let prior = seed_recording(&store).await; + store + .mark_recording_start_notified(prior, "cloud-cancelled-A") + .await + .expect("mark start notified"); + store + .cancel_recording(prior, 5_000_000_000) + .await + .expect("cancel"); + // The next recording B for the same source. + let next = seed_recording(&store).await; + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_start_notifier( + store.clone(), + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + bus.publish(DaemonEvent::RecordingStarted { + recording_index: next, + }); + + timeout(Duration::from_secs(3), async { + loop { + let prior_row = store + .get_recording(prior) + .await + .expect("get") + .expect("exists"); + let next_row = store + .get_recording(next) + .await + .expect("get") + .expect("exists"); + if prior_row.backend_cancel_notified_at.is_some() && next_row.recording_id.is_some() + { + // Prior cancelled server-side; next opened with a FRESH id. + assert_eq!(next_row.recording_id.as_deref(), Some("cloud-fresh-B")); + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("prior recording must be cancelled and next opened fresh within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn startup_sweep_notifies_recordings_opened_while_offline() { + // A recording opened during a previous offline session: no cloud id, + // no start-notify/failed stamps. The pre-loop sweep must POST and + // persist the minted cloud id. + let server = MockServer::start().await; + start_ok_mock("cloud-rec-offline").mount(&server).await; + + let (store, _dir) = open_store().await; + let index = seed_recording(&store).await; + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_start_notifier( + store.clone(), + bus, + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + timeout(Duration::from_secs(3), async { + loop { + let row = store + .get_recording(index) + .await + .expect("get") + .expect("exists"); + if row.recording_id.as_deref() == Some("cloud-rec-offline") { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("sweep must persist the minted cloud id within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } +} diff --git a/rust/data_daemon/src/cloud/recording_stop_notifier.rs b/rust/data_daemon/src/cloud/recording_stop_notifier.rs new file mode 100644 index 000000000..fe15321c2 --- /dev/null +++ b/rust/data_daemon/src/cloud/recording_stop_notifier.rs @@ -0,0 +1,452 @@ +//! Backend recording-stop notifier. +//! +//! Subscribes to [`DaemonEvent::RecordingStopped`] and POSTs +//! `/org/{org}/recording/stop` (JSON body `{recording_id, end_time}`) to the +//! backend. The Python SDK +//! used to make this call inline from `nc.stop_recording`, but the staging +//! POST has a fat upper tail (occasional 1-2 s spikes on otherwise +//! sub-second calls). Doing it here means the SDK call returns as soon as +//! the producer publishes the `StopRecording` envelope, and the staging +//! notification rides the daemon's standard retry policy in the background. +//! +//! A single long-lived task processes lifecycle events sequentially, awaiting +//! each POST inline (a pre-loop sweep and a broadcast-lag sweep recover any +//! recordings whose notification is pending from an offline session — see +//! [`notifier`](super::notifier) for the shared loop). Failures are logged +//! with the recording index but never surfaced to the SDK — by the time we +//! reach this notifier the SDK is long gone and the producer's iceoryx2 +//! publish already succeeded. + +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::broadcast; + +use crate::api::ApiClient; +use crate::cloud::notifier::{ + notify_recording_lifecycle, spawn_notifier, LifecycleKind, NotifierCtx, NotifierHandle, + RecordingNotifier, +}; +use crate::cloud::OrgIdRx; +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::{ + DaemonEvent, EventBus, RecordingRow, SqliteStateStore, StateStore, StateStoreError, +}; + +/// Notifier that POSTs `/recording/stop` once a recording stops and its cloud +/// id is known. Triggered by `RecordingStopped` (the live path) and by +/// `RecordingCloudIdAssigned` (offline recovery: a recording stopped while +/// offline already fired `RecordingStopped` before any coordinator could see +/// it, so the POST is unblocked only when the start notifier later mints the +/// cloud id — `notify_backend` no-ops for a not-yet-stopped recording). +struct StopNotifier; + +#[async_trait] +impl RecordingNotifier for StopNotifier { + fn label(&self) -> &'static str { + "recording-stop" + } + + fn triggered_by(&self, event: &DaemonEvent) -> Option { + match event { + DaemonEvent::RecordingStopped { recording_index } + | DaemonEvent::RecordingCloudIdAssigned { recording_index } => Some(*recording_index), + _ => None, + } + } + + async fn pending( + &self, + store: &Arc, + ) -> Result, StateStoreError> { + store.recordings_pending_stop_notify().await + } + + async fn notify(&self, ctx: &NotifierCtx, recording_index: i64) { + notify_recording_lifecycle( + LifecycleKind::Stop, + &ctx.store, + &ctx.client, + &ctx.org_rx, + recording_index, + ) + .await; + } +} + +/// Spawn the recording-stop notifier on the current Tokio runtime. +pub fn spawn_recording_stop_notifier( + store: SqliteStateStore, + bus: EventBus, + client: Arc, + org_rx: OrgIdRx, + shutdown_rx: broadcast::Receiver, +) -> NotifierHandle { + spawn_notifier(StopNotifier, store, bus, client, org_rx, shutdown_rx) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::time::Duration; + + use tempfile::TempDir; + use tokio::sync::broadcast; + use tokio::time::{sleep, timeout}; + use wiremock::matchers::{body_partial_json, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use crate::api::auth::StaticAuthProvider; + use crate::api::{ApiClient, ApiClientOptions}; + use crate::lifecycle::signals::ShutdownSignal; + use crate::state::{DaemonEvent, EventBus, NewRecording, SqliteStateStore, StateStore}; + + async fn open_store() -> (SqliteStateStore, TempDir) { + let dir = TempDir::new().expect("tempdir"); + let store = SqliteStateStore::open(&dir.path().join("state.db")) + .await + .expect("open store"); + (store, dir) + } + + fn options(base_url: String) -> ApiClientOptions { + ApiClientOptions { + base_url, + timeout: Duration::from_secs(5), + max_retries: 1, + max_backoff: Duration::from_secs(1), + } + } + + /// Insert a recording, stamp its cloud id (as if `/start` was notified), + /// and return its local index. + async fn seed_notified_recording(store: &SqliteStateStore, recording_id: &str) -> i64 { + let index = store + .create_recording(NewRecording { + robot_id: Some("robot-1"), + robot_instance: Some(0), + dataset_id: Some("ds-1"), + start_timestamp_ns: 1_700_000_000_000_000_000, + }) + .await + .expect("create recording") + .recording_index; + store + .mark_recording_start_notified(index, recording_id) + .await + .expect("mark start notified"); + index + } + + /// A live-org receiver fixed at `org`. The sender is leaked so the channel + /// stays open for the test's duration. + fn org_rx(org: Option<&str>) -> OrgIdRx { + let (org_tx, org_rx) = tokio::sync::watch::channel(org.map(str::to_string)); + Box::leak(Box::new(org_tx)); + org_rx + } + + #[tokio::test] + async fn posts_backend_stop_on_recording_stopped_event() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/stop")) + .and(body_partial_json( + serde_json::json!({ "recording_id": "rec-stop-1" }), + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!("ok"))) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let index = seed_notified_recording(&store, "rec-stop-1").await; + store + .mark_recording_stopped(index, 1) + .await + .expect("mark stopped"); + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_stop_notifier( + store.clone(), + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + bus.publish(DaemonEvent::RecordingStopped { + recording_index: index, + }); + + // Give the notifier task a moment to drain the event and call wiremock. + timeout(Duration::from_secs(3), async { + loop { + let received = server.received_requests().await.unwrap_or_default(); + if !received.is_empty() { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("expected one POST within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn startup_sweep_recovers_recordings_stopped_while_offline() { + // Simulate a daemon coming online with a recording that was + // stopped during a previous offline session: `stopped_at` is + // already set, `backend_stop_notified_at` is still NULL. The + // notifier's pre-loop sweep must POST and mark the row notified. + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/stop")) + .and(body_partial_json( + serde_json::json!({ "recording_id": "rec-offline-1" }), + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!("ok"))) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let index = seed_notified_recording(&store, "rec-offline-1").await; + store + .mark_recording_stopped(index, 1) + .await + .expect("mark stopped"); + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_stop_notifier( + store.clone(), + bus, + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + timeout(Duration::from_secs(3), async { + loop { + let received = server.received_requests().await.unwrap_or_default(); + if !received.is_empty() { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("sweep must POST within 3s"); + + // Give the notifier a beat to persist the success column. + timeout(Duration::from_secs(3), async { + loop { + let row = store + .get_recording(index) + .await + .expect("get") + .expect("exists"); + if row.backend_stop_notified_at.is_some() { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("backend_stop_notified_at must be stamped within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn skips_notify_when_recording_row_missing() { + let server = MockServer::start().await; + let (store, _dir) = open_store().await; + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_stop_notifier( + store, + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + bus.publish(DaemonEvent::RecordingStopped { + recording_index: 9_999, + }); + + // Yield enough for the notifier to process the event and bail. We + // assert *absence* of an HTTP request: wiremock has no mocks armed, + // so any incoming request would have already failed the test. A + // short sleep is the cheapest way to observe quiescence. + sleep(Duration::from_millis(150)).await; + let received = server.received_requests().await.unwrap_or_default(); + assert!( + received.is_empty(), + "no backend POST expected when recording row is missing" + ); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn skips_notify_when_cloud_id_absent() { + // A stopped recording without a cloud id has nothing to stop + // server-side; the notifier must defer (no POST) until the start + // notifier fills the id. + let server = MockServer::start().await; + let (store, _dir) = open_store().await; + let index = store + .create_recording(NewRecording { + robot_id: Some("robot-1"), + robot_instance: Some(0), + start_timestamp_ns: 1_700_000_000_000_000_000, + ..NewRecording::default() + }) + .await + .expect("create recording") + .recording_index; + store + .mark_recording_stopped(index, 1) + .await + .expect("mark stopped"); + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_stop_notifier( + store, + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + bus.publish(DaemonEvent::RecordingStopped { + recording_index: index, + }); + + sleep(Duration::from_millis(150)).await; + let received = server.received_requests().await.unwrap_or_default(); + assert!( + received.is_empty(), + "no backend POST expected when the cloud recording_id is absent" + ); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn cloud_id_assigned_event_notifies_a_recording_stopped_while_offline() { + // Offline recovery: a recording stopped while offline already fired its + // `RecordingStopped` (which no coordinator saw). Once the start notifier + // assigns the cloud id and publishes `RecordingCloudIdAssigned`, the + // stop notifier must POST `/recording/stop`. + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/stop")) + .and(body_partial_json( + serde_json::json!({ "recording_id": "rec-recovered-1" }), + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!("ok"))) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let index = seed_notified_recording(&store, "rec-recovered-1").await; + store + .mark_recording_stopped(index, 1) + .await + .expect("mark stopped"); + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_stop_notifier( + store.clone(), + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + // The cloud-id-assigned event — not RecordingStopped — drives the POST. + bus.publish(DaemonEvent::RecordingCloudIdAssigned { + recording_index: index, + }); + + timeout(Duration::from_secs(3), async { + loop { + let received = server.received_requests().await.unwrap_or_default(); + if !received.is_empty() { + break; + } + sleep(Duration::from_millis(20)).await; + } + }) + .await + .expect("cloud-id-assigned event must POST within 3s"); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } + + #[tokio::test] + async fn cloud_id_assigned_event_ignores_a_running_recording() { + // A recording that just got its cloud id but has not stopped yet must + // not be stop-notified — the `stopped_at` guard holds the POST until the + // recording actually stops. + let server = MockServer::start().await; + let (store, _dir) = open_store().await; + let index = seed_notified_recording(&store, "rec-running-1").await; + // Deliberately NOT stopped. + + let auth = Arc::new(StaticAuthProvider::new("token-1")); + let client = Arc::new(ApiClient::new(options(server.uri()), auth).expect("client")); + + let bus = EventBus::new(); + let (shutdown_tx, _) = broadcast::channel::(8); + let handle = spawn_recording_stop_notifier( + store, + bus.clone(), + client, + org_rx(Some("org-1")), + shutdown_tx.subscribe(), + ); + + bus.publish(DaemonEvent::RecordingCloudIdAssigned { + recording_index: index, + }); + + sleep(Duration::from_millis(150)).await; + let received = server.received_requests().await.unwrap_or_default(); + assert!( + received.is_empty(), + "no backend POST expected for a recording that has not stopped" + ); + + let _ = shutdown_tx.send(ShutdownSignal::Sigterm); + handle.join().await; + } +} diff --git a/rust/data_daemon/src/cloud/registration.rs b/rust/data_daemon/src/cloud/registration.rs new file mode 100644 index 000000000..b500f0996 --- /dev/null +++ b/rust/data_daemon/src/cloud/registration.rs @@ -0,0 +1,770 @@ +//! Batch registration coordinator. +//! +//! Claims traces whose row exists (any write_status except `failed`) — not just +//! fully-written ones — buffers up to `BATCH_SIZE` (or `MAX_WAIT`) and POSTs +//! them to `/org/{org}/recording/traces/batch-register`. Registration only +//! needs the trace's *identity* (recording id, trace id, data type, cloud +//! files), all known at `/recording/start`, so it runs **while the recording is +//! still writing** — overlapping the round trip with the recording instead of +//! adding it to the post-stop tail ("pre-registration"). +//! +//! Because registration and the on-disk write can finish in either order, +//! `ReadyForUpload` is gated on BOTH states; [`publish_ready_traces`] owns that +//! promotion (and its write-behind-lag safety-net role). Registration failures +//! roll the status back to `Pending` so the next tick re-claims them. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::broadcast; +use tokio::task::JoinHandle; +use tokio::time::{interval, MissedTickBehavior}; + +use crate::api::models::RegisterTraceRequest; +use crate::api::ApiClient; +use crate::cloud::cloud_files::cloud_file_list; +use crate::cloud::OrgIdRx; +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::store::TraceUpdate; +use crate::state::{ + DaemonEvent, EventBus, SqliteStateStore, StateStore, TraceRecord, TraceRegistrationStatus, +}; + +/// Maximum traces to register in a single call. Matches the +/// `claim_traces_for_registration` size trigger. +pub const BATCH_SIZE: usize = 50; +/// Maximum age before flushing a partial batch. +pub const MAX_WAIT: Duration = Duration::from_millis(200); +/// Poll interval the coordinator falls back to when the bus is quiet. +pub const POLL_INTERVAL: Duration = Duration::from_millis(500); +/// How many times a trace the backend explicitly rejects (returns in +/// `failed_traces`) is rolled back to `pending` and retried before being marked +/// terminally `failed`. Backend registration errors are frequently transient +/// (e.g. a staging "Unexpected error during registration" under a large +/// registration burst); terminally failing on the first one permanently wedges +/// the whole recording (its traces never upload, so it never reaches "all +/// uploaded" and is never reaped). A small bounded retry rides out the hiccup +/// while still terminating a genuinely-permanent failure. +const MAX_REGISTRATION_ATTEMPTS: u32 = 5; + +/// Handle returned by [`spawn_registration`]. +pub struct RegistrationCoordinatorHandle { + join: JoinHandle<()>, +} + +impl RegistrationCoordinatorHandle { + /// Wait for the coordinator task to exit. + pub async fn join(self) { + if let Err(error) = self.join.await { + tracing::warn!(?error, "registration coordinator join failed"); + } + } +} + +/// Spawn the registration coordinator on the current Tokio runtime. +pub fn spawn_registration( + store: SqliteStateStore, + bus: EventBus, + client: Arc, + org_rx: OrgIdRx, + mut shutdown_rx: broadcast::Receiver, +) -> RegistrationCoordinatorHandle { + let mut subscriber = bus.subscribe(); + let store = Arc::new(store); + let join = tokio::spawn(async move { + let mut ticker = interval(POLL_INTERVAL); + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + // Per-trace count of backend-rejected registration attempts, kept for + // the coordinator's lifetime so the retry budget spans drains. Entries + // are removed once a trace registers or is terminally failed, so the map + // only ever holds currently-retrying traces. + let mut registration_attempts: HashMap = HashMap::new(); + + loop { + tokio::select! { + biased; + signal = shutdown_rx.recv() => { + tracing::debug!(?signal, "registration coordinator shutting down"); + break; + } + event = subscriber.recv() => { + match event { + Ok(DaemonEvent::TraceWritten { .. }) => { + drain_once(&store, &bus, &client, &org_rx, MAX_WAIT, &mut registration_attempts).await; + } + Ok(_) => {} + Err(broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!( + skipped, + "registration coordinator missed bus events; \ + falling back to a drain" + ); + drain_once(&store, &bus, &client, &org_rx, MAX_WAIT, &mut registration_attempts).await; + } + Err(broadcast::error::RecvError::Closed) => break, + } + } + _ = ticker.tick() => { + drain_once(&store, &bus, &client, &org_rx, MAX_WAIT, &mut registration_attempts).await; + } + } + } + }); + RegistrationCoordinatorHandle { join } +} + +async fn drain_once( + store: &Arc, + bus: &EventBus, + client: &Arc, + org_rx: &OrgIdRx, + max_wait: Duration, + registration_attempts: &mut HashMap, +) { + // Safety net: promote any traces that became (registered + written) since + // the last drain. This runs even when there is nothing new to register, so + // the periodic tick eventually promotes a pre-registered trace once its + // write-behind `write_status = written` commit lands. + publish_ready_traces(store, bus).await; + + let claimed = match store + .claim_traces_for_registration(BATCH_SIZE, max_wait.as_secs_f64()) + .await + { + Ok(rows) => rows, + Err(error) => { + tracing::warn!(%error, "claim_traces_for_registration failed"); + return; + } + }; + if claimed.is_empty() { + return; + } + tracing::debug!(count = claimed.len(), "claimed traces for registration"); + submit_batch(store, bus, client, org_rx, claimed, registration_attempts).await; + publish_ready_traces(store, bus).await; +} + +async fn submit_batch( + store: &Arc, + bus: &EventBus, + client: &Arc, + org_rx: &OrgIdRx, + traces: Vec, + registration_attempts: &mut HashMap, +) { + // Group by recording so we can look up the recording row once per + // recording rather than once per trace; in practice every claim ships + // traces from a single recording but the protocol does not require that. + let mut by_recording: HashMap> = HashMap::new(); + for trace in traces { + by_recording + .entry(trace.recording_index) + .or_default() + .push(trace); + } + + for (recording_index, traces) in by_recording { + let row = match store.get_recording(recording_index).await { + Ok(Some(row)) => row, + Ok(None) => { + tracing::warn!( + recording_index, + "recording row missing; rolling traces back to pending" + ); + rollback_to_pending(store, &traces).await; + continue; + } + Err(error) => { + tracing::warn!(%error, recording_index, "failed to read recording row"); + rollback_to_pending(store, &traces).await; + continue; + } + }; + + let Some(org_id) = org_rx.borrow().clone() else { + tracing::warn!( + recording_index, + "no current org_id configured yet; rolling traces back to pending" + ); + rollback_to_pending(store, &traces).await; + continue; + }; + + // The backend recording_id always comes from `/recording/start`. An + // offline recording (or one whose `/recording/start` POST has not yet + // landed) carries no cloud id, so there is nothing to register against + // yet — roll the traces back to pending and retry once the start + // notifier has populated the id. + let Some(cloud_id) = row.recording_id.clone() else { + rollback_to_pending(store, &traces).await; + continue; + }; + + let payload: Vec = traces + .iter() + .map(|trace| RegisterTraceRequest { + recording_id: cloud_id.clone(), + data_type: trace.data_type.clone().unwrap_or_default(), + trace_id: trace.trace_id.clone(), + cloud_files: cloud_file_list( + trace.data_type.as_deref().unwrap_or(""), + trace.data_type_name.as_deref(), + ), + }) + .collect(); + + match client.batch_register(&org_id, &payload).await { + Ok(response) => { + let registered_ids: HashMap = response + .registered_traces + .into_iter() + .map(|entry| (entry.trace_id.clone(), entry.upload_session_uris)) + .collect(); + let failed_ids: HashMap> = response + .failed_traces + .into_iter() + .map(|entry| (entry.trace_id, entry.error)) + .collect(); + + for trace in &traces { + if let Some(uris) = registered_ids.get(&trace.trace_id) { + // A serialise failure must NOT mark the trace registered + // with a "{}" placeholder — that records an empty URI map + // and the uploader later finalises it as 0 bytes uploaded + // (silent data loss). Roll back to pending so the next + // tick re-registers it instead. + let serialised = match serde_json::to_string(uris) { + Ok(serialised) => serialised, + Err(error) => { + tracing::warn!(%error, trace_id = trace.trace_id, "failed to serialise session URIs; rolling back to pending"); + rollback_single_to_pending(store, &trace.trace_id).await; + continue; + } + }; + let update = TraceUpdate { + registration_status: Some(TraceRegistrationStatus::Registered), + upload_session_uris: Some(serialised), + ..TraceUpdate::default() + }; + if let Err(error) = store.update_trace(&trace.trace_id, update).await { + // The backend registered the trace but we couldn't + // persist it; leaving it in `registering` would wedge + // it for the session (no coordinator re-claims that + // state mid-session). Roll back to `pending` so the + // next tick re-claims and re-registers it. + tracing::warn!(%error, trace_id = trace.trace_id, "failed to persist registration outcome; rolling back to pending"); + rollback_single_to_pending(store, &trace.trace_id).await; + continue; + } + // Registered — clear any accumulated retry budget. + registration_attempts.remove(&trace.trace_id); + bus.publish(DaemonEvent::TraceRegistered { + trace_id: trace.trace_id.clone(), + recording_index, + }); + } else if let Some(error) = failed_ids.get(&trace.trace_id) { + // Backend rejections are usually transient (e.g. a + // staging burst error); retry under the shared budget. + handle_registration_setback( + store, + registration_attempts, + &trace.trace_id, + error.clone(), + error.as_deref().unwrap_or("backend rejected trace"), + ) + .await; + } else { + // Backend returned neither a registered nor a failed + // entry for this trace; retry under the same bounded + // budget so a persistently-omitted trace can't loop + // forever. + handle_registration_setback( + store, + registration_attempts, + &trace.trace_id, + Some("backend returned no registration outcome".to_string()), + "backend silently dropped trace", + ) + .await; + } + } + } + Err(error) => { + tracing::warn!(%error, recording_index, "batch register request failed"); + rollback_to_pending(store, &traces).await; + } + } + } +} + +/// Promote any traces that are now both registered and written to `queued` and +/// emit `ReadyForUpload` for each. +/// +/// Run on every drain (including the periodic tick) so it doubles as the safety +/// net for the lag between the `TraceWritten` event and the write-behind commit +/// of `write_status`: a pre-registered trace is promoted on whichever drain +/// first sees both states committed, rather than depending on a single event. +async fn publish_ready_traces(store: &Arc, bus: &EventBus) { + match store.promote_ready_traces_to_queued().await { + Ok(ready) => { + for (trace_id, recording_index) in ready { + bus.publish(DaemonEvent::ReadyForUpload { + trace_id, + recording_index, + }); + } + } + Err(error) => { + tracing::warn!(%error, "failed to promote ready traces for upload"); + } + } +} + +async fn rollback_to_pending(store: &Arc, traces: &[TraceRecord]) { + for trace in traces { + rollback_single_to_pending(store, &trace.trace_id).await; + } +} + +/// Apply bounded-retry accounting to a trace the backend did not register — +/// either an explicit rejection or a silent omission. Rolls the trace back to +/// `pending` for another attempt, or terminally marks it `failed` once +/// [`MAX_REGISTRATION_ATTEMPTS`] is reached, so a persistently-unregisterable +/// trace can't re-claim and re-POST forever. `error_message` is the reason +/// persisted on terminal failure; `reason` is the human-readable log context. +async fn handle_registration_setback( + store: &Arc, + registration_attempts: &mut HashMap, + trace_id: &str, + error_message: Option, + reason: &str, +) { + let attempts = registration_attempts + .entry(trace_id.to_string()) + .or_insert(0); + *attempts += 1; + if *attempts < MAX_REGISTRATION_ATTEMPTS { + tracing::warn!( + trace_id, + reason, + attempt = *attempts, + "trace registration setback; rolling back to pending for retry" + ); + rollback_single_to_pending(store, trace_id).await; + return; + } + tracing::warn!( + trace_id, + reason, + attempts = *attempts, + "trace registration setback after retry budget exhausted; marking failed" + ); + registration_attempts.remove(trace_id); + let update = TraceUpdate { + registration_status: Some(TraceRegistrationStatus::Failed), + error_message: Some(error_message), + ..TraceUpdate::default() + }; + // If persisting the `failed` status itself fails, the trace would otherwise + // sit in `registering` forever (no coordinator re-claims that state + // mid-session), so roll it back to `pending` for the next tick. + if let Err(persist_error) = store.update_trace(trace_id, update).await { + tracing::warn!(%persist_error, trace_id, "failed to persist registration failure; rolling back to pending"); + rollback_single_to_pending(store, trace_id).await; + } +} + +async fn rollback_single_to_pending(store: &Arc, trace_id: &str) { + let update = TraceUpdate { + registration_status: Some(TraceRegistrationStatus::Pending), + ..TraceUpdate::default() + }; + if let Err(error) = store.update_trace(trace_id, update).await { + tracing::warn!(%error, trace_id, "failed to roll registration status back"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::api::auth::StaticAuthProvider; + use crate::api::client::ApiClientOptions; + use crate::state::store::TraceUpdate; + use crate::state::{NewRecording, TraceUploadStatus, TraceWriteStatus}; + use std::time::Duration; + use tempfile::TempDir; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + async fn open_store() -> (SqliteStateStore, TempDir) { + let dir = TempDir::new().unwrap(); + let store = SqliteStateStore::open(&dir.path().join("state.db")) + .await + .unwrap(); + (store, dir) + } + + /// A live-org receiver fixed at `org` for the duration of a test. The + /// sender is leaked so the channel stays open and `borrow()` keeps + /// returning the seeded value. + fn org_rx(org: Option<&str>) -> OrgIdRx { + let (org_tx, org_rx) = tokio::sync::watch::channel(org.map(str::to_string)); + Box::leak(Box::new(org_tx)); + org_rx + } + + /// Seed a recording plus a single written trace under it, returning the + /// local `recording_index`. When `cloud_id` is `Some`, the recording's + /// cloud `recording_id` is persisted (as the start notifier would) so + /// registration finds one; when `None`, the recording has no cloud id yet + /// and registration must defer. + async fn seed_written_trace( + store: &SqliteStateStore, + trace_id: &str, + cloud_id: Option<&str>, + ) -> i64 { + let recording_index = store + .create_recording(NewRecording { + robot_id: Some("robot-1"), + robot_instance: Some(0), + dataset_id: Some("ds-1"), + start_timestamp_ns: 1_700_000_000_000_000_000, + }) + .await + .unwrap() + .recording_index; + if let Some(cloud_id) = cloud_id { + store + .mark_recording_start_notified(recording_index, cloud_id) + .await + .unwrap(); + } + store + .create_trace( + recording_index, + trace_id, + Some("JOINT_POSITIONS"), + Some("arm0"), + ) + .await + .unwrap(); + store + .update_trace( + trace_id, + TraceUpdate { + write_status: Some(TraceWriteStatus::Written), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + recording_index + } + + fn client(server: &MockServer) -> Arc { + let auth = Arc::new(StaticAuthProvider::new("test-token")); + let mut options = ApiClientOptions::new(server.uri()); + options.max_backoff = Duration::from_millis(10); + Arc::new(ApiClient::new(options, auth).unwrap()) + } + + #[tokio::test] + async fn successful_registration_persists_session_uri_and_emits_event() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/traces/batch-register")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "registered_traces": [{ + "trace_id": "trace-1", + "upload_session_uris": {"JOINT_POSITIONS/arm0/trace.json": "https://upload/abc"} + }], + "failed_traces": [] + }))) + .expect(1) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + let recording_index = seed_written_trace(&store, "trace-1", Some("cloud-rec-1")).await; + let bus = EventBus::new(); + let mut subscriber = bus.subscribe(); + let api = client(&server); + + // Drive a single drain directly so the test does not depend on the + // ticker firing: register the batch, then run the promotion sweep that + // emits ReadyForUpload once a trace is both registered and written. + let store_arc = Arc::new(store.clone()); + let claimed = store + .claim_traces_for_registration(BATCH_SIZE, 0.0) + .await + .unwrap(); + submit_batch( + &store_arc, + &bus, + &api, + &org_rx(Some("org-1")), + claimed, + &mut HashMap::new(), + ) + .await; + publish_ready_traces(&store_arc, &bus).await; + + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!( + trace.registration_status, + TraceRegistrationStatus::Registered + ); + assert_eq!(trace.upload_status, TraceUploadStatus::Queued); + assert!(trace + .upload_session_uris + .as_ref() + .unwrap() + .contains("https://upload/abc")); + + // First two events on the bus are TraceRegistered + ReadyForUpload. + let mut saw_registered = false; + let mut saw_ready = false; + for _ in 0..2 { + match subscriber.recv().await.unwrap() { + DaemonEvent::TraceRegistered { + trace_id, + recording_index: event_index, + } => { + assert_eq!(trace_id, "trace-1"); + assert_eq!(event_index, recording_index); + saw_registered = true; + } + DaemonEvent::ReadyForUpload { + trace_id, + recording_index: event_index, + } => { + assert_eq!(trace_id, "trace-1"); + assert_eq!(event_index, recording_index); + saw_ready = true; + } + other => panic!("unexpected event: {other:?}"), + } + } + assert!(saw_registered); + assert!(saw_ready); + } + + #[tokio::test] + async fn failed_request_rolls_back_to_pending() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/traces/batch-register")) + .respond_with(ResponseTemplate::new(500)) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + seed_written_trace(&store, "trace-1", Some("cloud-rec-1")).await; + let bus = EventBus::new(); + let api = client(&server); + + let claimed = store + .claim_traces_for_registration(BATCH_SIZE, 0.0) + .await + .unwrap(); + submit_batch( + &Arc::new(store.clone()), + &bus, + &api, + &org_rx(Some("org-1")), + claimed, + &mut HashMap::new(), + ) + .await; + + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!(trace.registration_status, TraceRegistrationStatus::Pending); + } + + #[tokio::test] + async fn missing_org_id_rolls_back_to_pending() { + let server = MockServer::start().await; + let (store, _dir) = open_store().await; + let recording_index = store + .create_recording(NewRecording { + robot_id: Some("robot-1"), + robot_instance: Some(0), + dataset_id: Some("ds-1"), + start_timestamp_ns: 1_700_000_000_000_000_000, + }) + .await + .unwrap() + .recording_index; + store + .create_trace( + recording_index, + "trace-1", + Some("JOINT_POSITIONS"), + Some("arm"), + ) + .await + .unwrap(); + store + .update_trace( + "trace-1", + TraceUpdate { + write_status: Some(TraceWriteStatus::Written), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + let bus = EventBus::new(); + let api = client(&server); + + let claimed = store + .claim_traces_for_registration(BATCH_SIZE, 0.0) + .await + .unwrap(); + submit_batch( + &Arc::new(store.clone()), + &bus, + &api, + &org_rx(None), + claimed, + &mut HashMap::new(), + ) + .await; + + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!(trace.registration_status, TraceRegistrationStatus::Pending); + } + + #[tokio::test] + async fn defers_registration_when_recording_has_no_cloud_id() { + let server = MockServer::start().await; + // The recording has no cloud id yet, so registration must not POST. + Mock::given(method("POST")) + .and(path("/org/org-1/recording/traces/batch-register")) + .respond_with(ResponseTemplate::new(200)) + .expect(0) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + // No cloud id seeded: the start notifier hasn't populated one yet. + let recording_index = seed_written_trace(&store, "trace-1", None).await; + assert_eq!( + store + .get_recording(recording_index) + .await + .unwrap() + .unwrap() + .recording_id, + None, + "recording starts with no cloud id" + ); + let bus = EventBus::new(); + let api = client(&server); + + let claimed = store + .claim_traces_for_registration(BATCH_SIZE, 0.0) + .await + .unwrap(); + submit_batch( + &Arc::new(store.clone()), + &bus, + &api, + &org_rx(Some("org-1")), + claimed, + &mut HashMap::new(), + ) + .await; + + // The recording still has no cloud id — none is minted locally. + let row = store.get_recording(recording_index).await.unwrap().unwrap(); + assert_eq!( + row.recording_id, None, + "registration must not mint a cloud id" + ); + // The trace is rolled back to pending for a later retry. + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!(trace.registration_status, TraceRegistrationStatus::Pending); + } + + #[tokio::test] + async fn backend_rejection_retries_then_fails_after_budget() { + // A backend that rejects a trace (returns it in `failed_traces`) is + // treated as transient: the trace is rolled back to `pending` and + // retried up to MAX_REGISTRATION_ATTEMPTS, then marked terminally + // `failed`. Terminally failing on the first rejection would permanently + // wedge the recording (the regression a staging burst-error exposed). + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/org/org-1/recording/traces/batch-register")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "registered_traces": [], + "failed_traces": [{ + "trace_id": "trace-1", + "error": "Unexpected error during registration" + }] + }))) + .mount(&server) + .await; + + let (store, _dir) = open_store().await; + seed_written_trace(&store, "trace-1", Some("cloud-rec-1")).await; + let bus = EventBus::new(); + let api = client(&server); + let store_arc = Arc::new(store.clone()); + let mut attempts = HashMap::new(); + + // Each of the first MAX-1 rejections rolls the trace back to pending so + // the next tick re-claims and retries it. + for attempt in 1..MAX_REGISTRATION_ATTEMPTS { + let claimed = store + .claim_traces_for_registration(BATCH_SIZE, 0.0) + .await + .unwrap(); + assert_eq!(claimed.len(), 1, "the pending trace is re-claimable"); + submit_batch( + &store_arc, + &bus, + &api, + &org_rx(Some("org-1")), + claimed, + &mut attempts, + ) + .await; + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!( + trace.registration_status, + TraceRegistrationStatus::Pending, + "attempt {attempt} (< budget) must retry, not terminate" + ); + } + + // The final rejection exhausts the budget → terminal failure. + let claimed = store + .claim_traces_for_registration(BATCH_SIZE, 0.0) + .await + .unwrap(); + submit_batch( + &store_arc, + &bus, + &api, + &org_rx(Some("org-1")), + claimed, + &mut attempts, + ) + .await; + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!( + trace.registration_status, + TraceRegistrationStatus::Failed, + "an exhausted retry budget terminates the trace" + ); + assert_eq!( + trace.error_message.as_deref(), + Some("Unexpected error during registration") + ); + } +} diff --git a/rust/data_daemon/src/cloud/status.rs b/rust/data_daemon/src/cloud/status.rs new file mode 100644 index 000000000..9faf5a863 --- /dev/null +++ b/rust/data_daemon/src/cloud/status.rs @@ -0,0 +1,421 @@ +//! Debounced trace status updater. +//! +//! The uploader pushes [`StatusUpdate`] entries onto an unbounded mpsc; the +//! updater coalesces them into per-recording batches and flushes when one of +//! the following becomes true: +//! +//! - `MAX_BATCH_SIZE` (50) traces are queued. +//! - `IN_PROGRESS_MAX_WAIT` (4 s) elapsed since the batch opened. +//! - A completed-trace entry is in the batch and `COMPLETION_MAX_WAIT` +//! (0.2 s) has elapsed. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::task::{JoinHandle, JoinSet}; +use tokio::time::{interval, MissedTickBehavior}; + +use crate::api::models::{TraceStatusUpdate, TraceStatusValue}; +use crate::api::ApiClient; +use crate::cloud::OrgIdRx; +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::{RecordingRow, SqliteStateStore, StateStore}; + +/// Maximum number of traces to coalesce before flushing. +pub const MAX_BATCH_SIZE: usize = 50; +/// Maximum age of an in-progress batch before flushing. +pub const IN_PROGRESS_MAX_WAIT: Duration = Duration::from_secs(4); +/// Maximum age of a batch containing a completed trace. +pub const COMPLETION_MAX_WAIT: Duration = Duration::from_millis(200); +/// How long to wait before re-attempting a flush when no current `org_id` is +/// configured yet, or the recording's cloud id hasn't been assigned. Picked +/// larger than the `MAX_WAIT` triggers above so a perpetually-missing org +/// doesn't spin the executor while waiting for login / org selection. +const ORG_RESOLVE_RETRY_BACKOFF: Duration = Duration::from_secs(2); + +/// Update emitted by the uploader for the status coordinator to forward to +/// the backend. +#[derive(Debug, Clone)] +pub struct StatusUpdate { + /// Recording the trace belongs to (local `recording_index`). + pub recording_index: i64, + /// Trace identifier. + pub trace_id: String, + /// Bytes uploaded so far. + pub uploaded_bytes: i64, + /// `true` when this update represents an `UPLOAD_COMPLETE` transition. + pub completed: bool, + /// Total bytes once finalised; required when `completed` is `true`. + pub total_bytes: Option, +} + +impl StatusUpdate { + /// Build an in-progress (bytes-only) status update. + pub fn in_progress(recording_index: i64, trace_id: String, uploaded_bytes: i64) -> Self { + Self { + recording_index, + trace_id, + uploaded_bytes, + completed: false, + total_bytes: None, + } + } + + /// Build a completion update (status=UPLOAD_COMPLETE). + pub fn completed(recording_index: i64, trace_id: String, total_bytes: i64) -> Self { + Self { + recording_index, + trace_id, + uploaded_bytes: total_bytes, + completed: true, + total_bytes: Some(total_bytes), + } + } +} + +/// Handle returned by [`spawn_status_updater`]. +pub struct StatusUpdaterHandle { + join: JoinHandle<()>, +} + +impl StatusUpdaterHandle { + /// Wait for the status updater to exit. + pub async fn join(self) { + if let Err(error) = self.join.await { + tracing::warn!(?error, "status updater join failed"); + } + } +} + +/// Spawn the status updater. Returns the mpsc sender used by the uploader. +pub fn spawn_status_updater( + store: SqliteStateStore, + client: Arc, + org_rx: OrgIdRx, + inbox: mpsc::UnboundedReceiver, + shutdown_rx: broadcast::Receiver, +) -> StatusUpdaterHandle { + let store = Arc::new(store); + let join = tokio::spawn(async move { + run(store, client, org_rx, inbox, shutdown_rx).await; + }); + StatusUpdaterHandle { join } +} + +async fn run( + store: Arc, + client: Arc, + org_rx: OrgIdRx, + mut inbox: mpsc::UnboundedReceiver, + mut shutdown_rx: broadcast::Receiver, +) { + // Per-recording pending batches keyed by recording_index; preserves the + // last-seen update per trace (later updates supersede earlier ones). + let mut pending: HashMap = HashMap::new(); + // Flush tasks running in the background — spawned by flush_due and the + // max-batch path so the select loop never blocks on HTTP round-trips. + let mut background_flushes: JoinSet> = JoinSet::new(); + // Periodic flush ticker — fires every 100 ms regardless of inbox load. + let mut flush_ticker = interval(Duration::from_millis(100)); + flush_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + loop { + tokio::select! { + biased; + signal = shutdown_rx.recv() => { + tracing::debug!(?signal, "status updater shutting down"); + // Let in-flight flushes finish; re-queue any deferred batches + // so flush_all gets a chance to send them. + while let Some(flush_result) = background_flushes.join_next().await { + if let Ok(Some(deferred_batch)) = flush_result { + pending.insert(deferred_batch.recording_index, deferred_batch); + } + } + flush_all(&store, &client, &org_rx, &mut pending).await; + break; + } + // Drain completed background flush tasks without blocking the loop. + Some(flush_result) = background_flushes.join_next(), + if !background_flushes.is_empty() => + { + match flush_result { + Ok(Some(deferred_batch)) => { + pending.insert(deferred_batch.recording_index, deferred_batch); + } + Ok(None) => {} + Err(panic_err) => { + tracing::warn!(?panic_err, "flush_batch task panicked"); + } + } + } + _ = flush_ticker.tick() => { + flush_due(&store, &client, &org_rx, &mut pending, &mut background_flushes); + } + maybe_update = inbox.recv() => { + let Some(update) = maybe_update else { break }; + let recording_index = update.recording_index; + let batch = pending + .entry(recording_index) + .or_insert_with(|| RecordingBatch::new(recording_index)); + batch.add(update); + if batch.size() >= MAX_BATCH_SIZE { + if let Some(batch) = pending.remove(&recording_index) { + background_flushes.spawn(flush_batch( + Arc::clone(&store), + Arc::clone(&client), + org_rx.clone(), + batch, + )); + } + } + } + } + } +} + +/// Spawn a background task for every batch whose deadline has passed. +/// Synchronous — never blocks the select loop on HTTP I/O. +fn flush_due( + store: &Arc, + client: &Arc, + org_rx: &OrgIdRx, + pending: &mut HashMap, + background_flushes: &mut JoinSet>, +) { + let now = Instant::now(); + let due_ids: Vec = pending + .iter() + .filter(|(_, batch)| now >= batch.deadline()) + .map(|(recording_index, _)| *recording_index) + .collect(); + for recording_index in &due_ids { + if let Some(batch) = pending.remove(recording_index) { + background_flushes.spawn(flush_batch( + Arc::clone(store), + Arc::clone(client), + org_rx.clone(), + batch, + )); + } + } +} + +async fn flush_all( + store: &Arc, + client: &Arc, + org_rx: &OrgIdRx, + pending: &mut HashMap, +) { + let mut tasks: JoinSet> = JoinSet::new(); + for (_, batch) in pending.drain() { + tasks.spawn(flush_batch( + Arc::clone(store), + Arc::clone(client), + org_rx.clone(), + batch, + )); + } + // Deferred batches (org_id / cloud id not yet known) can't be sent and are + // dropped on shutdown. The persisted trace rows and the final reclaim are + // the source of truth that recovers state; the live per-trace progress in + // these dropped batches is forfeited on shutdown. Count them so a + // surprising number is visible rather than silent. + let mut dropped = 0usize; + while let Some(result) = tasks.join_next().await { + match result { + Ok(Some(_deferred_batch)) => dropped += 1, + Ok(None) => {} + Err(panic_err) => { + tracing::warn!(?panic_err, "flush_batch task panicked on shutdown"); + } + } + } + if dropped > 0 { + tracing::info!( + dropped, + "dropped deferred status batches on shutdown (no org/cloud id yet; \ + persisted rows remain source-of-truth)" + ); + } +} + +/// Flush a single recording's batch. Returns the batch back if the recording's +/// `org_id` / cloud `recording_id` isn't available yet (caller should re-insert +/// with deferred deadline), or `None` when the flush was sent (or the batch was +/// empty). +async fn flush_batch( + store: Arc, + client: Arc, + org_rx: OrgIdRx, + mut batch: RecordingBatch, +) -> Option { + let recording_index = batch.recording_index; + let row = match resolve_recording(&store, recording_index).await { + Some(row) => row, + None => { + // Re-queue with a fresh `opened_at` pushed + // `ORG_RESOLVE_RETRY_BACKOFF` into the future so the next + // `flush_due` skips this batch until the start notifier has + // populated the cloud id. Without this, a missing field pins + // `deadline()` permanently in the past and the select loop becomes + // a busy-wait until the row is ready. + batch.defer(ORG_RESOLVE_RETRY_BACKOFF); + return Some(batch); + } + }; + let (Some(org_id), Some(recording_id)) = (org_rx.borrow().clone(), row.recording_id) else { + batch.defer(ORG_RESOLVE_RETRY_BACKOFF); + return Some(batch); + }; + let updates = batch.into_updates(); + if updates.is_empty() { + return None; + } + let updates_payload: HashMap = updates.into_iter().collect(); + match client + .batch_update_traces(&org_id, &recording_id, &updates_payload) + .await + { + Ok(()) => { + tracing::debug!( + recording_index, + recording_id, + count = updates_payload.len(), + "flushed status updates" + ); + } + Err(error) => { + tracing::warn!(%error, recording_index, recording_id, count = updates_payload.len(), "status batch update failed"); + } + } + None +} + +async fn resolve_recording( + store: &Arc, + recording_index: i64, +) -> Option { + match store.get_recording(recording_index).await { + Ok(Some(row)) => Some(row), + Ok(None) => None, + Err(error) => { + tracing::warn!(%error, recording_index, "status updater could not read recording row"); + None + } + } +} + +#[derive(Debug)] +struct RecordingBatch { + recording_index: i64, + opened_at: Instant, + has_completion: bool, + updates: HashMap, +} + +impl RecordingBatch { + fn new(recording_index: i64) -> Self { + Self { + recording_index, + opened_at: Instant::now(), + has_completion: false, + updates: HashMap::new(), + } + } + + fn add(&mut self, update: StatusUpdate) { + let entry = self.updates.entry(update.trace_id).or_default(); + entry.uploaded_bytes = Some(update.uploaded_bytes); + if update.completed { + entry.status = Some(TraceStatusValue::UploadComplete); + entry.total_bytes = update.total_bytes.or(entry.total_bytes); + self.has_completion = true; + } + } + + fn size(&self) -> usize { + self.updates.len() + } + + fn deadline(&self) -> Instant { + if self.has_completion { + self.opened_at + COMPLETION_MAX_WAIT + } else { + self.opened_at + IN_PROGRESS_MAX_WAIT + } + } + + /// Slide `opened_at` forward by `delay` so the next deadline tick lands + /// at least `delay` from now. Used by the org-id retry path to space + /// out flush attempts when the recording's org isn't yet stamped. + fn defer(&mut self, delay: Duration) { + // Pin the new `opened_at` so that whatever the current deadline + // policy returns is at least `delay` from now. + let target = Instant::now() + delay; + let policy_wait = if self.has_completion { + COMPLETION_MAX_WAIT + } else { + IN_PROGRESS_MAX_WAIT + }; + self.opened_at = target.checked_sub(policy_wait).unwrap_or(target); + } + + fn into_updates(self) -> Vec<(String, TraceStatusUpdate)> { + self.updates.into_iter().collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn batch_records_completion_flag() { + let mut batch = RecordingBatch::new(1); + batch.add(StatusUpdate::in_progress(1, "t1".to_string(), 1)); + assert!(!batch.has_completion); + batch.add(StatusUpdate::completed(1, "t1".to_string(), 100)); + assert!(batch.has_completion); + // The latest update for the same trace_id overrides bytes_uploaded. + let entry = batch.updates.get("t1").unwrap(); + assert_eq!(entry.uploaded_bytes, Some(100)); + assert!(matches!( + entry.status, + Some(TraceStatusValue::UploadComplete) + )); + } + + #[test] + fn defer_slides_deadline_forward_into_future() { + // The defer path is invoked when no current org_id is configured + // yet. Without it the batch's deadline stays in the past + // and the select loop spins; with it the next deadline is at + // least `delay` from now. + let mut batch = RecordingBatch::new(1); + batch.add(StatusUpdate::in_progress(1, "t".to_string(), 1)); + // Force the batch's apparent deadline well into the past. + batch.opened_at = Instant::now() - Duration::from_secs(60); + assert!(batch.deadline() < Instant::now()); + + let delay = Duration::from_secs(2); + let before = Instant::now(); + batch.defer(delay); + let deadline = batch.deadline(); + // `deadline` should be at least `delay` from `before` (timing + // slop ~50ms is generous for CI). The exact value is `before + + // delay` because the batch is in-progress (IN_PROGRESS_MAX_WAIT + // is subtracted then re-added by deadline()). + assert!(deadline >= before + delay - Duration::from_millis(50)); + } + + #[test] + fn completion_deadline_is_shorter() { + let mut batch = RecordingBatch::new(1); + let baseline = batch.opened_at + IN_PROGRESS_MAX_WAIT; + assert!(batch.deadline() <= baseline); + batch.add(StatusUpdate::completed(1, "t".to_string(), 1)); + assert!(batch.deadline() < baseline); + } +} diff --git a/rust/data_daemon/src/cloud/upload_transfer.rs b/rust/data_daemon/src/cloud/upload_transfer.rs new file mode 100644 index 000000000..7c25f0e78 --- /dev/null +++ b/rust/data_daemon/src/cloud/upload_transfer.rs @@ -0,0 +1,624 @@ +//! Wire-level resumable-upload transfer mechanics. +//! +//! The per-file and per-chunk PUT machinery the upload coordinator +//! ([`super::uploader`]) drives: [`upload_one_file`] streams a single on-disk +//! artefact as 16 MiB chunks to the GCS resumable session URI, handling the +//! 308-continue, 410-session-expired, and 401-auth-refresh transitions, and +//! verifies the server-side CRC32C checksum on completion. + +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use base64::engine::general_purpose::STANDARD as BASE64; +use base64::Engine; +use bytes::Bytes; +use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_LENGTH, CONTENT_RANGE}; +use reqwest::StatusCode; +use tokio::fs::File; +use tokio::io::{AsyncReadExt, AsyncSeekExt}; +use tokio::time::{sleep, timeout}; + +use crate::api::ApiClient; +use crate::cloud::status::StatusUpdate; +use crate::cloud::OrgIdRx; +use crate::state::{DaemonEvent, EventBus, TraceRecord, TraceWriteHandle}; + +/// Chunk size used for resumable uploads. +/// +/// Must be a multiple of 256 KiB (the GCS resumable-upload requirement for +/// every non-final chunk); 16 MiB = 64 × 256 KiB. Larger chunks raise peak +/// throughput on fast links (fewer sequential PUTs) at the cost of coarser +/// upload-progress granularity, higher per-upload memory, and a higher minimum +/// sustained speed: a chunk must transfer within `CHUNK_UPLOAD_TIMEOUT` (200 s), +/// so the minimum sustained speed is `CHUNK_SIZE / CHUNK_UPLOAD_TIMEOUT` +/// (16 MiB / 200 s ≈ 0.67 Mbit/s). +pub const CHUNK_SIZE: usize = 16 * 1024 * 1024; +/// Persist `bytes_uploaded` to SQLite only every Nth chunk (plus once when the +/// file finishes), instead of every chunk. The per-chunk write took the store's +/// single `write_guard` once per 16 MiB and, at `MAX_CONCURRENT_UPLOADS` +/// in-flight files, serialised all uploads against each other and against the +/// notifiers/progress reporter — eroding the stop-recording SLA. Resume +/// correctness does not depend on this value: the 308-continue path +/// (`parse_resume_offset`) re-derives the committed offset from the server on +/// restart, so a stale DB offset only costs re-sending at most this many chunks. +const PROGRESS_PERSIST_EVERY_CHUNKS: u32 = 4; +/// Cap on the exponential backoff for transient upload failures. +const MAX_BACKOFF: Duration = Duration::from_secs(300); +/// Maximum retries for a single chunk. +const MAX_RETRIES: u32 = 5; +/// Hard deadline for a single chunk PUT. Belt-and-braces over the reqwest +/// client-level timeout, which can silently fail to fire for direct GCS +/// resumable session URI uploads. +const CHUNK_UPLOAD_TIMEOUT: Duration = Duration::from_secs(200); + +/// Outcome of [`upload_one_file`]. Carries a refreshed `session_uri` when the +/// server expired the original one mid-upload so the caller can persist it +/// for restart-resume. +pub(crate) struct UploadFileOutcome { + pub(crate) bytes_uploaded: i64, + pub(crate) final_session_uri: Option, +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn upload_one_file( + client: &Arc, + trace_writer: &TraceWriteHandle, + bus: &EventBus, + org_rx: &OrgIdRx, + status_tx: &tokio::sync::mpsc::UnboundedSender, + trace: &TraceRecord, + recording_id: &str, + local_path: &std::path::Path, + cloud_filepath: &str, + content_type: &str, + session_uri: String, +) -> Result { + let metadata = tokio::fs::metadata(local_path) + .await + .map_err(|error| format!("stat {} failed: {error}", local_path.display()))?; + let total_bytes = metadata.len(); + let original_uri = session_uri.clone(); + if total_bytes == 0 { + return Ok(UploadFileOutcome { + bytes_uploaded: 0, + final_session_uri: None, + }); + } + + let mut file = File::open(local_path) + .await + .map_err(|error| format!("open {} failed: {error}", local_path.display()))?; + let mut offset: u64 = 0; + let mut crc: u32 = 0; + let mut server_crc: Option = None; + let mut session_uri = session_uri; + let recording_index = trace.recording_index; + let trace_id = trace.trace_id.clone(); + let Some(org_id) = org_rx.borrow().clone() else { + return Err("no current org_id configured; cannot refresh session URI".to_string()); + }; + + tracing::info!( + trace_id, + path = %local_path.display(), + bytes = total_bytes, + "starting file upload" + ); + let upload_started = Instant::now(); + let mut chunks_since_persist: u32 = 0; + let mut last_persisted_offset: u64 = 0; + // Bound consecutive iterations that make no forward progress (e.g. a peer + // returning repeated zero-advance 308s) so a misbehaving server cannot wedge + // this upload task — and its concurrency permit — in an infinite loop. + let mut stalled_iterations: u32 = 0; + while offset < total_bytes { + let offset_before = offset; + let chunk_end = (offset + CHUNK_SIZE as u64).min(total_bytes) - 1; + let chunk_len = (chunk_end - offset + 1) as usize; + let mut buffer = vec![0u8; chunk_len]; + file.seek(std::io::SeekFrom::Start(offset)) + .await + .map_err(|error| format!("seek failed: {error}"))?; + file.read_exact(&mut buffer) + .await + .map_err(|error| format!("read failed: {error}"))?; + let chunk = Bytes::from(buffer); + let is_final = chunk_end + 1 == total_bytes; + + let outcome = put_chunk( + client, + &session_uri, + chunk.clone(), + offset, + chunk_end, + total_bytes, + is_final, + ) + .await?; + match outcome { + PutChunkOutcome::Accepted { headers, body } => { + crc = crc32c::crc32c_append(crc, &chunk); + if is_final { + server_crc = extract_server_crc32c(&headers, &body); + } + offset += chunk_len as u64; + } + PutChunkOutcome::Incomplete { headers } => { + // 308 — the server reports how much it actually committed via + // the Range header. GCS commits in 256 KiB units, so it can + // accept only a *prefix* of a 16 MiB chunk. We must hash exactly + // the committed prefix: hashing the whole chunk then resuming at + // `server_offset` re-reads — and re-hashes — the uncommitted + // tail on the next iteration, double-counting it into the local + // checksum and failing the final compare with a spurious mismatch. + let server_offset = parse_resume_offset(&headers).unwrap_or(offset); + match resume_decision(offset, chunk_len, server_offset) { + ResumeDecision::Behind => { + return Err(format!( + "server resume offset {server_offset} is behind local offset \ + {offset}; refusing to corrupt {}", + local_path.display() + )); + } + ResumeDecision::Ahead { new_offset } => { + // Server has bytes we didn't send this session (e.g. a + // prior session) and can't re-hash — accept its view but + // flag the local checksum untrustworthy. + tracing::warn!( + server_offset, + local_offset = offset + chunk_len as u64, + path = %local_path.display(), + "server resume offset is ahead of local; skipping local checksum" + ); + crc = crc32c::crc32c_append(crc, &chunk); + server_crc = None; + offset = new_offset; + } + ResumeDecision::Committed { + hash_len, + new_offset, + } => { + // Fold in only the committed prefix; the tail is re-sent + // (and hashed) on the next read, so every byte is hashed + // exactly once. + crc = crc32c::crc32c_append(crc, &chunk[..hash_len]); + offset = new_offset; + } + } + } + PutChunkOutcome::SessionExpired => { + tracing::info!( + trace_id, + path = %local_path.display(), + "upload session expired; requesting fresh URI" + ); + match client + .fetch_resumable_upload_url(&org_id, recording_id, cloud_filepath, content_type) + .await + { + Ok(new_uri) => { + session_uri = new_uri; + // A new session means the server has zero bytes for + // this file; restart from offset 0 and rehash. + offset = 0; + crc = 0; + server_crc = None; + continue; + } + Err(error) => { + return Err(format!("failed to fetch fresh session URI: {error}")); + } + } + } + PutChunkOutcome::Failed { status, body } => { + return Err(format!( + "upload failed with HTTP {status} for {}: {body}", + local_path.display() + )); + } + } + + if offset > offset_before { + stalled_iterations = 0; + } else { + stalled_iterations += 1; + if stalled_iterations >= MAX_RETRIES { + return Err(format!( + "upload of {} stalled at offset {offset}: server reported no \ + progress after {MAX_RETRIES} consecutive attempts", + local_path.display() + )); + } + } + + bus.publish(DaemonEvent::UploadProgress { + trace_id: trace_id.clone(), + recording_index, + bytes_uploaded: offset as i64, + total_bytes: Some(total_bytes as i64), + }); + let _ = status_tx.send(StatusUpdate::in_progress( + recording_index, + trace_id.clone(), + offset as i64, + )); + // Persist the rolling progress on a coarse cadence (not every chunk): + // the in-memory bus/status updates above are debounced downstream, and + // only the SQLite write contends on the shared write_guard. Resume + // correctness comes from the server's 308 offset, not this row. + chunks_since_persist += 1; + if chunks_since_persist >= PROGRESS_PERSIST_EVERY_CHUNKS { + persist_upload_offset(trace_writer, &trace_id, offset); + chunks_since_persist = 0; + last_persisted_offset = offset; + } + } + + // Persist the final offset once so the DB row reflects the completed bytes + // even if the last persisted checkpoint was several chunks back. + if offset != last_persisted_offset { + persist_upload_offset(trace_writer, &trace_id, offset); + } + + tracing::info!( + trace_id, + path = %local_path.display(), + bytes = total_bytes, + elapsed_ms = upload_started.elapsed().as_millis(), + "file upload complete" + ); + if let Some(expected) = server_crc { + if expected != crc { + return Err(format!( + "crc32c mismatch for {}: local={crc:#010x} server={expected:#010x}", + local_path.display() + )); + } + } + let final_session_uri = (session_uri != original_uri).then_some(session_uri); + Ok(UploadFileOutcome { + bytes_uploaded: total_bytes as i64, + final_session_uri, + }) +} + +/// Persist the rolling `bytes_uploaded` checkpoint for a trace via the +/// coalescing write-behind — fire-and-forget, so a burst of concurrent uploads +/// collapses to one batched row write per flush tick instead of a synchronous +/// transaction each. A missed checkpoint only costs re-sending a few chunks on +/// restart, never correctness (resume uses the server's 308 offset). +fn persist_upload_offset(trace_writer: &TraceWriteHandle, trace_id: &str, offset: u64) { + trace_writer.upload_progress(trace_id, offset as i64); +} + +/// Outcome of a single PUT to the resumable session URI. Returned by +/// [`put_chunk`] so [`upload_one_file`] can dispatch on it without re-parsing +/// status codes. +enum PutChunkOutcome { + /// 2xx — chunk accepted. Headers/body carry the final response on the + /// last chunk (the server-side CRC32C lives here). + Accepted { headers: HeaderMap, body: String }, + /// 308 — chunk accepted but the server wants more bytes. The Range + /// header tells us where it is. + Incomplete { headers: HeaderMap }, + /// 410/404 — the resumable session is gone. The caller must call + /// `/resumable_upload_url` to obtain a fresh one. + SessionExpired, + /// Any other non-retryable status; the caller surfaces it as a hard + /// error and lets the upload coordinator roll the trace to `retrying`. + Failed { status: StatusCode, body: String }, +} + +async fn put_chunk( + client: &Arc, + session_uri: &str, + chunk: Bytes, + chunk_start: u64, + chunk_end: u64, + total_bytes: u64, + is_final: bool, +) -> Result { + let mut headers = HeaderMap::new(); + let content_range = if is_final { + format!("bytes {chunk_start}-{chunk_end}/{total_bytes}") + } else { + format!("bytes {chunk_start}-{chunk_end}/*") + }; + headers.insert( + CONTENT_RANGE, + HeaderValue::from_str(&content_range).unwrap(), + ); + headers.insert(CONTENT_LENGTH, HeaderValue::from(chunk.len() as u64)); + + let mut attempt: u32 = 0; + let mut refreshed_auth = false; + loop { + let bearer = match client.auth().bearer_token().await { + Ok(token) => token, + Err(error) => { + tracing::warn!(%error, "uploader could not read auth token"); + return Err(format!("auth load failed: {error}")); + } + }; + let mut request_headers = headers.clone(); + request_headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {bearer}")) + .map_err(|error| format!("auth header invalid: {error}"))?, + ); + + // `Bytes` is cheaply cloneable (Arc-backed), so re-sending the + // same chunk on retry is a refcount bump, not a 16 MiB copy. + let request = client + .raw_client() + .put(session_uri) + .headers(request_headers) + .body(chunk.clone()) + .build() + .map_err(|error| format!("failed to build request: {error}"))?; + tracing::debug!( + attempt, + bytes = chunk.len(), + chunk_start, + chunk_end, + "sending upload chunk" + ); + let chunk_started = Instant::now(); + let response = + match timeout(CHUNK_UPLOAD_TIMEOUT, client.raw_client().execute(request)).await { + Ok(Ok(response)) => response, + Ok(Err(error)) => { + if attempt + 1 >= MAX_RETRIES { + return Err(format!("transport error: {error}")); + } + attempt += 1; + tracing::warn!(%error, attempt, "upload chunk transport error; retrying"); + sleep(backoff(attempt)).await; + continue; + } + Err(_elapsed) => { + tracing::warn!( + attempt, + timeout_secs = CHUNK_UPLOAD_TIMEOUT.as_secs(), + bytes = chunk.len(), + "upload chunk PUT timed out; retrying" + ); + if attempt + 1 >= MAX_RETRIES { + return Err(format!( + "chunk PUT timed out after {}s ({MAX_RETRIES} attempts exhausted)", + CHUNK_UPLOAD_TIMEOUT.as_secs() + )); + } + attempt += 1; + sleep(backoff(attempt)).await; + continue; + } + }; + tracing::debug!( + elapsed_ms = chunk_started.elapsed().as_millis(), + bytes = chunk.len(), + status = response.status().as_u16(), + "upload chunk response received" + ); + + let status = response.status(); + let response_headers = response.headers().clone(); + let body = response.text().await.unwrap_or_default(); + + if status == StatusCode::UNAUTHORIZED && !refreshed_auth { + if let Err(error) = client.auth().reload().await { + return Err(format!("auth reload failed: {error}")); + } + refreshed_auth = true; + continue; + } + if status.is_success() { + return Ok(PutChunkOutcome::Accepted { + headers: response_headers, + body, + }); + } + if status.as_u16() == 308 { + return Ok(PutChunkOutcome::Incomplete { + headers: response_headers, + }); + } + if matches!(status.as_u16(), 410 | 404) { + return Ok(PutChunkOutcome::SessionExpired); + } + if matches!(status.as_u16(), 429 | 500 | 502 | 503 | 504) && attempt + 1 < MAX_RETRIES { + attempt += 1; + tracing::warn!(%status, attempt, "retrying upload chunk after transient failure"); + sleep(backoff(attempt)).await; + continue; + } + return Ok(PutChunkOutcome::Failed { status, body }); + } +} + +fn backoff(attempt: u32) -> Duration { + let secs = 2u64.saturating_pow(attempt.saturating_sub(1)); + Duration::from_secs(secs.min(MAX_BACKOFF.as_secs())) +} + +fn parse_resume_offset(headers: &HeaderMap) -> Option { + let value = headers.get("range")?.to_str().ok()?; + let last = value.split('-').nth(1)?; + let last_byte: u64 = last.parse().ok()?; + Some(last_byte + 1) +} + +/// How a 308's committed `server_offset` reconciles against the just-sent chunk +/// `[offset, offset + chunk_len)`. +#[derive(Debug, PartialEq, Eq)] +enum ResumeDecision { + /// Server is behind our local offset — would corrupt the object; abort. + Behind, + /// Server is ahead of anything we sent this session (bytes we can't + /// re-hash); accept its offset but treat the local checksum as unusable. + Ahead { new_offset: u64 }, + /// Server committed `hash_len` bytes of this chunk; fold exactly that prefix + /// into the running checksum and resume from `new_offset`. + Committed { hash_len: usize, new_offset: u64 }, +} + +/// Decide how many bytes of the just-sent chunk the running checksum should absorb +/// after a 308, given the server's committed `server_offset`. Hashing only the +/// committed prefix is what keeps every byte hashed exactly once across a +/// partial (sub-chunk) commit and the resend of its tail. +fn resume_decision(offset: u64, chunk_len: usize, server_offset: u64) -> ResumeDecision { + if server_offset < offset { + ResumeDecision::Behind + } else if server_offset > offset + chunk_len as u64 { + ResumeDecision::Ahead { + new_offset: server_offset, + } + } else { + ResumeDecision::Committed { + hash_len: (server_offset - offset) as usize, + new_offset: server_offset, + } + } +} + +/// Extract the server's CRC32C for the completed object as a `u32`. +/// +/// GCS reports CRC32C as base64 of the 4-byte big-endian checksum, via the +/// `x-goog-hash` header (`crc32c=…,md5=…`, components in arbitrary order) on a +/// resumable PUT and via the `crc32c` field of the JSON object resource. Unlike +/// `md5Hash`, CRC32C is present on every object — including composite objects — +/// so the completion check can never be silently skipped. +fn extract_server_crc32c(headers: &HeaderMap, body: &str) -> Option { + let decode = |b64: &str| -> Option { + let bytes = BASE64.decode(b64).ok()?; + Some(u32::from_be_bytes( + <[u8; 4]>::try_from(bytes.as_slice()).ok()?, + )) + }; + if let Some(text) = headers + .get("x-goog-hash") + .and_then(|value| value.to_str().ok()) + { + for part in text.split(',') { + if let Some(b64) = part.trim().strip_prefix("crc32c=") { + return decode(b64); + } + } + } + if let Ok(json) = serde_json::from_str::(body) { + if let Some(b64) = json.get("crc32c").and_then(|value| value.as_str()) { + return decode(b64); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_resume_offset_uses_last_byte_plus_one() { + // 308 carries a `Range: bytes=0-` header; the offset is + // ` + 1`. The 308 commit path keys off this so anyone + // refactoring it later sees an explicit covering test. + let mut headers = HeaderMap::new(); + headers.insert("range", HeaderValue::from_static("bytes=0-99")); + assert_eq!(parse_resume_offset(&headers), Some(100)); + let mut empty = HeaderMap::new(); + empty.insert("range", HeaderValue::from_static("bytes=*")); + assert_eq!(parse_resume_offset(&empty), None); + } + + #[test] + fn resume_full_chunk_commit_hashes_whole_chunk() { + // Server committed the entire chunk → hash all of it, advance fully. + assert_eq!( + resume_decision(0, CHUNK_SIZE, CHUNK_SIZE as u64), + ResumeDecision::Committed { + hash_len: CHUNK_SIZE, + new_offset: CHUNK_SIZE as u64, + } + ); + } + + #[test] + fn resume_partial_commit_hashes_only_committed_prefix() { + // M7 regression: GCS commits in 256 KiB units, so a 16 MiB chunk can be + // committed only up to, say, 16 MiB − 256 KiB. We must hash exactly that + // committed prefix — NOT the whole chunk — or the re-sent tail is hashed + // twice and the final checksum spuriously mismatches. + let committed = (CHUNK_SIZE - 256 * 1024) as u64; + assert_eq!( + resume_decision(0, CHUNK_SIZE, committed), + ResumeDecision::Committed { + hash_len: committed as usize, + new_offset: committed, + } + ); + } + + #[test] + fn resume_zero_advance_hashes_nothing() { + // Server has nothing yet (missing/zero Range) → hash nothing, retry the + // same offset; otherwise the whole chunk would be double-hashed. + assert_eq!( + resume_decision(100, CHUNK_SIZE, 100), + ResumeDecision::Committed { + hash_len: 0, + new_offset: 100, + } + ); + } + + #[test] + fn resume_ahead_marks_checksum_untrustworthy() { + assert_eq!( + resume_decision(0, CHUNK_SIZE, CHUNK_SIZE as u64 + 1), + ResumeDecision::Ahead { + new_offset: CHUNK_SIZE as u64 + 1, + } + ); + } + + #[test] + fn resume_behind_is_a_corruption_guard() { + assert_eq!(resume_decision(100, CHUNK_SIZE, 50), ResumeDecision::Behind); + } + + #[test] + fn crc32c_matches_known_vector() { + // Castagnoli CRC32C of the standard check string; guards against a + // future swap to the wrong (zlib/ISO) polynomial, which would compile + // fine but never match a GCS-reported checksum. + assert_eq!(crc32c::crc32c(b"123456789"), 0xE306_9283); + } + + #[test] + fn extract_server_crc32c_reads_x_goog_hash_in_any_order() { + let expected = crc32c::crc32c(b"hello world"); + let b64 = BASE64.encode(expected.to_be_bytes()); + let mut headers = HeaderMap::new(); + // md5 first, crc32c second — component order is arbitrary, md5 ignored. + let value = format!("md5=ignored, crc32c={b64}"); + headers.insert("x-goog-hash", HeaderValue::from_str(&value).unwrap()); + assert_eq!(extract_server_crc32c(&headers, ""), Some(expected)); + } + + #[test] + fn extract_server_crc32c_falls_back_to_json_body() { + let expected = crc32c::crc32c(b"resumable-payload"); + let b64 = BASE64.encode(expected.to_be_bytes()); + let body = format!(r#"{{"crc32c":"{b64}","md5Hash":"ignored"}}"#); + assert_eq!( + extract_server_crc32c(&HeaderMap::new(), &body), + Some(expected) + ); + } + + #[test] + fn extract_server_crc32c_absent_is_none() { + // No crc32c anywhere → None → completion check is skipped, not failed. + assert_eq!(extract_server_crc32c(&HeaderMap::new(), ""), None); + } +} diff --git a/rust/data_daemon/src/cloud/uploader.rs b/rust/data_daemon/src/cloud/uploader.rs new file mode 100644 index 000000000..00fb7295c --- /dev/null +++ b/rust/data_daemon/src/cloud/uploader.rs @@ -0,0 +1,939 @@ +//! Resumable file uploader coordinator. +//! +//! Subscribes to [`DaemonEvent::ReadyForUpload`] (and re-scans the +//! state store on startup for any traces already in the registered/queued +//! state). For each on-disk artefact the coordinator PUTs `CHUNK_SIZE` (16 MiB) +//! chunks to the GCS resumable session URI persisted by the registration coordinator, +//! handling 308-continue, 410-session-expired, and 401-auth-refresh +//! transitions. On completion the trace is marked `Uploaded` and the upload +//! sub-system publishes `UploadComplete` for the status updater. + +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::{broadcast, Semaphore}; +use tokio::task::{JoinHandle, JoinSet}; +use tokio::time::{interval, MissedTickBehavior}; + +use crate::api::ApiClient; +use crate::cloud::cloud_files::content_type_for_filename; +use crate::cloud::status::StatusUpdate; +use crate::cloud::upload_transfer::{upload_one_file, UploadFileOutcome}; +use crate::cloud::OrgIdRx; +use crate::lifecycle::signals::ShutdownSignal; +use crate::state::store::TraceUpdate; +use crate::state::{ + ConnectionState, DaemonEvent, EventBus, SqliteStateStore, StateStore, TraceRecord, + TraceUploadStatus, TraceWriteHandle, +}; +use crate::storage::paths::TracePath; + +/// Maximum number of traces uploading concurrently. With 8 parallel contexts +/// each queuing ~128 traces simultaneously (1024 total), 32 slots serialise +/// into ~32 rounds × 300 ms ≈ 9.6 s. 128 slots cuts that to ~8 rounds × +/// 300 ms ≈ 2.4 s, giving ~6 s headroom against the 9 s stop-recording SLA. +pub const MAX_CONCURRENT_UPLOADS: usize = 128; + +/// Handle returned by [`spawn_uploader`]. +pub struct UploaderHandle { + join: JoinHandle<()>, +} + +impl UploaderHandle { + /// Wait for the uploader task to exit. + pub async fn join(self) { + if let Err(error) = self.join.await { + tracing::warn!(?error, "uploader join failed"); + } + } +} + +/// Spawn the uploader task on the current Tokio runtime. +#[allow(clippy::too_many_arguments)] +pub fn spawn_uploader( + store: SqliteStateStore, + trace_writer: TraceWriteHandle, + bus: EventBus, + client: Arc, + recordings_root: Arc, + org_rx: OrgIdRx, + status_tx: tokio::sync::mpsc::UnboundedSender, + mut shutdown_rx: broadcast::Receiver, +) -> UploaderHandle { + let mut subscriber = bus.subscribe(); + let store = Arc::new(store); + let join = tokio::spawn(async move { + let semaphore = Arc::new(Semaphore::new(MAX_CONCURRENT_UPLOADS)); + let mut in_flight: JoinSet = JoinSet::new(); + // Tracks dispatched trace IDs to prevent a drain triggered by + // join_next from re-queuing a trace whose task hasn't yet run the + // DB update to mark itself Uploading. + let mut in_flight_ids: HashSet = HashSet::new(); + // Safety-net rescan: catch any traces that were skipped when the + // semaphore was full during a drain, without relying on bus events. + let mut rescan_tick = interval(Duration::from_secs(5)); + rescan_tick.set_missed_tick_behavior(MissedTickBehavior::Skip); + let mut connected = false; + loop { + tokio::select! { + biased; + signal = shutdown_rx.recv() => { + tracing::debug!(?signal, "uploader shutting down"); + break; + } + // Reap a completed task immediately and chain the next drain + // so a finishing upload starts the next one without waiting + // for a bus event or the rescan tick. + Some(join_result) = in_flight.join_next(), if !in_flight.is_empty() => { + match join_result { + Ok(completed_trace_id) => { in_flight_ids.remove(&completed_trace_id); } + Err(panic_err) => { tracing::warn!(?panic_err, "upload task panicked"); } + } + if connected { + drain_ready_traces( + &store, + &trace_writer, + &bus, + &client, + &recordings_root, + &org_rx, + &status_tx, + &semaphore, + &mut in_flight, + &mut in_flight_ids, + ) + .await; + } + } + event = subscriber.recv() => { + match event { + Ok(DaemonEvent::ConnectionStateChanged(state)) => { + connected = matches!(state, ConnectionState::Up); + if connected { + drain_ready_traces( + &store, + &trace_writer, + &bus, + &client, + &recordings_root, + &org_rx, + &status_tx, + &semaphore, + &mut in_flight, + &mut in_flight_ids, + ) + .await; + } + } + Ok(DaemonEvent::ReadyForUpload { trace_id, .. }) => { + if !connected { + tracing::debug!(trace_id, "deferring upload until connection up"); + continue; + } + spawn_upload_task( + &store, + &trace_writer, + &bus, + &client, + &recordings_root, + &org_rx, + &status_tx, + &semaphore, + &mut in_flight, + &mut in_flight_ids, + trace_id, + ); + } + Ok(_) => {} + Err(broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!(skipped, "uploader missed bus events; rescanning"); + if connected { + drain_ready_traces( + &store, + &trace_writer, + &bus, + &client, + &recordings_root, + &org_rx, + &status_tx, + &semaphore, + &mut in_flight, + &mut in_flight_ids, + ) + .await; + } + } + Err(broadcast::error::RecvError::Closed) => break, + } + } + _ = rescan_tick.tick() => { + if connected { + drain_ready_traces( + &store, + &trace_writer, + &bus, + &client, + &recordings_root, + &org_rx, + &status_tx, + &semaphore, + &mut in_flight, + &mut in_flight_ids, + ) + .await; + } + } + } + } + in_flight.shutdown().await; + }); + UploaderHandle { join } +} + +#[allow(clippy::too_many_arguments)] +async fn drain_ready_traces( + store: &Arc, + trace_writer: &TraceWriteHandle, + bus: &EventBus, + client: &Arc, + recordings_root: &Arc, + org_rx: &OrgIdRx, + status_tx: &tokio::sync::mpsc::UnboundedSender, + semaphore: &Arc, + in_flight: &mut JoinSet, + in_flight_ids: &mut HashSet, +) { + // Server-side filter for `queued`/`retrying` traces (uses + // `idx_traces_upload_status`) instead of walking every recording's full + // trace set on each completed upload — the old N+1 scan was quadratic under + // the burst this loop runs after every `join_next`. + let trace_ids = match store.traces_ready_for_upload().await { + Ok(ids) => ids, + Err(error) => { + tracing::warn!(%error, "uploader could not query traces ready for upload"); + return; + } + }; + for trace_id in trace_ids { + spawn_upload_task( + store, + trace_writer, + bus, + client, + recordings_root, + org_rx, + status_tx, + semaphore, + in_flight, + in_flight_ids, + trace_id, + ); + } +} + +#[allow(clippy::too_many_arguments)] +fn spawn_upload_task( + store: &Arc, + trace_writer: &TraceWriteHandle, + bus: &EventBus, + client: &Arc, + recordings_root: &Arc, + org_rx: &OrgIdRx, + status_tx: &tokio::sync::mpsc::UnboundedSender, + semaphore: &Arc, + in_flight: &mut JoinSet, + in_flight_ids: &mut HashSet, + trace_id: String, +) { + if in_flight_ids.contains(&trace_id) { + tracing::debug!(trace_id, "trace already dispatched; skipping duplicate"); + return; + } + let Ok(permit) = Arc::clone(semaphore).try_acquire_owned() else { + tracing::debug!(trace_id, "upload semaphore full; will retry on next drain"); + return; + }; + in_flight_ids.insert(trace_id.clone()); + let store = Arc::clone(store); + let trace_writer = trace_writer.clone(); + let bus = bus.clone(); + let client = Arc::clone(client); + let recordings_root = Arc::clone(recordings_root); + let org_rx = org_rx.clone(); + let status_tx = status_tx.clone(); + in_flight.spawn(async move { + upload_single( + &store, + &trace_writer, + &bus, + &client, + &recordings_root, + &org_rx, + &status_tx, + &trace_id, + ) + .await; + drop(permit); + trace_id + }); +} + +#[allow(clippy::too_many_arguments)] +async fn upload_single( + store: &Arc, + trace_writer: &TraceWriteHandle, + bus: &EventBus, + client: &Arc, + recordings_root: &Arc, + org_rx: &OrgIdRx, + status_tx: &tokio::sync::mpsc::UnboundedSender, + trace_id: &str, +) { + let trace = match store.get_trace(trace_id).await { + Ok(Some(trace)) => trace, + Ok(None) => { + tracing::warn!(trace_id, "uploader could not find trace row"); + return; + } + Err(error) => { + tracing::warn!(%error, trace_id, "uploader failed to load trace row"); + return; + } + }; + let session_uris = match parse_session_uris(&trace) { + Some(uris) => uris, + None => return, + }; + if session_uris.is_empty() { + // Nothing to upload — mark uploaded immediately so downstream + // accounting matches a registered-but-empty trace. + finalise_upload(store, bus, status_tx, &trace, 0).await; + return; + } + + // Resolve the cloud `recording_id` (needed for the resumable-upload-url + // refresh) before we touch the trace's upload state. A None here means + // registration hasn't minted the cloud id yet — leave the trace in its + // queued/retrying state and skip; a later drain re-enters once it lands. + let Some(recording_id) = recording_cloud_id(store, trace.recording_index).await else { + tracing::warn!( + trace_id, + recording_index = trace.recording_index, + "recording has no cloud recording_id yet; deferring upload" + ); + return; + }; + + // Mark the trace as uploading so the next bus tick doesn't repeat the + // attempt (the registration path is one-shot, but the periodic rescan + // could re-enter on a long-running upload). + let _ = store + .update_trace( + trace_id, + TraceUpdate { + upload_status: Some(TraceUploadStatus::Uploading), + ..TraceUpdate::default() + }, + ) + .await; + + tracing::info!(trace_id, data_type = ?trace.data_type, "starting trace upload"); + let Some(data_type) = trace.data_type.as_deref() else { + // No data_type means we never saw a `StartTrace` for this row, so we + // can't locate the on-disk artefact. Surface the failure both to the + // status updater and on the event bus so the recording's progress + // reporter (which gates on every trace having settled) doesn't wait + // for an upload that can never happen. + tracing::warn!(trace_id, "trace row missing data_type; marking failed"); + mark_failed_and_emit(store, bus, status_tx, &trace, "trace missing data_type").await; + return; + }; + // On-disk artefacts are keyed by the local `recording_index`, matching the + // directory the dispatcher / trace actors wrote to. + let trace_dir = TracePath::new( + trace.recording_index.to_string(), + data_type, + trace_id.to_string(), + ) + .directory(recordings_root.as_path()); + + // Upload each on-disk artefact under its session URI and persist the + // refreshed URI back into the same slot (by index) for resume on retry. + let mut total_uploaded: i64 = 0; + let mut session_uris = session_uris; + for index in 0..session_uris.len() { + let (filename, session_uri) = session_uris[index].clone(); + let local_path = trace_dir.join(file_basename(&filename)); + if !local_path.exists() { + tracing::warn!( + trace_id, + path = %local_path.display(), + "expected upload artefact missing; marking trace failed" + ); + mark_failed_and_emit( + store, + bus, + status_tx, + &trace, + &format!("missing artefact {filename}"), + ) + .await; + return; + } + + // `content_type` here drives the GCS-side metadata when we re-acquire a + // session URI on 410. Use the same filename→type mapping registration + // used (`cloud_files::content_type_for_filename`) so the refresh can't + // disagree with what was originally registered. + let content_type = content_type_for_filename(&filename); + let outcome = upload_one_file( + client, + trace_writer, + bus, + org_rx, + status_tx, + &trace, + &recording_id, + &local_path, + &filename, + content_type, + session_uri, + ) + .await; + match outcome { + Ok(UploadFileOutcome { + bytes_uploaded, + final_session_uri, + }) => { + total_uploaded = total_uploaded.saturating_add(bytes_uploaded); + // Persist the (possibly refreshed) URI so a subsequent + // restart resumes from the right session, even if the + // refresh path fired mid-stream. + if let Some(new_uri) = final_session_uri { + session_uris[index].1 = new_uri; + persist_session_uris(store, trace_id, &session_uris).await; + } + } + Err(error) => { + tracing::warn!(%error, trace_id, "upload failed; rolling back to retrying"); + let update = TraceUpdate { + upload_status: Some(TraceUploadStatus::Retrying), + error_message: Some(Some(error)), + ..TraceUpdate::default() + }; + if let Err(error) = store.update_trace(trace_id, update).await { + tracing::warn!(%error, trace_id, "failed to mark trace as retrying"); + } + return; + } + } + } + + finalise_upload(store, bus, status_tx, &trace, total_uploaded).await; +} + +async fn finalise_upload( + store: &Arc, + bus: &EventBus, + status_tx: &tokio::sync::mpsc::UnboundedSender, + trace: &TraceRecord, + total_uploaded: i64, +) { + let update = TraceUpdate { + upload_status: Some(TraceUploadStatus::Uploaded), + bytes_uploaded: Some(total_uploaded), + total_bytes: Some(total_uploaded.max(trace.total_bytes)), + ..TraceUpdate::default() + }; + if let Err(error) = store.update_trace(&trace.trace_id, update).await { + tracing::warn!(%error, trace_id = trace.trace_id, "failed to mark trace uploaded"); + } + bus.publish(DaemonEvent::UploadComplete { + trace_id: trace.trace_id.clone(), + recording_index: trace.recording_index, + }); + let _ = status_tx.send(StatusUpdate::completed( + trace.recording_index, + trace.trace_id.clone(), + total_uploaded.max(trace.total_bytes), + )); +} + +async fn mark_failed_and_emit( + store: &Arc, + bus: &EventBus, + status_tx: &tokio::sync::mpsc::UnboundedSender, + trace: &TraceRecord, + message: &str, +) { + let update = TraceUpdate { + upload_status: Some(TraceUploadStatus::Failed), + error_message: Some(Some(message.to_string())), + ..TraceUpdate::default() + }; + if let Err(error) = store.update_trace(&trace.trace_id, update).await { + tracing::warn!(%error, trace_id = trace.trace_id, "failed to mark trace as failed"); + } + // Publishing on the upload-complete topic lets the progress reporter and + // status updater treat the trace as terminal — without this signal a + // single bad trace would block the recording's progress report forever. + bus.publish(DaemonEvent::UploadComplete { + trace_id: trace.trace_id.clone(), + recording_index: trace.recording_index, + }); + let _ = status_tx.send(StatusUpdate::completed( + trace.recording_index, + trace.trace_id.clone(), + trace.total_bytes.max(0), + )); +} + +async fn persist_session_uris( + store: &Arc, + trace_id: &str, + uris: &[(String, String)], +) { + let map: HashMap<&str, &str> = uris + .iter() + .map(|(filename, uri)| (filename.as_str(), uri.as_str())) + .collect(); + let serialised = match serde_json::to_string(&map) { + Ok(serialised) => serialised, + Err(error) => { + tracing::warn!(%error, trace_id, "failed to serialise refreshed session URIs"); + return; + } + }; + let update = TraceUpdate { + upload_session_uris: Some(serialised), + ..TraceUpdate::default() + }; + if let Err(error) = store.update_trace(trace_id, update).await { + tracing::warn!(%error, trace_id, "failed to persist refreshed session URIs"); + } +} + +fn parse_session_uris(trace: &TraceRecord) -> Option> { + let Some(serialised) = &trace.upload_session_uris else { + tracing::warn!( + trace_id = trace.trace_id, + "trace ready-for-upload but no session URIs stored" + ); + return None; + }; + match serde_json::from_str::>(serialised) { + Ok(map) => Some(map.into_iter().collect()), + Err(error) => { + tracing::warn!(%error, trace_id = trace.trace_id, "failed to decode stored session URIs"); + None + } + } +} + +fn file_basename(path: &str) -> &str { + match path.rsplit_once('/') { + Some((_, tail)) => tail, + None => path, + } +} + +/// Resolve the cloud `recording_id` (the backend handle every cloud URL needs) +/// from its local `recording_index`. `None` when registration hasn't minted +/// the cloud id yet, or the row is missing. +async fn recording_cloud_id(store: &Arc, recording_index: i64) -> Option { + match store.get_recording(recording_index).await { + Ok(Some(row)) => row.recording_id, + Ok(None) => None, + Err(error) => { + tracing::warn!(%error, recording_index, "uploader could not read recording cloud id"); + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::api::auth::StaticAuthProvider; + use crate::api::client::ApiClientOptions; + use crate::state::store::{NewRecording, TraceUpdate}; + use crate::state::{TraceUploadStatus, TraceWriteStatus}; + use crate::storage::paths::TRACE_JSON_FILENAME; + use base64::engine::general_purpose::STANDARD as BASE64; + use base64::Engine; + use std::collections::HashMap; + use std::time::Duration; + use tempfile::TempDir; + use tokio::sync::mpsc; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, Request, ResponseTemplate}; + + /// A live-org receiver fixed at `org`. The sender is leaked so the channel + /// stays open for the test's duration. + fn org_rx(org: Option<&str>) -> OrgIdRx { + let (org_tx, org_rx) = tokio::sync::watch::channel(org.map(str::to_string)); + Box::leak(Box::new(org_tx)); + org_rx + } + + async fn open_store() -> (SqliteStateStore, TempDir) { + let dir = TempDir::new().unwrap(); + let store = SqliteStateStore::open(&dir.path().join("state.db")) + .await + .unwrap(); + (store, dir) + } + + fn client(server: &MockServer) -> Arc { + let auth = Arc::new(StaticAuthProvider::new("t")); + let mut options = ApiClientOptions::new(server.uri()); + options.max_backoff = Duration::from_millis(10); + Arc::new(ApiClient::new(options, auth).unwrap()) + } + + #[allow(clippy::too_many_arguments)] + async fn seed_ready_trace( + store: &SqliteStateStore, + recordings_root: &std::path::Path, + cloud_recording_id: &str, + trace_id: &str, + data_type: &str, + data_type_name: &str, + session_uri: &str, + contents: &[u8], + ) -> (i64, std::path::PathBuf) { + let recording = store + .create_recording(NewRecording::default()) + .await + .unwrap(); + let recording_index = recording.recording_index; + // Stamp the cloud `recording_id` so the uploader's cloud-id resolution + // and the resumable-upload-url refresh see the same id the wiremock + // expectations assert on. + store + .mark_recording_start_notified(recording_index, cloud_recording_id) + .await + .unwrap(); + store + .create_trace( + recording_index, + trace_id, + Some(data_type), + Some(data_type_name), + ) + .await + .unwrap(); + // On-disk artefacts are keyed by the local `recording_index`. + let dir = TracePath::new(recording_index.to_string(), data_type, trace_id.to_string()) + .directory(recordings_root); + std::fs::create_dir_all(&dir).unwrap(); + let local = dir.join(TRACE_JSON_FILENAME); + std::fs::write(&local, contents).unwrap(); + let mut uris = HashMap::new(); + uris.insert( + format!("{data_type}/{data_type_name}/{TRACE_JSON_FILENAME}"), + session_uri.to_string(), + ); + store + .update_trace( + trace_id, + TraceUpdate { + write_status: Some(TraceWriteStatus::Written), + upload_status: Some(TraceUploadStatus::Queued), + upload_session_uris: Some(serde_json::to_string(&uris).unwrap()), + total_bytes: Some(contents.len() as i64), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + (recording_index, local) + } + + #[tokio::test] + async fn bad_server_checksum_marks_trace_retrying() { + // Server returns a deliberately wrong CRC32C (0) — the real payload's + // checksum is non-zero, so the guard must reject the upload and roll + // the trace back to `retrying` + // (the registration coordinator's recovery sweep takes it from + // there). The doc-claimed happy path is covered by + // `uploader_marks_uploaded_when_checksum_matches` below. + let server = MockServer::start().await; + Mock::given(method("PUT")) + .and(path("/upload/abc")) + .respond_with(|_req: &Request| { + ResponseTemplate::new(200).insert_header("X-Goog-Hash", "crc32c=AAAAAA==") + }) + .expect(1) + .mount(&server) + .await; + + let (store, tempdir) = open_store().await; + let recordings_root = tempdir.path().join("recordings"); + let payload = b"some-bytes"; + let (_recording_index, _) = seed_ready_trace( + &store, + &recordings_root, + "rec-1", + "trace-1", + "JOINT_POSITIONS", + "arm", + &format!("{}/upload/abc", server.uri()), + payload, + ) + .await; + + let api = client(&server); + let bus = EventBus::new(); + let (status_tx, mut status_rx) = mpsc::unbounded_channel::(); + + let store_arc = Arc::new(store.clone()); + let (trace_writer, _trace_writer_owner) = + crate::state::trace_writer::spawn(store_arc.clone()); + let recordings_root = Arc::new(recordings_root); + upload_single( + &store_arc, + &trace_writer, + &bus, + &api, + &recordings_root, + &org_rx(Some("org-1")), + &status_tx, + "trace-1", + ) + .await; + + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!(trace.upload_status, TraceUploadStatus::Retrying); + // Status updates are sent regardless. + let _ = status_rx.try_recv(); + } + + #[tokio::test] + async fn uploader_marks_uploaded_when_checksum_matches() { + let server = MockServer::start().await; + // Use the CRC32C of the payload below. + let payload = b"hello world"; + let b64 = BASE64.encode(crc32c::crc32c(payload).to_be_bytes()); + let header_value = format!("crc32c={b64}"); + let header_value_clone = header_value.clone(); + Mock::given(method("PUT")) + .and(path("/upload/abc")) + .respond_with(move |_req: &Request| { + ResponseTemplate::new(200).insert_header("X-Goog-Hash", header_value_clone.as_str()) + }) + .expect(1) + .mount(&server) + .await; + + let (store, tempdir) = open_store().await; + let recordings_root = tempdir.path().join("recordings"); + let (_recording_index, _) = seed_ready_trace( + &store, + &recordings_root, + "rec-1", + "trace-1", + "JOINT_POSITIONS", + "arm", + &format!("{}/upload/abc", server.uri()), + payload, + ) + .await; + + let api = client(&server); + let bus = EventBus::new(); + let (status_tx, mut status_rx) = mpsc::unbounded_channel::(); + + let store_arc = Arc::new(store.clone()); + let (trace_writer, _trace_writer_owner) = + crate::state::trace_writer::spawn(store_arc.clone()); + let recordings_root = Arc::new(recordings_root); + upload_single( + &store_arc, + &trace_writer, + &bus, + &api, + &recordings_root, + &org_rx(Some("org-1")), + &status_tx, + "trace-1", + ) + .await; + + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!(trace.upload_status, TraceUploadStatus::Uploaded); + assert_eq!(trace.bytes_uploaded, payload.len() as i64); + // At least one in-progress + one final status update should have + // been queued. + let mut count = 0; + while status_rx.try_recv().is_ok() { + count += 1; + } + assert!(count >= 1); + } + + #[tokio::test] + async fn session_expired_410_fetches_fresh_uri_and_restarts() { + // First PUT to /upload/dead returns 410 (expired session). + // GET resumable_upload_url returns a fresh /upload/live URI. + // Subsequent PUT to /upload/live succeeds with a valid checksum. + let server = MockServer::start().await; + let payload = b"resumable-payload"; + let b64 = BASE64.encode(crc32c::crc32c(payload).to_be_bytes()); + let header_value = format!("crc32c={b64}"); + + Mock::given(method("PUT")) + .and(path("/upload/dead")) + .respond_with(ResponseTemplate::new(410)) + .expect(1) + .mount(&server) + .await; + let live_uri = format!("{}/upload/live", server.uri()); + let fresh_response = serde_json::json!({"url": live_uri}); + Mock::given(method("GET")) + .and(path("/org/org-1/recording/rec-1/resumable_upload_url")) + .respond_with(ResponseTemplate::new(200).set_body_json(fresh_response)) + .expect(1) + .mount(&server) + .await; + let header_value_clone = header_value.clone(); + Mock::given(method("PUT")) + .and(path("/upload/live")) + .respond_with(move |_req: &Request| { + ResponseTemplate::new(200).insert_header("X-Goog-Hash", header_value_clone.as_str()) + }) + .expect(1) + .mount(&server) + .await; + + let (store, tempdir) = open_store().await; + let recordings_root = tempdir.path().join("recordings"); + let dead_uri = format!("{}/upload/dead", server.uri()); + let (_recording_index, _) = seed_ready_trace( + &store, + &recordings_root, + "rec-1", + "trace-1", + "JOINT_POSITIONS", + "arm", + &dead_uri, + payload, + ) + .await; + + let api = client(&server); + let bus = EventBus::new(); + let (status_tx, _status_rx) = mpsc::unbounded_channel::(); + + let store_arc = Arc::new(store.clone()); + let (trace_writer, _trace_writer_owner) = + crate::state::trace_writer::spawn(store_arc.clone()); + let recordings_root = Arc::new(recordings_root); + upload_single( + &store_arc, + &trace_writer, + &bus, + &api, + &recordings_root, + &org_rx(Some("org-1")), + &status_tx, + "trace-1", + ) + .await; + + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!(trace.upload_status, TraceUploadStatus::Uploaded); + // Persisted URI must be the refreshed one so a restart resumes + // against the live session, not the dead one. + let serialised = trace.upload_session_uris.as_ref().expect("uris stored"); + assert!( + serialised.contains("/upload/live"), + "refreshed URI not persisted: {serialised}" + ); + assert!( + !serialised.contains("/upload/dead"), + "dead URI still present: {serialised}" + ); + } + + #[tokio::test] + async fn missing_data_type_emits_terminal_failure_and_unblocks_progress() { + // A trace registered without a data_type cannot be located on + // disk. The uploader must mark it Failed *and* emit an + // UploadComplete so the progress reporter's "all settled" gate + // moves on — otherwise the recording sits as `pending` forever. + let server = MockServer::start().await; + let (store, tempdir) = open_store().await; + let recordings_root = tempdir.path().join("recordings"); + + let recording = store + .create_recording(NewRecording::default()) + .await + .unwrap(); + let recording_index = recording.recording_index; + // Stamp a cloud id so the uploader's cloud-id resolution passes and it + // reaches the missing-data-type branch (not the deferral path). + store + .mark_recording_start_notified(recording_index, "rec-1") + .await + .unwrap(); + // Insert directly with NULL data_type so the uploader hits the + // missing-data-type branch. + store + .create_trace(recording_index, "trace-1", None, None) + .await + .unwrap(); + let mut uris = HashMap::new(); + uris.insert("dummy".to_string(), "https://upload/abc".to_string()); + store + .update_trace( + "trace-1", + TraceUpdate { + write_status: Some(TraceWriteStatus::Written), + upload_status: Some(TraceUploadStatus::Queued), + upload_session_uris: Some(serde_json::to_string(&uris).unwrap()), + ..TraceUpdate::default() + }, + ) + .await + .unwrap(); + + let api = client(&server); + let bus = EventBus::new(); + let mut subscriber = bus.subscribe(); + let (status_tx, mut status_rx) = mpsc::unbounded_channel::(); + + let store_arc = Arc::new(store.clone()); + let (trace_writer, _trace_writer_owner) = + crate::state::trace_writer::spawn(store_arc.clone()); + let recordings_root = Arc::new(recordings_root); + upload_single( + &store_arc, + &trace_writer, + &bus, + &api, + &recordings_root, + &org_rx(Some("org-1")), + &status_tx, + "trace-1", + ) + .await; + + let trace = store.get_trace("trace-1").await.unwrap().unwrap(); + assert_eq!(trace.upload_status, TraceUploadStatus::Failed); + // UploadComplete fires so the recording's progress report can + // proceed — otherwise a stray no-data-type trace would deadlock + // the whole recording. + match subscriber.try_recv() { + Ok(DaemonEvent::UploadComplete { trace_id, .. }) => { + assert_eq!(trace_id, "trace-1"); + } + other => panic!("expected UploadComplete event, got {other:?}"), + } + // Status updater also gets a terminal entry. + let update = status_rx.try_recv().expect("status update enqueued"); + assert!(update.completed); + } +} diff --git a/rust/data_daemon/src/lifecycle/signals.rs b/rust/data_daemon/src/lifecycle/signals.rs new file mode 100644 index 000000000..87e2107cc --- /dev/null +++ b/rust/data_daemon/src/lifecycle/signals.rs @@ -0,0 +1,110 @@ +//! Async SIGTERM / SIGINT handling, fanned out to subscribers over a +//! `tokio::sync::broadcast`. +//! +//! Both signals trigger a graceful shutdown: the broadcast channel is the +//! notification the daemon's main loop awaits. SIGHUP is intentionally not +//! handled. + +use tokio::signal::unix::{signal, SignalKind}; +use tokio::sync::broadcast; + +/// Source of a graceful-shutdown notification, useful for log messages. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ShutdownSignal { + /// `SIGTERM` (default `kill` signal, CLI `stop` command). + Sigterm, + /// `SIGINT` (Ctrl-C from a controlling terminal). + Sigint, +} + +/// A handle to the shutdown channel: clone [`subscribe`](Self::subscribe) for +/// each task that needs to wait for shutdown. +#[derive(Clone)] +pub struct ShutdownHandle { + sender: broadcast::Sender, +} + +impl ShutdownHandle { + /// Subscribe to receive a single shutdown notification. + pub fn subscribe(&self) -> broadcast::Receiver { + self.sender.subscribe() + } + + /// Fire an explicit shutdown (used by `SystemExit`-style flows and + /// tests). Returns the number of receivers notified. + #[allow(dead_code)] + pub fn signal(&self, kind: ShutdownSignal) -> usize { + self.sender.send(kind).unwrap_or(0) + } +} + +/// Install async SIGTERM and SIGINT handlers, returning a +/// [`ShutdownHandle`] and the *primary* shutdown receiver that the caller +/// must await. +/// +/// Returning the primary receiver alongside the handle closes a race that +/// would otherwise exist between the supervisor task's first `send` and the +/// caller's first `subscribe()`: `broadcast::Sender::send` returns +/// `SendError` when there are zero receivers, and `broadcast` does not +/// replay messages for receivers that subscribe later. By constructing the +/// primary receiver up-front via `broadcast::channel`, we guarantee at least +/// one receiver exists from the moment the supervisor task starts. +pub fn install_shutdown_handler( +) -> std::io::Result<(ShutdownHandle, broadcast::Receiver)> { + let (sender, primary_receiver) = broadcast::channel(8); + let supervisor_sender = sender.clone(); + + let mut sigterm = signal(SignalKind::terminate())?; + let mut sigint = signal(SignalKind::interrupt())?; + + tokio::spawn(async move { + loop { + let received = tokio::select! { + Some(()) = sigterm.recv() => ShutdownSignal::Sigterm, + Some(()) = sigint.recv() => ShutdownSignal::Sigint, + else => return, + }; + tracing::info!(signal = ?received, "shutdown signal received"); + // The primary receiver returned from this function keeps the + // channel populated with at least one receiver; further sends + // only fail if every subscriber has been dropped (typically + // during shutdown), which we ignore. + let _ = supervisor_sender.send(received); + } + }); + + Ok((ShutdownHandle { sender }, primary_receiver)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn explicit_signal_reaches_subscriber() { + // Construct the channel directly (we cannot install signal handlers + // in tests because they're a process-global resource) and exercise + // the `signal` / `subscribe` plumbing the supervisor task relies on. + let (sender, primary_receiver) = broadcast::channel(8); + let handle = ShutdownHandle { sender }; + // Drop the primary so we can test that the explicit subscribe path + // also works for additional listeners. + drop(primary_receiver); + let mut subscriber = handle.subscribe(); + + let notified = handle.signal(ShutdownSignal::Sigterm); + assert_eq!(notified, 1); + let received = subscriber.recv().await.expect("recv"); + assert_eq!(received, ShutdownSignal::Sigterm); + } + + #[tokio::test] + async fn signal_with_no_subscribers_returns_zero() { + let (sender, primary_receiver) = broadcast::channel(8); + let handle = ShutdownHandle { sender }; + drop(primary_receiver); + // No live receivers — `send` returns Err, surfaced as `0` from our + // `signal` wrapper. + assert_eq!(handle.signal(ShutdownSignal::Sigint), 0); + } +} diff --git a/rust/data_daemon/src/state/trace_writer.rs b/rust/data_daemon/src/state/trace_writer.rs new file mode 100644 index 000000000..63c48463b --- /dev/null +++ b/rust/data_daemon/src/state/trace_writer.rs @@ -0,0 +1,558 @@ +//! Coalescing + batching write-behind for per-trace actor writes. +//! +//! Per-trace actors fire-and-forget partial column updates (a `writing` bump, a +//! debounced `bytes_written`, the finalise `written` + `total_bytes`, or a +//! `failed`) without ever awaiting a transaction. This task coalesces +//! consecutive ops for the same trace last-writer-wins per column (a burst of +//! `bytes_written` collapses to one row write) and flushes the pending set in a +//! single batched transaction ([`SqliteStateStore::apply_trace_writes`]) on a +//! short timer or once the pending set grows past a cap. +//! +//! Terminal-state monotonicity (a late progress write can't resurrect a +//! cancelled row) lives in `apply_trace_writes`'s `WHERE` guard, so the writer +//! needs no coordination with the cancel path. + +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::{mpsc, oneshot}; +use tokio::task::JoinHandle; +use tokio::time::{interval, Duration, MissedTickBehavior}; + +use crate::state::schema::{TraceErrorCode, TraceWriteStatus}; +use crate::state::store::{CoalescedTraceWrite, SqliteStateStore, TraceCreate}; + +/// How often pending writes are flushed. Short enough that finalised traces +/// become visible promptly, long enough that a burst of progress updates +/// coalesces into one row write per flush. +const FLUSH_INTERVAL: Duration = Duration::from_millis(25); + +/// Flush eagerly once this many distinct traces are pending, so a wide +/// fan-out (many traces updated within one interval) doesn't grow an +/// unbounded batch before the timer fires. +const MAX_PENDING_TRACES: usize = 512; + +/// Control + data messages accepted by the writer task. +enum Message { + /// A partial column update for one trace, merged into the pending set. + Write(CoalescedTraceWrite), + /// Discard every *pending create* for a recording, then acknowledge. Sent by + /// the dispatcher's cancel path before `cancel_recording` so a not-yet- + /// flushed trace can't be inserted as an orphan row after the recording is + /// burned. Pending update-only writes are left — the terminal-state guard in + /// `apply_trace_writes` already makes them no-ops against the failed row. + DropRecording { + recording_index: i64, + ack: oneshot::Sender<()>, + }, + /// Flush everything pending now and acknowledge (tests + shutdown). + Flush(oneshot::Sender<()>), + /// Drain, flush, acknowledge, and exit. + Shutdown(oneshot::Sender<()>), +} + +/// Cloneable handle the per-trace actors use to enqueue writes. Every method is +/// synchronous and non-blocking: the actor fires an update and moves on. +#[derive(Clone)] +pub struct TraceWriteHandle { + tx: mpsc::UnboundedSender, +} + +impl TraceWriteHandle { + /// Create the trace row (fire-and-forget). Sent once, as the actor's first + /// write, so the row is inserted by the next batched flush instead of the + /// actor blocking on a synchronous `create_trace`. Works at any point in a + /// recording, including a sensor that starts logging midway. + pub fn create( + &self, + trace_id: &str, + recording_index: i64, + data_type: Option<&str>, + data_type_name: Option<&str>, + ) { + self.enqueue(CoalescedTraceWrite { + trace_id: trace_id.to_string(), + create: Some(TraceCreate { + recording_index, + data_type: data_type.map(str::to_string), + data_type_name: data_type_name.map(str::to_string), + }), + ..Default::default() + }); + } + + /// Mark the trace `writing` (first frame / first video chunk). + pub fn mark_writing(&self, trace_id: &str) { + self.enqueue(CoalescedTraceWrite { + trace_id: trace_id.to_string(), + write_status: Some(TraceWriteStatus::Writing), + ..Default::default() + }); + } + + /// Record the latest absolute on-disk byte count. + pub fn progress(&self, trace_id: &str, bytes_written: i64) { + self.enqueue(CoalescedTraceWrite { + trace_id: trace_id.to_string(), + bytes_written: Some(bytes_written), + ..Default::default() + }); + } + + /// Record the latest rolling upload offset (advisory progress). Coalesced + /// like the write-phase progress so the uploader's per-64-MiB checkpoint + /// across many concurrent uploads collapses to one batched row write + /// instead of a synchronous transaction each. Resume correctness comes from + /// the server's 308 offset, not this row, so a coalesced/late value is + /// harmless; the store skips it once the upload has settled. + pub fn upload_progress(&self, trace_id: &str, bytes_uploaded: i64) { + self.enqueue(CoalescedTraceWrite { + trace_id: trace_id.to_string(), + bytes_uploaded: Some(bytes_uploaded), + ..Default::default() + }); + } + + /// Finalise the trace: `written`, with the final byte total. + pub fn finalise(&self, trace_id: &str, total_bytes: i64) { + self.enqueue(CoalescedTraceWrite { + trace_id: trace_id.to_string(), + write_status: Some(TraceWriteStatus::Written), + total_bytes: Some(total_bytes), + bytes_written: Some(total_bytes), + ..Default::default() + }); + } + + /// Mark the trace `failed`, preserving the latest byte count. + pub fn fail(&self, trace_id: &str, bytes_written: i64) { + self.enqueue(CoalescedTraceWrite { + trace_id: trace_id.to_string(), + write_status: Some(TraceWriteStatus::Failed), + bytes_written: Some(bytes_written), + ..Default::default() + }); + } + + /// Mark the trace `failed` with a write-phase error code + message. + #[allow(dead_code)] + pub fn fail_with( + &self, + trace_id: &str, + bytes_written: i64, + error_code: TraceErrorCode, + error_message: impl Into, + ) { + self.enqueue(CoalescedTraceWrite { + trace_id: trace_id.to_string(), + write_status: Some(TraceWriteStatus::Failed), + bytes_written: Some(bytes_written), + error_code: Some(error_code), + error_message: Some(error_message.into()), + ..Default::default() + }); + } + + /// Flush all pending writes and wait for the batch to commit. Used by + /// tests and by callers that need a happens-before with the DB. + pub async fn flush(&self) { + let (ack, ack_rx) = oneshot::channel(); + if self.tx.send(Message::Flush(ack)).is_ok() { + let _ = ack_rx.await; + } + } + + /// Discard pending creates for a recording and wait for the purge to + /// complete. The dispatcher calls this before `cancel_recording` so a + /// not-yet-flushed trace of a cancelled recording can't land as an orphan + /// row after the cancel has burned the recording's existing traces. + pub async fn drop_recording(&self, recording_index: i64) { + let (ack, ack_rx) = oneshot::channel(); + if self + .tx + .send(Message::DropRecording { + recording_index, + ack, + }) + .is_ok() + { + let _ = ack_rx.await; + } + } + + fn enqueue(&self, write: CoalescedTraceWrite) { + // The channel only closes once the writer task has exited (daemon + // shutdown). A drop here means we're past the point where writes + // matter, so swallow it rather than propagate to the actor. + let _ = self.tx.send(Message::Write(write)); + } +} + +/// Owns the writer task's lifetime. Held by the daemon main loop; dropping it +/// does not stop the task (clones of the handle keep the channel open) — call +/// [`TraceWriter::shutdown`] to drain, flush, and join. +pub struct TraceWriter { + tx: mpsc::UnboundedSender, + join: JoinHandle<()>, +} + +impl TraceWriter { + /// Drain every queued write, flush a final batch, and join the task. Call + /// after the dispatcher (and therefore every actor) has shut down, so no + /// further writes can be produced, and before the store is closed. + pub async fn shutdown(self) { + let (ack, ack_rx) = oneshot::channel(); + if self.tx.send(Message::Shutdown(ack)).is_ok() { + let _ = ack_rx.await; + } + if let Err(error) = self.join.await { + tracing::warn!(?error, "trace-writer task join failed during shutdown"); + } + } +} + +/// Spawn the writer task and return a cloneable [`TraceWriteHandle`] for the +/// actors plus the [`TraceWriter`] owner for shutdown. +pub fn spawn(store: Arc) -> (TraceWriteHandle, TraceWriter) { + let (tx, rx) = mpsc::unbounded_channel(); + let join = tokio::spawn(run(store, rx)); + ( + TraceWriteHandle { tx: tx.clone() }, + TraceWriter { tx, join }, + ) +} + +/// Merge one partial update into the pending set, last-writer-wins per column. +fn merge(pending: &mut HashMap, write: CoalescedTraceWrite) { + let entry = pending + .entry(write.trace_id.clone()) + .or_insert_with(|| CoalescedTraceWrite { + trace_id: write.trace_id.clone(), + ..Default::default() + }); + // `create` is set-once — it arrives on the first write and is immutable + // thereafter (the row identity never changes). + if write.create.is_some() && entry.create.is_none() { + entry.create = write.create; + } + if write.write_status.is_some() { + entry.write_status = write.write_status; + } + if write.bytes_written.is_some() { + entry.bytes_written = write.bytes_written; + } + if write.total_bytes.is_some() { + entry.total_bytes = write.total_bytes; + } + if write.bytes_uploaded.is_some() { + entry.bytes_uploaded = write.bytes_uploaded; + } + // `error_code`/`error_message` are only ever set by `fail`, which is + // mutually exclusive with `finalise` (a trace either fails or finalises, not + // both), so a `written` status never coalesces with a stale error in the + // same entry. + if write.error_code.is_some() { + entry.error_code = write.error_code; + } + if write.error_message.is_some() { + entry.error_message = write.error_message; + } +} + +/// Discard pending entries that would *insert* a row for `recording_index`. +/// Update-only entries (whose create already flushed) are left: the terminal +/// guard in `apply_trace_writes` makes them no-ops against the cancelled row. +fn drop_recording_creates( + pending: &mut HashMap, + recording_index: i64, +) { + pending.retain(|_, write| { + write + .create + .as_ref() + .is_none_or(|create| create.recording_index != recording_index) + }); +} + +/// Flush the pending set in one batched transaction, clearing it only on a +/// successful commit. +/// +/// On error the drained batch is **re-merged** into `pending` rather than +/// discarded: dropping it loses a `finalise`/`failed`, which wedges the trace +/// in `writing` and retains its parent recording forever. Re-merging keeps the +/// updates for the next tick's retry and, because the merge is keyed by +/// `trace_id`, coalesces with any writes that arrived since — so a persistent +/// failure can't grow `pending` past the live trace count. +async fn flush(store: &SqliteStateStore, pending: &mut HashMap) { + if pending.is_empty() { + return; + } + let batch: Vec = pending.drain().map(|(_, write)| write).collect(); + if let Err(error) = store.apply_trace_writes(&batch).await { + tracing::warn!( + %error, + rows = batch.len(), + "trace-writer batch flush failed; re-queueing batch for retry" + ); + for write in batch { + merge(pending, write); + } + } +} + +async fn run(store: Arc, mut rx: mpsc::UnboundedReceiver) { + let mut pending: HashMap = HashMap::new(); + let mut ticker = interval(FLUSH_INTERVAL); + // A flush that runs long must not fire a backlog of catch-up ticks. + ticker.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + tokio::select! { + message = rx.recv() => match message { + Some(Message::Write(write)) => { + merge(&mut pending, write); + if pending.len() >= MAX_PENDING_TRACES { + flush(&store, &mut pending).await; + } + } + Some(Message::DropRecording { recording_index, ack }) => { + drop_recording_creates(&mut pending, recording_index); + let _ = ack.send(()); + } + Some(Message::Flush(ack)) => { + flush(&store, &mut pending).await; + let _ = ack.send(()); + } + Some(Message::Shutdown(ack)) => { + // Drain anything already queued behind the Shutdown so no + // finalise is lost, then flush a last batch. + while let Ok(message) = rx.try_recv() { + match message { + Message::Write(write) => merge(&mut pending, write), + Message::DropRecording { recording_index, ack: inner } => { + drop_recording_creates(&mut pending, recording_index); + let _ = inner.send(()); + } + Message::Flush(inner) => { + flush(&store, &mut pending).await; + let _ = inner.send(()); + } + Message::Shutdown(inner) => { + let _ = inner.send(()); + } + } + } + flush(&store, &mut pending).await; + let _ = ack.send(()); + return; + } + // All handles dropped without an explicit shutdown — flush + // whatever's left so a finalise isn't lost on an abrupt exit. + None => { + flush(&store, &mut pending).await; + return; + } + }, + _ = ticker.tick() => { + flush(&store, &mut pending).await; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::state::schema::TraceWriteStatus; + use crate::state::store::NewRecording; + use crate::state::StateStore; + use tempfile::TempDir; + + async fn store_with_trace() -> (Arc, TempDir, String) { + let dir = TempDir::new().unwrap(); + let store = Arc::new( + SqliteStateStore::open(&dir.path().join("state.db")) + .await + .unwrap(), + ); + let rec = store + .create_recording(NewRecording { + robot_id: Some("r"), + robot_instance: Some(0), + start_timestamp_ns: 1, + ..Default::default() + }) + .await + .unwrap() + .recording_index; + store + .create_trace(rec, "t1", Some("J"), Some("j")) + .await + .unwrap(); + (store, dir, "t1".to_string()) + } + + #[tokio::test] + async fn coalesces_progress_and_finalises() { + let (store, _dir, trace_id) = store_with_trace().await; + let (handle, writer) = spawn(store.clone()); + + handle.mark_writing(&trace_id); + for bytes in [10, 20, 30, 40] { + handle.progress(&trace_id, bytes); + } + handle.finalise(&trace_id, 100); + handle.flush().await; + + let trace = store.get_trace(&trace_id).await.unwrap().unwrap(); + assert_eq!(trace.write_status, TraceWriteStatus::Written); + assert_eq!(trace.total_bytes, 100); + assert_eq!(trace.bytes_written, 100); + + writer.shutdown().await; + } + + #[tokio::test] + async fn progress_does_not_resurrect_a_failed_row() { + let (store, _dir, trace_id) = store_with_trace().await; + let (handle, writer) = spawn(store.clone()); + + // Simulate cancel burning the row to `failed` out of band. + store + .update_trace( + &trace_id, + crate::state::store::TraceUpdate { + write_status: Some(TraceWriteStatus::Failed), + ..Default::default() + }, + ) + .await + .unwrap(); + + // A late coalesced progress write must NOT move it back to writing. + handle.progress(&trace_id, 999); + handle.mark_writing(&trace_id); + handle.flush().await; + + let trace = store.get_trace(&trace_id).await.unwrap().unwrap(); + assert_eq!(trace.write_status, TraceWriteStatus::Failed); + + writer.shutdown().await; + } + + #[tokio::test] + async fn shutdown_flushes_queued_writes() { + let (store, _dir, trace_id) = store_with_trace().await; + let (handle, writer) = spawn(store.clone()); + + handle.finalise(&trace_id, 42); + // No explicit flush — shutdown must drain and persist it. + writer.shutdown().await; + + let trace = store.get_trace(&trace_id).await.unwrap().unwrap(); + assert_eq!(trace.write_status, TraceWriteStatus::Written); + assert_eq!(trace.total_bytes, 42); + } + + #[tokio::test] + async fn flush_retains_batch_when_apply_fails() { + let (store, _dir, trace_id) = store_with_trace().await; + let mut pending = HashMap::new(); + merge( + &mut pending, + CoalescedTraceWrite { + trace_id: trace_id.clone(), + write_status: Some(TraceWriteStatus::Written), + total_bytes: Some(99), + ..Default::default() + }, + ); + + // Force apply_trace_writes to fail by closing the write connection. + store.write_pool().close().await; + flush(&store, &mut pending).await; + + // Regression guard for H2: a failed flush must NOT silently drop the + // batch — a lost `finalise` would wedge the trace in `writing` and + // retain its parent recording forever. + assert_eq!(pending.len(), 1, "failed flush must retain the batch"); + let retained = pending.get(&trace_id).expect("batch retained for retry"); + assert_eq!(retained.write_status, Some(TraceWriteStatus::Written)); + assert_eq!(retained.total_bytes, Some(99)); + } + + async fn store_with_recording() -> (Arc, TempDir, i64) { + let dir = TempDir::new().unwrap(); + let store = Arc::new( + SqliteStateStore::open(&dir.path().join("state.db")) + .await + .unwrap(), + ); + let rec = store + .create_recording(NewRecording { + robot_id: Some("r"), + robot_instance: Some(0), + start_timestamp_ns: 1, + ..Default::default() + }) + .await + .unwrap() + .recording_index; + (store, dir, rec) + } + + #[tokio::test] + async fn batched_create_inserts_then_finalises() { + let (store, _dir, rec) = store_with_recording().await; + let (handle, writer) = spawn(store.clone()); + + // No synchronous create_trace — the row is born from the batch. + handle.create("t-new", rec, Some("J"), Some("j")); + handle.mark_writing("t-new"); + handle.progress("t-new", 64); + handle.finalise("t-new", 128); + handle.flush().await; + + let trace = store.get_trace("t-new").await.unwrap().unwrap(); + assert_eq!(trace.recording_index, rec); + assert_eq!(trace.data_type.as_deref(), Some("J")); + assert_eq!(trace.write_status, TraceWriteStatus::Written); + assert_eq!(trace.total_bytes, 128); + + writer.shutdown().await; + } + + #[tokio::test] + async fn create_only_write_inserts_initializing_row() { + let (store, _dir, rec) = store_with_recording().await; + let (handle, writer) = spawn(store.clone()); + + // A sensor that starts logging mid-recording: actor spawns, sends the + // create, but no data has been appended before the flush. + handle.create("t-mid", rec, Some("RGB"), Some("cam")); + handle.flush().await; + + let trace = store.get_trace("t-mid").await.unwrap().unwrap(); + assert_eq!(trace.write_status, TraceWriteStatus::Initializing); + assert_eq!(trace.recording_index, rec); + + writer.shutdown().await; + } + + #[tokio::test] + async fn drop_recording_discards_unflushed_create() { + let (store, _dir, rec) = store_with_recording().await; + let (handle, writer) = spawn(store.clone()); + + // Create queued but NOT flushed, then the recording is cancelled. + handle.create("t-cancel", rec, Some("J"), Some("j")); + handle.mark_writing("t-cancel"); + handle.drop_recording(rec).await; + handle.flush().await; + + // The orphan row must never have been inserted. + assert!(store.get_trace("t-cancel").await.unwrap().is_none()); + + writer.shutdown().await; + } +} diff --git a/rust/neuracore_webrtc/.gitignore b/rust/neuracore_webrtc/.gitignore new file mode 100644 index 000000000..ea8c4bf7f --- /dev/null +++ b/rust/neuracore_webrtc/.gitignore @@ -0,0 +1 @@ +/target diff --git a/rust/neuracore_webrtc/Cargo.toml b/rust/neuracore_webrtc/Cargo.toml new file mode 100644 index 000000000..bf5e0aa23 --- /dev/null +++ b/rust/neuracore_webrtc/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "neuracore_webrtc" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Rust WebRTC streaming core for Neuracore, exposed to Python via PyO3. PR0 scaffolding: a Rust-owned tokio runtime, the bounded frame / drainable event queues, and the stubbed synchronous Producer/Consumer API surface. The real transport (the `datachannel` crate, `media` feature, over libdatachannel) is wired in PR2." + +[lib] +# `cdylib` produces the `.so` Python imports as +# `neuracore.core.streaming.p2p._native_webrtc`. +# `rlib` lets Rust integration tests still link against the library. +name = "neuracore_webrtc" +crate-type = ["cdylib", "rlib"] +path = "src/lib.rs" + +[dependencies] +# The real transport: datachannel-rs over libdatachannel. +# * `media` builds libdatachannel with media support so PR4 can add video +# tracks on the same PeerConnection; PR2 only uses the data plane. +# * `vendored` static-links libdatachannel *and* OpenSSL into our cdylib, so +# the shipped `.so` is self-contained: no runtime libdatachannel.so +# to locate and no system OpenSSL needed (built from source via +# openssl-src). Build prerequisites are cmake, libclang, a C/C++ +# compiler and perl — all present in this image. See +# reports/PR2-data-path.md "Build prerequisites". +datachannel = { version = "0.16", features = ["media", "vendored"] } +# The raw libdatachannel C bindings, pulled directly (same 0.23 build datachannel +# already links) so PR4 can reach two symbols datachannel-rs 0.16 never surfaces: +# * `rtcSetTrackCallback` — the consumer's only way to adopt an inbound media +# track (the high-level `PeerConnectionHandler` has no `on_track`). +# * `rtcSetMessageCallback` on that inbound track id — to receive its RTP. +# See reports/PR4-encode-and-decode.md "Consumer inbound-track adoption". +datachannel-sys = "0.23" +once_cell.workspace = true +pyo3.workspace = true +serde_json.workspace = true +tokio.workspace = true diff --git a/rust/neuracore_webrtc/src/broadcaster.rs b/rust/neuracore_webrtc/src/broadcaster.rs new file mode 100644 index 000000000..c7aee25cf --- /dev/null +++ b/rust/neuracore_webrtc/src/broadcaster.rs @@ -0,0 +1,1428 @@ +//! The broadcaster: one producer serving **many** consumers from a single shared +//! encode per video source. +//! +//! The 1:1 [`Producer`](crate::producer::Producer) owns one peer connection and +//! one ffmpeg encode per track that doubles as that single consumer's track. The +//! [`Broadcaster`] keeps the same encode/packetize separation PR4 introduced but +//! fans **one** encode out to N consumers: +//! +//! * Each consumer is its own answer-only peer connection (the broadcaster is the +//! sole offerer to each), with its own PR3 negotiation queue, its own control +//! channel + manifest (its own mids), and its own [`CongestionController`] per +//! track registered in [`PRODUCER_FB`]. +//! * Exactly one ffmpeg encode runs per video source (`track_id`), never per +//! consumer. Its NAL access units fan out to every consumer's track for that +//! source via that track's own send handle (`rtcSendMessage` on the raw track +//! id); each consumer's built-in chain packetizes independently with its own +//! SSRC and sequence space. Re-encoding per consumer is exactly what this +//! module exists to avoid. +//! * The shared encoder rung is the **minimum estimate** across all consumers' +//! controllers — the worst link caps everyone (a single-encoder tradeoff: no +//! per-consumer quality). A lower estimate is a *coarser* ladder rung (higher +//! index), so the min-fold over estimates is a `max` over ladder indices +//! ([`fold_rung`]). +//! +//! ## Join / leave +//! +//! A join adds a consumer peer connection and, for each existing source, a +//! per-consumer track, then negotiates that consumer only. It does **not** force a +//! shared-encode keyframe — a forced IDR would blip every existing consumer — so a +//! joiner waits for the next periodic IDR (the encoder's `keyint`). A joiner's +//! early PLIs (it has no decodable frame until that IDR) are coalesced and +//! suppressed for a grace window ([`should_honor_pli`]) so one joining consumer +//! cannot restart the shared encode and disrupt the rest. +//! +//! A leave tears down only that consumer's peer connection, its tracks, and its +//! controllers (deregistered from [`PRODUCER_FB`]) without disturbing the others +//! or the shared encode. Removing the last consumer is graceful: the encode idles +//! (its encoder is reaped) and the min-fold over zero consumers does not panic. + +use std::collections::{HashMap, HashSet, VecDeque}; +use std::ffi::CString; +use std::os::raw::{c_char, c_int}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::thread::JoinHandle; +use std::time::Instant; + +use datachannel::{ + ConnectionState, DataChannelInfo, IceCandidate, IceState, PeerConnectionHandler, + RtcPeerConnection, SdpType, SessionDescription, SignalingState, +}; +use datachannel_sys as sys; +use pyo3::buffer::PyBuffer; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyList; +use tokio::sync::mpsc; + +use crate::congestion::{CongestionController, TrackControl, LADDER, TOP_STEP}; +use crate::events::{emit_closed_once, Event, EventQueue}; +use crate::media::{ + annexb_access_unit, is_open, open_flag, vcl_nal_count, DropPolicy, EncodeParams, H264Encoder, + OpenFlag, RestartPolicy, VIDEO_CLOCK_HZ, +}; +use crate::producer::{ + attach_producer_chain, deregister_feedback, last_stderr_line, on_rtcp_cb, push_capture_ts, + read_frame, republish, teardown_sys_track, Channels, Frame, FrameData, FrameError, Mutation, + NegState, OutgoingEntry, ProducerChannelHandler, FRAME_QUEUE_CAPACITY, PREOPEN_STASH_FRAMES, + PRODUCER_FB, VIDEO_PAYLOAD_TYPE, +}; +use crate::runtime::{ensure_started, runtime}; +use crate::transport::{ + chrome_sdp_enabled, connection_state_str, lock, loopback_config, map_err, munge_ssrc_cname, + parse_session, raw_pc_id, sdp_type_str, ManifestState, CONTROL_LABEL, +}; + +/// How long after a consumer's track joins its PLIs are suppressed. A joiner has +/// no decodable frame until the next periodic IDR (`keyint`, ~1 s at the source +/// rate), so it will spew PLIs; honouring them would restart the shared encode and +/// blip every other consumer. The grace window is comfortably longer than one +/// keyint so the joiner becomes decodable on the natural periodic IDR before its +/// PLIs are ever honoured. After the window a PLI is real loss and is honoured. +const JOINER_PLI_GRACE_S: f64 = 2.0; + +// --------------------------------------------------------------------------- +// Pure governance helpers (unit-tested without a peer) +// --------------------------------------------------------------------------- + +/// The shared encoder rung from the per-consumer ladder steps: the **min-fold** +/// over consumer estimates, which (because a lower estimate is a coarser, higher +/// ladder index) is the `max` over ladder indices — the worst link caps everyone. +/// Zero consumers does not panic: it returns the finest rung [`TOP_STEP`]. +pub(crate) fn fold_rung(steps: &[usize]) -> usize { + steps.iter().copied().max().unwrap_or(TOP_STEP) +} + +/// Whether a PLI from a track that joined `joined_elapsed_s` ago should be honoured +/// (i.e. trigger a shared-encode keyframe restart). A freshly joined track's PLIs +/// are suppressed until the grace window passes so a joiner cannot disrupt the +/// established consumers; an established track's PLI is real loss and is honoured. +pub(crate) fn should_honor_pli(joined_elapsed_s: f64, grace_s: f64) -> bool { + joined_elapsed_s >= grace_s +} + +/// The encoder's frame rate at a ladder rung: the rung's fps cap, but never above +/// the synthetic source's nominal 45 fps (so libx264's bit budget matches the +/// frames it actually receives). Mirrors the producer's `rung_encoder_fps`. +fn rung_encoder_fps(step: usize) -> u32 { + LADDER[step].fps_cap.min(45) +} + +// --------------------------------------------------------------------------- +// Per-source fan-out state +// --------------------------------------------------------------------------- + +/// One consumer's track for a source — the unit the shared encode fans out to. +/// Created when that consumer's add-track renegotiation is applied; removed on +/// leave or source removal. +struct FanTrack { + /// libdatachannel's integer id for this consumer's track. The shared encode + /// sends NAL units on it via `rtcSendMessage`; teardown deletes it. The + /// consumer's mid for the source lives in the link's `fan_refs` (the + /// source-independent teardown record), not here. + raw_id: i32, + /// Flipped true when this consumer's add-track renegotiation completes, so the + /// fan-out only sends to a track whose remote peer is ready to receive. + open: OpenFlag, + /// This track's congestion controller's published rung. The shared rung is the + /// min-fold over all open fan tracks' rungs. + control: Arc, + /// When this track joined, for the PLI suppression grace window. + joined_at: Instant, +} + +/// A shared video source: exactly one ffmpeg encode, fanned out to every +/// consumer's [`FanTrack`] for the source. The encode reads `ts_queue`, +/// `frames_encoded`, and `fanout` directly (cloned `Arc`s) so it never has to take +/// the broadcaster-wide `sources` lock on the per-frame hot path. +#[derive(Clone, Default)] +struct VideoSource { + /// Capture timestamps (90 kHz) in submit order, one per frame written to the + /// encoder, popped per emitted access unit (in-order baseline H.264). Shared so + /// it survives encoder restarts. + ts_queue: Arc>>, + /// Access units emitted by the single encoder for this source. The + /// observability stat that proves fan-out: it counts one encode regardless of + /// how many consumers receive it, so it does not scale with consumer count. + frames_encoded: Arc, + /// The fan-out set, keyed by consumer id. The feed governs the shared rung as a + /// min-fold over the open entries' controllers; the encode sends each access + /// unit to every open entry's track. + fanout: Arc>>, +} + +/// Persistent per-source ffmpeg encoders, keyed by `track_id`. Exactly one per +/// source (never per consumer); created lazily when the first consumer for a +/// source is ready, reaped when the source idles (no open consumers). +type Encoders = Arc>>>; + +/// Shared map of every video source, keyed by `track_id`. +type Sources = Arc>>; + +/// Shared map of every consumer link, keyed by `consumer_id`. +type Consumers = Arc>>>; + +// --------------------------------------------------------------------------- +// Per-consumer peer connection +// --------------------------------------------------------------------------- + +/// One consumer's peer connection and its negotiation state. The broadcaster is +/// the sole offerer to it; it answers. Each link owns its own PR3 negotiation +/// queue (so a burst of track adds for this consumer never overlaps an in-flight +/// offer), its own control channel + manifest (its own mids), and the set of +/// source ids it currently carries (for leave teardown). +struct ConsumerLink { + consumer_id: String, + /// The libdatachannel peer connection; dropped (set to `None`) on leave/close. + pc: Arc>>>>, + /// This consumer's single-writer negotiation queue. + neg: Arc>, + /// This consumer's outgoing data channels keyed by label (its control channel + /// plus every json/joints channel fanned to it). Shared (`Arc`) with this + /// consumer's flusher and negotiator; the broadcaster reaches into it to fan a + /// `send_json` across consumers and to open a new channel on a live consumer. + channels: Channels, + /// This consumer's published manifest (its own mids + data-channel labels). + /// Shared with the flusher and negotiator; a data-channel add upserts it here + /// and republishes over this consumer's control channel. + manifest: Arc>, + /// Source ids this consumer currently carries a track for (the *desired* set, + /// updated synchronously at enqueue time so a remove can decide whether to queue + /// a teardown even before the matching add has been applied). + sources: Mutex>, + /// The *applied* per-source track identities, `track_id -> (raw_id, mid)`, + /// populated when an add is applied and consulted on remove/leave/close. This is + /// the source-independent record of what to deregister: removing a video source + /// drops the shared `VideoSource` (and its fan-out), so teardown can no longer + /// recover a track's raw id from there — it recovers it from here instead, which + /// is what keeps `PRODUCER_FB` from leaking when a source is removed. Shared + /// (`Arc`) with this consumer's negotiator, which writes it. + fan_refs: Arc>>, + /// Wakes this consumer's negotiation pump. + pump_tx: mpsc::UnboundedSender<()>, + /// Channels signal this on open so the flusher drains their send buffers. + flush_tx: mpsc::UnboundedSender, +} + +/// Peer-connection handler for one consumer link. Identical in shape to the +/// producer's handler, except every surfaced event is wrapped with this +/// consumer's id ([`Event::ForConsumer`]) so the fan-out signaling layer routes +/// it to the right consumer. +struct BroadcastConsumerHandler { + consumer_id: String, + events: EventQueue, + neg: Arc>, + /// The open flag of the track whose add-renegotiation is in flight, flipped + /// true when that cycle returns to Stable — the point the consumer is provably + /// ready to receive the track's RTP. + pending_open: Arc>>, + pump_tx: mpsc::UnboundedSender<()>, + /// One-shot reconnect-needed surface for this consumer's current outage; + /// cleared on Connected. + reconnect_surfaced: Arc, +} + +impl BroadcastConsumerHandler { + /// Push an event tagged with this consumer's id. + fn emit(&self, inner: Event) { + self.events.push(Event::ForConsumer { + consumer_id: self.consumer_id.clone(), + inner: Box::new(inner), + }); + } +} + +impl PeerConnectionHandler for BroadcastConsumerHandler { + type DCH = ProducerChannelHandler; + + fn data_channel_handler(&mut self, _info: DataChannelInfo) -> Self::DCH { + // A consumer never opens channels back to the broadcaster, so this factory + // is effectively unused; hand back a detached handler (its flush_tx goes + // nowhere, exactly like the producer's). + let (flush_tx, _flush_rx) = mpsc::unbounded_channel(); + ProducerChannelHandler::new(String::new(), flush_tx) + } + + fn on_description(&mut self, sess_desc: SessionDescription) { + // Chrome-only SDP munge (gated, so the loopback path is byte-identical), + // mirroring the producer: give the bare a=ssrc its required cname on offers. + let mut sdp = sess_desc.sdp.to_string(); + if sess_desc.sdp_type == SdpType::Offer && chrome_sdp_enabled() { + sdp = munge_ssrc_cname(&sdp, crate::media::PACKETIZER_CNAME); + } + self.emit(Event::LocalDescription { + sdp_type: sdp_type_str(&sess_desc.sdp_type).to_string(), + sdp, + }); + } + + fn on_candidate(&mut self, cand: IceCandidate) { + self.emit(Event::LocalCandidate { + candidate: cand.candidate, + mid: Some(cand.mid), + }); + } + + fn on_connection_state_change(&mut self, state: ConnectionState) { + crate::transport::debug_trace(&self.consumer_id, connection_state_str(&state)); + if state == ConnectionState::New { + return; // the constructor already emitted the initial "new" + } + if state == ConnectionState::Connected { + self.reconnect_surfaced.store(false, Ordering::SeqCst); + } + self.emit(Event::State(connection_state_str(&state).to_string())); + // A failed consumer is torn down and re-added per consumer (this binding + // cannot ICE-restart); surface reconnect-needed once per outage, tagged + // with the consumer id, without disturbing the other consumers or the + // shared encode. + if let crate::transport::ReconnectAction::SurfaceReconnect = crate::transport::reconnect_action( + state, + crate::transport::ICE_RESTART_SUPPORTED, + ) { + if !self.reconnect_surfaced.swap(true, Ordering::SeqCst) { + self.emit(Event::error_for( + &self.consumer_id, + "connection", + "reconnect-needed: consumer connection failed and ICE restart is \ + unsupported — remove and re-add this consumer", + )); + } + } + } + + fn on_signaling_state_change(&mut self, state: SignalingState) { + crate::transport::debug_trace(&self.consumer_id, &format!("sig:{state:?}")); + // A return to Stable means the in-flight offer's answer has been applied + // (the bootstrap data-channel cycle, or a track-add cycle). Open the track + // whose add just completed, clear the gate, and wake the pump — all off the + // PC (we only touch the neg/pending locks here). + if state == SignalingState::Stable { + if let Some(open) = lock(&self.pending_open).take() { + open.store(true, Ordering::SeqCst); + } + lock(&self.neg).in_flight = false; + let _ = self.pump_tx.send(()); + } + } + + fn on_ice_state_change(&mut self, state: IceState) { + crate::transport::debug_trace(&self.consumer_id, &format!("ice:{state:?}")); + } +} + +/// Production negotiator for one consumer: applies a single track mutation against +/// that consumer's peer connection (add/drop the track on its PC, drive its offer, +/// register/unregister the [`FanTrack`] in the shared source's fan-out). Drives the +/// reused [`pump_step`](crate::producer::pump_step) control logic. +struct ConsumerNegotiator { + consumer_id: String, + pc: Arc>>>>, + channels: Channels, + manifest: Arc>, + pending_open: Arc>>, + sources: Sources, + /// Shared with the [`ConsumerLink`]: `track_id -> (raw_id, mid)` for every + /// applied track, so remove/leave/close can tear a track down without the + /// shared `VideoSource` (which a source removal drops). + fan_refs: Arc>>, + events: EventQueue, +} + +impl crate::producer::NegotiationApply for ConsumerNegotiator { + fn apply(&self, mutation: Mutation) -> Result<(), String> { + let result = match mutation { + Mutation::Add { + track_id, + mid, + ssrc, + } => self.apply_add(track_id, mid, ssrc), + Mutation::Remove { track_id } => self.apply_remove(track_id), + }; + // Surface a per-consumer negotiation failure on the event queue (tagged + // with the consumer id) rather than only tracing it; the pump still clears + // the gate and drains on, so one bad mutation never wedges this consumer. + if let Err(err) = &result { + self.events.push(Event::error_for( + &self.consumer_id, + "negotiate", + format!("track mutation failed: {err}"), + )); + } + result + } +} + +impl ConsumerNegotiator { + /// Add this consumer's track for `track_id`: create the sys track on its PC, + /// attach the built-in H.264 chain, register its controller + fan-out entry, + /// and drive the offer. Mirrors the producer's `apply_mutation` add arm. + fn apply_add(&self, track_id: String, mid: String, ssrc: u32) -> Result<(), String> { + let open = open_flag(); + let control = Arc::new(TrackControl::default()); + let raw_id = { + let mut guard = lock(&self.pc); + let pc = guard.as_mut().ok_or("consumer is closed")?; + let pc_id = raw_pc_id(pc).ok_or("cannot recover pc id for rtcAddTrackEx")?; + let mid_c = CString::new(mid.clone()).map_err(|e| e.to_string())?; + let track_c = CString::new(track_id.clone()).map_err(|e| e.to_string())?; + let init = sys::rtcTrackInit { + direction: sys::rtcDirection_RTC_DIRECTION_SENDONLY, + codec: sys::rtcCodec_RTC_CODEC_H264, + payloadType: VIDEO_PAYLOAD_TYPE, + ssrc, + mid: mid_c.as_ptr(), + name: std::ptr::null(), + msid: std::ptr::null(), + trackId: track_c.as_ptr(), + profile: std::ptr::null(), + }; + // SAFETY: `pc_id` is this live PC's id; the CString pointers live until + // the end of this block. + let raw_id = unsafe { sys::rtcAddTrackEx(pc_id, &init) }; + if raw_id < 0 { + return Err(format!("rtcAddTrackEx failed: {raw_id}")); + } + // From here the track exists; any later failure tears it down so a + // mid-setup error leaves no PRODUCER_FB entry or callback leaked. + let setup = (|| -> Result<(), String> { + attach_producer_chain(raw_id, ssrc)?; + lock(&PRODUCER_FB).insert( + raw_id, + Arc::new(CongestionController::new(ssrc, control.clone())), + ); + // SAFETY: `raw_id` is the just-created track id. + unsafe { sys::rtcSetMessageCallback(raw_id, Some(on_rtcp_cb)) }; + pc.set_local_description(SdpType::Offer) + .map_err(|e| e.to_string())?; + Ok(()) + })(); + if let Err(err) = setup { + teardown_sys_track(raw_id); + return Err(err); + } + raw_id + }; + // Record the applied identity on the link (source-independent), so a later + // remove/leave/close can deregister this track even after the shared source + // is gone. + lock(&self.fan_refs) + .insert(track_id.clone(), (raw_id, mid.clone())); + // Register the fan-out entry on the shared source so the encode reaches it. + if let Some(source) = lock(&self.sources).get(&track_id).cloned() { + lock(&source.fanout).insert( + self.consumer_id.clone(), + FanTrack { + raw_id, + open: open.clone(), + control, + joined_at: Instant::now(), + }, + ); + } + // Arm the open flag to flip when this offer's answer is applied (Stable). + *lock(&self.pending_open) = Some(open); + lock(&self.manifest).upsert_video_track(&mid, &track_id); + republish(&self.channels, &self.manifest); + Ok(()) + } + + /// Remove this consumer's track for `track_id`: tear it down on its PC, + /// deregister its controller, drop the fan-out entry, and renegotiate. Mirrors + /// the producer's `apply_mutation` remove arm. The track's raw id comes from the + /// link's `fan_refs` (not the shared source), so this still tears the track down + /// and deregisters `PRODUCER_FB` even when the source was already removed. + fn apply_remove(&self, track_id: String) -> Result<(), String> { + let Some((raw_id, mid)) = lock(&self.fan_refs).remove(&track_id) else { + return Ok(()); // unknown track for this consumer: no-op, keep draining + }; + // Drop the fan-out entry if the source still exists (the send path); if the + // source was removed, its fan-out went with it. + if let Some(source) = lock(&self.sources).get(&track_id).cloned() { + lock(&source.fanout).remove(&self.consumer_id); + } + // Deregister the controller, clear the callback, and delete the sys track + // via the shared teardown (the same path remove/close/error-cleanup use). + teardown_sys_track(raw_id); + { + let mut guard = lock(&self.pc); + let pc = guard.as_mut().ok_or("consumer is closed")?; + pc.set_local_description(SdpType::Offer) + .map_err(|e| e.to_string())?; + } + lock(&self.manifest).remove_entry(&mid); + republish(&self.channels, &self.manifest); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// The broadcaster pyclass +// --------------------------------------------------------------------------- + +/// One producer fanning a single shared encode per source out to many consumers. +#[pyclass] +pub struct Broadcaster { + events: EventQueue, + /// The single shared ingress sender, behind an `Option` so [`close`](Self::close) + /// can drop it and end the shared feed thread's `blocking_recv` deterministically. + frame_tx: Mutex>>, + /// The shared encoder feed thread, joined on close so no thread outlives the + /// broadcaster. + feed_handle: Mutex>>, + closed: Arc, + /// Broadcaster epoch: `submit_frame` stamps each frame's capture time as + /// `elapsed_since(epoch)` in 90 kHz units. + epoch: Instant, + sources: Sources, + consumers: Consumers, + encoders: Encoders, + /// Allocates a process-wide-unique RTP SSRC for every track (per-consumer + /// tracks for the same source still get distinct SSRCs). + ssrc_counter: AtomicU64, + /// The data channels to open on every consumer, in insertion order, as + /// `(label, kind)`. `add_data_channel` records each here (so a future consumer + /// opens them at bootstrap) and opens it on every current consumer. The + /// reserved `control` label is never recorded here. Mirrors the 1:1 + /// `Producer`'s data channels, fanned across consumers. + data_channels: Mutex>, +} + +#[pymethods] +impl Broadcaster { + /// Create a broadcaster with no consumers and no sources. `connection_id` is an + /// opaque label for logging/correlation. `frame_queue_capacity` sizes the + /// single shared bounded ingress queue feeding the per-source encode. + #[new] + #[pyo3(signature = (connection_id=None, frame_queue_capacity=FRAME_QUEUE_CAPACITY))] + fn new(connection_id: Option, frame_queue_capacity: usize) -> PyResult { + let _ = connection_id; + ensure_started(); + + let (frame_tx, frame_rx) = mpsc::channel::(frame_queue_capacity.max(1)); + let sources: Sources = Arc::new(Mutex::new(HashMap::new())); + let encoders: Encoders = Arc::new(Mutex::new(HashMap::new())); + let events = EventQueue::default(); + let feed_handle = + spawn_broadcast_feed(frame_rx, encoders.clone(), sources.clone(), events.clone()); + + events.push(Event::State("new".to_string())); + + Ok(Self { + events, + frame_tx: Mutex::new(Some(frame_tx)), + feed_handle: Mutex::new(Some(feed_handle)), + closed: Arc::new(AtomicBool::new(false)), + epoch: Instant::now(), + sources, + consumers: Arc::new(Mutex::new(HashMap::new())), + encoders, + ssrc_counter: AtomicU64::new(1), + data_channels: Mutex::new(Vec::new()), + }) + } + + /// Add a consumer: stand up its answer-only peer connection (the broadcaster is + /// the sole offerer), open its control channel (which triggers the offer), and + /// add a per-consumer track for every existing source, then negotiate that + /// consumer only. The track adds are queued behind the bootstrap offer so they + /// do not race it. + fn add_consumer(&self, consumer_id: &str) -> PyResult<()> { + if self.closed.load(Ordering::SeqCst) { + return Err(PyValueError::new_err("broadcaster is closed")); + } + if lock(&self.consumers).contains_key(consumer_id) { + return Err(PyValueError::new_err(format!( + "consumer {consumer_id:?} already added" + ))); + } + + let channels: Channels = Arc::new(Mutex::new(HashMap::new())); + let manifest = Arc::new(Mutex::new(ManifestState::default())); + let neg = Arc::new(Mutex::new(NegState::default())); + let pending_open: Arc>> = Arc::new(Mutex::new(None)); + let fan_refs: Arc>> = + Arc::new(Mutex::new(HashMap::new())); + let (flush_tx, flush_rx) = mpsc::unbounded_channel::(); + let (pump_tx, pump_rx) = mpsc::unbounded_channel::<()>(); + + spawn_flusher(channels.clone(), manifest.clone(), flush_rx); + + let handler = BroadcastConsumerHandler { + consumer_id: consumer_id.to_string(), + events: self.events.clone(), + neg: neg.clone(), + pending_open: pending_open.clone(), + pump_tx: pump_tx.clone(), + reconnect_surfaced: Arc::new(AtomicBool::new(false)), + }; + let pc = Arc::new(Mutex::new(Some( + RtcPeerConnection::new(&loopback_config(), handler).map_err(map_err)?, + ))); + // The constructor's connection state is not surfaced (libdatachannel emits + // New first); emit the per-consumer "new" so the signaling layer sees it. + self.events.push(Event::ForConsumer { + consumer_id: consumer_id.to_string(), + inner: Box::new(Event::State("new".to_string())), + }); + + spawn_consumer_pump( + consumer_id.to_string(), + pc.clone(), + channels.clone(), + manifest.clone(), + pending_open.clone(), + self.sources.clone(), + fan_refs.clone(), + neg.clone(), + self.events.clone(), + pump_rx, + ); + + let link = Arc::new(ConsumerLink { + consumer_id: consumer_id.to_string(), + pc: pc.clone(), + neg: neg.clone(), + channels: channels.clone(), + manifest: manifest.clone(), + sources: Mutex::new(HashSet::new()), + fan_refs, + pump_tx: pump_tx.clone(), + flush_tx, + }); + + // Gate any track adds behind the bootstrap offer: mark a cycle in flight so + // the pump will not apply a track mutation (a second set_local_description) + // until the control-channel offer's answer brings signaling back to Stable. + lock(&neg).in_flight = true; + // Open the control channel — this triggers the bootstrap offer. + { + let dch = ProducerChannelHandler::new(CONTROL_LABEL.to_string(), link.flush_tx.clone()); + let channel = { + let mut guard = lock(&pc); + let pc = guard + .as_mut() + .ok_or_else(|| PyValueError::new_err("consumer is closed"))?; + pc.create_data_channel(CONTROL_LABEL, dch) + .map_err(map_err)? + }; + lock(&channels).insert( + CONTROL_LABEL.to_string(), + OutgoingEntry { + channel, + open: false, + pending: VecDeque::new(), + }, + ); + } + + // Open every registered json/joints data channel on this fresh consumer so + // a late joiner gets the existing channels at bootstrap (DCEP over the same + // SCTP the control channel brings up — no renegotiation; see PR2). + let registered: Vec<(String, String)> = lock(&self.data_channels).clone(); + for (label, kind) in registered { + if let Err(err) = open_consumer_channel(&link, &label, &kind) { + self.events.push(Event::error_for( + consumer_id, + "data-channel", + format!("failed to open data channel {label:?} at bootstrap: {err}"), + )); + } + } + + lock(&self.consumers).insert(consumer_id.to_string(), link.clone()); + + // Add a track for each existing source (queued behind the bootstrap gate). + let existing: Vec = lock(&self.sources).keys().cloned().collect(); + for track_id in existing { + self.enqueue_add(&link, &track_id); + } + let _ = pump_tx.send(()); + Ok(()) + } + + /// Remove a consumer: tear down only its peer connection, its tracks, and its + /// controllers, without disturbing the other consumers or the shared encode. + fn remove_consumer(&self, consumer_id: &str) -> PyResult<()> { + let link = lock(&self.consumers).remove(consumer_id); + let Some(link) = link else { + return Ok(()); // unknown / already removed: idempotent + }; + teardown_consumer(&self.sources, &link); + Ok(()) + } + + /// Add a video source visible to all consumers (current and future). + /// `submit_frame(track_id, ...)` encodes it once and fans it out. For each + /// current consumer this queues a per-consumer track add (renegotiated per + /// consumer); a future consumer picks the source up when it is added. + fn add_video_track(&self, track_id: &str) -> PyResult<()> { + if self.closed.load(Ordering::SeqCst) { + return Err(PyValueError::new_err("broadcaster is closed")); + } + lock(&self.sources).entry(track_id.to_string()).or_default(); + let links: Vec> = lock(&self.consumers).values().cloned().collect(); + for link in links { + self.enqueue_add(&link, track_id); + let _ = link.pump_tx.send(()); + } + Ok(()) + } + + /// Remove a video source from every consumer and stop its shared encode. + fn remove_video_track(&self, track_id: &str) -> PyResult<()> { + let links: Vec> = lock(&self.consumers).values().cloned().collect(); + for link in links { + if lock(&link.sources).remove(track_id) { + lock(&link.neg).pending.push_back(Mutation::Remove { + track_id: track_id.to_string(), + }); + let _ = link.pump_tx.send(()); + } + } + // Stop the shared encode immediately (kills its ffmpeg); drop the source. + lock(&self.encoders).remove(track_id); + lock(&self.sources).remove(track_id); + Ok(()) + } + + /// Enqueue one raw frame for `track_id` onto the single shared bounded ingress + /// queue and return immediately. Never blocks: under overload the frame is + /// dropped. Same drop policy and contiguity contract as the producer. + #[pyo3(signature = (track_id, frame))] + fn submit_frame(&self, track_id: &str, frame: PyBuffer) -> PyResult<()> { + let FrameData { + data, + width, + height, + } = read_frame(&frame).map_err(|err| match err { + FrameError::NotContiguous => PyValueError::new_err("frame buffer must be C-contiguous"), + FrameError::BadShape => PyValueError::new_err("frame must be an 8-bit HxWx3 image"), + })?; + let capture_ts = (self.epoch.elapsed().as_secs_f64() * VIDEO_CLOCK_HZ as f64) as u32; + let job = Frame { + track_id: track_id.to_string(), + data, + width, + height, + capture_ts, + }; + let guard = lock(&self.frame_tx); + let Some(frame_tx) = guard.as_ref() else { + return Ok(()); // closed: no-op + }; + let capacity = frame_tx.max_capacity(); + let backlog = capacity - frame_tx.capacity(); + if !DropPolicy::new(capacity).admit(backlog) { + return Ok(()); + } + let _ = frame_tx.try_send(job); + Ok(()) + } + + /// Open a reliable-ordered data channel with `label` on every consumer + /// (current and future). For a current consumer the channel opens over its + /// existing SCTP association (DCEP, no renegotiation — PR2); a future consumer + /// opens it at bootstrap (see [`add_consumer`](Self::add_consumer)). `kind` is + /// an opaque label hint recorded in each consumer's manifest. The reserved + /// `control` label carries the per-consumer manifest and is never a json/joints + /// label. Idempotent per label: re-adding a known label is a no-op so a live + /// consumer never gets the same channel twice. Mirrors the 1:1 + /// [`Producer::add_data_channel`](crate::producer::Producer), fanned across + /// consumers. + fn add_data_channel(&self, label: &str, kind: &str) -> PyResult<()> { + if self.closed.load(Ordering::SeqCst) { + return Err(PyValueError::new_err("broadcaster is closed")); + } + if label == CONTROL_LABEL { + return Err(PyValueError::new_err( + "'control' is reserved for the per-consumer manifest transport", + )); + } + // Record once so a future consumer opens it at bootstrap; skip a known + // label so a live consumer is never handed a duplicate channel. + { + let mut registry = lock(&self.data_channels); + if registry.iter().any(|(existing, _)| existing == label) { + return Ok(()); + } + registry.push((label.to_string(), kind.to_string())); + } + // Open it on every current consumer over its existing association. + let links: Vec> = lock(&self.consumers).values().cloned().collect(); + for link in links { + if let Err(err) = open_consumer_channel(&link, label, kind) { + self.events.push(Event::error_for( + &link.consumer_id, + "data-channel", + format!("failed to open data channel {label:?}: {err}"), + )); + } + } + Ok(()) + } + + /// Send a JSON payload (already-serialised text) over the named data channel of + /// **every** consumer that has it. A consumer still mid-bootstrap buffers it + /// behind its pre-open gate and replays it in order on open, so no consumer + /// loses or reorders a message. A label no consumer carries (e.g. one never + /// added, or sent before any browser connected) simply reaches no one — the + /// broadcast analogue of the 1:1 [`Producer::send_json`](crate::producer::Producer). + fn send_json(&self, label: &str, payload: &str) -> PyResult<()> { + let links: Vec> = lock(&self.consumers).values().cloned().collect(); + for link in links { + let mut map = lock(&link.channels); + if let Some(entry) = map.get_mut(label) { + entry.send(payload.as_bytes().to_vec()); + } + } + Ok(()) + } + + /// Apply a consumer's SDP answer, routed by `consumer_id`. + fn set_remote_answer(&self, consumer_id: &str, sdp: &str) -> PyResult<()> { + let link = lock(&self.consumers) + .get(consumer_id) + .cloned() + .ok_or_else(|| PyValueError::new_err(format!("unknown consumer {consumer_id:?}")))?; + let sess = parse_session(sdp, SdpType::Answer)?; + let mut guard = lock(&link.pc); + let pc = guard + .as_mut() + .ok_or_else(|| PyValueError::new_err("consumer is closed"))?; + pc.set_remote_description(&sess).map_err(map_err) + } + + /// Apply a remote ICE candidate trickled from a consumer, routed by + /// `consumer_id`. + #[pyo3(signature = (consumer_id, candidate, mid=None))] + fn add_remote_candidate( + &self, + consumer_id: &str, + candidate: &str, + mid: Option, + ) -> PyResult<()> { + let link = lock(&self.consumers) + .get(consumer_id) + .cloned() + .ok_or_else(|| PyValueError::new_err(format!("unknown consumer {consumer_id:?}")))?; + let cand = IceCandidate { + candidate: candidate.to_string(), + mid: mid.unwrap_or_default(), + }; + let mut guard = lock(&link.pc); + let pc = guard + .as_mut() + .ok_or_else(|| PyValueError::new_err("consumer is closed"))?; + pc.add_remote_candidate(&cand).map_err(map_err) + } + + /// Drain and return all queued events as a list of dicts. Per-consumer events + /// carry a `"consumer_id"` key so the signaling layer routes them. + fn drain_events(&self, py: Python<'_>) -> PyResult> { + self.events.drain_to_py(py) + } + + /// The number of live shared encoders. Exactly one per active source, + /// **independent of the consumer count** — the observable that proves the + /// encode is shared (fan-out) rather than re-encoded per consumer. + fn encoder_count(&self) -> usize { + lock(&self.encoders).len() + } + + /// Access units the shared encoder has emitted for `track_id` (one encode, + /// fanned out), or `None` for an unknown source. Does not scale with the + /// consumer count. + fn frames_encoded(&self, track_id: &str) -> Option { + lock(&self.sources) + .get(track_id) + .map(|source| source.frames_encoded.load(Ordering::SeqCst)) + } + + /// The shared ladder rung a source is currently encoded at (the min-fold over + /// its open consumers' controllers; finest = 0). `None` for an unknown source. + fn congestion_step(&self, track_id: &str) -> Option { + lock(&self.sources).get(track_id).map(|source| { + let steps = open_rungs(&source); + fold_rung(&steps) + }) + } + + /// The number of consumers currently attached. + fn consumer_count(&self) -> usize { + lock(&self.consumers).len() + } + + /// Close the broadcaster. Idempotent: tears down every consumer, kills every + /// encoder, drops every source, then emits a final `on_state: "closed"`. + fn close(&self) -> PyResult<()> { + if emit_closed_once(&self.closed, &self.events) { + // Stop the shared feed first (drop the ingress sender, join the thread) + // so no thread or ffmpeg subprocess outlives the broadcaster. + *lock(&self.frame_tx) = None; + if let Some(handle) = lock(&self.feed_handle).take() { + let _ = handle.join(); + } + let links: Vec> = + lock(&self.consumers).drain().map(|(_, v)| v).collect(); + for link in links { + teardown_consumer(&self.sources, &link); + } + lock(&self.encoders).clear(); + lock(&self.sources).clear(); + } + Ok(()) + } +} + +impl Broadcaster { + /// Allocate this consumer's identity for a source and queue an add mutation on + /// its negotiation queue (the caller pings the pump). + fn enqueue_add(&self, link: &Arc, track_id: &str) { + // The Producer owns the mid: use the supplied track_id verbatim (see + // crate::producer::track_mid) as the SDP m-line mid. The Producer registers + // available_robots under this same value, so the offer's a=mid and the SSE + // manifest agree and the browser's identityForMid succeeds. Every consumer + // link receives the same source set in the same order, so one producer-owned + // mid per source is globally consistent across consumers. The old per-link + // "v{n}" counter produced "v0" on the wire while available_robots held "0", + // which never matched, so the browser rendered no tile. + let mid = crate::producer::track_mid(track_id); + let ssrc = self.ssrc_counter.fetch_add(1, Ordering::SeqCst) as u32; + lock(&link.sources).insert(track_id.to_string()); + lock(&link.neg).pending.push_back(Mutation::Add { + track_id: track_id.to_string(), + mid, + ssrc, + }); + } +} + +/// Open one reliable-ordered data channel on a single consumer link: create it on +/// that consumer's peer connection, register it in the link's channel map (so a +/// later `send_json` finds it and the pre-open gate buffers until it opens), then +/// upsert its label into that consumer's manifest and republish over its control +/// channel. Mirrors the 1:1 [`Producer::add_data_channel`](crate::producer::Producer) +/// for one consumer; the broadcaster calls it for every consumer. +fn open_consumer_channel(link: &ConsumerLink, label: &str, kind: &str) -> Result<(), String> { + let dch = ProducerChannelHandler::new(label.to_string(), link.flush_tx.clone()); + let channel = { + let mut guard = lock(&link.pc); + let pc = guard.as_mut().ok_or("consumer is closed")?; + pc.create_data_channel(label, dch) + .map_err(|e| e.to_string())? + }; + lock(&link.channels).insert( + label.to_string(), + OutgoingEntry { + channel, + open: false, + pending: VecDeque::new(), + }, + ); + lock(&link.manifest).upsert_data_channel(label, kind); + republish(&link.channels, &link.manifest); + Ok(()) +} + +/// Tear down one consumer: drop its fan-out entries from every source (deregister +/// their controllers), then drop its peer connection (which closes its tracks and +/// fires its Closed callback). The shared encode and other consumers are untouched. +fn teardown_consumer(sources: &Sources, link: &ConsumerLink) { + // Deregister every applied track from the link's own record, so cleanup is + // independent of whether the shared source still exists (a removed source has + // already dropped its fan-out). This is what keeps PRODUCER_FB from leaking on + // leave/close after a source removal. + let refs: Vec<(String, (i32, String))> = lock(&link.fan_refs).drain().collect(); + let source_map = lock(sources).clone(); + for (track_id, (raw_id, _mid)) in refs { + if let Some(source) = source_map.get(&track_id) { + lock(&source.fanout).remove(&link.consumer_id); + } + // Deregister the controller and clear the track's RTCP callback before the + // PC drop frees the track, so no chain callback races teardown with a stale + // registry entry and no entry/callback is leaked. + deregister_feedback(raw_id); + // SAFETY: `raw_id` is this consumer's track, alive until the PC drop just + // below. The PC drop deletes the track itself. + unsafe { sys::rtcSetMessageCallback(raw_id, None) }; + } + // Drop this consumer's data channels before the PC (the producer-close order), + // releasing their flush_tx handles so the flusher task ends, and tearing the + // channels down cleanly with no per-consumer leak. + lock(&link.channels).clear(); + // Dropping the PC frees its tracks/SRTP and fires the consumer's Closed; the + // pump task ends once every pump_tx sender (handler + link) is gone. + *lock(&link.pc) = None; +} + +/// Spawn one consumer's negotiation pump (reusing the producer's `pump_step`). +#[allow(clippy::too_many_arguments)] +fn spawn_consumer_pump( + consumer_id: String, + pc: Arc>>>>, + channels: Channels, + manifest: Arc>, + pending_open: Arc>>, + sources: Sources, + fan_refs: Arc>>, + neg: Arc>, + events: EventQueue, + mut pump_rx: mpsc::UnboundedReceiver<()>, +) { + runtime().spawn(async move { + let negotiator = ConsumerNegotiator { + consumer_id, + pc, + channels, + manifest, + pending_open, + sources, + fan_refs, + events, + }; + while pump_rx.recv().await.is_some() { + crate::producer::pump_step(&neg, &negotiator); + } + }); +} + +/// Spawn the per-consumer flusher: drains a channel's pre-open send buffer once it +/// opens, and re-sends the current manifest when the control channel opens. +/// Mirrors the producer's flusher. +fn spawn_flusher( + channels: Channels, + manifest: Arc>, + mut flush_rx: mpsc::UnboundedReceiver, +) { + runtime().spawn(async move { + while let Some(label) = flush_rx.recv().await { + { + let mut map = lock(&channels); + if let Some(entry) = map.get_mut(&label) { + entry.open = true; + entry.flush(); + } + } + if label == CONTROL_LABEL { + let json = lock(&manifest).to_json(); + let mut map = lock(&channels); + if let Some(entry) = map.get_mut(CONTROL_LABEL) { + entry.send(json.into_bytes()); + } + } + } + }); +} + +// --------------------------------------------------------------------------- +// The shared encoder feed (one ffmpeg encode per source, fanned to N consumers) +// --------------------------------------------------------------------------- + +/// The per-source feed state local to the feed thread. +#[derive(Default)] +struct FeedState { + applied_step: Option, + allowance: f64, + last_tick: Option, + stash: VecDeque, + /// Bounded crash-restart budget for this source's shared encoder. + restart: RestartPolicy, +} + +impl FeedState { + fn stash(&mut self, frame: Frame) { + if self.stash.len() < PREOPEN_STASH_FRAMES { + self.stash.push_back(frame); + } + } +} + +/// The desired-rung steps of a source's currently-open fan tracks. +fn open_rungs(source: &VideoSource) -> Vec { + lock(&source.fanout) + .values() + .filter(|ft| is_open(&ft.open)) + .map(|ft| ft.control.desired_step()) + .collect() +} + +/// Coalesce a source's pending PLIs across its fan tracks, draining each track's +/// request (so a suppressed joiner's PLI does not accumulate) and honouring a +/// keyframe restart only for tracks past the join grace window. +fn coalesce_pli(source: &VideoSource) -> bool { + let mut honor = false; + for ft in lock(&source.fanout).values() { + let requested = ft.control.take_pli(); + if requested && should_honor_pli(ft.joined_at.elapsed().as_secs_f64(), JOINER_PLI_GRACE_S) { + honor = true; + } + } + honor +} + +/// Spawn the shared encoder feed on a dedicated OS thread (off the tokio pool: it +/// makes blocking ffmpeg-stdin writes). It drains the single shared ingress queue, +/// runs exactly one ffmpeg encode per source, governs that encode's rung by the +/// min-fold over the source's open consumers, and fans each encoded access unit out +/// to every open consumer's track. The blocking write propagates back-pressure to +/// the shared ingress queue (the single shed point in `submit_frame`). +fn spawn_broadcast_feed( + mut frame_rx: mpsc::Receiver, + encoders: Encoders, + sources: Sources, + events: EventQueue, +) -> JoinHandle<()> { + std::thread::Builder::new() + .name("ncwebrtc-fanout-feed".into()) + .spawn(move || { + let adapt_disabled = std::env::var_os("NCD_WEBRTC_DISABLE_ADAPT").is_some(); + let mut feeds: HashMap = HashMap::new(); + while let Some(frame) = frame_rx.blocking_recv() { + let track_id = frame.track_id.clone(); + // The source must exist and have at least one open consumer before + // we encode; otherwise stash (bounded) so a first consumer's IDR is + // not lost while it is still negotiating. + let Some(source) = lock(&sources).get(&track_id).cloned() else { + feeds.entry(track_id).or_default().stash(frame); + continue; + }; + let steps = open_rungs(&source); + if steps.is_empty() { + // No open consumers: idle the encode (reap its ffmpeg) and hold + // the frame. The min-fold over zero consumers does not panic. + lock(&encoders).remove(&track_id); + let feed = feeds.entry(track_id).or_default(); + feed.applied_step = None; + feed.stash(frame); + continue; + } + let feed = feeds.entry(track_id.clone()).or_default(); + + let desired = if adapt_disabled { + TOP_STEP + } else { + fold_rung(&steps) + }; + let pli = coalesce_pli(&source); + + // Crash detection for the shared encode: a crashed ffmpeg is + // restarted (not stalled) so it does not silently stall every + // consumer. A shared-encode crash belongs to no single consumer, so + // the on_error carries no consumer_id. Resync the source's capture- + // timestamp queue and rebuild within the bounded budget. + let dead = lock(&encoders) + .get(&track_id) + .map(|e| !e.is_alive()) + .unwrap_or(false); + if dead { + let detail = lock(&encoders) + .get(&track_id) + .map(|e| e.stderr_tail()) + .unwrap_or_default(); + lock(&encoders).remove(&track_id); + lock(&source.ts_queue).clear(); + feed.applied_step = None; + if feed.restart.should_restart() { + events.push(Event::error( + "encode", + format!( + "shared encoder for {track_id:?} crashed; restarting (ffmpeg: {})", + last_stderr_line(&detail) + ), + )); + } else { + events.push(Event::error( + "encode", + format!( + "shared encoder for {track_id:?} crashed and exceeded the restart \ + budget; dropping frames (ffmpeg: {})", + last_stderr_line(&detail) + ), + )); + feed.stash(frame); + continue; + } + } else if lock(&encoders).contains_key(&track_id) { + feed.restart.reset(); + } + + let missing = !lock(&encoders).contains_key(&track_id); + if feed.applied_step != Some(desired) || pli || missing { + match make_broadcast_encoder( + frame.width, + frame.height, + desired, + &source, + events.clone(), + ) { + Some(encoder) => { + lock(&encoders).insert(track_id.clone(), encoder); + feed.applied_step = Some(desired); + } + None => { + if !dead && feed.restart.should_restart() { + events.push(Event::error( + "encode", + format!("could not spawn shared ffmpeg encoder for {track_id:?}"), + )); + } + feed.stash(frame); + continue; + } + } + } + + let Some(encoder) = lock(&encoders).get(&track_id).cloned() else { + feed.stash(frame); + continue; + }; + // Flush the pre-open stash IDR-first (the startup burst, unpaced). + for held in feed.stash.drain(..) { + push_capture_ts(&source.ts_queue, held.capture_ts); + encoder.write_frame(&held.data); + } + // Input fps cap via a token bucket at the rung's fps cap. + let fps_cap = LADDER[desired].fps_cap.max(1) as f64; + let now = Instant::now(); + if let Some(last) = feed.last_tick { + feed.allowance += now.duration_since(last).as_secs_f64() * fps_cap; + } else { + feed.allowance = 1.0; + } + feed.last_tick = Some(now); + if feed.allowance > fps_cap { + feed.allowance = fps_cap; + } + if feed.allowance < 1.0 { + continue; // over the cap -> drop this frame + } + feed.allowance -= 1.0; + push_capture_ts(&source.ts_queue, frame.capture_ts); + encoder.write_frame(&frame.data); + } + }) + .expect("spawn broadcast encoder feed thread") +} + +/// Build a fresh shared ffmpeg encoder for `step`. Its per-access-unit callback +/// fans the encoded NAL units out to every currently-open consumer track for the +/// source: it stamps the access unit's shared capture timestamp on each track and +/// sends the same Annex-B bytes on each track's raw id. Each consumer's own chain +/// packetizes independently (its own SSRC/sequence). Returns `None` if ffmpeg could +/// not be spawned. The PR5.6 invariant guards are preserved (one VCL NAL / one +/// capture timestamp per access unit). +fn make_broadcast_encoder( + width: u32, + height: u32, + step: usize, + source: &VideoSource, + events: EventQueue, +) -> Option> { + let params = EncodeParams { + fps: rung_encoder_fps(step), + bitrate: LADDER[step].bitrate, + scale: LADDER[step].scale, + }; + let ts_queue = source.ts_queue.clone(); + let frames_encoded = source.frames_encoded.clone(); + let fanout = source.fanout.clone(); + let on_access_unit = move |access_unit: Vec>| { + // Invariant: one access unit -> exactly one RTP frame under one capture + // timestamp, so exactly one VCL NAL. A multi-slice/aggregation change would + // desync the timestamp queue and fabricate per-slice timestamps — the + // Chrome-only, loopback-invisible defect in reports/SPIKE-chrome-pframe.md. + // Fail loud and drop rather than send a malformed/fabricated frame. + let vcl_count = vcl_nal_count(&access_unit); + if vcl_count != 1 { + eprintln!( + "[ncwebrtc] INVARIANT VIOLATED: access unit has {vcl_count} VCL NAL(s), \ + expected exactly 1 (one slice per frame). Dropping rather than \ + fabricating RTP timestamps — see reports/SPIKE-chrome-pframe.md." + ); + return; + } + let buf = annexb_access_unit(&access_unit); + if buf.is_empty() { + return; + } + // One capture timestamp per access unit (shared across all consumer tracks + // for the frame). No silent fallback: an underflow means more access units + // than input frames (a multi-slice encode), so fail loud and drop. + let Some(ts) = lock(&ts_queue).pop_front() else { + eprintln!( + "[ncwebrtc] INVARIANT VIOLATED: capture-timestamp queue underflow \ + (more access units than input frames — multi-slice encode?). \ + Dropping rather than fabricating a timestamp — see \ + reports/SPIKE-chrome-pframe.md." + ); + return; + }; + // One encode -> fan out to every open consumer's track. This counter is the + // observable that proves the encode is shared (one bump per access unit, + // regardless of how many consumers it is sent to). + frames_encoded.fetch_add(1, Ordering::SeqCst); + let fan = lock(&fanout); + for (consumer_id, ft) in fan.iter() { + if !is_open(&ft.open) { + continue; + } + // SAFETY: `raw_id` is a live sys track id this broadcaster created and + // owns until the consumer leaves / the source is removed; both sys calls + // are libdatachannel-internally locked. + let sent = unsafe { + sys::rtcSetTrackRtpTimestamp(ft.raw_id, ts); + sys::rtcSendMessage( + ft.raw_id, + buf.as_ptr() as *const c_char, + buf.len() as c_int, + ) + }; + if sent < 0 { + // The track went away under us (consumer closed / SRTP torn down). + // Stop sending to it (flip its open flag) so this does not spam each + // frame, and surface one per-consumer error; the leave path reclaims + // the entry. The other consumers and the shared encode are untouched. + ft.open.store(false, Ordering::SeqCst); + events.push(Event::error_for( + consumer_id, + "send", + format!("send on a closed track for consumer {consumer_id:?}; suppressing"), + )); + } + } + }; + match H264Encoder::new(width, height, params, on_access_unit) { + Ok(encoder) => Some(Arc::new(encoder)), + Err(err) => { + crate::transport::debug_trace("B", &format!("encoder spawn failed: {err}")); + None + } + } +} + +#[cfg(test)] +mod tests { + //! Peer-free unit tests for the broadcaster's governance and fan-out + //! bookkeeping: the min-fold, join PLI suppression, and the fan-out set's + //! add/remove routing. None touch a socket, a peer, ffmpeg, or the GIL. + + use super::*; + use crate::congestion::bottom_step; + + // --- min-fold: the worst link caps everyone ------------------------------ + + #[test] + fn fold_rung_is_the_max_ladder_index_the_worst_link_caps_everyone() { + // A lower estimate is a coarser (higher-index) rung, so the min estimate is + // the max index. Two consumers on the finest rung and one on a coarse rung + // -> everyone encodes at the coarse rung. + assert_eq!(fold_rung(&[0, 0, 3]), 3); + assert_eq!(fold_rung(&[1, 2, 2]), 2); + } + + #[test] + fn a_newly_worst_consumer_lowers_the_shared_rung() { + let mut steps = vec![0usize, 0, 0]; + assert_eq!(fold_rung(&steps), 0); + // A consumer's link degrades to step 4 (the coarsest) -> shared rung = 4. + steps.push(4); + assert_eq!(fold_rung(&steps), bottom_step().min(4)); + } + + #[test] + fn fold_rung_over_zero_consumers_does_not_panic() { + // The last-leave case: an empty fold returns the finest rung, no panic. + assert_eq!(fold_rung(&[]), TOP_STEP); + } + + // --- join PLI suppression ------------------------------------------------ + + #[test] + fn a_joining_consumers_early_pli_is_suppressed() { + // Inside the grace window a joiner's PLI must not trigger a shared restart + // (which would blip every other consumer); the joiner waits for the next + // periodic IDR. + assert!(!should_honor_pli(0.0, JOINER_PLI_GRACE_S)); + assert!(!should_honor_pli( + JOINER_PLI_GRACE_S - 0.5, + JOINER_PLI_GRACE_S + )); + // An established track past the window: a PLI is real loss and is honoured. + assert!(should_honor_pli(JOINER_PLI_GRACE_S, JOINER_PLI_GRACE_S)); + assert!(should_honor_pli(10.0, JOINER_PLI_GRACE_S)); + } + + // --- fan-out set routing ------------------------------------------------- + + fn fan_track(raw_id: i32, open: bool) -> FanTrack { + let flag = open_flag(); + flag.store(open, Ordering::SeqCst); + FanTrack { + raw_id, + open: flag, + control: Arc::new(TrackControl::default()), + joined_at: Instant::now(), + } + } + + #[test] + fn the_fanout_set_updates_on_add_and_remove() { + let source = VideoSource::default(); + // Two consumers subscribe; both are in the fan-out set. + lock(&source.fanout).insert("c1".into(), fan_track(1, true)); + lock(&source.fanout).insert("c2".into(), fan_track(2, true)); + assert_eq!(lock(&source.fanout).len(), 2); + + // One consumer leaves -> it stops receiving (drops out of the set); the + // other is untouched. + lock(&source.fanout).remove("c1"); + let fan = lock(&source.fanout); + assert_eq!(fan.len(), 1); + assert!(fan.contains_key("c2")); + assert!(!fan.contains_key("c1")); + } + + #[test] + fn open_rungs_only_counts_open_tracks_and_folds_to_the_worst() { + let source = VideoSource::default(); + let a = fan_track(1, true); + a.control.set_step(1); + let b = fan_track(2, true); + b.control.set_step(3); + let c = fan_track(3, false); // still negotiating: excluded from the fold + c.control.set_step(4); + lock(&source.fanout).insert("a".into(), a); + lock(&source.fanout).insert("b".into(), b); + lock(&source.fanout).insert("c".into(), c); + + let steps = open_rungs(&source); + assert_eq!(steps.len(), 2, "the not-yet-open track is excluded"); + // The shared rung is the worst (coarsest) of the open tracks. + assert_eq!(fold_rung(&steps), 3); + } + + #[test] + fn coalesce_pli_suppresses_a_joiner_but_honours_an_established_track() { + let source = VideoSource::default(); + // An established track (joined long ago) with a pending PLI: honoured. + let mut established = fan_track(1, true); + established.joined_at = Instant::now() - std::time::Duration::from_secs(10); + established.control.request_pli(); + lock(&source.fanout).insert("old".into(), established); + assert!( + coalesce_pli(&source), + "an established track's PLI is honoured" + ); + + // A fresh joiner with a pending PLI: suppressed (drained, not honoured). + let joiner = fan_track(2, true); + joiner.control.request_pli(); + lock(&source.fanout).insert("new".into(), joiner); + assert!( + !coalesce_pli(&source), + "a freshly joined consumer's PLI must not restart the shared encode" + ); + } +} diff --git a/rust/neuracore_webrtc/src/congestion.rs b/rust/neuracore_webrtc/src/congestion.rs new file mode 100644 index 000000000..f0835c9d1 --- /dev/null +++ b/rust/neuracore_webrtc/src/congestion.rs @@ -0,0 +1,671 @@ +//! Queue-driven congestion adaptation: the RTCP feedback the producer reads, and +//! the lightweight estimator that turns it into a rung on the [`LADDER`]. +//! +//! libdatachannel 0.23.2 implements **no** transport-cc and does **no** bandwidth +//! estimation — see `reports/SPIKE-pr5-media-chain.md`. The only two real signals +//! the producer can read are: +//! +//! * **REMB** — the receiver-estimated max bitrate, delivered through the chain's +//! [`rtcChainRembHandler`] callback. A real browser (Chrome) computes this from +//! its own receive-side bandwidth estimator; the libdatachannel loopback +//! consumer only echoes whatever `rtcRequestBitrate` was set to, so REMB is the +//! *Chrome-path* driver. +//! * **RTCP RR** — receiver reports carrying `fraction_lost` and `jitter`, which +//! the consumer's receiving session computes from the real sequence-number gaps. +//! RR is the *loopback-path* driver (it reacts to netem rate/loss for real, +//! where loopback REMB cannot). +//! +//! Neither signal is transport-cc and there is no full GCC here: the estimator is +//! a deliberately small loss/headroom controller over a fixed [`LADDER`]. It is a +//! pure seam — every decision is a function of the samples it is fed and a clock +//! passed in — so it is exercised by a fake clock in the unit tests below without +//! a peer, a socket, or live media. +//! +//! ## What lives here +//! +//! - [`Step`] / [`LADDER`] — the adaptation rungs (input-fps cap, resolution +//! scale, target encoder bitrate). +//! - [`parse_rtcp_reports`] — hand-parses fraction-lost/jitter out of a compound +//! RTCP RR (or SR) packet, because the C API exposes no RR decoder. +//! - [`Estimator`] — folds REMB + RR samples into a ladder step, degrading on +//! pressure and recovering conservatively (sticky). + +use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::Arc; + +/// One rung of the adaptation ladder. `fps_cap` caps the *input* frame rate fed +/// to the encoder (shed toward the 30 fps floor first); `scale` divides each +/// spatial axis (1 = full resolution, 2 = half); `bitrate` is the encoder's +/// target bits/sec. Coarser rungs combine a lower fps, a lower bitrate, and +/// eventually a downscale. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct Step { + pub fps_cap: u32, + pub scale: u32, + pub bitrate: u32, +} + +/// The fixed adaptation ladder, finest (index 0) to coarsest. Step 0 is +/// effectively unconstrained for a 45 fps source; step 1 sheds input fps to the +/// 30 fps floor (the cheap first move); steps 2+ lower the bitrate and finally +/// downscale for larger or sustained pressure. The 30 fps floor is held from +/// step 1 on, so adaptation never sheds *delivered* fps below the contract floor. +pub(crate) const LADDER: &[Step] = &[ + Step { fps_cap: 60, scale: 1, bitrate: 2_500_000 }, + Step { fps_cap: 40, scale: 1, bitrate: 1_500_000 }, + Step { fps_cap: 36, scale: 1, bitrate: 900_000 }, + Step { fps_cap: 36, scale: 2, bitrate: 500_000 }, + Step { fps_cap: 36, scale: 2, bitrate: 300_000 }, +]; + +/// The step the encoder starts on (finest). +pub(crate) const TOP_STEP: usize = 0; + +/// Highest (coarsest) ladder index. +pub(crate) fn bottom_step() -> usize { + LADDER.len() - 1 +} + +// --- tuning ---------------------------------------------------------------- + +/// REMB below `committed_bitrate * REMB_PRESSURE` signals the link cannot carry +/// the current rung: degrade. (Headroom for protocol overhead/burstiness.) +const REMB_PRESSURE: f64 = 0.85; +/// REMB below `committed_bitrate * REMB_SEVERE` is acute under-provisioning: +/// degrade immediately rather than waiting out the degrade window. +const REMB_SEVERE: f64 = 0.5; +/// RR fraction-lost above this sustained for the degrade window is pressure. +const LOSS_PRESSURE: f64 = 0.02; +/// RR fraction-lost above this is acute loss: degrade immediately. +const LOSS_SEVERE: f64 = 0.10; +/// Sustained pressure must persist this long before a mild degrade fires (so a +/// single noisy sample does not move the ladder). +const DEGRADE_WINDOW_S: f64 = 1.5; +/// Recovery is deliberately slow and sticky: the link must look clear this much +/// longer than a degrade before the estimator steps back up one rung. It is an +/// order of magnitude longer than the degrade window so the ladder settles on a +/// fitting rung and does not oscillate back into loss — a recovery into a rung +/// the link cannot carry re-loses and re-degrades, and that lossy excursion shows +/// as corruption (inter-frame error propagation). Recovery is especially cautious +/// because the only headroom signal here (loopback REMB) is a fixed echo that +/// cannot veto a premature step-up. +const RECOVERY_WINDOW_S: f64 = 20.0; +/// Only recover when REMB shows headroom for the *finer* rung we would move to +/// (its bitrate times this margin), so we do not immediately re-degrade. +const RECOVERY_HEADROOM: f64 = 1.25; + +/// One feedback observation. Either field may be absent: the loopback path only +/// has useful RR, the Chrome path drives REMB. The estimator uses whichever is +/// present, taking the worse of the two when both are. +#[derive(Debug, Clone, Copy, Default)] +pub(crate) struct Sample { + /// Receiver-estimated max bitrate in bits/sec, if a REMB arrived. + pub remb_bps: Option, + /// RR fraction lost in 0.0..=1.0, if an RR arrived. + pub fraction_lost: Option, +} + +/// The pure congestion estimator. Folds [`Sample`]s into a [`LADDER`] index, +/// degrading promptly under pressure and recovering slowly when the link looks +/// clear. Holds no clock of its own — every method takes `now_s` — so the unit +/// tests drive it deterministically. +#[derive(Debug)] +pub(crate) struct Estimator { + step: usize, + /// When the link first looked pressured since the last clear sample. + pressured_since: Option, + /// When the link first looked clear since the last pressured sample. + clear_since: Option, + /// Last REMB seen, carried so a recovery decision can check headroom even on + /// an RR-only sample. + last_remb: Option, +} + +impl Default for Estimator { + fn default() -> Self { + Self { + step: TOP_STEP, + pressured_since: None, + clear_since: None, + last_remb: None, + } + } +} + +impl Estimator { + /// The current ladder index. (Read by the tests; the live path uses the + /// value [`observe`](Self::observe) returns.) + #[allow(dead_code)] + pub(crate) fn step(&self) -> usize { + self.step + } + + /// Fold one observation in at time `now_s` and return the (possibly changed) + /// ladder index. Degrades immediately on severe pressure, after + /// [`DEGRADE_WINDOW_S`] on mild pressure, and recovers one rung only after a + /// longer clear window with REMB headroom. + pub(crate) fn observe(&mut self, sample: Sample, now_s: f64) -> usize { + if let Some(remb) = sample.remb_bps { + self.last_remb = Some(remb); + } + let committed = LADDER[self.step].bitrate as f64; + + let remb_ratio = sample.remb_bps.map(|r| r as f64 / committed); + let loss = sample.fraction_lost.unwrap_or(0.0); + + let severe = matches!(remb_ratio, Some(r) if r < REMB_SEVERE) || loss > LOSS_SEVERE; + let mild = severe + || matches!(remb_ratio, Some(r) if r < REMB_PRESSURE) + || loss > LOSS_PRESSURE; + + if mild { + self.clear_since = None; + let since = *self.pressured_since.get_or_insert(now_s); + let sustained = now_s - since >= DEGRADE_WINDOW_S; + if (severe || sustained) && self.step < bottom_step() { + self.step += 1; + // Reset the windows: re-measure pressure/clearness against the + // new rung rather than carrying the old timer across a move. + self.pressured_since = None; + self.clear_since = None; + } + } else { + self.pressured_since = None; + let since = *self.clear_since.get_or_insert(now_s); + let clear_long = now_s - since >= RECOVERY_WINDOW_S; + if clear_long && self.step > TOP_STEP && self.has_recovery_headroom() { + self.step -= 1; + self.pressured_since = None; + self.clear_since = None; + } + } + self.step + } + + /// Whether REMB (if known) shows enough headroom to move up to the next finer + /// rung without immediately re-degrading. With no REMB (the loopback path), + /// a clear loss window alone authorises recovery. + fn has_recovery_headroom(&self) -> bool { + match self.last_remb { + Some(remb) => { + let finer = LADDER[self.step - 1].bitrate as f64; + remb as f64 >= finer * RECOVERY_HEADROOM + } + None => true, + } + } +} + +// --------------------------------------------------------------------------- +// RTCP RR / SR report-block parsing (hand-rolled; the C API decodes none) +// --------------------------------------------------------------------------- + +/// One RTCP report block's loss + jitter figures, as parsed off the wire. +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct ReportBlock { + /// SSRC the report is *about* (our outgoing media SSRC). + pub ssrc: u32, + /// Loss fraction since the last report, in 0.0..=1.0 (the wire byte / 256). + pub fraction_lost: f64, + /// Cumulative packets lost (24-bit signed on the wire; widened here). + pub cumulative_lost: u32, + /// Interarrival jitter in RTP timestamp units. + pub jitter: u32, +} + +const RTCP_HEADER_LEN: usize = 4; +const PT_SR: u8 = 200; +const PT_RR: u8 = 201; +const REPORT_BLOCK_LEN: usize = 24; +/// Bytes from an RR packet's start to its first report block: the 4-byte common +/// header plus the 4-byte reporter SSRC. +const RR_REPORTS_OFFSET: usize = 8; +/// Bytes from an SR packet's start to its first report block: the 4-byte common +/// header, the 4-byte reporter SSRC, and the 20-byte sender info. +const SR_REPORTS_OFFSET: usize = 28; + +/// Parse every report block out of a (possibly compound) RTCP packet. Walks each +/// sub-packet by its length field and pulls report blocks from RR (PT=201) and +/// SR (PT=200) packets; all other packet types (REMB/PLI/NACK/BYE/SDES) are +/// skipped. Tolerates a truncated or non-RTCP buffer by returning what it could +/// read. Pure — unit-tested against synthetic compound RTCP. +pub(crate) fn parse_rtcp_reports(buf: &[u8]) -> Vec { + let mut out = Vec::new(); + let mut off = 0usize; + while off + RTCP_HEADER_LEN <= buf.len() { + let rc = (buf[off] & 0x1F) as usize; // reception report count + let pt = buf[off + 1]; + // length is in 32-bit words minus one; total packet bytes = (len+1)*4. + let len_words = u16::from_be_bytes([buf[off + 2], buf[off + 3]]) as usize; + let packet_len = (len_words + 1) * 4; + if packet_len < RTCP_HEADER_LEN || off + packet_len > buf.len() { + break; + } + let reports_at = match pt { + PT_RR => Some(off + RR_REPORTS_OFFSET), + PT_SR => Some(off + SR_REPORTS_OFFSET), + _ => None, + }; + if let Some(mut block_off) = reports_at { + for _ in 0..rc { + if block_off + REPORT_BLOCK_LEN > off + packet_len { + break; + } + out.push(parse_report_block(&buf[block_off..block_off + REPORT_BLOCK_LEN])); + block_off += REPORT_BLOCK_LEN; + } + } + off += packet_len; + } + out +} + +/// Decode one 24-byte RTCP report block. +fn parse_report_block(b: &[u8]) -> ReportBlock { + let ssrc = u32::from_be_bytes([b[0], b[1], b[2], b[3]]); + let fraction_lost = b[4] as f64 / 256.0; + let cumulative_lost = u32::from_be_bytes([0, b[5], b[6], b[7]]); + let jitter = u32::from_be_bytes([b[12], b[13], b[14], b[15]]); + ReportBlock { + ssrc, + fraction_lost, + cumulative_lost, + jitter, + } +} + +// --------------------------------------------------------------------------- +// Shared effect surface +// --------------------------------------------------------------------------- + +/// The per-track effect surface the estimator writes and the encoder feed reads. +/// The estimator (driven on libdatachannel's RTCP callback threads) only flips +/// these atomics; the feed thread applies them — caps input fps and restarts the +/// ffmpeg subprocess at the new rung — so no libdatachannel callback ever blocks +/// on an ffmpeg restart. `pli_pending` coalesces PLIs into a single keyframe +/// request the feed satisfies via a restart. +#[derive(Debug)] +pub(crate) struct TrackControl { + desired_step: AtomicU32, + /// The coarsest rung ever requested (a high-water mark), so a test or operator + /// can see that adaptation fired even after the link recovered and the rung + /// stepped back up. + max_step: AtomicU32, + pli_pending: std::sync::atomic::AtomicBool, +} + +impl Default for TrackControl { + fn default() -> Self { + Self { + desired_step: AtomicU32::new(TOP_STEP as u32), + max_step: AtomicU32::new(TOP_STEP as u32), + pli_pending: std::sync::atomic::AtomicBool::new(false), + } + } +} + +impl TrackControl { + /// Publish the ladder rung the estimator now wants, advancing the high-water + /// mark if this is the coarsest rung seen so far. + pub(crate) fn set_step(&self, step: usize) { + self.desired_step.store(step as u32, Ordering::SeqCst); + self.max_step.fetch_max(step as u32, Ordering::SeqCst); + } + + /// The rung the feed thread should be encoding at. + pub(crate) fn desired_step(&self) -> usize { + self.desired_step.load(Ordering::SeqCst) as usize + } + + /// The coarsest rung the estimator ever requested. + pub(crate) fn max_step(&self) -> usize { + self.max_step.load(Ordering::SeqCst) as usize + } + + /// Coalesce a PLI: record that a keyframe is wanted. + pub(crate) fn request_pli(&self) { + self.pli_pending.store(true, Ordering::SeqCst); + } + + /// Take the coalesced PLI request (true at most once per burst). + pub(crate) fn take_pli(&self) -> bool { + self.pli_pending.swap(false, Ordering::SeqCst) + } +} + +/// The estimator plus its effect surface, shared between the RTCP callbacks +/// (which observe) and the feed thread (which applies). The estimator is behind a +/// `Mutex` because the REMB callback, the RR callback, and the PLI callback all +/// fire on libdatachannel threads. +/// How long after the first feedback sample the controller ignores REMB/RR. The +/// connection's startup (the stash flush, the first IDR, ICE settling) produces a +/// transient loss/jitter spike on an otherwise clean link; acting on it would +/// degrade quality on a link that is actually fine. A real sustained constraint +/// (netem) persists well past this, so the warmup only suppresses the transient. +const WARMUP_S: f64 = 2.5; + +pub(crate) struct CongestionController { + estimator: std::sync::Mutex, + control: Arc, + /// Our committed outgoing SSRC, so RR report blocks for other SSRCs are + /// ignored. + ssrc: u32, + /// Seconds the controller suppresses feedback after its first sample. + warmup_s: f64, + /// Wall-clock seconds of the first observed sample, set lazily. + started_at: std::sync::Mutex>, +} + +impl CongestionController { + pub(crate) fn new(ssrc: u32, control: Arc) -> Self { + Self::with_warmup(ssrc, control, WARMUP_S) + } + + /// Construct with an explicit warmup (0.0 disables it, for deterministic + /// unit tests of the REMB/RR wiring). + pub(crate) fn with_warmup(ssrc: u32, control: Arc, warmup_s: f64) -> Self { + Self { + estimator: std::sync::Mutex::new(Estimator::default()), + control, + ssrc, + warmup_s, + started_at: std::sync::Mutex::new(None), + } + } + + /// Whether the controller is still inside its startup warmup at `now`. + fn in_warmup(&self, now: f64) -> bool { + let mut started = self.started_at.lock().unwrap_or_else(|e| e.into_inner()); + let start = *started.get_or_insert(now); + now - start < self.warmup_s + } + + /// A monotonic seconds clock for the live path (the unit tests pass their own). + fn now_s() -> f64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0) + } + + /// Feed a REMB estimate (bits/sec) in. + pub(crate) fn on_remb(&self, bitrate_bps: u32) { + if self.in_warmup(Self::now_s()) { + return; + } + let step = { + let mut est = self.estimator.lock().unwrap_or_else(|e| e.into_inner()); + est.observe( + Sample { + remb_bps: Some(bitrate_bps), + ..Default::default() + }, + Self::now_s(), + ) + }; + self.control.set_step(step); + } + + /// Feed raw inbound RTCP in: parse RR report blocks for our SSRC and fold + /// their loss/jitter into the estimator. + pub(crate) fn on_rtcp(&self, buf: &[u8]) { + if self.in_warmup(Self::now_s()) { + return; + } + let mut applied = None; + for block in parse_rtcp_reports(buf) { + if block.ssrc != self.ssrc { + continue; + } + // Jitter is parsed for observability (the report records it); loss + // fraction is what actually drives the ladder. + crate::transport::debug_trace( + "P", + &format!( + "rr ssrc={:#x} loss={:.3} jitter={}", + block.ssrc, block.fraction_lost, block.jitter + ), + ); + let mut est = self.estimator.lock().unwrap_or_else(|e| e.into_inner()); + applied = Some(est.observe( + Sample { + fraction_lost: Some(block.fraction_lost), + ..Default::default() + }, + Self::now_s(), + )); + } + if let Some(step) = applied { + self.control.set_step(step); + } + } + + /// A PLI arrived: coalesce it into the feed's keyframe request. + pub(crate) fn on_pli(&self) { + self.control.request_pli(); + } +} + +#[cfg(test)] +mod tests { + //! Peer-free, clock-injected tests for the estimator, the RR parser, and the + //! REMB-consumption path. None touch a socket, a peer, or live media. + + use super::*; + + // --- RR / SR parsing ----------------------------------------------------- + + /// Build one RTCP RR with a single report block carrying `fraction_lost` + /// (a raw byte) and `jitter`, reporting on `ssrc`. + fn rr_packet(reporter: u32, ssrc: u32, fraction_lost: u8, cumulative: u32, jitter: u32) -> Vec { + let mut p = Vec::new(); + p.push(0x80 | 1); // V=2, P=0, RC=1 + p.push(PT_RR); + // length in words minus one: header(1) + reporter ssrc(1) + 6 words block = 8 -> len 7 + p.extend_from_slice(&7u16.to_be_bytes()); + p.extend_from_slice(&reporter.to_be_bytes()); + // report block + p.extend_from_slice(&ssrc.to_be_bytes()); + p.push(fraction_lost); + p.extend_from_slice(&cumulative.to_be_bytes()[1..]); // 24-bit + p.extend_from_slice(&0u32.to_be_bytes()); // ext highest seq + p.extend_from_slice(&jitter.to_be_bytes()); + p.extend_from_slice(&0u32.to_be_bytes()); // lsr + p.extend_from_slice(&0u32.to_be_bytes()); // dlsr + p + } + + #[test] + fn parses_fraction_lost_and_jitter_from_an_rr() { + // 13% loss -> byte 33 (33/256 ~= 0.129). + let pkt = rr_packet(0x1111, 0xCAFE, 33, 7, 4096); + let blocks = parse_rtcp_reports(&pkt); + assert_eq!(blocks.len(), 1); + assert_eq!(blocks[0].ssrc, 0xCAFE); + assert!((blocks[0].fraction_lost - 33.0 / 256.0).abs() < 1e-9); + assert_eq!(blocks[0].cumulative_lost, 7); + assert_eq!(blocks[0].jitter, 4096); + } + + #[test] + fn parses_report_blocks_out_of_a_compound_packet_after_an_sr() { + // Compound: a minimal SR (no report blocks) followed by an RR with one. + let mut sr = Vec::new(); + sr.push(0x80); // V=2, RC=0 + sr.push(PT_SR); + sr.extend_from_slice(&6u16.to_be_bytes()); // header+ssrc+sender info = 7 words -> len 6 + sr.extend_from_slice(&0x2222u32.to_be_bytes()); // ssrc + sr.extend_from_slice(&[0u8; 20]); // sender info + let rr = rr_packet(0x2222, 0xBEEF, 8, 2, 100); + let mut compound = sr; + compound.extend_from_slice(&rr); + + let blocks = parse_rtcp_reports(&compound); + assert_eq!(blocks.len(), 1, "the SR carried no blocks; the RR carried one"); + assert_eq!(blocks[0].ssrc, 0xBEEF); + assert_eq!(blocks[0].jitter, 100); + } + + #[test] + fn tolerates_reordered_and_truncated_rtcp_without_panicking() { + // A truncated tail after a valid RR: parser returns the good block and + // stops at the bad length rather than reading out of bounds. + let mut pkt = rr_packet(1, 2, 5, 1, 9); + pkt.extend_from_slice(&[0x81, PT_RR, 0xFF, 0xFF]); // claims a huge length + let blocks = parse_rtcp_reports(&pkt); + assert_eq!(blocks.len(), 1); + // Pure garbage parses to nothing, no panic. + assert!(parse_rtcp_reports(&[0xAA, 0xBB]).is_empty()); + } + + // --- estimator: degrade on loss (the loopback/RR path) ------------------- + + fn loss(f: f64) -> Sample { + Sample { + fraction_lost: Some(f), + ..Default::default() + } + } + fn remb(bps: u32) -> Sample { + Sample { + remb_bps: Some(bps), + ..Default::default() + } + } + + #[test] + fn severe_loss_degrades_immediately_one_rung() { + let mut est = Estimator::default(); + assert_eq!(est.step(), 0); + // 20% loss is above LOSS_SEVERE -> immediate single-rung degrade. + assert_eq!(est.observe(loss(0.20), 0.0), 1); + assert_eq!(est.observe(loss(0.20), 0.1), 2); + } + + #[test] + fn mild_loss_only_degrades_after_the_sustained_window() { + let mut est = Estimator::default(); + // 3% loss is mild (> LOSS_PRESSURE, < LOSS_SEVERE): no move yet. + assert_eq!(est.observe(loss(0.03), 0.0), 0); + assert_eq!(est.observe(loss(0.03), 1.0), 0, "still inside the degrade window"); + // Past DEGRADE_WINDOW_S of sustained mild pressure -> one rung down. + assert_eq!(est.observe(loss(0.03), 1.6), 1); + } + + #[test] + fn a_single_clear_sample_resets_the_degrade_window() { + let mut est = Estimator::default(); + assert_eq!(est.observe(loss(0.03), 0.0), 0); + // A clear sample mid-window cancels the pending degrade. + assert_eq!(est.observe(loss(0.0), 1.0), 0); + assert_eq!(est.observe(loss(0.03), 1.4), 0, "window restarts from the new pressure"); + assert_eq!(est.observe(loss(0.03), 3.0), 1); + } + + // --- estimator: REMB path (the Chrome path) ------------------------------ + + #[test] + fn remb_below_severe_fraction_degrades_immediately() { + let mut est = Estimator::default(); + // committed at step 0 is 2.5Mbit; REMB 1.0Mbit is < 0.5x -> severe. + assert_eq!(est.observe(remb(1_000_000), 0.0), 1); + } + + #[test] + fn remb_mild_pressure_waits_for_the_window() { + let mut est = Estimator::default(); + // step1 committed 1.5Mbit; REMB 1.2Mbit is 0.8x (< 0.85 pressure, > 0.5). + est.observe(loss(0.20), 0.0); // -> step 1 fast + assert_eq!(est.step(), 1); + assert_eq!(est.observe(remb(1_200_000), 10.0), 1, "mild, window restarts"); + assert_eq!(est.observe(remb(1_200_000), 11.6), 2, "sustained -> degrade"); + } + + // --- estimator: conservative, sticky recovery ---------------------------- + + #[test] + fn recovery_is_slow_sticky_and_needs_headroom() { + let mut est = Estimator::default(); + // Drive down to step 2 on severe loss. + est.observe(loss(0.20), 0.0); + est.observe(loss(0.20), 0.1); + assert_eq!(est.step(), 2); + + // Clear loss but no REMB headroom info yet: with no REMB, a long clear + // window recovers one rung (loopback path). + assert_eq!(est.observe(loss(0.0), 1.0), 2, "inside the recovery window"); + assert_eq!(est.observe(loss(0.0), 15.0), 2, "still inside recovery window"); + assert_eq!( + est.observe(loss(0.0), 21.5), + 1, + "recovers exactly one rung only after the long clear window (since 1.0)" + ); + // Recovery is one rung at a time: another full clear window for the next. + assert_eq!(est.observe(loss(0.0), 22.0), 1); + assert_eq!(est.observe(loss(0.0), 42.5), 0); + } + + #[test] + fn remb_recovery_requires_headroom_for_the_finer_rung() { + let mut est = Estimator::default(); + est.observe(loss(0.20), 0.0); + est.observe(loss(0.20), 0.1); + assert_eq!(est.step(), 2); // step 1 bitrate is 1.5Mbit + + // Clear loss, but REMB only 1.6Mbit: finer rung (step 1) needs + // 1.5M * 1.25 = 1.875M of headroom, so even past the clear window + // recovery is withheld. + est.observe(remb(1_600_000), 1.0); + assert_eq!( + est.observe(remb(1_600_000), 22.0), + 2, + "past the window but no headroom -> stay put" + ); + // REMB rises above the headroom bar -> recover one rung. + est.observe(remb(2_000_000), 23.0); + assert_eq!(est.observe(remb(2_000_000), 44.0), 1); + } + + #[test] + fn never_degrades_past_the_bottom_rung() { + let mut est = Estimator::default(); + for t in 0..20 { + est.observe(loss(0.5), t as f64 * 0.1); + } + assert_eq!(est.step(), bottom_step()); + } + + // --- controller wiring: REMB + RR -> the shared control ------------------ + + #[test] + fn controller_publishes_the_estimator_step_from_remb() { + let control = Arc::new(TrackControl::default()); + let ctrl = CongestionController::with_warmup(0xABCD, control.clone(), 0.0); + assert_eq!(control.desired_step(), 0); + ctrl.on_remb(900_000); // < 0.5x of step0's 2.5Mbit -> degrade + assert_eq!(control.desired_step(), 1); + } + + #[test] + fn controller_only_acts_on_rr_blocks_for_our_ssrc() { + let control = Arc::new(TrackControl::default()); + let ctrl = CongestionController::with_warmup(0xABCD, control.clone(), 0.0); + // RR about a different ssrc is ignored. + ctrl.on_rtcp(&rr_packet(1, 0x9999, 200, 50, 0)); + assert_eq!(control.desired_step(), 0); + // RR about our ssrc with severe loss degrades. + ctrl.on_rtcp(&rr_packet(1, 0xABCD, 200, 50, 0)); + assert_eq!(control.desired_step(), 1); + } + + #[test] + fn pli_coalesces_into_a_single_keyframe_request() { + let control = Arc::new(TrackControl::default()); + let ctrl = CongestionController::with_warmup(1, control.clone(), 0.0); + ctrl.on_pli(); + ctrl.on_pli(); + ctrl.on_pli(); + assert!(control.take_pli(), "a burst of PLIs is one pending request"); + assert!(!control.take_pli(), "taken exactly once"); + } +} diff --git a/rust/neuracore_webrtc/src/consumer.rs b/rust/neuracore_webrtc/src/consumer.rs new file mode 100644 index 000000000..3b6adb5ed --- /dev/null +++ b/rust/neuracore_webrtc/src/consumer.rs @@ -0,0 +1,553 @@ +//! The consumer peer: answer-only. It never offers and never opens channels; it +//! receives the producer's offer, lets libdatachannel auto-answer it, and +//! surfaces remote data channels, their messages, the control-channel manifest, +//! and connection state on its drainable event queue. +//! +//! ## Video tracks are observed via the manifest, not a track callback +//! +//! datachannel-rs (libdatachannel) exposes **no** incoming-track callback on the +//! peer-connection handler, so the consumer cannot learn of a producer-added video +//! track from the SDP renegotiation directly. Instead the producer republishes the +//! control-channel manifest atomically on every track add/remove, and the consumer +//! derives `on_track_added` / `on_track_removed` by diffing each manifest against +//! the previously-known video-track set. The manifest is the canonical +//! stream-identity channel (see the locked design), so this is the authoritative +//! signal — the SDP renegotiation still happens underneath and brings the media +//! m-line up; PR4's receive/decode path keys off the same mid. + +use std::collections::HashMap; +use std::os::raw::{c_char, c_int, c_void}; +use std::sync::atomic::AtomicBool; +use std::sync::{Arc, Mutex}; + +use crate::media::RestartPolicy; + +use datachannel::{ + ConnectionState, DataChannelHandler, DataChannelInfo, IceCandidate, IceState, + PeerConnectionHandler, RtcDataChannel, RtcPeerConnection, SdpType, SessionDescription, +}; +use datachannel_sys as sys; +use once_cell::sync::Lazy; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyList; +use serde_json::Value; + +use crate::events::{emit_closed_once, Event, EventQueue}; +use crate::media::{H264Decoder, RtpDepacketizer}; +use crate::runtime::ensure_started; +use crate::transport::{ + connection_state_str, debug_trace, lock, loopback_config, map_err, parse_session, raw_pc_id, + reliability_kind_hint, sdp_type_str, CONTROL_LABEL, +}; + +/// The bitrate (bits/sec) the consumer requests via `rtcRequestBitrate`, which is +/// what makes its receiving session emit REMB toward the producer. On the +/// libdatachannel loopback this is a fixed echoed number (the library does no +/// bandwidth estimation); a real browser computes its own REMB. We request a high +/// ceiling so REMB never *itself* throttles — the producer's estimator owns the +/// adaptation. See `reports/SPIKE-pr5-media-chain.md` §3. +const REQUESTED_BITRATE_BPS: u32 = 8_000_000; + +/// Decoded-frame dimensions. The synthetic source is fixed 640x480 rgb24; the +/// producer encodes that and the consumer decodes it back to the same shape. A +/// later PR can carry per-track dimensions in the manifest if sources vary. +const FRAME_WIDTH: u32 = 640; +const FRAME_HEIGHT: u32 = 480; + +/// Shared mid -> track_id view, written by the control-channel manifest diff and +/// read by an inbound track's frame emitter so `on_frame` carries the app track id. +type MidToTrack = Arc>>; + +/// Inbound channels kept alive for the connection's lifetime. The `Box` is +/// mandatory and cannot be elided: libdatachannel stores a raw pointer to each +/// `RtcDataChannel`'s heap location (`rtcSetUserPointer`), so the value must not +/// move — clippy's `vec_box` suggestion to unbox is wrong here. +#[allow(clippy::vec_box)] +type IncomingChannels = Vec>>; + +/// Per-data-channel handler for the consumer's inbound channels. Surfaces +/// application messages as `on_message`, and control-channel payloads as +/// `on_manifest` plus the derived `on_track_added` / `on_track_removed`. +pub(crate) struct ConsumerChannelHandler { + events: EventQueue, + label: String, + is_control: bool, + /// Video tracks known from the last manifest, mid -> track_id. Only the + /// control channel's handler uses this; it persists for the connection's + /// lifetime (the control channel opens once), and on_message is `&mut self`. + known_tracks: HashMap, + /// The shared mid -> track_id view the inbound-track frame emitter reads, so + /// each decoded `on_frame` carries the application track id. Updated on every + /// manifest from the control channel's handler. + mid_to_track: MidToTrack, +} + +impl DataChannelHandler for ConsumerChannelHandler { + fn on_message(&mut self, msg: &[u8]) { + if self.is_control { + // The manifest is JSON text; ignore anything non-UTF-8 on control. + if let Ok(json) = std::str::from_utf8(msg) { + self.diff_tracks(json); + self.events.push(Event::Manifest { + json: json.to_string(), + }); + } + } else { + self.events.push(Event::Message { + label: self.label.clone(), + data: msg.to_vec(), + }); + } + } +} + +impl ConsumerChannelHandler { + /// Diff this manifest's video-track entries against the previously-known set + /// and emit `on_track_added` (new mids) and `on_track_removed` (vanished + /// mids). Idempotent: a manifest republished for a non-track change (e.g. a + /// data-channel add) produces no track events. + fn diff_tracks(&mut self, json: &str) { + let Ok(Value::Object(map)) = serde_json::from_str::(json) else { + // Not a manifest object; leave the known set untouched. + return; + }; + let mut current: HashMap = HashMap::new(); + for (key, value) in &map { + if value.get("type").and_then(Value::as_str) == Some("video_track") { + let track_id = value + .get("track_id") + .and_then(Value::as_str) + .unwrap_or_default() + .to_string(); + current.insert(key.clone(), track_id); + } + } + + for (mid, track_id) in ¤t { + if !self.known_tracks.contains_key(mid) { + self.events.push(Event::TrackAdded { + track_id: track_id.clone(), + mid: mid.clone(), + }); + } + } + for mid in self.known_tracks.keys() { + if !current.contains_key(mid) { + self.events.push(Event::TrackRemoved { mid: mid.clone() }); + } + } + // Publish the latest mid -> track_id view for the inbound-track emitter. + *lock(&self.mid_to_track) = current.clone(); + self.known_tracks = current; + } +} + +/// Peer-connection handler for the consumer: relays signaling/state callbacks +/// and adopts inbound data channels. +pub(crate) struct ConsumerHandler { + events: EventQueue, + /// Inbound channels are kept alive here; dropping a channel deletes it in + /// libdatachannel, which would stop delivering its messages. + incoming: Arc>, + /// Shared mid -> track_id view, handed to each channel handler so the control + /// channel's manifest diff can publish it for the inbound-track emitter. + mid_to_track: MidToTrack, +} + +impl PeerConnectionHandler for ConsumerHandler { + type DCH = ConsumerChannelHandler; + + fn data_channel_handler(&mut self, info: DataChannelInfo) -> Self::DCH { + ConsumerChannelHandler { + events: self.events.clone(), + label: info.label.clone(), + is_control: info.label == CONTROL_LABEL, + known_tracks: HashMap::new(), + mid_to_track: self.mid_to_track.clone(), + } + } + + fn on_description(&mut self, sess_desc: SessionDescription) { + // The consumer is answer-only, so this is the auto-generated answer. + self.events.push(Event::LocalDescription { + sdp_type: sdp_type_str(&sess_desc.sdp_type).to_string(), + sdp: sess_desc.sdp.to_string(), + }); + } + + fn on_candidate(&mut self, cand: IceCandidate) { + self.events.push(Event::LocalCandidate { + candidate: cand.candidate, + mid: Some(cand.mid), + }); + } + + fn on_connection_state_change(&mut self, state: ConnectionState) { + crate::transport::debug_trace("C", connection_state_str(&state)); + // The constructor emits the initial "new"; skip the duplicate. + if state == ConnectionState::New { + return; + } + self.events + .push(Event::State(connection_state_str(&state).to_string())); + } + + fn on_ice_state_change(&mut self, state: IceState) { + crate::transport::debug_trace("C", &format!("ice:{state:?}")); + } + + fn on_data_channel(&mut self, data_channel: Box>) { + let label = data_channel.label(); + // The control channel is the manifest transport, not an application + // stream; do not surface it as a data channel. + if label != CONTROL_LABEL { + let kind_hint = reliability_kind_hint(&data_channel.reliability()); + self.events.push(Event::DataChannel { label, kind_hint }); + } + lock(&self.incoming).push(data_channel); + } +} + +// --------------------------------------------------------------------------- +// Inbound media: receive, depacketize, decode, emit on_frame +// --------------------------------------------------------------------------- +// +// datachannel-rs 0.16's `PeerConnectionHandler` has no `on_track`, so the +// consumer cannot learn of a producer-added media track through the safe API. +// We register libdatachannel's C track callback directly via the sys layer +// (`rtcSetTrackCallback`) to adopt each inbound track by its raw id, set a +// message callback on it to receive its RTP, depacketize FU-A back into NAL +// units, feed them to a per-track ffmpeg decoder, and surface each decoded +// picture as an `on_frame` event. +// +// Both C callbacks are plain `extern "C"` functions with no closure environment, +// so they route through process-global registries keyed by the libdatachannel +// integer ids: `MEDIA` maps a peer-connection id to its `ConsumerMedia`, and +// `TRACK_PC` maps an inbound track id back to its peer-connection id. We never +// touch the peer connection's user pointer (datachannel-rs owns it for its own +// callbacks), so the registries are how the callbacks find their context. + +/// Process-global: peer-connection id -> its consumer media context. +static MEDIA: Lazy>>> = Lazy::new(Default::default); +/// Process-global: inbound track id -> the peer-connection id that owns it. +static TRACK_PC: Lazy>> = Lazy::new(Default::default); + +/// The number of consumer media contexts in `MEDIA`. A diagnostics accessor for +/// the soak test's registry-baseline check (no leaked entries after churn). +pub(crate) fn media_registry_len() -> usize { + lock(&MEDIA).len() +} + +/// The number of inbound-track->pc mappings in `TRACK_PC`. A diagnostics accessor +/// for the soak test's registry-baseline check. +pub(crate) fn track_pc_registry_len() -> usize { + lock(&TRACK_PC).len() +} + +/// One inbound track's receive pipeline: a stateful RTP depacketizer feeding a +/// persistent ffmpeg decoder. Dropping it stops the decoder (kills its ffmpeg). +/// The decoder is behind a `Mutex` so a crashed one can be replaced in place +/// (bounded by `restart`) without losing the depacketizer's reassembly state. +struct TrackReceiver { + depacketizer: Mutex, + decoder: Mutex, + /// The track's mid, kept so a restart can rebuild the decoder's frame emitter. + mid: String, + /// Bounded crash-restart budget for this track's decoder. + restart: Mutex, +} + +/// The consumer's media context, shared between the Python-facing `Consumer` and +/// the global registry the C callbacks consult. +struct ConsumerMedia { + events: EventQueue, + mid_to_track: MidToTrack, + receivers: Mutex>, +} + +impl ConsumerMedia { + /// Build a fresh decoder for `mid` whose decoded frames become `on_frame` + /// events keyed to the app track id via the manifest. Used by both [`adopt`] and + /// the crash-restart path, so the frame emitter is identical across a restart. + fn build_decoder(&self, mid: &str) -> std::io::Result { + let events = self.events.clone(); + let mid_to_track = self.mid_to_track.clone(); + let emit_mid = mid.to_string(); + let on_frame = move |data: Vec| { + let track = lock(&mid_to_track) + .get(&emit_mid) + .cloned() + .unwrap_or_default(); + events.push(Event::Frame { + track_id: track, + mid: emit_mid.clone(), + data, + width: FRAME_WIDTH, + height: FRAME_HEIGHT, + }); + }; + H264Decoder::new(FRAME_WIDTH, FRAME_HEIGHT, on_frame) + } + + /// Adopt an inbound track: stand up its decoder and register its receive + /// pipeline. Idempotent per track id. + fn adopt(&self, track_id: i32, mid: String) { + if lock(&self.receivers).contains_key(&track_id) { + return; + } + match self.build_decoder(&mid) { + Ok(decoder) => { + lock(&self.receivers).insert( + track_id, + TrackReceiver { + depacketizer: Mutex::new(RtpDepacketizer::new()), + decoder: Mutex::new(decoder), + mid: mid.clone(), + restart: Mutex::new(RestartPolicy::default()), + }, + ); + } + Err(err) => debug_trace("C", &format!("decoder spawn failed for mid {mid}: {err}")), + } + } + + /// Restart a crashed decoder in place, within the per-track bounded budget, + /// surfacing an `on_error` with ffmpeg's stderr tail. The new decoder recovers + /// on the producer's next periodic IDR. The depacketizer's reassembly state is + /// preserved (it self-recovers on the next clean NAL). + fn restart_decoder(&self, track_id: i32) { + let receivers = lock(&self.receivers); + let Some(receiver) = receivers.get(&track_id) else { + return; + }; + let mid = receiver.mid.clone(); + let detail = lock(&receiver.decoder).stderr_tail(); + if !lock(&receiver.restart).should_restart() { + self.events.push(Event::error( + "decode", + format!("decoder for mid {mid:?} crashed and exceeded the restart budget"), + )); + return; + } + match self.build_decoder(&mid) { + Ok(decoder) => { + // Installing the new decoder drops the old one, killing its ffmpeg. + *lock(&receiver.decoder) = decoder; + self.events.push(Event::error( + "decode", + format!( + "decoder for mid {mid:?} crashed; restarting (ffmpeg: {})", + crate::producer::last_stderr_line(&detail) + ), + )); + } + Err(err) => self.events.push(Event::error( + "decode", + format!("could not respawn decoder for mid {mid:?}: {err}"), + )), + } + } + + /// Feed one inbound RTP packet through the track's depacketizer and into its + /// decoder as Annex-B NAL units, restarting the decoder first if it has crashed. + fn feed(&self, track_id: i32, packet: &[u8]) { + // Detect a crashed decoder (ffmpeg exited) and restart it before feeding, + // rather than silently dropping every inbound packet into a dead pipe. + let dead = { + let receivers = lock(&self.receivers); + receivers + .get(&track_id) + .map(|r| !lock(&r.decoder).is_alive()) + .unwrap_or(false) + }; + if dead { + self.restart_decoder(track_id); + } + let receivers = lock(&self.receivers); + let Some(receiver) = receivers.get(&track_id) else { + return; + }; + if !dead { + // A healthy live decoder: clear the consecutive-crash budget. + lock(&receiver.restart).reset(); + } + let nals = lock(&receiver.depacketizer).depacketize(packet); + let decoder = lock(&receiver.decoder); + for nal in nals { + decoder.feed_nal(&nal); + } + } +} + +/// libdatachannel's track callback: an inbound media track was created. Look up +/// its peer connection's media context, adopt the track by mid, and wire its +/// message callback so its RTP starts flowing to the decoder. +unsafe extern "C" fn on_track_cb(pc: c_int, track: c_int, _ptr: *mut c_void) { + let Some(media) = lock(&MEDIA).get(&pc).cloned() else { + return; + }; + let mid = track_mid(track); + debug_trace("C", &format!("adopt track tr={track} mid={mid}")); + media.adopt(track, mid); + lock(&TRACK_PC).insert(track, pc); + // Attach the built-in RTCP receiving session so the consumer answers the + // producer's SR with RR (loss/jitter) and processes NACK/PLI, and request a + // bitrate so the session emits REMB toward the producer. Inbound media still + // arrives as whole RTP packets to the message callback (the C API has no + // depacketizer), so the FU-A depacketizer is kept; the chain only adds the + // RTCP feedback the producer's estimator reads. + if sys::rtcChainRtcpReceivingSession(track) < 0 { + debug_trace("C", &format!("receiving-session attach failed tr={track}")); + } + if sys::rtcRequestBitrate(track, REQUESTED_BITRATE_BPS) < 0 { + debug_trace("C", &format!("requestBitrate failed tr={track}")); + } + sys::rtcSetMessageCallback(track, Some(on_message_cb)); +} + +/// libdatachannel's per-track message callback: one inbound RTP packet. Binary +/// messages carry a non-negative size; string messages (size < 0) are not media. +unsafe extern "C" fn on_message_cb(id: c_int, msg: *const c_char, size: c_int, _ptr: *mut c_void) { + if size < 0 || msg.is_null() { + return; + } + let packet = std::slice::from_raw_parts(msg as *const u8, size as usize); + let Some(pc) = lock(&TRACK_PC).get(&id).copied() else { + return; + }; + let Some(media) = lock(&MEDIA).get(&pc).cloned() else { + return; + }; + media.feed(id, packet); +} + +/// Read an inbound track's mid via the sys layer (the same two-call size-then-fill +/// pattern datachannel-rs uses internally). +fn track_mid(track: i32) -> String { + unsafe { + let size = sys::rtcGetTrackMid(track, std::ptr::null_mut(), 0); + if size <= 0 { + return String::new(); + } + let mut buf = vec![0u8; size as usize]; + if sys::rtcGetTrackMid(track, buf.as_mut_ptr() as *mut c_char, size) < 0 { + return String::new(); + } + let end = buf.iter().position(|&b| b == 0).unwrap_or(buf.len()); + String::from_utf8_lossy(&buf[..end]).into_owned() + } +} + +/// The consumer-side WebRTC peer exposed to Python. Answer-only by design. +#[pyclass] +pub struct Consumer { + events: EventQueue, + closed: Arc, + pc: Mutex>>>, + incoming: Arc>, + /// The raw peer-connection id this consumer registered media callbacks under, + /// used to deregister and drop its media context on close. + pc_id: Option, +} + +#[pymethods] +impl Consumer { + /// Create an answer-only consumer. `connection_id` is an opaque label used + /// only for logging/correlation. + #[new] + #[pyo3(signature = (connection_id=None))] + fn new(connection_id: Option) -> PyResult { + let _ = connection_id; + ensure_started(); + + let events = EventQueue::default(); + let incoming = Arc::new(Mutex::new(Vec::new())); + let mid_to_track: MidToTrack = Arc::new(Mutex::new(HashMap::new())); + let handler = ConsumerHandler { + events: events.clone(), + incoming: incoming.clone(), + mid_to_track: mid_to_track.clone(), + }; + let pc = RtcPeerConnection::new(&loopback_config(), handler).map_err(map_err)?; + events.push(Event::State("new".to_string())); + + // Register the media receive path under this peer connection's raw id: + // publish its context and install libdatachannel's track callback so + // inbound media tracks are adopted (the only way to receive — the safe + // handler has no on_track). + let pc_id = raw_pc_id(&pc); + if let Some(id) = pc_id { + let media = Arc::new(ConsumerMedia { + events: events.clone(), + mid_to_track, + receivers: Mutex::new(HashMap::new()), + }); + lock(&MEDIA).insert(id, media); + // SAFETY: `id` is this live peer connection's id; the callback only + // touches the process-global registries. + unsafe { sys::rtcSetTrackCallback(id, Some(on_track_cb)) }; + } else { + debug_trace("C", "could not recover pc id; inbound media disabled"); + } + + Ok(Self { + events, + closed: Arc::new(AtomicBool::new(false)), + pc: Mutex::new(Some(pc)), + incoming, + pc_id, + }) + } + + /// Apply the producer's SDP offer. libdatachannel auto-generates the answer, + /// delivered as an `on_local_description` event. + fn set_remote_offer(&self, sdp: &str) -> PyResult<()> { + let sess = parse_session(sdp, SdpType::Offer)?; + let mut guard = lock(&self.pc); + let pc = guard + .as_mut() + .ok_or_else(|| PyValueError::new_err("consumer is closed"))?; + pc.set_remote_description(&sess).map_err(map_err) + } + + /// Apply a remote ICE candidate trickled from the producer. + #[pyo3(signature = (candidate, mid=None))] + fn add_remote_candidate(&self, candidate: &str, mid: Option) -> PyResult<()> { + let cand = IceCandidate { + candidate: candidate.to_string(), + mid: mid.unwrap_or_default(), + }; + let mut guard = lock(&self.pc); + let pc = guard + .as_mut() + .ok_or_else(|| PyValueError::new_err("consumer is closed"))?; + pc.add_remote_candidate(&cand).map_err(map_err) + } + + /// Drain and return all queued events as a list of dicts. See the + /// [`events`](crate::events) module for the dict schema. + fn drain_events(&self, py: Python<'_>) -> PyResult> { + self.events.drain_to_py(py) + } + + /// Close the consumer. Idempotent: the first call deregisters the media + /// callbacks, drops its decoders (killing their ffmpeg), drops inbound + /// channels and the peer connection, then emits a final `on_state: "closed"`. + fn close(&self) -> PyResult<()> { + if emit_closed_once(&self.closed, &self.events) { + if let Some(id) = self.pc_id { + // Stop new track callbacks, then drop this pc's media context + // (its decoders/ffmpeg) and forget its tracks. A message callback + // racing teardown finds no context and no-ops. The peer connection + // itself is still alive here; it is dropped just below. + unsafe { sys::rtcSetTrackCallback(id, None) }; + lock(&MEDIA).remove(&id); + lock(&TRACK_PC).retain(|_, owner| *owner != id); + } + lock(&self.incoming).clear(); + *lock(&self.pc) = None; + } + Ok(()) + } +} diff --git a/rust/neuracore_webrtc/src/events.rs b/rust/neuracore_webrtc/src/events.rs new file mode 100644 index 000000000..33a16349f --- /dev/null +++ b/rust/neuracore_webrtc/src/events.rs @@ -0,0 +1,430 @@ +//! The drainable, thread-safe event queue both peers expose to Python. +//! +//! Rust tasks running on the core's tokio runtime [`push`](EventQueue::push) +//! events without ever touching the GIL; Python pulls them with a single +//! synchronous [`drain_events`] call that converts each queued [`Event`] into a +//! plain `dict` under the GIL. This keeps the producer side lock-free of Python +//! and the consumer side cheap: one lock acquisition drains the whole backlog. +//! +//! Each drained dict carries a `"kind"` discriminator. The six kinds mandated +//! by the API contract are: +//! +//! - `on_state` — connection-state transitions (`{"kind", "state"}`) +//! - `on_track_added` — a remote track appeared (`{"kind", "track_id", "mid"}`) +//! - `on_track_removed` — a remote track went away (`{"kind", "mid"}`) +//! - `on_data_channel` — a remote data channel opened (`{"kind", "label", "kind_hint"}`) +//! - `on_message` — a message arrived on a data channel (`{"kind", "label", "data"}`) +//! - `on_frame` — a decoded video frame (`{"kind", "track_id", "mid", "data", "width", "height"}`) +//! - `on_manifest` — the mid→RobotStreamTrack manifest was republished (`{"kind", "json"}`) +//! +//! Two further kinds carry signaling *out* of the core (the producer is the +//! sole offerer, the consumer answers): `on_local_description` +//! (`{"kind", "sdp_type", "sdp"}`) and `on_local_candidate` +//! (`{"kind", "candidate", "mid"}`). They are part of the surface PR1 compiles +//! against; the core only starts emitting them once PR2 wires real signaling. +//! +//! [`drain_events`]: crate::producer::Producer::drain_events + +use std::collections::VecDeque; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyDict, PyList}; + +/// A single event surfaced to Python via the drainable queue. Variants map +/// one-to-one onto the `"kind"` values documented on the module. +/// +/// `dead_code` is allowed because PR0 only ever constructs `State`; the other +/// variants are the agreed event surface and are emitted from PR2 onward. They +/// are already wired through `kind()`/`to_pydict` so the schema is locked now. +#[allow(dead_code)] +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum Event { + /// Connection-state transition, e.g. `"new"`, `"connecting"`, `"closed"`. + State(String), + /// A remote media track became available. + TrackAdded { track_id: String, mid: String }, + /// A previously-added remote media track was removed. + TrackRemoved { mid: String }, + /// A remote data channel opened. `kind_hint` mirrors the producer-side + /// `add_data_channel(label, kind)` label (e.g. `"json"`, `"control"`). + DataChannel { label: String, kind_hint: String }, + /// A message arrived on a data channel. + Message { label: String, data: Vec }, + /// A decoded video frame ready for the consumer. `data` is the raw picture + /// as 8-bit HxWx3 (RGB or BGR; the block codec is colour-order agnostic). + Frame { + track_id: String, + mid: String, + data: Vec, + width: u32, + height: u32, + }, + /// The mid→RobotStreamTrack manifest, republished verbatim as JSON text. + Manifest { json: String }, + /// A locally-created SDP offer/answer to relay over signaling. + LocalDescription { sdp_type: String, sdp: String }, + /// A locally-gathered ICE candidate to trickle over signaling. + LocalCandidate { + candidate: String, + mid: Option, + }, + /// A recoverable error surfaced from a hot or lifecycle path instead of + /// panicking across the FFI boundary: a subprocess crash, a chain-attach or + /// SDP failure, a send on a closed track, or a reconnect-needed signal. Rendered + /// as `{"kind": "on_error", "where", "detail", "consumer_id"?}`. `consumer_id` + /// is present only for a broadcaster's per-consumer error (a shared-encode error + /// has none); `where` is a short location tag (`"encode"`, `"decode"`, + /// `"negotiate"`, `"connection"`, `"send"`). + Error { + consumer_id: Option, + location: String, + detail: String, + }, + /// A per-consumer event from a [`Broadcaster`](crate::broadcaster::Broadcaster): + /// the wrapped `inner` event rendered with an extra `"consumer_id"` key so a + /// fan-out signaling layer routes it to the right consumer. The single 1:1 + /// `Producer`/`Consumer` path never constructs this — its events stay + /// untagged and byte-identical, so the single-consumer suite is unaffected. + ForConsumer { + consumer_id: String, + inner: Box, + }, +} + +impl Event { + /// An untagged recoverable error (1:1 `Producer`/`Consumer`, or a + /// broadcaster's shared-encode error that belongs to no single consumer). + pub(crate) fn error(location: &str, detail: impl Into) -> Event { + Event::Error { + consumer_id: None, + location: location.to_string(), + detail: detail.into(), + } + } + + /// A recoverable error attributed to one broadcaster consumer. + pub(crate) fn error_for(consumer_id: &str, location: &str, detail: impl Into) -> Event { + Event::Error { + consumer_id: Some(consumer_id.to_string()), + location: location.to_string(), + detail: detail.into(), + } + } + + /// The stable `"kind"` discriminator placed on the Python dict. + pub(crate) fn kind(&self) -> &'static str { + match self { + Event::State(_) => "on_state", + Event::TrackAdded { .. } => "on_track_added", + Event::TrackRemoved { .. } => "on_track_removed", + Event::DataChannel { .. } => "on_data_channel", + Event::Message { .. } => "on_message", + Event::Frame { .. } => "on_frame", + Event::Manifest { .. } => "on_manifest", + Event::LocalDescription { .. } => "on_local_description", + Event::LocalCandidate { .. } => "on_local_candidate", + Event::Error { .. } => "on_error", + // A wrapped per-consumer event keeps the inner event's kind; the + // consumer_id rides alongside it on the dict. + Event::ForConsumer { inner, .. } => inner.kind(), + } + } + + /// Render this event as a Python `dict`. Caller holds the GIL. + fn to_pydict<'py>(&self, py: Python<'py>) -> PyResult> { + // A per-consumer event renders exactly as its inner event plus a + // `consumer_id` key, so the fan-out signaling layer routes it without + // changing any other field. Done first so the recursion is obvious. + if let Event::ForConsumer { consumer_id, inner } = self { + let dict = inner.to_pydict(py)?; + dict.set_item("consumer_id", consumer_id)?; + return Ok(dict); + } + let dict = PyDict::new_bound(py); + dict.set_item("kind", self.kind())?; + match self { + Event::State(state) => { + dict.set_item("state", state)?; + } + Event::TrackAdded { track_id, mid } => { + dict.set_item("track_id", track_id)?; + dict.set_item("mid", mid)?; + } + Event::TrackRemoved { mid } => { + dict.set_item("mid", mid)?; + } + Event::DataChannel { label, kind_hint } => { + dict.set_item("label", label)?; + dict.set_item("kind_hint", kind_hint)?; + } + Event::Message { label, data } => { + dict.set_item("label", label)?; + dict.set_item("data", PyBytes::new_bound(py, data))?; + } + Event::Frame { + track_id, + mid, + data, + width, + height, + } => { + dict.set_item("track_id", track_id)?; + dict.set_item("mid", mid)?; + dict.set_item("data", PyBytes::new_bound(py, data))?; + dict.set_item("width", width)?; + dict.set_item("height", height)?; + } + Event::Manifest { json } => { + dict.set_item("json", json)?; + } + Event::LocalDescription { sdp_type, sdp } => { + dict.set_item("sdp_type", sdp_type)?; + dict.set_item("sdp", sdp)?; + } + Event::LocalCandidate { candidate, mid } => { + dict.set_item("candidate", candidate)?; + dict.set_item("mid", mid.clone())?; + } + Event::Error { + consumer_id, + location, + detail, + } => { + dict.set_item("where", location)?; + dict.set_item("detail", detail)?; + // `consumer_id` is optional: present only for a broadcaster's + // per-consumer error, so the fan-out signaling layer can route it. + if let Some(id) = consumer_id { + dict.set_item("consumer_id", id)?; + } + } + // Handled by the early return above; the match is over the inner + // (unwrapped) variants only. + Event::ForConsumer { .. } => unreachable!("ForConsumer rendered above"), + } + Ok(dict) + } +} + +/// A cloneable handle onto the shared event backlog. Cloning shares the same +/// underlying queue (it is an `Arc`), so the core's tasks and the Python-facing +/// peer hold the same queue. +#[derive(Clone, Default)] +pub(crate) struct EventQueue { + inner: Arc>>, +} + +impl EventQueue { + /// Append an event. Cheap, GIL-free, callable from any thread/task. + pub(crate) fn push(&self, event: Event) { + // A poisoned lock here only means another thread panicked mid-push; + // recovering the guard keeps event delivery alive rather than cascading + // the panic across the FFI boundary. + let mut queue = self.inner.lock().unwrap_or_else(|e| e.into_inner()); + queue.push_back(event); + } + + /// Move the whole backlog out in FIFO order, leaving the queue empty. This + /// is the GIL-free drain primitive [`drain_to_py`](Self::drain_to_py) builds + /// on; unit tests use it to assert ordering and drain semantics without the + /// interpreter. + pub(crate) fn take_all(&self) -> Vec { + let mut queue = self.inner.lock().unwrap_or_else(|e| e.into_inner()); + queue.drain(..).collect() + } + + /// Drain every queued event into a fresh Python list of dicts. The lock is + /// held only long enough to move the backlog out; the Python objects are + /// built afterwards so no task is blocked while we touch the interpreter. + pub(crate) fn drain_to_py(&self, py: Python<'_>) -> PyResult> { + let drained = self.take_all(); + let list = PyList::empty_bound(py); + for event in drained { + list.append(event.to_pydict(py)?)?; + } + Ok(list.unbind()) + } +} + +/// Emit a single terminal `on_state: "closed"` the first time it is called for +/// a given `closed` flag, and nothing on any later call. Returns `true` exactly +/// once — on that first call — so the caller can run one-shot teardown (drop +/// channels/tracks/the peer connection) under the same guard. Both peers' close +/// paths funnel through here, so "closed" is observed exactly once however many +/// times Python calls `close()`. +pub(crate) fn emit_closed_once(closed: &AtomicBool, events: &EventQueue) -> bool { + if closed.swap(true, Ordering::SeqCst) { + return false; + } + events.push(Event::State("closed".to_string())); + true +} + +#[cfg(test)] +mod tests { + //! Peer-free, GIL-free unit tests for the event queue: FIFO drain semantics, + //! the `"kind"` discriminator schema, and close-once idempotency. The + //! Event -> `dict` *field* rendering (`to_pydict`) needs the interpreter and + //! is exercised by the integration relay, which reads each field by name; + //! here we pin the discriminators (the schema's type tags) and the queueing. + + use super::*; + + #[test] + fn take_all_preserves_fifo_order_and_drains() { + let queue = EventQueue::default(); + queue.push(Event::State("new".to_string())); + queue.push(Event::DataChannel { + label: "telemetry".to_string(), + kind_hint: "reliable".to_string(), + }); + queue.push(Event::State("connected".to_string())); + + let drained = queue.take_all(); + assert_eq!( + drained, + vec![ + Event::State("new".to_string()), + Event::DataChannel { + label: "telemetry".to_string(), + kind_hint: "reliable".to_string(), + }, + Event::State("connected".to_string()), + ] + ); + // A second drain sees an empty queue: the first drain emptied it. + assert!(queue.take_all().is_empty()); + } + + #[test] + fn kind_discriminators_match_the_documented_schema() { + assert_eq!(Event::State("new".to_string()).kind(), "on_state"); + assert_eq!( + Event::TrackAdded { + track_id: "cam".to_string(), + mid: "v0".to_string(), + } + .kind(), + "on_track_added" + ); + assert_eq!( + Event::TrackRemoved { + mid: "v0".to_string(), + } + .kind(), + "on_track_removed" + ); + assert_eq!( + Event::DataChannel { + label: "telemetry".to_string(), + kind_hint: "reliable".to_string(), + } + .kind(), + "on_data_channel" + ); + assert_eq!( + Event::Message { + label: "telemetry".to_string(), + data: vec![1, 2, 3], + } + .kind(), + "on_message" + ); + assert_eq!( + Event::Frame { + track_id: "cam0".to_string(), + mid: "v0".to_string(), + data: vec![0, 0, 0], + width: 640, + height: 480, + } + .kind(), + "on_frame" + ); + assert_eq!( + Event::Manifest { + json: "{}".to_string(), + } + .kind(), + "on_manifest" + ); + assert_eq!( + Event::LocalDescription { + sdp_type: "offer".to_string(), + sdp: "v=0".to_string(), + } + .kind(), + "on_local_description" + ); + assert_eq!( + Event::LocalCandidate { + candidate: "candidate:1 ...".to_string(), + mid: Some("0".to_string()), + } + .kind(), + "on_local_candidate" + ); + assert_eq!( + Event::Error { + consumer_id: None, + location: "encode".to_string(), + detail: "ffmpeg died".to_string(), + } + .kind(), + "on_error" + ); + } + + #[test] + fn for_consumer_wraps_an_error_keeping_its_kind() { + // A per-consumer error can also ride the ForConsumer wrapper (the + // broadcaster tags reconnect/negotiate errors this way); the inner kind is + // preserved either way. + let wrapped = Event::ForConsumer { + consumer_id: "c1".to_string(), + inner: Box::new(Event::Error { + consumer_id: None, + location: "connection".to_string(), + detail: "reconnect-needed".to_string(), + }), + }; + assert_eq!(wrapped.kind(), "on_error"); + } + + #[test] + fn for_consumer_delegates_kind_to_the_wrapped_event() { + // A per-consumer wrapper keeps the inner event's "kind" discriminator; + // the consumer_id is added as a sibling key (rendered under the GIL, so + // exercised by the multi-consumer relay, not here). + let wrapped = Event::ForConsumer { + consumer_id: "c1".to_string(), + inner: Box::new(Event::LocalDescription { + sdp_type: "offer".to_string(), + sdp: "v=0".to_string(), + }), + }; + assert_eq!(wrapped.kind(), "on_local_description"); + let wrapped_state = Event::ForConsumer { + consumer_id: "c2".to_string(), + inner: Box::new(Event::State("connected".to_string())), + }; + assert_eq!(wrapped_state.kind(), "on_state"); + } + + #[test] + fn close_emits_on_state_closed_exactly_once() { + let closed = AtomicBool::new(false); + let events = EventQueue::default(); + + // The first close does the work; every later close is a no-op. This is + // the exact guard both Producer::close and Consumer::close run. + assert!(emit_closed_once(&closed, &events)); + assert!(!emit_closed_once(&closed, &events)); + assert!(!emit_closed_once(&closed, &events)); + + let drained = events.take_all(); + assert_eq!(drained, vec![Event::State("closed".to_string())]); + } +} diff --git a/rust/neuracore_webrtc/src/lib.rs b/rust/neuracore_webrtc/src/lib.rs new file mode 100644 index 000000000..4698a0b22 --- /dev/null +++ b/rust/neuracore_webrtc/src/lib.rs @@ -0,0 +1,67 @@ +// PyO3 0.22's `#[pyfunction]`/`#[pymethods]` expansion includes an `.into()` on +// the `PyResult` return value that fires clippy's `useless_conversion` lint +// when T resolves to `()`. The lint is correct about the generated code but the +// conversion lives in the macro expansion, not anything we wrote, so we silence +// it at the crate level rather than spraying allows over every method. +#![allow(clippy::useless_conversion)] + +//! PyO3 WebRTC streaming core for Neuracore — the synchronous, queue-backed +//! replacement for the aiortc stack. +//! +//! This crate ships as `neuracore.core.streaming.p2p._native_webrtc` inside the +//! Python wheel. It exposes two peers, [`Producer`](producer::Producer) (the +//! sole offerer) and [`Consumer`](consumer::Consumer) (answer-only), behind a +//! deliberately **synchronous, thread-safe, queue-backed** API: +//! +//! - Rust owns a tokio runtime and drives it on its own threads (see +//! [`runtime`]). Python never touches the runtime. +//! - `submit_frame` enqueues onto a bounded queue and returns immediately. +//! - Both peers expose a drainable event queue (see [`events`]). +//! +//! ## Scope +//! +//! PR2 wires the real data plane: a libdatachannel [`RtcPeerConnection`] (via +//! the `datachannel` crate, `media` feature) carrying reliable-ordered data +//! channels with trickle ICE and a control-channel manifest. The producer is +//! the sole offerer; the consumer answers. Video (`add_video_track` / +//! `remove_video_track`) stays stubbed until PR4. +//! +//! [`RtcPeerConnection`]: datachannel::RtcPeerConnection + +mod broadcaster; +mod congestion; +mod consumer; +mod events; +mod media; +mod producer; +mod runtime; +mod transport; + +use pyo3::prelude::*; + +/// The sizes of the three process-global lifecycle registries, as +/// `(producer_fb, media, track_pc)`. The hardening soak test reads this to assert +/// every registry returns to its starting size after add/remove churn — i.e. no +/// leaked entries. Diagnostics only; not part of the streaming API. +#[pyfunction] +fn registry_sizes() -> (usize, usize, usize) { + ( + producer::producer_fb_len(), + consumer::media_registry_len(), + consumer::track_pc_registry_len(), + ) +} + +/// Python module entrypoint registered as +/// `neuracore.core.streaming.p2p._native_webrtc`. +#[pymodule] +fn _native_webrtc(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + // Surface the bounded frame-queue depth so callers/tests can read it + // without hard-coding the constant. + module.add("FRAME_QUEUE_CAPACITY", producer::FRAME_QUEUE_CAPACITY)?; + module.add_function(wrap_pyfunction!(registry_sizes, module)?)?; + Ok(()) +} diff --git a/rust/neuracore_webrtc/src/media.rs b/rust/neuracore_webrtc/src/media.rs new file mode 100644 index 000000000..0d593c228 --- /dev/null +++ b/rust/neuracore_webrtc/src/media.rs @@ -0,0 +1,1330 @@ +//! H.264 encode/decode, the Annex-B framing the producer feeds the built-in +//! packetizer, and the FU-A depacketizer the consumer still needs. +//! +//! ## Why the producer no longer hand-rolls RTP (PR5) +//! +//! PR4 hand-rolled the producer's RTP (FU-A) because datachannel-rs keeps a +//! track's integer id private. PR5 creates the producer track through the sys +//! layer (`rtcAddTrackEx`, see [`crate::producer`]), recovering that raw id, so it +//! can attach libdatachannel's **built-in** media chain to it: +//! `rtcSetH264Packetizer` (does the FU-A framing, sequence numbers, marker bit and +//! SSRC), plus `rtcChainRtcpSrReporter` / `rtcChainRtcpNackResponder` / +//! `rtcChainPliHandler` / `rtcChainRembHandler`. The producer therefore sends +//! **raw NAL units** (as an Annex-B access unit) and the library packetizes — the +//! hand-rolled `RtpPacketizer` and its unit tests are gone. +//! +//! The **consumer** keeps the FU-A depacketizer: the C API exposes no +//! depacketizer (`rtcChainRtcpReceivingSession` only validates and passes whole +//! RTP packets through to the message callback — see +//! `reports/SPIKE-pr5-media-chain.md`), so inbound media still arrives as raw RTP +//! and [`RtpDepacketizer`] reassembles NAL units for the decoder. +//! +//! The encode (`numpy -> H.264 NAL units`) is kept separate from the send so a +//! later PR can fan one encode out to many consumers. Both ends shell out to a +//! **persistent** ffmpeg subprocess (one per track) — spawning per frame would +//! blow the glass-to-glass budget — exactly as the disk recording path does. The +//! encoder is restartable at a coarser [`crate::congestion::Step`] (lower +//! bitrate, then downscale) so the queue-driven adaptation can degrade under +//! congestion; the restart's first IDR carries SPS/PPS (`repeat-headers=1`). +//! +//! ## What lives here +//! +//! - [`NalSplitter`] — streaming Annex-B byte stream → NAL units. +//! - [`AccessUnitAssembler`] — NAL units → per-frame access units (flush on VCL). +//! - [`annexb_access_unit`] — NAL units → one Annex-B buffer the built-in +//! packetizer splits (the producer's send payload). +//! - [`RtpDepacketizer`] — the consumer's FU-A reassembly (kept; pure seam). +//! - [`DropPolicy`] — the shed-on-backlog decision behind a fake-clock seam. +//! - [`H264Encoder`] / [`H264Decoder`] — the persistent ffmpeg subprocesses. + +use std::io::{Read, Write}; +use std::process::{Child, ChildStdin, Command, Stdio}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::RecvTimeoutError; +use std::sync::{Arc, Mutex}; +use std::thread::JoinHandle; +use std::time::Duration; + +/// How long the encoder reader waits for more ffmpeg output before deciding the +/// frame is complete and flushing the trailing NAL. Annex-B carries no NAL length +/// prefix, so a NAL is only "complete" when the next start code arrives — which +/// never happens for the final frame of a paused stream. An idle this short +/// (below the inter-frame gap at 45 fps) flushes that trailing NAL promptly while +/// staying clear of a single frame's atomic ffmpeg write. +const ENCODER_FLUSH_IDLE: Duration = Duration::from_millis(12); + +/// How long the decoder feed waits with no new NAL before nudging ffmpeg to emit +/// its last held frame. The h264 decoder only outputs a frame once it sees the +/// *next* access unit, so the final frame of a paused stream is stuck inside +/// ffmpeg until an Access Unit Delimiter (or EOF) arrives. This must clear the +/// inter-frame gap (33 ms at 30 fps) *and* any startup or mid-stream jitter, +/// because a spuriously-early AUD makes ffmpeg re-emit the frame it is still +/// holding — a duplicate. It must also stay below the test collector's quiet +/// window (0.5 s) so a genuinely paused stream's last frame still flushes in +/// time. 250 ms sits comfortably between the two. +const DECODER_FLUSH_IDLE: Duration = Duration::from_millis(250); + +/// A bare H.264 Access Unit Delimiter NAL (start code + type-9 + primary_pic +/// payload), fed to the decoder on an idle to flush its last held frame. +const AUD_NAL: [u8; 6] = [0, 0, 0, 1, 0x09, 0x10]; + +/// How many trailing bytes of a subprocess's stderr to retain for the +/// `on_error` detail when it crashes. ffmpeg's last error line(s) are the useful +/// diagnostic; an unbounded capture would be a slow leak on a chatty process, so +/// the tail is ring-trimmed to this cap. +const STDERR_TAIL_CAP: usize = 2048; + +/// The default bounded crash-restart budget for a persistent subprocess: how many +/// consecutive deaths the feed will try to recover from before surfacing a +/// terminal error rather than spinning forever (e.g. ffmpeg permanently missing). +/// A healthy run resets the budget. See [`RestartPolicy`]. +pub(crate) const DEFAULT_RESTART_BUDGET: u32 = 5; + +/// The built-in H.264 packetizer's max RTP fragment size (bytes). Capped well +/// below the 64 KiB loopback datagram so a 640x480 keyframe actually fragments +/// into FU-A packets rather than riding as one datagram — the fragmentation path +/// must be exercised on loopback, not just under a real MTU. +pub(crate) const MAX_FRAGMENT_SIZE: u16 = 1200; + +/// RTP fixed header length (V/P/X/CC + M/PT + seq + timestamp + ssrc), no CSRCs. +/// The consumer's depacketizer skips this prefix on every inbound packet. +const RTP_HEADER_LEN: usize = 12; + +/// The FU-A NAL type (RFC 6184 §5.8): fragmentation unit without DON. +const FU_A_TYPE: u8 = 28; + +/// The RTCP CNAME the built-in packetizer's SR reporter advertises. libdatachannel +/// **throws** (`rtcSetH264Packetizer` returns -1) if `rtcPacketizerInit.cname` is +/// null — and then `rtcChainRtcpSrReporter` fails too because it chains onto the +/// packetizer's RTP config that was never created. So the cname must always be a +/// non-null, non-empty C string; [`packetizer_cname`] is the single source and +/// `cname_is_non_null_and_non_empty` guards the invariant. +pub(crate) const PACKETIZER_CNAME: &str = "neuracore"; + +/// The packetizer's non-null CNAME as a `CString`, ready for +/// `rtcPacketizerInit.cname`. Panics only if [`PACKETIZER_CNAME`] ever contains an +/// interior NUL (a compile-time-constant string that never will), so the live +/// path can `.expect` it and the guard test pins it. +pub(crate) fn packetizer_cname() -> std::ffi::CString { + std::ffi::CString::new(PACKETIZER_CNAME).expect("packetizer cname has no interior NUL") +} + +/// 90 kHz is the RTP clock for video. The producer steps the track's RTP +/// timestamp by `VIDEO_CLOCK_HZ / fps` per access unit (via +/// `rtcSetTrackRtpTimestamp`) so every packet of one frame shares a timestamp and +/// the SR reporter sees a coherent clock. +pub(crate) const VIDEO_CLOCK_HZ: u32 = 90_000; + +/// The ffmpeg binary to shell out to: `NEURACORE_WEBRTC_FFMPEG` if set, else +/// `ffmpeg` on `PATH` (the same provisioning the disk recording path uses). +pub(crate) fn ffmpeg_bin() -> String { + std::env::var("NEURACORE_WEBRTC_FFMPEG").unwrap_or_else(|_| "ffmpeg".to_string()) +} + +// --------------------------------------------------------------------------- +// Annex-B NAL splitting and access-unit grouping +// --------------------------------------------------------------------------- + +/// Streaming Annex-B splitter: feed it arbitrary byte chunks from the encoder's +/// stdout and it yields complete NAL units (start codes stripped). A NAL is only +/// emitted once the *next* start code is seen, so the trailing partial NAL stays +/// buffered until more bytes arrive or [`NalSplitter::flush`] is called at EOF. +pub(crate) struct NalSplitter { + buf: Vec, +} + +impl NalSplitter { + pub(crate) fn new() -> Self { + Self { buf: Vec::new() } + } + + /// Indices of every `00 00 01` start-code prefix in `buf`. A 4-byte + /// `00 00 00 01` start code contains a `00 00 01` at offset 1; the extra + /// leading zero is stripped from the preceding NAL's tail instead. + fn start_codes(buf: &[u8]) -> Vec { + let mut positions = Vec::new(); + if buf.len() < 3 { + return positions; + } + let mut i = 0; + while i + 2 < buf.len() { + if buf[i] == 0 && buf[i + 1] == 0 && buf[i + 2] == 1 { + positions.push(i); + i += 3; + } else { + i += 1; + } + } + positions + } + + /// Feed more bytes; return every NAL unit that became complete. + pub(crate) fn push(&mut self, bytes: &[u8]) -> Vec> { + self.buf.extend_from_slice(bytes); + let positions = Self::start_codes(&self.buf); + let mut out = Vec::new(); + if positions.len() < 2 { + return out; + } + for pair in positions.windows(2) { + let start = pair[0] + 3; + let mut end = pair[1]; + // Strip the trailing zero(s) that belong to the next 4-byte start + // code (H.264 RBSP never legitimately ends in 0x00 after the stop + // bit, so this cannot truncate real payload). + while end > start && self.buf[end - 1] == 0 { + end -= 1; + } + if end > start { + out.push(self.buf[start..end].to_vec()); + } + } + // Retain from the final start code onward (its NAL is not yet complete). + let last = *positions.last().unwrap(); + self.buf.drain(..last); + out + } + + /// Emit the final buffered NAL at end of stream (if any). + pub(crate) fn flush(&mut self) -> Option> { + let positions = Self::start_codes(&self.buf); + let start = *positions.first()? + 3; + let mut end = self.buf.len(); + while end > start && self.buf[end - 1] == 0 { + end -= 1; + } + let nal = (end > start).then(|| self.buf[start..end].to_vec()); + self.buf.clear(); + nal + } +} + +/// Returns whether a NAL header byte denotes a VCL (coded-slice) NAL — types 1 +/// (non-IDR slice) through 5 (IDR slice). A frame's single slice is its last +/// NAL, so a VCL NAL completes the access unit. +fn is_vcl(nal: &[u8]) -> bool { + matches!(nal.first().map(|b| b & 0x1F), Some(1..=5)) +} + +/// Groups NAL units into access units (one decoded frame each). The encoder is +/// configured single-slice, so a VCL NAL is the last NAL of its frame; non-VCL +/// NALs (SPS/PPS/SEI) accumulate as the prefix of the access unit they precede. +pub(crate) struct AccessUnitAssembler { + nals: Vec>, +} + +impl AccessUnitAssembler { + pub(crate) fn new() -> Self { + Self { nals: Vec::new() } + } + + /// Append a NAL; return the completed access unit when this NAL is a VCL slice. + pub(crate) fn push(&mut self, nal: Vec) -> Option>> { + let vcl = is_vcl(&nal); + self.nals.push(nal); + vcl.then(|| std::mem::take(&mut self.nals)) + } +} + +/// The number of VCL (coded-slice) NAL units in an access unit. The producer +/// sends one access unit as exactly **one** RTP frame under one capture +/// timestamp, so a well-formed access unit carries exactly one VCL NAL. More than +/// one means a slicing or NAL-aggregation change (e.g. x264 multi-slice threading, +/// or STAP-A) silently broke the one-VCL-per-frame invariant the timestamping +/// depends on — the exact defect in `reports/SPIKE-chrome-pframe.md`. The send +/// path asserts `== 1` and drops loudly rather than fabricate timestamps; see +/// [`x264_params`] for the encoder lever that keeps it true. Pure and testable. +pub(crate) fn vcl_nal_count(nals: &[Vec]) -> usize { + nals.iter().filter(|nal| is_vcl(nal)).count() +} + +// --------------------------------------------------------------------------- +// Producer send framing (the built-in packetizer does the RTP) +// --------------------------------------------------------------------------- + +/// Join an access unit's NAL units (start codes stripped) into one Annex-B +/// buffer with 4-byte long start codes, the payload the producer hands +/// `rtcSendMessage`. The attached `rtcSetH264Packetizer` is configured with the +/// long-start-sequence NAL separator, so it splits this back into NAL units and +/// does the RTP framing (single-NAL or FU-A), sequence numbers, marker bit and +/// SSRC itself. Pure so the framing is unit-testable without the chain. +pub(crate) fn annexb_access_unit(nals: &[Vec]) -> Vec { + let mut out = Vec::new(); + for nal in nals { + if nal.is_empty() { + continue; + } + out.extend_from_slice(&[0, 0, 0, 1]); + out.extend_from_slice(nal); + } + out +} + +/// Reassembles H.264 NAL units from inbound RTP. Single NAL packets pass through; +/// FU-A fragments are stitched back into the original NAL. A sequence gap mid-FU-A +/// drops the partial NAL rather than emitting a corrupt one. Owned and pure. +pub(crate) struct RtpDepacketizer { + fu_buffer: Vec, + in_fu: bool, + last_seq: Option, +} + +impl RtpDepacketizer { + pub(crate) fn new() -> Self { + Self { + fu_buffer: Vec::new(), + in_fu: false, + last_seq: None, + } + } + + /// Feed one RTP packet; return any NAL unit(s) that became complete. + pub(crate) fn depacketize(&mut self, packet: &[u8]) -> Vec> { + let mut out = Vec::new(); + if packet.len() <= RTP_HEADER_LEN { + return out; + } + let seq = u16::from_be_bytes([packet[2], packet[3]]); + // Drop a non-advancing sequence: the producer's NACK responder + // retransmits lost packets, and a retransmit (or a reordered duplicate) + // arrives with a sequence we have already processed. Re-feeding it to the + // decoder would surface a duplicate frame, so only strictly-newer + // sequences pass. "Newer" is the forward half of the 16-bit sequence + // space (1..=32767 ahead); a duplicate (0) or an old packet is dropped. + if let Some(prev) = self.last_seq { + let ahead = seq.wrapping_sub(prev); + if ahead == 0 || ahead >= 0x8000 { + return out; + } + } + let gap = matches!(self.last_seq, Some(prev) if seq != prev.wrapping_add(1)); + self.last_seq = Some(seq); + + let payload = &packet[RTP_HEADER_LEN..]; + let nal_type = payload[0] & 0x1F; + match nal_type { + FU_A_TYPE => { + if payload.len() < 2 { + return out; + } + let fu_indicator = payload[0]; + let fu_header = payload[1]; + let start = fu_header & 0x80 != 0; + let end = fu_header & 0x40 != 0; + let orig_type = fu_header & 0x1F; + if start { + self.fu_buffer.clear(); + self.fu_buffer.push((fu_indicator & 0xE0) | orig_type); + self.fu_buffer.extend_from_slice(&payload[2..]); + self.in_fu = true; + } else if !self.in_fu || gap { + // Missing the FU start, or a packet was lost mid-fragment: + // abandon the partial NAL rather than emit a corrupt one. + self.in_fu = false; + self.fu_buffer.clear(); + return out; + } else { + self.fu_buffer.extend_from_slice(&payload[2..]); + } + if end && self.in_fu { + out.push(std::mem::take(&mut self.fu_buffer)); + self.in_fu = false; + } + } + 1..=23 => { + // A single, complete NAL unit. Any in-flight FU-A is broken. + self.in_fu = false; + self.fu_buffer.clear(); + out.push(payload.to_vec()); + } + _ => { + // STAP-A/MTAP/etc. are never produced by our packetizer; ignore. + } + } + out + } +} + +// --------------------------------------------------------------------------- +// Drop policy +// --------------------------------------------------------------------------- + +/// The shed-on-backlog decision. The encoder is a fixed-rate sink fed from the +/// bounded ingress queue; while the queue has room (the steady state at or below +/// the encoder's throughput) nothing is ever dropped, so at or below 30 fps the +/// stream is loss-free. Frames are shed only once the encoder has backed the +/// queue up to capacity — which only happens above the sustainable rate. +/// +/// Pure so a fake clock/queue can drive the shed-above-30 / never-below-30 / +/// zero-drop-at-or-below-30 contract without a live encoder. +pub(crate) struct DropPolicy { + capacity: usize, +} + +impl DropPolicy { + pub(crate) fn new(capacity: usize) -> Self { + Self { + capacity: capacity.max(1), + } + } + + /// Admit a newly submitted frame iff the ingress queue has room for it. + pub(crate) fn admit(&self, backlog: usize) -> bool { + backlog < self.capacity + } +} + +// --------------------------------------------------------------------------- +// Crash-restart policy (the bounded-retry decision behind a seam) +// --------------------------------------------------------------------------- + +/// The bounded decision a feed makes when a persistent subprocess dies: try to +/// restart it, up to a budget, then give up and surface a terminal error rather +/// than spinning. The budget guards against a permanently-broken process (e.g. +/// ffmpeg missing or a fatal arg) turning a crash into a hot loop. A subprocess +/// that comes back healthy (produces output again) calls [`reset`](Self::reset) +/// to restore the full budget, so transient crashes never exhaust it. +/// +/// Pure and clock-free, so the crash-restart decision is unit-tested with a fake +/// "subprocess that dies" without a live ffmpeg. +#[derive(Debug)] +pub(crate) struct RestartPolicy { + max_consecutive: u32, + consecutive: u32, +} + +impl RestartPolicy { + pub(crate) fn new(max_consecutive: u32) -> Self { + Self { + max_consecutive, + consecutive: 0, + } + } + + /// Record a death and decide whether to attempt another restart. Returns + /// `true` (counting the attempt) while the budget remains; `false` once it is + /// exhausted. + pub(crate) fn should_restart(&mut self) -> bool { + if self.consecutive >= self.max_consecutive { + return false; + } + self.consecutive += 1; + true + } + + /// Mark the subprocess healthy again, clearing the consecutive-failure count + /// so a later transient crash gets the full budget. + pub(crate) fn reset(&mut self) { + self.consecutive = 0; + } + + /// Whether the restart budget is spent (no further restart will be attempted + /// until a [`reset`](Self::reset)). Exercised by the crash-restart unit tests. + #[cfg_attr(not(test), allow(dead_code))] + pub(crate) fn exhausted(&self) -> bool { + self.consecutive >= self.max_consecutive + } +} + +impl Default for RestartPolicy { + fn default() -> Self { + Self::new(DEFAULT_RESTART_BUDGET) + } +} + +/// Append `chunk` to a bounded stderr-tail buffer, trimming the front so it never +/// grows past [`STDERR_TAIL_CAP`]. Shared by the encoder/decoder stderr drains. +fn append_stderr_tail(tail: &Mutex, chunk: &str) { + let mut buf = tail.lock().unwrap_or_else(|e| e.into_inner()); + buf.push_str(chunk); + if buf.len() > STDERR_TAIL_CAP { + let cut = buf.len() - STDERR_TAIL_CAP; + // Trim on a char boundary so the retained tail is valid UTF-8. + let cut = (cut..=buf.len()) + .find(|&i| buf.is_char_boundary(i)) + .unwrap_or(buf.len()); + buf.drain(..cut); + } +} + +/// Drain a subprocess's stderr to a bounded tail buffer (and echo it when +/// `NEURACORE_WEBRTC_DEBUG` is set). Reused for both ffmpeg subprocesses; runs on +/// its own thread and ends on stderr EOF, so it is joined on `Drop` like the other +/// reader threads. +fn spawn_stderr_drain( + name: &str, + mut stderr: std::process::ChildStderr, + tail: Arc>, +) -> std::io::Result> { + let echo = std::env::var_os("NEURACORE_WEBRTC_DEBUG").is_some(); + std::thread::Builder::new() + .name(name.into()) + .spawn(move || { + let mut buf = [0u8; 1 << 12]; + loop { + match stderr.read(&mut buf) { + Ok(0) | Err(_) => break, + Ok(n) => { + let chunk = String::from_utf8_lossy(&buf[..n]); + if echo { + eprint!("{chunk}"); + } + append_stderr_tail(&tail, &chunk); + } + } + } + }) +} + +// --------------------------------------------------------------------------- +// Persistent ffmpeg encode / decode subprocesses +// --------------------------------------------------------------------------- + +/// The encoder's rate-control / resolution parameters for one ladder rung. The +/// input is always the full-resolution rgb24 source; `scale` (1 = full, 2 = half +/// each axis) downscales **inside** ffmpeg, and `bitrate` is the libx264 target. +/// Restarting the encoder with a coarser [`EncodeParams`] is how the queue-driven +/// adaptation degrades under congestion. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct EncodeParams { + pub fps: u32, + pub bitrate: u32, + pub scale: u32, +} + +/// The libx264 `-x264-params` string for an encoder with the given `keyint`. +/// +/// ## `threads=1` is load-bearing for Chrome (single slice per frame, low latency) +/// +/// `-tune zerolatency` turns on x264 **slice-based threading** (`sliced-threads=1`), +/// which splits every coded frame into one slice *per worker thread* — on a many-core +/// host that is ~7+ slices per frame, each its own VCL NAL. Our pipeline assumes +/// **one VCL NAL per frame**: [`AccessUnitAssembler`] flushes an access unit on each +/// VCL slice, and the producer sends each access unit as a separate `rtcSendMessage` +/// with its own RTP timestamp and marker bit. With multi-slice frames that turns one +/// captured frame into N RTP "frames" — N timestamps (the capture-timestamp queue +/// underflows ~N:1 and fabricates the rest, so they run *backwards*), N markers, and +/// each carrying only 1/N of the macroblocks. The loopback ffmpeg consumer reassembles +/// the Annex-B NALs regardless and masks it, but **Chrome's RTP frame assembler keys +/// frames on the timestamp**: it sees N partial pseudo-frames per real frame and never +/// completes the inter-keyframe P-frames (`framesReceived` stalls at the keyframe rate +/// with `packetsLost == 0`). See `reports/SPIKE-chrome-pframe.md`. +/// +/// `threads=1` (not merely `sliced-threads=0`) is the fix: a single thread emits a +/// single slice per frame. Disabling *only* slice threading would let x264 fall back +/// to **frame-based** threading, whose pipeline delays output by one frame per worker +/// (~14 frames / ~310 ms on a 14-core host — a glass-to-glass SLO blowout). A fully +/// serial encoder has no such pipeline and still clears the source rate with room to +/// spare at `ultrafast`. `slices=1` is kept belt-and-braces (a single thread cannot +/// slice-parallelise anyway). +pub(crate) fn x264_params(keyint: &str) -> String { + // Test-only hook (`NCD_WEBRTC_FORCE_SLICES=N`): deliberately emit N slices per + // frame so the one-VCL-NAL-per-access-unit invariant guard can be exercised + // end-to-end — it must DROP the malformed access unit and shout, never panic. + // Unset (the only production value) keeps the single-slice invariant below. + if let Some(n) = std::env::var_os("NCD_WEBRTC_FORCE_SLICES") { + let n = n.to_string_lossy(); + return format!( + "keyint={keyint}:min-keyint={keyint}:scenecut=0:bframes=0:\ + repeat-headers=1:slices={n}:annexb=1" + ); + } + format!( + "keyint={keyint}:min-keyint={keyint}:scenecut=0:bframes=0:\ + repeat-headers=1:slices=1:threads=1:annexb=1" + ) +} + +/// A persistent ffmpeg encoder: raw rgb24 frames in on stdin, H.264 Annex-B out +/// on stdout. A reader thread splits the output into NAL units, groups them into +/// per-frame access units, and hands each access unit to the `on_access_unit` +/// callback (which frames and sends it on the track). One per track; restarted +/// (a fresh instance) when the ladder rung changes. +pub(crate) struct H264Encoder { + child: Mutex, + stdin: Mutex>, + reader: Mutex>>, + splitter: Mutex>>, + stderr: Mutex>>, + frame_len: usize, + /// Cleared by the reader thread when ffmpeg's stdout reaches EOF — i.e. the + /// subprocess exited. The feed reads [`is_alive`](Self::is_alive) to detect a + /// crash and trigger a bounded restart instead of stalling silently. + alive: Arc, + /// The tail of ffmpeg's stderr, the diagnostic surfaced in the `on_error` + /// event when the encoder crashes. + stderr_tail: Arc>, +} + +impl H264Encoder { + /// Spawn ffmpeg for a `width`x`height` rgb24 source at the given + /// [`EncodeParams`]. `on_access_unit` runs on the reader thread for every + /// encoded frame. The first output IDR always carries SPS/PPS + /// (`repeat-headers=1`), so a restart at a new rung is itself a clean + /// keyframe — which is how a coalesced PLI is satisfied. + pub(crate) fn new( + width: u32, + height: u32, + params: EncodeParams, + mut on_access_unit: impl FnMut(Vec>) + Send + 'static, + ) -> std::io::Result { + let fps = params.fps.max(1); + let keyint = fps.to_string(); + let scale = params.scale.max(1); + // Even dimensions for yuv420p chroma. Downscale inside ffmpeg so the wire + // resolution drops while the source stays full-res; the consumer scales + // any rung back to its fixed decode size. + let out_w = ((width / scale) & !1).max(2); + let out_h = ((height / scale) & !1).max(2); + // VBV-capped CRF rate control: CRF keeps low-complexity content small (so + // a clean link never bursts the loopback socket — strict CBR padding + // would, and did), while `maxrate`/`bufsize` cap the peak at the ladder + // rung so a coarser rung genuinely throttles the stream under a + // constrained link. A small `bufsize` keeps the cap responsive (low + // latency) rather than letting a big VBV buffer absorb a whole second. + let kbps = (params.bitrate / 1000).max(50); + let maxrate = format!("{kbps}k"); + let bufsize = format!("{}k", (kbps / 2).max(32)); + let scale_filter = format!("scale={out_w}:{out_h}:flags=fast_bilinear"); + let mut args: Vec = [ + "-hide_banner", "-loglevel", "error", + "-f", "rawvideo", "-pix_fmt", "rgb24", + ] + .iter() + .map(|s| s.to_string()) + .collect(); + args.extend([ + "-s".into(), format!("{width}x{height}"), + "-r".into(), fps.to_string(), + "-i".into(), "pipe:0".into(), + "-an".into(), + "-vf".into(), scale_filter, + "-c:v".into(), "libx264".into(), + "-profile:v".into(), "baseline".into(), + "-pix_fmt".into(), "yuv420p".into(), + "-preset".into(), "ultrafast".into(), + "-tune".into(), "zerolatency".into(), + "-bf".into(), "0".into(), + "-g".into(), keyint.clone(), + "-crf".into(), "26".into(), + "-maxrate".into(), maxrate, + "-bufsize".into(), bufsize, + "-x264-params".into(), + x264_params(&keyint), + "-f".into(), "h264".into(), "pipe:1".into(), + ]); + // Pipe stderr (always) into a bounded tail buffer so a crash carries + // ffmpeg's last error line in the surfaced `on_error`; the drain thread + // also echoes it under NEURACORE_WEBRTC_DEBUG, preserving the old inherit + // behaviour for interactive debugging. + let mut child = Command::new(ffmpeg_bin()) + .args(&args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + let stdin = child.stdin.take(); + let mut stdout = child + .stdout + .take() + .ok_or_else(|| std::io::Error::other("ffmpeg encoder stdout unavailable"))?; + let child_stderr = child + .stderr + .take() + .ok_or_else(|| std::io::Error::other("ffmpeg encoder stderr unavailable"))?; + let stderr_tail = Arc::new(Mutex::new(String::new())); + let stderr = spawn_stderr_drain("ncwebrtc-encode-err", child_stderr, stderr_tail.clone())?; + + // Two threads: the reader does blocking stdout reads and forwards raw + // chunks; the splitter feeds them through the NAL splitter/assembler with + // an idle-flush so the trailing NAL of the final (or any paused) frame is + // not stranded waiting for a start code that never comes. + let alive = Arc::new(AtomicBool::new(true)); + let reader_alive = alive.clone(); + let (chunk_tx, chunk_rx) = std::sync::mpsc::channel::>(); + let reader = std::thread::Builder::new() + .name("ncwebrtc-encode-read".into()) + .spawn(move || { + let mut buf = [0u8; 1 << 16]; + loop { + match stdout.read(&mut buf) { + // stdout EOF/err == the subprocess exited: mark it dead so + // the feed restarts it rather than stalling. + Ok(0) | Err(_) => { + reader_alive.store(false, Ordering::SeqCst); + break; + } + Ok(n) => { + if chunk_tx.send(buf[..n].to_vec()).is_err() { + break; + } + } + } + } + })?; + let splitter = std::thread::Builder::new() + .name("ncwebrtc-encode-split".into()) + .spawn(move || { + let mut splitter = NalSplitter::new(); + let mut assembler = AccessUnitAssembler::new(); + let mut emit = |nal: Vec, assembler: &mut AccessUnitAssembler| { + if let Some(au) = assembler.push(nal) { + on_access_unit(au); + } + }; + loop { + match chunk_rx.recv_timeout(ENCODER_FLUSH_IDLE) { + Ok(chunk) => { + for nal in splitter.push(&chunk) { + emit(nal, &mut assembler); + } + } + // Idle: ffmpeg produced nothing for a frame period, so the + // buffered trailing NAL is a complete frame — flush it. + Err(RecvTimeoutError::Timeout) => { + if let Some(nal) = splitter.flush() { + emit(nal, &mut assembler); + } + } + Err(RecvTimeoutError::Disconnected) => { + if let Some(nal) = splitter.flush() { + emit(nal, &mut assembler); + } + break; + } + } + } + })?; + + Ok(Self { + child: Mutex::new(child), + stdin: Mutex::new(stdin), + reader: Mutex::new(Some(reader)), + splitter: Mutex::new(Some(splitter)), + stderr: Mutex::new(Some(stderr)), + frame_len: (width as usize) * (height as usize) * 3, + alive, + stderr_tail, + }) + } + + /// Write one raw rgb24 frame to ffmpeg's stdin. Blocks if ffmpeg has backed + /// up — that back-pressure propagates to the bounded ingress queue, which is + /// where frames are shed. Returns false once stdin has gone away (a crash): + /// the feed treats that, like [`is_alive`](Self::is_alive), as a death signal. + pub(crate) fn write_frame(&self, data: &[u8]) -> bool { + if data.len() != self.frame_len { + return true; // wrong shape for this encoder; skip, do not wedge + } + let mut guard = self.stdin.lock().unwrap_or_else(|e| e.into_inner()); + match guard.as_mut() { + Some(stdin) => stdin.write_all(data).is_ok(), + None => false, + } + } + + /// Whether ffmpeg is still running (its stdout has not hit EOF). The feed + /// polls this to detect a crash and trigger a bounded restart. + pub(crate) fn is_alive(&self) -> bool { + self.alive.load(Ordering::SeqCst) + } + + /// A snapshot of ffmpeg's stderr tail, for the surfaced crash diagnostic. + pub(crate) fn stderr_tail(&self) -> String { + self.stderr_tail + .lock() + .unwrap_or_else(|e| e.into_inner()) + .clone() + } +} + +impl Drop for H264Encoder { + fn drop(&mut self) { + // Close stdin (EOF -> ffmpeg flushes and exits), then make sure the + // process is gone and join the reader so no thread outlives the encoder. + *self.stdin.lock().unwrap_or_else(|e| e.into_inner()) = None; + if let Ok(mut child) = self.child.lock() { + let _ = child.kill(); + let _ = child.wait(); + } + // Reader exits on stdout EOF and drops its chunk sender; the splitter then + // sees Disconnected and exits; the stderr drain exits on stderr EOF. Join + // all three so no thread (and no zombie child) outlives the encoder. + for slot in [&self.reader, &self.splitter, &self.stderr] { + if let Some(handle) = slot.lock().unwrap_or_else(|e| e.into_inner()).take() { + let _ = handle.join(); + } + } + } +} + +/// A persistent ffmpeg decoder: H.264 Annex-B in on stdin, raw rgb24 frames out +/// on stdout. A writer thread feeds NAL units (start-code framed) to stdin; a +/// reader thread reads fixed-size decoded frames and hands each to `on_frame`. +/// One per inbound track. +pub(crate) struct H264Decoder { + input_tx: Mutex>>>, + child: Mutex, + writer: Mutex>>, + reader: Mutex>>, + stderr: Mutex>>, + /// Cleared by the reader thread when ffmpeg's stdout reaches EOF (the + /// subprocess exited). The consumer polls [`is_alive`](Self::is_alive) to + /// detect a decoder crash and restart the receive pipeline. + alive: Arc, + /// The tail of ffmpeg's stderr for the surfaced crash diagnostic. + stderr_tail: Arc>, +} + +impl H264Decoder { + pub(crate) fn new( + width: u32, + height: u32, + mut on_frame: impl FnMut(Vec) + Send + 'static, + ) -> std::io::Result { + // Normalise every rung back to the fixed decode size: the producer may + // downscale under congestion (a coarser ladder rung), so scale whatever + // resolution arrives up to width x height. The scale filter reconfigures + // on a mid-stream resolution change, and the block-coded header band + // survives the rescale, so the consumer's fixed-size frame reader stays + // correct across an adaptation step. + let scale_filter = format!("scale={width}:{height}:flags=fast_bilinear"); + let mut child = Command::new(ffmpeg_bin()) + .args([ + "-hide_banner", + "-loglevel", + "error", + // Minimal probe + low_delay so the decoder emits each frame as + // soon as it is decoded, starting from the first IDR, rather than + // buffering megabytes to analyse the stream. Baseline has no + // B-frames, so output order == input order. + "-probesize", + "32", + "-analyzeduration", + "0", + "-flags", + "low_delay", + "-f", + "h264", + "-i", + "pipe:0", + "-an", + "-vf", + &scale_filter, + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", + "pipe:1", + ]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + let mut stdin = child + .stdin + .take() + .ok_or_else(|| std::io::Error::other("ffmpeg decoder stdin unavailable"))?; + let mut stdout = child + .stdout + .take() + .ok_or_else(|| std::io::Error::other("ffmpeg decoder stdout unavailable"))?; + let child_stderr = child + .stderr + .take() + .ok_or_else(|| std::io::Error::other("ffmpeg decoder stderr unavailable"))?; + let stderr_tail = Arc::new(Mutex::new(String::new())); + let stderr = spawn_stderr_drain("ncwebrtc-decode-err", child_stderr, stderr_tail.clone())?; + let alive = Arc::new(AtomicBool::new(true)); + + let (input_tx, input_rx) = std::sync::mpsc::channel::>(); + let writer = std::thread::Builder::new() + .name("ncwebrtc-decode-in".into()) + .spawn(move || { + // `flushed` tracks whether the last real NAL has already been + // followed by an AUD, so a long pause emits exactly one AUD (which + // flushes ffmpeg's last held frame) rather than a stream of them. + let mut flushed = true; + loop { + match input_rx.recv_timeout(DECODER_FLUSH_IDLE) { + Ok(annexb) => { + if stdin.write_all(&annexb).is_err() { + break; + } + flushed = false; + } + Err(RecvTimeoutError::Timeout) => { + if !flushed { + if stdin.write_all(&AUD_NAL).is_err() { + break; + } + flushed = true; + } + } + Err(RecvTimeoutError::Disconnected) => break, + } + } + // Dropping stdin signals EOF so ffmpeg drains and exits. + })?; + + let frame_len = (width as usize) * (height as usize) * 3; + let reader_alive = alive.clone(); + let reader = std::thread::Builder::new() + .name("ncwebrtc-decode-out".into()) + .spawn(move || { + let mut frame = vec![0u8; frame_len]; + // Drop a decoded picture byte-identical to the one just emitted. + // libdatachannel's RTCP receiving session re-delivers the first + // frame's media at startup (re-packetized, so its new RTP + // sequences slip past the depacketizer's sequence de-dup), making + // the decoder emit the identical first picture twice. A real + // encoder never produces two bit-exact consecutive frames, so + // suppressing an exact duplicate is information-preserving and + // costs one frame comparison. + let mut prev: Option> = None; + // Each decoded picture is exactly frame_len bytes; a short read + // means EOF/shutdown. + while stdout.read_exact(&mut frame).is_ok() { + if prev.as_deref() == Some(frame.as_slice()) { + continue; + } + on_frame(frame.clone()); + prev = Some(frame.clone()); + } + // A short read means EOF/shutdown — the subprocess exited. Mark it + // dead so the consumer restarts the receive pipeline. + reader_alive.store(false, Ordering::SeqCst); + })?; + + Ok(Self { + input_tx: Mutex::new(Some(input_tx)), + child: Mutex::new(child), + writer: Mutex::new(Some(writer)), + reader: Mutex::new(Some(reader)), + stderr: Mutex::new(Some(stderr)), + alive, + stderr_tail, + }) + } + + /// Whether ffmpeg is still running (its stdout has not hit EOF). The consumer + /// polls this to detect a decoder crash and restart the receive pipeline. + pub(crate) fn is_alive(&self) -> bool { + self.alive.load(Ordering::SeqCst) + } + + /// A snapshot of ffmpeg's stderr tail, for the surfaced crash diagnostic. + pub(crate) fn stderr_tail(&self) -> String { + self.stderr_tail + .lock() + .unwrap_or_else(|e| e.into_inner()) + .clone() + } + + /// Feed one NAL unit (no start code) to the decoder as Annex-B. + pub(crate) fn feed_nal(&self, nal: &[u8]) { + let mut annexb = Vec::with_capacity(nal.len() + 4); + annexb.extend_from_slice(&[0, 0, 0, 1]); + annexb.extend_from_slice(nal); + if let Some(tx) = self + .input_tx + .lock() + .unwrap_or_else(|e| e.into_inner()) + .as_ref() + { + let _ = tx.send(annexb); + } + } +} + +impl Drop for H264Decoder { + fn drop(&mut self) { + // Drop the sender (writer thread ends, closes stdin -> ffmpeg EOF), then + // kill/reap the process and join both threads. + *self.input_tx.lock().unwrap_or_else(|e| e.into_inner()) = None; + if let Ok(mut child) = self.child.lock() { + let _ = child.kill(); + let _ = child.wait(); + } + for slot in [&self.writer, &self.reader, &self.stderr] { + if let Some(handle) = slot.lock().unwrap_or_else(|e| e.into_inner()).take() { + let _ = handle.join(); + } + } + } +} + +/// A live track-open flag shared between the producer's track handler (which +/// flips it on `on_open`) and the encoder feed thread (which holds frames until +/// it is set, so the very first encoded access unit — always an IDR — is the +/// first thing sent once SRTP is ready). +pub(crate) type OpenFlag = Arc; + +/// Convenience: a fresh, closed open-flag. +pub(crate) fn open_flag() -> OpenFlag { + Arc::new(AtomicBool::new(false)) +} + +/// Whether `flag` has been flipped open. +pub(crate) fn is_open(flag: &OpenFlag) -> bool { + flag.load(Ordering::SeqCst) +} + +#[cfg(test)] +mod tests { + //! Peer-free, ffmpeg-free unit tests for the framing we own: Annex-B + //! splitting, access-unit grouping, RTP packetize/depacketize (including + //! FU-A and gap handling), and the drop policy behind a fake clock/queue. + + use super::*; + + fn start_code(four: bool, nal: &[u8]) -> Vec { + let mut v = Vec::new(); + v.extend_from_slice(if four { &[0, 0, 0, 1] } else { &[0, 0, 1] }); + v.extend_from_slice(nal); + v + } + + // --- Annex-B splitting ---------------------------------------------------- + + #[test] + fn nal_splitter_separates_units_and_strips_4byte_start_codes() { + let mut s = NalSplitter::new(); + let mut stream = Vec::new(); + stream.extend(start_code(true, &[0x67, 1, 2, 3])); // SPS (4-byte code) + stream.extend(start_code(false, &[0x68, 4, 5])); // PPS (3-byte code) + stream.extend(start_code(true, &[0x65, 9, 9, 9])); // IDR slice + + // Feed in two arbitrary chunks to exercise the cross-read buffering. + let mut out = s.push(&stream[..7]); + out.extend(s.push(&stream[7..])); + if let Some(tail) = s.flush() { + out.push(tail); + } + assert_eq!( + out, + vec![vec![0x67, 1, 2, 3], vec![0x68, 4, 5], vec![0x65, 9, 9, 9],] + ); + } + + #[test] + fn access_unit_assembler_flushes_on_the_vcl_slice() { + let mut a = AccessUnitAssembler::new(); + assert!(a.push(vec![0x67, 1]).is_none()); // SPS: non-VCL, accumulates + assert!(a.push(vec![0x68, 2]).is_none()); // PPS: non-VCL, accumulates + let au = a + .push(vec![0x65, 3]) + .expect("IDR completes the access unit"); + assert_eq!(au, vec![vec![0x67, 1], vec![0x68, 2], vec![0x65, 3]]); + // A bare non-IDR slice is its own access unit. + let au2 = a + .push(vec![0x61, 4]) + .expect("non-IDR slice completes its AU"); + assert_eq!(au2, vec![vec![0x61, 4]]); + } + + // --- producer send framing (Annex-B for the built-in packetizer) --------- + + #[test] + fn annexb_access_unit_prefixes_each_nal_with_a_long_start_code() { + let sps = vec![0x67, 1, 2, 3]; + let idr = vec![0x65, 9, 9]; + let buf = annexb_access_unit(&[sps.clone(), idr.clone()]); + assert_eq!( + buf, + vec![0, 0, 0, 1, 0x67, 1, 2, 3, 0, 0, 0, 1, 0x65, 9, 9], + "each NAL is preceded by a 4-byte start code so the chain's \ + long-start-sequence separator can split them" + ); + } + + #[test] + fn annexb_access_unit_skips_empty_nals() { + assert!(annexb_access_unit(&[vec![]]).is_empty()); + assert_eq!( + annexb_access_unit(&[vec![], vec![0x41, 7]]), + vec![0, 0, 0, 1, 0x41, 7] + ); + } + + // --- single-slice encoder invariant (Chrome P-frame assembly) ------------ + + #[test] + fn x264_params_force_a_single_slice_per_frame() { + // The producer sends each access unit as one RTP frame (its own timestamp + // + marker), and AccessUnitAssembler flushes one access unit per VCL NAL, + // so the encoder MUST emit exactly one slice per frame. `-tune zerolatency` + // otherwise enables slice-based threading (one slice per worker thread), + // which Chrome's timestamp-keyed frame assembler cannot reassemble — it + // sees N partial pseudo-frames per real frame and never completes the + // P-frames. Pin both levers that guarantee a single slice. + let params = x264_params("45"); + assert!( + params.contains("slices=1"), + "must request a single slice: {params}" + ); + assert!( + params.contains("threads=1"), + "must run x264 fully serial: one slice/frame and no frame-thread \ + pipeline latency (zerolatency's slice threading would multi-slice): {params}" + ); + // The keyint is interpolated into both keyint and min-keyint. + assert!(params.contains("keyint=45:min-keyint=45"), "keyint wired: {params}"); + } + + // --- one-VCL-NAL-per-access-unit invariant ------------------------------- + + #[test] + fn vcl_nal_count_is_one_for_a_normal_access_unit_and_trips_on_multi_slice() { + // A normal access unit — SPS + PPS + a single IDR slice (the assembler's + // steady-state output) — carries exactly one VCL NAL and so maps to one + // RTP frame under one capture timestamp. + let normal = vec![vec![0x67, 1, 2], vec![0x68, 3], vec![0x65, 9, 9]]; + assert_eq!(vcl_nal_count(&normal), 1, "SPS+PPS+IDR is one VCL NAL"); + // A bare non-IDR slice is also a single VCL NAL. + assert_eq!(vcl_nal_count(&[vec![0x41, 7]]), 1); + // Two coded slices grouped into one access unit — what a multi-slice / + // NAL-aggregation change would produce — has >1 VCL NAL, so the send + // path's `!= 1` guard trips and drops the AU rather than emitting + // out-of-order or fabricated per-slice timestamps. + let multi_slice = vec![vec![0x65, 1], vec![0x65, 2], vec![0x65, 3]]; + assert_eq!(vcl_nal_count(&multi_slice), 3, "multi-slice trips the guard"); + // Parameter sets alone are not a complete frame (zero VCL NALs). + assert_eq!(vcl_nal_count(&[vec![0x67, 1], vec![0x68, 2]]), 0); + } + + // --- packetizer-init cname guard ----------------------------------------- + + #[test] + fn cname_is_non_null_and_non_empty() { + // rtcSetH264Packetizer returns -1 (and the SR reporter then fails) if the + // packetizer init's cname is null. Pin that the single source is always a + // valid, non-empty C string, so the live path never passes null. + let cname = packetizer_cname(); + assert!(!cname.as_bytes().is_empty(), "cname must be non-empty"); + // CString guarantees NUL termination and no interior NUL. + assert_eq!(cname.to_str(), Ok(PACKETIZER_CNAME)); + } + + // --- RTP depacketizer (kept; the C API has no depacketizer) --------------- + + /// Build the RTP packets for one access unit the way libdatachannel's + /// built-in H.264 packetizer does (single-NAL under the fragment cap, FU-A + /// above it), so the kept depacketizer tests have wire-shaped input without + /// the deleted producer packetizer. Sequence numbers are monotonic from + /// `seq0`; the marker is set on the access unit's final packet. + fn build_rtp(seq0: u16, ssrc: u32, nals: &[Vec]) -> Vec> { + let cap = MAX_FRAGMENT_SIZE as usize; + let mut seq = seq0; + let mut header = |seq: &mut u16| { + let mut h = vec![0x80u8, 96]; + h.extend_from_slice(&seq.to_be_bytes()); + h.extend_from_slice(&0u32.to_be_bytes()); // timestamp (unused here) + h.extend_from_slice(&ssrc.to_be_bytes()); + *seq = seq.wrapping_add(1); + h + }; + let mut out: Vec> = Vec::new(); + for nal in nals { + if nal.len() <= cap { + let mut pkt = header(&mut seq); + pkt.extend_from_slice(nal); + out.push(pkt); + } else { + let f_nri = nal[0] & 0xE0; + let nal_type = nal[0] & 0x1F; + let chunks: Vec<&[u8]> = nal[1..].chunks(cap - 2).collect(); + let last = chunks.len() - 1; + for (i, chunk) in chunks.iter().enumerate() { + let fu_header = ((i == 0) as u8) << 7 | ((i == last) as u8) << 6 | nal_type; + let mut pkt = header(&mut seq); + pkt.push(f_nri | FU_A_TYPE); + pkt.push(fu_header); + pkt.extend_from_slice(chunk); + out.push(pkt); + } + } + } + if let Some(last) = out.last_mut() { + last[1] |= 0x80; + } + out + } + + #[test] + fn round_trip_reassembles_a_fragmented_nal_exactly() { + let mut nal = vec![0x65]; + nal.extend((0..(MAX_FRAGMENT_SIZE as usize * 2 + 37)).map(|i| (i % 253) as u8)); + let pkts = build_rtp(0, 42, &[nal.clone()]); + + let mut d = RtpDepacketizer::new(); + let mut got = Vec::new(); + for pkt in &pkts { + got.extend(d.depacketize(pkt)); + } + assert_eq!(got, vec![nal], "FU-A round trip is byte-exact"); + } + + #[test] + fn round_trip_preserves_a_multi_nal_access_unit() { + let sps = vec![0x67, 1, 2, 3]; + let pps = vec![0x68, 4, 5]; + let mut idr = vec![0x65]; + idr.extend(std::iter::repeat(9u8).take(MAX_FRAGMENT_SIZE as usize + 5)); // forces FU-A + let pkts = build_rtp(0, 7, &[sps.clone(), pps.clone(), idr.clone()]); + + let mut d = RtpDepacketizer::new(); + let mut got = Vec::new(); + for pkt in &pkts { + got.extend(d.depacketize(pkt)); + } + assert_eq!(got, vec![sps, pps, idr]); + } + + #[test] + fn depacketizer_drops_a_retransmitted_duplicate_sequence() { + // The producer's NACK responder retransmits packets; a retransmit (or a + // reordered duplicate) carries a sequence already processed and must not + // surface a second copy of the NAL. + let pkts = build_rtp(10, 7, &[vec![0x67, 1], vec![0x41, 2], vec![0x41, 3]]); + let mut d = RtpDepacketizer::new(); + let mut got = Vec::new(); + for pkt in &pkts { + got.extend(d.depacketize(pkt)); + } + assert_eq!(got.len(), 3, "three distinct NALs"); + // Replay the middle packet (a retransmit): no new NAL is emitted. + assert!( + d.depacketize(&pkts[1]).is_empty(), + "an already-seen sequence is dropped, not re-emitted" + ); + } + + #[test] + fn depacketizer_drops_a_partial_nal_on_a_sequence_gap() { + let mut nal = vec![0x65]; + nal.extend(std::iter::repeat(3u8).take(MAX_FRAGMENT_SIZE as usize * 3)); + let pkts = build_rtp(0, 7, &[nal]); + assert!(pkts.len() >= 3); + + let mut d = RtpDepacketizer::new(); + // Deliver the FU-A start, then SKIP a middle fragment, then the rest. + let mut got = Vec::new(); + got.extend(d.depacketize(&pkts[0])); // start + for pkt in &pkts[2..] { + got.extend(d.depacketize(pkt)); // gap: pkts[1] dropped + } + assert!(got.is_empty(), "a mid-FU gap abandons the corrupt NAL"); + + // The depacketizer recovers cleanly on the next complete single NAL. + let single = build_rtp(500, 7, &[vec![0x67, 1, 2, 3]]); + let recovered = d.depacketize(&single[0]); + assert_eq!(recovered, vec![vec![0x67, 1, 2, 3]]); + } + + // --- drop policy ---------------------------------------------------------- + + /// Drive the drop policy over a fake clock: submit at `fps`, drain (encode) + /// at `encoder_fps`, and report (delivered, dropped) across `seconds`. + fn simulate(capacity: usize, fps: f64, encoder_fps: f64, seconds: f64) -> (usize, usize) { + let policy = DropPolicy::new(capacity); + let submit_dt = 1.0 / fps; + let drain_dt = 1.0 / encoder_fps; + let mut backlog = 0usize; + let mut next_drain = drain_dt; + let (mut delivered, mut dropped) = (0usize, 0usize); + let total = (fps * seconds) as usize; + for i in 0..total { + let now = i as f64 * submit_dt; + while next_drain <= now { + if backlog > 0 { + backlog -= 1; + delivered += 1; + } + next_drain += drain_dt; + } + if policy.admit(backlog) { + backlog += 1; + } else { + dropped += 1; + } + } + delivered += backlog; // the encoder drains the remainder + (delivered, dropped) + } + + #[test] + fn zero_drops_at_or_below_thirty_when_the_encoder_keeps_up() { + // Encoder comfortably faster than a 30fps source -> the queue never fills. + let (delivered, dropped) = simulate(16, 30.0, 45.0, 4.0); + assert_eq!(dropped, 0, "no deliberate drops at or below 30fps"); + assert_eq!(delivered, (30.0f64 * 4.0) as usize, "every frame delivered"); + } + + #[test] + fn sheds_above_thirty_but_holds_the_delivered_floor() { + // 45fps source, encoder sustains ~35fps -> excess is shed, floor holds. + let (delivered, dropped) = simulate(16, 45.0, 35.0, 4.0); + assert!(dropped > 0, "over-rate source sheds the excess"); + let delivered_fps = delivered as f64 / 4.0; + assert!( + delivered_fps >= 30.0, + "delivered {delivered_fps:.1}fps must stay at/above the 30 floor" + ); + } + + #[test] + fn admit_tracks_queue_occupancy() { + let policy = DropPolicy::new(16); + assert!(policy.admit(0), "room -> admit"); + assert!(policy.admit(15), "last slot -> admit"); + assert!(!policy.admit(16), "full -> shed"); + } + + // --- crash-restart policy ------------------------------------------------- + + #[test] + fn restart_policy_permits_up_to_the_budget_then_gives_up() { + // A "subprocess that keeps dying": each death asks should_restart. The + // feed gets the budget of restart attempts, then a terminal give-up so a + // permanently-broken encoder surfaces an error instead of hot-looping. + let mut policy = RestartPolicy::new(3); + assert!(policy.should_restart(), "1st death -> restart"); + assert!(policy.should_restart(), "2nd death -> restart"); + assert!(policy.should_restart(), "3rd death -> restart"); + assert!(!policy.should_restart(), "budget spent -> give up"); + assert!(policy.exhausted()); + } + + #[test] + fn restart_policy_reset_after_a_healthy_run_restores_the_budget() { + // A recovered subprocess (it produced output again) resets the budget, so + // a much later, unrelated transient crash still gets the full retry count. + let mut policy = RestartPolicy::new(2); + assert!(policy.should_restart()); + assert!(policy.should_restart()); + assert!(!policy.should_restart(), "budget spent"); + policy.reset(); + assert!(!policy.exhausted(), "reset clears the exhausted state"); + assert!(policy.should_restart(), "full budget again after a healthy run"); + } + + // --- bounded stderr tail -------------------------------------------------- + + #[test] + fn stderr_tail_is_ring_trimmed_to_the_cap() { + let tail = Mutex::new(String::new()); + // Write well past the cap; only the trailing STDERR_TAIL_CAP bytes survive. + let chunk = "x".repeat(STDERR_TAIL_CAP); + append_stderr_tail(&tail, &chunk); + append_stderr_tail(&tail, "ERROR: ffmpeg died"); + let got = tail.lock().unwrap().clone(); + assert!(got.len() <= STDERR_TAIL_CAP, "tail stays bounded: {}", got.len()); + assert!( + got.ends_with("ERROR: ffmpeg died"), + "the most recent (diagnostic) bytes are retained" + ); + } +} diff --git a/rust/neuracore_webrtc/src/producer.rs b/rust/neuracore_webrtc/src/producer.rs new file mode 100644 index 000000000..52678ae5b --- /dev/null +++ b/rust/neuracore_webrtc/src/producer.rs @@ -0,0 +1,1813 @@ +//! The producer peer: the sole offerer in the negotiation model. +//! +//! The producer owns a libdatachannel [`RtcPeerConnection`] and is the only peer +//! that creates data channels and video tracks. Creating the first data channel +//! makes libdatachannel emit an offer (`on_local_description`); candidates trickle +//! out as `on_local_candidate`; the consumer's answer comes back through +//! [`Producer::set_remote_answer`]. Auto-negotiation is left at the libdatachannel +//! default for **data channels** — additional channels open over the existing SCTP +//! association without further SDP. +//! +//! ## Video tracks and the renegotiation queue (PR3) +//! +//! Unlike `createDataChannel`, libdatachannel's `addTrack` does **not** auto-offer +//! (verified in `impl/peerconnection.cpp`): a track add only changes the local +//! state, and the producer must call `set_local_description(Offer)` itself to +//! renegotiate. libdatachannel also silently drops a track that is added while a +//! prior offer is still in flight (no answer applied yet). The producer therefore +//! serialises all track mutations through a single-writer **negotiation queue**: +//! +//! * `add_video_track` / `remove_video_track` allocate identity, enqueue a +//! [`Mutation`], and signal the pump. They return immediately (the mid is +//! allocated synchronously, so `add_video_track` can return it). +//! * The pump task applies at most **one** mutation per offer/answer cycle: it +//! pops a mutation, applies it (`add_track_ex` + `set_local_description(Offer)`, +//! or drop-the-track + `set_local_description(Offer)`), and marks the cycle +//! in-flight. It does nothing more until the consumer's answer is applied and +//! `on_signaling_state_change(Stable)` fires, which clears the in-flight flag +//! and re-signals the pump to advance. +//! +//! This serialises a burst of adds/removes into one mutation per offer and never +//! mutates tracks mid-cycle, so no track is silently dropped. The transport is +//! provisioned with `force_media_transport` up front (see +//! [`crate::transport::loopback_config`]) so the first track reuses the existing +//! BUNDLE/DTLS transport — a track add never triggers a second DTLS handshake. +//! +//! The pump never calls into libdatachannel from inside a libdatachannel callback: +//! the signaling-state callback only sets a flag and pings the pump channel; the +//! pump runs on the tokio runtime, outside any callback (the same discipline the +//! data-channel flusher uses for `on_open`). It also never holds the negotiation +//! lock while taking the peer-connection lock, so it cannot deadlock against the +//! callback (which takes the negotiation lock while libdatachannel holds the PC). +//! +//! ## Pre-open send gate +//! +//! libdatachannel rejects a send before the channel's SCTP stream is open, so +//! each outgoing channel buffers sends until its `on_open` fires. The handler +//! cannot safely send from inside its own `on_open` callback (that would alias +//! the channel libdatachannel is mid-callback on), so it instead signals a +//! per-producer flusher task over a channel; the task takes the channels lock — +//! outside any callback — and replays the buffer in order via the safe +//! `RtcDataChannel::send`. +//! +//! Still real from PR0: the Rust-owned tokio runtime, the bounded `submit_frame` +//! queue, and the drainable event queue. PR4 feeds encoded RTP into the per-track +//! [`RtcTrack`] handles registered here. + +use std::collections::{HashMap, VecDeque}; +use std::ffi::CString; +use std::os::raw::{c_char, c_int, c_uint, c_void}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::thread::JoinHandle; +use std::time::Instant; + +use datachannel::{ + ConnectionState, DataChannelHandler, DataChannelInfo, IceCandidate, IceState, + PeerConnectionHandler, RtcDataChannel, RtcPeerConnection, SdpType, SessionDescription, + SignalingState, +}; +use datachannel_sys as sys; +use once_cell::sync::Lazy; +use pyo3::buffer::PyBuffer; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyList; +use tokio::sync::mpsc; + +use crate::congestion::{CongestionController, TrackControl, LADDER, TOP_STEP}; +use crate::events::{emit_closed_once, Event, EventQueue}; +use crate::media::{ + annexb_access_unit, is_open, open_flag, packetizer_cname, vcl_nal_count, DropPolicy, + EncodeParams, H264Encoder, OpenFlag, RestartPolicy, MAX_FRAGMENT_SIZE, VIDEO_CLOCK_HZ, +}; +use crate::runtime::{ensure_started, runtime}; +use crate::transport::{ + chrome_sdp_enabled, connection_state_str, lock, loopback_config, map_err, munge_ssrc_cname, + parse_session, raw_pc_id, sdp_type_str, ManifestState, CONTROL_LABEL, +}; + +/// The NACK responder's stored-packet history depth (packets it can retransmit on +/// a NACK). One second at the source rate is ample for loopback/LAN RTT. +const NACK_HISTORY: c_uint = 512; + +// --------------------------------------------------------------------------- +// Producer-side RTCP feedback routing (sys-level chain callbacks) +// --------------------------------------------------------------------------- +// +// The producer's video track is created through `rtcAddTrackEx` (see +// `apply_mutation`), so unlike a datachannel-rs `RtcTrack` it has no safe handler +// — its chain callbacks are bare `extern "C"` functions. The REMB handler, the +// PLI handler, and the inbound-RTCP message callback (which carries RR) all route +// through this process-global registry keyed by libdatachannel's integer track +// id, exactly as the consumer's inbound-track callbacks do. We never touch the +// PC's user pointer (datachannel-rs owns it). + +/// Process-global: producer track id -> its congestion controller. Shared with +/// the broadcaster (its per-consumer tracks register here too) — the registry is +/// keyed by libdatachannel's globally-unique integer track id, so one map serves +/// every peer connection in the process. +pub(crate) static PRODUCER_FB: Lazy>>> = + Lazy::new(Default::default); + +/// Cap on a track's capture-timestamp queue. A healthy 1:1 pipeline keeps it tiny +/// (one push per frame written, one pop per access unit emitted); the cap is a +/// belt-and-braces bound so a transiently-stalled or mid-crash encoder (writes +/// accepted, no access units emitted) can never grow it without bound. This is the +/// only place besides the ingress queue that could grow under fault, so it is +/// explicitly bounded. See [`push_capture_ts`]. +pub(crate) const TS_QUEUE_CAP: usize = 256; + +/// Push a capture timestamp onto a track's queue, dropping the oldest once the cap +/// is reached so the queue stays bounded under any fault. +pub(crate) fn push_capture_ts(queue: &Mutex>, ts: u32) { + let mut q = lock(queue); + if q.len() >= TS_QUEUE_CAP { + q.pop_front(); + } + q.push_back(ts); +} + +/// Deregister a producer/broadcaster sys track from the process-global feedback +/// registry. Split from the libdatachannel teardown ([`teardown_sys_track`]) so the +/// registry-hygiene invariant is unit-testable without a live track. +pub(crate) fn deregister_feedback(raw_id: i32) { + lock(&PRODUCER_FB).remove(&raw_id); +} + +/// The number of live feedback controllers in `PRODUCER_FB`. A diagnostics +/// accessor the soak test reads to assert the registry returns to baseline after +/// churn (no leaked entries). +pub(crate) fn producer_fb_len() -> usize { + lock(&PRODUCER_FB).len() +} + +/// Fully tear down a producer/broadcaster sys track: deregister its feedback +/// controller, clear its RTCP message callback, and delete the track in +/// libdatachannel. Every remove, every close, AND every mid-setup error path +/// funnels through here, so no `PRODUCER_FB` entry or chain callback is ever +/// leaked — even when setup fails partway. +pub(crate) fn teardown_sys_track(raw_id: i32) { + deregister_feedback(raw_id); + // SAFETY: `raw_id` is a sys track this peer created and still owns; clearing + // the callback before deletion stops any in-flight RTCP routing. + unsafe { + sys::rtcSetMessageCallback(raw_id, None); + sys::rtcDeleteTrack(raw_id); + } +} + +/// REMB handler callback: the receiver-estimated max bitrate for this track. +pub(crate) unsafe extern "C" fn on_remb_cb(tr: c_int, bitrate: c_uint, _ptr: *mut c_void) { + if let Some(ctrl) = lock(&PRODUCER_FB).get(&tr).cloned() { + ctrl.on_remb(bitrate); + } +} + +/// PLI handler callback: the receiver wants a keyframe. Coalesced into a single +/// pending request the feed satisfies via an encoder restart. +pub(crate) unsafe extern "C" fn on_pli_cb(tr: c_int, _ptr: *mut c_void) { + if let Some(ctrl) = lock(&PRODUCER_FB).get(&tr).cloned() { + ctrl.on_pli(); + } +} + +/// Inbound-RTCP message callback on the producer's (send-only) track. The chain's +/// SR reporter is outgoing-only, so inbound RR/REMB pass through to here; we +/// hand-parse the RR report blocks (the C API decodes none). Binary messages +/// carry a non-negative size. +pub(crate) unsafe extern "C" fn on_rtcp_cb( + id: c_int, + msg: *const c_char, + size: c_int, + _ptr: *mut c_void, +) { + if size < 0 || msg.is_null() { + return; + } + let buf = std::slice::from_raw_parts(msg as *const u8, size as usize); + if let Some(ctrl) = lock(&PRODUCER_FB).get(&id).cloned() { + ctrl.on_rtcp(buf); + } +} + +/// Persistent per-track ffmpeg encoders, keyed by `track_id`. Created lazily on +/// the first frame for a track (when its dimensions are known) and dropped on +/// close (which kills the subprocess). +type Encoders = Arc>>>; + +/// How many pre-open frames to stash per track before the track's SRTP is up. +/// The feed thread drains the bounded ingress queue into this stash (so +/// `submit_frame` never drops while the connection is still coming up) and +/// flushes it in order — IDR first — the moment the track opens. Generous enough +/// to cover the sub-second open window at the source rate; the oldest frames are +/// the IDR we most want, so an over-long open sheds the newest. +pub(crate) const PREOPEN_STASH_FRAMES: usize = 32; + +/// Capacity of the bounded queue behind [`Producer::submit_frame`]. Frames are +/// dropped (never block the caller) once this many are in flight; PR4 replaces +/// the stub drain with the real encoder feed and may revisit the depth. +pub(crate) const FRAME_QUEUE_CAPACITY: usize = 16; + +/// The H.264 dynamic payload type advertised on every video track. Constrained +/// baseline / packetization-mode 1 is the first browser target; for the PR3 +/// loopback (both peers libdatachannel) only consistency matters. +pub(crate) const VIDEO_PAYLOAD_TYPE: i32 = 96; + +/// One raw frame handed off by `submit_frame`: its owned bytes, the routing +/// track id, and the dimensions the encoder needs to start ffmpeg. The pixel +/// format is fixed rgb24 (8-bit, 3-channel, the synthetic source shape). +pub(crate) struct Frame { + pub(crate) track_id: String, + pub(crate) data: Vec, + pub(crate) width: u32, + pub(crate) height: u32, + /// Capture time in 90 kHz RTP units (from the producer's epoch at + /// `submit_frame`). Carried through to the encoder output so each access + /// unit's RTP timestamp reflects when the frame was *captured*, not when the + /// (bursty) encoder happened to emit it — a strict receiver (Chrome) needs a + /// timestamp cadence that matches real arrival or it discards frames. + pub(crate) capture_ts: u32, +} + +/// The slice of a raw frame buffer `submit_frame` needs: whether it is +/// C-contiguous, its shape, and a copy of its bytes. Abstracted behind a trait so +/// the contiguity/shape gate and the copy can be unit-tested with a fake, without +/// constructing a live `PyBuffer` (which needs the interpreter). Production +/// implements it over `PyBuffer` (the buffer-protocol view of the caller's +/// numpy array); the bytes are copied because that array may be reused the +/// instant `submit_frame` returns. +pub(crate) trait FrameBytes { + fn is_contiguous(&self) -> bool; + fn shape(&self) -> Vec; + fn to_owned_bytes(&self) -> Vec; +} + +impl FrameBytes for PyBuffer { + fn is_contiguous(&self) -> bool { + self.is_c_contiguous() + } + + fn shape(&self) -> Vec { + self.shape().to_vec() + } + + fn to_owned_bytes(&self) -> Vec { + let len = self.item_count(); + // SAFETY: the GIL is held (this runs inside a `#[pymethods]` call), the + // buffer is a validated `u8` C-contiguous buffer, the length comes from + // `PyBuffer::item_count`, and we only read. + unsafe { std::slice::from_raw_parts(self.buf_ptr() as *const u8, len).to_vec() } + } +} + +/// A validated frame: its bytes plus the dimensions extracted from the buffer's +/// shape. Channels are fixed at 3 (rgb24/bgr24). +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct FrameData { + pub(crate) data: Vec, + pub(crate) width: u32, + pub(crate) height: u32, +} + +/// Why a frame buffer was rejected by [`read_frame`]. Kept as an enum rather +/// than a `PyErr` so the validation decision is testable without the +/// interpreter; `submit_frame` maps each to the `ValueError` callers see. +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum FrameError { + /// The buffer is not C-contiguous (the encoder needs a packed row layout). + NotContiguous, + /// The buffer is not an 8-bit HxWx3 image (the only format the encoder feeds). + BadShape, +} + +/// Validate and copy a frame buffer: reject a non-C-contiguous buffer (the same +/// methodology as the disk recording path's `log_frame`), require an `HxWx3` +/// 8-bit shape so the encoder knows the input format, then copy its bytes out. +/// Pure given the [`FrameBytes`] seam. +pub(crate) fn read_frame(buffer: &B) -> Result { + if !buffer.is_contiguous() { + return Err(FrameError::NotContiguous); + } + let shape = buffer.shape(); + let [height, width, channels] = shape[..] else { + return Err(FrameError::BadShape); + }; + if channels != 3 || width == 0 || height == 0 { + return Err(FrameError::BadShape); + } + Ok(FrameData { + data: buffer.to_owned_bytes(), + width: width as u32, + height: height as u32, + }) +} + +/// Shared map of the producer's outgoing data channels, keyed by label. +pub(crate) type Channels = Arc>>; + +/// Shared map of the producer's outgoing video tracks, keyed by `track_id`. +type Tracks = Arc>>; + +/// A producer-owned outgoing data channel plus its pre-open send buffer. +pub(crate) struct OutgoingEntry { + /// The live channel; dropping it deletes the channel in libdatachannel. + pub(crate) channel: Box>, + /// True once the channel's SCTP stream is open. + pub(crate) open: bool, + /// Bytes submitted before the channel opened, replayed in order on open. + pub(crate) pending: VecDeque>, +} + +impl OutgoingEntry { + /// Send now if open, else buffer. Flushes any backlog first to preserve order. + pub(crate) fn send(&mut self, bytes: Vec) { + if self.open { + self.flush(); + let _ = self.channel.send(&bytes); + } else { + self.pending.push_back(bytes); + } + } + + pub(crate) fn flush(&mut self) { + while let Some(message) = self.pending.pop_front() { + let _ = self.channel.send(&message); + } + } +} + +/// A producer-owned outgoing video track. PR5 creates the track through the sys +/// layer (`rtcAddTrackEx`) so it can attach libdatachannel's built-in H.264 +/// chain (packetizer + SR/NACK/PLI/REMB) to the raw id — datachannel-rs keeps a +/// safe `RtcTrack`'s id private. We therefore own the track's lifecycle manually: +/// `rtcDeleteTrack(raw_id)` on remove/close (see [`apply_mutation`] / +/// [`Producer::close`]), in place of PR4's drop-the-`Box` lifecycle. +struct TrackEntry { + mid: String, + /// libdatachannel's integer id for the track, from `rtcAddTrackEx`. The + /// encoder feed sends NAL units on it via `rtcSendMessage`; remove/close + /// deletes it via `rtcDeleteTrack`. + raw_id: i32, + /// Flipped true when the track's renegotiation completes (its answer applied). + /// The encoder feed holds frames until this is set so the first sent access + /// unit (always an IDR) is the first thing the consumer receives. + open: OpenFlag, + /// The adaptation effect surface: the estimator (driven by the RTCP callbacks) + /// publishes a ladder rung here and the feed thread applies it (fps cap + + /// encoder restart). Shared with the [`CongestionController`] in + /// [`PRODUCER_FB`]. + control: Arc, + /// Capture timestamps (90 kHz) queued in submit order, one pushed per frame + /// written to the encoder and popped per emitted access unit. Baseline H.264 + /// is in-order, so the Nth output frame carries the Nth input frame's capture + /// time. Shared (`Arc`) so it survives encoder restarts. + ts_queue: Arc>>, +} + +/// The SDP m-line mid for a video source: the caller-supplied `track_id`, used +/// verbatim. The caller (the Python provider for the broadcaster, the loopback +/// harness for the 1:1 producer) owns the mid and registers the same value for +/// browser identity (available_robots), so the offer's a=mid and that record +/// always agree. There is deliberately no per-connection "v{n}" counter: that +/// produced a wire mid the identity record never matched. The caller keeps mids +/// unique within a peer connection and must not reuse the data channel's m-line +/// mid ("0"); libdatachannel silently drops a video track whose mid collides with +/// the data channel (the Python provider prefixes "v" to stay clear of it). +pub(crate) fn track_mid(track_id: &str) -> String { + track_id.to_string() +} + +/// One queued track mutation. Removal is keyed by `track_id` (symmetric with the +/// add); the mid is recovered from the registry when the removal is applied. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum Mutation { + Add { + track_id: String, + mid: String, + ssrc: u32, + }, + Remove { + track_id: String, + }, +} + +impl Mutation { + /// The track id this mutation targets (the queue's stable identity for both + /// add and remove). Used by the negotiation-queue unit tests to assert the + /// order mutations are applied in. + #[cfg(test)] + fn track_id(&self) -> &str { + match self { + Mutation::Add { track_id, .. } | Mutation::Remove { track_id } => track_id, + } + } +} + +/// The single-writer negotiation queue state. `in_flight` is true between +/// applying a mutation (offer sent) and the consumer's answer being applied +/// (signaling back to Stable). At most one offer/answer cycle is ever in flight. +#[derive(Default)] +pub(crate) struct NegState { + pub(crate) in_flight: bool, + pub(crate) pending: VecDeque, +} + +/// Per-data-channel handler for the producer's outgoing channels. On open it +/// signals the flusher task to drain the channel's pre-open send buffer. The +/// channels are send-only here, so inbound messages are ignored. +pub(crate) struct ProducerChannelHandler { + label: String, + flush_tx: mpsc::UnboundedSender, +} + +impl ProducerChannelHandler { + /// Build a channel handler that signals `flush_tx` (with `label`) when its + /// SCTP stream opens. Reused by the broadcaster's per-consumer control + /// channels, which share the producer's pre-open send/flush discipline. + pub(crate) fn new(label: String, flush_tx: mpsc::UnboundedSender) -> Self { + Self { label, flush_tx } + } +} + +impl DataChannelHandler for ProducerChannelHandler { + fn on_open(&mut self) { + // Do not send from inside on_open: that would re-enter the channel + // libdatachannel is mid-callback on. Defer the flush to the task, which + // sends from outside any callback. + let _ = self.flush_tx.send(self.label.clone()); + } +} + +/// Peer-connection handler for the producer: relays libdatachannel's signaling +/// and state callbacks onto the drainable event queue, and drives the +/// renegotiation queue forward when a cycle completes. All callbacks fire on +/// libdatachannel threads and only touch mutex-backed shared state (never the PC +/// itself — the pump does that off-callback). +pub(crate) struct ProducerHandler { + events: EventQueue, + /// Cleared to advance the queue when a negotiation cycle returns to Stable. + neg: Arc>, + /// The open flag of the track whose add-renegotiation is in flight, if any. + /// Flipped true when that cycle returns to Stable — the precise point the + /// consumer is provably ready to receive the track's RTP. See + /// [`on_signaling_state_change`](Self::on_signaling_state_change). + pending_open: Arc>>, + /// Pinged after `in_flight` is cleared so the pump applies the next mutation. + pump_tx: mpsc::UnboundedSender<()>, + /// Set once a reconnect-needed error has been surfaced for the current outage, + /// cleared on `Connected`, so a Disconnected->Failed sequence surfaces it once + /// per outage rather than on every transition. + reconnect_surfaced: Arc, +} + +impl PeerConnectionHandler for ProducerHandler { + type DCH = ProducerChannelHandler; + + fn data_channel_handler(&mut self, _info: DataChannelInfo) -> Self::DCH { + // The consumer never opens channels back to the producer, so this + // factory is effectively unused; hand back a detached handler. + let (flush_tx, _flush_rx) = mpsc::unbounded_channel(); + ProducerChannelHandler { + label: String::new(), + flush_tx, + } + } + + fn on_description(&mut self, sess_desc: SessionDescription) { + // Chrome-only SDP munge (gated, so the loopback path is byte-identical): + // give libdatachannel's bare `a=ssrc:` line a cname, which Chrome's + // stricter parser requires before it will set up the receive track. + let mut sdp = sess_desc.sdp.to_string(); + if sess_desc.sdp_type == SdpType::Offer && chrome_sdp_enabled() { + sdp = munge_ssrc_cname(&sdp, crate::media::PACKETIZER_CNAME); + } + self.events.push(Event::LocalDescription { + sdp_type: sdp_type_str(&sess_desc.sdp_type).to_string(), + sdp, + }); + } + + fn on_candidate(&mut self, cand: IceCandidate) { + self.events.push(Event::LocalCandidate { + candidate: cand.candidate, + mid: Some(cand.mid), + }); + } + + fn on_connection_state_change(&mut self, state: ConnectionState) { + crate::transport::debug_trace("P", connection_state_str(&state)); + // The constructor emits the initial "new"; skip the duplicate so a + // single new->connecting->connected sequence is observed. + if state == ConnectionState::New { + return; + } + if state == ConnectionState::Connected { + // A fresh connection re-arms the one-shot reconnect surface. + self.reconnect_surfaced.store(false, Ordering::SeqCst); + } + self.events + .push(Event::State(connection_state_str(&state).to_string())); + // Reconnect handling: this binding cannot restart ICE (libjuice + // single-shot agent), so a Disconnected/Failed connection surfaces a clear + // reconnect-needed error once per outage and the app removes + re-adds the + // peer. If the binding ever gains ICE restart, the seam returns IceRestart. + match crate::transport::reconnect_action(state, crate::transport::ICE_RESTART_SUPPORTED) { + crate::transport::ReconnectAction::SurfaceReconnect => { + if !self.reconnect_surfaced.swap(true, Ordering::SeqCst) { + self.events.push(Event::error( + "connection", + "reconnect-needed: connection failed and ICE restart is \ + unsupported on this binding — remove and re-add the peer", + )); + } + } + crate::transport::ReconnectAction::IceRestart + | crate::transport::ReconnectAction::None => {} + } + } + + fn on_signaling_state_change(&mut self, state: SignalingState) { + crate::transport::debug_trace("P", &format!("sig:{state:?}")); + // A return to Stable means the in-flight offer's answer has been applied + // (or there was nothing in flight). Clear the gate and wake the pump so + // it applies the next queued mutation. Done off the PC: we only touch the + // neg/pending locks here, never call back into libdatachannel from its + // callback. + if state == SignalingState::Stable { + // If a track add just completed, the consumer has processed the media + // offer and answered it, so it is now ready to receive. Open the + // track for the encoder feed at exactly this point (not at the track's + // own too-early SRTP on_open). + if let Some(open) = lock(&self.pending_open).take() { + open.store(true, Ordering::SeqCst); + crate::transport::debug_trace("P", "gate-open (renegotiation complete)"); + } + lock(&self.neg).in_flight = false; + let _ = self.pump_tx.send(()); + } + } + + fn on_ice_state_change(&mut self, state: IceState) { + crate::transport::debug_trace("P", &format!("ice:{state:?}")); + } +} + +/// The producer-side WebRTC peer exposed to Python. +#[pyclass] +pub struct Producer { + events: EventQueue, + /// The bounded ingress sender, behind an `Option` so [`close`](Self::close) can + /// drop it: dropping the last sender ends the feed thread's `blocking_recv`, so + /// close stops the feed deterministically instead of leaking it until the + /// `Producer` is garbage-collected. + frame_tx: Mutex>>, + /// The encoder feed thread, joined on close so no thread outlives the producer. + feed_handle: Mutex>>, + closed: Arc, + /// Producer epoch: `submit_frame` stamps each frame's capture time as + /// `elapsed_since(epoch)` in 90 kHz units. + epoch: Instant, + /// The libdatachannel peer connection. Shared (`Arc`) so the pump task can + /// apply track mutations on it off-callback; dropped on `close`. + pc: Arc>>>>, + /// Outgoing data channels keyed by label (includes the control channel). + channels: Channels, + /// Outgoing video tracks keyed by track_id (the RTP send registry). + tracks: Tracks, + /// Persistent per-track ffmpeg encoders, created lazily by the feed thread. + encoders: Encoders, + /// The published stream manifest (data channels keyed by label, video tracks + /// keyed by mid). Shared with the flusher and the pump's manifest republish. + manifest: Arc>, + /// The single-writer renegotiation queue, shared with the pump and handler. + neg: Arc>, + /// Allocates a unique RTP SSRC per track. + ssrc_counter: AtomicU64, + /// Channels signal this on open so the flusher drains their send buffers. + flush_tx: mpsc::UnboundedSender, + /// Pinged on every track mutation (and from the handler on Stable) to wake + /// the pump. + pump_tx: mpsc::UnboundedSender<()>, +} + +#[pymethods] +impl Producer { + /// Create a producer. `connection_id` is an opaque label used only for + /// logging/correlation; the Python signaling layer owns connection + /// identity. `frame_queue_capacity` sizes the bounded `submit_frame` queue. + #[new] + #[pyo3(signature = (connection_id=None, frame_queue_capacity=FRAME_QUEUE_CAPACITY))] + fn new(connection_id: Option, frame_queue_capacity: usize) -> PyResult { + let _ = connection_id; + ensure_started(); + + let (frame_tx, frame_rx) = mpsc::channel::(frame_queue_capacity.max(1)); + + let channels: Channels = Arc::new(Mutex::new(HashMap::new())); + let tracks: Tracks = Arc::new(Mutex::new(HashMap::new())); + let encoders: Encoders = Arc::new(Mutex::new(HashMap::new())); + let events = EventQueue::default(); + // The encoder feed: drains the bounded ingress queue, lazily spins up one + // ffmpeg encoder per track, gates on track-open, restarts a crashed encoder + // (surfacing on_error), and hands each encoded access unit to the + // packetize+send stage. Its handle is joined on close. + let feed_handle = spawn_feed(frame_rx, encoders.clone(), tracks.clone(), events.clone()); + let manifest = Arc::new(Mutex::new(ManifestState::default())); + let neg = Arc::new(Mutex::new(NegState::default())); + let pending_open: Arc>> = Arc::new(Mutex::new(None)); + let (flush_tx, flush_rx) = mpsc::unbounded_channel::(); + let (pump_tx, pump_rx) = mpsc::unbounded_channel::<()>(); + Self::spawn_flusher(channels.clone(), manifest.clone(), flush_rx); + + let handler = ProducerHandler { + events: events.clone(), + neg: neg.clone(), + pending_open: pending_open.clone(), + pump_tx: pump_tx.clone(), + reconnect_surfaced: Arc::new(AtomicBool::new(false)), + }; + let pc = Arc::new(Mutex::new(Some( + RtcPeerConnection::new(&loopback_config(), handler).map_err(map_err)?, + ))); + events.push(Event::State("new".to_string())); + + Self::spawn_pump( + pc.clone(), + tracks.clone(), + encoders.clone(), + manifest.clone(), + channels.clone(), + pending_open.clone(), + neg.clone(), + events.clone(), + pump_rx, + ); + + Ok(Self { + events, + frame_tx: Mutex::new(Some(frame_tx)), + feed_handle: Mutex::new(Some(feed_handle)), + closed: Arc::new(AtomicBool::new(false)), + epoch: Instant::now(), + pc, + channels, + tracks, + encoders, + manifest, + neg, + ssrc_counter: AtomicU64::new(1), + flush_tx, + pump_tx, + }) + } + + /// Add a video track and return its negotiated mid. The caller owns the mid: + /// the supplied `track_id` is used verbatim as the SDP m-line mid, so the value + /// the caller registers for identity (e.g. available_robots) and the offer's + /// a=mid are the same. The caller must keep mids unique within the peer + /// connection and must not reuse the data m-line mid ("0"); libdatachannel drops + /// a track whose mid collides with the data channel. The mid is returned + /// synchronously (so callers learn it immediately); the actual `add_track_ex` + /// and the renegotiation it triggers are serialised through the queue, so a + /// burst of adds never overlaps an in-flight offer. + fn add_video_track(&self, track_id: &str) -> PyResult { + if self.closed.load(Ordering::SeqCst) { + return Err(PyValueError::new_err("producer is closed")); + } + let mid = track_mid(track_id); + let ssrc = self.ssrc_counter.fetch_add(1, Ordering::SeqCst) as u32; + lock(&self.neg).pending.push_back(Mutation::Add { + track_id: track_id.to_string(), + mid: mid.clone(), + ssrc, + }); + let _ = self.pump_tx.send(()); + Ok(mid) + } + + /// Remove a previously-added video track by its `track_id`, routed through the + /// queue. The consumer learns of the removal by mid via the republished + /// manifest (libdatachannel surfaces no incoming-track callback). + fn remove_video_track(&self, track_id: &str) -> PyResult<()> { + if self.closed.load(Ordering::SeqCst) { + return Err(PyValueError::new_err("producer is closed")); + } + lock(&self.neg).pending.push_back(Mutation::Remove { + track_id: track_id.to_string(), + }); + let _ = self.pump_tx.send(()); + Ok(()) + } + + /// Open a reliable-ordered data channel. The first channel triggers the + /// offer (negotiation-needed); later channels open over the existing SCTP + /// association. `kind` is an opaque label hint recorded in the manifest. + /// The reserved `"control"` label carries the manifest and is not itself + /// listed in it. + fn add_data_channel(&self, label: &str, kind: &str) -> PyResult<()> { + if self.closed.load(Ordering::SeqCst) { + return Err(PyValueError::new_err("producer is closed")); + } + let handler = ProducerChannelHandler { + label: label.to_string(), + flush_tx: self.flush_tx.clone(), + }; + let channel = { + let mut guard = lock(&self.pc); + let pc = guard + .as_mut() + .ok_or_else(|| PyValueError::new_err("producer is closed"))?; + pc.create_data_channel(label, handler).map_err(map_err)? + }; + lock(&self.channels).insert( + label.to_string(), + OutgoingEntry { + channel, + open: false, + pending: VecDeque::new(), + }, + ); + + if label != CONTROL_LABEL { + lock(&self.manifest).upsert_data_channel(label, kind); + } + // Republish the manifest on every change (atomic full-state message). + // Buffers until the control channel opens, after which the flusher + // replays it, so the consumer always converges on the latest set. + republish(&self.channels, &self.manifest); + Ok(()) + } + + /// Send a JSON payload (already-serialised text) over the named data + /// channel. This is the single channel send path both `send_json` and the + /// recording-context `log_*` bridge reach. + fn send_json(&self, label: &str, payload: &str) -> PyResult<()> { + let mut map = lock(&self.channels); + let entry = map + .get_mut(label) + .ok_or_else(|| PyValueError::new_err(format!("no data channel labelled {label:?}")))?; + entry.send(payload.as_bytes().to_vec()); + Ok(()) + } + + /// Enqueue one raw frame for `track_id` onto the bounded queue and return + /// immediately. Never blocks: under overload the frame is dropped rather + /// than back-pressuring the caller. The frame buffer must be C-contiguous + /// (same zero-copy methodology as the disk recording path); the bytes are + /// copied under the GIL because the caller's numpy array may be reused the + /// instant this returns. + #[pyo3(signature = (track_id, frame))] + fn submit_frame(&self, track_id: &str, frame: PyBuffer) -> PyResult<()> { + let FrameData { + data, + width, + height, + } = read_frame(&frame).map_err(|err| match err { + FrameError::NotContiguous => PyValueError::new_err("frame buffer must be C-contiguous"), + FrameError::BadShape => PyValueError::new_err("frame must be an 8-bit HxWx3 image"), + })?; + let capture_ts = (self.epoch.elapsed().as_secs_f64() * VIDEO_CLOCK_HZ as f64) as u32; + let job = Frame { + track_id: track_id.to_string(), + data, + width, + height, + capture_ts, + }; + // Drop policy: admit only while the bounded ingress queue has room. With + // room (the steady state at or below the encoder's throughput) nothing is + // shed; once the encoder backs the queue to capacity, frames are dropped + // here. Never blocks the caller. After close the sender is gone -> no-op. + let guard = lock(&self.frame_tx); + let Some(frame_tx) = guard.as_ref() else { + return Ok(()); + }; + let capacity = frame_tx.max_capacity(); + let backlog = capacity - frame_tx.capacity(); + if !DropPolicy::new(capacity).admit(backlog) { + return Ok(()); + } + let _ = frame_tx.try_send(job); + Ok(()) + } + + /// Apply the consumer's SDP answer to complete a negotiation round. + fn set_remote_answer(&self, sdp: &str) -> PyResult<()> { + let sess = parse_session(sdp, SdpType::Answer)?; + let mut guard = lock(&self.pc); + let pc = guard + .as_mut() + .ok_or_else(|| PyValueError::new_err("producer is closed"))?; + pc.set_remote_description(&sess).map_err(map_err) + } + + /// Apply a remote ICE candidate trickled from the consumer. + #[pyo3(signature = (candidate, mid=None))] + fn add_remote_candidate(&self, candidate: &str, mid: Option) -> PyResult<()> { + let cand = IceCandidate { + candidate: candidate.to_string(), + mid: mid.unwrap_or_default(), + }; + let mut guard = lock(&self.pc); + let pc = guard + .as_mut() + .ok_or_else(|| PyValueError::new_err("producer is closed"))?; + pc.add_remote_candidate(&cand).map_err(map_err) + } + + /// Drain and return all queued events as a list of dicts. See the + /// [`events`](crate::events) module for the dict schema. + fn drain_events(&self, py: Python<'_>) -> PyResult> { + self.events.drain_to_py(py) + } + + /// The congestion ladder rung a track is currently encoded at (0 = finest, + /// higher = coarser), or `None` for an unknown track. The structured signal a + /// constrained-link test reads to confirm the estimator is adapting. + fn congestion_step(&self, track_id: &str) -> Option { + lock(&self.tracks) + .get(track_id) + .map(|entry| entry.control.desired_step()) + } + + /// The coarsest ladder rung a track ever reached (a high-water mark), so a + /// test sees adaptation fired even after the link recovered and the rung + /// stepped back up. `None` for an unknown track. + fn congestion_max_step(&self, track_id: &str) -> Option { + lock(&self.tracks) + .get(track_id) + .map(|entry| entry.control.max_step()) + } + + /// Close the producer. Idempotent: the first call drops every data channel, + /// every track, and the peer connection, then emits a final + /// `on_state: "closed"`. Dropping the channels/tracks/PC releases the + /// handler-held flush/pump senders, so the flusher and pump tasks end. + fn close(&self) -> PyResult<()> { + if emit_closed_once(&self.closed, &self.events) { + // Stop the feed: dropping the ingress sender ends the feed thread's + // blocking_recv. Join it so no thread (and, once it drops its encoder + // clones, no ffmpeg subprocess) outlives the producer. + *lock(&self.frame_tx) = None; + if let Some(handle) = lock(&self.feed_handle).take() { + let _ = handle.join(); + } + // Drop encoders first so their ffmpeg subprocesses are killed before + // the tracks they send on go away. + lock(&self.encoders).clear(); + lock(&self.channels).clear(); + // Deregister each sys track's feedback controller and clear its + // message callback before the PC drop frees the tracks, so no chain + // callback races teardown with a stale registry entry. + { + let tracks = lock(&self.tracks); + let mut fb = lock(&PRODUCER_FB); + for entry in tracks.values() { + fb.remove(&entry.raw_id); + // SAFETY: raw_id is this producer's track, still alive until + // the PC is dropped just below. + unsafe { sys::rtcSetMessageCallback(entry.raw_id, None) }; + } + } + lock(&self.tracks).clear(); + *lock(&self.pc) = None; + } + Ok(()) + } +} + +impl Producer { + /// Spawn the task that drains a channel's pre-open send buffer once it opens. + /// For the control channel it also re-sends the current manifest so a freshly + /// connected consumer gets the up-to-date stream set. + fn spawn_flusher( + channels: Channels, + manifest: Arc>, + mut flush_rx: mpsc::UnboundedReceiver, + ) { + runtime().spawn(async move { + while let Some(label) = flush_rx.recv().await { + { + let mut map = lock(&channels); + if let Some(entry) = map.get_mut(&label) { + entry.open = true; + entry.flush(); + } + } + if label == CONTROL_LABEL { + let json = lock(&manifest).to_json(); + let mut map = lock(&channels); + if let Some(entry) = map.get_mut(CONTROL_LABEL) { + entry.send(json.into_bytes()); + } + } + } + }); + } + + /// Spawn the single-writer negotiation pump. It is woken by `pump_tx` from + /// both the Python track methods (a mutation was enqueued) and the handler + /// (a cycle returned to Stable), and applies at most one mutation per cycle. + #[allow(clippy::too_many_arguments)] + fn spawn_pump( + pc: Arc>>>>, + tracks: Tracks, + encoders: Encoders, + manifest: Arc>, + channels: Channels, + pending_open: Arc>>, + neg: Arc>, + events: EventQueue, + mut pump_rx: mpsc::UnboundedReceiver<()>, + ) { + runtime().spawn(async move { + let negotiator = ProducerNegotiator { + pc, + tracks, + encoders, + manifest, + channels, + pending_open, + events, + }; + while pump_rx.recv().await.is_some() { + pump_step(&neg, &negotiator); + } + }); + } +} + +/// The track facts the encoder feed needs, cloned out of the [`TrackEntry`] so it +/// can build a send closure and read the adaptation state without holding the +/// tracks lock across an ffmpeg restart. +struct TrackMeta { + raw_id: i32, + open: OpenFlag, + control: Arc, + ts_queue: Arc>>, +} + +/// The encoder feed's per-track state, local to the feed thread: which ladder +/// rung is currently encoded, the token-bucket pacer that enforces the rung's +/// input fps cap, and the pre-open stash. +#[derive(Default)] +struct FeedState { + applied_step: Option, + /// Token-bucket pacer: `allowance` frames are available to send, refilled at + /// `fps_cap` per second. A simple "skip if too soon" gate would beat against + /// the source's fixed frame grid (e.g. a 45 fps source through a 33 fps gate + /// only lands on 22 ms boundaries, quantising to ~22.7 fps); the bucket hits + /// the target rate regardless of the input cadence. + allowance: f64, + last_tick: Option, + stash: VecDeque, + /// Bounded crash-restart budget for this track's encoder. Reset whenever a + /// healthy live encoder is observed, so only repeated back-to-back crashes + /// exhaust it. + restart: RestartPolicy, +} + +impl FeedState { + /// Stash a pre-open (or pre-encoder) frame, keeping the earliest + /// [`PREOPEN_STASH_FRAMES`] so the IDR-first prefix survives and memory stays + /// bounded if the track never opens. + fn stash(&mut self, frame: Frame) { + if self.stash.len() < PREOPEN_STASH_FRAMES { + self.stash.push_back(frame); + } + } +} + +/// Spawn the producer's encoder feed on a dedicated OS thread (off the tokio +/// pool, because it makes blocking ffmpeg-stdin writes). It drains the bounded +/// ingress queue, holds frames until the track's renegotiation completes, then +/// encodes them — restarting the per-track ffmpeg encoder whenever the congestion +/// estimator moves the ladder rung (a coarser bitrate/resolution) or a PLI is +/// pending, and capping the input fps to the rung's floor. The blocking write +/// propagates back-pressure to the bounded ingress queue, which is the single +/// place steady-state overload sheds frames (drop-on-full in `submit_frame`). +fn spawn_feed( + mut frame_rx: mpsc::Receiver, + encoders: Encoders, + tracks: Tracks, + events: EventQueue, +) -> JoinHandle<()> { + std::thread::Builder::new() + .name("ncwebrtc-feed".into()) + .spawn(move || { + let adapt_disabled = std::env::var_os("NCD_WEBRTC_DISABLE_ADAPT").is_some(); + let mut feeds: HashMap = HashMap::new(); + while let Some(frame) = frame_rx.blocking_recv() { + let track_id = frame.track_id.clone(); + // The track must be registered and its renegotiation complete + // before we encode/send; until then, stash and wait. + let Some(meta) = track_meta(&tracks, &track_id) else { + feeds.entry(track_id).or_default().stash(frame); + continue; + }; + if !is_open(&meta.open) { + feeds.entry(track_id).or_default().stash(frame); + continue; + } + let feed = feeds.entry(track_id.clone()).or_default(); + + // (Re)build the encoder when the rung changes, a PLI is pending + // (restart = clean keyframe), or none exists yet. `NCD_WEBRTC_ + // DISABLE_ADAPT` pins the finest rung so the constrained-link + // gate can prove it fails *without* adaptation (the estimator + // still observes, but the feed ignores it). + let desired = if adapt_disabled { + TOP_STEP + } else { + meta.control.desired_step() + }; + let pli = meta.control.take_pli(); + + // Crash detection: a crashed ffmpeg encoder (its stdout hit EOF) + // must be restarted, not silently stalled. Surface an on_error with + // ffmpeg's stderr tail, resync the capture-timestamp queue (the dead + // encoder emitted none of its queued stamps), and rebuild — within a + // bounded budget so a permanently-broken ffmpeg surfaces a terminal + // error instead of hot-looping. + let dead = lock(&encoders) + .get(&track_id) + .map(|e| !e.is_alive()) + .unwrap_or(false); + if dead { + let detail = lock(&encoders) + .get(&track_id) + .map(|e| e.stderr_tail()) + .unwrap_or_default(); + lock(&encoders).remove(&track_id); + lock(&meta.ts_queue).clear(); + feed.applied_step = None; + if feed.restart.should_restart() { + events.push(Event::error( + "encode", + format!( + "encoder for {track_id:?} crashed; restarting (ffmpeg: {})", + last_stderr_line(&detail) + ), + )); + } else { + events.push(Event::error( + "encode", + format!( + "encoder for {track_id:?} crashed and exceeded the restart \ + budget; dropping frames (ffmpeg: {})", + last_stderr_line(&detail) + ), + )); + feed.stash(frame); + continue; + } + } else if lock(&encoders).contains_key(&track_id) { + // A healthy live encoder: clear the consecutive-crash budget so + // only repeated back-to-back crashes ever exhaust it. + feed.restart.reset(); + } + + let missing = !lock(&encoders).contains_key(&track_id); + if feed.applied_step != Some(desired) || pli || missing { + match make_encoder(frame.width, frame.height, desired, &meta, events.clone()) { + Some(encoder) => { + // Dropping the previous encoder kills its ffmpeg. + lock(&encoders).insert(track_id.clone(), encoder); + feed.applied_step = Some(desired); + } + None => { + // Spawn failed; surface once per budget so a missing + // ffmpeg does not emit an unbounded error stream. + if !dead && feed.restart.should_restart() { + events.push(Event::error( + "encode", + format!("could not spawn ffmpeg encoder for {track_id:?}"), + )); + } + feed.stash(frame); + continue; + } + } + } + + let Some(encoder) = lock(&encoders).get(&track_id).cloned() else { + feed.stash(frame); + continue; + }; + // Flush the pre-open stash IDR-first (unpaced: the startup burst). + // Queue each frame's capture timestamp so the encoder callback + // stamps the matching output access unit with it. + for held in feed.stash.drain(..) { + push_capture_ts(&meta.ts_queue, held.capture_ts); + encoder.write_frame(&held.data); + } + // Input fps cap via a token bucket: refill `allowance` at the + // rung's fps_cap, spend one token per frame, drop when empty. This + // sheds toward the floor without beating against the source grid. + let fps_cap = LADDER[desired].fps_cap.max(1) as f64; + let now = Instant::now(); + if let Some(last) = feed.last_tick { + feed.allowance += now.duration_since(last).as_secs_f64() * fps_cap; + } else { + feed.allowance = 1.0; // first frame after a (re)start may send + } + feed.last_tick = Some(now); + // Cap the bucket at ~1s of frames so a stall does not let a burst + // through afterwards. + if feed.allowance > fps_cap { + feed.allowance = fps_cap; + } + if feed.allowance < 1.0 { + continue; // over the cap -> drop this frame + } + feed.allowance -= 1.0; + push_capture_ts(&meta.ts_queue, frame.capture_ts); + encoder.write_frame(&frame.data); + } + }) + .expect("spawn encoder feed thread") +} + +/// The last non-empty line of an ffmpeg stderr tail, trimmed — the concise +/// diagnostic put in a crash `on_error`. Empty when stderr captured nothing. +pub(crate) fn last_stderr_line(tail: &str) -> String { + tail.lines() + .rev() + .map(str::trim) + .find(|l| !l.is_empty()) + .unwrap_or("") + .to_string() +} + +/// Clone the feed-relevant facts out of a track's registry entry. +fn track_meta(tracks: &Tracks, track_id: &str) -> Option { + lock(tracks).get(track_id).map(|entry| TrackMeta { + raw_id: entry.raw_id, + open: entry.open.clone(), + control: entry.control.clone(), + ts_queue: entry.ts_queue.clone(), + }) +} + +/// The encoder's frame rate at a ladder rung: the rung's fps cap, but never above +/// the synthetic source's nominal 45 fps, so libx264's CBR bit budget matches the +/// frames it actually receives. +fn rung_encoder_fps(step: usize) -> u32 { + LADDER[step].fps_cap.min(45) +} + +/// Build a fresh per-track ffmpeg encoder for `step`. Its per-access-unit callback +/// frames the encoded NAL units as Annex-B and sends them on the sys track via +/// `rtcSendMessage` — the attached built-in packetizer does the RTP. Each access +/// unit advances the track's 90 kHz RTP timestamp (pushed via +/// `rtcSetTrackRtpTimestamp`) so a frame's packets share a timestamp across +/// restarts. Returns `None` if ffmpeg could not be spawned. +fn make_encoder( + width: u32, + height: u32, + step: usize, + meta: &TrackMeta, + events: EventQueue, +) -> Option> { + let params = EncodeParams { + fps: rung_encoder_fps(step), + bitrate: LADDER[step].bitrate, + scale: LADDER[step].scale, + }; + let raw_id = meta.raw_id; + let open = meta.open.clone(); + let ts_queue = meta.ts_queue.clone(); + let on_access_unit = move |access_unit: Vec>| { + if !is_open(&open) { + return; + } + // Invariant: the producer sends one access unit as exactly one RTP frame + // under one capture timestamp, so the access unit must carry exactly one + // VCL NAL. A slicing or NAL-aggregation change that emits more than one + // would desync the capture-timestamp queue and fabricate per-slice + // timestamps — precisely the Chrome-only, loopback-invisible defect in + // reports/SPIKE-chrome-pframe.md. Fail loud and drop the frame rather than + // send a malformed one. (x264 `threads=1` keeps this true; see + // media::x264_params.) + let vcl_count = vcl_nal_count(&access_unit); + if vcl_count != 1 { + eprintln!( + "[ncwebrtc] INVARIANT VIOLATED: access unit has {vcl_count} VCL NAL(s), \ + expected exactly 1 (one slice per frame). Dropping rather than \ + fabricating RTP timestamps — see reports/SPIKE-chrome-pframe.md." + ); + return; + } + let buf = annexb_access_unit(&access_unit); + if buf.is_empty() { + return; + } + // Stamp this access unit with the matching input frame's *capture* time. + // The queue is pushed once per frame written to the encoder and popped + // once per emitted access unit (in-order baseline, so the head is this + // frame's), so a healthy 1:1 pipeline never underflows. An underflow means + // more access units than input frames (a multi-slice encode the VCL guard + // above did not catch) — fail loud and drop rather than fabricate a + // backward-running timestamp (the exact regression the spike found). No + // silent fallback: a capture-time RTP clock is what a strict receiver + // (Chrome's jitter buffer) needs to assemble inter-keyframe frames. + let Some(ts) = lock(&ts_queue).pop_front() else { + eprintln!( + "[ncwebrtc] INVARIANT VIOLATED: capture-timestamp queue underflow \ + (more access units than input frames — multi-slice encode?). \ + Dropping rather than fabricating a timestamp — see \ + reports/SPIKE-chrome-pframe.md." + ); + return; + }; + // SAFETY: `raw_id` is a live sys track id this producer created and owns + // until remove/close; both calls are libdatachannel-internally locked. + let sent = unsafe { + sys::rtcSetTrackRtpTimestamp(raw_id, ts); + sys::rtcSendMessage(raw_id, buf.as_ptr() as *const c_char, buf.len() as c_int) + }; + if sent < 0 { + // The consumer's track went away under us (closed / SRTP torn down). + // Stop sending (close the gate so this does not spam) and surface once; + // the Failed connection state separately drives reconnect-needed. + open.store(false, Ordering::SeqCst); + events.push(Event::error( + "send", + "send on a closed track; suppressing further sends", + )); + } + }; + match H264Encoder::new(width, height, params, on_access_unit) { + Ok(encoder) => Some(Arc::new(encoder)), + Err(err) => { + crate::transport::debug_trace("P", &format!("encoder spawn failed: {err}")); + None + } + } +} + +/// Build and attach libdatachannel's built-in H.264 send chain to a sys track id: +/// the packetizer (FU-A / sequence / marker / SSRC) plus the SR reporter, NACK +/// responder, PLI handler and REMB handler. The packetizer init's `cname` **must** +/// be non-null or `rtcSetH264Packetizer` returns -1 (and the SR reporter then +/// fails) — [`packetizer_cname`] guarantees that. The cname is copied into the +/// config during the call, so it only needs to outlive `rtcSetH264Packetizer`. +pub(crate) fn attach_producer_chain(raw_id: i32, ssrc: u32) -> Result<(), String> { + let cname = packetizer_cname(); + let init = sys::rtcPacketizerInit { + ssrc, + cname: cname.as_ptr(), + payloadType: VIDEO_PAYLOAD_TYPE as u8, + clockRate: VIDEO_CLOCK_HZ, + sequenceNumber: 0, + timestamp: 0, + maxFragmentSize: MAX_FRAGMENT_SIZE, + nalSeparator: sys::rtcNalUnitSeparator_RTC_NAL_SEPARATOR_LONG_START_SEQUENCE, + obuPacketization: sys::rtcObuPacketization_RTC_OBU_PACKETIZED_OBU, + playoutDelayId: 0, + playoutDelayMin: 0, + playoutDelayMax: 0, + }; + // SAFETY: `raw_id` is a freshly created sys track id; `init.cname` is a + // non-null C string that lives until the end of this function. + unsafe { + if sys::rtcSetH264Packetizer(raw_id, &init) < 0 { + return Err("rtcSetH264Packetizer failed".into()); + } + if sys::rtcChainRtcpSrReporter(raw_id) < 0 { + return Err("rtcChainRtcpSrReporter failed".into()); + } + if sys::rtcChainRtcpNackResponder(raw_id, NACK_HISTORY) < 0 { + return Err("rtcChainRtcpNackResponder failed".into()); + } + if sys::rtcChainPliHandler(raw_id, Some(on_pli_cb)) < 0 { + return Err("rtcChainPliHandler failed".into()); + } + if sys::rtcChainRembHandler(raw_id, Some(on_remb_cb)) < 0 { + return Err("rtcChainRembHandler failed".into()); + } + } + drop(cname); // copied by rtcSetH264Packetizer; nothing else holds the pointer + Ok(()) +} + +/// The side-effecting half of one negotiation cycle: apply a single track +/// mutation against the real transport (add/drop the track, drive the offer, +/// update the registry and manifest). Abstracted behind a trait so the queue's +/// *control logic* in [`pump_step`] — one cycle in flight, in-order draining, +/// error recovery — can be driven by a fake in unit tests without a live +/// `PeerConnection`. Production is [`ProducerNegotiator`]. +pub(crate) trait NegotiationApply { + fn apply(&self, mutation: Mutation) -> Result<(), String>; +} + +/// Production negotiator: applies a mutation against the real peer connection +/// (the `pc`/`tracks`/`manifest`/`channels` the pump previously took directly). +struct ProducerNegotiator { + pc: Arc>>>>, + tracks: Tracks, + encoders: Encoders, + manifest: Arc>, + channels: Channels, + pending_open: Arc>>, + events: EventQueue, +} + +impl NegotiationApply for ProducerNegotiator { + fn apply(&self, mutation: Mutation) -> Result<(), String> { + let result = apply_mutation( + &self.pc, + &self.tracks, + &self.encoders, + &self.manifest, + &self.channels, + &self.pending_open, + mutation, + ); + // Surface a negotiation failure (chain-attach / SDP / add-track error) on + // the event queue rather than only tracing it; the queue's control logic + // still clears the in-flight gate and drains on, so one bad mutation never + // wedges the pump. + if let Err(err) = &result { + self.events + .push(Event::error("negotiate", format!("track mutation failed: {err}"))); + } + result + } +} + +/// Republish the current manifest over the control channel, if it exists. +/// Buffers (via the pre-open gate) until the control channel is open. +pub(crate) fn republish(channels: &Channels, manifest: &Arc>) { + let json = lock(manifest).to_json(); + let mut map = lock(channels); + if let Some(entry) = map.get_mut(CONTROL_LABEL) { + entry.send(json.into_bytes()); + } +} + +/// Advance the negotiation queue. Starts at most one offer/answer cycle: if a +/// cycle is already in flight, returns immediately (the Stable callback will wake +/// the pump again). On an apply error the gate is cleared and the next mutation is +/// attempted, so one bad mutation never wedges the queue. +/// +/// Generic over [`NegotiationApply`] so the control logic here is exercised by a +/// fake in unit tests; production passes [`ProducerNegotiator`]. +/// +/// Lock discipline: the negotiation lock is only ever held to pop a mutation and +/// is released before the negotiator runs (which takes the peer-connection lock). +/// The signaling callback takes the negotiation lock while libdatachannel holds +/// the PC, so taking them in the opposite order here would deadlock — hence the +/// release-before-apply. +pub(crate) fn pump_step(neg: &Arc>, negotiator: &impl NegotiationApply) { + loop { + let mutation = { + let mut state = lock(neg); + if state.in_flight { + return; + } + match state.pending.pop_front() { + Some(mutation) => { + state.in_flight = true; + mutation + } + None => return, + } + }; + + match negotiator.apply(mutation) { + Ok(()) => return, // cycle in flight; the Stable callback resumes us + Err(err) => { + crate::transport::debug_trace("P", &format!("mutation failed: {err}")); + lock(neg).in_flight = false; + // Try the next mutation rather than wedging on a bad one. + } + } + } +} + +/// Apply one track mutation: add or remove a media m-line and trigger the offer. +/// Never holds the tracks/manifest lock across the PC lock. +#[allow(clippy::too_many_arguments)] +fn apply_mutation( + pc: &Arc>>>>, + tracks: &Tracks, + encoders: &Encoders, + manifest: &Arc>, + channels: &Channels, + pending_open: &Arc>>, + mutation: Mutation, +) -> Result<(), String> { + match mutation { + Mutation::Add { + track_id, + mid, + ssrc, + } => { + let open = open_flag(); + let control = Arc::new(TrackControl::default()); + let ts_queue = Arc::new(Mutex::new(VecDeque::new())); + // Create the track through the sys layer so we keep its raw id and can + // attach the built-in chain to it (datachannel-rs `add_track_ex` + // swallows the id). `rtcAddTrackEx` does NOT auto-offer, so we drive + // the offer ourselves — same as PR3. + let raw_id = { + let mut guard = lock(pc); + let pc = guard.as_mut().ok_or("producer is closed")?; + let pc_id = raw_pc_id(pc).ok_or("cannot recover pc id for rtcAddTrackEx")?; + let mid_c = CString::new(mid.clone()).map_err(|e| e.to_string())?; + let track_c = CString::new(track_id.clone()).map_err(|e| e.to_string())?; + let init = sys::rtcTrackInit { + direction: sys::rtcDirection_RTC_DIRECTION_SENDONLY, + codec: sys::rtcCodec_RTC_CODEC_H264, + payloadType: VIDEO_PAYLOAD_TYPE, + ssrc, + mid: mid_c.as_ptr(), + name: std::ptr::null(), + msid: std::ptr::null(), + trackId: track_c.as_ptr(), + profile: std::ptr::null(), + }; + // SAFETY: `pc_id` is this live PC's id; the CString pointers live + // until the end of this block. + let raw_id = unsafe { sys::rtcAddTrackEx(pc_id, &init) }; + if raw_id < 0 { + return Err(format!("rtcAddTrackEx failed: {raw_id}")); + } + // From here the track exists in libdatachannel; any later failure + // must tear it down (deregister the controller, clear the callback, + // delete the track) so a mid-setup error leaves no PRODUCER_FB entry + // or callback leaked. Run the fallible setup, then clean up on Err. + let setup = (|| -> Result<(), String> { + attach_producer_chain(raw_id, ssrc)?; + // Register the feedback controller, then route the chain's RR + // (raw inbound RTCP) to it via the track's message callback. + lock(&PRODUCER_FB).insert( + raw_id, + Arc::new(CongestionController::new(ssrc, control.clone())), + ); + // SAFETY: `raw_id` is the just-created track id. + unsafe { sys::rtcSetMessageCallback(raw_id, Some(on_rtcp_cb)) }; + pc.set_local_description(SdpType::Offer) + .map_err(|e| e.to_string())?; + Ok(()) + })(); + if let Err(err) = setup { + teardown_sys_track(raw_id); + return Err(err); + } + raw_id + }; + lock(tracks).insert( + track_id.clone(), + TrackEntry { + mid: mid.clone(), + raw_id, + open: open.clone(), + control, + ts_queue, + }, + ); + // Arm the open flag to flip when this offer's answer is applied (the + // Stable callback). That is the point the consumer is ready, so the + // encoder feed may start sending without losing the first IDR. + *lock(pending_open) = Some(open); + lock(manifest).upsert_video_track(&mid, &track_id); + republish(channels, manifest); + Ok(()) + } + Mutation::Remove { track_id } => { + let entry = lock(tracks).remove(&track_id); + let Some(entry) = entry else { + // Unknown track id: nothing to negotiate, treat as a no-op so the + // queue keeps draining. + return Ok(()); + }; + let mid = entry.mid.clone(); + let raw_id = entry.raw_id; + drop(entry); + // Reap the encoder (kills its ffmpeg), then deregister the controller, + // clear the callback, and delete the sys track via the shared teardown — + // the same path the close and mid-setup-error paths use. + lock(encoders).remove(&track_id); + teardown_sys_track(raw_id); + { + let mut guard = lock(pc); + let pc = guard.as_mut().ok_or("producer is closed")?; + pc.set_local_description(SdpType::Offer) + .map_err(|e| e.to_string())?; + } + lock(manifest).remove_entry(&mid); + republish(channels, manifest); + Ok(()) + } + } +} + +#[cfg(test)] +mod tests { + //! Peer-free unit tests for the producer's frame ingress and the negotiation + //! queue's control logic. The queue is driven through a fake + //! [`NegotiationApply`] and frame ingress through a fake [`FrameBytes`], so + //! neither needs a live `PeerConnection`, a socket, or the GIL. + + use super::*; + use std::collections::HashSet; + use tokio::sync::mpsc::error::TrySendError; + + // --- frame ingress: the contiguity gate ---------------------------------- + + struct FakeFrame { + contiguous: bool, + shape: Vec, + bytes: Vec, + } + + impl FrameBytes for FakeFrame { + fn is_contiguous(&self) -> bool { + self.contiguous + } + fn shape(&self) -> Vec { + self.shape.clone() + } + fn to_owned_bytes(&self) -> Vec { + self.bytes.clone() + } + } + + #[test] + fn read_frame_rejects_a_non_contiguous_buffer() { + let frame = FakeFrame { + contiguous: false, + shape: vec![2, 2, 3], + bytes: vec![0; 12], + }; + // submit_frame maps this exact error onto the ValueError callers see. + assert_eq!(read_frame(&frame), Err(FrameError::NotContiguous)); + } + + #[test] + fn read_frame_rejects_a_non_hwc3_shape() { + // Wrong dimensionality (a flat buffer) and wrong channel count both fail + // the shape gate before any encoder sees them. + let flat = FakeFrame { + contiguous: true, + shape: vec![12], + bytes: vec![0; 12], + }; + assert_eq!(read_frame(&flat), Err(FrameError::BadShape)); + let four_channel = FakeFrame { + contiguous: true, + shape: vec![2, 2, 4], + bytes: vec![0; 16], + }; + assert_eq!(read_frame(&four_channel), Err(FrameError::BadShape)); + } + + #[test] + fn read_frame_extracts_dimensions_and_copies_a_contiguous_buffer() { + let frame = FakeFrame { + contiguous: true, + shape: vec![480, 640, 3], + bytes: vec![9, 8, 7], + }; + assert_eq!( + read_frame(&frame), + Ok(FrameData { + data: vec![9, 8, 7], + width: 640, + height: 480, + }) + ); + } + + // --- frame ingress: the bounded queue ------------------------------------ + + fn frame(tag: u8) -> Frame { + Frame { + track_id: "cam0".to_string(), + data: vec![tag], + width: 640, + height: 480, + capture_ts: tag as u32, + } + } + + #[test] + fn frame_queue_capacity_is_sixteen() { + assert_eq!(FRAME_QUEUE_CAPACITY, 16); + } + + #[test] + fn submit_drops_on_overflow_and_never_blocks() { + // The producer's bounded channel, exactly as submit_frame builds it. + // `try_send` is the non-blocking enqueue: it returns immediately whether + // or not the queue is full. + let (tx, mut rx) = mpsc::channel::(FRAME_QUEUE_CAPACITY); + for i in 0..FRAME_QUEUE_CAPACITY { + assert!(tx.try_send(frame(i as u8)).is_ok(), "frame {i} should fit"); + } + // Full queue: the next enqueue is dropped (Full), never blocked. + match tx.try_send(frame(0xff)) { + Err(TrySendError::Full(_)) => {} + other => panic!("expected a Full drop on overflow, got {other:?}"), + } + // Draining one slot admits exactly one more (bounded depth N, FIFO). + assert!(rx.try_recv().is_ok()); + assert!(tx.try_send(frame(0x01)).is_ok()); + } + + // --- negotiation queue: control logic behind a fake ---------------------- + + /// Records every mutation it applies, in order. A track id listed in `fail` + /// returns an error, to drive the error-recovery path. + #[derive(Default)] + struct FakeNegotiator { + applied: Mutex>, + fail: Mutex>, + } + + impl FakeNegotiator { + fn applied_ids(&self) -> Vec { + self.applied + .lock() + .unwrap() + .iter() + .map(|m| m.track_id().to_string()) + .collect() + } + } + + impl NegotiationApply for FakeNegotiator { + fn apply(&self, mutation: Mutation) -> Result<(), String> { + if self.fail.lock().unwrap().contains(mutation.track_id()) { + return Err(format!("forced failure for {}", mutation.track_id())); + } + self.applied.lock().unwrap().push(mutation); + Ok(()) + } + } + + fn queue_with(mutations: Vec) -> Arc> { + let neg = Arc::new(Mutex::new(NegState::default())); + lock(&neg).pending.extend(mutations); + neg + } + + /// Models `on_signaling_state_change(Stable)`: the in-flight offer's answer + /// has been applied, so clear the gate. (The real callback also pings the + /// pump; here the test drives `pump_step` directly.) + fn complete_cycle(neg: &Arc>) { + lock(neg).in_flight = false; + } + + fn add(track_id: &str) -> Mutation { + Mutation::Add { + track_id: track_id.to_string(), + mid: format!("v_{track_id}"), + ssrc: 1, + } + } + fn remove(track_id: &str) -> Mutation { + Mutation::Remove { + track_id: track_id.to_string(), + } + } + + #[test] + fn applies_at_most_one_mutation_per_cycle_and_holds_the_rest() { + let neg = queue_with(vec![add("a"), add("b"), remove("c")]); + let negotiator = FakeNegotiator::default(); + + // One pump starts exactly one cycle; the rest stay queued (serialized, + // not batched — the PR3 contract). + pump_step(&neg, &negotiator); + assert_eq!(negotiator.applied_ids(), vec!["a"]); + assert!(lock(&neg).in_flight, "a cycle must be in flight"); + assert_eq!(lock(&neg).pending.len(), 2); + } + + #[test] + fn no_mutation_is_applied_while_a_cycle_is_in_flight() { + let neg = queue_with(vec![add("a"), add("b")]); + let negotiator = FakeNegotiator::default(); + + pump_step(&neg, &negotiator); // applies "a"; in_flight = true + // Further pumps are no-ops until the cycle completes. + pump_step(&neg, &negotiator); + pump_step(&neg, &negotiator); + assert_eq!(negotiator.applied_ids(), vec!["a"]); + } + + #[test] + fn queued_mutations_apply_in_order_after_each_completion() { + let neg = queue_with(vec![add("a"), add("b"), remove("c")]); + let negotiator = FakeNegotiator::default(); + + pump_step(&neg, &negotiator); // a + complete_cycle(&neg); + pump_step(&neg, &negotiator); // b + complete_cycle(&neg); + pump_step(&neg, &negotiator); // c + + assert_eq!(negotiator.applied_ids(), vec!["a", "b", "c"]); + } + + #[test] + fn the_queue_drains_to_a_converged_final_state() { + let mutations = vec![add("a"), add("b"), remove("a"), add("c")]; + let neg = queue_with(mutations.clone()); + let negotiator = FakeNegotiator::default(); + + for _ in 0..mutations.len() { + pump_step(&neg, &negotiator); + complete_cycle(&neg); + } + // Every mutation applied exactly once, in order; queue empty, not in flight. + assert_eq!(*negotiator.applied.lock().unwrap(), mutations); + assert!(lock(&neg).pending.is_empty()); + assert!(!lock(&neg).in_flight); + } + + #[test] + fn a_failed_mutation_clears_the_gate_and_the_queue_keeps_draining() { + let neg = queue_with(vec![add("bad"), add("good")]); + let negotiator = FakeNegotiator::default(); + negotiator.fail.lock().unwrap().insert("bad".to_string()); + + // One pump: "bad" fails (gate cleared, the loop continues), then "good" + // applies and leaves its cycle in flight — one bad mutation never wedges + // the queue. + pump_step(&neg, &negotiator); + assert_eq!(negotiator.applied_ids(), vec!["good"]); + assert!(lock(&neg).in_flight); + assert!(lock(&neg).pending.is_empty()); + } + + // --- registry hygiene: deregistration leaves no leak --------------------- + + #[test] + fn deregister_feedback_removes_only_the_keyed_track() { + // The producer-feedback registry is process-global, so assert on the + // specific ids this test owns (parallel-test-safe) rather than emptiness. + // Two tracks are registered (an add); deregistering one — the remove or a + // mid-setup-failure cleanup — leaves no entry for it and does not disturb + // the other. + let id_a = 0x07_70_00_01; + let id_b = 0x07_70_00_02; + let ctrl = || Arc::new(CongestionController::new(1, Arc::new(TrackControl::default()))); + lock(&PRODUCER_FB).insert(id_a, ctrl()); + lock(&PRODUCER_FB).insert(id_b, ctrl()); + + deregister_feedback(id_a); + assert!(!lock(&PRODUCER_FB).contains_key(&id_a), "added-then-removed leaves no entry"); + assert!(lock(&PRODUCER_FB).contains_key(&id_b), "the other track is untouched"); + + // Simulate the mid-setup-failure cleanup path (register, then fail, then + // deregister): the partially-set-up track also leaves no leaked entry. + deregister_feedback(id_b); + assert!(!lock(&PRODUCER_FB).contains_key(&id_b), "setup-failure cleanup leaves no entry"); + } + + // --- the capture-timestamp queue stays bounded --------------------------- + + #[test] + fn push_capture_ts_caps_the_queue_and_keeps_the_newest() { + let queue = Mutex::new(VecDeque::new()); + for ts in 0..(TS_QUEUE_CAP as u32 + 50) { + push_capture_ts(&queue, ts); + } + let q = lock(&queue); + assert_eq!(q.len(), TS_QUEUE_CAP, "the queue never grows past the cap"); + // The oldest stamps were shed; the most recent are retained. + assert_eq!(*q.back().unwrap(), TS_QUEUE_CAP as u32 + 49); + assert_eq!(*q.front().unwrap(), 50); + } + + // --- crash diagnostic extraction ----------------------------------------- + + #[test] + fn last_stderr_line_picks_the_last_non_empty_line() { + assert_eq!( + last_stderr_line("Input #0\n[libx264] fatal: out of memory\n\n"), + "[libx264] fatal: out of memory" + ); + assert_eq!(last_stderr_line(""), ""); + assert_eq!(last_stderr_line(" \n \n"), ""); + } + + // --- the caller-owned mid is used verbatim (no per-producer v-counter) ----- + + #[test] + fn track_mid_is_the_supplied_track_id_verbatim() { + // The single source of truth for the video m-line mid: the caller-supplied + // track_id, used verbatim. The old per-producer / per-link "v{n}" counter is + // gone, so the value a caller registers for identity (available_robots) and + // the offer's a=mid are the one and same producer-owned value. A regression + // that reintroduced a counter would fail here. + assert_eq!(track_mid("wrist_cam"), "wrist_cam"); + assert_eq!(track_mid("v0"), "v0"); + assert_eq!(track_mid("0"), "0"); + } +} diff --git a/rust/neuracore_webrtc/src/runtime.rs b/rust/neuracore_webrtc/src/runtime.rs new file mode 100644 index 000000000..e240c91d4 --- /dev/null +++ b/rust/neuracore_webrtc/src/runtime.rs @@ -0,0 +1,42 @@ +//! The process-global tokio runtime owned and run by the Rust WebRTC core. +//! +//! The Python boundary is deliberately synchronous (see the crate docs): the +//! tokio runtime is created and driven entirely by Rust on its own worker +//! threads, and Python never holds, polls, or awaits anything on it. Every +//! `#[pymethods]` entry point is a plain blocking call that either touches a +//! thread-safe queue or hands work to a task `spawn`ed onto this runtime. +//! +//! The runtime is a lazily-initialised singleton so the first `Producer` or +//! `Consumer` constructed in the process stands it up, and it then lives for +//! the lifetime of the process. There is no shutdown hook: dropping the last +//! handle does not tear the runtime down, which keeps the close path simple and +//! avoids racing a half-torn-down runtime against in-flight tasks. + +use once_cell::sync::Lazy; +use tokio::runtime::Runtime; + +/// Number of tokio worker threads the core runs on. Two is plenty for the +/// PR0 scaffolding (one stub frame-drain task per producer); PR2+ may revisit +/// this once the real transport tasks land. +const WORKER_THREADS: usize = 2; + +static RUNTIME: Lazy = Lazy::new(|| { + tokio::runtime::Builder::new_multi_thread() + .worker_threads(WORKER_THREADS) + .thread_name("neuracore-webrtc") + .enable_all() + .build() + .expect("failed to build the neuracore-webrtc tokio runtime") +}); + +/// Return the shared, Rust-owned runtime, initialising it on first use. +pub(crate) fn runtime() -> &'static Runtime { + &RUNTIME +} + +/// Ensure the runtime is stood up. Called from the `Producer`/`Consumer` +/// constructors so the runtime is owned by Rust the moment a peer exists, even +/// if the peer never spawns a task itself (the answer-only `Consumer`). +pub(crate) fn ensure_started() { + Lazy::force(&RUNTIME); +} diff --git a/rust/neuracore_webrtc/src/transport.rs b/rust/neuracore_webrtc/src/transport.rs new file mode 100644 index 000000000..f8e8d2269 --- /dev/null +++ b/rust/neuracore_webrtc/src/transport.rs @@ -0,0 +1,606 @@ +//! Shared transport plumbing for the producer and consumer peers. +//! +//! Both peers wrap a libdatachannel [`RtcPeerConnection`] (via datachannel-rs) +//! and translate between its callback surface and the synchronous, queue-backed +//! Python API. This module holds the pieces both sides need: +//! +//! - [`loopback_config`] — the ICE configuration used in-process (host +//! candidates only; no STUN/TURN). +//! - SDP / state translation ([`parse_session`], [`sdp_type_str`], +//! [`connection_state_str`], [`reliability_kind_hint`]). +//! - [`ManifestState`] — the control-channel manifest model (a flat map keyed by +//! data-channel label now, video-track mid later) and its JSON rendering. +//! +//! libdatachannel refuses (`outgoing()` throws) any message sent before a data +//! channel's SCTP stream is open, so the producer buffers outgoing bytes per +//! channel until `on_open` fires and then flushes them in order; that send gate +//! lives in [`crate::producer`] next to the channels it owns. +//! +//! All shared state is interior-mutable behind `Arc`/`Mutex` so the libdatachannel +//! callback threads and the Python-facing methods can touch it concurrently while +//! the peers stay `Send + Sync`. + +use std::collections::BTreeMap; +use std::sync::{Mutex, MutexGuard}; + +use datachannel::{ + ConnectionState, DataChannelHandler, PeerConnectionHandler, Reliability, RtcConfig, + RtcPeerConnection, SdpType, SessionDescription, +}; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::PyErr; +use serde_json::{json, Map, Value}; + +/// Recover the raw libdatachannel peer-connection id that the sys-level calls +/// (`rtcSetTrackCallback`, `rtcAddTrackEx`) need. datachannel-rs 0.16 surfaces it +/// only as the opaque `PeerConnectionId`, whose inner `i32` is private; its +/// derived `Debug` renders as `PeerConnectionId()`, so we parse the integer +/// out. The alternative is forking the binding; this keeps the workaround +/// contained to one function (used by both peers). If parsing ever fails, the +/// caller disables its sys-level path gracefully. +pub(crate) fn raw_pc_id

(pc: &RtcPeerConnection

) -> Option +where + P: PeerConnectionHandler + Send, + P::DCH: DataChannelHandler + Send, +{ + let dbg = format!("{:?}", pc.id()); + let start = dbg.find('(')? + 1; + let end = dbg.rfind(')')?; + dbg.get(start..end)?.trim().parse().ok() +} + +/// The reserved control-channel label. It carries the manifest and is never +/// itself listed in the manifest nor surfaced to the consumer as a data channel. +pub(crate) const CONTROL_LABEL: &str = "control"; + +/// Stderr trace of a transport event, gated on `NEURACORE_WEBRTC_DEBUG`. Used to +/// time the ICE/DTLS phases when diagnosing connect latency; a no-op otherwise. +pub(crate) fn debug_trace(peer: &str, what: &str) { + if std::env::var_os("NEURACORE_WEBRTC_DEBUG").is_some() { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs_f64()) + .unwrap_or(0.0); + eprintln!("[ncwebrtc {now:.3}] {peer} {what}"); + } +} + +/// Recover a mutex guard even if a holder panicked mid-update. A poisoned lock +/// here only means a callback thread panicked while pushing; keeping delivery +/// alive beats cascading the panic across the FFI boundary. +pub(crate) fn lock(m: &Mutex) -> MutexGuard<'_, T> { + m.lock().unwrap_or_else(|e| e.into_inner()) +} + +/// Map a libdatachannel error into a Python exception. +pub(crate) fn map_err(err: E) -> PyErr { + PyRuntimeError::new_err(format!("datachannel error: {err}")) +} + +/// In-process ICE configuration: no ICE servers (host candidates only) and the +/// ICE agent bound to a single address so exactly one host candidate is +/// gathered. Two peers in the same process then connect over that one address +/// with no STUN/TURN round trip. +/// +/// Binding matters: left unbound, libdatachannel gathers a candidate per local +/// interface and ICE tries the highest-priority pair first. In containers the +/// highest-priority candidate is often an unreachable IPv6 ULA, so the agent +/// stalls a full second on the STUN retransmit before falling back to IPv4 — +/// blowing the connect SLO. Binding to one reachable address removes the dud +/// pair. The address is overridable via `NEURACORE_WEBRTC_BIND_ADDRESS` (default +/// `127.0.0.1`, correct for the in-process peers this PR ships; the production +/// cutover supplies its own RtcConfig with real ICE servers). +/// +/// `force_media_transport` is mandatory from PR3 on **both** peers. libdatachannel +/// only stands up the DTLS-SRTP transport when the initial connection already has +/// media or this flag is set; otherwise a track added by a *later* renegotiation +/// hits `iterateRemoteTracks` with no SRTP transport and the track is errored +/// ("The connection has no media transport" — see libdatachannel +/// `impl/peerconnection.cpp`). Forcing it up front means the first video-track add +/// reuses the existing BUNDLE transport with no second DTLS handshake, so +/// connect-latency is paid once during bootstrap and `connect_ms` does not regress. +pub(crate) fn loopback_config() -> RtcConfig { + let no_servers: [&str; 0] = []; + RtcConfig::new(&no_servers) + .bind_address(&bind_address()) + .force_media_transport() +} + +/// The single ICE bind address used in-process: `NEURACORE_WEBRTC_BIND_ADDRESS` +/// if set, else `127.0.0.1`. Split out of [`loopback_config`] so the selection +/// is unit-testable without building an `RtcConfig` (which would pull in the ICE +/// agent and bind a socket). +pub(crate) fn bind_address() -> String { + std::env::var("NEURACORE_WEBRTC_BIND_ADDRESS").unwrap_or_else(|_| "127.0.0.1".to_string()) +} + +/// The wire `sdp_type` string for an SDP. Mirrors datachannel-rs's private +/// `SdpType::val`, which we cannot call. +pub(crate) fn sdp_type_str(sdp_type: &SdpType) -> &'static str { + match sdp_type { + SdpType::Answer => "answer", + SdpType::Offer => "offer", + SdpType::Pranswer => "pranswer", + SdpType::Rollback => "rollback", + } +} + +/// Whether outgoing offers should be munged for Chrome's stricter SDP parser. +/// Gated by `NCD_WEBRTC_CHROME_SDP` so the libdatachannel-to-libdatachannel +/// loopback path (which parses the bare `a=ssrc` line fine) is untouched. +pub(crate) fn chrome_sdp_enabled() -> bool { + std::env::var_os("NCD_WEBRTC_CHROME_SDP").is_some() +} + +/// Make a libdatachannel offer acceptable to Chrome's stricter SDP parser by +/// giving every bare `a=ssrc:` line a `cname` attribute. +/// +/// libdatachannel emits a bare `a=ssrc:` (no source attribute); Chrome rejects +/// it ("a=ssrc Expects 2 fields") and never sets up the receive track, so the +/// producer's H.264 never reaches the decoder. RFC 5576 requires the SSRC carry at +/// least a `cname`. We append `cname:` (the packetizer's RTCP CNAME, so the +/// SDP and the RTCP SR agree) to any bare `a=ssrc:` line, leaving an already- +/// qualified line (`a=ssrc: ...`) untouched and idempotent. Pure so it is +/// unit-testable without a peer; applied only when [`chrome_sdp_enabled`]. +pub(crate) fn munge_ssrc_cname(sdp: &str, cname: &str) -> String { + // Preserve the original line endings: SDP is CRLF on the wire, but the munge + // must not rewrite an `\n`-only document into CRLF (or vice versa). + let mut out = String::with_capacity(sdp.len() + 32); + let mut rest = sdp; + while let Some(nl) = rest.find('\n') { + let (line_with_cr, tail) = rest.split_at(nl + 1); + out.push_str(&munge_ssrc_line(line_with_cr, cname)); + rest = tail; + } + if !rest.is_empty() { + out.push_str(&munge_ssrc_line(rest, cname)); + } + out +} + +/// Munge a single SDP line (which may carry a trailing `\r\n`/`\n`). Appends +/// ` cname:` to a bare `a=ssrc:` line; everything else passes through. +fn munge_ssrc_line(line: &str, cname: &str) -> String { + let trimmed = line.trim_end_matches(['\r', '\n']); + let eol = &line[trimmed.len()..]; + let Some(value) = trimmed.strip_prefix("a=ssrc:") else { + return line.to_string(); + }; + // Already qualified (`a=ssrc: cname:...`) -> leave it alone (idempotent). + if value.split_whitespace().count() != 1 { + return line.to_string(); + } + format!("a=ssrc:{value} cname:{cname}{eol}") +} + +// --------------------------------------------------------------------------- +// Reconnect / ICE-restart decision (a pure seam) +// --------------------------------------------------------------------------- + +/// Whether this binding can perform an ICE restart. **It cannot.** libdatachannel +/// builds ICE on libjuice, whose agent is single-shot: it cannot regather with new +/// credentials, so there is no `restart_ice`/`setLocalDescription({iceRestart})` +/// path (upstream libjuice #130; recorded in `reports/PR8-*` from the spike line). +/// On a Disconnected/Failed connection the peer therefore surfaces a clear +/// reconnect-needed signal and the application removes and re-adds the peer (for a +/// broadcaster, the single failed consumer) rather than attempting an in-place +/// restart. This constant documents the binding capability in one place; if a +/// future libdatachannel/libjuice gains real ICE restart, flip it and +/// [`reconnect_action`] starts returning [`ReconnectAction::IceRestart`]. +pub(crate) const ICE_RESTART_SUPPORTED: bool = false; + +/// What to do when a peer connection changes state, given whether the binding can +/// restart ICE. The decision is pure so it is unit-tested behind a fake without a +/// live peer (see the reconnect-decision test). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ReconnectAction { + /// A healthy/transitional state: do nothing. + None, + /// Disconnected/Failed and the binding supports ICE restart: attempt a bounded + /// in-place restart. Unreachable while [`ICE_RESTART_SUPPORTED`] is false. + IceRestart, + /// Disconnected/Failed and the binding cannot restart ICE: surface a + /// reconnect-needed error so the app tears down and re-adds the peer/consumer. + SurfaceReconnect, +} + +/// Map a connection state to a reconnect action. `Failed` and `Disconnected` are +/// the only states that need recovery; everything else is `None`. +pub(crate) fn reconnect_action(state: ConnectionState, ice_restart_supported: bool) -> ReconnectAction { + match state { + ConnectionState::Failed | ConnectionState::Disconnected => { + if ice_restart_supported { + ReconnectAction::IceRestart + } else { + ReconnectAction::SurfaceReconnect + } + } + _ => ReconnectAction::None, + } +} + +/// The lowercase wire string for a connection state (the `on_state` payload). +pub(crate) fn connection_state_str(state: &ConnectionState) -> &'static str { + match state { + ConnectionState::New => "new", + ConnectionState::Connecting => "connecting", + ConnectionState::Connected => "connected", + ConnectionState::Disconnected => "disconnected", + ConnectionState::Failed => "failed", + ConnectionState::Closed => "closed", + } +} + +/// A coarse reliability hint surfaced with a newly observed data channel. All +/// channels in this PR are reliable-ordered, but we report what the channel +/// actually negotiated so the hint stays honest if that changes. +pub(crate) fn reliability_kind_hint(reliability: &Reliability) -> String { + if reliability.unreliable { + "unreliable".to_string() + } else if reliability.unordered { + "unordered".to_string() + } else { + "reliable".to_string() + } +} + +/// Parse a wire SDP string into the [`SessionDescription`] datachannel-rs wants +/// for `set_remote_description`. The SDP round-trips through webrtc_sdp's +/// parser, so it is semantically — not byte — faithful, which is fine when both +/// peers are libdatachannel. +pub(crate) fn parse_session(sdp: &str, sdp_type: SdpType) -> Result { + let parsed = datachannel::sdp::parse_sdp(sdp, false) + .map_err(|e| PyValueError::new_err(format!("invalid SDP: {e}")))?; + Ok(SessionDescription { + sdp: parsed, + sdp_type, + }) +} + +/// The control-channel manifest: the producer's published view of the streams a +/// consumer can expect on this connection. +/// +/// It is a flat JSON object keyed by stream identity — data-channel **label** +/// now, video-track **mid** once PR4 adds tracks — so a consumer reads the whole +/// stream set from the object's keys. Each value is a small descriptor object +/// carrying a `type` discriminator plus type-specific fields. The control +/// channel itself is never an entry. See `reports/PR2-data-path.md` for the +/// schema PR4 extends. +#[derive(Default)] +pub(crate) struct ManifestState { + entries: BTreeMap, +} + +impl ManifestState { + /// Insert or replace a data-channel entry keyed by its label. + pub(crate) fn upsert_data_channel(&mut self, label: &str, kind: &str) { + self.entries.insert( + label.to_string(), + json!({ "type": "data_channel", "kind": kind }), + ); + } + + /// Insert or replace a video-track entry keyed by its negotiated `mid`. The + /// shape is the one PR2 reserved and the video test asserts: `mid` is the + /// key, the descriptor carries `type: "video_track"`, the producer-side + /// `track_id`, and the `mid` itself. + pub(crate) fn upsert_video_track(&mut self, mid: &str, track_id: &str) { + self.entries.insert( + mid.to_string(), + json!({ "type": "video_track", "track_id": track_id, "mid": mid }), + ); + } + + /// Remove an entry by its key (a data-channel label or a video-track mid). + pub(crate) fn remove_entry(&mut self, key: &str) { + self.entries.remove(key); + } + + /// Render the current manifest as a JSON object string. + pub(crate) fn to_json(&self) -> String { + let object: Map = self + .entries + .iter() + .map(|(key, value)| (key.clone(), value.clone())) + .collect(); + serde_json::to_string(&Value::Object(object)).unwrap_or_else(|_| "{}".to_string()) + } +} + +#[cfg(test)] +mod tests { + //! Peer-free unit tests for the transport translation and the manifest + //! model. None of these touch a `PeerConnection`, a socket, or the GIL: they + //! pin the pure SDP/state/reliability translation and the control-channel + //! manifest JSON schema deterministically. + + use super::*; + use datachannel::{IceCandidate, Reliability}; + use std::collections::BTreeSet; + + // The SDP tests drive `datachannel::sdp::parse_sdp` directly (the same parser + // `parse_session` wraps) rather than `parse_session` itself: the wrapper + // returns a `PyErr`, and referencing pyo3's runtime from a `cargo test` + // binary would need libpython linked (the crate is an `extension-module`). + // The wrapper is a one-line `map_err`; the semantics under test are the + // parser's round trip, exercised here without the interpreter. + + fn reliability(unordered: bool, unreliable: bool) -> Reliability { + Reliability { + unordered, + unreliable, + max_packet_life_time: 0, + max_retransmits: 0, + } + } + + // --- Chrome a=ssrc cname munge ------------------------------------------- + + #[test] + fn munge_adds_cname_to_a_bare_ssrc_line() { + let sdp = "v=0\r\nm=video 9 UDP/TLS/RTP/SAVPF 96\r\na=ssrc:1\r\n"; + let out = munge_ssrc_cname(sdp, "neuracore"); + assert!(out.contains("a=ssrc:1 cname:neuracore\r\n"), "{out}"); + // Untouched lines pass through unchanged, CRLF preserved. + assert!(out.starts_with("v=0\r\nm=video 9 UDP/TLS/RTP/SAVPF 96\r\n")); + } + + #[test] + fn munge_is_idempotent_and_leaves_qualified_ssrc_lines() { + // An already-qualified ssrc line (Chrome's own, or a second munge pass) is + // left exactly as-is. + let already = "a=ssrc:1 cname:neuracore\r\n"; + assert_eq!(munge_ssrc_cname(already, "neuracore"), already); + let other_attr = "a=ssrc:42 msid:stream track\r\n"; + assert_eq!(munge_ssrc_cname(other_attr, "neuracore"), other_attr); + } + + #[test] + fn munge_preserves_lf_only_documents_and_the_final_unterminated_line() { + // `\n`-only input stays `\n`-only (no spurious CR), and a trailing line + // without an EOL is still munged. + let sdp = "a=ssrc:7\nc=IN IP4 0.0.0.0\na=ssrc:8"; + let out = munge_ssrc_cname(sdp, "nc"); + assert_eq!(out, "a=ssrc:7 cname:nc\nc=IN IP4 0.0.0.0\na=ssrc:8 cname:nc"); + } + + // --- SDP / state / reliability translation ------------------------------- + + #[test] + fn sdp_type_maps_to_wire_strings() { + assert_eq!(sdp_type_str(&SdpType::Offer), "offer"); + assert_eq!(sdp_type_str(&SdpType::Answer), "answer"); + assert_eq!(sdp_type_str(&SdpType::Pranswer), "pranswer"); + assert_eq!(sdp_type_str(&SdpType::Rollback), "rollback"); + } + + #[test] + fn connection_state_maps_to_on_state_strings() { + assert_eq!(connection_state_str(&ConnectionState::New), "new"); + assert_eq!(connection_state_str(&ConnectionState::Connecting), "connecting"); + assert_eq!(connection_state_str(&ConnectionState::Connected), "connected"); + assert_eq!( + connection_state_str(&ConnectionState::Disconnected), + "disconnected" + ); + assert_eq!(connection_state_str(&ConnectionState::Failed), "failed"); + assert_eq!(connection_state_str(&ConnectionState::Closed), "closed"); + } + + // --- reconnect / ICE-restart decision ------------------------------------ + + #[test] + fn reconnect_action_surfaces_when_ice_restart_is_unsupported() { + // The binding cannot restart ICE (libjuice single-shot agent), so a + // failed/disconnected connection surfaces reconnect-needed. + assert_eq!( + reconnect_action(ConnectionState::Failed, false), + ReconnectAction::SurfaceReconnect + ); + assert_eq!( + reconnect_action(ConnectionState::Disconnected, false), + ReconnectAction::SurfaceReconnect + ); + // Healthy/transitional states need no recovery. + for state in [ + ConnectionState::New, + ConnectionState::Connecting, + ConnectionState::Connected, + ConnectionState::Closed, + ] { + assert_eq!(reconnect_action(state, false), ReconnectAction::None); + } + } + + #[test] + fn reconnect_action_would_restart_if_the_binding_supported_it() { + // Behind the fake "supported" flag the same states choose an in-place ICE + // restart — the branch that goes live only if libjuice ever gains it. + assert_eq!( + reconnect_action(ConnectionState::Failed, true), + ReconnectAction::IceRestart + ); + assert_eq!( + reconnect_action(ConnectionState::Disconnected, true), + ReconnectAction::IceRestart + ); + // The shipped binding pins the unsupported path. + assert!(!ICE_RESTART_SUPPORTED); + } + + #[test] + fn reliability_kind_hint_reports_what_was_negotiated() { + assert_eq!(reliability_kind_hint(&reliability(false, false)), "reliable"); + assert_eq!(reliability_kind_hint(&reliability(true, false)), "unordered"); + // unreliable wins over unordered (it is checked first). + assert_eq!( + reliability_kind_hint(&reliability(true, true)), + "unreliable" + ); + } + + // --- SDP round trip ------------------------------------------------------ + + /// A minimal but valid data-channel offer SDP. `parse_sdp(.., false)` is the + /// lenient (non-local) mode the consumer/producer use for remote SDPs. + const SAMPLE_SDP: &str = "v=0\r\n\ +o=- 0 0 IN IP4 127.0.0.1\r\n\ +s=-\r\n\ +c=IN IP4 127.0.0.1\r\n\ +t=0 0\r\n\ +m=application 9 UDP/DTLS/SCTP webrtc-datachannel\r\n\ +a=mid:0\r\n\ +a=sctp-port:5000\r\n"; + + #[test] + fn parse_sdp_round_trip_is_semantically_faithful() { + let parsed = datachannel::sdp::parse_sdp(SAMPLE_SDP, false) + .expect("sample data-channel SDP should parse"); + // This is exactly what parse_session wraps into a SessionDescription; + // the sdp_type comes from the caller, not the wire SDP. + let session = SessionDescription { + sdp: parsed, + sdp_type: SdpType::Offer, + }; + assert_eq!(sdp_type_str(&session.sdp_type), "offer"); + + // Render back out and re-parse: a faithful round trip parses again and + // preserves the application m-line and the origin address. + let rendered = session.sdp.to_string(); + assert!( + rendered.contains("m=application"), + "data m-line lost on round trip: {rendered}" + ); + assert!( + rendered.contains("IN IP4 127.0.0.1"), + "origin address lost on round trip: {rendered}" + ); + datachannel::sdp::parse_sdp(&rendered, false) + .expect("rendered SDP should parse again (idempotent)"); + } + + #[test] + fn parse_sdp_rejects_garbage() { + // parse_session maps this Err into a ValueError for callers; here we pin + // the underlying rejection. + assert!(datachannel::sdp::parse_sdp("not an sdp at all", false).is_err()); + } + + // --- ICE candidate parse/format ----------------------------------------- + + #[test] + fn ice_candidate_carries_candidate_and_mid_faithfully() { + let wire = "candidate:1 1 udp 2113937151 127.0.0.1 54321 typ host"; + let candidate = IceCandidate { + candidate: wire.to_string(), + mid: "0".to_string(), + }; + // Both fields survive construction (the shape on_candidate emits and + // add_remote_candidate consumes). + assert_eq!(candidate.candidate, wire); + assert_eq!(candidate.mid, "0"); + } + + // --- ManifestState schema ------------------------------------------------ + + fn manifest_object(state: &ManifestState) -> Map { + serde_json::from_str::(&state.to_json()) + .expect("manifest renders valid JSON") + .as_object() + .expect("manifest is a flat JSON object") + .clone() + } + + #[test] + fn data_channel_entry_matches_the_fixed_schema() { + let mut state = ManifestState::default(); + state.upsert_data_channel("telemetry", "reliable"); + let object = manifest_object(&state); + // Keyed by label; descriptor carries the type discriminator + kind. + assert_eq!( + object.get("telemetry"), + Some(&json!({ "type": "data_channel", "kind": "reliable" })) + ); + } + + #[test] + fn video_track_entry_matches_the_fixed_schema() { + let mut state = ManifestState::default(); + state.upsert_video_track("v0", "wrist_cam"); + let object = manifest_object(&state); + // Keyed by mid; descriptor carries type, track_id, and the mid itself. + assert_eq!( + object.get("v0"), + Some(&json!({ "type": "video_track", "track_id": "wrist_cam", "mid": "v0" })) + ); + } + + #[test] + fn manifest_is_a_flat_object_with_no_envelope_or_version_key() { + let mut state = ManifestState::default(); + state.upsert_data_channel("telemetry", "reliable"); + state.upsert_video_track("v0", "wrist_cam"); + state.upsert_data_channel("joints", "reliable"); + let object = manifest_object(&state); + assert_eq!( + object.keys().cloned().collect::>(), + BTreeSet::from([ + "joints".to_string(), + "telemetry".to_string(), + "v0".to_string(), + ]) + ); + // No top-level envelope/version key: the keys ARE the stream set. + assert!(!object.contains_key("version")); + assert!(!object.contains_key("streams")); + } + + #[test] + fn control_is_never_a_manifest_entry() { + // Mirror the producer's rule: the control channel carries the manifest + // and is never itself listed in it. + let mut state = ManifestState::default(); + for label in ["control", "telemetry", "joints"] { + if label != CONTROL_LABEL { + state.upsert_data_channel(label, "reliable"); + } + } + assert!(!manifest_object(&state).contains_key(CONTROL_LABEL)); + } + + #[test] + fn remove_entry_drops_only_the_keyed_stream() { + let mut state = ManifestState::default(); + state.upsert_data_channel("telemetry", "reliable"); + state.upsert_video_track("v0", "wrist_cam"); + state.remove_entry("v0"); + let object = manifest_object(&state); + assert!(!object.contains_key("v0")); + assert!(object.contains_key("telemetry")); + } + + #[test] + fn to_json_republishes_the_full_state_every_call() { + // Each render is an atomic full-state message: every entry, every time. + let mut state = ManifestState::default(); + state.upsert_data_channel("a", "reliable"); + assert_eq!(manifest_object(&state).len(), 1); + state.upsert_data_channel("b", "reliable"); + let object = manifest_object(&state); + assert_eq!(object.len(), 2); + assert!(object.contains_key("a") && object.contains_key("b")); + } + + // --- bind-address selection ---------------------------------------------- + + #[test] + fn bind_address_defaults_to_loopback_and_honours_the_override() { + // This is the only test that touches NEURACORE_WEBRTC_BIND_ADDRESS, so + // the set/remove cannot race another test reading the same variable. + std::env::remove_var("NEURACORE_WEBRTC_BIND_ADDRESS"); + assert_eq!(bind_address(), "127.0.0.1"); + std::env::set_var("NEURACORE_WEBRTC_BIND_ADDRESS", "10.1.2.3"); + assert_eq!(bind_address(), "10.1.2.3"); + std::env::remove_var("NEURACORE_WEBRTC_BIND_ADDRESS"); + } +} diff --git a/rust/scripts/build_webrtc_artefact.sh b/rust/scripts/build_webrtc_artefact.sh new file mode 100755 index 000000000..b9308b73e --- /dev/null +++ b/rust/scripts/build_webrtc_artefact.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash +# Build ONLY the neuracore_webrtc cdylib and drop it into the Python package tree +# as neuracore/core/streaming/p2p/_native_webrtc.so (renamed from +# libneuracore_webrtc.so so PyO3's PyInit__native_webrtc is import-discoverable). +# +# This is the slice of build_wheel_artefacts.sh the WebRTC stack needs: the +# frontend integration devcontainer wants the NCD_RUST_WEBRTC native module but +# not the data-daemon / producer artefacts, so it avoids building (and pulling the +# deps of) those two crates. The .so is gated at runtime by NCD_RUST_WEBRTC (see +# neuracore/core/streaming/p2p/webrtc_selection.py). +# +# See docs/rust_data_daemon_development.md#packaging-the-wheel for the rationale. + +set -euo pipefail + +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +workspace_root="$(cd "$script_dir/.." && pwd)" +repo_root="$(cd "$workspace_root/.." && pwd)" + +webrtc_dst="$repo_root/neuracore/core/streaming/p2p/_native_webrtc.so" + +# Pick the interpreter PyO3 links against so the .so matches the target Python's +# ABI (the extension is not abi3). Caller-supplied PYO3_PYTHON always wins; else +# probe VIRTUAL_ENV / CONDA_PREFIX / system, then fall back to python3 on PATH. +# Mirrors build_wheel_artefacts.sh. +if [[ -z "${PYO3_PYTHON:-}" ]]; then + pyo3_candidates=() + [[ -n "${VIRTUAL_ENV:-}" ]] && pyo3_candidates+=("$VIRTUAL_ENV/bin/python") + [[ -n "${CONDA_PREFIX:-}" ]] && pyo3_candidates+=("$CONDA_PREFIX/bin/python") + pyo3_candidates+=("/usr/bin/python") + for candidate in "${pyo3_candidates[@]}"; do + if [[ -x "$candidate" ]]; then + export PYO3_PYTHON="$candidate" + break + fi + done + if [[ -z "${PYO3_PYTHON:-}" ]]; then + if command -v python3 >/dev/null 2>&1; then + export PYO3_PYTHON + PYO3_PYTHON="$(command -v python3)" + else + echo "error: no python interpreter found; set PYO3_PYTHON or install python3" >&2 + exit 1 + fi + fi +fi +echo "==> building neuracore_webrtc against PYO3_PYTHON=$PYO3_PYTHON" + +# Patch the libdatachannel that datachannel-sys builds to cap its DTLS retransmit +# timer (otherwise every WebRTC connection eats OpenSSL's 1s loopback retransmit; +# see reports/PR2-data-path.md). Idempotent and only forces a rebuild on change. +echo "==> patch libdatachannel (DTLS retransmit timer)" +bash "$script_dir/patch_libdatachannel.sh" + +echo "==> cargo build --release -p neuracore_webrtc" +cargo build --release --manifest-path "$workspace_root/Cargo.toml" -p neuracore_webrtc + +webrtc_src="$workspace_root/target/release/libneuracore_webrtc.so" +if [[ ! -f "$webrtc_src" ]]; then + echo "error: cdylib not found at $webrtc_src (Linux-first; macOS/Windows unsupported)" >&2 + exit 1 +fi +install -m 0755 "$webrtc_src" "$webrtc_dst" +echo " wrote $webrtc_dst" diff --git a/rust/scripts/build_wheel_artefacts.sh b/rust/scripts/build_wheel_artefacts.sh index 3a4e8c126..e2307dc18 100755 --- a/rust/scripts/build_wheel_artefacts.sh +++ b/rust/scripts/build_wheel_artefacts.sh @@ -9,6 +9,11 @@ # 2. The data_daemon_producer cdylib -> neuracore/data_daemon/_native_producer.so # Renamed from libdata_daemon_producer.so so PyO3's PyInit__native_producer # is discoverable by the Python import machinery. +# 3. The neuracore_webrtc cdylib -> +# neuracore/core/streaming/p2p/_native_webrtc.so +# Renamed from libneuracore_webrtc.so so PyO3's PyInit__native_webrtc is +# discoverable. Gated at runtime by NCD_RUST_WEBRTC (see +# neuracore/core/streaming/p2p/webrtc_selection.py). # # See docs/rust_data_daemon_development.md#packaging-the-wheel for the rationale. @@ -22,6 +27,8 @@ package_dir="$repo_root/neuracore/data_daemon" bin_dst="$package_dir/bin/data-daemon" cdylib_dst="$package_dir/_native_producer.so" +webrtc_dst="$repo_root/neuracore/core/streaming/p2p/_native_webrtc.so" + # PyO3's build-config probes (in order) PYO3_PYTHON, VIRTUAL_ENV/bin/python, # CONDA_PREFIX/bin/python, then /usr/bin/python. On minimal Debian/Ubuntu # images only python3 is on PATH and some dev environments set VIRTUAL_ENV to @@ -60,6 +67,15 @@ cargo build --release --manifest-path "$workspace_root/Cargo.toml" -p data-daemo echo "==> cargo build --release -p data_daemon_producer" cargo build --release --manifest-path "$workspace_root/Cargo.toml" -p data_daemon_producer +# Patch the libdatachannel that datachannel-sys builds to cap its DTLS retransmit +# timer (otherwise every WebRTC connection eats OpenSSL's 1s loopback retransmit; +# see reports/PR2-data-path.md). Idempotent and only forces a rebuild on change. +echo "==> patch libdatachannel (DTLS retransmit timer)" +bash "$script_dir/patch_libdatachannel.sh" + +echo "==> cargo build --release -p neuracore_webrtc" +cargo build --release --manifest-path "$workspace_root/Cargo.toml" -p neuracore_webrtc + mkdir -p "$(dirname "$bin_dst")" install -m 0755 "$workspace_root/target/release/data-daemon" "$bin_dst" echo " wrote $bin_dst" @@ -74,3 +90,12 @@ if [[ ! -f "$cdylib_src" ]]; then fi install -m 0755 "$cdylib_src" "$cdylib_dst" echo " wrote $cdylib_dst" + +webrtc_src="$workspace_root/target/release/libneuracore_webrtc.so" +if [[ ! -f "$webrtc_src" ]]; then + echo "error: cdylib not found at $webrtc_src" >&2 + echo " (data-daemon-rewrite.md is Linux-first; macOS/Windows are not supported)" >&2 + exit 1 +fi +install -m 0755 "$webrtc_src" "$webrtc_dst" +echo " wrote $webrtc_dst" diff --git a/rust/scripts/patch_libdatachannel.sh b/rust/scripts/patch_libdatachannel.sh new file mode 100755 index 000000000..cfa032067 --- /dev/null +++ b/rust/scripts/patch_libdatachannel.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash +# Idempotently patch the vendored libdatachannel that `datachannel-sys` builds, +# to cap the OpenSSL DTLS handshake retransmit timer. +# +# Why: on fast loopback the DTLS responder's first flight is dropped while its +# ICE transport is momentarily not yet Connected (libdatachannel gates outgoing +# packets on ICE state). OpenSSL's default DTLS retransmit then waits a full +# second before resending, so every WebRTC connection takes ~1006ms and the +# connect-latency SLO (< 500ms p95) is impossible to meet. Capping the timer to +# ~50ms initial resends the dropped flight quickly: connect drops to ~56ms. +# +# This edits the crate source under CARGO_HOME in place (datachannel-sys ships +# libdatachannel inside the published crate; there is no upstream knob and no +# lighter override point). It is idempotent — re-runs detect the marker and skip +# — and only forces a datachannel-sys rebuild when it actually applies the patch. +# +# Invoked by build_wheel_artefacts.sh before building neuracore_webrtc. See +# reports/PR2-data-path.md "DTLS retransmit on loopback". + +set -euo pipefail + +script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +workspace_root="$(cd "$script_dir/.." && pwd)" +manifest="$workspace_root/Cargo.toml" + +# Make sure the datachannel-sys source is fetched/extracted so we can find it. +cargo fetch --manifest-path "$manifest" >/dev/null 2>&1 || true + +cargo_home="${CARGO_HOME:-$HOME/.cargo}" +target="" +for d in "$cargo_home"/registry/src/*/datachannel-sys-0.23.*; do + candidate="$d/libdatachannel/src/impl/dtlstransport.cpp" + if [[ -f "$candidate" ]]; then + target="$candidate" + break + fi +done + +if [[ -z "$target" ]]; then + echo "warn: datachannel-sys source not found under $cargo_home; skipping DTLS patch" >&2 + exit 0 +fi + +if grep -q "NEURACORE PATCH" "$target"; then + echo "==> libdatachannel DTLS retransmit patch already applied" + exit 0 +fi + +python3 - "$target" <<'PY' +import sys + +path = sys.argv[1] +with open(path, "r", encoding="utf-8") as fh: + src = fh.read() + +anchor = "\t\tSSL_set_ex_data(mSsl, TransportExIndex, this);\n" +if anchor not in src: + sys.exit("error: DTLS patch anchor not found; libdatachannel layout changed") + +block = ( + "\n" + "\t\t// NEURACORE PATCH: cap the DTLS handshake retransmit timer. On fast\n" + "\t\t// loopback the responder's first flight is dropped while its ICE is not\n" + "\t\t// yet Connected (icetransport.cpp gates outgoing on ICE state); OpenSSL's\n" + "\t\t// default 1s initial retransmit then dominates connection latency. Cap it\n" + "\t\t// so a dropped flight is resent in tens of ms, not a full second. See\n" + "\t\t// reports/PR2-data-path.md \"DTLS retransmit on loopback\".\n" + "\t\tDTLS_set_timer_cb(mSsl, [](SSL *, unsigned int timer_us) -> unsigned int {\n" + "\t\t\tunsigned int first = 50000; /* 50 ms initial */\n" + "\t\t\tunsigned int cap = 1000000; /* 1 s backoff cap */\n" + "\t\t\tunsigned int next = (timer_us == 0) ? first : timer_us * 2;\n" + "\t\t\tif (next < first) next = first;\n" + "\t\t\tif (next > cap) next = cap;\n" + "\t\t\treturn next;\n" + "\t\t});\n" +) + +src = src.replace(anchor, anchor + block, 1) +with open(path, "w", encoding="utf-8") as fh: + fh.write(src) +print(f"==> applied libdatachannel DTLS retransmit patch to {path}") +PY + +# Force a datachannel-sys rebuild so the patched C++ is recompiled. cargo's +# fingerprint does not track edits to the crate's bundled C source, so drop the +# build artefacts explicitly. (On a fresh checkout there is nothing to remove.) +rm -rf "$workspace_root"/target/*/.fingerprint/datachannel-sys-* \ + "$workspace_root"/target/*/build/datachannel-sys-* \ + "$workspace_root"/target/*/deps/libdatachannel_sys-* \ + "$workspace_root"/target/*/.fingerprint/neuracore_webrtc-* \ + "$workspace_root"/target/*/libneuracore_webrtc.so 2>/dev/null || true diff --git a/setup.py b/setup.py index f1c681bff..19893cd99 100644 --- a/setup.py +++ b/setup.py @@ -21,9 +21,13 @@ # build (which does not run that script), so their presence is what decides # whether this is a binary wheel. _DATA_DAEMON_DIR = os.path.join(os.path.dirname(__file__), "neuracore", "data_daemon") +_P2P_DIR = os.path.join( + os.path.dirname(__file__), "neuracore", "core", "streaming", "p2p" +) _RUST_ARTEFACTS = ( os.path.join(_DATA_DAEMON_DIR, "bin", "data-daemon"), os.path.join(_DATA_DAEMON_DIR, "_native_producer.so"), + os.path.join(_P2P_DIR, "_native_webrtc.so"), ) _HAS_RUST_ARTEFACTS = all(os.path.exists(path) for path in _RUST_ARTEFACTS) @@ -59,6 +63,12 @@ def has_ext_modules(self) -> bool: "bin/data-daemon", "_native_producer.so", ], + "neuracore.core.streaming.p2p": [ + # Pre-built neuracore_webrtc cdylib (the Rust WebRTC streaming + # core). Generated by the same build script, gitignored, and gated + # at runtime by NCD_RUST_WEBRTC; absent from a plain checkout. + "_native_webrtc.so", + ], }, version=version, author="Stephen James", diff --git a/tests/integration/platform/data_daemon/shared/runners.py b/tests/integration/platform/data_daemon/shared/runners.py index 3cc5c45ff..ba67ddefb 100644 --- a/tests/integration/platform/data_daemon/shared/runners.py +++ b/tests/integration/platform/data_daemon/shared/runners.py @@ -75,7 +75,6 @@ def offline_daemon_running() -> Generator[None, None, None]: """ with scoped_daemon_storage_env(), scoped_offline_profile(): try: - stop_daemon() assert_daemon_cleanup() ensure_daemon_running(timeout_s=DEFAULT_DAEMON_STARTUP_TIMEOUT_SECONDS) with Timer(MAX_TIME_TO_START_S, label="nc.login", always_log=True): diff --git a/tests/integration/webrtc/__init__.py b/tests/integration/webrtc/__init__.py new file mode 100644 index 000000000..b49e81908 --- /dev/null +++ b/tests/integration/webrtc/__init__.py @@ -0,0 +1,13 @@ +"""In-process integration suite for the Rust WebRTC streaming core. + +Two `neuracore` native peers — a `Producer` (sole offerer) and a `Consumer` +(answer-only) — are wired together over an in-process signaling relay (no +separate signaling server). The suite drives asynchronous add/remove of data +and video streams with renegotiation across three tests: behavioural +correctness, data integrity, and performance. + +Written red-first against the PR0 stubs: every assertion that depends on real +WebRTC behaviour is marked ``xfail(strict=True)`` and grouped by the PR that +greens it. Later PRs flip a slice to green by deleting its marker only — never +by editing an assertion. See [markers.py](shared/markers.py) for the map. +""" diff --git a/tests/integration/webrtc/chrome_interop.py b/tests/integration/webrtc/chrome_interop.py new file mode 100644 index 000000000..19acea34b --- /dev/null +++ b/tests/integration/webrtc/chrome_interop.py @@ -0,0 +1,317 @@ +"""Chrome interop decision gate: our producer -> real Google Chrome, under netem. + +The fast loopback suite validates protocol mechanics on the libdatachannel<-> +libdatachannel path (the 1% path). The real consumer is Chrome (the 99% path), +and the whole REMB-driven-adaptation choice (`reports/SPIKE-pr5-media-chain.md`) +rests on Chrome-plus-REMB holding the live-preview SLOs under constraint. This +harness is that decision gate: it drives **installed Google Chrome** (Playwright +channel "chrome", *not* open-source Chromium, which often lacks the H.264 +decoder) as a recvonly WebRTC peer, with our stack as the sole offerer. + +Shape: + * The producer offers a sendonly H.264 video track (goog-remb + nack, no + transport-cc), so Chrome runs its own receive-side bandwidth estimator and + sends REMB back toward the producer. + * Signaling is bridged in-process: the producer's drained events + (`on_local_description` offer, `on_local_candidate`) are fed to the browser + page, and the browser's answer + ICE candidates are fed back to + `set_remote_answer` / `add_remote_candidate`. + * The whole process runs inside a private netns with a netem-shaped `lo` (same + out-of-band shaping as `netem_runner`), so Chrome and the producer talk over + the constrained loopback. + +It asserts from two sides and prints a JSON verdict: + * Chrome `getStats` inbound-rtp: framesPerSecond / framesDecoded hold a floor, + freezeCount / totalFreezesDuration stay low, frameHeight reflects any + downscale step, no decode stall. + * The producer's structured ladder: REMB/RR drove a step down under constraint + (max_step > 0). + +Run via the gated test (`test_chrome_interop`) or directly: + unshare -n env NCD_RUST_WEBRTC=1 PYTHONPATH=$PWD \ + python3 -m tests.integration.webrtc.chrome_interop +""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +import time + +# The receiver page below is embedded JavaScript whose lines exceed the Python +# line-length limit; that is fine for an inline harness page. +# ruff: noqa: E501 + + +RECEIVER_HTML = """ + + + +""" + +DEFAULT_NETEM = "delay 20ms rate 400kbit limit 48" + + +def _chrome_munge(sdp: str) -> str: + """Make a libdatachannel offer acceptable to Chrome's stricter SDP parser. + + libdatachannel emits a bare ``a=ssrc:`` line; Chrome rejects it + ("a=ssrc Expects 2 fields") and requires at least an attribute such as + ``cname``. We append ``cname`` (matching the packetizer's RTCP CNAME) to any + bare ssrc line. This is a real producer-side cutover concern flagged in + reports/PR5-congestion.md; the munge keeps the interop gate honest without + changing the loopback path (libdatachannel parses the bare line fine). + """ + out = [] + for line in sdp.replace("\r\n", "\n").split("\n"): + if line.startswith("a=ssrc:") and " " not in line.strip()[len("a=ssrc:") :]: + ssrc = line.strip()[len("a=ssrc:") :] + out.append(f"a=ssrc:{ssrc} cname:neuracore") + else: + out.append(line) + return "\r\n".join(out) + + +def _sh(cmd: str) -> tuple[int, str]: + proc = subprocess.run(cmd, shell=True, capture_output=True, text=True) # noqa: S602 + return proc.returncode, (proc.stdout + proc.stderr).strip() + + +def _setup_link() -> str: + netem = os.environ.get("NEURACORE_WEBRTC_NETEM", DEFAULT_NETEM) + rc, out = _sh("ip link set lo up") + if rc != 0: + raise RuntimeError(f"could not bring up lo: {out}") + _sh("tc qdisc del dev lo root") + rc, out = _sh(f"tc qdisc add dev lo root netem {netem}") + if rc != 0: + raise RuntimeError(f"could not apply netem '{netem}': {out}") + return netem + + +def _run(seconds: float, settle: float) -> dict: + os.environ.setdefault("NCD_RUST_WEBRTC", "1") + # Turn on the producer's Chrome-only SDP munge (bare a=ssrc -> a=ssrc cname), + # which Chrome's parser requires. Gated so the loopback path stays byte-identical. + os.environ.setdefault("NCD_WEBRTC_CHROME_SDP", "1") + from playwright.sync_api import sync_playwright + + from neuracore.core.streaming.p2p.webrtc_selection import load_native + from tests.integration.webrtc.shared import metrics + from tests.integration.webrtc.shared.frames import encode_frame + + native = load_native() + track_id = "cam0" + producer = native.Producer(connection_id=None, frame_queue_capacity=16) + + with sync_playwright() as p: + browser = p.chromium.launch( + channel="chrome", + headless=True, + args=[ + "--no-sandbox", + "--autoplay-policy=no-user-gesture-required", + "--disable-gpu", + # In an isolated netns Chrome's default mDNS ICE candidate + # obfuscation (.local hostnames) hangs — there is no mDNS + # responder. Expose raw host IPs so ICE gathers the 127.0.0.1 + # candidate directly and the handshake completes. + "--disable-features=WebRtcHideLocalIpsWithMdns", + ], + ) + page = browser.new_page() + page.set_content(RECEIVER_HTML) + + # Playwright's sync API is single-threaded, so the whole bridge runs on + # this thread: pump producer signaling-out into the page and the page's + # ICE candidates back. `_pump_signaling` is called both during the + # handshake and throughout the frame loop (trickle ICE continues). + state = {"answered": False} + + def pump_signaling() -> None: + for event in producer.drain_events(): + kind = event.get("kind") + if kind == "on_local_description" and event.get("sdp_type") == "offer": + # Answer every offer (the data-only bootstrap and the video + # renegotiation) on the same Chrome peer connection. The + # producer already munges the bare a=ssrc line for Chrome + # (NCD_WEBRTC_CHROME_SDP, set below); _chrome_munge stays as an + # idempotent backstop in case the env gate is ever off. + answer = page.evaluate( + "(sdp) => window.handleOffer(sdp)", _chrome_munge(event["sdp"]) + ) + producer.set_remote_answer(answer) + state["answered"] = True + elif kind == "on_local_candidate": + page.evaluate( + "(a) => window.addCand(a.cand, a.mid)", + {"cand": event["candidate"], "mid": event.get("mid")}, + ) + if state["answered"]: + for c in page.evaluate("() => window.takeCands()") or []: + producer.add_remote_candidate( + c.get("candidate", ""), c.get("sdpMid") + ) + + # The producer must offer the video track up front so the first drained + # description carries the m-line Chrome answers. + producer.add_data_channel("control", "reliable") + producer.add_video_track(track_id) + + deadline = time.time() + 10 + while not state["answered"] and time.time() < deadline: + pump_signaling() + time.sleep(0.02) + if not state["answered"]: + browser.close() + raise RuntimeError("offer/answer did not complete with Chrome") + + # Feed frames at the source rate and poll both sides. + period = 1.0 / metrics.SOURCE_FPS + total = int(seconds * metrics.SOURCE_FPS) + samples: list[dict] = [] + steps: list[int] = [] + start = time.perf_counter() + last_poll = 0.0 + for i in range(total): + target = start + i * period + now = time.perf_counter() + if target > now: + time.sleep(target - now) + producer.submit_frame(track_id, encode_frame(i)) + t = time.perf_counter() - start + if t - last_poll >= 0.5: + last_poll = t + pump_signaling() # keep trickling ICE both ways + steps.append(producer.congestion_step(track_id) or 0) + s = page.evaluate("() => window.inboundStats()") + if s: + s["t"] = round(t, 2) + samples.append(s) + + max_step = producer.congestion_max_step(track_id) + + # Steady-state Chrome stats: compare two getStats samples in the tail to + # derive the decode fps Chrome actually sustained after adaptation. + tail = [s for s in samples if s["t"] >= settle] + verdict: dict = {"ok": True, "max_step": max_step, "step_timeline": steps} + if len(tail) >= 2: + first, last = tail[0], tail[-1] + dt = last["t"] - first["t"] + decoded_fps = ( + (last["framesDecoded"] - first["framesDecoded"]) / dt if dt else 0.0 + ) + verdict.update( + tail_decoded_fps=round(decoded_fps, 2), + tail_reported_fps=round(last["framesPerSecond"], 2), + freezeCount=last["freezeCount"], + totalFreezesDuration=round(last["totalFreezesDuration"], 3), + frameWidth=last["frameWidth"], + frameHeight=last["frameHeight"], + framesDecoded=last["framesDecoded"], + keyFramesDecoded=last["keyFramesDecoded"], + framesDropped=last["framesDropped"], + packetsLost=last["packetsLost"], + packetsReceived=last["packetsReceived"], + nackCount=last["nackCount"], + pliCount=last["pliCount"], + bytesReceived=last["bytesReceived"], + ) + else: + verdict.update( + ok=False, error=f"insufficient Chrome stats samples: {len(tail)}" + ) + browser.close() + return verdict + + +def _host_ip() -> str: + """The host's primary (non-loopback) IP, which both the producer ICE agent + and Chrome will gather as a host candidate so they pair on the host network. + """ + rc, out = _sh("ip route get 1.1.1.1") + for tok in out.split(): + if tok == "src": + return out.split("src", 1)[1].split()[0] + return "127.0.0.1" + + +def main() -> int: + seconds = float(os.environ.get("NEURACORE_WEBRTC_CHROME_SECONDS", 24)) + settle = float(os.environ.get("NEURACORE_WEBRTC_CHROME_SETTLE", 12)) + # Host mode (no netns): Chrome's WebRTC stack does not function inside an + # isolated net namespace (ICE gathering hangs), and the host loopback cannot + # be netem-shaped (it is the host's). So the Chrome decision gate runs clean + # on the host to validate the wire interop (Chrome decoding the built-in + # chain's H.264 + REMB exchange); the adaptation-under-constraint proof comes + # from the libdatachannel-consumer netem gate. See reports/PR5-congestion.md. + host_mode = os.environ.get("NEURACORE_WEBRTC_CHROME_HOST", "1") not in ("0", "") + try: + if host_mode: + os.environ.setdefault("NEURACORE_WEBRTC_BIND_ADDRESS", _host_ip()) + netem = "none (host mode: clean loopback interop validation)" + else: + netem = _setup_link() + result = _run(seconds, settle) + result["netem"] = netem + result["host_mode"] = host_mode + except Exception as exc: # noqa: BLE001 + result = {"ok": False, "error": f"{type(exc).__name__}: {exc}"} + print(json.dumps(result), flush=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/integration/webrtc/conftest.py b/tests/integration/webrtc/conftest.py new file mode 100644 index 000000000..955dcc9f6 --- /dev/null +++ b/tests/integration/webrtc/conftest.py @@ -0,0 +1,125 @@ +"""Fixtures for the in-process WebRTC integration suite. + +Selects the Rust WebRTC stack (``NCD_RUST_WEBRTC=1``), loads the native module, +and hands tests a started :class:`Relay` (or a factory for several, for the +multi-consumer case). The factory and the relay are real and never call a +stubbed method, so fixture setup always succeeds against the PR0 stubs — only +the test bodies go red. +""" + +from __future__ import annotations + +import os +from collections.abc import Callable, Iterator + +import pytest + +# The suite exercises the Rust stack; select it before the native module loads. +os.environ.setdefault("NCD_RUST_WEBRTC", "1") + +from neuracore.core.streaming.p2p.webrtc_selection import load_native # noqa: E402 +from tests.integration.webrtc.shared.harness import BroadcastRelay, Relay # noqa: E402 +from tests.integration.webrtc.shared.metrics import Metrics, emit # noqa: E402 +from tests.integration.webrtc.shared.server_transport import ( # noqa: E402 + ServerBroadcastRelay, + ServerRelay, + signaling_config_from_env, +) + + +@pytest.fixture(scope="session") +def native() -> object: + """The compiled ``_native_webrtc`` module, or skip if it is not built.""" + try: + return load_native() + except Exception as exc: # noqa: BLE001 - RuntimeError hint from the loader + pytest.skip(f"native webrtc module unavailable: {exc}") + + +RelayFactory = Callable[..., Relay] + + +@pytest.fixture +def make_relay(native: object) -> Iterator[RelayFactory]: + """Factory creating started producer<->consumer relays, closed on teardown.""" + relays: list[Relay] = [] + + # When the operator sets the signaling env, the same native peers connect + # through the real backend; otherwise the in-process relay is used unchanged. + config = signaling_config_from_env() + + def _make(*, frame_queue_capacity: int = 16, name: str = "relay") -> Relay: + producer = native.Producer( + connection_id=None, frame_queue_capacity=frame_queue_capacity + ) + consumer = native.Consumer(connection_id=None) + if config is not None: + relay: Relay = ServerRelay( + producer, consumer, config=config, name=name + ).start() + else: + relay = Relay(producer, consumer, name=name).start() + relays.append(relay) + return relay + + yield _make + + for relay in relays: + relay.close() + + +@pytest.fixture +def relay(make_relay: RelayFactory) -> Relay: + """A single started producer<->consumer relay.""" + return make_relay() + + +BroadcastFactory = Callable[..., BroadcastRelay] + + +@pytest.fixture +def make_broadcast(native: object) -> Iterator[BroadcastFactory]: + """Factory creating started one-broadcaster<->many-consumers relays. + + The returned relay owns one ``Broadcaster``; ``add_consumer(id)`` builds an + answer-only ``Consumer`` peer and wires its signaling. Closed on teardown. + """ + relays: list[BroadcastRelay] = [] + config = signaling_config_from_env() + + def _make( + *, frame_queue_capacity: int = 16, name: str = "broadcast" + ) -> BroadcastRelay: + broadcaster = native.Broadcaster( + connection_id=None, frame_queue_capacity=frame_queue_capacity + ) + if config is not None: + relay: BroadcastRelay = ServerBroadcastRelay( + broadcaster, config=config, name=name + ).start() + else: + relay = BroadcastRelay(broadcaster, name=name).start() + relays.append(relay) + return relay + + def _add_consumer(relay: BroadcastRelay, consumer_id: str) -> object: + consumer = native.Consumer(connection_id=consumer_id) + relay.add_consumer(consumer_id, consumer) + return consumer + + # Expose the consumer constructor so tests can join consumers without reaching + # for the native module directly. + _make.add_consumer = _add_consumer # type: ignore[attr-defined] + + yield _make + + for relay in relays: + relay.close() + + +@pytest.fixture(scope="session") +def perf_metrics() -> Iterator[Metrics]: + """Shared structured perf output, emitted as JSON at session teardown.""" + metrics = Metrics() + yield metrics + emit(metrics) diff --git a/tests/integration/webrtc/netem_runner.py b/tests/integration/webrtc/netem_runner.py new file mode 100644 index 000000000..4e6fa16b2 --- /dev/null +++ b/tests/integration/webrtc/netem_runner.py @@ -0,0 +1,187 @@ +"""Constrained-link harness body, run inside a private network namespace. + +`test_perf_under_constrained_link` cannot shape the loopback the in-process peers +use: the container is `network_mode: host`, so its `lo` is the host's, and a +long-lived global tokio runtime cannot be moved into a namespace after its +sockets exist. So the test re-execs *this* script under ``unshare -n`` (a fresh +net namespace), and here — already inside that namespace, before any peer is +built — we bring up the namespace's private `lo` and apply a real ``tc netem`` +profile to it. Everything the peers do then traverses the shaped loopback. + +The script runs one producer->consumer relay at the source fps for a fixed +window, then prints a single JSON line of results to stdout for the parent test +to assert against: + + {"delivered_fps", "corrupted", "closed", "max_step", "sent", "ok"} + +Env in: + * ``NEURACORE_WEBRTC_NETEM`` tc netem args (default: a profile that bites) + * ``NEURACORE_WEBRTC_NETEM_SECONDS`` window length (default 12) + * ``NCD_WEBRTC_DISABLE_ADAPT`` if set, the producer pins the finest rung, so + the same constraint should *fail* the floor — the proof the test bites. + +Requires CAP_NET_ADMIN (tc) and that it is already in a private netns (the +parent supplies CAP_SYS_ADMIN via ``unshare -n``). +""" + +# cspell: ignore WEBRTC + +from __future__ import annotations + +import json +import os +import subprocess +import sys +import time + +# Default netem profile: a delay plus a rate low enough that the full-resolution +# top rung overflows the qdisc (drops -> RR loss -> the producer's estimator +# degrades), but a downscaled coarser rung fits. Tuned so the test is red without +# adaptation and green with it. +# +# The rate sits in the wide window between the full-resolution top rung +# (~2.5 Mbit, maxrate-capped) and the half-resolution bottom rung (~0.3 Mbit). +# That separation only exists because the synthetic frames carry high-frequency +# detail (see shared/frames.py): a smooth gradient compresses to almost nothing at +# every resolution, leaving the downscale ladder with no rate lever — the earlier +# `rate 400kbit limit 48` profile only bit because a since-fixed producer bug +# emitted ~7 RTP slices per frame, inflating packet/header rate ~7x (see +# reports/SPIKE-chrome-pframe.md). With one slice per frame and high-frequency +# content the gate is driven by genuine bitrate, not a packetization artefact. +DEFAULT_NETEM = "delay 20ms rate 800kbit limit 64" + + +def _sh(cmd: str) -> tuple[int, str]: + proc = subprocess.run(cmd, shell=True, capture_output=True, text=True) # noqa: S602 + return proc.returncode, (proc.stdout + proc.stderr).strip() + + +def _setup_link() -> str: + """Bring up the namespace's private loopback and shape it with netem.""" + netem = os.environ.get("NEURACORE_WEBRTC_NETEM", DEFAULT_NETEM) + rc, out = _sh("ip link set lo up") + if rc != 0: + raise RuntimeError(f"could not bring up lo in the netns: {out}") + # Clear any inherited qdisc, then apply the profile to loopback. + _sh("tc qdisc del dev lo root") + rc, out = _sh(f"tc qdisc add dev lo root netem {netem}") + if rc != 0: + raise RuntimeError(f"could not apply netem '{netem}': {out}") + return netem + + +def _run_relay(seconds: float) -> dict: + # Imported here so the import cost is paid inside the namespace, after the + # link is shaped. + os.environ.setdefault("NCD_RUST_WEBRTC", "1") + from neuracore.core.streaming.p2p.webrtc_selection import load_native + from tests.integration.webrtc.shared import metrics + from tests.integration.webrtc.shared.frames import ( + decode_frame, + parse_video_frame_event, + ) + from tests.integration.webrtc.shared.harness import ( + Relay, + bootstrap_connection, + collect_video_frames, + decoded_counters, + recv_time, + submit_at_rate, + ) + + # The adaptation is loss-driven, so a constrained link has a transient + # head — packets lost while the estimator settles the ladder — before it + # reaches a steady, fitting rung. The contract ("degrades gracefully, + # delivered fps holds the floor, no corruption") is a *steady-state* + # property, so we measure only the tail after the settle window. + settle = float(os.environ.get("NEURACORE_WEBRTC_NETEM_SETTLE", 12.0)) + + native = load_native() + track_id = "cam0" + relay = Relay( + native.Producer(connection_id=None, frame_queue_capacity=16), + native.Consumer(connection_id=None), + ).start() + try: + relay.producer.add_data_channel("control", "reliable") + relay.producer.add_video_track(track_id) + bootstrap_connection(relay) + + # Poll the ladder rung while submitting so the timeline is visible. + import threading + + steps: list[int] = [] + stop = threading.Event() + + def poll() -> None: + while not stop.is_set(): + s = relay.producer.congestion_step(track_id) + if s is not None: + steps.append(s) + time.sleep(0.5) + + poller = threading.Thread(target=poll, daemon=True) + poller.start() + origin = time.perf_counter() + sent, _ = submit_at_rate( + relay, track_id, fps=metrics.SOURCE_FPS, seconds=seconds + ) + stop.set() + poller.join(timeout=1.0) + time.sleep(0.3) + frames = collect_video_frames(relay, track_id) + + # Whole-window and steady-state (tail) figures. Delivered counts only + # frames whose embedded checksum verifies — a corrupt frame's recovered + # counter is garbage and must not inflate the delivered count. + all_counters, all_corrupted = decoded_counters(frames) + tail_start = origin + settle + tail_window = max(seconds - settle, 1e-9) + tail_ok: set[int] = set() + tail_corrupt = 0 + for event in frames: + rt = recv_time(event) + if rt is None or rt < tail_start: + continue + _, _, array = parse_video_frame_event(event) + counter, ok = decode_frame(array) + if ok: + tail_ok.add(counter) + else: + tail_corrupt += 1 + + max_step = relay.producer.congestion_max_step(track_id) + final_step = relay.producer.congestion_step(track_id) + closed = "closed" in relay.state_sequence("consumer") + return { + "sent": sent, + "window_delivered_fps": len(set(all_counters)) / seconds, + "window_corrupted": len(all_corrupted), + "delivered_fps": len(tail_ok) / tail_window, + "corrupted": tail_corrupt, + "closed": closed, + "max_step": max_step, + "final_step": final_step, + "step_timeline": steps, + "settle": settle, + } + finally: + relay.close() + + +def main() -> int: + seconds = float(os.environ.get("NEURACORE_WEBRTC_NETEM_SECONDS", 24)) + try: + netem = _setup_link() + result = _run_relay(seconds) + result["netem"] = netem + result["adapt_disabled"] = bool(os.environ.get("NCD_WEBRTC_DISABLE_ADAPT")) + result["ok"] = True + except Exception as exc: # noqa: BLE001 - report failure as JSON, not a trace + result = {"ok": False, "error": f"{type(exc).__name__}: {exc}"} + print(json.dumps(result), flush=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/integration/webrtc/shared/__init__.py b/tests/integration/webrtc/shared/__init__.py new file mode 100644 index 000000000..cbe941ffb --- /dev/null +++ b/tests/integration/webrtc/shared/__init__.py @@ -0,0 +1 @@ +"""Shared harness, frame codec, markers, and metrics for the WebRTC suite.""" diff --git a/tests/integration/webrtc/shared/constants.py b/tests/integration/webrtc/shared/constants.py new file mode 100644 index 000000000..7c5fcecd3 --- /dev/null +++ b/tests/integration/webrtc/shared/constants.py @@ -0,0 +1,33 @@ +"""Tuning constants for the WebRTC integration suite. + +Timeouts are deliberately generous — comfortably above the performance SLOs in +[metrics.py](metrics.py) so green runs do not flake — yet irrelevant to the red +runs, where the first stubbed call raises long before any wait elapses. All are +overridable from the environment so CI can shorten or lengthen them. +""" + +from __future__ import annotations + +import os + + +def _f(name: str, default: float) -> float: + return float(os.environ.get(name, default)) + + +# Wait for both peers to report on_state "connected" after the first +# negotiation-triggering call (add_data_channel / add_video_track). +CONNECT_TIMEOUT_S = _f("NEURACORE_WEBRTC_CONNECT_TIMEOUT", 5.0) + +# Wait for a renegotiation to surface a track add/remove at the consumer. +RENEG_TIMEOUT_S = _f("NEURACORE_WEBRTC_RENEG_TIMEOUT", 3.0) + +# Wait for a freshly added data channel to be observed at the consumer. +DC_OPEN_TIMEOUT_S = _f("NEURACORE_WEBRTC_DC_OPEN_TIMEOUT", 3.0) + +# Wait for a known count of data-channel messages to arrive at the consumer. +MESSAGE_TIMEOUT_S = _f("NEURACORE_WEBRTC_MESSAGE_TIMEOUT", 5.0) + +# After the last frame is submitted, how long to keep draining before deciding +# the decoded video stream has gone quiet. +FRAME_SETTLE_TIMEOUT_S = _f("NEURACORE_WEBRTC_FRAME_SETTLE", 5.0) diff --git a/tests/integration/webrtc/shared/frames.py b/tests/integration/webrtc/shared/frames.py new file mode 100644 index 000000000..5bd3765cd --- /dev/null +++ b/tests/integration/webrtc/shared/frames.py @@ -0,0 +1,116 @@ +"""Synthetic video frames with a counter + checksum embedded in the pixels. + +The producer side keeps only bytes + track_id (frame metadata is deferred to +PR5), so order and integrity must travel *inside* the picture. We block-code a +monotonic 32-bit counter plus a 16-bit checksum into a header band of large +solid squares at the top of each frame. Solid blocks survive H.264's chroma +subsampling and deblocking, and decoding samples each block's centre to dodge +edge bleed — so the consumer can recover the counter from a *decoded* frame and +verify it was not corrupted. + +The rest of the frame is a deterministic high-spatial-frequency texture (fine +diagonal stripes) XOR-ed with a per-counter moving gradient, so the encoder has +real, changing content to compress. The high-frequency detail is deliberate: it +makes the encoded bitrate strongly resolution-dependent (fine stripes survive at +full resolution but average toward grey when the congestion ladder downscales), +which is what gives the rungs real bitrate separation. A smooth gradient +compresses to almost nothing at *every* resolution, so the downscale ladder had +no rate lever and the constrained-link netem gate could not distinguish a fat +rung from a thin one (see reports/SPIKE-chrome-pframe.md §constrained-link). +""" + +from __future__ import annotations + +import zlib + +import numpy as np + +WIDTH = 640 +HEIGHT = 480 +CHANNELS = 3 + +# Header band geometry: a grid of solid BLOCK x BLOCK squares, one per bit. +_BLOCK = 40 +_COLS = WIDTH // _BLOCK # 16 blocks across +_HEADER_ROWS = 3 # -> 48 blocks available in the top 120 rows +_BITS = _COLS * _HEADER_ROWS # 48 +_COUNTER_BITS = 32 +_CHECK_BITS = 16 +assert _COUNTER_BITS + _CHECK_BITS == _BITS, "header band must hold counter+checksum" + +_COUNTER_MASK = (1 << _COUNTER_BITS) - 1 +_CHECK_MASK = (1 << _CHECK_BITS) - 1 + + +def _checksum(counter: int) -> int: + """16-bit checksum over the 32-bit counter.""" + return zlib.crc32(counter.to_bytes(4, "big")) & _CHECK_MASK + + +def _payload_bits(counter: int) -> list[int]: + counter &= _COUNTER_MASK + payload = (counter << _CHECK_BITS) | _checksum(counter) + return [(payload >> (_BITS - 1 - i)) & 1 for i in range(_BITS)] + + +def _block_box(idx: int) -> tuple[int, int, int, int]: + row, col = divmod(idx, _COLS) + y0 = row * _BLOCK + x0 = col * _BLOCK + return y0, x0, y0 + _BLOCK, x0 + _BLOCK + + +def encode_frame(counter: int) -> np.ndarray: + """Build a C-contiguous (H, W, 3) uint8 frame carrying ``counter``.""" + # High-frequency body: a static fine-stripe texture (period ~2px, so it + # carries energy near Nyquist and shrinks sharply under downscale) XOR-ed with + # a moving gradient (per-counter, so every frame changes and P-frames carry + # real residual). Deterministic and reproducible. + rows = np.arange(HEIGHT, dtype=np.uint32)[:, None] + cols = np.arange(WIDTH, dtype=np.uint32)[None, :] + stripes = (rows * 127 + cols * 127) % 256 + moving = (rows + cols + counter * 4) & 0xFF + body = (stripes ^ moving).astype(np.uint8) + frame = np.repeat(body[:, :, None], CHANNELS, axis=2) + + # Overwrite the header band with the block-coded counter + checksum. + for idx, bit in enumerate(_payload_bits(counter)): + y0, x0, y1, x1 = _block_box(idx) + frame[y0:y1, x0:x1, :] = 255 if bit else 0 + + return np.ascontiguousarray(frame) + + +def decode_frame(frame: np.ndarray) -> tuple[int, bool]: + """Recover ``(counter, ok)`` from a decoded frame. + + ``ok`` is False when the embedded checksum does not match the recovered + counter, i.e. the header band was corrupted in transit. + """ + gray = frame[..., :CHANNELS].mean(axis=2) if frame.ndim == 3 else frame + margin = _BLOCK // 4 # sample the central half of each block + payload = 0 + for idx in range(_BITS): + y0, x0, y1, x1 = _block_box(idx) + patch = gray[y0 + margin : y1 - margin, x0 + margin : x1 - margin] + payload = (payload << 1) | (1 if patch.mean() >= 127 else 0) + counter = payload >> _CHECK_BITS + check = payload & _CHECK_MASK + return counter, check == _checksum(counter) + + +def parse_video_frame_event(event: dict) -> tuple[str | None, str | None, np.ndarray]: + """Reshape a consumer ``on_frame`` event into ``(track_id, mid, array)``. + + Forward contract introduced by PR5 (extends the PR0 event schema): the + consumer surfaces each decoded frame on its drainable queue as + ``{"kind": "on_frame", "track_id", "mid", "data": bytes, "width", + "height"}``. ``data`` is the decoded picture as 8-bit HxWx3; the block codec + above is colour-order agnostic (black vs white blocks), so RGB or BGR both + decode. + """ + width = int(event.get("width", WIDTH)) + height = int(event.get("height", HEIGHT)) + array = np.frombuffer(event["data"], dtype=np.uint8) + array = array[: width * height * CHANNELS].reshape(height, width, CHANNELS) + return event.get("track_id"), event.get("mid"), array diff --git a/tests/integration/webrtc/shared/harness.py b/tests/integration/webrtc/shared/harness.py new file mode 100644 index 000000000..93ebb529c --- /dev/null +++ b/tests/integration/webrtc/shared/harness.py @@ -0,0 +1,534 @@ +"""In-process signaling relay and event pump for two native WebRTC peers. + +No separate signaling server: a background pump repeatedly drains both peers' +event queues and relays the signaling-out events between them, so the producer +offers and the consumer answers entirely in-process. + +Relay rules (the producer is the sole offerer; the consumer never offers and +never adds tracks): + + * producer offer ``on_local_description`` -> ``consumer.set_remote_offer`` + * consumer answer ``on_local_description`` -> ``producer.set_remote_answer`` + * either peer's ``on_local_candidate`` -> the other peer's + ``add_remote_candidate`` (trickle, both directions) + +Every drained event is also recorded (in arrival order, with a perf-clock +receive timestamp) into a per-peer log that tests assert against via the +``wait_*`` / ``*_events`` helpers. + +Against the PR0 stubs no signaling-out events are emitted, so the pump records +only the initial ``on_state: "new"`` and never invokes a stubbed +``set_remote_*`` / ``add_remote_candidate`` — the harness itself runs clean; only +the test bodies, which call the stubbed producer methods directly, go red. +""" + +from __future__ import annotations + +import threading +import time +from collections.abc import Callable + +from tests.integration.webrtc.shared import constants +from tests.integration.webrtc.shared.frames import ( + decode_frame, + encode_frame, + parse_video_frame_event, +) + +EventPredicate = Callable[[dict], bool] + +# Key the relay stamps onto every recorded event with its perf-clock arrival +# time (same clock as submit timestamps, so glass-to-glass subtracts cleanly). +RECV_TS_KEY = "_recv_perf" + + +def _kind(event: dict) -> str | None: + return event.get("kind") + + +def recv_time(event: dict) -> float | None: + """Perf-clock time at which the relay recorded ``event`` (or None).""" + return event.get(RECV_TS_KEY) + + +class Relay: + """One producer + one consumer joined by an in-process pump thread.""" + + POLL_INTERVAL_S = 0.002 + + def __init__(self, producer: object, consumer: object, *, name: str = "relay"): + self.producer = producer + self.consumer = consumer + self.name = name + self._lock = threading.RLock() + self._events: dict[str, list[dict]] = {"producer": [], "consumer": []} + # Exceptions raised while *relaying* signaling (not test failures). In a + # red run this stays empty (no signaling-out events fire). + self.dispatch_errors: list[BaseException] = [] + self._stop = threading.Event() + self._thread: threading.Thread | None = None + + # --- lifecycle ----------------------------------------------------------- + def start(self) -> Relay: + if self._thread is None: + self._thread = threading.Thread( + target=self._run, name=self.name, daemon=True + ) + self._thread.start() + return self + + def stop(self) -> None: + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=2.0) + self._thread = None + + def close(self) -> None: + self.stop() + for peer in (self.producer, self.consumer): + try: + peer.close() + except Exception: # noqa: BLE001 - teardown best effort + pass + + # --- pump ---------------------------------------------------------------- + def _run(self) -> None: + while not self._stop.is_set(): + try: + self.pump_once() + except Exception as exc: # noqa: BLE001 - never let the pump die + self.dispatch_errors.append(exc) + time.sleep(self.POLL_INTERVAL_S) + + def pump_once(self) -> None: + """Drain both peers once and relay their signaling-out events.""" + for event in self.producer.drain_events(): + self._record("producer", event) + self._relay(event, dst=self.consumer) + for event in self.consumer.drain_events(): + self._record("consumer", event) + self._relay(event, dst=self.producer) + + def _record(self, which: str, event: dict) -> None: + event[RECV_TS_KEY] = time.perf_counter() + with self._lock: + self._events[which].append(event) + + def _relay(self, event: dict, dst: object) -> None: + kind = _kind(event) + try: + if kind == "on_local_description": + if event.get("sdp_type") == "offer": + dst.set_remote_offer(event["sdp"]) + else: + dst.set_remote_answer(event["sdp"]) + elif kind == "on_local_candidate": + dst.add_remote_candidate(event["candidate"], event.get("mid")) + except Exception as exc: # noqa: BLE001 - surfaced via dispatch_errors + self.dispatch_errors.append(exc) + + # --- observation --------------------------------------------------------- + def events(self, which: str) -> list[dict]: + with self._lock: + return list(self._events[which]) + + def producer_events(self) -> list[dict]: + return self.events("producer") + + def consumer_events(self) -> list[dict]: + return self.events("consumer") + + def state_sequence(self, which: str) -> list[str]: + """Ordered ``on_state`` values seen for a peer (for PC-reset checks).""" + return [e["state"] for e in self.events(which) if _kind(e) == "on_state"] + + def wait_for( + self, which: str, predicate: EventPredicate, timeout: float + ) -> dict | None: + """Return the first recorded event matching ``predicate``, or None.""" + deadline = time.monotonic() + timeout + while True: + for event in self.events(which): + if predicate(event): + return event + if time.monotonic() >= deadline: + return None + time.sleep(self.POLL_INTERVAL_S) + + def wait_consumer(self, predicate: EventPredicate, timeout: float) -> dict | None: + return self.wait_for("consumer", predicate, timeout) + + def wait_producer(self, predicate: EventPredicate, timeout: float) -> dict | None: + return self.wait_for("producer", predicate, timeout) + + def wait_connected(self, timeout: float = constants.CONNECT_TIMEOUT_S) -> bool: + """True once *both* peers have reported on_state "connected".""" + + def connected(events: list[dict]) -> bool: + return any( + _kind(e) == "on_state" and e.get("state") == "connected" for e in events + ) + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if connected(self.producer_events()) and connected(self.consumer_events()): + return True + time.sleep(self.POLL_INTERVAL_S) + return False + + # --- data-channel message collection ------------------------------------- + def messages(self, which: str, label: str) -> list[bytes]: + """Ordered payloads of ``on_message`` events for ``label``.""" + return [ + e["data"] + for e in self.events(which) + if _kind(e) == "on_message" and e.get("label") == label + ] + + def wait_messages( + self, which: str, label: str, count: int, timeout: float + ) -> list[bytes]: + """Wait until ``count`` messages for ``label`` arrive; return them.""" + deadline = time.monotonic() + timeout + while True: + got = self.messages(which, label) + if len(got) >= count or time.monotonic() >= deadline: + return got + time.sleep(self.POLL_INTERVAL_S) + + # --- video frame collection ---------------------------------------------- + def video_frames(self, which: str, track_id: str | None = None) -> list[dict]: + """Ordered ``on_frame`` events, optionally filtered by track_id.""" + return [ + e + for e in self.events(which) + if _kind(e) == "on_frame" + and (track_id is None or e.get("track_id") == track_id) + ] + + +class BroadcastRelay: + """One :class:`Broadcaster` joined to N answer-only consumers by a pump thread. + + The broadcaster is the sole offerer to each consumer; its signaling-out events + are tagged with a ``consumer_id`` so the pump routes each to the right consumer + peer, and each consumer's answer/candidates are routed back to the broadcaster + with that id. This is the multi-consumer analogue of :class:`Relay`: one shared + encode per source fans out to every consumer. + + Relay rules: + * broadcaster ``on_local_description{consumer_id, offer}`` -> + ``consumers[consumer_id].set_remote_offer`` + * broadcaster ``on_local_candidate{consumer_id, ...}`` -> + ``consumers[consumer_id].add_remote_candidate`` + * consumer answer ``on_local_description`` -> + ``broadcaster.set_remote_answer(consumer_id, sdp)`` + * consumer ``on_local_candidate`` -> + ``broadcaster.add_remote_candidate(consumer_id, ...)`` + """ + + POLL_INTERVAL_S = 0.002 + + def __init__(self, broadcaster: object, *, name: str = "broadcast"): + self.broadcaster = broadcaster + self.name = name + self.consumers: dict[str, object] = {} + self._lock = threading.RLock() + self._events: dict[str, list[dict]] = {"broadcaster": []} + self.dispatch_errors: list[BaseException] = [] + self._stop = threading.Event() + self._thread: threading.Thread | None = None + + # --- lifecycle ----------------------------------------------------------- + def add_consumer(self, consumer_id: str, consumer: object) -> None: + """Register a consumer peer and add it to the broadcaster (triggers its + offer). The pump routes its signaling from then on.""" + with self._lock: + self.consumers[consumer_id] = consumer + self._events.setdefault(consumer_id, []) + self.broadcaster.add_consumer(consumer_id) + + def remove_consumer(self, consumer_id: str) -> None: + self.broadcaster.remove_consumer(consumer_id) + consumer = self.consumers.pop(consumer_id, None) + if consumer is not None: + try: + consumer.close() + except Exception: # noqa: BLE001 - teardown best effort + pass + + def start(self) -> BroadcastRelay: + if self._thread is None: + self._thread = threading.Thread( + target=self._run, name=self.name, daemon=True + ) + self._thread.start() + return self + + def stop(self) -> None: + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=2.0) + self._thread = None + + def close(self) -> None: + self.stop() + for peer in (self.broadcaster, *self.consumers.values()): + try: + peer.close() + except Exception: # noqa: BLE001 - teardown best effort + pass + + # --- pump ---------------------------------------------------------------- + def _run(self) -> None: + while not self._stop.is_set(): + try: + self.pump_once() + except Exception as exc: # noqa: BLE001 - never let the pump die + self.dispatch_errors.append(exc) + time.sleep(self.POLL_INTERVAL_S) + + def pump_once(self) -> None: + for event in self.broadcaster.drain_events(): + consumer_id = event.get("consumer_id") + self._record("broadcaster", event) + if consumer_id is not None: + self._relay_to_consumer(event, consumer_id) + # Snapshot the consumer set so a mid-pump add/remove does not error here. + with self._lock: + current = dict(self.consumers) + for consumer_id, consumer in current.items(): + for event in consumer.drain_events(): + self._record(consumer_id, event) + self._relay_to_broadcaster(event, consumer_id) + + def _record(self, which: str, event: dict) -> None: + event[RECV_TS_KEY] = time.perf_counter() + with self._lock: + self._events.setdefault(which, []).append(event) + + def _relay_to_consumer(self, event: dict, consumer_id: str) -> None: + consumer = self.consumers.get(consumer_id) + if consumer is None: + return + kind = _kind(event) + try: + if kind == "on_local_description" and event.get("sdp_type") == "offer": + consumer.set_remote_offer(event["sdp"]) + elif kind == "on_local_candidate": + consumer.add_remote_candidate(event["candidate"], event.get("mid")) + except Exception as exc: # noqa: BLE001 + self.dispatch_errors.append(exc) + + def _relay_to_broadcaster(self, event: dict, consumer_id: str) -> None: + kind = _kind(event) + try: + if kind == "on_local_description" and event.get("sdp_type") == "answer": + self.broadcaster.set_remote_answer(consumer_id, event["sdp"]) + elif kind == "on_local_candidate": + self.broadcaster.add_remote_candidate( + consumer_id, event["candidate"], event.get("mid") + ) + except Exception as exc: # noqa: BLE001 + self.dispatch_errors.append(exc) + + # --- observation --------------------------------------------------------- + def events(self, which: str) -> list[dict]: + with self._lock: + return list(self._events.get(which, [])) + + def _has_connected(self, which: str) -> bool: + return any( + _kind(e) == "on_state" and e.get("state") == "connected" + for e in self.events(which) + ) + + def wait_for( + self, which: str, predicate: EventPredicate, timeout: float + ) -> dict | None: + """Return the first recorded event for ``which`` matching ``predicate``.""" + deadline = time.monotonic() + timeout + while True: + for event in self.events(which): + if predicate(event): + return event + if time.monotonic() >= deadline: + return None + time.sleep(self.POLL_INTERVAL_S) + + def wait_consumer_connected( + self, consumer_id: str, timeout: float = constants.CONNECT_TIMEOUT_S + ) -> bool: + """True once both the broadcaster (for this consumer) and the consumer + peer report ``on_state: connected``.""" + + def broadcaster_connected() -> bool: + return any( + e.get("consumer_id") == consumer_id + and _kind(e) == "on_state" + and e.get("state") == "connected" + for e in self.events("broadcaster") + ) + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if broadcaster_connected() and self._has_connected(consumer_id): + return True + time.sleep(self.POLL_INTERVAL_S) + return False + + def video_frames(self, consumer_id: str, track_id: str | None = None) -> list[dict]: + return [ + e + for e in self.events(consumer_id) + if _kind(e) == "on_frame" + and (track_id is None or e.get("track_id") == track_id) + ] + + # --- data-channel message collection ------------------------------------- + def messages(self, consumer_id: str, label: str) -> list[bytes]: + """Ordered payloads of ``on_message`` events for ``label`` at a consumer.""" + return [ + e["data"] + for e in self.events(consumer_id) + if _kind(e) == "on_message" and e.get("label") == label + ] + + def wait_messages( + self, consumer_id: str, label: str, count: int, timeout: float + ) -> list[bytes]: + """Wait until ``count`` messages for ``label`` arrive at ``consumer_id``.""" + deadline = time.monotonic() + timeout + while True: + got = self.messages(consumer_id, label) + if len(got) >= count or time.monotonic() >= deadline: + return got + time.sleep(self.POLL_INTERVAL_S) + + def collect_video_frames( + self, + consumer_id: str, + track_id: str, + *, + settle: float = constants.FRAME_SETTLE_TIMEOUT_S, + quiet: float = 0.5, + ) -> list[dict]: + """Drain until the decoded stream for ``consumer_id``/``track_id`` is quiet.""" + last_change = time.monotonic() + deadline = last_change + settle + seen = len(self.video_frames(consumer_id, track_id)) + while time.monotonic() < deadline: + current = len(self.video_frames(consumer_id, track_id)) + if current != seen: + seen = current + last_change = time.monotonic() + deadline = last_change + settle + elif seen > 0 and time.monotonic() - last_change >= quiet: + break + time.sleep(0.01) + return self.video_frames(consumer_id, track_id) + + def submit_at_rate( + self, track_id: str, *, fps: float, seconds: float, start_counter: int = 0 + ) -> int: + """Submit encoded frames to the broadcaster at ``fps`` for ``seconds`` — + one shared encode, fanned to every consumer.""" + period = 1.0 / fps + total = int(round(fps * seconds)) + origin = time.perf_counter() + for index in range(total): + target = origin + index * period + now = time.perf_counter() + if target > now: + time.sleep(target - now) + self.broadcaster.submit_frame(track_id, encode_frame(index + start_counter)) + return total + + +def bootstrap_connection( + relay: Relay, + *, + control_label: str = "control", + timeout: float = constants.CONNECT_TIMEOUT_S, +) -> None: + """Open a control data channel (triggering the offer) and wait connected. + + Raises if the connection does not establish in ``timeout``. Against the PR0 + stubs ``add_data_channel`` raises ``NotImplementedError`` here, which is the + expected red path. + """ + relay.producer.add_data_channel(control_label, "reliable") + if not relay.wait_connected(timeout): + raise TimeoutError(f"connection did not establish within {timeout}s") + + +def submit_at_rate( + relay: Relay, + track_id: str, + *, + fps: float, + seconds: float, + start_counter: int = 0, +) -> tuple[int, dict[int, float]]: + """Submit encoded frames at ``fps`` for ``seconds``. + + Returns ``(submitted_count, submit_times)`` where ``submit_times`` maps each + counter to its perf-clock submit time (for glass-to-glass measurement). + """ + period = 1.0 / fps + total = int(round(fps * seconds)) + submit_times: dict[int, float] = {} + origin = time.perf_counter() + for index in range(total): + target = origin + index * period + now = time.perf_counter() + if target > now: + time.sleep(target - now) + counter = start_counter + index + submit_times[counter] = time.perf_counter() + relay.producer.submit_frame(track_id, encode_frame(counter)) + return total, submit_times + + +def collect_video_frames( + relay: Relay, + track_id: str, + *, + settle: float = constants.FRAME_SETTLE_TIMEOUT_S, + quiet: float = 0.5, +) -> list[dict]: + """Drain until the decoded video stream for ``track_id`` goes quiet. + + Returns the recorded ``on_frame`` events once no new frame has arrived for + ``quiet`` seconds, or after ``settle`` seconds total. + """ + last_change = time.monotonic() + deadline = last_change + settle + seen = len(relay.video_frames("consumer", track_id)) + while time.monotonic() < deadline: + current = len(relay.video_frames("consumer", track_id)) + if current != seen: + seen = current + last_change = time.monotonic() + deadline = last_change + settle + elif seen > 0 and time.monotonic() - last_change >= quiet: + break + time.sleep(0.01) + return relay.video_frames("consumer", track_id) + + +def decoded_counters(frames: list[dict]) -> tuple[list[int], list[int]]: + """Decode ``on_frame`` events into ``(counters, corrupted_counters)``. + + ``counters`` preserves arrival order; ``corrupted_counters`` holds any whose + embedded checksum failed. + """ + counters: list[int] = [] + corrupted: list[int] = [] + for event in frames: + _, _, array = parse_video_frame_event(event) + counter, ok = decode_frame(array) + counters.append(counter) + if not ok: + corrupted.append(counter) + return counters, corrupted diff --git a/tests/integration/webrtc/shared/markers.py b/tests/integration/webrtc/shared/markers.py new file mode 100644 index 000000000..ec8fc52c2 --- /dev/null +++ b/tests/integration/webrtc/shared/markers.py @@ -0,0 +1,66 @@ +"""Centralised xfail markers grouping each assertion to the PR that greens it. + +Every test in this suite depends on real WebRTC behaviour the PR0 stubs do not +provide: the protocol methods raise ``NotImplementedError`` and the +signaling-out events (``on_local_description`` / ``on_local_candidate``) are +never emitted. Each test is therefore expected to *fail* against the stubs. + +We mark every such test ``xfail(strict=True)`` so that: + + * a test that fails or errors against the stubs is reported ``xfailed`` — the + expected red state for this PR, and + * a test that *passes* is reported ``XPASS`` which, under ``strict``, turns + the run RED. That is the signal to the implementing PR: the behaviour is now + real, so DELETE this marker (never edit the assertion to keep it red). + +Markers are grouped by the PR that makes the slice real. A later PR greens its +slice by removing exactly the marker(s) tagged with its number. Because the PRs +land in order (PR2 < PR3 < ... < PR7), a later-PR test may freely rely on +earlier-PR behaviour being real by the time its own marker is removed (e.g. a +PR4 video test bootstraps a PR2 data channel). +""" + +from __future__ import annotations + +import pytest + +# PR number -> one-line description of the slice that PR turns green. The keys +# are the only valid `pr` arguments to `greened_by`; the report's +# test-name -> marker-group table is generated from this same mapping. +PR_SLICES: dict[str, str] = { + "PR2": ( + "data channel send/recv integrity (zero loss/reorder), data channel " + "add observed at consumer, mid->track manifest, connect & dc-add timing" + ), + "PR3": ( + "rapid data-channel add correctness under coalesced in-flight " + "renegotiation (no channel silently dropped)" + ), + "PR4": ( + "video track add/remove via renegotiation, manifest atomicity, " + "PC-not-reset, rapid video churn, add/remove reneg timing" + ), + "PR5": ( + "video frame integrity (monotonic counters, no corruption), " + "glass-to-glass latency, sustained fps from a 45fps source (single " + "consumer)" + ), + "PR6": "performance under a constrained link (netem-shaped loopback)", + "PR7": "multi-consumer performance (per-consumer SLOs hold with N consumers)", +} + + +def greened_by(pr: str, detail: str) -> pytest.MarkDecorator: + """Return an ``xfail(strict=True)`` marker tagged to the PR that greens it. + + Args: + pr: one of the keys in :data:`PR_SLICES` (e.g. ``"PR4"``). + detail: a one-line description of what this specific test asserts, used + in the xfail reason so a red run is self-documenting. + """ + if pr not in PR_SLICES: + raise KeyError(f"unknown PR slice {pr!r}; expected one of {sorted(PR_SLICES)}") + return pytest.mark.xfail( + strict=True, + reason=f"red until {pr} greens [{PR_SLICES[pr]}] -- {detail}", + ) diff --git a/tests/integration/webrtc/shared/metrics.py b/tests/integration/webrtc/shared/metrics.py new file mode 100644 index 000000000..5567c7bbb --- /dev/null +++ b/tests/integration/webrtc/shared/metrics.py @@ -0,0 +1,78 @@ +"""Performance SLOs, percentile maths, and the structured CI output schema. + +The performance test (Test 3) asserts each timing against the agreed SLO and +records it into a single :class:`Metrics` dataclass. At session teardown the +dataclass is emitted as one JSON object so CI can scrape it. The schema is +fixed here so it is stable across the PRs that progressively fill it in. +""" + +from __future__ import annotations + +import json +import math +import os +import sys +from dataclasses import asdict, dataclass + +# --- Agreed SLOs (the thresholds the perf tests assert against) --------------- +CONNECT_MS_P95 = 500.0 # connection established under 500ms p95 (trickle on) +RENEG_ADD_MS_P95 = 300.0 # add-track renegotiation under 300ms p95 +RENEG_REMOVE_MS_P95 = 300.0 # remove-track renegotiation under 300ms p95 +DC_ADD_MS_P95 = 300.0 # data channel add usable at consumer under 300ms p95 +G2G_P50_MS = 120.0 # glass-to-glass under 120ms p50 +G2G_P95_MS = 200.0 # glass-to-glass under 200ms p95 +MIN_DELIVERED_FPS = 30.0 # delivered fps floor from an over-rate source + +# --- Workload knobs (overridable for CI) -------------------------------------- +SOURCE_FPS = float(os.environ.get("NEURACORE_WEBRTC_SOURCE_FPS", 45)) +AT_RATE_FPS = float(os.environ.get("NEURACORE_WEBRTC_AT_RATE_FPS", 30)) +PERF_DURATION_S = float(os.environ.get("NEURACORE_WEBRTC_PERF_SECONDS", 60)) +PERF_SAMPLES = int(os.environ.get("NEURACORE_WEBRTC_PERF_SAMPLES", 20)) +MULTI_CONSUMER_N = int(os.environ.get("NEURACORE_WEBRTC_CONSUMERS", 3)) + + +@dataclass +class Metrics: + """Structured perf output for CI. ``None`` means "not measured this run". + + Field names are the CI contract — do not rename without updating the report. + """ + + connect_ms: float | None = None + reneg_add_ms: float | None = None + reneg_remove_ms: float | None = None + dc_add_ms: float | None = None + g2g_p50_ms: float | None = None + g2g_p95_ms: float | None = None + delivered_fps: float | None = None + drop_rate: float | None = None + + +def percentile(samples: list[float], pct: float) -> float: + """Linear-interpolated percentile (``pct`` in 0..100).""" + if not samples: + raise ValueError("percentile of empty sample set") + ordered = sorted(samples) + if len(ordered) == 1: + return ordered[0] + rank = (len(ordered) - 1) * (pct / 100.0) + low = math.floor(rank) + high = math.ceil(rank) + if low == high: + return ordered[int(rank)] + return ordered[low] * (high - rank) + ordered[high] * (rank - low) + + +def emit(metrics: Metrics, *, label: str = "neuracore-webrtc-perf") -> str: + """Emit ``metrics`` as one JSON line for CI; return the JSON payload. + + Always prints ``[label] {json}`` to stderr. If ``NEURACORE_WEBRTC_PERF_OUT`` + is set, also writes the JSON to that path. + """ + payload = json.dumps(asdict(metrics), sort_keys=True) + print(f"\n[{label}] {payload}", file=sys.stderr, flush=True) + out_path = os.environ.get("NEURACORE_WEBRTC_PERF_OUT") + if out_path: + with open(out_path, "w", encoding="utf-8") as handle: + handle.write(payload + "\n") + return payload diff --git a/tests/integration/webrtc/shared/server_transport.py b/tests/integration/webrtc/shared/server_transport.py new file mode 100644 index 000000000..9d6881cd7 --- /dev/null +++ b/tests/integration/webrtc/shared/server_transport.py @@ -0,0 +1,663 @@ +"""Server-backed signaling transport for the WebRTC integration suite. + +The default harness joins two native peers with an in-process relay (no signaling +server). This module adds an additive, env-gated alternative that drives the same +native ``Producer`` / ``Broadcaster`` / ``Consumer`` peers through the real +running backend's signaling: the REST submit endpoint plus the SSE notification +stream the deprecated aiortc path and the PR8 web path both use. It validates the +real transport end to end without a browser, as the intermediate gate between the +in-process loopback and the browser Playwright run. + +Design: subclass the relays and swap only the cross-peer hop. The parent +drain-and-record pump is unchanged, so every test assertion and ``wait_*`` helper +keeps working; the only difference is that a peer's drained +``on_local_description`` / ``on_local_candidate`` is mapped to a +``HandshakeMessage`` and POSTed to the backend instead of handed to the other +peer in-process, and inbound messages arrive on per-stream SSE reader threads. + +The producer-side mappings are reused from PR8 +(:func:`outbound_signal` / :func:`inbound_candidate` / :func:`needs_reconnect`), +not duplicated. The consumer-side mirror lives here because the production +consumer is a browser, so it is test-only. + +Wire contract (confirmed against the live backend): + * submit: POST /api/org/{org}/signalling/message/submit (HandshakeMessage) + * subscribe: GET /api/org/{org}/signalling/notifications/{stream_id} (SSE) + * keepalive: POST /api/org/{org}/signalling/alive/{stream_id} (body "pong") + * auth: Authorization: Bearer +A submit whose ``to_id`` is not yet a registered SSE subscriber is silently +dropped with no backend queue, so a peer's SSE stream is opened and confirmed +before any offer is POSTed to it. The 25s inactivity reaper requires answering +SSE heartbeats by POSTing the alive endpoint. Both peers run in this one process, +so only signaling crosses the backend; media is local loopback P2P (no TURN). +""" + +from __future__ import annotations + +import logging +import os +import threading +import time +from collections.abc import Callable, Iterable, Iterator +from dataclasses import dataclass +from uuid import uuid4 + +import requests +from neuracore_types import HandshakeMessage, MessageType + +from neuracore.core.streaming.p2p.provider.native_broadcast_provider import ( + inbound_candidate, + needs_reconnect, + outbound_signal, +) +from tests.integration.webrtc.shared import constants +from tests.integration.webrtc.shared.harness import BroadcastRelay, Relay + +logger = logging.getLogger(__name__) + +# Env that selects and configures the server-backed transport. When any is unset +# the suite falls back to the in-process relay and the server-backed tests skip. +URL_ENV = "NEURACORE_WEBRTC_SIGNALING_URL" +ORG_ENV = "NEURACORE_WEBRTC_SIGNALING_ORG" +TOKEN_ENV = "NEURACORE_WEBRTC_SIGNALING_TOKEN" + +# Placeholder consumer id injected so the 1:1 producer (which emits no +# ``consumer_id``) and the consumer side can reuse the PR8 ``outbound_signal`` +# candidate formatting without duplicating it. The value is never sent on the +# wire; addressing is by stream id. +_PLACEHOLDER_ID = "_" + + +# --------------------------------------------------------------------------- # +# Config +# --------------------------------------------------------------------------- # +@dataclass(frozen=True) +class SignalingConfig: + """Operator-supplied backend coordinates for the server-backed transport.""" + + base_url: str # includes the /api prefix, e.g. http://host:8000/api + org: str + token: str + + +def signaling_config_from_env() -> SignalingConfig | None: + """Build a config from env, or None when the transport is not selected.""" + base_url = os.environ.get(URL_ENV) + org = os.environ.get(ORG_ENV) + token = os.environ.get(TOKEN_ENV) + if not (base_url and org and token): + return None + return SignalingConfig(base_url.rstrip("/"), org, token) + + +# --------------------------------------------------------------------------- # +# Pure mappings (peer-free, unit-tested) +# --------------------------------------------------------------------------- # +def consumer_outbound_signal(event: dict) -> tuple[MessageType, str] | None: + """Map one drained consumer event to a signaling message. + + The consumer is answer-only, so it emits SDP **answers** and ICE candidates. + Candidate formatting is delegated to the PR8 producer mapper (with a + placeholder ``consumer_id``) so the JSON shape is not duplicated. + + Args: + event: a single dict from ``Consumer.drain_events()``. + + Returns: + ``(message_type, data)`` to submit, or ``None`` if not deliverable. + """ + kind = event.get("kind") + if kind == "on_local_description" and event.get("sdp_type") == "answer": + return MessageType.SDP_ANSWER, event["sdp"] + if kind == "on_local_candidate": + signal = outbound_signal({**event, "consumer_id": _PLACEHOLDER_ID}) + if signal is not None: + return signal.message_type, signal.data + return None + + +def parse_sse_lines(lines: Iterable[bytes | str]) -> Iterator[tuple[str, str]]: + """Parse a stream of SSE lines into ``(event_type, data)`` frames. + + Handles the backend's ``event:data`` / ``event:heartbeat`` / ``event:end`` + framing (optional leading space after the colon, comment lines, multi-line + data joined by newline, and the blank-line frame boundary). + """ + event = "message" + data_parts: list[str] = [] + for raw in lines: + line = raw.decode("utf-8") if isinstance(raw, bytes) else raw + if line == "": + if data_parts: + yield event, "\n".join(data_parts) + event = "message" + data_parts = [] + continue + if line.startswith(":"): + continue + field, _, value = line.partition(":") + if value.startswith(" "): + value = value[1:] + if field == "event": + event = value + elif field == "data": + data_parts.append(value) + if data_parts: + yield event, "\n".join(data_parts) + + +# --------------------------------------------------------------------------- # +# Inbound application with trickle buffering +# --------------------------------------------------------------------------- # +class _Inbound: + """Applies inbound signaling to one peer, buffering early candidates. + + Candidates that arrive before the remote description is set are buffered and + flushed immediately after it, mirroring the web path and the aiortc + ``received_offer_event`` / ``received_answer_event`` gate. The peer setters + are injected so this is unit-testable with a fake peer. + """ + + def __init__( + self, + description: Callable[[str], None], + add_candidate: Callable[[str, str | None], None], + ) -> None: + self._description = description + self._add_candidate = add_candidate + self._have_description = False + self._pending: list[tuple[str, str | None]] = [] + self._lock = threading.Lock() + + def description(self, sdp: str) -> None: + """Apply the remote offer/answer, then flush buffered candidates.""" + self._description(sdp) + with self._lock: + self._have_description = True + pending = self._pending + self._pending = [] + for candidate, mid in pending: + self._add_candidate(candidate, mid) + + def candidate(self, candidate: str, mid: str | None) -> None: + """Apply a candidate, or buffer it until the description is set.""" + with self._lock: + if not self._have_description: + self._pending.append((candidate, mid)) + return + self._add_candidate(candidate, mid) + + +# --------------------------------------------------------------------------- # +# HTTP client (POST submit/alive + SSE subscribe) +# --------------------------------------------------------------------------- # +class _Subscription: + """One SSE reader thread for a stream id, with reconnect and heartbeat.""" + + def __init__(self, client: SignalingClient, stream_id: str) -> None: + self._client = client + self._stream_id = stream_id + self._stop = threading.Event() + self._connected = threading.Event() + self._response: requests.Response | None = None + self._thread = threading.Thread( + target=self._run, name=f"sse-{stream_id[:8]}", daemon=True + ) + self._thread.start() + + def wait_connected(self, timeout: float) -> bool: + """True once the SSE GET has established (the stream is registered).""" + return self._connected.wait(timeout) + + def stop(self) -> None: + self._stop.set() + if self._response is not None: + try: + self._response.close() # unblock a blocked iter_lines + except Exception: # noqa: BLE001 - teardown best effort + pass + self._thread.join(timeout=2.0) + + def _run(self) -> None: + backoff = 0.05 + while not self._stop.is_set(): + try: + response = self._client.open_stream(self._stream_id) + self._response = response + self._connected.set() + backoff = 0.05 + for event, data in parse_sse_lines( + response.iter_lines(decode_unicode=True) + ): + if self._stop.is_set(): + break + if event == "data": + self._dispatch(data) + elif event == "heartbeat": + self._client.mark_alive(self._stream_id) + elif event == "end": + break + except Exception as exc: # noqa: BLE001 - reconnect on any error + if not self._stop.is_set(): + logger.warning("sse stream %s error: %s", self._stream_id, exc) + if self._stop.is_set(): + break + time.sleep(backoff) + backoff = min(5.0, backoff * 2) + + def _dispatch(self, data: str) -> None: + try: + message = HandshakeMessage.model_validate_json(data) + except Exception: # noqa: BLE001 - ignore malformed frames + logger.warning("sse stream %s dropped malformed message", self._stream_id) + return + self._client.deliver(self._stream_id, message) + + +class SignalingClient: + """Thin HTTP client over the backend signaling contract. + + POSTs handshake messages, opens SSE subscriptions, answers heartbeats, and + fans inbound messages to the handler registered for each stream id. + """ + + CONNECT_TIMEOUT_S = 10.0 + # Longer than the 20s server heartbeat so a healthy stream never times out + # but a dead one eventually errors and the subscription reconnects. + READ_TIMEOUT_S = 40.0 + + def __init__(self, config: SignalingConfig) -> None: + self._config = config + self._session = requests.Session() + self._handlers: dict[str, Callable[[HandshakeMessage], None]] = {} + self._lock = threading.Lock() + + # --- url + headers ------------------------------------------------------- + def _signalling_base(self) -> str: + return f"{self._config.base_url}/org/{self._config.org}/signalling" + + def _headers(self) -> dict[str, str]: + return {"Authorization": f"Bearer {self._config.token}"} + + # --- outbound ------------------------------------------------------------ + def submit( + self, + from_id: str, + to_id: str, + connection_id: str, + message_type: MessageType, + data: str, + ) -> None: + """POST one handshake message; a fresh id avoids backend LRU de-duplication.""" + message = HandshakeMessage( + from_id=from_id, + to_id=to_id, + connection_id=connection_id, + type=message_type, + data=data, + ) + response = self._session.post( + f"{self._signalling_base()}/message/submit", + headers=self._headers(), + json=message.model_dump(mode="json"), + timeout=self.CONNECT_TIMEOUT_S, + ) + response.raise_for_status() + + def mark_alive(self, stream_id: str) -> None: + """Answer a heartbeat so the inactivity reaper keeps the stream.""" + try: + self._session.post( + f"{self._signalling_base()}/alive/{stream_id}", + headers=self._headers(), + data="pong", + timeout=self.CONNECT_TIMEOUT_S, + ) + except Exception as exc: # noqa: BLE001 - keepalive is best effort + logger.warning("alive ping for %s failed: %s", stream_id, exc) + + # --- inbound ------------------------------------------------------------- + def subscribe( + self, stream_id: str, on_message: Callable[[HandshakeMessage], None] + ) -> _Subscription: + """Register a handler and start an SSE reader thread for ``stream_id``.""" + with self._lock: + self._handlers[stream_id] = on_message + return _Subscription(self, stream_id) + + def open_stream(self, stream_id: str) -> requests.Response: + """Open the SSE GET (used by :class:`_Subscription`).""" + response = self._session.get( + f"{self._signalling_base()}/notifications/{stream_id}", + headers=self._headers(), + stream=True, + timeout=(self.CONNECT_TIMEOUT_S, self.READ_TIMEOUT_S), + ) + response.raise_for_status() + return response + + def deliver(self, stream_id: str, message: HandshakeMessage) -> None: + """Route one inbound message to its stream's handler.""" + with self._lock: + handler = self._handlers.get(stream_id) + if handler is not None: + handler(message) + + def close(self) -> None: + with self._lock: + self._handlers.clear() + try: + self._session.close() + except Exception: # noqa: BLE001 - teardown best effort + pass + + +# --------------------------------------------------------------------------- # +# 1:1 server-backed relay +# --------------------------------------------------------------------------- # +class ServerRelay(Relay): + """A :class:`Relay` that signals over the real backend instead of in-process. + + The parent pump still drains and records both peers; this subclass only + rewrites the cross-peer hop (``_relay``) to POST and adds two SSE reader + threads that apply inbound messages to the local peers. + """ + + def __init__( + self, + producer: object, + consumer: object, + *, + config: SignalingConfig, + name: str = "server-relay", + connect_timeout: float = constants.CONNECT_TIMEOUT_S, + ) -> None: + super().__init__(producer, consumer, name=name) + self._client = SignalingClient(config) + self._connect_timeout = connect_timeout + self._producer_stream_id = uuid4().hex + self._consumer_stream_id = uuid4().hex + self._connection_id = uuid4().hex + self._producer_inbound = _Inbound( + producer.set_remote_answer, producer.add_remote_candidate + ) + self._consumer_inbound = _Inbound( + consumer.set_remote_offer, consumer.add_remote_candidate + ) + self._subs: list[_Subscription] = [] + + def start(self) -> ServerRelay: + # Open and confirm both SSE streams before the pump can POST any offer, + # so neither peer is an unregistered (silently dropped) recipient. + self._subs.append( + self._client.subscribe(self._producer_stream_id, self._on_producer_message) + ) + self._subs.append( + self._client.subscribe(self._consumer_stream_id, self._on_consumer_message) + ) + for sub in self._subs: + sub.wait_connected(self._connect_timeout) + super().start() + return self + + def close(self) -> None: + for sub in self._subs: + sub.stop() + self._client.close() + super().close() + + # --- outbound (override the in-process hop) ------------------------------ + def _relay(self, event: dict, dst: object) -> None: + try: + if dst is self.consumer: + # producer -> consumer: 1:1 producer emits no consumer_id, so + # inject the placeholder to reuse the PR8 producer mapping. + signal = outbound_signal({**event, "consumer_id": _PLACEHOLDER_ID}) + if signal is None: + return + self._client.submit( + self._producer_stream_id, + self._consumer_stream_id, + self._connection_id, + signal.message_type, + signal.data, + ) + else: + mapped = consumer_outbound_signal(event) + if mapped is None: + return + message_type, data = mapped + self._client.submit( + self._consumer_stream_id, + self._producer_stream_id, + self._connection_id, + message_type, + data, + ) + except Exception as exc: # noqa: BLE001 - surfaced via dispatch_errors + self.dispatch_errors.append(exc) + + # --- inbound ------------------------------------------------------------- + def _on_producer_message(self, message: HandshakeMessage) -> None: + if message.type == MessageType.SDP_ANSWER: + self._producer_inbound.description(message.data) + elif message.type == MessageType.ICE_CANDIDATE: + candidate, mid = inbound_candidate(message.data) + self._producer_inbound.candidate(candidate, mid) + + def _on_consumer_message(self, message: HandshakeMessage) -> None: + if message.type == MessageType.SDP_OFFER: + self._consumer_inbound.description(message.data) + elif message.type == MessageType.ICE_CANDIDATE: + candidate, mid = inbound_candidate(message.data) + self._consumer_inbound.candidate(candidate, mid) + + +# --------------------------------------------------------------------------- # +# Multi-consumer server-backed relay +# --------------------------------------------------------------------------- # +class ServerBroadcastRelay(BroadcastRelay): + """A :class:`BroadcastRelay` that signals each consumer over the backend. + + Per-consumer routing maps the broadcaster's internal ``consumer_id`` to a + server ``connection_id`` and stream id so offers, answers, and candidates + reach the right peer. A join opens the consumer's SSE stream before + ``add_consumer`` (so the broadcaster's offer is not dropped); a leave tears + its subscription down; a PR7 ``on_error{where:"connection"}`` is a remove + + re-add over the server with a fresh ``connection_id`` (no ICE restart). + """ + + def __init__( + self, + broadcaster: object, + *, + config: SignalingConfig, + name: str = "server-broadcast", + connect_timeout: float = constants.CONNECT_TIMEOUT_S, + ) -> None: + super().__init__(broadcaster, name=name) + self._client = SignalingClient(config) + self._connect_timeout = connect_timeout + self._broadcaster_stream_id = uuid4().hex + self._consumer_stream: dict[str, str] = {} + self._consumer_conn: dict[str, str] = {} + self._conn_to_consumer: dict[str, str] = {} + self._broadcaster_inbound: dict[str, _Inbound] = {} + self._consumer_subs: dict[str, _Subscription] = {} + self._broadcaster_sub: _Subscription | None = None + + def start(self) -> ServerBroadcastRelay: + self._broadcaster_sub = self._client.subscribe( + self._broadcaster_stream_id, self._on_broadcaster_message + ) + self._broadcaster_sub.wait_connected(self._connect_timeout) + super().start() + return self + + def close(self) -> None: + for sub in self._consumer_subs.values(): + sub.stop() + if self._broadcaster_sub is not None: + self._broadcaster_sub.stop() + self._client.close() + super().close() + + # --- consumer lifecycle -------------------------------------------------- + def add_consumer(self, consumer_id: str, consumer: object) -> None: + stream_id = uuid4().hex + connection_id = uuid4().hex + with self._lock: + self.consumers[consumer_id] = consumer + self._events.setdefault(consumer_id, []) + self._consumer_stream[consumer_id] = stream_id + self._consumer_conn[consumer_id] = connection_id + self._conn_to_consumer[connection_id] = consumer_id + self._broadcaster_inbound[consumer_id] = self._make_broadcaster_inbound( + consumer_id + ) + sub = self._client.subscribe( + stream_id, self._make_consumer_handler(consumer_id, consumer) + ) + self._consumer_subs[consumer_id] = sub + # Confirm the consumer is registered before the broadcaster offers to it. + sub.wait_connected(self._connect_timeout) + self.broadcaster.add_consumer(consumer_id) + + def remove_consumer(self, consumer_id: str) -> None: + self.broadcaster.remove_consumer(consumer_id) + sub = self._consumer_subs.pop(consumer_id, None) + if sub is not None: + sub.stop() + consumer = self.consumers.pop(consumer_id, None) + with self._lock: + self._consumer_stream.pop(consumer_id, None) + conn = self._consumer_conn.pop(consumer_id, None) + if conn is not None: + self._conn_to_consumer.pop(conn, None) + self._broadcaster_inbound.pop(consumer_id, None) + if consumer is not None: + try: + consumer.close() + except Exception: # noqa: BLE001 - teardown best effort + pass + + def _make_broadcaster_inbound(self, consumer_id: str) -> _Inbound: + return _Inbound( + lambda sdp: self.broadcaster.set_remote_answer(consumer_id, sdp), + lambda candidate, mid: self.broadcaster.add_remote_candidate( + consumer_id, candidate, mid + ), + ) + + def _make_consumer_handler( + self, consumer_id: str, consumer: object + ) -> Callable[[HandshakeMessage], None]: + inbound = _Inbound(consumer.set_remote_offer, consumer.add_remote_candidate) + + def handle(message: HandshakeMessage) -> None: + if message.type == MessageType.SDP_OFFER: + inbound.description(message.data) + elif message.type == MessageType.ICE_CANDIDATE: + candidate, mid = inbound_candidate(message.data) + inbound.candidate(candidate, mid) + + return handle + + # --- pump (add PR7 reconnect handling to the parent routing) ------------- + def pump_once(self) -> None: + for event in self.broadcaster.drain_events(): + consumer_id = event.get("consumer_id") + self._record("broadcaster", event) + reconnect_id = needs_reconnect(event) + if reconnect_id is not None: + logger.warning( + "webrtc consumer %s needs reconnect: %s", + reconnect_id, + event.get("detail"), + ) + self._reconnect(reconnect_id) + continue + if consumer_id is not None: + self._relay_to_consumer(event, consumer_id) + with self._lock: + current = dict(self.consumers) + for consumer_id, consumer in current.items(): + for event in consumer.drain_events(): + self._record(consumer_id, event) + self._relay_to_broadcaster(event, consumer_id) + + def _reconnect(self, consumer_id: str) -> None: + """Remove + re-add one consumer over the server with a fresh id.""" + if consumer_id not in self.consumers: + return + new_conn = uuid4().hex + with self._lock: + old_conn = self._consumer_conn.get(consumer_id) + if old_conn is not None: + self._conn_to_consumer.pop(old_conn, None) + self._consumer_conn[consumer_id] = new_conn + self._conn_to_consumer[new_conn] = consumer_id + self._broadcaster_inbound[consumer_id] = self._make_broadcaster_inbound( + consumer_id + ) + self.broadcaster.remove_consumer(consumer_id) + self.broadcaster.add_consumer(consumer_id) + + # --- outbound (override the in-process hops) ----------------------------- + def _relay_to_consumer(self, event: dict, consumer_id: str) -> None: + try: + signal = outbound_signal(event) # event is already consumer-tagged + if signal is None: + return + with self._lock: + to_id = self._consumer_stream.get(consumer_id) + connection_id = self._consumer_conn.get(consumer_id) + if to_id is None or connection_id is None: + return + self._client.submit( + self._broadcaster_stream_id, + to_id, + connection_id, + signal.message_type, + signal.data, + ) + except Exception as exc: # noqa: BLE001 - surfaced via dispatch_errors + self.dispatch_errors.append(exc) + + def _relay_to_broadcaster(self, event: dict, consumer_id: str) -> None: + try: + mapped = consumer_outbound_signal(event) + if mapped is None: + return + message_type, data = mapped + with self._lock: + from_id = self._consumer_stream.get(consumer_id) + connection_id = self._consumer_conn.get(consumer_id) + if from_id is None or connection_id is None: + return + self._client.submit( + from_id, + self._broadcaster_stream_id, + connection_id, + message_type, + data, + ) + except Exception as exc: # noqa: BLE001 - surfaced via dispatch_errors + self.dispatch_errors.append(exc) + + # --- inbound ------------------------------------------------------------- + def _on_broadcaster_message(self, message: HandshakeMessage) -> None: + with self._lock: + consumer_id = self._conn_to_consumer.get(message.connection_id) + inbound = ( + self._broadcaster_inbound.get(consumer_id) + if consumer_id is not None + else None + ) + if inbound is None: + return # unknown connection_id: ignore, do not raise + if message.type == MessageType.SDP_ANSWER: + inbound.description(message.data) + elif message.type == MessageType.ICE_CANDIDATE: + candidate, mid = inbound_candidate(message.data) + inbound.candidate(candidate, mid) diff --git a/tests/integration/webrtc/test_behavioural_correctness.py b/tests/integration/webrtc/test_behavioural_correctness.py new file mode 100644 index 000000000..aa9fe75af --- /dev/null +++ b/tests/integration/webrtc/test_behavioural_correctness.py @@ -0,0 +1,154 @@ +"""Test 1 - behavioural correctness of async add/remove with renegotiation. + +A data-only session adds and removes video tracks and data channels +mid-session; the consumer must observe each change via renegotiation, the +mid<->track manifest must stay consistent, and the peer connection must NOT be +torn down or reset. Rapid churn guards the in-flight-negotiation hazard. + +xfail groups (see shared/markers.py): the video behavioural assertions and +rapid *video* churn green in PR4; rapid *data-channel* churn greens in PR3. +""" + +from __future__ import annotations + +import json + +from tests.integration.webrtc.shared import constants +from tests.integration.webrtc.shared.harness import Relay, bootstrap_connection + + +def _manifest_map(event: dict) -> dict: + return json.loads(event["json"]) + + +def test_video_track_add_remove_midsession_pc_not_reset(relay: Relay) -> None: + # Start data-only, then bring the connection up. + relay.producer.add_data_channel("json", "reliable") + bootstrap_connection(relay) + states_before = relay.state_sequence("consumer") + + # --- add a video track mid-session ------------------------------------- + track_id = "wrist_cam" + mid = relay.producer.add_video_track(track_id) + + added = relay.wait_consumer( + lambda e: e.get("kind") == "on_track_added" and e.get("track_id") == track_id, + constants.RENEG_TIMEOUT_S, + ) + assert added is not None, "consumer never observed on_track_added via reneg" + assert added["mid"] == mid, "on_track_added mid disagrees with add_video_track" + + # The manifest republished on this renegotiation already carries the new + # mid -> track mapping as one coherent (atomic) update. + manifest = relay.wait_consumer( + lambda e: e.get("kind") == "on_manifest" and mid in _manifest_map(e), + constants.RENEG_TIMEOUT_S, + ) + assert manifest is not None, "manifest not republished with the new track" + assert track_id in json.dumps( + _manifest_map(manifest)[mid] + ), "manifest entry for the new mid does not reference its track_id" + + # --- remove the video track (keyed by track_id) ------------------------ + # The consumer learns of removal by mid only, so map track_id -> mid from + # the earlier on_track_added to assert the matching removal. + relay.producer.remove_video_track(track_id) + removed = relay.wait_consumer( + lambda e: e.get("kind") == "on_track_removed" and e.get("mid") == mid, + constants.RENEG_TIMEOUT_S, + ) + assert removed is not None, "consumer never observed on_track_removed for the mid" + + # --- the PC must NOT be reset ------------------------------------------ + states_after = relay.state_sequence("consumer") + assert states_after.count("new") == states_before.count( + "new" + ), "peer connection churned back to 'new' (full reset) during renegotiation" + assert "closed" not in states_after, "peer connection was torn down on remove" + assert ( + not relay.dispatch_errors + ), f"signaling relay errored: {relay.dispatch_errors}" + + +# Baseline guard (un-xfailed in PR3): data-channel add is inherently safe. SCTP +# data channels do not renegotiate — after the first channel brings up the SCTP +# association, each further channel is a DCEP stream open over it, so there is no +# in-flight-renegotiation hazard for data channels (PR2 proved this). The +# in-flight hazard exists only for media tracks; this test stays as a passing +# regression guard that rapid data-channel add never drops a channel. +def test_rapid_data_channel_churn_no_silent_drop(relay: Relay) -> None: + bootstrap_connection(relay) + + # Add many channels with zero spacing so several negotiations are in flight + # at once - the hazard PR3's single-in-flight coalescing must absorb. + labels = [f"dc{i}" for i in range(12)] + for label in labels: + relay.producer.add_data_channel(label, "reliable") + + # Every channel must surface at the consumer; none silently dropped. + for label in labels: + observed = relay.wait_consumer( + lambda e, lbl=label: e.get("kind") == "on_data_channel" + and e.get("label") == lbl, + constants.DC_OPEN_TIMEOUT_S, + ) + assert observed is not None, f"data channel {label!r} was silently dropped" + + # Final manifest carries all of them (state matches on both sides). + manifest = relay.wait_consumer( + lambda e: e.get("kind") == "on_manifest" + and all(lbl in e["json"] for lbl in labels), + constants.RENEG_TIMEOUT_S, + ) + assert manifest is not None, "final manifest is missing some data channels" + assert ( + not relay.dispatch_errors + ), f"signaling relay errored: {relay.dispatch_errors}" + + +def test_rapid_video_track_churn_no_silent_drop(relay: Relay) -> None: + relay.producer.add_data_channel("json", "reliable") + bootstrap_connection(relay) + + track_ids = [f"cam{i}" for i in range(6)] + mids: dict[str, str] = {} + + # Interleave adds and removes with no spacing to overlap negotiations. + for track_id in track_ids: + mids[track_id] = relay.producer.add_video_track(track_id) + removed = set(track_ids[::2]) # remove every other one + for track_id in removed: + relay.producer.remove_video_track(track_id) + + survivors = [t for t in track_ids if t not in removed] + + # Each survivor must be observed and not later removed. + for track_id in survivors: + added = relay.wait_consumer( + lambda e, t=track_id: e.get("kind") == "on_track_added" + and e.get("track_id") == t, + constants.RENEG_TIMEOUT_S, + ) + assert added is not None, f"video track {track_id!r} was silently dropped" + + # Each removed track must surface a removal for its mid. + for track_id in removed: + gone = relay.wait_consumer( + lambda e, m=mids[track_id]: e.get("kind") == "on_track_removed" + and e.get("mid") == m, + constants.RENEG_TIMEOUT_S, + ) + assert gone is not None, f"removal of {track_id!r} was never observed" + + # Final manifest = exactly the survivors; PC never reset. + final = relay.wait_consumer( + lambda e: e.get("kind") == "on_manifest" + and {mids[t] for t in survivors}.issubset(set(json.loads(e["json"]))) + and not ({mids[t] for t in removed} & set(json.loads(e["json"]))), + constants.RENEG_TIMEOUT_S, + ) + assert final is not None, "final manifest does not match the survivor set" + assert "new" not in relay.state_sequence("consumer")[1:], "PC reset during churn" + assert ( + not relay.dispatch_errors + ), f"signaling relay errored: {relay.dispatch_errors}" diff --git a/tests/integration/webrtc/test_data_integrity.py b/tests/integration/webrtc/test_data_integrity.py new file mode 100644 index 000000000..463e028d5 --- /dev/null +++ b/tests/integration/webrtc/test_data_integrity.py @@ -0,0 +1,145 @@ +"""Test 2 - data integrity over reliable channels and over video. + +Reliable-ordered data channels must deliver every message exactly once and in +order; decoded video frames must carry monotonic, uncorrupted counters. + +xfail groups (see shared/markers.py): data-channel delivery, data-channel add, +and the manifest green in PR2; video frame integrity greens in PR5. +""" + +from __future__ import annotations + +import json +import time +from collections.abc import Callable + +from tests.integration.webrtc.shared import constants, metrics +from tests.integration.webrtc.shared.harness import ( + BroadcastRelay, + Relay, + bootstrap_connection, + collect_video_frames, + decoded_counters, + submit_at_rate, +) + +BroadcastFactory = Callable[..., BroadcastRelay] + + +def test_data_channels_zero_loss_zero_reorder(relay: Relay) -> None: + relay.producer.add_data_channel("json", "reliable") + relay.producer.add_data_channel("joints", "reliable") + bootstrap_connection(relay) + + json_seq = [{"i": i, "kind": "json", "payload": f"msg-{i}"} for i in range(50)] + joints_seq = [{"i": i, "q": [float(i + j) for j in range(7)]} for i in range(50)] + + for payload in json_seq: + relay.producer.send_json("json", json.dumps(payload)) + for payload in joints_seq: + relay.producer.send_json("joints", json.dumps(payload)) + + got_json = relay.wait_messages( + "consumer", "json", len(json_seq), constants.MESSAGE_TIMEOUT_S + ) + got_joints = relay.wait_messages( + "consumer", "joints", len(joints_seq), constants.MESSAGE_TIMEOUT_S + ) + + # Completeness + ordering in one comparison: decode in arrival order and + # require exact equality with what was sent. + assert [json.loads(d) for d in got_json] == json_seq, "json loss or reorder" + assert [json.loads(d) for d in got_joints] == joints_seq, "joints loss or reorder" + + +def test_data_channel_add_observed_with_manifest(relay: Relay) -> None: + bootstrap_connection(relay) + + relay.producer.add_data_channel("telemetry", "reliable") + + observed = relay.wait_consumer( + lambda e: e.get("kind") == "on_data_channel" and e.get("label") == "telemetry", + constants.DC_OPEN_TIMEOUT_S, + ) + assert observed is not None, "consumer never observed the new data channel" + + manifest = relay.wait_consumer( + lambda e: e.get("kind") == "on_manifest" and "telemetry" in e["json"], + constants.RENEG_TIMEOUT_S, + ) + assert manifest is not None, "manifest not republished with the new channel" + assert isinstance(json.loads(manifest["json"]), dict), "manifest is not a mid map" + + +def test_video_frames_monotonic_and_intact(relay: Relay) -> None: + track_id = "cam0" + relay.producer.add_data_channel("json", "reliable") + relay.producer.add_video_track(track_id) + bootstrap_connection(relay) + + submitted, _ = submit_at_rate(relay, track_id, fps=30, seconds=4) + assert submitted > 0 + + # allow the tail of the pipeline to flush + time.sleep(0.2) + frames = collect_video_frames(relay, track_id) + assert frames, "no video frames delivered to the consumer" + + counters, corrupted = decoded_counters(frames) + assert not corrupted, f"corrupted frames detected at counters {corrupted}" + # Drops are allowed (lossy ingress/network); reorder and duplication are not. + assert counters == sorted(counters), "frames delivered out of order" + assert len(counters) == len(set(counters)), "duplicate frames delivered" + + +def test_multi_consumer_data_zero_loss_zero_reorder( + make_broadcast: BroadcastFactory, +) -> None: + """One Broadcaster fans json/joints data channels to N consumers losslessly. + + The data analogue of ``test_multi_consumer_perf``: every consumer must receive + the exact known sequence on each reliable channel with zero loss and zero + reorder. ``json`` is registered before the consumers join (each gets it at + bootstrap) and ``joints`` after they connect (DCEP over the live association, + no renegotiation — PR2), so both the bootstrap and the live-add fan-out paths + are covered. + """ + n = metrics.MULTI_CONSUMER_N + relay = make_broadcast() + + # Registered before any consumer: late joiners pick it up at bootstrap. + relay.broadcaster.add_data_channel("json", "reliable") + + consumer_ids = [f"c{i}" for i in range(n)] + for consumer_id in consumer_ids: + make_broadcast.add_consumer(relay, consumer_id) + for consumer_id in consumer_ids: + assert relay.wait_consumer_connected( + consumer_id + ), f"consumer {consumer_id} did not connect" + + # Added on the live association: every connected consumer gets it via DCEP. + relay.broadcaster.add_data_channel("joints", "reliable") + + json_seq = [{"i": i, "kind": "json", "payload": f"msg-{i}"} for i in range(50)] + joints_seq = [{"i": i, "q": [float(i + j) for j in range(7)]} for i in range(50)] + + for payload in json_seq: + relay.broadcaster.send_json("json", json.dumps(payload)) + for payload in joints_seq: + relay.broadcaster.send_json("joints", json.dumps(payload)) + + # Every consumer receives both full sequences, in order, exactly once. + for consumer_id in consumer_ids: + got_json = relay.wait_messages( + consumer_id, "json", len(json_seq), constants.MESSAGE_TIMEOUT_S + ) + got_joints = relay.wait_messages( + consumer_id, "joints", len(joints_seq), constants.MESSAGE_TIMEOUT_S + ) + assert [ + json.loads(d) for d in got_json + ] == json_seq, f"consumer {consumer_id}: json loss or reorder" + assert [ + json.loads(d) for d in got_joints + ] == joints_seq, f"consumer {consumer_id}: joints loss or reorder" diff --git a/tests/integration/webrtc/test_hardening.py b/tests/integration/webrtc/test_hardening.py new file mode 100644 index 000000000..557336c84 --- /dev/null +++ b/tests/integration/webrtc/test_hardening.py @@ -0,0 +1,313 @@ +"""Test 4 - operational hardening (PR7). + +Proves the streaming core survives real operation: ffmpeg subprocess crashes are +detected, surfaced (``on_error``) and restarted; teardown and error paths leak no +threads, file descriptors, subprocesses, or registry entries; backpressure has a +single drop point; and close fully tears both surfaces down. + +Two flavours: + + * **Error-injection** (peer-capable but fast, runs by default): kill an encode + mid-stream and assert it restarts + surfaces ``on_error`` + the stream + recovers; close a peer out from under a live sender and assert it stays + graceful; force a malformed multi-slice encode and assert the PR5.6 invariant + guard drops it (never panics). + * **Soak/stress** (gated by ``NEURACORE_WEBRTC_SOAK``; shortened via + ``NEURACORE_WEBRTC_SOAK_SECONDS``, like the netem and Chrome gates): a long run + of churn — add/remove consumers, add/remove video tracks, sustained submit — + with periodic forced ffmpeg kills, asserting every resource (subprocess / + thread / fd count and the three process-global registries) returns to baseline, + with no zombies and no panics. + +The ``/proc`` resource probes are Linux-only; the suite skips elsewhere. +""" + +from __future__ import annotations + +import os +import signal +import sys +import time + +import pytest + +from tests.integration.webrtc.shared import constants +from tests.integration.webrtc.shared.frames import encode_frame +from tests.integration.webrtc.shared.harness import ( + Relay, + bootstrap_connection, + collect_video_frames, + submit_at_rate, +) + +# Resource probes read /proc, so the whole module is Linux-only. +pytestmark = pytest.mark.skipif( + not sys.platform.startswith("linux"), + reason="resource/subprocess probes require /proc (Linux only)", +) + + +# --- /proc resource probes --------------------------------------------------- +def _own_children() -> list[tuple[int, str, str]]: + """`(pid, comm, state)` for every direct child process of this process.""" + me = os.getpid() + children: list[tuple[int, str, str]] = [] + for entry in os.listdir("/proc"): + if not entry.isdigit(): + continue + pid = int(entry) + try: + with open(f"/proc/{pid}/status") as handle: + status = handle.read() + except OSError: + continue # the process exited between listdir and open + ppid = comm = state = None + for line in status.splitlines(): + if line.startswith("PPid:"): + ppid = int(line.split()[1]) + elif line.startswith("Name:"): + comm = line.split(":", 1)[1].strip() + elif line.startswith("State:"): + state = line.split(":", 1)[1].strip() + if ppid == me: + children.append((pid, comm or "", state or "")) + return children + + +def _ffmpeg_children() -> list[int]: + """The pids of our directly-spawned ffmpeg encode/decode subprocesses.""" + return [pid for pid, comm, _ in _own_children() if "ffmpeg" in comm] + + +def _zombie_children() -> int: + """How many of our children are unreaped zombies (state ``Z``).""" + return sum(1 for _, _, state in _own_children() if state.startswith("Z")) + + +def _thread_count() -> int: + return len(os.listdir("/proc/self/task")) + + +def _fd_count() -> int: + return len(os.listdir("/proc/self/fd")) + + +def _kill_ffmpeg() -> int: + """SIGKILL every live ffmpeg child; return how many were killed.""" + pids = _ffmpeg_children() + for pid in pids: + try: + os.kill(pid, signal.SIGKILL) + except ProcessLookupError: + pass + return len(pids) + + +def _wait_until(predicate, timeout: float, interval: float = 0.05) -> bool: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if predicate(): + return True + time.sleep(interval) + return predicate() + + +def _producer_errors(relay: Relay, where: str | None = None) -> list[dict]: + return [ + e + for e in relay.producer_events() + if e.get("kind") == "on_error" and (where is None or e.get("where") == where) + ] + + +# --- error-injection: encode crash -> restart + on_error + recovery ---------- +def test_encode_crash_is_detected_restarted_and_surfaced(make_relay) -> None: + """Killing the encoder ffmpeg mid-stream restarts it, surfaces an + ``on_error{where: encode}``, and the decoded stream recovers — not a silent + stall and not a panic.""" + relay = make_relay() + bootstrap_connection(relay) + relay.producer.add_video_track("cam0") + + # Prime the stream so an encoder is running and frames are flowing. + submit_at_rate(relay, "cam0", fps=30, seconds=2) + assert _wait_until( + lambda: len(relay.video_frames("consumer", "cam0")) > 0, timeout=3.0 + ), "no frames decoded before the injected crash" + before = len(relay.video_frames("consumer", "cam0")) + + killed = _kill_ffmpeg() + assert killed > 0, "expected at least the encoder ffmpeg to be running" + + # Keep submitting: the feed must detect the dead encoder and restart it. + submit_at_rate(relay, "cam0", fps=30, seconds=3, start_counter=1000) + + assert _wait_until( + lambda: len(_producer_errors(relay, "encode")) > 0, timeout=3.0 + ), "the encoder crash was not surfaced as on_error{where: encode}" + + # The stream recovered past the crash (more frames decoded after it). + after = collect_video_frames(relay, "cam0") + assert len(after) > before, "the stream did not recover after the encoder crash" + + # No zombie ffmpeg children left behind by the restart. + assert _zombie_children() == 0, "restart left a zombie subprocess" + + +# --- error-injection: send on a closed track stays graceful ------------------ +def test_send_on_a_closed_track_is_graceful(make_relay) -> None: + """Closing the consumer out from under a live producer must not raise or + panic; the producer stays responsive and surfaces an error rather than + crashing.""" + relay = make_relay() + bootstrap_connection(relay) + relay.producer.add_video_track("cam0") + submit_at_rate(relay, "cam0", fps=30, seconds=1) + + # Tear the consumer down; the producer is now sending on a track whose remote + # end is gone. + relay.consumer.close() + + # Submitting after the peer closed must never raise (graceful, drop-on-fail). + for index in range(80): + relay.producer.submit_frame("cam0", encode_frame(2000 + index)) + time.sleep(0.01) + + # The producer is still usable (no panic across the FFI boundary): a further + # API call returns normally and the event queue is still drainable. + relay.producer.add_data_channel("late", "reliable") + assert isinstance(relay.producer.drain_events(), list) + assert _zombie_children() == 0 + + +# --- error-injection: malformed multi-slice input trips the PR5.6 guard ------- +def test_multislice_input_trips_the_invariant_guard_without_panic( + make_broadcast, monkeypatch +) -> None: + """Forcing a multi-slice encode makes one input frame emit several access + units; the PR5.6 capture-timestamp underflow guard drops the extras (one + timestamp per input frame) instead of fabricating timestamps or panicking, so + the emitted-access-unit count tracks the input rather than multiplying by the + slice count.""" + monkeypatch.setenv("NCD_WEBRTC_FORCE_SLICES", "4") + track_id = "cam0" + relay = make_broadcast() + relay.broadcaster.add_video_track(track_id) + make_broadcast.add_consumer(relay, "c0") + assert relay.wait_consumer_connected("c0"), "consumer did not connect" + assert ( + relay.wait_for( + "c0", + lambda e: e.get("kind") == "on_track_added" + and e.get("track_id") == track_id, + constants.RENEG_TIMEOUT_S, + ) + is not None + ) + + submitted = relay.submit_at_rate(track_id, fps=30, seconds=2) + time.sleep(0.5) + + encoded = relay.broadcaster.frames_encoded(track_id) + assert encoded is not None, "the source should exist (no crash)" + # Without the guard a 4-slice frame would emit ~4 access units per input frame + # (the assembler flushes per VCL slice). The underflow guard drops the extra 3 + # — one capture timestamp per input frame — so the count stays ~1:1 with the + # input rather than ~4x it. Generous bound so it is robust to keyframe + # restarts and host slice-count variation; the point is it does not multiply. + assert encoded <= submitted * 2, ( + f"emitted {encoded} access units for {submitted} input frames — the " + f"multi-slice underflow guard did not fire (count multiplied by slices)" + ) + # And nothing panicked: the broadcaster is still live. + assert relay.broadcaster.consumer_count() == 1 + assert _zombie_children() == 0 + + +# --- soak/stress (gated; shortenable) ---------------------------------------- +def _soak_enabled() -> bool: + return os.environ.get("NEURACORE_WEBRTC_SOAK", "") not in ("", "0", "false") + + +@pytest.mark.skipif( + not _soak_enabled(), + reason="long soak/churn gate; enable with NEURACORE_WEBRTC_SOAK=1 " + "(shorten with NEURACORE_WEBRTC_SOAK_SECONDS)", +) +def test_soak_churn_returns_all_resources_to_baseline(make_broadcast) -> None: + """A long run of consumer/track churn plus sustained submit and periodic + forced ffmpeg kills returns every resource to baseline: the three + process-global registries back to their starting size, no leaked + subprocesses, no zombies, thread and fd counts not growing without bound.""" + from neuracore.core.streaming.p2p.webrtc_selection import load_native + + module = load_native() + + seconds = float(os.environ.get("NEURACORE_WEBRTC_SOAK_SECONDS", 30)) + + # Warm the runtime up (its global threads persist) before baselining, so the + # baseline reflects the steady state, not a cold process. + warm = make_broadcast() + make_broadcast.add_consumer(warm, "warm") + warm.broadcaster.add_video_track("warm0") + warm.wait_consumer_connected("warm") + warm.submit_at_rate("warm0", fps=30, seconds=1) + warm.remove_consumer("warm") + warm.broadcaster.remove_video_track("warm0") + warm.close() + time.sleep(1.0) + + base_threads = _thread_count() + base_fds = _fd_count() + base_registries = module.registry_sizes() + assert base_registries == ( + 0, + 0, + 0, + ), f"registries not clean before the soak: {base_registries}" + assert _ffmpeg_children() == [], "stray ffmpeg before the soak" + + deadline = time.monotonic() + seconds + iteration = 0 + while time.monotonic() < deadline: + iteration += 1 + relay = make_broadcast() + relay.broadcaster.add_video_track("cam0") + ids = [f"c{iteration}_{i}" for i in range(3)] + for cid in ids: + make_broadcast.add_consumer(relay, cid) + for cid in ids: + relay.wait_consumer_connected(cid, timeout=constants.CONNECT_TIMEOUT_S) + + # Add then remove a second track mid-stream (track churn). + relay.broadcaster.add_video_track("cam1") + relay.submit_at_rate("cam0", fps=30, seconds=1) + relay.broadcaster.remove_video_track("cam1") + + # Forced ffmpeg kill (crash injection) every other iteration. + if iteration % 2 == 0: + _kill_ffmpeg() + relay.submit_at_rate("cam0", fps=30, seconds=1) + + # Consumer churn: drop one, keep submitting, then tear the relay down. + relay.remove_consumer(ids[0]) + relay.submit_at_rate("cam0", fps=20, seconds=0.5) + relay.close() + # Each fully-closed iteration must return the registries to baseline. + assert _wait_until( + lambda: module.registry_sizes() == base_registries, timeout=3.0 + ), f"registries leaked after iteration {iteration}: {module.registry_sizes()}" + + # Let any teardown threads/subprocesses wind down, then assert baseline. + time.sleep(2.0) + assert module.registry_sizes() == base_registries, "registry entries leaked" + assert _ffmpeg_children() == [], "ffmpeg subprocesses leaked" + assert _zombie_children() == 0, "zombie subprocesses left behind" + # Threads and fds may wobble slightly (runtime worker reuse), but must not grow + # with the iteration count — a small fixed slack, not a per-iteration one. + assert ( + _thread_count() <= base_threads + 4 + ), f"threads grew from {base_threads} to {_thread_count()} (leak)" + assert ( + _fd_count() <= base_fds + 16 + ), f"fds grew from {base_fds} to {_fd_count()} (leak)" diff --git a/tests/integration/webrtc/test_performance.py b/tests/integration/webrtc/test_performance.py new file mode 100644 index 000000000..51b6aa1f6 --- /dev/null +++ b/tests/integration/webrtc/test_performance.py @@ -0,0 +1,417 @@ +"""Test 3 - performance against the agreed SLOs. + +Each test measures one slice, records it into the shared :class:`Metrics` +(emitted as structured JSON at session end for CI), and asserts the SLO. The +structured-output schema is: + + {connect_ms, reneg_add_ms, reneg_remove_ms, dc_add_ms, + g2g_p50_ms, g2g_p95_ms, delivered_fps, drop_rate} + +xfail groups (see shared/markers.py): + * connect + dc-add timing -> PR2 + * add/remove renegotiation timing -> PR4 + * glass-to-glass + sustained fps (1 consumer) -> PR5 + * performance under a constrained link -> PR6 + * multi-consumer performance -> PR7 + +In a red run the first stubbed call (add_data_channel / add_video_track) raises +before any measurement loop runs, so even the 60s sustained-fps test xfails +immediately. +""" + +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import sys +import time +from collections.abc import Callable +from time import perf_counter + +import pytest + +from tests.integration.webrtc.shared import constants, metrics +from tests.integration.webrtc.shared.frames import decode_frame, parse_video_frame_event +from tests.integration.webrtc.shared.harness import ( + BroadcastRelay, + Relay, + collect_video_frames, + decoded_counters, + recv_time, + submit_at_rate, +) +from tests.integration.webrtc.shared.metrics import Metrics, percentile + +RelayFactory = Callable[..., Relay] +BroadcastFactory = Callable[..., BroadcastRelay] + + +# --- connection + data-channel timing (PR2) ---------------------------------- +def test_connect_established_under_slo( + make_relay: RelayFactory, perf_metrics: Metrics +) -> None: + samples: list[float] = [] + for _ in range(metrics.PERF_SAMPLES): + relay = make_relay() + start = perf_counter() + relay.producer.add_data_channel("control", "reliable") # raises in red + assert relay.wait_connected(constants.CONNECT_TIMEOUT_S), "no connection" + samples.append((perf_counter() - start) * 1000.0) + relay.close() + + p95 = percentile(samples, 95) + perf_metrics.connect_ms = p95 + assert p95 < metrics.CONNECT_MS_P95, f"connect p95 {p95:.1f}ms over SLO" + + +def test_data_channel_add_under_slo(relay: Relay, perf_metrics: Metrics) -> None: + relay.producer.add_data_channel("control", "reliable") # raises in red + assert relay.wait_connected(constants.CONNECT_TIMEOUT_S), "no connection" + + samples: list[float] = [] + for i in range(metrics.PERF_SAMPLES): + label = f"dc{i}" + start = perf_counter() + relay.producer.add_data_channel(label, "reliable") + observed = relay.wait_consumer( + lambda e, lbl=label: e.get("kind") == "on_data_channel" + and e.get("label") == lbl, + constants.DC_OPEN_TIMEOUT_S, + ) + assert observed is not None, f"channel {label} not usable at consumer" + samples.append((perf_counter() - start) * 1000.0) + + p95 = percentile(samples, 95) + perf_metrics.dc_add_ms = p95 + assert p95 < metrics.DC_ADD_MS_P95, f"dc-add p95 {p95:.1f}ms over SLO" + + +# --- renegotiation timing (PR3) ---------------------------------------------- +def test_reneg_add_track_under_slo(relay: Relay, perf_metrics: Metrics) -> None: + relay.producer.add_data_channel("control", "reliable") # raises in red + bootstrap_wait(relay) + + add_samples: list[float] = [] + for i in range(metrics.PERF_SAMPLES): + track_id = f"cam{i}" + start = perf_counter() + mid = relay.producer.add_video_track(track_id) + added = relay.wait_consumer( + lambda e, t=track_id: e.get("kind") == "on_track_added" + and e.get("track_id") == t, + constants.RENEG_TIMEOUT_S, + ) + assert added is not None, f"add of {track_id} not observed" + add_samples.append((perf_counter() - start) * 1000.0) + # clean up so the next iteration starts from a known track set + relay.producer.remove_video_track(track_id) + relay.wait_consumer( + lambda e, m=mid: e.get("kind") == "on_track_removed" and e.get("mid") == m, + constants.RENEG_TIMEOUT_S, + ) + + p95 = percentile(add_samples, 95) + perf_metrics.reneg_add_ms = p95 + assert p95 < metrics.RENEG_ADD_MS_P95, f"reneg-add p95 {p95:.1f}ms over SLO" + + +def test_reneg_remove_track_under_slo(relay: Relay, perf_metrics: Metrics) -> None: + relay.producer.add_data_channel("control", "reliable") # raises in red + bootstrap_wait(relay) + + remove_samples: list[float] = [] + for i in range(metrics.PERF_SAMPLES): + track_id = f"cam{i}" + mid = relay.producer.add_video_track(track_id) + added = relay.wait_consumer( + lambda e, t=track_id: e.get("kind") == "on_track_added" + and e.get("track_id") == t, + constants.RENEG_TIMEOUT_S, + ) + assert added is not None, f"setup add of {track_id} not observed" + + start = perf_counter() + relay.producer.remove_video_track(track_id) + removed = relay.wait_consumer( + lambda e, m=mid: e.get("kind") == "on_track_removed" and e.get("mid") == m, + constants.RENEG_TIMEOUT_S, + ) + assert removed is not None, f"remove of {track_id} not observed" + remove_samples.append((perf_counter() - start) * 1000.0) + + p95 = percentile(remove_samples, 95) + perf_metrics.reneg_remove_ms = p95 + assert p95 < metrics.RENEG_REMOVE_MS_P95, f"reneg-remove p95 {p95:.1f}ms over SLO" + + +# --- glass-to-glass + sustained fps, single consumer (PR5) ------------------- +def test_glass_to_glass_under_slo(relay: Relay, perf_metrics: Metrics) -> None: + track_id = "cam0" + relay.producer.add_data_channel("control", "reliable") # raises in red + relay.producer.add_video_track(track_id) + bootstrap_wait(relay) + + submitted, submit_times = submit_at_rate( + relay, track_id, fps=metrics.SOURCE_FPS, seconds=5 + ) + assert submitted > 0 + time.sleep(0.2) + frames = collect_video_frames(relay, track_id) + assert frames, "no frames delivered for glass-to-glass measurement" + + g2g: list[float] = [] + for event in frames: + _, _, array = parse_video_frame_event(event) + counter, ok = decode_frame(array) + recv = recv_time(event) + if ok and counter in submit_times and recv is not None: + g2g.append((recv - submit_times[counter]) * 1000.0) + assert g2g, "no decodable frames matched a submit timestamp" + + p50 = percentile(g2g, 50) + p95 = percentile(g2g, 95) + perf_metrics.g2g_p50_ms = p50 + perf_metrics.g2g_p95_ms = p95 + assert p50 < metrics.G2G_P50_MS, f"g2g p50 {p50:.1f}ms over SLO" + assert p95 < metrics.G2G_P95_MS, f"g2g p95 {p95:.1f}ms over SLO" + + +def test_sustained_fps_single_consumer(relay: Relay, perf_metrics: Metrics) -> None: + track_id = "cam0" + relay.producer.add_data_channel("control", "reliable") # raises in red + relay.producer.add_video_track(track_id) + bootstrap_wait(relay) + + # Phase 1: at-or-below 30fps must drop nothing. + at_rate_seconds = min(10.0, metrics.PERF_DURATION_S) + sent_lo, _ = submit_at_rate( + relay, track_id, fps=metrics.AT_RATE_FPS, seconds=at_rate_seconds + ) + time.sleep(0.2) + frames_lo = collect_video_frames(relay, track_id) + counters_lo, corrupted_lo = decoded_counters(frames_lo) + assert not corrupted_lo, "corruption at or below 30fps" + delivered_lo = len(set(counters_lo)) + assert ( + delivered_lo == sent_lo + ), f"dropped {sent_lo - delivered_lo} frames at or below 30fps (expected 0)" + + # Phase 2: over-rate (45fps) for the full duration. Drops allowed here, but + # delivered throughput must hold the floor. + start_counter = sent_lo + sent_hi, _ = submit_at_rate( + relay, + track_id, + fps=metrics.SOURCE_FPS, + seconds=metrics.PERF_DURATION_S, + start_counter=start_counter, + ) + time.sleep(0.2) + frames_all = collect_video_frames(relay, track_id) + counters_all, corrupted_all = decoded_counters(frames_all) + assert not corrupted_all, "corruption during sustained run" + delivered_hi = len({c for c in counters_all if c >= start_counter}) + + delivered_fps = delivered_hi / metrics.PERF_DURATION_S + drop_rate = 1.0 - (delivered_hi / sent_hi) if sent_hi else 0.0 + perf_metrics.delivered_fps = delivered_fps + perf_metrics.drop_rate = drop_rate + assert ( + delivered_fps >= metrics.MIN_DELIVERED_FPS + ), f"delivered {delivered_fps:.1f}fps below the {metrics.MIN_DELIVERED_FPS} floor" + + +# --- constrained link -------------------------------------------------------- +def _netns_available() -> str | None: + """Return None if a private netns + netem can be set up, else a skip reason. + + The constrained-link test needs CAP_SYS_ADMIN (``unshare -n``) and + CAP_NET_ADMIN (``tc``); on a host without them (most CI) the test skips + rather than fails. It is a real-netem decision/nightly gate, not a fast + per-PR check. + """ + if shutil.which("tc") is None or shutil.which("unshare") is None: + return "tc/unshare not on PATH" + probe = subprocess.run( + ["unshare", "-n", "tc", "qdisc", "show"], + capture_output=True, + text=True, + ) + if probe.returncode != 0: + return ( + "cannot enter a private netns (need CAP_SYS_ADMIN/NET_ADMIN): " + f"{probe.stderr.strip()}" + ) + return None + + +def test_perf_under_constrained_link(perf_metrics: Metrics) -> None: + """Under a real netem-shaped loopback the stream degrades gracefully. + + Netem cannot be applied to the in-process peers' loopback (the container is + network_mode: host), so the body runs in a private network namespace via + ``netem_runner`` under ``unshare -n``: it brings up the namespace's `lo`, + applies ``NEURACORE_WEBRTC_NETEM`` (a profile that overflows a 45 fps stream + but fits a degraded one), runs the relay, and reports steady-state results. + + The contract: once the REMB+RR estimator settles the ladder on a fitting + rung, the *good* (checksum-valid) delivered-fps holds the floor, the producer + demonstrably adapted, and the connection stays up. Without adaptation + (``NCD_WEBRTC_DISABLE_ADAPT``) the same constraint collapses the stream — the + proof the netem bite is real (recorded in reports/PR5-congestion.md). + """ + skip_reason = _netns_available() + if skip_reason is not None: + pytest.skip(skip_reason) + + env = dict(os.environ) + env.setdefault("NCD_RUST_WEBRTC", "1") + env.setdefault("PYTHONPATH", os.getcwd()) + proc = subprocess.run( + [ + "unshare", + "-n", + sys.executable, + "-m", + "tests.integration.webrtc.netem_runner", + ], + capture_output=True, + text=True, + env=env, + timeout=120, + ) + assert proc.returncode == 0, f"netem runner crashed: {proc.stderr[-2000:]}" + line = proc.stdout.strip().splitlines()[-1] + result = json.loads(line) + assert result.get("ok"), f"netem runner failed: {result.get('error')}" + + perf_metrics.delivered_fps = result["delivered_fps"] + assert not result["closed"], "connection dropped under the constrained link" + assert result["max_step"] > 0, ( + "producer did not adapt (ladder never left the finest rung) under " + f"netem {result['netem']!r}" + ) + assert result["delivered_fps"] >= metrics.MIN_DELIVERED_FPS, ( + f"steady-state delivered {result['delivered_fps']:.1f}fps below the " + f"{metrics.MIN_DELIVERED_FPS} floor under netem {result['netem']!r}" + ) + + +# --- Chrome interop decision gate -------------------------------------------- +@pytest.mark.skipif( + os.environ.get("NEURACORE_WEBRTC_CHROME") not in ("1", "true", "yes"), + reason="Chrome interop is the REMB decision/nightly gate: it needs installed " + "Google Chrome (Playwright channel 'chrome') and is kept out of the fast " + "per-PR loopback suite. Enable with NEURACORE_WEBRTC_CHROME=1. The recorded " + "verdict lives in reports/PR5-congestion.md.", +) +def test_chrome_interop(perf_metrics: Metrics) -> None: + """Drive real Google Chrome as the consumer of the producer's built-in chain. + + Runs ``chrome_interop`` (host mode) and asserts Chrome negotiates and receives + our H.264 with zero loss and the producer's REMB-driven estimator engages. + The full verdict (including the Chrome-under-netem environmental limitation + and the P-frame RTP-assembly finding) is recorded in the report; this test is + a smoke gate that the harness still works. + """ + proc = subprocess.run( + [sys.executable, "-m", "tests.integration.webrtc.chrome_interop"], + capture_output=True, + text=True, + env={**os.environ, "NCD_RUST_WEBRTC": "1", "PYTHONPATH": os.getcwd()}, + timeout=180, + ) + assert proc.returncode == 0, f"chrome harness crashed: {proc.stderr[-2000:]}" + result = json.loads(proc.stdout.strip().splitlines()[-1]) + assert result.get("ok"), f"chrome harness failed: {result.get('error')}" + # Chrome received our chain's media with no loss and the REMB estimator engaged. + assert result["packetsReceived"] > 0, "Chrome received no media from the producer" + assert result["packetsLost"] == 0, "unexpected loss on a clean host link" + assert result["max_step"] >= 0 # adaptation ran (driven by Chrome's real REMB) + perf_metrics.delivered_fps = result.get("tail_decoded_fps") + + +# --- multi-consumer ---------------------------------------------------------- +def test_multi_consumer_perf( + make_broadcast: BroadcastFactory, perf_metrics: Metrics +) -> None: + """One producer serves N consumers from a single shared encode per source. + + Submitting at the source rate, every consumer must deliver at or above the + 30 fps floor. The decisive assertion is that the encode is *shared*: exactly + one encoder runs for the one source regardless of N, and the per-source + frames-encoded stat tracks the submitted count (one encode), not N times it. + That is the observable that proves fan-out rather than N independent encodes. + + Loopback caveat: N co-located consumers contend for CPU, so the per-consumer + min-governance under real loss is a Chrome-capable-netem measurement (see + reports/PR6-fanout.md), not a loopback one; here the min-fold is unit-tested. + """ + track_id = "cam0" + n = metrics.MULTI_CONSUMER_N + seconds = 10 + relay = make_broadcast() + + # One source, visible to all consumers (current and future). + relay.broadcaster.add_video_track(track_id) + + consumer_ids = [f"c{i}" for i in range(n)] + for consumer_id in consumer_ids: + make_broadcast.add_consumer(relay, consumer_id) + for consumer_id in consumer_ids: + assert relay.wait_consumer_connected( + consumer_id + ), f"consumer {consumer_id} did not connect" + # Wait until this consumer has observed the source track (manifest diff), + # so its track is negotiated before the shared encode starts. + added = relay.wait_for( + consumer_id, + lambda e, t=track_id: e.get("kind") == "on_track_added" + and e.get("track_id") == t, + constants.RENEG_TIMEOUT_S, + ) + assert added is not None, f"consumer {consumer_id} never saw the source track" + + submitted = relay.submit_at_rate(track_id, fps=metrics.SOURCE_FPS, seconds=seconds) + assert submitted > 0 + + # The encode is shared: exactly one encoder for the one source, regardless of + # how many consumers receive it. + assert ( + relay.broadcaster.encoder_count() == 1 + ), f"expected one shared encode, saw {relay.broadcaster.encoder_count()}" + + # ...and the per-source frames-encoded stat tracks the submitted count (one + # encode), not N times it. A small overhead allowance covers keyframe + # restarts; the point is it does NOT scale with the consumer count. + encoded = relay.broadcaster.frames_encoded(track_id) + assert encoded is not None + assert encoded <= submitted * 1.5, ( + f"frames_encoded {encoded} scales with consumers (submitted {submitted}, " + f"{n} consumers) — the encode is not shared" + ) + + delivered: list[float] = [] + for consumer_id in consumer_ids: + frames = relay.collect_video_frames(consumer_id, track_id) + counters, corrupted = decoded_counters(frames) + assert not corrupted, f"consumer {consumer_id} saw corruption" + delivered_fps = len(set(counters)) / seconds + delivered.append(delivered_fps) + assert delivered_fps >= metrics.MIN_DELIVERED_FPS, ( + f"consumer {consumer_id} delivered {delivered_fps:.1f}fps below the " + f"{metrics.MIN_DELIVERED_FPS} floor" + ) + + # Record the worst consumer's delivered fps for the CI metrics line. + perf_metrics.delivered_fps = min(delivered) + + +# --- local helpers ----------------------------------------------------------- +def bootstrap_wait(relay: Relay) -> None: + """Wait for the connection brought up by a prior add_* call to establish.""" + assert relay.wait_connected( + constants.CONNECT_TIMEOUT_S + ), "connection not established" diff --git a/tests/integration/webrtc/test_server_transport_unit.py b/tests/integration/webrtc/test_server_transport_unit.py new file mode 100644 index 000000000..8bfd56abc --- /dev/null +++ b/tests/integration/webrtc/test_server_transport_unit.py @@ -0,0 +1,262 @@ +"""Peer-free unit tests for the server-backed signaling transport. + +These need no native peers and no backend, so they run in the sandbox: they +cover the consumer-side signaling mapping, connection_id / consumer_id routing, +and SSE frame dispatch by type with a fake feed. +""" + +from __future__ import annotations + +import json + +from neuracore_types import HandshakeMessage, MessageType + +from tests.integration.webrtc.shared.server_transport import ( + ServerBroadcastRelay, + ServerRelay, + SignalingConfig, + _Inbound, + consumer_outbound_signal, + inbound_candidate, + parse_sse_lines, + signaling_config_from_env, +) + +DUMMY_CONFIG = SignalingConfig(base_url="http://backend/api", org="org1", token="tok") + + +# --------------------------------------------------------------------------- # +# Fakes +# --------------------------------------------------------------------------- # +class FakeConsumer: + """Records the consumer-side inbound setter calls.""" + + def __init__(self) -> None: + self.offers: list[str] = [] + self.candidates: list[tuple[str, str | None]] = [] + + def set_remote_offer(self, sdp: str) -> None: + self.offers.append(sdp) + + def add_remote_candidate(self, candidate: str, mid: str | None) -> None: + self.candidates.append((candidate, mid)) + + def set_remote_answer(self, sdp: str) -> None: # producer-side use + self.offers.append(sdp) + + +class FakeBroadcaster: + """Records the broadcaster-side per-consumer inbound setter calls.""" + + def __init__(self) -> None: + self.answers: list[tuple[str, str]] = [] + self.candidates: list[tuple[str, str, str | None]] = [] + self.added: list[str] = [] + self.removed: list[str] = [] + + def set_remote_answer(self, consumer_id: str, sdp: str) -> None: + self.answers.append((consumer_id, sdp)) + + def add_remote_candidate( + self, consumer_id: str, candidate: str, mid: str | None + ) -> None: + self.candidates.append((consumer_id, candidate, mid)) + + def add_consumer(self, consumer_id: str) -> None: + self.added.append(consumer_id) + + def remove_consumer(self, consumer_id: str) -> None: + self.removed.append(consumer_id) + + def drain_events(self) -> list[dict]: + return [] + + def close(self) -> None: + pass + + +def _message( + message_type: MessageType, data: str, *, connection_id: str = "conn" +) -> HandshakeMessage: + return HandshakeMessage( + from_id="from", + to_id="to", + connection_id=connection_id, + type=message_type, + data=data, + ) + + +# --------------------------------------------------------------------------- # +# Consumer-side mapping +# --------------------------------------------------------------------------- # +def test_consumer_answer_maps_to_sdp_answer() -> None: + event = {"kind": "on_local_description", "sdp_type": "answer", "sdp": "the-sdp"} + assert consumer_outbound_signal(event) == (MessageType.SDP_ANSWER, "the-sdp") + + +def test_consumer_offer_is_not_emitted() -> None: + # The consumer is answer-only: it never offers, so an offer maps to None. + event = {"kind": "on_local_description", "sdp_type": "offer", "sdp": "x"} + assert consumer_outbound_signal(event) is None + + +def test_consumer_non_signaling_event_maps_to_none() -> None: + assert consumer_outbound_signal({"kind": "on_state", "state": "connected"}) is None + + +def test_consumer_candidate_formats_and_round_trips() -> None: + event = {"kind": "on_local_candidate", "candidate": "candidate:1 udp", "mid": "0"} + mapped = consumer_outbound_signal(event) + assert mapped is not None + message_type, data = mapped + assert message_type == MessageType.ICE_CANDIDATE + payload = json.loads(data) + assert payload["candidate"] == "candidate:1 udp" + assert payload["sdpMid"] == "0" + # The same parser the producer uses recovers (candidate, mid). + assert inbound_candidate(data) == ("candidate:1 udp", "0") + + +def test_inbound_sdp_offer_drives_set_remote_offer() -> None: + relay = ServerBroadcastRelay(FakeBroadcaster(), config=DUMMY_CONFIG) + consumer = FakeConsumer() + handle = relay._make_consumer_handler("c0", consumer) + handle(_message(MessageType.SDP_OFFER, "offer-sdp")) + handle( + _message( + MessageType.ICE_CANDIDATE, + json.dumps({"candidate": "cand", "sdpMid": "1", "sdpMLineIndex": "1"}), + ) + ) + assert consumer.offers == ["offer-sdp"] + assert consumer.candidates == [("cand", "1")] + + +# --------------------------------------------------------------------------- # +# Trickle buffering +# --------------------------------------------------------------------------- # +def test_inbound_buffers_candidates_until_description() -> None: + applied: list[str] = [] + candidates: list[tuple[str, str | None]] = [] + inbound = _Inbound(applied.append, lambda c, m: candidates.append((c, m))) + + inbound.candidate("early-1", "0") + inbound.candidate("early-2", "0") + assert candidates == [] # buffered, not applied before the description + + inbound.description("sdp") + assert applied == ["sdp"] + assert candidates == [("early-1", "0"), ("early-2", "0")] # flushed in order + + inbound.candidate("late", "0") + assert candidates[-1] == ("late", "0") # applied immediately afterwards + + +# --------------------------------------------------------------------------- # +# Routing by connection_id / consumer_id +# --------------------------------------------------------------------------- # +def test_broadcaster_inbound_routes_answer_to_correct_consumer() -> None: + broadcaster = FakeBroadcaster() + relay = ServerBroadcastRelay(broadcaster, config=DUMMY_CONFIG) + relay._conn_to_consumer["conn-a"] = "c-a" + relay._conn_to_consumer["conn-b"] = "c-b" + relay._broadcaster_inbound["c-a"] = relay._make_broadcaster_inbound("c-a") + relay._broadcaster_inbound["c-b"] = relay._make_broadcaster_inbound("c-b") + + relay._on_broadcaster_message( + _message(MessageType.SDP_ANSWER, "ans-b", connection_id="conn-b") + ) + relay._on_broadcaster_message( + _message( + MessageType.ICE_CANDIDATE, + json.dumps({"candidate": "cb", "sdpMid": "0"}), + connection_id="conn-b", + ) + ) + assert broadcaster.answers == [("c-b", "ans-b")] + assert broadcaster.candidates == [("c-b", "cb", "0")] + + +def test_broadcaster_inbound_ignores_unknown_connection_id() -> None: + broadcaster = FakeBroadcaster() + relay = ServerBroadcastRelay(broadcaster, config=DUMMY_CONFIG) + # No route registered: an unknown connection_id is ignored, not raised. + relay._on_broadcaster_message( + _message(MessageType.SDP_ANSWER, "x", connection_id="ghost") + ) + assert broadcaster.answers == [] + + +def test_relay_inbound_offer_and_answer_dispatch() -> None: + producer = FakeConsumer() # exposes set_remote_answer + add_remote_candidate + consumer = FakeConsumer() + relay = ServerRelay(producer, consumer, config=DUMMY_CONFIG) + + relay._on_consumer_message(_message(MessageType.SDP_OFFER, "offer")) + relay._on_producer_message(_message(MessageType.SDP_ANSWER, "answer")) + assert consumer.offers == ["offer"] + assert producer.offers == ["answer"] # FakeConsumer.set_remote_answer records here + + +# --------------------------------------------------------------------------- # +# SSE dispatch by type (fake feed) +# --------------------------------------------------------------------------- # +def test_parse_sse_lines_classifies_by_event_type() -> None: + message = _message(MessageType.SDP_OFFER, "sdp-body").model_dump_json() + feed = [ + "event:data", + f"data:{message}", + "", + "event:heartbeat", + "data:ping", + "", + "event:end", + "data:", + "", + ] + frames = list(parse_sse_lines(feed)) + assert [event for event, _ in frames] == ["data", "heartbeat", "end"] + + data_event, data_body = frames[0] + assert data_event == "data" + # The data frame round-trips back into a HandshakeMessage. + parsed = HandshakeMessage.model_validate_json(data_body) + assert parsed.type == MessageType.SDP_OFFER + assert parsed.data == "sdp-body" + + +def test_parse_sse_lines_handles_comments_and_leading_space() -> None: + feed = [ + ":comment-keepalive", + "event: data", + "data: hello", + "", + ] + assert list(parse_sse_lines(feed)) == [("data", "hello")] + + +def test_parse_sse_lines_accepts_bytes() -> None: + feed = [b"event:heartbeat", b"data:ping", b""] + assert list(parse_sse_lines(feed)) == [("heartbeat", "ping")] + + +# --------------------------------------------------------------------------- # +# Env gating +# --------------------------------------------------------------------------- # +def test_config_from_env_none_when_unset(monkeypatch) -> None: + for var in ( + "NEURACORE_WEBRTC_SIGNALING_URL", + "NEURACORE_WEBRTC_SIGNALING_ORG", + "NEURACORE_WEBRTC_SIGNALING_TOKEN", + ): + monkeypatch.delenv(var, raising=False) + assert signaling_config_from_env() is None + + +def test_config_from_env_built_when_set(monkeypatch) -> None: + monkeypatch.setenv("NEURACORE_WEBRTC_SIGNALING_URL", "http://backend/api/") + monkeypatch.setenv("NEURACORE_WEBRTC_SIGNALING_ORG", "org42") + monkeypatch.setenv("NEURACORE_WEBRTC_SIGNALING_TOKEN", "secret") + config = signaling_config_from_env() + assert config == SignalingConfig("http://backend/api", "org42", "secret") diff --git a/tests/unit/webrtc/__init__.py b/tests/unit/webrtc/__init__.py new file mode 100644 index 000000000..7282d82f4 --- /dev/null +++ b/tests/unit/webrtc/__init__.py @@ -0,0 +1,9 @@ +"""Fast, peer-free unit tests for the Rust WebRTC stack's Python surface. + +These pin the pure logic added in PR0-PR3 (the recording bridge, the +feature-flag selection/loader, and the PR1 test-helper logic) without a native +module, a peer connection, sockets, or sleeps. The native transport, the +negotiation queue, and the manifest model are unit-tested on the Rust side in +``neuracore_webrtc``'s ``#[cfg(test)]`` modules; everything requiring a live +``PeerConnection`` stays in ``tests/integration/webrtc``. +""" diff --git a/tests/unit/webrtc/test_frames.py b/tests/unit/webrtc/test_frames.py new file mode 100644 index 000000000..5ea2edbc8 --- /dev/null +++ b/tests/unit/webrtc/test_frames.py @@ -0,0 +1,94 @@ +"""Unit tests for the synthetic frame codec (PR1 ``shared/frames.py``). + +Formalises the self-checks the PR1 report noted: the embedded 32-bit counter +round-trips across the whole ``[0, 2**32)`` range, a scrambled header band is +flagged as corrupt, and the counter survives a simulated lossy (blur + quantise ++ noise) path the way it must survive H.264. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from tests.integration.webrtc.shared import frames + +# A spread across the 32-bit range: edges, powers of two, and an odd large value. +_COUNTERS = [ + 0, + 1, + 255, + 256, + 65535, + 65536, + 1234567, + 2**31, + 2**32 - 2, + 2**32 - 1, +] + + +@pytest.mark.parametrize("counter", _COUNTERS) +def test_counter_round_trips_clean(counter: int) -> None: + frame = frames.encode_frame(counter) + # Contract for submit_frame: C-contiguous uint8 H x W x 3. + assert frame.dtype == np.uint8 + assert frame.shape == (frames.HEIGHT, frames.WIDTH, frames.CHANNELS) + assert frame.flags["C_CONTIGUOUS"] + + recovered, ok = frames.decode_frame(frame) + assert ok, f"checksum failed for counter {counter}" + assert recovered == counter + + +def test_scrambled_header_band_is_flagged_as_corrupt() -> None: + frame = frames.encode_frame(987654) + # Flip the header band (the top rows that carry the block-coded counter + + # checksum) so the recovered counter no longer matches its checksum. + header_rows = frames._HEADER_ROWS * frames._BLOCK + corrupted = frame.copy() + corrupted[:header_rows, :, :] = 255 - corrupted[:header_rows, :, :] + + _, ok = frames.decode_frame(corrupted) + assert not ok, "corruption in the header band must be detected by the checksum" + + +def _lossy(frame: np.ndarray) -> np.ndarray: + """A deterministic lossy transform standing in for an H.264 round trip: + a 3x3 box blur, an 8-level quantisation, and a small fixed ripple.""" + blurred = frame.astype(np.float32) + # Separable 3x3 box blur via shifts (cheap, no scipy dependency). + for axis in (0, 1): + blurred = ( + blurred + np.roll(blurred, 1, axis=axis) + np.roll(blurred, -1, axis=axis) + ) / 3.0 + quantised = np.round(blurred / 32.0) * 32.0 + rows = np.arange(frame.shape[0])[:, None, None] + ripple = (8.0 * np.sin(rows / 7.0)).astype(np.float32) + return np.clip(quantised + ripple, 0, 255).astype(np.uint8) + + +@pytest.mark.parametrize("counter", [0, 42, 4096, 1_000_000, 2**32 - 1]) +def test_counter_survives_a_lossy_path(counter: int) -> None: + frame = frames.encode_frame(counter) + recovered, ok = frames.decode_frame(_lossy(frame)) + # The solid blocks + centre sampling are engineered to survive exactly this. + assert ok, f"checksum failed through the lossy path for counter {counter}" + assert recovered == counter + + +def test_parse_video_frame_event_reshapes_the_pr5_contract() -> None: + counter = 314159 + frame = frames.encode_frame(counter) + event = { + "kind": "on_frame", + "track_id": "cam0", + "mid": "v0", + "data": frame.tobytes(), + "width": frames.WIDTH, + "height": frames.HEIGHT, + } + track_id, mid, array = frames.parse_video_frame_event(event) + assert (track_id, mid) == ("cam0", "v0") + recovered, ok = frames.decode_frame(array) + assert ok and recovered == counter diff --git a/tests/unit/webrtc/test_markers.py b/tests/unit/webrtc/test_markers.py new file mode 100644 index 000000000..83d3af0d3 --- /dev/null +++ b/tests/unit/webrtc/test_markers.py @@ -0,0 +1,92 @@ +"""Unit tests for the xfail-marker mechanism (PR1 ``shared/markers.py``). + +Pins two PR1 self-checks: ``greened_by`` returns a strict xfail marker for every +valid PR slice (and rejects an unknown one), and the suite's live ``greened_by`` +tags form a clean test -> PR map — every tag is a known ``PR_SLICES`` key and no +test is tagged more than once. (A PR greens its slice by deleting its marker, so +PR2-PR4 keys are no longer used as live tags; the invariant tested here is that +whatever tags remain are valid and unambiguous.) +""" + +from __future__ import annotations + +import ast +from pathlib import Path + +import pytest + +from tests.integration.webrtc.shared.markers import PR_SLICES, greened_by + +_SUITE_DIR = Path(__file__).resolve().parents[2] / "integration" / "webrtc" +_TEST_MODULES = sorted(_SUITE_DIR.glob("test_*.py")) + + +def test_pr_slices_descriptions_are_non_empty_and_unique() -> None: + assert PR_SLICES, "the PR -> slice map must not be empty" + assert all(desc.strip() for desc in PR_SLICES.values()) + assert len(set(PR_SLICES.values())) == len(PR_SLICES), "duplicate slice text" + + +@pytest.mark.parametrize("pr", sorted(PR_SLICES)) +def test_greened_by_yields_a_strict_xfail_for_every_slice(pr: str) -> None: + detail = "a specific assertion this test pins" + marker = greened_by(pr, detail) + + assert marker.mark.name == "xfail" + assert marker.mark.kwargs["strict"] is True + reason = marker.mark.kwargs["reason"] + assert pr in reason + assert PR_SLICES[pr] in reason + assert detail in reason + + +def test_greened_by_rejects_an_unknown_pr() -> None: + with pytest.raises(KeyError): + greened_by("PR999", "no such slice") + + +def _greened_tags_by_test() -> dict[str, list[str]]: + """Map each integration test function -> the PR tags it is greened_by. + + Parses the suite statically (no import, no native module) so the check runs + in the unit layer. + """ + tags: dict[str, list[str]] = {} + for module in _TEST_MODULES: + tree = ast.parse(module.read_text(), filename=str(module)) + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef): + continue + if not node.name.startswith("test_"): + continue + found: list[str] = [] + for decorator in node.decorator_list: + if ( + isinstance(decorator, ast.Call) + and isinstance(decorator.func, ast.Name) + and decorator.func.id == "greened_by" + and decorator.args + and isinstance(decorator.args[0], ast.Constant) + ): + found.append(decorator.args[0].value) + tags[node.name] = found + return tags + + +def test_suite_is_discovered() -> None: + # Guard against a path typo silently passing the static checks below. + assert _TEST_MODULES, f"no integration test modules under {_SUITE_DIR}" + assert _greened_tags_by_test(), "no test functions discovered in the suite" + + +def test_every_live_tag_is_a_known_slice() -> None: + for test_name, found in _greened_tags_by_test().items(): + for tag in found: + assert tag in PR_SLICES, f"{test_name} tagged with unknown slice {tag!r}" + + +def test_no_test_is_greened_by_more_than_once() -> None: + # "covers every test exactly once": a marked test maps to a single PR; an + # unmarked (already-greened) test maps to none. + for test_name, found in _greened_tags_by_test().items(): + assert len(found) <= 1, f"{test_name} carries multiple greened_by markers" diff --git a/tests/unit/webrtc/test_metrics.py b/tests/unit/webrtc/test_metrics.py new file mode 100644 index 000000000..dc5825d1f --- /dev/null +++ b/tests/unit/webrtc/test_metrics.py @@ -0,0 +1,96 @@ +"""Unit tests for the perf metrics helpers (PR1 ``shared/metrics.py``). + +Pins the percentile maths on known inputs and the structured CI output schema +``emit`` writes — the contract CI scrapes. +""" + +from __future__ import annotations + +import dataclasses +import json + +import pytest + +from tests.integration.webrtc.shared import metrics +from tests.integration.webrtc.shared.metrics import Metrics, emit, percentile + + +def test_percentile_on_known_inputs() -> None: + data = [10.0, 20.0, 30.0, 40.0, 50.0] + assert percentile(data, 0) == 10.0 + assert percentile(data, 50) == 30.0 # median + assert percentile(data, 100) == 50.0 + assert percentile(data, 25) == 20.0 + + +def test_percentile_linear_interpolation_between_ranks() -> None: + # rank = (4-1)*0.5 = 1.5 -> halfway between ordered[1]=2 and ordered[2]=3. + assert percentile([1.0, 2.0, 3.0, 4.0], 50) == 2.5 + + +def test_percentile_is_order_independent() -> None: + assert percentile([50.0, 10.0, 40.0, 20.0, 30.0], 50) == 30.0 + + +def test_percentile_single_sample_returns_that_sample() -> None: + assert percentile([7.5], 95) == 7.5 + + +def test_percentile_of_empty_raises() -> None: + with pytest.raises(ValueError): + percentile([], 50) + + +def test_emit_writes_the_full_schema( + monkeypatch: pytest.MonkeyPatch, tmp_path, capsys +) -> None: + out = tmp_path / "perf.json" + monkeypatch.setenv("NEURACORE_WEBRTC_PERF_OUT", str(out)) + + sample = Metrics( + connect_ms=57.0, + reneg_add_ms=4.7, + reneg_remove_ms=5.6, + dc_add_ms=2.2, + g2g_p50_ms=80.0, + g2g_p95_ms=150.0, + delivered_fps=42.0, + drop_rate=0.06, + ) + payload = emit(sample) + + # Every field of the dataclass is a schema key, present in all three sinks: + # the returned payload, the file, and the stderr line. + expected_keys = {f.name for f in dataclasses.fields(Metrics)} + assert expected_keys == { + "connect_ms", + "reneg_add_ms", + "reneg_remove_ms", + "dc_add_ms", + "g2g_p50_ms", + "g2g_p95_ms", + "delivered_fps", + "drop_rate", + } + assert set(json.loads(payload)) == expected_keys + assert set(json.loads(out.read_text())) == expected_keys + assert "[neuracore-webrtc-perf]" in capsys.readouterr().err + + +def test_emit_null_fields_still_pin_the_schema( + monkeypatch: pytest.MonkeyPatch, capsys +) -> None: + # An all-None Metrics (nothing measured this run) still emits every key. + monkeypatch.delenv("NEURACORE_WEBRTC_PERF_OUT", raising=False) + payload = json.loads(emit(Metrics())) + assert set(payload) == {f.name for f in dataclasses.fields(Metrics)} + assert all(value is None for value in payload.values()) + + +def test_slo_constants_are_present() -> None: + # The SLO thresholds the perf suite asserts against are part of the contract. + assert metrics.CONNECT_MS_P95 == 500.0 + assert metrics.RENEG_ADD_MS_P95 == 300.0 + assert metrics.RENEG_REMOVE_MS_P95 == 300.0 + assert metrics.DC_ADD_MS_P95 == 300.0 + assert metrics.MIN_DELIVERED_FPS == 30.0 diff --git a/tests/unit/webrtc/test_native_broadcast_provider.py b/tests/unit/webrtc/test_native_broadcast_provider.py new file mode 100644 index 000000000..b4c86f0f5 --- /dev/null +++ b/tests/unit/webrtc/test_native_broadcast_provider.py @@ -0,0 +1,397 @@ +"""Peer-free unit tests for the web producer wiring (PR8). + +These exercise :mod:`native_broadcast_provider` against a fake, peer-free +producer (a stand-in for the native ``Broadcaster``) so the event mapping, +add/remove/reconnect consumer lifecycle and the Chrome cname munge can be +checked without a real WebRTC connection. +""" + +from __future__ import annotations + +import json +import os + +import pytest +from neuracore_types import MessageType + +from neuracore.core.streaming.p2p.provider.native_broadcast_provider import ( + CHROME_SDP_ENV, + NativeBroadcastProvider, + inbound_candidate, + needs_reconnect, + outbound_signal, +) + + +class FakeProducer: + """Records calls and replays a scripted ``drain_events`` queue.""" + + def __init__(self) -> None: + self.added: list[str] = [] + self.removed: list[str] = [] + self.answers: list[tuple[str, str]] = [] + self.candidates: list[tuple[str, str, str | None]] = [] + self.tracks: list[str] = [] + self.frames: list[tuple[str, object]] = [] + self.closed = False + self._pending: list[dict] = [] + # Reference model of the native Broadcaster's data-channel fan-out: which + # consumers are live, the registry of channels every consumer gets, and + # each consumer's received messages keyed by label. This faithfully mirrors + # the Rust contract (open on add, bootstrap for late joiners, tear down on + # leave) so the peer-free tests exercise the real semantics. + self._consumers: set[str] = set() + self._registry: list[tuple[str, str]] = [] + self.consumer_channels: dict[str, dict[str, list[str]]] = {} + + def queue(self, *events: dict) -> None: + self._pending.extend(events) + + # native Broadcaster API ------------------------------------------------ + def add_consumer(self, consumer_id: str) -> None: + self.added.append(consumer_id) + self._consumers.add(consumer_id) + # A late joiner gets every already-registered channel at bootstrap. + self.consumer_channels[consumer_id] = { + label: [] for label, _kind in self._registry + } + + def remove_consumer(self, consumer_id: str) -> None: + self.removed.append(consumer_id) + self._consumers.discard(consumer_id) + # A leaving consumer's channels tear down with it (no registry leak). + self.consumer_channels.pop(consumer_id, None) + + def set_remote_answer(self, consumer_id: str, sdp: str) -> None: + self.answers.append((consumer_id, sdp)) + + def add_remote_candidate( + self, consumer_id: str, candidate: str, mid: str | None + ) -> None: + self.candidates.append((consumer_id, candidate, mid)) + + def add_video_track(self, track_id: str) -> None: + self.tracks.append(track_id) + + def remove_video_track(self, track_id: str) -> None: + self.tracks.remove(track_id) + + def submit_frame(self, track_id: str, frame: object) -> None: + self.frames.append((track_id, frame)) + + def add_data_channel(self, label: str, kind: str) -> None: + self._registry.append((label, kind)) + # Opened on every current consumer; future consumers get it at bootstrap. + for channels in self.consumer_channels.values(): + channels.setdefault(label, []) + + def send_json(self, label: str, payload: str) -> None: + # Fans to every consumer that carries the label. + for channels in self.consumer_channels.values(): + if label in channels: + channels[label].append(payload) + + def drain_events(self) -> list[dict]: + drained, self._pending = self._pending, [] + return drained + + def close(self) -> None: + self.closed = True + + +def make_provider(*, browser_facing: bool = True): + producer = FakeProducer() + sent: list[tuple] = [] + provider = NativeBroadcastProvider( + producer, + lambda cid, rid, mt, data: sent.append((cid, rid, mt, data)), + browser_facing=browser_facing, + ) + return provider, producer, sent + + +# --- pure helpers ---------------------------------------------------------- + + +def test_outbound_signal_maps_offer_and_candidate(): + offer = outbound_signal({ + "kind": "on_local_description", + "sdp_type": "offer", + "sdp": "v=0...", + "consumer_id": "c1", + }) + assert offer.consumer_id == "c1" + assert offer.message_type == MessageType.SDP_OFFER + assert offer.data == "v=0..." + + cand = outbound_signal({ + "kind": "on_local_candidate", + "candidate": "candidate:1 ...", + "mid": "0", + "consumer_id": "c1", + }) + assert cand.message_type == MessageType.ICE_CANDIDATE + payload = json.loads(cand.data) + assert payload["candidate"] == "candidate:1 ..." + assert payload["sdpMid"] == "0" + + +def test_outbound_signal_ignores_answers_and_untagged_events(): + # The producer is the offerer; an answer sdp_type is never sent outbound. + assert ( + outbound_signal({ + "kind": "on_local_description", + "sdp_type": "answer", + "sdp": "x", + "consumer_id": "c1", + }) + is None + ) + # A shared-encode event without a consumer_id is not deliverable. + assert outbound_signal({"kind": "on_local_candidate", "candidate": "x"}) is None + assert outbound_signal({"kind": "on_state", "state": "connected"}) is None + + +def test_inbound_candidate_parses_browser_payload(): + data = json.dumps({ + "candidate": "candidate:2 ...", + "sdpMid": "1", + "sdpMLineIndex": 1, + }) + assert inbound_candidate(data) == ("candidate:2 ...", "1") + + +def test_needs_reconnect_only_for_connection_errors(): + assert ( + needs_reconnect({ + "kind": "on_error", + "where": "connection", + "consumer_id": "c1", + }) + == "c1" + ) + assert ( + needs_reconnect({"kind": "on_error", "where": "encode", "consumer_id": "c1"}) + is None + ) + assert needs_reconnect({"kind": "on_state", "state": "failed"}) is None + + +# --- cname munge ----------------------------------------------------------- + + +def test_browser_session_enables_cname_munge(monkeypatch): + monkeypatch.delenv(CHROME_SDP_ENV, raising=False) + make_provider(browser_facing=True) + assert os.environ.get(CHROME_SDP_ENV) == "1" + + +def test_loopback_session_leaves_cname_munge_off(monkeypatch): + monkeypatch.delenv(CHROME_SDP_ENV, raising=False) + make_provider(browser_facing=False) + assert CHROME_SDP_ENV not in os.environ + + +# --- consumer lifecycle ---------------------------------------------------- + + +def test_add_and_remove_consumer(): + provider, producer, _ = make_provider() + provider.add_consumer("c1", "stream-1") + provider.add_consumer("c1", "stream-1") # idempotent + assert producer.added == ["c1"] + + provider.remove_consumer("c1") + provider.remove_consumer("c1") # idempotent + assert producer.removed == ["c1"] + + +def test_reconnect_is_remove_then_readd(): + provider, producer, _ = make_provider() + provider.add_consumer("c1", "stream-1") + provider.reconnect_consumer("c1") + assert producer.removed == ["c1"] + assert producer.added == ["c1", "c1"] + + +def test_reconnect_unknown_consumer_is_noop(): + provider, producer, _ = make_provider() + provider.reconnect_consumer("ghost") + assert producer.removed == [] + + +# --- inbound signaling ----------------------------------------------------- + + +def test_on_answer_and_candidate_feed_the_producer(): + provider, producer, _ = make_provider() + provider.add_consumer("c1", "stream-1") + provider.on_answer("c1", "answer-sdp") + provider.on_ice_candidate( + "c1", json.dumps({"candidate": "candidate:9 ...", "sdpMid": "2"}) + ) + assert producer.answers == [("c1", "answer-sdp")] + assert producer.candidates == [("c1", "candidate:9 ...", "2")] + + +def test_inbound_for_unknown_consumer_is_ignored(): + provider, producer, _ = make_provider() + provider.on_answer("ghost", "sdp") + provider.on_ice_candidate("ghost", json.dumps({"candidate": "x", "sdpMid": "0"})) + assert producer.answers == [] + assert producer.candidates == [] + + +# --- the pump -------------------------------------------------------------- + + +def test_pump_routes_offer_to_the_right_browser_transport(): + provider, producer, sent = make_provider() + provider.add_consumer("c1", "stream-1") + producer.queue( + { + "kind": "on_local_description", + "sdp_type": "offer", + "sdp": "offer-sdp", + "consumer_id": "c1", + }, + { + "kind": "on_local_candidate", + "candidate": "candidate:1 ...", + "mid": "0", + "consumer_id": "c1", + }, + ) + provider.pump_once() + + assert sent[0] == ("c1", "stream-1", MessageType.SDP_OFFER, "offer-sdp") + assert sent[1][0:3] == ("c1", "stream-1", MessageType.ICE_CANDIDATE) + + +def test_pump_handles_reconnect_error_by_readding_consumer(): + provider, producer, sent = make_provider() + provider.add_consumer("c1", "stream-1") + producer.queue({ + "kind": "on_error", + "where": "connection", + "consumer_id": "c1", + "detail": "dropped", + }) + provider.pump_once() + + assert producer.removed == ["c1"] + assert producer.added == ["c1", "c1"] + assert sent == [] # an error is not forwarded as a signaling message + + +def test_pump_drops_signal_for_an_unknown_consumer(): + provider, producer, sent = make_provider() + producer.queue({ + "kind": "on_local_description", + "sdp_type": "offer", + "sdp": "x", + "consumer_id": "ghost", + }) + provider.pump_once() + assert sent == [] + + +# --- media + close --------------------------------------------------------- + + +def test_video_track_registration_and_close(): + provider, producer, _ = make_provider() + provider.add_video_track("cam0") + provider.add_video_track("cam0") # idempotent + provider.submit_frame("cam0", object()) + assert producer.tracks == ["cam0"] + assert len(producer.frames) == 1 + + provider.close() + assert producer.closed is True + + +# --- data-channel fan-out -------------------------------------------------- + + +def test_add_data_channel_fans_to_all_current_consumers(): + provider, producer, _ = make_provider() + provider.add_consumer("a", "stream-a") + provider.add_consumer("b", "stream-b") + + provider.add_data_channel("joints") + + # Opened on every current consumer. + assert "joints" in producer.consumer_channels["a"] + assert "joints" in producer.consumer_channels["b"] + + +def test_add_data_channel_is_idempotent_per_label(): + provider, producer, _ = make_provider() + provider.add_consumer("a", "stream-a") + provider.add_data_channel("joints") + provider.add_data_channel("joints") # second add is a no-op + assert producer._registry == [("joints", "reliable")] + + +def test_send_json_fans_to_every_consumer(): + provider, producer, _ = make_provider() + provider.add_consumer("a", "stream-a") + provider.add_consumer("b", "stream-b") + provider.add_data_channel("json") + + provider.send_json("json", '{"i": 1}') + + assert producer.consumer_channels["a"]["json"] == ['{"i": 1}'] + assert producer.consumer_channels["b"]["json"] == ['{"i": 1}'] + + +def test_a_consumer_added_after_add_data_channel_still_gets_the_channel(): + provider, producer, _ = make_provider() + provider.add_data_channel("json") + # The browser joins only afterwards. + provider.add_consumer("late", "stream-late") + + provider.send_json("json", '{"i": 7}') + + # The late joiner got the channel at bootstrap and receives the message. + assert producer.consumer_channels["late"]["json"] == ['{"i": 7}'] + + +def test_a_leaving_consumer_tears_down_its_channels_no_registry_leak(): + provider, producer, _ = make_provider() + provider.add_consumer("a", "stream-a") + provider.add_consumer("b", "stream-b") + provider.add_data_channel("json") + + provider.remove_consumer("a") + provider.send_json("json", '{"i": 2}') + + # 'a' is gone with its channels; only 'b' receives. + assert "a" not in producer.consumer_channels + assert producer.consumer_channels["b"]["json"] == ['{"i": 2}'] + + +def test_recording_bridge_routes_log_calls_to_broadcaster_send_json(): + """The RecordingContext bridge drives ``send_json`` over the Broadcaster.""" + from neuracore.core.streaming.p2p.recording_bridge import WebrtcRecordingBridge + + provider, producer, _ = make_provider() + provider.add_consumer("a", "stream-a") + bridge = WebrtcRecordingBridge(provider) + + bridge.log_joints("joint_positions", timestamp=1.0, items=[("j0", 0.5)]) + bridge.log_json("rgb", "cam", b"{}", timestamp=2.0) + + # Both log entry points opened their channel and reached the consumer. + assert json.loads(producer.consumer_channels["a"]["joint_positions"][0]) == { + "type": "joints", + "data_type": "joint_positions", + "timestamp": 1.0, + "values": {"j0": 0.5}, + } + assert json.loads(producer.consumer_channels["a"]["rgb/cam"][0])["type"] == "json" + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-q"])) diff --git a/tests/unit/webrtc/test_recording_bridge.py b/tests/unit/webrtc/test_recording_bridge.py new file mode 100644 index 000000000..283470074 --- /dev/null +++ b/tests/unit/webrtc/test_recording_bridge.py @@ -0,0 +1,106 @@ +"""Unit tests for ``WebrtcRecordingBridge`` (PR2). + +The bridge mirrors ``RecordingContext.log_json`` / ``log_joints`` and forwards +each call to ``Producer.send_json`` over a lazily-opened reliable data channel. +It is duck-typed against the native producer, so a tiny spy implementing +``add_data_channel`` / ``send_json`` stands in for it — no native module, no peer +connection. These tests pin the channel-label derivation, the lazy open, and the +JSON forwarded onto the one send path. +""" + +from __future__ import annotations + +import json + +import pytest + +from neuracore.core.streaming.p2p.recording_bridge import WebrtcRecordingBridge + + +class SpyProducer: + """The slice of the native ``Producer`` the bridge touches, recorded.""" + + def __init__(self) -> None: + self.opened: list[tuple[str, str]] = [] + self.sent: list[tuple[str, str]] = [] + + def add_data_channel(self, label: str, kind: str) -> None: + self.opened.append((label, kind)) + + def send_json(self, label: str, payload: str) -> None: + self.sent.append((label, payload)) + + +def test_log_json_derives_the_data_type_slash_name_channel() -> None: + producer = SpyProducer() + bridge = WebrtcRecordingBridge(producer) + + bridge.log_json("rgb", "wrist_cam", b'{"frame": 1}', timestamp=12.5) + + # Label is data_type/name; the channel is opened reliable on first use. + assert producer.opened == [("rgb/wrist_cam", "reliable")] + assert len(producer.sent) == 1 + label, payload = producer.sent[0] + assert label == "rgb/wrist_cam" + assert json.loads(payload) == { + "type": "json", + "data_type": "rgb", + "name": "wrist_cam", + "timestamp": 12.5, + "payload": '{"frame": 1}', + } + + +def test_log_joints_derives_the_data_type_channel() -> None: + producer = SpyProducer() + bridge = WebrtcRecordingBridge(producer) + + bridge.log_joints( + "joint_positions", timestamp=3.0, items=[("j0", 1.0), ("j1", 2.5)] + ) + + assert producer.opened == [("joint_positions", "reliable")] + label, payload = producer.sent[0] + assert label == "joint_positions" + assert json.loads(payload) == { + "type": "joints", + "data_type": "joint_positions", + "timestamp": 3.0, + "values": {"j0": 1.0, "j1": 2.5}, + } + + +def test_channels_open_lazily_and_exactly_once_per_label() -> None: + producer = SpyProducer() + bridge = WebrtcRecordingBridge(producer) + + bridge.log_json("rgb", "cam", b"{}", timestamp=0.0) + bridge.log_json("rgb", "cam", b"{}", timestamp=1.0) + bridge.log_joints("joints", timestamp=0.0, items=[("j0", 0.0)]) + bridge.log_joints("joints", timestamp=1.0, items=[("j0", 1.0)]) + + # One open per distinct stream, despite four log calls. + assert producer.opened == [("rgb/cam", "reliable"), ("joints", "reliable")] + assert len(producer.sent) == 4 + + +def test_empty_joint_batch_sends_nothing() -> None: + producer = SpyProducer() + bridge = WebrtcRecordingBridge(producer) + + bridge.log_joints("joints", timestamp=0.0, items=[]) + + assert producer.opened == [] + assert producer.sent == [] + + +def test_control_label_is_reserved_and_rejected() -> None: + producer = SpyProducer() + bridge = WebrtcRecordingBridge(producer) + + # "control" carries the manifest; the recording path must never use it as a + # stream label. log_joints' label is the data_type, so this hits the guard. + with pytest.raises(ValueError): + bridge.log_joints("control", timestamp=0.0, items=[("j0", 0.0)]) + assert producer.opened == [] + assert producer.sent == [] diff --git a/tests/unit/webrtc/test_webrtc_selection.py b/tests/unit/webrtc/test_webrtc_selection.py new file mode 100644 index 000000000..63e0ab914 --- /dev/null +++ b/tests/unit/webrtc/test_webrtc_selection.py @@ -0,0 +1,79 @@ +"""Unit tests for the feature-flag selection and native loader (PR0). + +``rust_webrtc_enabled`` gates the Rust stack on ``NCD_RUST_WEBRTC``; +``load_native`` imports the compiled extension and, when it is absent, raises a +``RuntimeError`` carrying a build/fallback hint. Neither path needs the native +module to be built — the import is monkeypatched. +""" + +from __future__ import annotations + +from types import ModuleType + +import pytest + +from neuracore.core.streaming.p2p import webrtc_selection + + +@pytest.mark.parametrize("value", ["1", "true", "True", "YES", "y", " yes "]) +def test_rust_webrtc_enabled_truthy_values( + monkeypatch: pytest.MonkeyPatch, value: str +) -> None: + monkeypatch.setenv("NCD_RUST_WEBRTC", value) + assert webrtc_selection.rust_webrtc_enabled() is True + + +@pytest.mark.parametrize("value", ["0", "false", "no", "n", "", "off", "2"]) +def test_rust_webrtc_enabled_falsy_values( + monkeypatch: pytest.MonkeyPatch, value: str +) -> None: + monkeypatch.setenv("NCD_RUST_WEBRTC", value) + assert webrtc_selection.rust_webrtc_enabled() is False + + +def test_rust_webrtc_enabled_unset_is_false(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("NCD_RUST_WEBRTC", raising=False) + assert webrtc_selection.rust_webrtc_enabled() is False + + +def test_load_native_raises_a_hinted_runtime_error_when_absent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Force a fresh import attempt and make the extension import fail. + monkeypatch.setattr(webrtc_selection, "_NATIVE_MODULE", None) + + def _missing(name: str) -> ModuleType: + raise ImportError(f"no module named {name!r}") + + monkeypatch.setattr(webrtc_selection, "import_module", _missing) + + with pytest.raises(RuntimeError) as excinfo: + webrtc_selection.load_native() + + # The hint must point at the build script and the aiortc fallback. + message = str(excinfo.value) + assert "build_wheel_artefacts.sh" in message + assert "NCD_RUST_WEBRTC" in message + # The original ImportError is chained for debuggability. + assert isinstance(excinfo.value.__cause__, ImportError) + + +def test_load_native_caches_and_returns_the_module( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(webrtc_selection, "_NATIVE_MODULE", None) + sentinel = ModuleType("fake_native_webrtc") + calls: list[str] = [] + + def _import(name: str) -> ModuleType: + calls.append(name) + return sentinel + + monkeypatch.setattr(webrtc_selection, "import_module", _import) + + first = webrtc_selection.load_native() + second = webrtc_selection.load_native() + + assert first is sentinel and second is sentinel + # Imported once (the documented dotted path), then served from the cache. + assert calls == ["neuracore.core.streaming.p2p._native_webrtc"]