@@ -186,15 +186,13 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
186186
187187 Ok ( self
188188 . table
189- . push ( FutureClientStreams ( StreamState :: Pending ( Box :: pin (
190- async move {
191- let connector = tokio_rustls:: TlsConnector :: from ( default_client_config ( ) ) ;
192- connector
193- . connect ( domain, streams)
194- . await
195- . with_context ( || "connection failed" )
196- } ,
197- ) ) ) ) ?)
189+ . push ( FutureStreams ( StreamState :: Pending ( Box :: pin ( async move {
190+ let connector = tokio_rustls:: TlsConnector :: from ( default_client_config ( ) ) ;
191+ connector
192+ . connect ( domain, streams)
193+ . await
194+ . with_context ( || "connection failed" )
195+ } ) ) ) ) ?)
198196 }
199197
200198 fn drop (
@@ -206,21 +204,20 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
206204 }
207205}
208206
209- /// Future TLS connection after the handshake is completed.
210- pub struct FutureClientStreams ( StreamState < Result < TlsStream < WasiStreams > > > ) ;
207+ /// Future streams provides the tls streams after the handshake is completed
208+ pub struct FutureStreams < T > ( StreamState < Result < T > > ) ;
209+
210+ /// Library specific version of TLS connection after the handshake is completed.
211+ /// This alias allows it to use with wit-bindgen component generator which won't take generic types
212+ pub type FutureClientStreams = FutureStreams < TlsStream < WasiStreams > > ;
211213
212214#[ async_trait]
213- impl Pollable for FutureClientStreams {
215+ impl < T : Send + ' static > Pollable for FutureStreams < T > {
214216 async fn ready ( & mut self ) {
215- match & self . 0 {
216- StreamState :: Pending ( _) => ( ) ,
217+ match & mut self . 0 {
217218 StreamState :: Ready ( _) | StreamState :: Closed => return ,
219+ StreamState :: Pending ( task) => self . 0 = StreamState :: Ready ( task. as_mut ( ) . await ) ,
218220 }
219-
220- let StreamState :: Pending ( future) = mem:: replace ( & mut self . 0 , StreamState :: Closed ) else {
221- unreachable ! ( )
222- } ;
223- self . 0 = StreamState :: Ready ( future. await ) ;
224221 }
225222}
226223
@@ -324,7 +321,8 @@ enum StreamState<T> {
324321 Closed ,
325322}
326323
327- struct WasiStreams {
324+ /// Wrapper around Input and Output wasi IO Stream that provides Async Read/Write
325+ pub struct WasiStreams {
328326 input : StreamState < BoxInputStream > ,
329327 output : StreamState < BoxOutputStream > ,
330328}
@@ -635,3 +633,40 @@ fn try_lock_for_stream<TlsWriter>(
635633 . try_lock ( )
636634 . map_err ( |_| StreamError :: trap ( "concurrent access to resource not supported" ) )
637635}
636+
637+ #[ cfg( test) ]
638+ mod tests {
639+ use super :: * ;
640+ use tokio:: sync:: oneshot;
641+
642+ #[ tokio:: test]
643+ async fn test_future_client_streams_ready_can_be_canceled ( ) {
644+ let ( tx1, rx1) = oneshot:: channel :: < ( ) > ( ) ;
645+
646+ let mut future_streams = FutureStreams ( StreamState :: Pending ( Box :: pin ( async move {
647+ rx1. await . map_err ( |_| anyhow:: anyhow!( "oneshot canceled" ) )
648+ } ) ) ) ;
649+
650+ let mut fut = future_streams. ready ( ) ;
651+
652+ let mut cx = std:: task:: Context :: from_waker ( futures:: task:: noop_waker_ref ( ) ) ;
653+ assert ! ( fut. as_mut( ) . poll( & mut cx) . is_pending( ) ) ;
654+
655+ //cancel the readiness check
656+ drop ( fut) ;
657+
658+ match future_streams. 0 {
659+ StreamState :: Closed => panic ! ( "First future should be in Pending/ready state" ) ,
660+ _ => ( ) ,
661+ }
662+
663+ // make it ready and wait for it to progress
664+ tx1. send ( ( ) ) . unwrap ( ) ;
665+ future_streams. ready ( ) . await ;
666+
667+ match future_streams. 0 {
668+ StreamState :: Ready ( Ok ( ( ) ) ) => ( ) ,
669+ _ => panic ! ( "First future should be in Ready(Err) state" ) ,
670+ }
671+ }
672+ }
0 commit comments