Skip to content

Commit 33b9744

Browse files
committed
receiver: wait for batched FDs instead of returning MismatchedCount
When the sender batches file descriptors across multiple sendmsg() calls, the receiver may parse the complete JSON message before all FDs have arrived. Previously this returned a MismatchedCount error immediately. Buffer the parsed message in a pending_message field and let the receive() loop continue calling read_more_data() until enough FDs have been collected. Assisted-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent 7557460 commit 33b9744

2 files changed

Lines changed: 268 additions & 11 deletions

File tree

src/transport.rs

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ impl UnixSocketTransport {
5555
stream,
5656
buffer: Vec::new(),
5757
fd_queue: VecDeque::new(),
58+
pending_message: None,
5859
},
5960
)
6061
}
@@ -262,6 +263,8 @@ pub struct Receiver {
262263
stream: Arc<TokioUnixStream>,
263264
buffer: Vec<u8>,
264265
fd_queue: VecDeque<OwnedFd>,
266+
/// A fully parsed JSON message waiting for its FDs to arrive.
267+
pending_message: Option<(serde_json::Value, usize)>,
265268
}
266269

267270
impl Receiver {
@@ -275,7 +278,18 @@ impl Receiver {
275278
return Ok(message);
276279
}
277280

278-
self.read_more_data().await?;
281+
match self.read_more_data().await {
282+
Err(Error::ConnectionClosed) if self.pending_message.is_some() => {
283+
// Connection closed while waiting for FDs — per spec
284+
// Section 5, Step 4 this is a Mismatched Count error.
285+
let (_, fd_count) = self.pending_message.take().unwrap();
286+
return Err(Error::MismatchedCount {
287+
expected: fd_count,
288+
found: self.fd_queue.len(),
289+
});
290+
}
291+
other => other?,
292+
}
279293
}
280294
}
281295

@@ -300,7 +314,43 @@ impl Receiver {
300314
}
301315
}
302316

317+
/// Build a `MessageWithFds` by draining `fd_count` FDs from the queue.
318+
fn build_message(
319+
fd_queue: &mut VecDeque<OwnedFd>,
320+
value: serde_json::Value,
321+
fd_count: usize,
322+
) -> Result<MessageWithFds> {
323+
let fds: Vec<OwnedFd> = (0..fd_count)
324+
.map(|_| fd_queue.pop_front().unwrap())
325+
.collect();
326+
let message = JsonRpcMessage::from_json_value(value)?;
327+
Ok(MessageWithFds::new(message, fds))
328+
}
329+
303330
fn try_parse_message(&mut self) -> Result<Option<MessageWithFds>> {
331+
// Check if we have a pending message waiting for FDs.
332+
if let Some((value, fd_count)) = self.pending_message.take() {
333+
if self.fd_queue.len() >= fd_count {
334+
return Ok(Some(Self::build_message(
335+
&mut self.fd_queue,
336+
value,
337+
fd_count,
338+
)?));
339+
}
340+
// Not enough FDs yet. Per the spec (Section 5, Step 4),
341+
// if the buffer contains any non-whitespace byte the sender
342+
// has started the next message before delivering all FDs for
343+
// the current one — that is a fatal protocol violation.
344+
if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
345+
return Err(Error::MismatchedCount {
346+
expected: fd_count,
347+
found: self.fd_queue.len(),
348+
});
349+
}
350+
self.pending_message = Some((value, fd_count));
351+
return Ok(None);
352+
}
353+
304354
if self.buffer.is_empty() {
305355
return Ok(None);
306356
}
@@ -323,18 +373,33 @@ impl Receiver {
323373
let fd_count = get_fd_count(&value);
324374

325375
if fd_count > self.fd_queue.len() {
326-
return Err(Error::MismatchedCount {
327-
expected: fd_count,
328-
found: self.fd_queue.len(),
329-
});
376+
// FDs may arrive across multiple recvmsg() calls when the
377+
// sender batches them. Buffer the parsed message and let
378+
// the receive() loop read more data.
379+
//
380+
// Per the spec (Section 5, Step 4), if the buffer already
381+
// contains non-whitespace bytes the sender has started the
382+
// next message before delivering all FDs — a fatal error.
383+
if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
384+
return Err(Error::MismatchedCount {
385+
expected: fd_count,
386+
found: self.fd_queue.len(),
387+
});
388+
}
389+
debug!(
390+
"Message expects {} FDs but only {} available, waiting for more",
391+
fd_count,
392+
self.fd_queue.len()
393+
);
394+
self.pending_message = Some((value, fd_count));
395+
return Ok(None);
330396
}
331397

332-
let fds: Vec<OwnedFd> = (0..fd_count)
333-
.map(|_| self.fd_queue.pop_front().unwrap())
334-
.collect();
335-
336-
let message = JsonRpcMessage::from_json_value(value)?;
337-
Ok(Some(MessageWithFds::new(message, fds)))
398+
Ok(Some(Self::build_message(
399+
&mut self.fd_queue,
400+
value,
401+
fd_count,
402+
)?))
338403
}
339404
Some(Err(e)) if e.is_eof() => {
340405
// Incomplete JSON - need more data

tests/integration_tests.rs

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,3 +2276,195 @@ async fn test_fd_batching_many_fds_small_batches() -> Result<()> {
22762276
server_handle.abort();
22772277
Ok(())
22782278
}
2279+
2280+
/// Test that the receiver correctly waits for batched FDs from the server.
2281+
///
2282+
/// When the server responds with many FDs using a small batch size, the
2283+
/// receiver may parse the JSON message before all FDs have arrived. The
2284+
/// receiver must buffer the parsed message and keep reading until enough
2285+
/// FDs are available, rather than returning a MismatchedCount error.
2286+
#[tokio::test]
2287+
async fn test_receiver_waits_for_batched_response_fds() -> Result<()> {
2288+
let temp_dir = TempDir::new().unwrap();
2289+
let socket_path = temp_dir.path().join("test_receiver_batch.sock");
2290+
2291+
let num_fds = 5;
2292+
2293+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2294+
2295+
let server_handle = tokio::spawn(async move {
2296+
if let Ok((stream, _)) = listener.accept().await {
2297+
let transport = UnixSocketTransport::new(stream);
2298+
let (mut sender, mut receiver) = transport.split();
2299+
2300+
// Force small batches on the server side so the client's
2301+
// receiver sees FDs arriving across multiple recvmsg() calls.
2302+
sender.set_max_fds_per_sendmsg(NonZeroUsize::new(1).unwrap());
2303+
2304+
// Read the request.
2305+
let request = receiver.receive().await.unwrap();
2306+
assert!(request.file_descriptors.is_empty());
2307+
2308+
// Build a response with many FDs.
2309+
let mut fds: Vec<OwnedFd> = Vec::new();
2310+
for i in 0..num_fds {
2311+
let mut temp_file = NamedTempFile::new().unwrap();
2312+
write!(temp_file, "response file {i}").unwrap();
2313+
temp_file.flush().unwrap();
2314+
temp_file.seek(SeekFrom::Start(0)).unwrap();
2315+
fds.push(temp_file.into_file().into());
2316+
}
2317+
2318+
let response = jsonrpc_fdpass::JsonRpcResponse::success(
2319+
Value::String("here are your files".to_string()),
2320+
Value::Number(1.into()),
2321+
);
2322+
let msg = MessageWithFds::new(JsonRpcMessage::Response(response), fds);
2323+
sender.send(msg).await.unwrap();
2324+
}
2325+
});
2326+
2327+
// Client side: send request, receive response with batched FDs.
2328+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2329+
let transport = UnixSocketTransport::new(stream);
2330+
let (mut sender, mut receiver) = transport.split();
2331+
2332+
let request = JsonRpcRequest::new("get_files".to_string(), None, Value::Number(1.into()));
2333+
sender
2334+
.send(MessageWithFds::new(
2335+
JsonRpcMessage::Request(request),
2336+
Vec::new(),
2337+
))
2338+
.await?;
2339+
2340+
// This is the critical part: the receiver must wait for all FDs
2341+
// instead of failing with MismatchedCount.
2342+
let response = receiver.receive().await?;
2343+
assert_eq!(
2344+
response.file_descriptors.len(),
2345+
num_fds,
2346+
"Expected {num_fds} FDs in batched response"
2347+
);
2348+
2349+
// Verify FD contents are correct and in order.
2350+
for (i, fd) in response.file_descriptors.into_iter().enumerate() {
2351+
let mut file = File::from(fd);
2352+
let mut contents = String::new();
2353+
file.seek(SeekFrom::Start(0)).unwrap();
2354+
file.read_to_string(&mut contents).unwrap();
2355+
assert_eq!(contents, format!("response file {i}"));
2356+
}
2357+
2358+
server_handle.await.unwrap();
2359+
Ok(())
2360+
}
2361+
2362+
/// Test that the receiver returns MismatchedCount when the sender starts a new
2363+
/// message before delivering all FDs for the current one (protocol violation).
2364+
#[tokio::test]
2365+
async fn test_receiver_errors_on_next_message_before_fds() -> Result<()> {
2366+
let temp_dir = TempDir::new().unwrap();
2367+
let socket_path = temp_dir.path().join("test_next_msg_violation.sock");
2368+
2369+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2370+
2371+
let server_handle = tokio::spawn(async move {
2372+
if let Ok((stream, _)) = listener.accept().await {
2373+
let transport = UnixSocketTransport::new(stream);
2374+
let (_sender, mut receiver) = transport.split();
2375+
2376+
// The client will claim fds but send a second message before
2377+
// delivering them. We expect a MismatchedCount error.
2378+
match receiver.receive().await {
2379+
Err(jsonrpc_fdpass::Error::MismatchedCount { expected, found }) => {
2380+
assert_eq!(expected, 2);
2381+
assert_eq!(found, 0);
2382+
}
2383+
Err(e) => panic!("Expected MismatchedCount, got: {e:?}"),
2384+
Ok(_) => panic!("Should have failed with MismatchedCount"),
2385+
}
2386+
}
2387+
});
2388+
2389+
// Connect and send a message claiming 2 FDs, then immediately send
2390+
// a second message without delivering any FDs.
2391+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2392+
2393+
use tokio::io::AsyncWriteExt;
2394+
let mut stream = stream;
2395+
2396+
let first = serde_json::json!({
2397+
"jsonrpc": "2.0",
2398+
"method": "need_fds",
2399+
"params": {},
2400+
"id": 1,
2401+
"fds": 2
2402+
});
2403+
let second = serde_json::json!({
2404+
"jsonrpc": "2.0",
2405+
"method": "violation",
2406+
"params": {},
2407+
"id": 2
2408+
});
2409+
2410+
// Send both messages back-to-back without any FDs.
2411+
let mut payload = serde_json::to_vec(&first).unwrap();
2412+
payload.extend_from_slice(&serde_json::to_vec(&second).unwrap());
2413+
stream.write_all(&payload).await.unwrap();
2414+
stream.flush().await.unwrap();
2415+
2416+
server_handle.await.unwrap();
2417+
Ok(())
2418+
}
2419+
2420+
/// Test that the receiver returns MismatchedCount when the connection is
2421+
/// closed while waiting for batched FDs.
2422+
#[tokio::test]
2423+
async fn test_receiver_errors_on_close_while_pending() -> Result<()> {
2424+
let temp_dir = TempDir::new().unwrap();
2425+
let socket_path = temp_dir.path().join("test_close_pending.sock");
2426+
2427+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2428+
2429+
let server_handle = tokio::spawn(async move {
2430+
if let Ok((stream, _)) = listener.accept().await {
2431+
let transport = UnixSocketTransport::new(stream);
2432+
let (_sender, mut receiver) = transport.split();
2433+
2434+
match receiver.receive().await {
2435+
Err(jsonrpc_fdpass::Error::MismatchedCount { expected, found }) => {
2436+
assert_eq!(expected, 3);
2437+
assert_eq!(found, 0);
2438+
}
2439+
Err(e) => panic!("Expected MismatchedCount, got: {e:?}"),
2440+
Ok(_) => panic!("Should have failed with MismatchedCount"),
2441+
}
2442+
}
2443+
});
2444+
2445+
// Connect, send a message claiming 3 FDs, then drop the connection.
2446+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2447+
2448+
use tokio::io::AsyncWriteExt;
2449+
let mut stream = stream;
2450+
2451+
let msg = serde_json::json!({
2452+
"jsonrpc": "2.0",
2453+
"method": "test",
2454+
"params": {},
2455+
"id": 1,
2456+
"fds": 3
2457+
});
2458+
2459+
stream
2460+
.write_all(&serde_json::to_vec(&msg).unwrap())
2461+
.await
2462+
.unwrap();
2463+
stream.flush().await.unwrap();
2464+
2465+
// Close the connection without sending any FDs.
2466+
drop(stream);
2467+
2468+
server_handle.await.unwrap();
2469+
Ok(())
2470+
}

0 commit comments

Comments
 (0)