Skip to content

Commit 8aa2697

Browse files
committed
receiver: wait for batched FDs instead of returning MismatchedCount
When the sender batches file descriptors across multiple sendmsg() calls, the receiver may parse the complete JSON message before all FDs have arrived. Previously this returned a MismatchedCount error immediately. Buffer the parsed message in a pending_message field and let the receive() loop continue calling read_more_data() until enough FDs have been collected. Assisted-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
1 parent 7557460 commit 8aa2697

2 files changed

Lines changed: 112 additions & 4 deletions

File tree

src/transport.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ impl UnixSocketTransport {
5555
stream,
5656
buffer: Vec::new(),
5757
fd_queue: VecDeque::new(),
58+
pending_message: None,
5859
},
5960
)
6061
}
@@ -262,6 +263,8 @@ pub struct Receiver {
262263
stream: Arc<TokioUnixStream>,
263264
buffer: Vec<u8>,
264265
fd_queue: VecDeque<OwnedFd>,
266+
/// A fully parsed JSON message waiting for its FDs to arrive.
267+
pending_message: Option<(serde_json::Value, usize)>,
265268
}
266269

267270
impl Receiver {
@@ -301,6 +304,23 @@ impl Receiver {
301304
}
302305

303306
fn try_parse_message(&mut self) -> Result<Option<MessageWithFds>> {
307+
// Check if we have a pending message waiting for FDs.
308+
if let Some((ref value, fd_count)) = self.pending_message {
309+
if self.fd_queue.len() >= fd_count {
310+
let value = value.clone();
311+
self.pending_message = None;
312+
313+
let fds: Vec<OwnedFd> = (0..fd_count)
314+
.map(|_| self.fd_queue.pop_front().unwrap())
315+
.collect();
316+
317+
let message = JsonRpcMessage::from_json_value(value)?;
318+
return Ok(Some(MessageWithFds::new(message, fds)));
319+
}
320+
// Still waiting for more FDs
321+
return Ok(None);
322+
}
323+
304324
if self.buffer.is_empty() {
305325
return Ok(None);
306326
}
@@ -323,10 +343,16 @@ impl Receiver {
323343
let fd_count = get_fd_count(&value);
324344

325345
if fd_count > self.fd_queue.len() {
326-
return Err(Error::MismatchedCount {
327-
expected: fd_count,
328-
found: self.fd_queue.len(),
329-
});
346+
// FDs may arrive across multiple recvmsg() calls when the
347+
// sender batches them. Buffer the parsed message and let
348+
// the receive() loop read more data.
349+
debug!(
350+
"Message expects {} FDs but only {} available, waiting for more",
351+
fd_count,
352+
self.fd_queue.len()
353+
);
354+
self.pending_message = Some((value, fd_count));
355+
return Ok(None);
330356
}
331357

332358
let fds: Vec<OwnedFd> = (0..fd_count)

tests/integration_tests.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,3 +2276,85 @@ async fn test_fd_batching_many_fds_small_batches() -> Result<()> {
22762276
server_handle.abort();
22772277
Ok(())
22782278
}
2279+
2280+
/// Test that the receiver correctly waits for batched FDs from the server.
2281+
///
2282+
/// When the server responds with many FDs using a small batch size, the
2283+
/// receiver may parse the JSON message before all FDs have arrived. The
2284+
/// receiver must buffer the parsed message and keep reading until enough
2285+
/// FDs are available, rather than returning a MismatchedCount error.
2286+
#[tokio::test]
2287+
async fn test_receiver_waits_for_batched_response_fds() -> Result<()> {
2288+
let temp_dir = TempDir::new().unwrap();
2289+
let socket_path = temp_dir.path().join("test_receiver_batch.sock");
2290+
2291+
let num_fds = 5;
2292+
2293+
let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
2294+
2295+
let server_handle = tokio::spawn(async move {
2296+
if let Ok((stream, _)) = listener.accept().await {
2297+
let transport = UnixSocketTransport::new(stream);
2298+
let (mut sender, mut receiver) = transport.split();
2299+
2300+
// Force small batches on the server side so the client's
2301+
// receiver sees FDs arriving across multiple recvmsg() calls.
2302+
sender.set_max_fds_per_sendmsg(NonZeroUsize::new(1).unwrap());
2303+
2304+
// Read the request.
2305+
let request = receiver.receive().await.unwrap();
2306+
assert!(request.file_descriptors.is_empty());
2307+
2308+
// Build a response with many FDs.
2309+
let mut fds: Vec<OwnedFd> = Vec::new();
2310+
for i in 0..num_fds {
2311+
let mut temp_file = NamedTempFile::new().unwrap();
2312+
write!(temp_file, "response file {i}").unwrap();
2313+
temp_file.flush().unwrap();
2314+
temp_file.seek(SeekFrom::Start(0)).unwrap();
2315+
fds.push(temp_file.into_file().into());
2316+
}
2317+
2318+
let response = jsonrpc_fdpass::JsonRpcResponse::success(
2319+
Value::String("here are your files".to_string()),
2320+
Value::Number(1.into()),
2321+
);
2322+
let msg = MessageWithFds::new(JsonRpcMessage::Response(response), fds);
2323+
sender.send(msg).await.unwrap();
2324+
}
2325+
});
2326+
2327+
// Client side: send request, receive response with batched FDs.
2328+
let stream = tokio::net::UnixStream::connect(&socket_path).await.unwrap();
2329+
let transport = UnixSocketTransport::new(stream);
2330+
let (mut sender, mut receiver) = transport.split();
2331+
2332+
let request = JsonRpcRequest::new("get_files".to_string(), None, Value::Number(1.into()));
2333+
sender
2334+
.send(MessageWithFds::new(
2335+
JsonRpcMessage::Request(request),
2336+
Vec::new(),
2337+
))
2338+
.await?;
2339+
2340+
// This is the critical part: the receiver must wait for all FDs
2341+
// instead of failing with MismatchedCount.
2342+
let response = receiver.receive().await?;
2343+
assert_eq!(
2344+
response.file_descriptors.len(),
2345+
num_fds,
2346+
"Expected {num_fds} FDs in batched response"
2347+
);
2348+
2349+
// Verify FD contents are correct and in order.
2350+
for (i, fd) in response.file_descriptors.into_iter().enumerate() {
2351+
let mut file = File::from(fd);
2352+
let mut contents = String::new();
2353+
file.seek(SeekFrom::Start(0)).unwrap();
2354+
file.read_to_string(&mut contents).unwrap();
2355+
assert_eq!(contents, format!("response file {i}"));
2356+
}
2357+
2358+
server_handle.await.unwrap();
2359+
Ok(())
2360+
}

0 commit comments

Comments
 (0)