Skip to content

Commit f13938e

Browse files
committed
Fix issue when another pollable cancels
Signed-off-by: James Sturtevant <jstur@microsoft.com>
1 parent 3cd3126 commit f13938e

File tree

4 files changed

+58
-21
lines changed

4 files changed

+58
-21
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,4 +587,4 @@ wasm-mutate = { git = "https://github.com/jsturtevant/wasm-tools.git", branch="
587587
wit-parser = { git = "https://github.com/jsturtevant/wasm-tools.git", branch="unstable-refrence-stable" }
588588
wit-component = { git = "https://github.com/jsturtevant/wasm-tools.git", branch="unstable-refrence-stable" }
589589
wasm-wave = { git = "https://github.com/jsturtevant/wasm-tools.git", branch="unstable-refrence-stable" }
590-
wasm-metadata = { git = "https://github.com/jsturtevant/wasm-tools.git", branch="unstable-refrence-stable" }
590+
wasm-metadata = { git = "https://github.com/jsturtevant/wasm-tools.git", branch="unstable-refrence-stable" }

crates/wasi-tls/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ webpki-roots = { workspace = true }
3333
test-programs-artifacts = { workspace = true }
3434
wasmtime-wasi = { workspace = true }
3535
tokio = { workspace = true, features = ["macros"] }
36+
futures = { workspace = true }

crates/wasi-tls/src/lib.rs

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,13 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
186186

187187
Ok(self
188188
.table
189-
.push(FutureClientStreams(StreamState::Pending(Box::pin(
190-
async move {
191-
let connector = tokio_rustls::TlsConnector::from(default_client_config());
192-
connector
193-
.connect(domain, streams)
194-
.await
195-
.with_context(|| "connection failed")
196-
},
197-
))))?)
189+
.push(FutureStreams(StreamState::Pending(Box::pin(async move {
190+
let connector = tokio_rustls::TlsConnector::from(default_client_config());
191+
connector
192+
.connect(domain, streams)
193+
.await
194+
.with_context(|| "connection failed")
195+
}))))?)
198196
}
199197

200198
fn drop(
@@ -206,21 +204,20 @@ impl<'a> generated::types::HostClientHandshake for WasiTlsCtx<'a> {
206204
}
207205
}
208206

209-
/// Future TLS connection after the handshake is completed.
210-
pub struct FutureClientStreams(StreamState<Result<TlsStream<WasiStreams>>>);
207+
/// Future streams provides the tls streams after the handshake is completed
208+
pub struct FutureStreams<T>(StreamState<Result<T>>);
209+
210+
/// Library specific version of TLS connection after the handshake is completed.
211+
/// This alias allows it to use with wit-bindgen component generator which won't take generic types
212+
pub type FutureClientStreams = FutureStreams<TlsStream<WasiStreams>>;
211213

212214
#[async_trait]
213-
impl Pollable for FutureClientStreams {
215+
impl<T: Send + 'static> Pollable for FutureStreams<T> {
214216
async fn ready(&mut self) {
215-
match &self.0 {
216-
StreamState::Pending(_) => (),
217+
match &mut self.0 {
217218
StreamState::Ready(_) | StreamState::Closed => return,
219+
StreamState::Pending(task) => self.0 = StreamState::Ready(task.as_mut().await),
218220
}
219-
220-
let StreamState::Pending(future) = mem::replace(&mut self.0, StreamState::Closed) else {
221-
unreachable!()
222-
};
223-
self.0 = StreamState::Ready(future.await);
224221
}
225222
}
226223

@@ -324,7 +321,8 @@ enum StreamState<T> {
324321
Closed,
325322
}
326323

327-
struct WasiStreams {
324+
/// Wrapper around Input and Output wasi IO Stream that provides Async Read/Write
325+
pub struct WasiStreams {
328326
input: StreamState<BoxInputStream>,
329327
output: StreamState<BoxOutputStream>,
330328
}
@@ -635,3 +633,40 @@ fn try_lock_for_stream<TlsWriter>(
635633
.try_lock()
636634
.map_err(|_| StreamError::trap("concurrent access to resource not supported"))
637635
}
636+
637+
#[cfg(test)]
638+
mod tests {
639+
use super::*;
640+
use tokio::sync::oneshot;
641+
642+
#[tokio::test]
643+
async fn test_future_client_streams_ready_can_be_canceled() {
644+
let (tx1, rx1) = oneshot::channel::<()>();
645+
646+
let mut future_streams = FutureStreams(StreamState::Pending(Box::pin(async move {
647+
rx1.await.map_err(|_| anyhow::anyhow!("oneshot canceled"))
648+
})));
649+
650+
let mut fut = future_streams.ready();
651+
652+
let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
653+
assert!(fut.as_mut().poll(&mut cx).is_pending());
654+
655+
//cancel the readiness check
656+
drop(fut);
657+
658+
match future_streams.0 {
659+
StreamState::Closed => panic!("First future should be in Pending/ready state"),
660+
_ => (),
661+
}
662+
663+
// make it ready and wait for it to progress
664+
tx1.send(()).unwrap();
665+
future_streams.ready().await;
666+
667+
match future_streams.0 {
668+
StreamState::Ready(Ok(())) => (),
669+
_ => panic!("First future should be in Ready(Err) state"),
670+
}
671+
}
672+
}

0 commit comments

Comments
 (0)