Skip to content

Commit a698d62

Browse files
giuseppeclaude
authored andcommitted
receiver: wait for batched FDs instead of returning MismatchedCount
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent f2e8728 commit a698d62

2 files changed

Lines changed: 270 additions & 11 deletions

File tree

src/transport.rs

Lines changed: 81 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ impl UnixSocketTransport {
5555
stream,
5656
buffer: Vec::new(),
5757
fd_queue: VecDeque::new(),
58+
pending_message: None,
5859
},
5960
)
6061
}
@@ -262,6 +263,8 @@ pub struct Receiver {
262263
stream: Arc<TokioUnixStream>,
263264
buffer: Vec<u8>,
264265
fd_queue: VecDeque<OwnedFd>,
266+
/// A fully parsed JSON message waiting for its FDs to arrive.
267+
pending_message: Option<(serde_json::Value, usize)>,
265268
}
266269

267270
impl Receiver {
@@ -275,7 +278,19 @@ impl Receiver {
275278
return Ok(message);
276279
}
277280

278-
self.read_more_data().await?;
281+
if let Err(e) = self.read_more_data().await {
282+
if matches!(e, Error::ConnectionClosed)
283+
&& let Some((_, fd_count)) = self.pending_message.take()
284+
{
285+
// Connection closed while waiting for FDs — per spec
286+
// Section 5, Step 4 this is a Mismatched Count error.
287+
return Err(Error::MismatchedCount {
288+
expected: fd_count,
289+
found: self.fd_queue.len(),
290+
});
291+
}
292+
return Err(e);
293+
}
279294
}
280295
}
281296

@@ -300,7 +315,47 @@ impl Receiver {
300315
}
301316
}
302317

318+
/// Build a `MessageWithFds` by draining `fd_count` FDs from the queue.
319+
fn build_message(
320+
fd_queue: &mut VecDeque<OwnedFd>,
321+
value: serde_json::Value,
322+
fd_count: usize,
323+
) -> Result<MessageWithFds> {
324+
let fds: Vec<OwnedFd> = fd_queue.drain(..fd_count).collect();
325+
let message = JsonRpcMessage::from_json_value(value)?;
326+
Ok(MessageWithFds::new(message, fds))
327+
}
328+
303329
fn try_parse_message(&mut self) -> Result<Option<MessageWithFds>> {
330+
// Check if we have a pending message waiting for FDs.
331+
// While a message is pending, all subsequent message parsing is
332+
// blocked — even messages needing 0 FDs. This preserves FIFO
333+
// ordering on the Unix socket: FDs queued after the pending
334+
// message's FDs belong to later messages and must not be
335+
// consumed early.
336+
if let Some((value, fd_count)) = self
337+
.pending_message
338+
.take_if(|(_, c)| self.fd_queue.len() >= *c)
339+
{
340+
return Ok(Some(Self::build_message(
341+
&mut self.fd_queue,
342+
value,
343+
fd_count,
344+
)?));
345+
} else if let Some((_, fd_count)) = &self.pending_message {
346+
// Not enough FDs yet. Per the spec (Section 5, Step 4),
347+
// if the buffer contains any non-whitespace byte the sender
348+
// has started the next message before delivering all FDs for
349+
// the current one — that is a fatal protocol violation.
350+
if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
351+
return Err(Error::MismatchedCount {
352+
expected: *fd_count,
353+
found: self.fd_queue.len(),
354+
});
355+
}
356+
return Ok(None);
357+
}
358+
304359
if self.buffer.is_empty() {
305360
return Ok(None);
306361
}
@@ -323,18 +378,33 @@ impl Receiver {
323378
let fd_count = get_fd_count(&value);
324379

325380
if fd_count > self.fd_queue.len() {
326-
return Err(Error::MismatchedCount {
327-
expected: fd_count,
328-
found: self.fd_queue.len(),
329-
});
381+
// FDs may arrive across multiple recvmsg() calls when the
382+
// sender batches them. Buffer the parsed message and let
383+
// the receive() loop read more data.
384+
//
385+
// Per the spec (Section 5, Step 4), if the buffer already
386+
// contains non-whitespace bytes the sender has started the
387+
// next message before delivering all FDs — a fatal error.
388+
if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
389+
return Err(Error::MismatchedCount {
390+
expected: fd_count,
391+
found: self.fd_queue.len(),
392+
});
393+
}
394+
trace!(
395+
"Message expects {} FDs but only {} available, waiting for more",
396+
fd_count,
397+
self.fd_queue.len()
398+
);
399+
self.pending_message = Some((value, fd_count));
400+
return Ok(None);
330401
}
331402

332-
let fds: Vec<OwnedFd> = (0..fd_count)
333-
.map(|_| self.fd_queue.pop_front().unwrap())
334-
.collect();
335-
336-
let message = JsonRpcMessage::from_json_value(value)?;
337-
Ok(Some(MessageWithFds::new(message, fds)))
403+
Ok(Some(Self::build_message(
404+
&mut self.fd_queue,
405+
value,
406+
fd_count,
407+
)?))
338408
}
339409
Some(Err(e)) if e.is_eof() => {
340410
// Incomplete JSON - need more data

tests/integration_tests.rs

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,3 +2276,192 @@ async fn test_fd_batching_many_fds_small_batches() -> Result<()> {
22762276
server_handle.abort();
22772277
Ok(())
22782278
}
2279+
2280+
/// Test that the receiver correctly waits for batched FDs from the server.
2281+
///
2282+
/// When the server responds with many FDs using a small batch size, the
2283+
/// receiver may parse the JSON message before all FDs have arrived. The
2284+
/// receiver must buffer the parsed message and keep reading until enough
2285+
/// FDs are available, rather than returning a MismatchedCount error.
2286+
#[tokio::test]
2287+
async fn test_receiver_waits_for_batched_response_fds() -> Result<()> {
2288+
let temp_dir = TempDir::new().unwrap();
2289+
let socket_path = temp_dir.path().join("test_receiver_batch.sock");
2290+
2291+
let num_fds = 5;
2292+
2293+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2294+
2295+
let server_handle = tokio::spawn(async move {
2296+
let (stream, _) = listener.accept().await.unwrap();
2297+
let transport = UnixSocketTransport::new(stream);
2298+
let (mut sender, mut receiver) = transport.split();
2299+
2300+
// Force small batches on the server side so the client's
2301+
// receiver sees FDs arriving across multiple recvmsg() calls.
2302+
sender.set_max_fds_per_sendmsg(NonZeroUsize::new(1).unwrap());
2303+
2304+
// Read the request.
2305+
let request = receiver.receive().await.unwrap();
2306+
assert!(request.file_descriptors.is_empty());
2307+
2308+
// Build a response with many FDs.
2309+
let mut fds: Vec<OwnedFd> = Vec::new();
2310+
for i in 0..num_fds {
2311+
let mut temp_file = NamedTempFile::new().unwrap();
2312+
write!(temp_file, "response file {i}").unwrap();
2313+
temp_file.flush().unwrap();
2314+
temp_file.seek(SeekFrom::Start(0)).unwrap();
2315+
fds.push(temp_file.into_file().into());
2316+
}
2317+
2318+
let response = jsonrpc_fdpass::JsonRpcResponse::success(
2319+
Value::String("here are your files".to_string()),
2320+
Value::Number(1.into()),
2321+
);
2322+
let msg = MessageWithFds::new(JsonRpcMessage::Response(response), fds);
2323+
sender.send(msg).await.unwrap();
2324+
});
2325+
2326+
// Client side: send request, receive response with batched FDs.
2327+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2328+
let transport = UnixSocketTransport::new(stream);
2329+
let (mut sender, mut receiver) = transport.split();
2330+
2331+
let request = JsonRpcRequest::new("get_files".to_string(), None, Value::Number(1.into()));
2332+
sender
2333+
.send(MessageWithFds::new(
2334+
JsonRpcMessage::Request(request),
2335+
Vec::new(),
2336+
))
2337+
.await?;
2338+
2339+
// This is the critical part: the receiver must wait for all FDs
2340+
// instead of failing with MismatchedCount.
2341+
let response = receiver.receive().await?;
2342+
assert_eq!(
2343+
response.file_descriptors.len(),
2344+
num_fds,
2345+
"Expected {num_fds} FDs in batched response"
2346+
);
2347+
2348+
// Verify FD contents are correct and in order.
2349+
for (i, fd) in response.file_descriptors.into_iter().enumerate() {
2350+
let mut file = File::from(fd);
2351+
let mut contents = String::new();
2352+
file.seek(SeekFrom::Start(0)).unwrap();
2353+
file.read_to_string(&mut contents).unwrap();
2354+
assert_eq!(contents, format!("response file {i}"));
2355+
}
2356+
2357+
server_handle.await.unwrap();
2358+
Ok(())
2359+
}
2360+
2361+
/// Test that the receiver returns MismatchedCount when the sender starts a new
2362+
/// message before delivering all FDs for the current one (protocol violation).
2363+
#[tokio::test]
2364+
async fn test_receiver_errors_on_next_message_before_fds() -> Result<()> {
2365+
let temp_dir = TempDir::new().unwrap();
2366+
let socket_path = temp_dir.path().join("test_next_msg_violation.sock");
2367+
2368+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2369+
2370+
let server_handle = tokio::spawn(async move {
2371+
let (stream, _) = listener.accept().await.unwrap();
2372+
let transport = UnixSocketTransport::new(stream);
2373+
let (_sender, mut receiver) = transport.split();
2374+
2375+
// The client will claim fds but send a second message before
2376+
// delivering them. We expect a MismatchedCount error.
2377+
match receiver.receive().await {
2378+
Err(jsonrpc_fdpass::Error::MismatchedCount { expected, found }) => {
2379+
assert_eq!(expected, 2);
2380+
assert_eq!(found, 0);
2381+
}
2382+
Err(e) => panic!("Expected MismatchedCount, got: {e:?}"),
2383+
Ok(_) => panic!("Should have failed with MismatchedCount"),
2384+
}
2385+
});
2386+
2387+
// Connect and send a message claiming 2 FDs, then immediately send
2388+
// a second message without delivering any FDs.
2389+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2390+
2391+
use tokio::io::AsyncWriteExt;
2392+
let mut stream = stream;
2393+
2394+
let first = serde_json::json!({
2395+
"jsonrpc": "2.0",
2396+
"method": "need_fds",
2397+
"params": {},
2398+
"id": 1,
2399+
"fds": 2
2400+
});
2401+
let second = serde_json::json!({
2402+
"jsonrpc": "2.0",
2403+
"method": "violation",
2404+
"params": {},
2405+
"id": 2
2406+
});
2407+
2408+
// Send both messages back-to-back without any FDs.
2409+
let mut payload = serde_json::to_vec(&first).unwrap();
2410+
payload.extend_from_slice(&serde_json::to_vec(&second).unwrap());
2411+
stream.write_all(&payload).await.unwrap();
2412+
stream.flush().await.unwrap();
2413+
2414+
server_handle.await.unwrap();
2415+
Ok(())
2416+
}
2417+
2418+
/// Test that the receiver returns MismatchedCount when the connection is
2419+
/// closed while waiting for batched FDs.
2420+
#[tokio::test]
2421+
async fn test_receiver_errors_on_close_while_pending() -> Result<()> {
2422+
let temp_dir = TempDir::new().unwrap();
2423+
let socket_path = temp_dir.path().join("test_close_pending.sock");
2424+
2425+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2426+
2427+
let server_handle = tokio::spawn(async move {
2428+
let (stream, _) = listener.accept().await.unwrap();
2429+
let transport = UnixSocketTransport::new(stream);
2430+
let (_sender, mut receiver) = transport.split();
2431+
2432+
match receiver.receive().await {
2433+
Err(jsonrpc_fdpass::Error::MismatchedCount { expected, found }) => {
2434+
assert_eq!(expected, 3);
2435+
assert_eq!(found, 0);
2436+
}
2437+
Err(e) => panic!("Expected MismatchedCount, got: {e:?}"),
2438+
Ok(_) => panic!("Should have failed with MismatchedCount"),
2439+
}
2440+
});
2441+
2442+
// Connect, send a message claiming 3 FDs, then drop the connection.
2443+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2444+
2445+
use tokio::io::AsyncWriteExt;
2446+
let mut stream = stream;
2447+
2448+
let msg = serde_json::json!({
2449+
"jsonrpc": "2.0",
2450+
"method": "test",
2451+
"params": {},
2452+
"id": 1,
2453+
"fds": 3
2454+
});
2455+
2456+
stream
2457+
.write_all(&serde_json::to_vec(&msg).unwrap())
2458+
.await
2459+
.unwrap();
2460+
stream.flush().await.unwrap();
2461+
2462+
// Close the connection without sending any FDs.
2463+
drop(stream);
2464+
2465+
server_handle.await.unwrap();
2466+
Ok(())
2467+
}

0 commit comments

Comments
 (0)