@@ -2276,3 +2276,195 @@ async fn test_fd_batching_many_fds_small_batches() -> Result<()> {
22762276 server_handle. abort ( ) ;
22772277 Ok ( ( ) )
22782278}
2279+
2280+ /// Test that the receiver correctly waits for batched FDs from the server.
2281+ ///
2282+ /// When the server responds with many FDs using a small batch size, the
2283+ /// receiver may parse the JSON message before all FDs have arrived. The
2284+ /// receiver must buffer the parsed message and keep reading until enough
2285+ /// FDs are available, rather than returning a MismatchedCount error.
2286+ #[ tokio:: test]
2287+ async fn test_receiver_waits_for_batched_response_fds ( ) -> Result < ( ) > {
2288+ let temp_dir = TempDir :: new ( ) . unwrap ( ) ;
2289+ let socket_path = temp_dir. path ( ) . join ( "test_receiver_batch.sock" ) ;
2290+
2291+ let num_fds = 5 ;
2292+
2293+ let listener = tokio:: net:: UnixListener :: bind ( & socket_path) . unwrap ( ) ;
2294+
2295+ let server_handle = tokio:: spawn ( async move {
2296+ if let Ok ( ( stream, _) ) = listener. accept ( ) . await {
2297+ let transport = UnixSocketTransport :: new ( stream) ;
2298+ let ( mut sender, mut receiver) = transport. split ( ) ;
2299+
2300+ // Force small batches on the server side so the client's
2301+ // receiver sees FDs arriving across multiple recvmsg() calls.
2302+ sender. set_max_fds_per_sendmsg ( NonZeroUsize :: new ( 1 ) . unwrap ( ) ) ;
2303+
2304+ // Read the request.
2305+ let request = receiver. receive ( ) . await . unwrap ( ) ;
2306+ assert ! ( request. file_descriptors. is_empty( ) ) ;
2307+
2308+ // Build a response with many FDs.
2309+ let mut fds: Vec < OwnedFd > = Vec :: new ( ) ;
2310+ for i in 0 ..num_fds {
2311+ let mut temp_file = NamedTempFile :: new ( ) . unwrap ( ) ;
2312+ write ! ( temp_file, "response file {i}" ) . unwrap ( ) ;
2313+ temp_file. flush ( ) . unwrap ( ) ;
2314+ temp_file. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
2315+ fds. push ( temp_file. into_file ( ) . into ( ) ) ;
2316+ }
2317+
2318+ let response = jsonrpc_fdpass:: JsonRpcResponse :: success (
2319+ Value :: String ( "here are your files" . to_string ( ) ) ,
2320+ Value :: Number ( 1 . into ( ) ) ,
2321+ ) ;
2322+ let msg = MessageWithFds :: new ( JsonRpcMessage :: Response ( response) , fds) ;
2323+ sender. send ( msg) . await . unwrap ( ) ;
2324+ }
2325+ } ) ;
2326+
2327+ // Client side: send request, receive response with batched FDs.
2328+ let stream = tokio:: net:: UnixStream :: connect ( & socket_path) . await . unwrap ( ) ;
2329+ let transport = UnixSocketTransport :: new ( stream) ;
2330+ let ( mut sender, mut receiver) = transport. split ( ) ;
2331+
2332+ let request = JsonRpcRequest :: new ( "get_files" . to_string ( ) , None , Value :: Number ( 1 . into ( ) ) ) ;
2333+ sender
2334+ . send ( MessageWithFds :: new (
2335+ JsonRpcMessage :: Request ( request) ,
2336+ Vec :: new ( ) ,
2337+ ) )
2338+ . await ?;
2339+
2340+ // This is the critical part: the receiver must wait for all FDs
2341+ // instead of failing with MismatchedCount.
2342+ let response = receiver. receive ( ) . await ?;
2343+ assert_eq ! (
2344+ response. file_descriptors. len( ) ,
2345+ num_fds,
2346+ "Expected {num_fds} FDs in batched response"
2347+ ) ;
2348+
2349+ // Verify FD contents are correct and in order.
2350+ for ( i, fd) in response. file_descriptors . into_iter ( ) . enumerate ( ) {
2351+ let mut file = File :: from ( fd) ;
2352+ let mut contents = String :: new ( ) ;
2353+ file. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
2354+ file. read_to_string ( & mut contents) . unwrap ( ) ;
2355+ assert_eq ! ( contents, format!( "response file {i}" ) ) ;
2356+ }
2357+
2358+ server_handle. await . unwrap ( ) ;
2359+ Ok ( ( ) )
2360+ }
2361+
2362+ /// Test that the receiver returns MismatchedCount when the sender starts a new
2363+ /// message before delivering all FDs for the current one (protocol violation).
2364+ #[ tokio:: test]
2365+ async fn test_receiver_errors_on_next_message_before_fds ( ) -> Result < ( ) > {
2366+ let temp_dir = TempDir :: new ( ) . unwrap ( ) ;
2367+ let socket_path = temp_dir. path ( ) . join ( "test_next_msg_violation.sock" ) ;
2368+
2369+ let listener = tokio:: net:: UnixListener :: bind ( & socket_path) . unwrap ( ) ;
2370+
2371+ let server_handle = tokio:: spawn ( async move {
2372+ if let Ok ( ( stream, _) ) = listener. accept ( ) . await {
2373+ let transport = UnixSocketTransport :: new ( stream) ;
2374+ let ( _sender, mut receiver) = transport. split ( ) ;
2375+
2376+ // The client will claim fds but send a second message before
2377+ // delivering them. We expect a MismatchedCount error.
2378+ match receiver. receive ( ) . await {
2379+ Err ( jsonrpc_fdpass:: Error :: MismatchedCount { expected, found } ) => {
2380+ assert_eq ! ( expected, 2 ) ;
2381+ assert_eq ! ( found, 0 ) ;
2382+ }
2383+ Err ( e) => panic ! ( "Expected MismatchedCount, got: {e:?}" ) ,
2384+ Ok ( _) => panic ! ( "Should have failed with MismatchedCount" ) ,
2385+ }
2386+ }
2387+ } ) ;
2388+
2389+ // Connect and send a message claiming 2 FDs, then immediately send
2390+ // a second message without delivering any FDs.
2391+ let stream = tokio:: net:: UnixStream :: connect ( & socket_path) . await . unwrap ( ) ;
2392+
2393+ use tokio:: io:: AsyncWriteExt ;
2394+ let mut stream = stream;
2395+
2396+ let first = serde_json:: json!( {
2397+ "jsonrpc" : "2.0" ,
2398+ "method" : "need_fds" ,
2399+ "params" : { } ,
2400+ "id" : 1 ,
2401+ "fds" : 2
2402+ } ) ;
2403+ let second = serde_json:: json!( {
2404+ "jsonrpc" : "2.0" ,
2405+ "method" : "violation" ,
2406+ "params" : { } ,
2407+ "id" : 2
2408+ } ) ;
2409+
2410+ // Send both messages back-to-back without any FDs.
2411+ let mut payload = serde_json:: to_vec ( & first) . unwrap ( ) ;
2412+ payload. extend_from_slice ( & serde_json:: to_vec ( & second) . unwrap ( ) ) ;
2413+ stream. write_all ( & payload) . await . unwrap ( ) ;
2414+ stream. flush ( ) . await . unwrap ( ) ;
2415+
2416+ server_handle. await . unwrap ( ) ;
2417+ Ok ( ( ) )
2418+ }
2419+
2420+ /// Test that the receiver returns MismatchedCount when the connection is
2421+ /// closed while waiting for batched FDs.
2422+ #[ tokio:: test]
2423+ async fn test_receiver_errors_on_close_while_pending ( ) -> Result < ( ) > {
2424+ let temp_dir = TempDir :: new ( ) . unwrap ( ) ;
2425+ let socket_path = temp_dir. path ( ) . join ( "test_close_pending.sock" ) ;
2426+
2427+ let listener = tokio:: net:: UnixListener :: bind ( & socket_path) . unwrap ( ) ;
2428+
2429+ let server_handle = tokio:: spawn ( async move {
2430+ if let Ok ( ( stream, _) ) = listener. accept ( ) . await {
2431+ let transport = UnixSocketTransport :: new ( stream) ;
2432+ let ( _sender, mut receiver) = transport. split ( ) ;
2433+
2434+ match receiver. receive ( ) . await {
2435+ Err ( jsonrpc_fdpass:: Error :: MismatchedCount { expected, found } ) => {
2436+ assert_eq ! ( expected, 3 ) ;
2437+ assert_eq ! ( found, 0 ) ;
2438+ }
2439+ Err ( e) => panic ! ( "Expected MismatchedCount, got: {e:?}" ) ,
2440+ Ok ( _) => panic ! ( "Should have failed with MismatchedCount" ) ,
2441+ }
2442+ }
2443+ } ) ;
2444+
2445+ // Connect, send a message claiming 3 FDs, then drop the connection.
2446+ let stream = tokio:: net:: UnixStream :: connect ( & socket_path) . await . unwrap ( ) ;
2447+
2448+ use tokio:: io:: AsyncWriteExt ;
2449+ let mut stream = stream;
2450+
2451+ let msg = serde_json:: json!( {
2452+ "jsonrpc" : "2.0" ,
2453+ "method" : "test" ,
2454+ "params" : { } ,
2455+ "id" : 1 ,
2456+ "fds" : 3
2457+ } ) ;
2458+
2459+ stream
2460+ . write_all ( & serde_json:: to_vec ( & msg) . unwrap ( ) )
2461+ . await
2462+ . unwrap ( ) ;
2463+ stream. flush ( ) . await . unwrap ( ) ;
2464+
2465+ // Close the connection without sending any FDs.
2466+ drop ( stream) ;
2467+
2468+ server_handle. await . unwrap ( ) ;
2469+ Ok ( ( ) )
2470+ }
0 commit comments