diff --git a/codex-rs/core-api/src/lib.rs b/codex-rs/core-api/src/lib.rs index e87ee82f309..81cddce050b 100644 --- a/codex-rs/core-api/src/lib.rs +++ b/codex-rs/core-api/src/lib.rs @@ -45,7 +45,14 @@ pub use codex_core::resolve_installation_id; pub use codex_core::skills::SkillsManager; pub use codex_core::thread_store_from_config; pub use codex_exec_server::EnvironmentManager; +pub use codex_exec_server::ExecServerError; pub use codex_exec_server::ExecServerRuntimePaths; +pub use codex_exec_server::NoiseChannelIdentity; +pub use codex_exec_server::NoiseChannelPublicKey; +pub use codex_exec_server::NoiseRendezvousConnectArgs; +pub use codex_exec_server::NoiseRendezvousConnectBundle; +pub use codex_exec_server::NoiseRendezvousConnectProvider; +pub use codex_exec_server::SharedNoiseRendezvousConnectProvider; pub use codex_extension_api::empty_extension_registry; pub use codex_features::Feature; pub use codex_features::Features; diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml index 7ed185baaf2..8d2d5bfc501 100644 --- a/codex-rs/exec-server/Cargo.toml +++ b/codex-rs/exec-server/Cargo.toml @@ -48,6 +48,7 @@ tokio = { workspace = true, features = [ tokio-util = { workspace = true, features = ["rt"] } tokio-tungstenite = { workspace = true } tracing = { workspace = true } +url = { workspace = true } uuid = { workspace = true, features = ["v4"] } [dev-dependencies] diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index b6c56e9f5e2..34b659e861c 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -23,6 +23,8 @@ use crate::ProcessId; use crate::client_api::ExecServerClientConnectOptions; use crate::client_api::ExecServerTransportParams; use crate::client_api::HttpClient; +use crate::client_api::NoiseRendezvousConnectArgs; +use crate::client_api::NoiseRendezvousConnectBundle; use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerConnectArgs; use crate::connection::JsonRpcConnection; @@ -115,6 +117,16 @@ impl From for ExecServerClientConnectOptions { } } +impl From for ExecServerClientConnectOptions { + fn from(value: NoiseRendezvousConnectArgs) -> Self { + Self { + client_name: value.client_name, + initialize_timeout: value.initialize_timeout, + resume_session_id: value.resume_session_id, + } + } +} + impl From for ExecServerClientConnectOptions { fn from(value: StdioExecServerConnectArgs) -> Self { Self { @@ -137,6 +149,23 @@ impl RemoteExecServerConnectArgs { } } +impl NoiseRendezvousConnectArgs { + pub fn new( + bundle: NoiseRendezvousConnectBundle, + harness_identity: crate::NoiseChannelIdentity, + client_name: String, + ) -> Self { + Self { + bundle, + harness_identity, + client_name, + connect_timeout: CONNECT_TIMEOUT, + initialize_timeout: INITIALIZE_TIMEOUT, + resume_session_id: None, + } + } +} + pub(crate) struct SessionState { wake_tx: watch::Sender, events: ExecProcessEventLog, @@ -231,6 +260,7 @@ impl LazyRemoteExecServerClient { if matches!( &self.transport_params, ExecServerTransportParams::WebSocketUrl { .. } + | ExecServerTransportParams::NoiseRendezvous { .. } ) => { ExecServerClient::connect_for_transport(self.transport_params.clone()).await? @@ -317,6 +347,8 @@ pub enum ExecServerError { EnvironmentRegistryAuth(String), #[error("environment registry request failed: {0}")] EnvironmentRegistryRequest(#[from] reqwest::Error), + #[error(transparent)] + NoiseChannel(#[from] crate::NoiseChannelError), } impl ExecServerClient { diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index 899863723fe..a6251a87579 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; use std::time::Duration; use futures::future::BoxFuture; @@ -8,6 +9,8 @@ use crate::ExecServerError; use crate::HttpRequestParams; use crate::HttpRequestResponse; use crate::HttpResponseBodyStream; +use crate::NoiseChannelIdentity; +use crate::NoiseChannelPublicKey; pub(crate) const DEFAULT_REMOTE_EXEC_SERVER_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); pub(crate) const DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10); @@ -30,6 +33,67 @@ pub struct RemoteExecServerConnectArgs { pub resume_session_id: Option, } +/// Registry-authorized material for one Noise rendezvous connection attempt. +pub struct NoiseRendezvousConnectBundle { + pub websocket_url: String, + pub environment_id: String, + pub executor_registration_id: String, + pub executor_public_key: NoiseChannelPublicKey, + pub harness_key_authorization: String, +} + +impl std::fmt::Debug for NoiseRendezvousConnectBundle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NoiseRendezvousConnectBundle") + .field( + "websocket_url", + &redacted_websocket_url(&self.websocket_url), + ) + .field("environment_id", &self.environment_id) + .field("executor_registration_id", &self.executor_registration_id) + .field("executor_public_key", &self.executor_public_key) + .field("harness_key_authorization", &"") + .finish() + } +} + +/// Connection arguments for an authenticated Noise rendezvous exec-server. +pub struct NoiseRendezvousConnectArgs { + pub bundle: NoiseRendezvousConnectBundle, + pub harness_identity: NoiseChannelIdentity, + pub client_name: String, + pub connect_timeout: Duration, + pub initialize_timeout: Duration, + pub resume_session_id: Option, +} + +impl std::fmt::Debug for NoiseRendezvousConnectArgs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NoiseRendezvousConnectArgs") + .field("bundle", &self.bundle) + .field("harness_identity", &"") + .field("client_name", &self.client_name) + .field("connect_timeout", &self.connect_timeout) + .field("initialize_timeout", &self.initialize_timeout) + .field("resume_session_id", &self.resume_session_id) + .finish() + } +} + +/// Supplies fresh registry-authorized material for Noise rendezvous connections. +/// +/// Implementations must preserve one endpoint-local harness identity while +/// refreshing short-lived registry material for every physical connection attempt. +pub trait NoiseRendezvousConnectProvider: Send + Sync { + /// Environment ID this provider is authorized to connect to. + fn environment_id(&self) -> &str; + + /// Returns a fresh atomic bundle for one physical connection attempt. + fn connect_args(&self) -> BoxFuture<'_, Result>; +} + +pub type SharedNoiseRendezvousConnectProvider = Arc; + /// Stdio connection arguments for a command-backed exec-server. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct StdioExecServerConnectArgs { @@ -49,13 +113,16 @@ pub(crate) struct StdioExecServerCommand { } /// Parameters used to connect to a remote exec-server environment. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Clone)] pub(crate) enum ExecServerTransportParams { WebSocketUrl { websocket_url: String, connect_timeout: Duration, initialize_timeout: Duration, }, + NoiseRendezvous { + provider: SharedNoiseRendezvousConnectProvider, + }, #[allow(dead_code)] StdioCommand { command: StdioExecServerCommand, @@ -63,6 +130,35 @@ pub(crate) enum ExecServerTransportParams { }, } +impl std::fmt::Debug for ExecServerTransportParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::WebSocketUrl { + websocket_url, + connect_timeout, + initialize_timeout, + } => f + .debug_struct("WebSocketUrl") + .field("websocket_url", websocket_url) + .field("connect_timeout", connect_timeout) + .field("initialize_timeout", initialize_timeout) + .finish(), + Self::NoiseRendezvous { provider } => f + .debug_struct("NoiseRendezvous") + .field("environment_id", &provider.environment_id()) + .finish(), + Self::StdioCommand { + command, + initialize_timeout, + } => f + .debug_struct("StdioCommand") + .field("command", command) + .field("initialize_timeout", initialize_timeout) + .finish(), + } + } +} + impl ExecServerTransportParams { pub(crate) fn websocket_url(websocket_url: String) -> Self { Self::WebSocketUrl { @@ -73,6 +169,17 @@ impl ExecServerTransportParams { } } +pub(crate) fn redacted_websocket_url(websocket_url: &str) -> String { + match url::Url::parse(websocket_url) { + Ok(mut url) => { + url.set_query(None); + url.set_fragment(None); + url.to_string() + } + Err(_) => "".to_string(), + } +} + /// Sends HTTP requests through a runtime-selected transport. /// /// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers diff --git a/codex-rs/exec-server/src/client_transport.rs b/codex-rs/exec-server/src/client_transport.rs index 4bdc09a80ed..ade4c0689db 100644 --- a/codex-rs/exec-server/src/client_transport.rs +++ b/codex-rs/exec-server/src/client_transport.rs @@ -4,6 +4,7 @@ use tokio::io::BufReader; use tokio::process::Command; use tokio::time::timeout; use tokio_tungstenite::connect_async; +use tokio_tungstenite::connect_async_with_config; use tracing::debug; use tracing::warn; @@ -11,10 +12,15 @@ use codex_utils_rustls_provider::ensure_rustls_crypto_provider; use crate::ExecServerClient; use crate::ExecServerError; +use crate::client_api::NoiseRendezvousConnectArgs; +use crate::client_api::NoiseRendezvousConnectBundle; use crate::client_api::RemoteExecServerConnectArgs; use crate::client_api::StdioExecServerCommand; use crate::client_api::StdioExecServerConnectArgs; +use crate::client_api::redacted_websocket_url; use crate::connection::JsonRpcConnection; +use crate::noise_relay::noise_harness_connection_from_websocket; +use crate::noise_relay::noise_relay_websocket_config; use crate::relay::harness_connection_from_websocket; const ENVIRONMENT_CLIENT_NAME: &str = "codex-environment"; @@ -38,6 +44,15 @@ impl ExecServerClient { }) .await } + crate::client_api::ExecServerTransportParams::NoiseRendezvous { provider } => { + let args = provider.connect_args().await?; + if args.bundle.environment_id != provider.environment_id() { + return Err(ExecServerError::Protocol( + "Noise rendezvous provider returned a different environment id".to_string(), + )); + } + Self::connect_noise_rendezvous(args).await + } crate::client_api::ExecServerTransportParams::StdioCommand { command, initialize_timeout, @@ -79,6 +94,69 @@ impl ExecServerClient { Self::connect(connection, args.into()).await } + pub async fn connect_noise_rendezvous( + args: NoiseRendezvousConnectArgs, + ) -> Result { + ensure_rustls_crypto_provider(); + // This connect call owns the complete registry-issued bundle. Move each + // sensitive value into the transport task exactly once rather than + // leaving extra copies of the harness authorization or endpoint identity + // alive in `args` after the handshake starts. + let NoiseRendezvousConnectArgs { + bundle, + harness_identity, + client_name, + connect_timeout, + initialize_timeout, + resume_session_id, + } = args; + let NoiseRendezvousConnectBundle { + websocket_url, + environment_id, + executor_registration_id, + executor_public_key, + harness_key_authorization, + } = bundle; + let diagnostic_url = redacted_websocket_url(&websocket_url); + let (stream, _) = timeout( + connect_timeout, + connect_async_with_config( + websocket_url.as_str(), + Some(noise_relay_websocket_config()), + /*disable_nagle*/ false, + ), + ) + .await + .map_err(|_| ExecServerError::WebSocketConnectTimeout { + url: diagnostic_url.clone(), + timeout: connect_timeout, + })? + .map_err(|source| ExecServerError::WebSocketConnect { + url: diagnostic_url.clone(), + source, + })?; + + let connection_label = format!("Noise exec-server rendezvous websocket {diagnostic_url}"); + let connection = noise_harness_connection_from_websocket( + stream, + connection_label, + environment_id, + executor_registration_id, + harness_identity, + executor_public_key, + harness_key_authorization, + ); + Self::connect( + connection, + crate::client_api::ExecServerClientConnectOptions { + client_name, + initialize_timeout, + resume_session_id, + }, + ) + .await + } + pub(crate) async fn connect_stdio_command( args: StdioExecServerConnectArgs, ) -> Result { diff --git a/codex-rs/exec-server/src/connection.rs b/codex-rs/exec-server/src/connection.rs index b211c504aca..1cace9df4ec 100644 --- a/codex-rs/exec-server/src/connection.rs +++ b/codex-rs/exec-server/src/connection.rs @@ -43,8 +43,11 @@ pub(crate) enum JsonRpcConnectionEvent { } #[derive(Clone)] +/// Describes resources owned by a JSON-RPC connection, not the protection of +/// its byte stream. `External` covers websocket, relay, and Noise connections +/// whose transport task owns no child process for Codex to terminate. pub(crate) enum JsonRpcTransport { - Plain, + External, Stdio { transport: StdioTransport }, } @@ -57,7 +60,7 @@ impl JsonRpcTransport { pub(crate) fn terminate(&self) { match self { - Self::Plain => {} + Self::External => {} Self::Stdio { transport } => transport.terminate(), } } @@ -315,7 +318,7 @@ impl JsonRpcConnection { incoming_rx, disconnected_rx, task_handles: vec![reader_task, writer_task], - transport: JsonRpcTransport::Plain, + transport: JsonRpcTransport::External, } } @@ -452,7 +455,7 @@ impl JsonRpcConnection { incoming_rx, disconnected_rx, task_handles: vec![websocket_task], - transport: JsonRpcTransport::Plain, + transport: JsonRpcTransport::External, } } diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index 55dc031273e..bb99b1f6bcd 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -6,6 +6,7 @@ use crate::ExecServerError; use crate::ExecServerRuntimePaths; use crate::ExecutorFileSystem; use crate::HttpClient; +use crate::SharedNoiseRendezvousConnectProvider; use crate::client::LazyRemoteExecServerClient; use crate::client::http_client::ReqwestHttpClient; use crate::client_api::ExecServerTransportParams; @@ -276,6 +277,34 @@ impl EnvironmentManager { .insert(environment_id, Arc::new(environment)); Ok(()) } + + /// Adds or replaces a named remote environment that connects through an + /// authenticated, end-to-end encrypted rendezvous stream. + pub fn upsert_noise_environment( + &self, + environment_id: String, + provider: SharedNoiseRendezvousConnectProvider, + ) -> Result<(), ExecServerError> { + if environment_id.is_empty() { + return Err(ExecServerError::Protocol( + "environment id cannot be empty".to_string(), + )); + } + if environment_id != provider.environment_id() { + return Err(ExecServerError::Protocol( + "Noise environment id does not match connection provider".to_string(), + )); + } + let environment = Environment::remote_with_transport( + ExecServerTransportParams::NoiseRendezvous { provider }, + self.local_runtime_paths.clone(), + ); + self.environments + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .insert(environment_id, Arc::new(environment)); + Ok(()) + } } /// Concrete execution/filesystem environment selected for a session. @@ -382,6 +411,7 @@ impl Environment { websocket_url: exec_server_url, .. } => Some(exec_server_url.clone()), + ExecServerTransportParams::NoiseRendezvous { .. } => None, ExecServerTransportParams::StdioCommand { .. } => None, }; let client = LazyRemoteExecServerClient::new(remote_transport.clone()); diff --git a/codex-rs/exec-server/src/environment_toml.rs b/codex-rs/exec-server/src/environment_toml.rs index 26c178b5b8c..daa024b4f58 100644 --- a/codex-rs/exec-server/src/environment_toml.rs +++ b/codex-rs/exec-server/src/environment_toml.rs @@ -48,7 +48,7 @@ struct EnvironmentToml { initialize_timeout_sec: Option, } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] struct TomlEnvironmentProvider { default: EnvironmentDefault, include_local: bool, @@ -577,18 +577,26 @@ mod tests { ) .expect("provider"); + let ExecServerTransportParams::StdioCommand { + command, + initialize_timeout, + } = &provider.environments[0].1 + else { + panic!("expected stdio transport"); + }; assert_eq!( - provider.environments[0].1, - ExecServerTransportParams::StdioCommand { - command: StdioExecServerCommand { - program: "ssh".to_string(), - args: Vec::new(), - env: HashMap::new(), - cwd: Some(config_dir.path().join("workspace")), - }, - initialize_timeout: DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT, + command, + &StdioExecServerCommand { + program: "ssh".to_string(), + args: Vec::new(), + env: HashMap::new(), + cwd: Some(config_dir.path().join("workspace")), } ); + assert_eq!( + *initialize_timeout, + DEFAULT_REMOTE_EXEC_SERVER_INITIALIZE_TIMEOUT + ); } #[test] @@ -614,26 +622,35 @@ mod tests { }) .expect("provider"); + let ExecServerTransportParams::WebSocketUrl { + websocket_url, + connect_timeout, + initialize_timeout, + } = &provider.environments[0].1 + else { + panic!("expected websocket transport"); + }; + assert_eq!(websocket_url, "ws://127.0.0.1:8765"); + assert_eq!(*connect_timeout, Duration::from_secs(12)); + assert_eq!(*initialize_timeout, Duration::from_secs(34)); + + let ExecServerTransportParams::StdioCommand { + command, + initialize_timeout, + } = &provider.environments[1].1 + else { + panic!("expected stdio transport"); + }; assert_eq!( - provider.environments[0].1, - ExecServerTransportParams::WebSocketUrl { - websocket_url: "ws://127.0.0.1:8765".to_string(), - connect_timeout: Duration::from_secs(12), - initialize_timeout: Duration::from_secs(34), - } - ); - assert_eq!( - provider.environments[1].1, - ExecServerTransportParams::StdioCommand { - command: StdioExecServerCommand { - program: "ssh".to_string(), - args: Vec::new(), - env: HashMap::new(), - cwd: None, - }, - initialize_timeout: Duration::from_secs(56), + command, + &StdioExecServerCommand { + program: "ssh".to_string(), + args: Vec::new(), + env: HashMap::new(), + cwd: None, } ); + assert_eq!(*initialize_timeout, Duration::from_secs(56)); } #[test] diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index b774452441d..19906c1fc61 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -13,6 +13,7 @@ mod fs_sandbox; mod local_file_system; mod local_process; mod noise_channel; +mod noise_relay; mod process; mod process_id; mod protocol; @@ -32,7 +33,11 @@ pub use client::http_client::HttpResponseBodyStream; pub use client::http_client::ReqwestHttpClient; pub use client_api::ExecServerClientConnectOptions; pub use client_api::HttpClient; +pub use client_api::NoiseRendezvousConnectArgs; +pub use client_api::NoiseRendezvousConnectBundle; +pub use client_api::NoiseRendezvousConnectProvider; pub use client_api::RemoteExecServerConnectArgs; +pub use client_api::SharedNoiseRendezvousConnectProvider; pub use codex_file_system::CopyOptions; pub use codex_file_system::CreateDirectoryOptions; pub use codex_file_system::ExecutorFileSystem; diff --git a/codex-rs/exec-server/src/noise_relay/environment.rs b/codex-rs/exec-server/src/noise_relay/environment.rs new file mode 100644 index 00000000000..6e8e43eadab --- /dev/null +++ b/codex-rs/exec-server/src/noise_relay/environment.rs @@ -0,0 +1,499 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; +use std::time::Duration; + +use futures::SinkExt; +use futures::StreamExt; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio::task::JoinSet; +use tokio::time::timeout; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::tungstenite::Message; +use tracing::debug; +use tracing::warn; + +use crate::ExecServerError; +use crate::connection::CHANNEL_CAPACITY; +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; +use crate::connection::JsonRpcTransport; +use crate::noise_channel::NoiseChannelIdentity; +use crate::noise_channel::NoiseChannelPublicKey; +use crate::noise_channel::NoiseTransport; +use crate::noise_channel::PendingResponderHandshake; +use crate::noise_channel::noise_channel_prologue; +use crate::noise_relay::message_framing::JsonRpcMessageDecoder; +use crate::noise_relay::message_framing::NOISE_RECORD_PLAINTEXT_LEN; +use crate::noise_relay::message_framing::frame_jsonrpc_message; +use crate::noise_relay::ordered_ciphertext::OrderedCiphertextFrames; +use crate::noise_relay::take_next_sequence; +use crate::relay::RelayFrameBodyKind; +use crate::relay::decode_relay_message_frame; +use crate::relay::encode_relay_message_frame; +use crate::relay_proto::RelayData; +use crate::relay_proto::RelayMessageFrame; +use crate::server::ConnectionProcessor; + +// This value is already part of the relay wire contract. Keep it stable even +// though the source module now uses the more precise Noise terminology. +const NOISE_RELAY_RESET_REASON: &str = "secure_relay_protocol_error"; +const MAX_ACTIVE_NOISE_RELAY_STREAMS: usize = 128; +const MAX_HARNESS_KEY_AUTHORIZATION_BYTES: usize = 4096; +const MAX_PENDING_HANDSHAKE_VALIDATIONS: usize = 32; +const HARNESS_KEY_VALIDATION_TIMEOUT: Duration = Duration::from_secs(10); + +/// Validates that a Noise-authenticated harness public key is authorized. +/// +/// Implementations must consult an authority independent of rendezvous. The +/// exec-server invokes this after parsing the first IK message and before +/// completing the responder handshake. +pub(crate) trait HarnessKeyValidator: Send + Sync { + fn validate_harness_key( + &self, + harness_public_key: &NoiseChannelPublicKey, + authorization: &str, + ) -> impl std::future::Future> + Send; +} + +/// Serve many authenticated virtual JSON-RPC streams over one executor websocket. +/// +/// Each stream has an independent Noise handshake and transport state. The +/// outer websocket and rendezvous route are treated as untrusted delivery: +/// malformed, unauthorized, or cryptographically invalid streams fail closed +/// without creating a `JsonRpcConnection`. +pub(crate) async fn run_noise_multiplexed_environment( + stream: WebSocketStream, + processor: ConnectionProcessor, + environment_id: String, + executor_registration_id: String, + identity: NoiseChannelIdentity, + validator: V, +) where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + V: HarnessKeyValidator + Clone + 'static, +{ + let mut websocket = stream; + let (physical_outgoing_tx, mut physical_outgoing_rx) = + mpsc::channel::>(CHANNEL_CAPACITY); + let mut streams: HashMap = HashMap::new(); + let mut pending_handshakes: HashMap = HashMap::new(); + let mut validation_tasks: JoinSet = JoinSet::new(); + let mut next_validation_id = 0u64; + + loop { + // Keep registry validation out of the main relay loop. A slow or + // malicious authorization request must not block existing streams or + // prevent other handshakes from being received and bounded. + let frame = tokio::select! { + maybe_encoded = physical_outgoing_rx.recv() => { + let Some(encoded) = maybe_encoded else { + break; + }; + if websocket.send(Message::Binary(encoded.into())).await.is_err() { + break; + } + continue; + } + validation_result = validation_tasks.join_next(), if !validation_tasks.is_empty() => { + match validation_result { + Some(Ok(validation_result)) => { + // Stream IDs may be reset and reused while validation + // is in flight. The monotonic validation ID ensures a + // stale task can never complete a newer handshake. + let is_current = pending_handshakes + .get(&validation_result.stream_id) + .is_some_and(|pending| { + pending.validation_id == validation_result.validation_id + }); + if !is_current { + continue; + } + let Some(pending) = + pending_handshakes.remove(&validation_result.stream_id) + else { + continue; + }; + if let Err(error) = validation_result.result { + warn!("Noise relay harness key validation failed: {error}"); + send_reset(&physical_outgoing_tx, validation_result.stream_id).await; + continue; + } + if streams.len() >= MAX_ACTIVE_NOISE_RELAY_STREAMS { + warn!("Noise relay has too many active streams"); + send_reset(&physical_outgoing_tx, validation_result.stream_id).await; + continue; + } + + // This is the only point where the responder completes + // IK and exposes a JSON-RPC stream: Noise authenticated + // the harness key and the registry authorized it. + let (transport, response) = match pending.handshake.complete() { + Ok(completed) => completed, + Err(error) => { + warn!("failed to complete Noise relay handshake: {error}"); + send_reset(&physical_outgoing_tx, validation_result.stream_id).await; + continue; + } + }; + let response = RelayMessageFrame::handshake( + validation_result.stream_id.clone(), + response, + ); + if physical_outgoing_tx + .send(encode_relay_message_frame(&response)) + .await + .is_err() + { + break; + } + streams.insert( + validation_result.stream_id.clone(), + spawn_noise_virtual_stream( + validation_result.stream_id, + processor.clone(), + physical_outgoing_tx.clone(), + transport, + ), + ); + } + Some(Err(error)) => { + warn!("Noise relay harness key validation task failed: {error}"); + let stream_ids = pending_handshakes.keys().cloned().collect::>(); + pending_handshakes.clear(); + for stream_id in stream_ids { + send_reset(&physical_outgoing_tx, stream_id).await; + } + } + None => {} + } + continue; + } + incoming_message = websocket.next() => match incoming_message { + Some(Ok(Message::Binary(payload))) => match decode_relay_message_frame(payload.as_ref()) { + Ok(frame) => frame, + Err(error) => { + warn!("dropping malformed Noise relay frame from harness: {error}"); + continue; + } + }, + Some(Ok(Message::Close(_))) | None => break, + Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_))) => continue, + Some(Ok(Message::Text(_))) => { + warn!("dropping non-binary Noise relay frame from harness"); + continue; + } + Some(Err(error)) => { + debug!("Noise multiplexed environment websocket read failed: {error}"); + break; + } + } + }; + + let kind = match frame.validate() { + Ok(kind) => kind, + Err(error) => { + warn!("dropping invalid Noise relay frame: {error}"); + continue; + } + }; + let stream_id = frame.stream_id.clone(); + match kind { + RelayFrameBodyKind::Handshake => { + // Bound all pre-authentication state before doing expensive + // hybrid cryptography or starting an external validation. + if streams.contains_key(&stream_id) || pending_handshakes.contains_key(&stream_id) { + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + if streams.len() >= MAX_ACTIVE_NOISE_RELAY_STREAMS { + warn!("Noise relay has too many active streams"); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + if validation_tasks.len() >= MAX_PENDING_HANDSHAKE_VALIDATIONS { + warn!("Noise relay has too many pending harness key validations"); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + let prologue = match noise_channel_prologue( + &environment_id, + &executor_registration_id, + &stream_id, + ) { + Ok(prologue) => prologue, + Err(error) => { + warn!("failed to build Noise relay prologue: {error}"); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + }; + let request = match frame.into_handshake_payload() { + Ok(request) => request, + Err(error) => { + warn!("failed to read Noise relay handshake frame: {error}"); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + }; + let mut pending = + match PendingResponderHandshake::read_request(&identity, &prologue, &request) { + Ok(pending) => pending, + Err(error) => { + warn!("failed to read Noise relay handshake request: {error}"); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + }; + + // The authorization is encrypted inside the first IK message. + // It is meaningful only alongside the initiator static key + // that Clatter authenticated from that same message. + let authorization = match String::from_utf8(pending.take_payload()) { + Ok(authorization) + if authorization.len() <= MAX_HARNESS_KEY_AUTHORIZATION_BYTES => + { + authorization + } + Ok(_) => { + warn!("Noise relay handshake authorization is too long"); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + Err(_) => { + warn!("Noise relay handshake authorization is not UTF-8"); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + }; + let harness_public_key = pending.initiator_public_key().clone(); + let validation_id = next_validation_id; + let Some(next_id) = next_validation_id.checked_add(1) else { + warn!("Noise relay harness key validation id exhausted"); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + }; + next_validation_id = next_id; + pending_handshakes.insert( + stream_id.clone(), + PendingHandshake { + validation_id, + handshake: pending, + }, + ); + let validator = validator.clone(); + + // Validation is time-bounded and concurrency-bounded above. + // Failure leaves no transport state and returns a generic + // protocol reset to avoid exposing authorization details. + validation_tasks.spawn(async move { + let result = match timeout( + HARNESS_KEY_VALIDATION_TIMEOUT, + validator.validate_harness_key(&harness_public_key, &authorization), + ) + .await + { + Ok(result) => result, + Err(_) => Err(ExecServerError::Protocol( + "timed out validating Noise relay harness key".to_string(), + )), + }; + HarnessKeyValidationResult { + stream_id, + validation_id, + result, + } + }); + } + RelayFrameBodyKind::Data => { + // Data before handshake completion is always invalid. Removing + // a pending handshake also ensures a peer cannot keep its + // authorization task alive while sending application records. + let Some(stream) = streams.get_mut(&stream_id) else { + pending_handshakes.remove(&stream_id); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + }; + let data = match frame.into_data() { + Ok(data) => data, + Err(error) => { + warn!("dropping malformed Noise relay data frame: {error}"); + streams.remove(&stream_id); + send_reset(&physical_outgoing_tx, stream_id).await; + continue; + } + }; + if let Err(error) = stream.receive_data(data).await { + warn!("failed to process Noise relay payload: {error}"); + streams.remove(&stream_id); + send_reset(&physical_outgoing_tx, stream_id).await; + } + } + RelayFrameBodyKind::Reset => { + pending_handshakes.remove(&stream_id); + if let Some(stream) = streams.remove(&stream_id) { + stream.disconnect(frame.into_reset_reason()).await; + } + } + RelayFrameBodyKind::Ack + | RelayFrameBodyKind::Resume + | RelayFrameBodyKind::Heartbeat => {} + } + } + + for (_stream_id, stream) in streams { + stream.disconnect(/*reason*/ None).await; + } +} + +struct PendingHandshake { + validation_id: u64, + handshake: PendingResponderHandshake, +} + +struct HarnessKeyValidationResult { + stream_id: String, + validation_id: u64, + result: Result<(), ExecServerError>, +} + +struct NoiseVirtualStream { + incoming_tx: mpsc::Sender, + disconnected_tx: watch::Sender, + transport: Arc>, + inbound_ciphertexts: OrderedCiphertextFrames, + inbound_decoder: JsonRpcMessageDecoder, +} + +impl NoiseVirtualStream { + async fn disconnect(self, reason: Option) { + let _ = self.disconnected_tx.send(true); + let _ = self + .incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { reason }) + .await; + } + + async fn receive_data(&mut self, data: RelayData) -> Result<(), ExecServerError> { + // Relay sequence ordering is enforced before taking the transport lock + // and decrypting. Each virtual stream owns one ordered Noise nonce + // space shared by its reader and writer transport halves. + for ciphertext in self.inbound_ciphertexts.push(data.seq, data.payload)? { + let plaintext = { + let mut transport = self + .transport + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + transport.decrypt(&ciphertext)? + }; + for message in self.inbound_decoder.push(&plaintext)? { + self.incoming_tx + .send(JsonRpcConnectionEvent::Message(message)) + .await + .map_err(|_| ExecServerError::Closed)?; + } + } + Ok(()) + } +} + +fn spawn_noise_virtual_stream( + stream_id: String, + processor: ConnectionProcessor, + physical_outgoing_tx: mpsc::Sender>, + transport: NoiseTransport, +) -> NoiseVirtualStream { + let (json_outgoing_tx, mut json_outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); + let transport = Arc::new(Mutex::new(transport)); + let writer_transport = Arc::clone(&transport); + let writer_stream_id = stream_id; + let writer_task = tokio::spawn(async move { + let mut next_seq = 0u32; + 'writer: while let Some(message) = json_outgoing_rx.recv().await { + // Frame first, then split into bounded Noise records. Each record + // receives one checked relay sequence and is encrypted exactly + // once, preserving the implicit Noise sending nonce. + let framed = match frame_jsonrpc_message(&message) { + Ok(framed) => framed, + Err(error) => { + warn!("failed to frame Noise virtual stream JSON-RPC payload: {error}"); + break; + } + }; + for plaintext_record in framed.chunks(NOISE_RECORD_PLAINTEXT_LEN) { + let seq = match take_next_sequence(&mut next_seq) { + Ok(seq) => seq, + Err(error) => { + warn!("Noise virtual stream sequence exhausted: {error}"); + break 'writer; + } + }; + let ciphertext = { + let mut transport = writer_transport + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + transport.encrypt(plaintext_record) + }; + let ciphertext = match ciphertext { + Ok(ciphertext) => ciphertext, + Err(error) => { + warn!("failed to encrypt Noise virtual stream payload: {error}"); + break 'writer; + } + }; + let frame = RelayMessageFrame::data(writer_stream_id.clone(), seq, ciphertext); + if physical_outgoing_tx + .send(encode_relay_message_frame(&frame)) + .await + .is_err() + { + break 'writer; + } + } + } + + // Tell the harness to discard this virtual stream whenever its writer + // exits, including processor shutdown or a cryptographic/send failure. + // Otherwise the peer could wait indefinitely on a dead stream. + let reset = + RelayMessageFrame::reset(writer_stream_id, NOISE_RELAY_RESET_REASON.to_string()); + let _ = physical_outgoing_tx + .send(encode_relay_message_frame(&reset)) + .await; + }); + + let connection = JsonRpcConnection { + outgoing_tx: json_outgoing_tx, + incoming_rx, + disconnected_rx, + task_handles: vec![writer_task], + transport: JsonRpcTransport::External, + }; + tokio::spawn(async move { + processor.run_connection(connection).await; + }); + + NoiseVirtualStream { + incoming_tx, + disconnected_tx, + transport, + inbound_ciphertexts: OrderedCiphertextFrames::default(), + inbound_decoder: JsonRpcMessageDecoder::default(), + } +} + +async fn send_reset(physical_outgoing_tx: &mpsc::Sender>, stream_id: String) { + let reset = RelayMessageFrame::reset(stream_id, NOISE_RELAY_RESET_REASON.to_string()); + let _ = physical_outgoing_tx + .send(encode_relay_message_frame(&reset)) + .await; +} + +#[cfg(test)] +#[path = "environment_tests.rs"] +mod tests; diff --git a/codex-rs/exec-server/src/noise_relay/environment_tests.rs b/codex-rs/exec-server/src/noise_relay/environment_tests.rs new file mode 100644 index 00000000000..1005dcc7997 --- /dev/null +++ b/codex-rs/exec-server/src/noise_relay/environment_tests.rs @@ -0,0 +1,163 @@ +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::time::Duration; + +use anyhow::Result; +use futures::SinkExt; +use futures::StreamExt; +use tokio::net::TcpListener; +use tokio::sync::Notify; +use tokio::time::timeout; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Message; + +use super::HarnessKeyValidator; +use super::MAX_HARNESS_KEY_AUTHORIZATION_BYTES; +use super::run_noise_multiplexed_environment; +use crate::ExecServerError; +use crate::ExecServerRuntimePaths; +use crate::noise_channel::InitiatorHandshake; +use crate::noise_channel::NoiseChannelIdentity; +use crate::noise_channel::NoiseChannelPublicKey; +use crate::noise_channel::noise_channel_prologue; +use crate::relay::RelayFrameBodyKind; +use crate::relay::decode_relay_message_frame; +use crate::relay::encode_relay_message_frame; +use crate::relay_proto::RelayMessageFrame; +use crate::server::ConnectionProcessor; + +const ENVIRONMENT_ID: &str = "environment-1"; +const EXECUTOR_REGISTRATION_ID: &str = "registration-1"; + +#[derive(Clone)] +struct BlockingValidator { + calls: Arc, + release: Arc, +} + +impl HarnessKeyValidator for BlockingValidator { + fn validate_harness_key( + &self, + _harness_public_key: &NoiseChannelPublicKey, + _authorization: &str, + ) -> impl std::future::Future> + Send { + let calls = Arc::clone(&self.calls); + let release = Arc::clone(&self.release); + async move { + calls.fetch_add(1, Ordering::SeqCst); + release.notified().await; + Ok(()) + } + } +} + +#[tokio::test] +async fn pending_harness_key_validation_does_not_block_new_handshakes() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let websocket_url = format!("ws://{}", listener.local_addr()?); + let harness_connection = tokio::spawn(connect_async(websocket_url)); + let (socket, _peer_addr) = listener.accept().await?; + let environment_websocket = accept_async(socket).await?; + let (mut harness_websocket, _response) = harness_connection.await??; + + let environment_identity = NoiseChannelIdentity::generate()?; + let harness_identity = NoiseChannelIdentity::generate()?; + let calls = Arc::new(AtomicUsize::new(0)); + let environment_task = tokio::spawn(run_noise_multiplexed_environment( + environment_websocket, + ConnectionProcessor::new(ExecServerRuntimePaths::new( + std::env::current_exe()?, + /*codex_linux_sandbox_exe*/ None, + )?), + ENVIRONMENT_ID.to_string(), + EXECUTOR_REGISTRATION_ID.to_string(), + environment_identity.clone(), + BlockingValidator { + calls: Arc::clone(&calls), + release: Arc::new(Notify::new()), + }, + )); + + for stream_id in ["stream-1", "stream-2"] { + let prologue = noise_channel_prologue(ENVIRONMENT_ID, EXECUTOR_REGISTRATION_ID, stream_id)?; + let (_handshake, request) = InitiatorHandshake::start( + &harness_identity, + &environment_identity.public_key(), + &prologue, + b"authorization", + )?; + let frame = RelayMessageFrame::handshake(stream_id.to_string(), request); + harness_websocket + .send(Message::Binary(encode_relay_message_frame(&frame).into())) + .await?; + } + + timeout(Duration::from_secs(1), async { + while calls.load(Ordering::SeqCst) != 2 { + tokio::task::yield_now().await; + } + }) + .await?; + + harness_websocket.close(None).await?; + timeout(Duration::from_secs(1), environment_task).await??; + Ok(()) +} + +#[tokio::test] +async fn oversized_harness_authorization_is_rejected_before_validation() -> Result<()> { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let websocket_url = format!("ws://{}", listener.local_addr()?); + let harness_connection = tokio::spawn(connect_async(websocket_url)); + let (socket, _peer_addr) = listener.accept().await?; + let environment_websocket = accept_async(socket).await?; + let (mut harness_websocket, _response) = harness_connection.await??; + + let environment_identity = NoiseChannelIdentity::generate()?; + let harness_identity = NoiseChannelIdentity::generate()?; + let calls = Arc::new(AtomicUsize::new(0)); + let environment_task = tokio::spawn(run_noise_multiplexed_environment( + environment_websocket, + ConnectionProcessor::new(ExecServerRuntimePaths::new( + std::env::current_exe()?, + /*codex_linux_sandbox_exe*/ None, + )?), + ENVIRONMENT_ID.to_string(), + EXECUTOR_REGISTRATION_ID.to_string(), + environment_identity.clone(), + BlockingValidator { + calls: Arc::clone(&calls), + release: Arc::new(Notify::new()), + }, + )); + + let stream_id = "stream-1"; + let prologue = noise_channel_prologue(ENVIRONMENT_ID, EXECUTOR_REGISTRATION_ID, stream_id)?; + let oversized_authorization = vec![b'a'; MAX_HARNESS_KEY_AUTHORIZATION_BYTES + 1]; + let (_handshake, request) = InitiatorHandshake::start( + &harness_identity, + &environment_identity.public_key(), + &prologue, + &oversized_authorization, + )?; + let frame = RelayMessageFrame::handshake(stream_id.to_string(), request); + harness_websocket + .send(Message::Binary(encode_relay_message_frame(&frame).into())) + .await?; + + let Message::Binary(payload) = timeout(Duration::from_secs(1), harness_websocket.next()) + .await? + .ok_or_else(|| anyhow::anyhow!("environment closed before sending reset"))?? + else { + anyhow::bail!("expected binary reset frame"); + }; + let reset = decode_relay_message_frame(payload.as_ref())?; + assert_eq!(reset.validate()?, RelayFrameBodyKind::Reset); + assert_eq!(calls.load(Ordering::SeqCst), 0); + + harness_websocket.close(None).await?; + timeout(Duration::from_secs(1), environment_task).await??; + Ok(()) +} diff --git a/codex-rs/exec-server/src/noise_relay/harness.rs b/codex-rs/exec-server/src/noise_relay/harness.rs new file mode 100644 index 00000000000..3c30a0cabb8 --- /dev/null +++ b/codex-rs/exec-server/src/noise_relay/harness.rs @@ -0,0 +1,402 @@ +use futures::Sink; +use futures::SinkExt; +use futures::Stream; +use futures::StreamExt; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio_tungstenite::tungstenite::Message; +use tracing::debug; +use tracing::warn; +use uuid::Uuid; + +use crate::ExecServerError; +use crate::connection::CHANNEL_CAPACITY; +use crate::connection::JsonRpcConnection; +use crate::connection::JsonRpcConnectionEvent; +use crate::connection::JsonRpcTransport; +use crate::noise_channel::InitiatorHandshake; +use crate::noise_channel::NoiseChannelIdentity; +use crate::noise_channel::NoiseChannelPublicKey; +use crate::noise_channel::NoiseTransport; +use crate::noise_channel::noise_channel_prologue; +use crate::noise_relay::message_framing::JsonRpcMessageDecoder; +use crate::noise_relay::message_framing::NOISE_RECORD_PLAINTEXT_LEN; +use crate::noise_relay::message_framing::frame_jsonrpc_message; +use crate::noise_relay::ordered_ciphertext::OrderedCiphertextFrames; +use crate::noise_relay::take_next_sequence; +use crate::relay::RelayFrameBodyKind; +use crate::relay::decode_relay_message_frame; +use crate::relay::encode_relay_message_frame; +use crate::relay_proto::RelayData; +use crate::relay_proto::RelayMessageFrame; + +/// Adapt one harness rendezvous websocket into an authenticated JSON-RPC connection. +/// +/// The returned connection is not usable until the background task completes +/// hybrid IK against the registry-pinned exec-server key. Rendezvous can see +/// stream metadata and ciphertext, but never JSON-RPC plaintext or either +/// endpoint's private key. +pub(crate) fn noise_harness_connection_from_websocket( + stream: T, + connection_label: String, + environment_id: String, + executor_registration_id: String, + identity: NoiseChannelIdentity, + responder_public_key: NoiseChannelPublicKey, + harness_key_authorization: String, +) -> JsonRpcConnection +where + T: Sink + Stream> + Unpin + Send + 'static, + E: std::fmt::Display + Send + 'static, +{ + let stream_id = Uuid::new_v4().to_string(); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (incoming_tx, incoming_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (disconnected_tx, disconnected_rx) = watch::channel(false); + + let websocket_task = tokio::spawn(async move { + let mut websocket = stream; + + // Bind the Noise transcript to the exact environment registration and + // virtual relay stream before emitting any handshake bytes. A captured + // handshake cannot be spliced onto a different routed connection. + let prologue = + match noise_channel_prologue(&environment_id, &executor_registration_id, &stream_id) { + Ok(prologue) => prologue, + Err(error) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + format!("failed to build Noise relay prologue: {error}"), + ) + .await; + return; + } + }; + let (initiator_handshake, request) = match InitiatorHandshake::start( + &identity, + &responder_public_key, + &prologue, + harness_key_authorization.as_bytes(), + ) { + Ok(handshake) => handshake, + Err(error) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + format!("failed to start Noise relay handshake: {error}"), + ) + .await; + return; + } + }; + + // Resume claims the stream ID at rendezvous; Handshake carries the + // opaque first IK message. No JSON-RPC data is sent before the + // responder proves possession of the pinned static key. + let resume = RelayMessageFrame::resume(stream_id.clone()); + let handshake = RelayMessageFrame::handshake(stream_id.clone(), request); + if websocket + .send(Message::Binary(encode_relay_message_frame(&resume).into())) + .await + .is_err() + || websocket + .send(Message::Binary( + encode_relay_message_frame(&handshake).into(), + )) + .await + .is_err() + { + let _ = disconnected_tx.send(true); + return; + } + + // During the handshake, ignore unrelated routed streams and control + // frames, but reject data on our stream. Accepting early data would + // create a plaintext or unauthenticated application path. + let mut transport = loop { + let Some(incoming_message) = websocket.next().await else { + send_disconnected( + &incoming_tx, + &disconnected_tx, + "Noise relay websocket closed during handshake".to_string(), + ) + .await; + return; + }; + let message = match incoming_message { + Ok(Message::Binary(payload)) => payload, + Ok(Message::Close(_)) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + "Noise relay websocket closed during handshake".to_string(), + ) + .await; + return; + } + Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_)) => continue, + Ok(Message::Text(_)) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + "Noise relay transport expects binary protobuf frames".to_string(), + ) + .await; + return; + } + Err(error) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + format!( + "failed to read Noise relay websocket from {connection_label}: {error}" + ), + ) + .await; + return; + } + }; + let frame = match decode_relay_message_frame(message.as_ref()) { + Ok(frame) => frame, + Err(error) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + format!("failed to parse Noise relay frame: {error}"), + ) + .await; + return; + } + }; + if frame.stream_id != stream_id { + continue; + } + match frame.validate() { + Ok(RelayFrameBodyKind::Handshake) => { + let response = match frame.into_handshake_payload() { + Ok(response) => response, + Err(error) => { + send_disconnected(&incoming_tx, &disconnected_tx, error.to_string()) + .await; + return; + } + }; + match initiator_handshake.finish(&response) { + Ok(transport) => break transport, + Err(error) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + format!("Noise relay handshake failed: {error}"), + ) + .await; + return; + } + } + } + Ok(RelayFrameBodyKind::Reset) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + frame + .into_reset_reason() + .unwrap_or_else(|| "Noise relay reset during handshake".to_string()), + ) + .await; + return; + } + Ok( + RelayFrameBodyKind::Ack + | RelayFrameBodyKind::Resume + | RelayFrameBodyKind::Heartbeat, + ) => {} + Ok(RelayFrameBodyKind::Data) | Err(_) => { + send_disconnected( + &incoming_tx, + &disconnected_tx, + "Noise relay received data before handshake completion".to_string(), + ) + .await; + return; + } + } + }; + + // After the handshake, each relay sequence maps to exactly one Noise + // transport record. Outbound records are encrypted once; inbound + // records are reordered and deduplicated before decryption. + let mut next_outbound_seq = 0u32; + let mut inbound_ciphertexts = OrderedCiphertextFrames::default(); + let mut inbound_decoder = JsonRpcMessageDecoder::default(); + 'relay: loop { + tokio::select! { + maybe_message = outgoing_rx.recv() => { + let Some(message) = maybe_message else { + break; + }; + let framed = match frame_jsonrpc_message(&message) { + Ok(framed) => framed, + Err(error) => { + warn!("failed to frame JSON-RPC payload for Noise relay: {error}"); + break; + } + }; + for plaintext_record in framed.chunks(NOISE_RECORD_PLAINTEXT_LEN) { + let seq = match take_next_sequence(&mut next_outbound_seq) { + Ok(seq) => seq, + Err(error) => { + warn!("Noise relay sequence exhausted: {error}"); + break 'relay; + } + }; + let ciphertext = match transport.encrypt(plaintext_record) { + Ok(ciphertext) => ciphertext, + Err(error) => { + warn!("failed to encrypt JSON-RPC payload for Noise relay: {error}"); + break 'relay; + } + }; + let frame = RelayMessageFrame::data(stream_id.clone(), seq, ciphertext); + if websocket + .send(Message::Binary(encode_relay_message_frame(&frame).into())) + .await + .is_err() + { + break 'relay; + } + } + } + incoming_message = websocket.next() => { + let Some(incoming_message) = incoming_message else { + break; + }; + match incoming_message { + Ok(Message::Binary(payload)) => { + let frame = match decode_relay_message_frame(payload.as_ref()) { + Ok(frame) => frame, + Err(error) => { + send_malformed(&incoming_tx, error.to_string()).await; + break; + } + }; + if frame.stream_id != stream_id { + continue; + } + match frame.validate() { + Ok(RelayFrameBodyKind::Data) => { + let data = match frame.into_data() { + Ok(data) => data, + Err(error) => { + send_malformed(&incoming_tx, error.to_string()).await; + break; + } + }; + if let Err(error) = receive_data( + &mut inbound_ciphertexts, + &mut transport, + &mut inbound_decoder, + data, + &incoming_tx, + ) + .await + { + send_malformed(&incoming_tx, error.to_string()).await; + break; + } + } + Ok(RelayFrameBodyKind::Reset) => { + let reason = frame.into_reset_reason(); + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { reason }) + .await; + break; + } + Ok( + RelayFrameBodyKind::Ack + | RelayFrameBodyKind::Resume + | RelayFrameBodyKind::Heartbeat, + ) => {} + Ok(RelayFrameBodyKind::Handshake) | Err(_) => { + send_malformed( + &incoming_tx, + "Noise relay received invalid post-handshake frame".to_string(), + ) + .await; + break; + } + } + } + Ok(Message::Close(_)) => break, + Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_)) => {} + Ok(Message::Text(_)) => { + send_malformed( + &incoming_tx, + "Noise relay transport expects binary protobuf frames".to_string(), + ) + .await; + break; + } + Err(error) => { + debug!("Noise relay websocket read failed: {error}"); + break; + } + } + } + } + } + let _ = disconnected_tx.send(true); + }); + + JsonRpcConnection { + outgoing_tx, + incoming_rx, + disconnected_rx, + task_handles: vec![websocket_task], + transport: JsonRpcTransport::External, + } +} + +async fn receive_data( + inbound_ciphertexts: &mut OrderedCiphertextFrames, + transport: &mut NoiseTransport, + decoder: &mut JsonRpcMessageDecoder, + data: RelayData, + incoming_tx: &mpsc::Sender, +) -> Result<(), ExecServerError> { + // Ordering must happen before decryption because Noise transport nonces are + // implicit. A future or duplicate ciphertext passed directly to Clatter + // would desynchronize the channel. + for ciphertext in inbound_ciphertexts.push(data.seq, data.payload)? { + let plaintext = transport.decrypt(&ciphertext)?; + + // The authenticated byte stream can carry partial or multiple JSON-RPC + // messages; emit only complete, successfully parsed messages. + for message in decoder.push(&plaintext)? { + incoming_tx + .send(JsonRpcConnectionEvent::Message(message)) + .await + .map_err(|_| ExecServerError::Closed)?; + } + } + Ok(()) +} + +async fn send_malformed(incoming_tx: &mpsc::Sender, reason: String) { + let _ = incoming_tx + .send(JsonRpcConnectionEvent::MalformedMessage { reason }) + .await; +} + +async fn send_disconnected( + incoming_tx: &mpsc::Sender, + disconnected_tx: &watch::Sender, + reason: String, +) { + let _ = disconnected_tx.send(true); + let _ = incoming_tx + .send(JsonRpcConnectionEvent::Disconnected { + reason: Some(reason), + }) + .await; +} diff --git a/codex-rs/exec-server/src/noise_relay/message_framing.rs b/codex-rs/exec-server/src/noise_relay/message_framing.rs new file mode 100644 index 00000000000..0b69f46b486 --- /dev/null +++ b/codex-rs/exec-server/src/noise_relay/message_framing.rs @@ -0,0 +1,92 @@ +use codex_app_server_protocol::JSONRPCMessage; + +use crate::ExecServerError; + +const LENGTH_PREFIX_BYTES: usize = size_of::(); +const MAX_NOISE_JSONRPC_MESSAGE_LEN: usize = 64 * 1024 * 1024; +pub(crate) const NOISE_RECORD_PLAINTEXT_LEN: usize = 60 * 1024; + +/// Serialize one JSON-RPC message into the encrypted record byte stream. +/// +/// Clatter limits an individual Noise message to 65,535 bytes, while valid +/// exec-server responses can be much larger. A four-byte authenticated length +/// prefix lets the caller split this byte stream into bounded Noise records and +/// lets the receiver reconstruct exact JSON-RPC message boundaries. +pub(crate) fn frame_jsonrpc_message(message: &JSONRPCMessage) -> Result, ExecServerError> { + let mut framed = vec![0; LENGTH_PREFIX_BYTES]; + serde_json::to_writer(&mut framed, message)?; + let message_len = framed.len() - LENGTH_PREFIX_BYTES; + if message_len > MAX_NOISE_JSONRPC_MESSAGE_LEN { + return Err(ExecServerError::Protocol( + "Noise relay JSON-RPC message exceeds maximum length".to_string(), + )); + } + framed[..LENGTH_PREFIX_BYTES].copy_from_slice(&(message_len as u32).to_be_bytes()); + Ok(framed) +} + +/// Incrementally reconstructs authenticated JSON-RPC messages from Noise records. +#[derive(Default)] +pub(crate) struct JsonRpcMessageDecoder { + buffered: Vec, +} + +impl JsonRpcMessageDecoder { + /// Append one decrypted record and return all complete framed messages. + pub(crate) fn push( + &mut self, + plaintext_record: &[u8], + ) -> Result, ExecServerError> { + if plaintext_record.len() > NOISE_RECORD_PLAINTEXT_LEN { + return Err(ExecServerError::Protocol( + "Noise relay plaintext record exceeds maximum length".to_string(), + )); + } + self.buffered.extend_from_slice(plaintext_record); + + // One record can finish multiple messages, and one message can span + // multiple records. Parse only after the authenticated length prefix + // and the full declared payload are present. + let mut messages = Vec::new(); + loop { + let Some(message_len) = self.next_message_len()? else { + break; + }; + let framed_len = LENGTH_PREFIX_BYTES + message_len; + if self.buffered.len() < framed_len { + break; + } + messages.push(serde_json::from_slice( + &self.buffered[LENGTH_PREFIX_BYTES..framed_len], + )?); + self.buffered.drain(..framed_len); + } + + // Even before a message is complete, keep reassembly memory bounded. + if self.buffered.len() > LENGTH_PREFIX_BYTES + MAX_NOISE_JSONRPC_MESSAGE_LEN { + return Err(ExecServerError::Protocol( + "Noise relay JSON-RPC reassembly buffer exceeds maximum length".to_string(), + )); + } + Ok(messages) + } + + fn next_message_len(&self) -> Result, ExecServerError> { + let Some(prefix) = self.buffered.get(..LENGTH_PREFIX_BYTES) else { + return Ok(None); + }; + let message_len = u32::from_be_bytes([prefix[0], prefix[1], prefix[2], prefix[3]]) as usize; + // Zero-length is never a JSON-RPC message, and rejecting an oversized + // declaration immediately avoids waiting for attacker-controlled data. + if message_len == 0 || message_len > MAX_NOISE_JSONRPC_MESSAGE_LEN { + return Err(ExecServerError::Protocol( + "Noise relay JSON-RPC message has invalid length".to_string(), + )); + } + Ok(Some(message_len)) + } +} + +#[cfg(test)] +#[path = "message_framing_tests.rs"] +mod tests; diff --git a/codex-rs/exec-server/src/noise_relay/message_framing_tests.rs b/codex-rs/exec-server/src/noise_relay/message_framing_tests.rs new file mode 100644 index 00000000000..c3df14bf562 --- /dev/null +++ b/codex-rs/exec-server/src/noise_relay/message_framing_tests.rs @@ -0,0 +1,52 @@ +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCNotification; +use pretty_assertions::assert_eq; + +use super::JsonRpcMessageDecoder; +use super::MAX_NOISE_JSONRPC_MESSAGE_LEN; +use super::NOISE_RECORD_PLAINTEXT_LEN; +use super::frame_jsonrpc_message; +use crate::ExecServerError; + +#[test] +fn fragments_and_reassembles_large_jsonrpc_message() { + let message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "large/test".to_string(), + params: Some(serde_json::json!({ + "data": "x".repeat(128 * 1024), + })), + }); + let framed = frame_jsonrpc_message(&message).unwrap(); + assert!(framed.len() > 128 * 1024); + + let mut decoder = JsonRpcMessageDecoder::default(); + let mut decoded = Vec::new(); + for record in framed.chunks(NOISE_RECORD_PLAINTEXT_LEN) { + decoded.extend(decoder.push(record).unwrap()); + } + + assert_eq!(decoded, vec![message]); +} + +#[test] +fn rejects_declared_message_length_above_limit_without_payload() { + let mut decoder = JsonRpcMessageDecoder::default(); + let declared_len = (MAX_NOISE_JSONRPC_MESSAGE_LEN as u32 + 1).to_be_bytes(); + + assert!(matches!( + decoder.push(&declared_len), + Err(ExecServerError::Protocol(message)) + if message == "Noise relay JSON-RPC message has invalid length" + )); +} + +#[test] +fn rejects_oversized_plaintext_record() { + let mut decoder = JsonRpcMessageDecoder::default(); + + assert!(matches!( + decoder.push(&vec![0; NOISE_RECORD_PLAINTEXT_LEN + 1]), + Err(ExecServerError::Protocol(message)) + if message == "Noise relay plaintext record exceeds maximum length" + )); +} diff --git a/codex-rs/exec-server/src/noise_relay/mod.rs b/codex-rs/exec-server/src/noise_relay/mod.rs new file mode 100644 index 00000000000..2399a8cc8cd --- /dev/null +++ b/codex-rs/exec-server/src/noise_relay/mod.rs @@ -0,0 +1,33 @@ +mod environment; +mod harness; +mod message_framing; +mod ordered_ciphertext; + +use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; + +use crate::ExecServerError; + +pub(crate) use environment::HarnessKeyValidator; +pub(crate) use environment::run_noise_multiplexed_environment; +pub(crate) use harness::noise_harness_connection_from_websocket; + +// This bounds allocation in tungstenite before protobuf and Noise record +// validation run. It comfortably fits one maximum Noise record plus metadata. +const MAX_NOISE_RELAY_WEBSOCKET_MESSAGE_SIZE: usize = 256 * 1024; + +/// Return the websocket limits required by every Noise relay endpoint. +pub(crate) fn noise_relay_websocket_config() -> WebSocketConfig { + WebSocketConfig::default() + .max_frame_size(Some(MAX_NOISE_RELAY_WEBSOCKET_MESSAGE_SIZE)) + .max_message_size(Some(MAX_NOISE_RELAY_WEBSOCKET_MESSAGE_SIZE)) +} + +fn take_next_sequence(next_seq: &mut u32) -> Result { + // Never wrap: relay sequence is the explicit ordering key for an implicit + // Noise nonce. Reusing zero after u32::MAX would be ambiguous and unsafe. + let seq = *next_seq; + *next_seq = next_seq.checked_add(1).ok_or_else(|| { + ExecServerError::Protocol("Noise relay sequence number exhausted".to_string()) + })?; + Ok(seq) +} diff --git a/codex-rs/exec-server/src/noise_relay/ordered_ciphertext.rs b/codex-rs/exec-server/src/noise_relay/ordered_ciphertext.rs new file mode 100644 index 00000000000..f122d497074 --- /dev/null +++ b/codex-rs/exec-server/src/noise_relay/ordered_ciphertext.rs @@ -0,0 +1,85 @@ +use std::collections::BTreeMap; + +use crate::ExecServerError; + +const MAX_REORDER_DISTANCE: u32 = 64; +const MAX_PENDING_FRAMES: usize = 64; +const MAX_PENDING_BYTES: usize = 1024 * 1024; + +/// Bounded pre-decryption reorder buffer for Noise transport records. +/// +/// Relay delivery can be duplicated or reordered, but Noise transport nonces +/// are strictly ordered. This type absorbs only a small reliable-delivery +/// window and releases ciphertexts exactly once in nonce order. It must sit +/// before `NoiseTransport::decrypt`; attempting to decrypt a future record +/// would advance or desynchronize cryptographic state. +#[derive(Default)] +pub(crate) struct OrderedCiphertextFrames { + next_seq: u32, + pending: BTreeMap>, + pending_bytes: usize, +} + +impl OrderedCiphertextFrames { + /// Accept one relay record and return the newly contiguous ciphertext run. + pub(crate) fn push( + &mut self, + seq: u32, + payload: Vec, + ) -> Result>, ExecServerError> { + // Already-delivered and already-buffered frames are retries. Keep the + // first buffered ciphertext for a sequence so a duplicate cannot + // replace it before authentication. + if seq < self.next_seq || self.pending.contains_key(&seq) { + return Ok(Vec::new()); + } + if seq > self.next_seq { + // Bound both sequence distance and actual buffered memory. Without + // both limits, an authenticated peer could hold the stream open + // while forcing unbounded pre-decryption state. + if seq - self.next_seq > MAX_REORDER_DISTANCE { + return Err(ExecServerError::Protocol( + "Noise relay ciphertext exceeds reorder window".to_string(), + )); + } + let pending_bytes = self + .pending_bytes + .checked_add(payload.len()) + .ok_or_else(|| { + ExecServerError::Protocol( + "Noise relay pending ciphertext byte count overflowed".to_string(), + ) + })?; + if self.pending.len() >= MAX_PENDING_FRAMES || pending_bytes > MAX_PENDING_BYTES { + return Err(ExecServerError::Protocol( + "Noise relay pending ciphertext buffer is full".to_string(), + )); + } + self.pending.insert(seq, payload); + self.pending_bytes = pending_bytes; + return Ok(Vec::new()); + } + + // The expected record closes the current gap. Release it and every + // contiguous buffered successor so Noise sees exactly nonce order. + let mut ready = vec![payload]; + self.advance()?; + while let Some(payload) = self.pending.remove(&self.next_seq) { + self.pending_bytes -= payload.len(); + ready.push(payload); + self.advance()?; + } + Ok(ready) + } + + fn advance(&mut self) -> Result<(), ExecServerError> { + self.next_seq = self.next_seq.checked_add(1).ok_or_else(|| { + ExecServerError::Protocol("Noise relay sequence number exhausted".to_string()) + })?; + Ok(()) + } +} + +#[cfg(test)] +#[path = "ordered_ciphertext_tests.rs"] +mod tests; diff --git a/codex-rs/exec-server/src/noise_relay/ordered_ciphertext_tests.rs b/codex-rs/exec-server/src/noise_relay/ordered_ciphertext_tests.rs new file mode 100644 index 00000000000..97f2c1a9ba3 --- /dev/null +++ b/codex-rs/exec-server/src/noise_relay/ordered_ciphertext_tests.rs @@ -0,0 +1,48 @@ +use pretty_assertions::assert_eq; + +use super::MAX_PENDING_BYTES; +use super::OrderedCiphertextFrames; + +#[test] +fn releases_ciphertexts_only_in_nonce_order() { + let mut frames = OrderedCiphertextFrames::default(); + + assert_eq!( + frames.push(1, b"second".to_vec()).unwrap(), + Vec::>::new() + ); + assert_eq!( + frames.push(0, b"first".to_vec()).unwrap(), + vec![b"first".to_vec(), b"second".to_vec()] + ); +} + +#[test] +fn ignores_duplicate_ciphertexts_without_replacing_buffered_record() { + let mut frames = OrderedCiphertextFrames::default(); + + assert_eq!( + frames.push(1, b"first copy".to_vec()).unwrap(), + Vec::>::new() + ); + assert_eq!( + frames.push(1, b"replacement".to_vec()).unwrap(), + Vec::>::new() + ); + assert_eq!( + frames.push(0, b"zero".to_vec()).unwrap(), + vec![b"zero".to_vec(), b"first copy".to_vec()] + ); + assert_eq!( + frames.push(0, b"duplicate".to_vec()).unwrap(), + Vec::>::new() + ); +} + +#[test] +fn rejects_unbounded_reordering() { + let mut frames = OrderedCiphertextFrames::default(); + + assert!(frames.push(65, Vec::new()).is_err()); + assert!(frames.push(1, vec![0; MAX_PENDING_BYTES + 1]).is_err()); +}