Skip to content

Commit fdb2e54

Browse files
giuseppeclaude
andcommitted
receiver: wait for batched FDs instead of returning MismatchedCount
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ec6cd15 commit fdb2e54

2 files changed

Lines changed: 269 additions & 11 deletions

File tree

src/transport.rs

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ impl UnixSocketTransport {
5555
stream,
5656
buffer: Vec::new(),
5757
fd_queue: VecDeque::new(),
58+
pending_message: None,
5859
},
5960
)
6061
}
@@ -262,6 +263,8 @@ pub struct Receiver {
262263
stream: Arc<TokioUnixStream>,
263264
buffer: Vec<u8>,
264265
fd_queue: VecDeque<OwnedFd>,
266+
/// A fully parsed JSON message waiting for its FDs to arrive.
267+
pending_message: Option<(serde_json::Value, usize)>,
265268
}
266269

267270
impl Receiver {
@@ -275,7 +278,18 @@ impl Receiver {
275278
return Ok(message);
276279
}
277280

278-
self.read_more_data().await?;
281+
match self.read_more_data().await {
282+
Err(Error::ConnectionClosed) if self.pending_message.is_some() => {
283+
// Connection closed while waiting for FDs — per spec
284+
// Section 5, Step 4 this is a Mismatched Count error.
285+
let (_, fd_count) = self.pending_message.take().unwrap();
286+
return Err(Error::MismatchedCount {
287+
expected: fd_count,
288+
found: self.fd_queue.len(),
289+
});
290+
}
291+
other => other?,
292+
}
279293
}
280294
}
281295

@@ -300,7 +314,47 @@ impl Receiver {
300314
}
301315
}
302316

317+
/// Build a `MessageWithFds` by draining `fd_count` FDs from the queue.
318+
fn build_message(
319+
fd_queue: &mut VecDeque<OwnedFd>,
320+
value: serde_json::Value,
321+
fd_count: usize,
322+
) -> Result<MessageWithFds> {
323+
let fds: Vec<OwnedFd> = fd_queue.drain(..fd_count).collect();
324+
let message = JsonRpcMessage::from_json_value(value)?;
325+
Ok(MessageWithFds::new(message, fds))
326+
}
327+
303328
fn try_parse_message(&mut self) -> Result<Option<MessageWithFds>> {
329+
// Check if we have a pending message waiting for FDs.
330+
// While a message is pending, all subsequent message parsing is
331+
// blocked — even messages needing 0 FDs. This preserves FIFO
332+
// ordering on the Unix socket: FDs queued after the pending
333+
// message's FDs belong to later messages and must not be
334+
// consumed early.
335+
if let Some((value, fd_count)) = self
336+
.pending_message
337+
.take_if(|(_, c)| self.fd_queue.len() >= *c)
338+
{
339+
return Ok(Some(Self::build_message(
340+
&mut self.fd_queue,
341+
value,
342+
fd_count,
343+
)?));
344+
} else if let Some((_, fd_count)) = &self.pending_message {
345+
// Not enough FDs yet. Per the spec (Section 5, Step 4),
346+
// if the buffer contains any non-whitespace byte the sender
347+
// has started the next message before delivering all FDs for
348+
// the current one — that is a fatal protocol violation.
349+
if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
350+
return Err(Error::MismatchedCount {
351+
expected: *fd_count,
352+
found: self.fd_queue.len(),
353+
});
354+
}
355+
return Ok(None);
356+
}
357+
304358
if self.buffer.is_empty() {
305359
return Ok(None);
306360
}
@@ -323,18 +377,33 @@ impl Receiver {
323377
let fd_count = get_fd_count(&value);
324378

325379
if fd_count > self.fd_queue.len() {
326-
return Err(Error::MismatchedCount {
327-
expected: fd_count,
328-
found: self.fd_queue.len(),
329-
});
380+
// FDs may arrive across multiple recvmsg() calls when the
381+
// sender batches them. Buffer the parsed message and let
382+
// the receive() loop read more data.
383+
//
384+
// Per the spec (Section 5, Step 4), if the buffer already
385+
// contains non-whitespace bytes the sender has started the
386+
// next message before delivering all FDs — a fatal error.
387+
if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
388+
return Err(Error::MismatchedCount {
389+
expected: fd_count,
390+
found: self.fd_queue.len(),
391+
});
392+
}
393+
debug!(
394+
"Message expects {} FDs but only {} available, waiting for more",
395+
fd_count,
396+
self.fd_queue.len()
397+
);
398+
self.pending_message = Some((value, fd_count));
399+
return Ok(None);
330400
}
331401

332-
let fds: Vec<OwnedFd> = (0..fd_count)
333-
.map(|_| self.fd_queue.pop_front().unwrap())
334-
.collect();
335-
336-
let message = JsonRpcMessage::from_json_value(value)?;
337-
Ok(Some(MessageWithFds::new(message, fds)))
402+
Ok(Some(Self::build_message(
403+
&mut self.fd_queue,
404+
value,
405+
fd_count,
406+
)?))
338407
}
339408
Some(Err(e)) if e.is_eof() => {
340409
// Incomplete JSON - need more data

tests/integration_tests.rs

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,3 +2276,192 @@ async fn test_fd_batching_many_fds_small_batches() -> Result<()> {
22762276
server_handle.abort();
22772277
Ok(())
22782278
}
2279+
2280+
/// Test that the receiver correctly waits for batched FDs from the server.
2281+
///
2282+
/// When the server responds with many FDs using a small batch size, the
2283+
/// receiver may parse the JSON message before all FDs have arrived. The
2284+
/// receiver must buffer the parsed message and keep reading until enough
2285+
/// FDs are available, rather than returning a MismatchedCount error.
2286+
#[tokio::test]
2287+
async fn test_receiver_waits_for_batched_response_fds() -> Result<()> {
2288+
let temp_dir = TempDir::new().unwrap();
2289+
let socket_path = temp_dir.path().join("test_receiver_batch.sock");
2290+
2291+
let num_fds = 5;
2292+
2293+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2294+
2295+
let server_handle = tokio::spawn(async move {
2296+
let (stream, _) = listener.accept().await.unwrap();
2297+
let transport = UnixSocketTransport::new(stream);
2298+
let (mut sender, mut receiver) = transport.split();
2299+
2300+
// Force small batches on the server side so the client's
2301+
// receiver sees FDs arriving across multiple recvmsg() calls.
2302+
sender.set_max_fds_per_sendmsg(NonZeroUsize::new(1).unwrap());
2303+
2304+
// Read the request.
2305+
let request = receiver.receive().await.unwrap();
2306+
assert!(request.file_descriptors.is_empty());
2307+
2308+
// Build a response with many FDs.
2309+
let mut fds: Vec<OwnedFd> = Vec::new();
2310+
for i in 0..num_fds {
2311+
let mut temp_file = NamedTempFile::new().unwrap();
2312+
write!(temp_file, "response file {i}").unwrap();
2313+
temp_file.flush().unwrap();
2314+
temp_file.seek(SeekFrom::Start(0)).unwrap();
2315+
fds.push(temp_file.into_file().into());
2316+
}
2317+
2318+
let response = jsonrpc_fdpass::JsonRpcResponse::success(
2319+
Value::String("here are your files".to_string()),
2320+
Value::Number(1.into()),
2321+
);
2322+
let msg = MessageWithFds::new(JsonRpcMessage::Response(response), fds);
2323+
sender.send(msg).await.unwrap();
2324+
});
2325+
2326+
// Client side: send request, receive response with batched FDs.
2327+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2328+
let transport = UnixSocketTransport::new(stream);
2329+
let (mut sender, mut receiver) = transport.split();
2330+
2331+
let request = JsonRpcRequest::new("get_files".to_string(), None, Value::Number(1.into()));
2332+
sender
2333+
.send(MessageWithFds::new(
2334+
JsonRpcMessage::Request(request),
2335+
Vec::new(),
2336+
))
2337+
.await?;
2338+
2339+
// This is the critical part: the receiver must wait for all FDs
2340+
// instead of failing with MismatchedCount.
2341+
let response = receiver.receive().await?;
2342+
assert_eq!(
2343+
response.file_descriptors.len(),
2344+
num_fds,
2345+
"Expected {num_fds} FDs in batched response"
2346+
);
2347+
2348+
// Verify FD contents are correct and in order.
2349+
for (i, fd) in response.file_descriptors.into_iter().enumerate() {
2350+
let mut file = File::from(fd);
2351+
let mut contents = String::new();
2352+
file.seek(SeekFrom::Start(0)).unwrap();
2353+
file.read_to_string(&mut contents).unwrap();
2354+
assert_eq!(contents, format!("response file {i}"));
2355+
}
2356+
2357+
server_handle.await.unwrap();
2358+
Ok(())
2359+
}
2360+
2361+
/// Test that the receiver returns MismatchedCount when the sender starts a new
2362+
/// message before delivering all FDs for the current one (protocol violation).
2363+
#[tokio::test]
2364+
async fn test_receiver_errors_on_next_message_before_fds() -> Result<()> {
2365+
let temp_dir = TempDir::new().unwrap();
2366+
let socket_path = temp_dir.path().join("test_next_msg_violation.sock");
2367+
2368+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2369+
2370+
let server_handle = tokio::spawn(async move {
2371+
let (stream, _) = listener.accept().await.unwrap();
2372+
let transport = UnixSocketTransport::new(stream);
2373+
let (_sender, mut receiver) = transport.split();
2374+
2375+
// The client will claim fds but send a second message before
2376+
// delivering them. We expect a MismatchedCount error.
2377+
match receiver.receive().await {
2378+
Err(jsonrpc_fdpass::Error::MismatchedCount { expected, found }) => {
2379+
assert_eq!(expected, 2);
2380+
assert_eq!(found, 0);
2381+
}
2382+
Err(e) => panic!("Expected MismatchedCount, got: {e:?}"),
2383+
Ok(_) => panic!("Should have failed with MismatchedCount"),
2384+
}
2385+
});
2386+
2387+
// Connect and send a message claiming 2 FDs, then immediately send
2388+
// a second message without delivering any FDs.
2389+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2390+
2391+
use tokio::io::AsyncWriteExt;
2392+
let mut stream = stream;
2393+
2394+
let first = serde_json::json!({
2395+
"jsonrpc": "2.0",
2396+
"method": "need_fds",
2397+
"params": {},
2398+
"id": 1,
2399+
"fds": 2
2400+
});
2401+
let second = serde_json::json!({
2402+
"jsonrpc": "2.0",
2403+
"method": "violation",
2404+
"params": {},
2405+
"id": 2
2406+
});
2407+
2408+
// Send both messages back-to-back without any FDs.
2409+
let mut payload = serde_json::to_vec(&first).unwrap();
2410+
payload.extend_from_slice(&serde_json::to_vec(&second).unwrap());
2411+
stream.write_all(&payload).await.unwrap();
2412+
stream.flush().await.unwrap();
2413+
2414+
server_handle.await.unwrap();
2415+
Ok(())
2416+
}
2417+
2418+
/// Test that the receiver returns MismatchedCount when the connection is
2419+
/// closed while waiting for batched FDs.
2420+
#[tokio::test]
2421+
async fn test_receiver_errors_on_close_while_pending() -> Result<()> {
2422+
let temp_dir = TempDir::new().unwrap();
2423+
let socket_path = temp_dir.path().join("test_close_pending.sock");
2424+
2425+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2426+
2427+
let server_handle = tokio::spawn(async move {
2428+
let (stream, _) = listener.accept().await.unwrap();
2429+
let transport = UnixSocketTransport::new(stream);
2430+
let (_sender, mut receiver) = transport.split();
2431+
2432+
match receiver.receive().await {
2433+
Err(jsonrpc_fdpass::Error::MismatchedCount { expected, found }) => {
2434+
assert_eq!(expected, 3);
2435+
assert_eq!(found, 0);
2436+
}
2437+
Err(e) => panic!("Expected MismatchedCount, got: {e:?}"),
2438+
Ok(_) => panic!("Should have failed with MismatchedCount"),
2439+
}
2440+
});
2441+
2442+
// Connect, send a message claiming 3 FDs, then drop the connection.
2443+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2444+
2445+
use tokio::io::AsyncWriteExt;
2446+
let mut stream = stream;
2447+
2448+
let msg = serde_json::json!({
2449+
"jsonrpc": "2.0",
2450+
"method": "test",
2451+
"params": {},
2452+
"id": 1,
2453+
"fds": 3
2454+
});
2455+
2456+
stream
2457+
.write_all(&serde_json::to_vec(&msg).unwrap())
2458+
.await
2459+
.unwrap();
2460+
stream.flush().await.unwrap();
2461+
2462+
// Close the connection without sending any FDs.
2463+
drop(stream);
2464+
2465+
server_handle.await.unwrap();
2466+
Ok(())
2467+
}

0 commit comments

Comments
 (0)