diff --git a/http/src/h1/dispatcher.rs b/http/src/h1/dispatcher.rs index df8e1efa0..7ecd87d1c 100644 --- a/http/src/h1/dispatcher.rs +++ b/http/src/h1/dispatcher.rs @@ -12,6 +12,7 @@ use std::net::Shutdown; use tracing::trace; use xitca_io::io::{AsyncBufRead, AsyncBufWrite}; use xitca_service::Service; +use xitca_service::shutdown::{ShutdownFutureExt, ShutdownToken}; use xitca_unsafe_collection::futures::SelectOutput; use crate::{ @@ -58,6 +59,7 @@ where config: HttpServiceConfig, service: &'a S, date: &'a D, + shutdown_token: &ShutdownToken, ) -> Result<(), Error> { let mut dispatcher = Dispatcher::<_, _, _, _, H_LIMIT, R_LIMIT, W_LIMIT> { io: SharedIo::new(io), @@ -90,15 +92,23 @@ where dispatcher.timer.update(dispatcher.ctx.date().now()); - match read_buf.read(dispatcher.io.io()).timeout(dispatcher.timer.get()).await { - Ok((res, r_buf)) => { + let is_read_buf_empty = read_buf.is_empty(); + + match read_buf + .read(dispatcher.io.io()) + .timeout(dispatcher.timer.get()) + .with_shutdown(shutdown_token, is_read_buf_empty) + .await + { + Some(Ok((res, r_buf))) => { read_buf = r_buf; if res? == 0 { break Ok(()); } } - Err(_) => break Err(dispatcher.timer.map_to_err()), + Some(Err(_)) => break Err(dispatcher.timer.map_to_err()), + None => break Ok(()), } }; diff --git a/http/src/h1/service.rs b/http/src/h1/service.rs index bb6ab053d..b6c6a4a39 100644 --- a/http/src/h1/service.rs +++ b/http/src/h1/service.rs @@ -1,8 +1,10 @@ use core::{net::SocketAddr, pin::pin}; +use std::sync::Arc; + use crate::body::Body; use xitca_io::io::{AsyncBufRead, AsyncBufWrite}; -use xitca_service::Service; +use xitca_service::{Service, shutdown::ShutdownToken}; use crate::{ body::RequestBody, @@ -18,7 +20,8 @@ pub type H1Service; impl - Service<(St, SocketAddr)> for H1Service + Service<((St, SocketAddr), Arc)> + for H1Service where S: Service>, Response = Response>, A: Service, @@ -29,7 +32,10 @@ where type Response = (); type Error = HttpServiceError; - async fn call(&self, (io, addr): (St, SocketAddr)) -> Result { + async fn call( + &self, + ((io, addr), st): ((St, SocketAddr), Arc), + ) -> Result { // at this stage keep-alive timer is used to tracks tls accept timeout. let mut timer = pin!(self.keep_alive()); @@ -48,6 +54,7 @@ where self.config, &self.service, self.date.get(), + &st, ) .await .map_err(Into::into) diff --git a/http/src/h2/service.rs b/http/src/h2/service.rs index cb178a993..406196245 100644 --- a/http/src/h2/service.rs +++ b/http/src/h2/service.rs @@ -1,7 +1,9 @@ use core::{fmt, net::SocketAddr, pin::pin}; +use std::sync::Arc; + use xitca_io::io::{AsyncBufRead, AsyncBufWrite}; -use xitca_service::Service; +use xitca_service::{Service, shutdown::ShutdownToken}; use crate::{ body::Body, @@ -19,7 +21,8 @@ pub type H2Service; impl - Service<(St, SocketAddr)> for H2Service + Service<((St, SocketAddr), Arc)> + for H2Service where S: Service>, Response = Response>, S::Error: fmt::Debug, @@ -32,7 +35,10 @@ where type Response = (); type Error = HttpServiceError; - async fn call(&self, (io, addr): (St, SocketAddr)) -> Result { + async fn call( + &self, + ((io, addr), _st): ((St, SocketAddr), Arc), + ) -> Result { // tls accept timer. let timer = self.keep_alive(); let mut timer = pin!(timer); diff --git a/http/src/h3/proto/dispatcher.rs b/http/src/h3/proto/dispatcher.rs index bbd62de7e..4f81f1a68 100644 --- a/http/src/h3/proto/dispatcher.rs +++ b/http/src/h3/proto/dispatcher.rs @@ -11,7 +11,10 @@ use ::h3::{ server::{self, RequestStream}, }; use xitca_io::net::QuicStream; -use xitca_service::Service; +use xitca_service::{ + Service, + shutdown::{ShutdownFutureExt, ShutdownToken}, +}; use xitca_unsafe_collection::futures::{Select, SelectOutput}; use crate::{ @@ -48,7 +51,7 @@ where } } - pub(crate) async fn run(self) -> Result<(), Error> { + pub(crate) async fn run(self, st: &ShutdownToken) -> Result<(), Error> { // wait for connecting. let conn = self.io.connecting().await?; @@ -60,8 +63,8 @@ where // accept loop loop { - match conn.accept().select(queue.next()).await { - SelectOutput::A(Ok(Some(req))) => { + match conn.accept().select(queue.next()).with_shutdown(st, true).await { + Some(SelectOutput::A(Ok(Some(req)))) => { queue.push(async move { let (req, stream) = req.resolve_request().await?; let (tx, rx) = stream.split(); @@ -76,13 +79,14 @@ where h3_handler(fut, tx).await }); } - SelectOutput::A(Ok(None)) => break, - SelectOutput::A(Err(e)) => return Err(e.into()), - SelectOutput::B(res) => { + Some(SelectOutput::A(Ok(None))) => break, + Some(SelectOutput::A(Err(e))) => return Err(e.into()), + Some(SelectOutput::B(res)) => { if let Err(e) = res { HttpServiceError::from(e).log("h3_dispatcher"); } } + None => break, } } diff --git a/http/src/h3/service.rs b/http/src/h3/service.rs index d2fda176c..751db29bc 100644 --- a/http/src/h3/service.rs +++ b/http/src/h3/service.rs @@ -1,8 +1,10 @@ use core::{fmt, net::SocketAddr}; +use std::sync::Arc; + use crate::body::Body; use xitca_io::net::QuicStream; -use xitca_service::{Service, ready::ReadyService}; +use xitca_service::{Service, ready::ReadyService, shutdown::ShutdownToken}; use crate::{ bytes::Bytes, @@ -24,7 +26,7 @@ impl H3Service { } } -impl Service<(QuicStream, SocketAddr)> for H3Service +impl Service<((QuicStream, SocketAddr), Arc)> for H3Service where S: Service>, Response = Response>, S::Error: fmt::Debug, @@ -33,10 +35,13 @@ where { type Response = (); type Error = HttpServiceError; - async fn call(&self, (stream, addr): (QuicStream, SocketAddr)) -> Result { + async fn call( + &self, + ((stream, addr), st): ((QuicStream, SocketAddr), Arc), + ) -> Result { let dispatcher = Dispatcher::new(stream, addr, &self.service); - dispatcher.run().await?; + dispatcher.run(st.as_ref()).await?; Ok(()) } diff --git a/http/src/service.rs b/http/src/service.rs index fa0214033..6819221fb 100644 --- a/http/src/service.rs +++ b/http/src/service.rs @@ -1,10 +1,12 @@ use core::{fmt, marker::PhantomData, pin::pin}; +use std::sync::Arc; + use xitca_io::{ io::{AsyncBufRead, AsyncBufWrite}, net::{Stream, TcpStream}, }; -use xitca_service::{Service, ready::ReadyService}; +use xitca_service::{Service, ready::ReadyService, shutdown::ShutdownToken}; use super::{ body::{Body, RequestBody}, @@ -169,6 +171,7 @@ where _tls_stream: impl AsVersion + AsyncBufRead + AsyncBufWrite + 'static, _addr: core::net::SocketAddr, mut _timer: core::pin::Pin<&mut KeepAlive>, + shutdown_token: &ShutdownToken, ) -> Result<(), HttpServiceError> { #[allow(unused_mut)] let mut version = _tls_stream.as_version(); @@ -197,6 +200,7 @@ where self.config, &self.service, self.date.get(), + shutdown_token, ) .await .map_err(From::from), @@ -225,7 +229,7 @@ where } #[cfg(feature = "io-uring")] -impl Service +impl Service<(Stream, Arc)> for HttpService where S: Service>, Response = Response>, @@ -237,7 +241,7 @@ where type Response = (); type Error = HttpServiceError; - async fn call(&self, io: Stream) -> Result { + async fn call(&self, (io, st): (Stream, Arc)) -> Result { // tls accept timer. let timer = self.keep_alive(); let mut timer = pin!(timer); @@ -245,7 +249,7 @@ where match io { #[cfg(feature = "http3")] Stream::Udp(io, addr) => super::h3::Dispatcher::new(io, addr, &self.service) - .run() + .run(st.as_ref()) .await .map_err(From::from), Stream::Tcp(io, _addr) => { @@ -258,7 +262,7 @@ where .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))? .map_err(Into::into)?; - self.dispatch(_tls_stream, _addr, timer.as_mut()).await + self.dispatch(_tls_stream, _addr, timer.as_mut(), st.as_ref()).await } #[cfg(unix)] Stream::Unix(_io, _) => { @@ -271,14 +275,15 @@ where .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))? .map_err(Into::into)?; - self.dispatch(_tls_stream, crate::unspecified_socket_addr(), timer.as_mut()) + self.dispatch(_tls_stream, crate::unspecified_socket_addr(), timer.as_mut(), st.as_ref()) .await } } } } -impl Service +impl + Service<(Stream, Arc)> for HttpService where S: Service>, Response = Response>, @@ -290,7 +295,7 @@ where type Response = (); type Error = HttpServiceError; - async fn call(&self, io: Stream) -> Result { + async fn call(&self, (io, st): (Stream, Arc)) -> Result { // tls accept timer. let timer = self.keep_alive(); let mut timer = pin!(timer); @@ -298,7 +303,7 @@ where match io { #[cfg(feature = "http3")] Stream::Udp(io, addr) => super::h3::Dispatcher::new(io, addr, &self.service) - .run() + .run(st.as_ref()) .await .map_err(From::from), Stream::Tcp(io, _addr) => { @@ -311,7 +316,7 @@ where .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))? .map_err(Into::into)?; - self.dispatch(_tls_stream, _addr, timer.as_mut()).await + self.dispatch(_tls_stream, _addr, timer.as_mut(), st.as_ref()).await } #[cfg(unix)] Stream::Unix(_io, _) => { @@ -324,8 +329,13 @@ where .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))? .map_err(Into::into)?; - self.dispatch(_tls_stream, crate::unspecified_socket_addr(), timer.as_mut()) - .await + self.dispatch( + _tls_stream, + crate::unspecified_socket_addr(), + timer.as_mut(), + st.as_ref(), + ) + .await } } } diff --git a/http/src/util/middleware/socket_config.rs b/http/src/util/middleware/socket_config.rs index 55b587c11..f49cd6b44 100644 --- a/http/src/util/middleware/socket_config.rs +++ b/http/src/util/middleware/socket_config.rs @@ -1,8 +1,8 @@ use core::{net::SocketAddr, time::Duration}; -use std::io; - use socket2::{SockRef, TcpKeepalive}; +use std::io; +use std::sync::Arc; use tracing::warn; use xitca_io::net::{Stream as ServerStream, TcpStream}; @@ -10,6 +10,7 @@ use xitca_service::{Service, ready::ReadyService}; #[cfg(unix)] use xitca_io::net::UnixStream; +use xitca_service::shutdown::ShutdownToken; /// A middleware for socket options config of `TcpStream` and `UnixStream`. #[derive(Clone, Debug)] @@ -90,30 +91,36 @@ where } } -impl Service<(TcpStream, SocketAddr)> for SocketConfigService +impl Service<(TcpStream, SocketAddr, Arc)> for SocketConfigService where - S: Service<(TcpStream, SocketAddr)>, + S: Service<(TcpStream, SocketAddr, Arc)>, { type Response = S::Response; type Error = S::Error; - async fn call(&self, (stream, addr): (TcpStream, SocketAddr)) -> Result { + async fn call( + &self, + (stream, addr, st): (TcpStream, SocketAddr, Arc), + ) -> Result { self.try_apply_config(&stream); - self.service.call((stream, addr)).await + self.service.call((stream, addr, st)).await } } #[cfg(unix)] -impl Service<(UnixStream, SocketAddr)> for SocketConfigService +impl Service<(UnixStream, SocketAddr, Arc)> for SocketConfigService where - S: Service<(UnixStream, SocketAddr)>, + S: Service<(UnixStream, SocketAddr, Arc)>, { type Response = S::Response; type Error = S::Error; - async fn call(&self, (stream, addr): (UnixStream, SocketAddr)) -> Result { + async fn call( + &self, + (stream, addr, st): (UnixStream, SocketAddr, Arc), + ) -> Result { self.try_apply_config(&stream); - self.service.call((stream, addr)).await + self.service.call((stream, addr, st)).await } } diff --git a/server/src/lib.rs b/server/src/lib.rs index bb2a0367a..b6e2204a1 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -17,14 +17,20 @@ compile_error!("io_uring can only be used on linux system"); #[cfg(test)] mod test { + use std::sync::Arc; use xitca_io::net::TcpStream; use xitca_service::fn_service; + use xitca_service::shutdown::ShutdownToken; #[test] fn test_builder() { let listener = std::net::TcpListener::bind("localhost:0").unwrap(); let _server = crate::builder::Builder::new() - .listen("test", listener, fn_service(|_: TcpStream| async { Ok::<_, ()>(()) })) + .listen( + "test", + listener, + fn_service(|_: (TcpStream, Arc)| async { Ok::<_, ()>(()) }), + ) .build(); } } diff --git a/server/src/net/mod.rs b/server/src/net/mod.rs index 09bf066fd..808eae827 100644 --- a/server/src/net/mod.rs +++ b/server/src/net/mod.rs @@ -15,11 +15,12 @@ use xitca_io::net::{QuicListener, QuicListenerBuilder}; /// /// # Examples /// ```rust -/// use std::io; +/// use std::{io, sync::Arc}; /// /// use xitca_io::net::Stream; /// use xitca_server::net::{IntoListener, Listen}; /// use xitca_service::fn_service; +/// use xitca_service::shutdown::ShutdownToken; /// /// // arbitrary socket type /// struct MySocket; @@ -45,7 +46,7 @@ use xitca_io::net::{QuicListener, QuicListenerBuilder}; /// } /// /// // service function receive connection stream from MySocket's Listen::accept method -/// let service = fn_service(async |stream: Stream| { +/// let service = fn_service(async |(stream, st): (Stream, Arc)| { /// Ok::<_, io::Error>(()) /// }); /// diff --git a/server/src/server/future.rs b/server/src/server/future.rs index 9470b80b8..682f40b85 100644 --- a/server/src/server/future.rs +++ b/server/src/server/future.rs @@ -28,13 +28,15 @@ impl ServerFuture { /// # Examples: /// /// ```rust + /// # use std::sync::Arc; /// # use xitca_io::net::{TcpStream}; /// # use xitca_server::Builder; /// # use xitca_service::fn_service; + /// # use xitca_service::shutdown::ShutdownToken; /// # #[tokio::main] /// # async fn main() { /// let mut server = Builder::new() - /// .bind("test", "127.0.0.1:0", fn_service(|_io: TcpStream| async { Ok::<_, ()>(())})) + /// .bind("test", "127.0.0.1:0", fn_service(|(_io, _st): (TcpStream, Arc)| async { Ok::<_, ()>(())})) /// .unwrap() /// .build(); /// @@ -52,9 +54,11 @@ impl ServerFuture { match *self { Self::Init { ref server, .. } => Ok(ServerHandle { tx: server.tx_cmd.clone(), + shutdown_token: server.shutdown_token.clone(), }), Self::Running(ref inner) => Ok(ServerHandle { tx: inner.server.tx_cmd.clone(), + shutdown_token: inner.server.shutdown_token.clone(), }), Self::Error(_) => match mem::take(self) { Self::Error(e) => Err(e), diff --git a/server/src/server/handle.rs b/server/src/server/handle.rs index 2fc90d358..a170075ad 100644 --- a/server/src/server/handle.rs +++ b/server/src/server/handle.rs @@ -1,10 +1,15 @@ +use std::sync::Arc; + use tokio::sync::mpsc::UnboundedSender; +use xitca_service::shutdown::ShutdownToken; + use super::Command; #[derive(Clone)] pub struct ServerHandle { pub(super) tx: UnboundedSender, + pub(super) shutdown_token: Arc, } impl ServerHandle { @@ -17,5 +22,6 @@ impl ServerHandle { }; let _ = self.tx.send(cmd); + self.shutdown_token.shutdown(); } } diff --git a/server/src/server/mod.rs b/server/src/server/mod.rs index aa8317fa5..60b403a42 100644 --- a/server/src/server/mod.rs +++ b/server/src/server/mod.rs @@ -20,6 +20,8 @@ use tokio::{ sync::mpsc::{UnboundedReceiver, UnboundedSender}, }; +use xitca_service::shutdown::ShutdownToken; + use crate::{builder::Builder, worker}; pub struct Server { @@ -28,6 +30,7 @@ pub struct Server { rx_cmd: UnboundedReceiver, rt: Option, worker_join_handles: Vec>>, + shutdown_token: Arc, } impl Server { @@ -114,6 +117,8 @@ impl Server { let is_graceful_shutdown = Arc::new(AtomicBool::new(false)); let is_graceful_shutdown2 = is_graceful_shutdown.clone(); + let st = Arc::new(ShutdownToken::new()); + let st2 = st.clone(); let worker_handles = thread::Builder::new() .name(String::from("xitca-server-worker-shared-scope")) @@ -133,7 +138,7 @@ impl Server { let mut services = Vec::new(); for (name, factory) in factories.iter() { - match factory.call((name, &listeners)).await { + match factory.call((name, &listeners, st2.clone())).await { Ok((h, s)) => { handles.extend(h); services.push(s); @@ -176,6 +181,7 @@ impl Server { is_graceful_shutdown, tx_cmd, rx_cmd, + shutdown_token: st, rt: Some(rt), worker_join_handles: vec![worker_handles], }) @@ -184,6 +190,8 @@ impl Server { pub(crate) fn stop(&mut self, graceful: bool) { if let Some(rt) = self.rt.take() { self.is_graceful_shutdown.store(graceful, Ordering::SeqCst); + self.shutdown_token.shutdown(); + rt.shutdown_background(); mem::take(&mut self.worker_join_handles).into_iter().for_each(|handle| { let _ = handle.join().unwrap(); diff --git a/server/src/server/service.rs b/server/src/server/service.rs index bb359bb51..d7abd7be0 100644 --- a/server/src/server/service.rs +++ b/server/src/server/service.rs @@ -1,10 +1,10 @@ use core::marker::PhantomData; -use std::rc::Rc; +use std::{rc::Rc, sync::Arc}; use tokio::task::JoinHandle; use xitca_io::net::Stream; -use xitca_service::{Service, ready::ReadyService}; +use xitca_service::{Service, ready::ReadyService, shutdown::ShutdownToken}; use crate::{ net::ListenerDyn, @@ -13,7 +13,7 @@ use crate::{ pub type ServiceObj = Box< dyn for<'a> xitca_service::object::ServiceObject< - (&'a str, &'a [(String, ListenerDyn)]), + (&'a str, &'a [(String, ListenerDyn)], Arc), Response = (Vec>, ServiceAny), Error = (), > + Send @@ -25,7 +25,7 @@ struct Container { _t: PhantomData, } -impl<'a, F, Req> Service<(&'a str, &'a [(String, ListenerDyn)])> for Container +impl<'a, F, Req> Service<(&'a str, &'a [(String, ListenerDyn)], Arc)> for Container where F: IntoServiceObj, Req: TryFrom + 'static, @@ -35,7 +35,7 @@ where async fn call( &self, - (name, listeners): (&'a str, &'a [(String, ListenerDyn)]), + (name, listeners, st): (&'a str, &'a [(String, ListenerDyn)], Arc), ) -> Result { let service = self.inner.call(()).await.map_err(|_| ())?; let service = Rc::new(service); @@ -43,7 +43,7 @@ where let handles = listeners .iter() .filter(|(n, _)| n == name) - .map(|(_, listener)| worker::start(listener, &service)) + .map(|(_, listener)| worker::start(listener, &service, st.clone())) .collect::>(); Ok((handles, service as _)) @@ -56,7 +56,7 @@ where Self: Service + Send + Sync + 'static, Req: TryFrom + 'static, { - type Service: ReadyService + Service; + type Service: ReadyService + Service<(Req, Arc)>; fn into_object(self) -> ServiceObj; } @@ -64,7 +64,7 @@ where impl IntoServiceObj for T where T: Service + Send + Sync + 'static, - T::Response: ReadyService + Service, + T::Response: ReadyService + Service<(Req, Arc)>, Req: TryFrom + 'static, { type Service = T::Response; diff --git a/server/src/worker/mod.rs b/server/src/worker/mod.rs index b93589d13..84d8a4410 100644 --- a/server/src/worker/mod.rs +++ b/server/src/worker/mod.rs @@ -2,12 +2,12 @@ mod shutdown; use core::{any::Any, sync::atomic::AtomicBool, time::Duration}; -use std::{io, rc::Rc, thread}; +use std::{io, rc::Rc, sync::Arc, thread}; use tokio::{task::JoinHandle, time::sleep}; use tracing::{error, info}; use xitca_io::net::Stream; -use xitca_service::{Service, ready::ReadyService}; +use xitca_service::{Service, ready::ReadyService, shutdown::ShutdownToken}; use crate::net::ListenerDyn; @@ -16,9 +16,9 @@ use self::shutdown::ShutdownHandle; // erase Rc> type and only use it for counting the reference counter of Rc. pub(crate) type ServiceAny = Rc; -pub(crate) fn start(listener: &ListenerDyn, service: &Rc) -> JoinHandle<()> +pub(crate) fn start(listener: &ListenerDyn, service: &Rc, st: Arc) -> JoinHandle<()> where - S: ReadyService + Service + 'static, + S: ReadyService + Service<(Req, Arc)> + 'static, S::Ready: 'static, Req: TryFrom + 'static, { @@ -33,8 +33,10 @@ where Ok(stream) => { if let Ok(req) = TryFrom::try_from(stream) { let service = service.clone(); + let st = st.clone(); + tokio::task::spawn_local(async move { - let _ = service.call(req).await; + let _ = service.call((req, st)).await; drop(ready); }); } diff --git a/service/Cargo.toml b/service/Cargo.toml index 38a9f0c11..3835631f1 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -12,8 +12,11 @@ readme= "README.md" [lints] workspace = true +[dependencies] +spin = { version = "0.9", default-features = false, features = ["mutex", "spin_mutex"], optional = true } + [features] -alloc = [] +alloc = ["dep:spin"] [dev-dependencies] xitca-unsafe-collection = "0.2.0" diff --git a/service/src/lib.rs b/service/src/lib.rs index a4110e4ff..beadca6de 100644 --- a/service/src/lib.rs +++ b/service/src/lib.rs @@ -36,6 +36,9 @@ pub mod middleware; pub mod pipeline; pub mod ready; +#[cfg(feature = "alloc")] +pub mod shutdown; + pub use self::{ middleware::{EnclosedBuilder, EnclosedFnBuilder, MapBuilder, MapErrorBuilder}, service::{FnService, Service, ServiceExt, fn_build, fn_service}, diff --git a/service/src/shutdown.rs b/service/src/shutdown.rs new file mode 100644 index 000000000..e3e1466a0 --- /dev/null +++ b/service/src/shutdown.rs @@ -0,0 +1,262 @@ +use alloc::{boxed::Box, vec::Vec}; + +use core::{ + future::Future, + pin::Pin, + task::{Context, Poll, Waker}, +}; + +use spin::Mutex; + +/// A thread-safe token that tracks whether a shutdown has been requested. +/// +/// `ShutdownToken` is designed to be shared by reference (`&ShutdownToken`). +/// Calling [`ShutdownToken::shutdown`] marks the token as shut down and +/// immediately wakes all futures that are waiting on it. +pub struct ShutdownToken { + inner: Mutex, +} + +struct ShutdownInner { + is_shutdown: bool, + wakers: Vec, +} + +impl ShutdownToken { + /// Create a new `ShutdownToken` in the non-shutdown state. + pub const fn new() -> Self { + Self { + inner: Mutex::new(ShutdownInner { + is_shutdown: false, + wakers: Vec::new(), + }), + } + } + + /// Mark this token as shut down, waking all futures that are waiting on it. + /// + /// Any armed [`ShutdownFuture`] referencing this token will resolve to + /// `None` as soon as its executor re-polls it. + pub fn shutdown(&self) { + let wakers = { + let mut inner = self.inner.lock(); + inner.is_shutdown = true; + // Take the wakers out so we can wake them outside the lock. + core::mem::take(&mut inner.wakers) + }; + for waker in wakers { + waker.wake(); + } + } + + /// Returns `true` if [`ShutdownToken::shutdown`] has been called. + pub fn is_shutdown(&self) -> bool { + self.inner.lock().is_shutdown + } + + /// Register a waker to be notified when shutdown occurs. + /// + /// If shutdown has already happened the waker is woken immediately. + fn register_waker(&self, waker: &Waker) { + let mut inner = self.inner.lock(); + + if inner.is_shutdown { + drop(inner); + waker.wake_by_ref(); + return; + } + + // Avoid duplicating wakers for the same task. + for existing in inner.wakers.iter() { + if existing.will_wake(waker) { + return; + } + } + inner.wakers.push(waker.clone()); + } +} + +impl Default for ShutdownToken { + fn default() -> Self { + Self::new() + } +} + +/// Extension trait that adds a `.with_shutdown(token, armed)` combinator to any [`Future`]. +pub trait ShutdownFutureExt: Future + Sized { + /// Wrap this future so that it can be cancelled by a [`ShutdownToken`]. + /// + /// The inner future is boxed so that it can be pinned without requiring + /// `Unpin`. + /// + /// - When `armed` is `true` **and** the token is (or becomes) shut down, the + /// future resolves to `None` — the executor is woken immediately when + /// [`ShutdownToken::shutdown`] is called. + /// - When `armed` is `false`, the shutdown token is ignored and the future + /// behaves as if it were unwrapped, resolving to `Some(output)`. + /// + /// This lets the caller decide at construction time whether to honour + /// shutdown — for example, only opt in when a buffer is empty so that + /// in-flight work is not interrupted mid-way. + fn with_shutdown(self, token: &ShutdownToken, armed: bool) -> ShutdownFuture<'_, Self>; +} + +impl ShutdownFutureExt for F { + fn with_shutdown(self, token: &ShutdownToken, armed: bool) -> ShutdownFuture<'_, Self> { + ShutdownFuture { + future: Box::pin(self), + token, + armed, + } + } +} + +/// A future that resolves to `Some(T)` if the inner future completes (or if +/// shutdown is not armed), or `None` if shutdown is armed and the token is +/// (or becomes) shut down. +pub struct ShutdownFuture<'a, F: Future> { + future: Pin>, + token: &'a ShutdownToken, + armed: bool, +} + +impl Future for ShutdownFuture<'_, F> { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + if !this.armed { + return this.future.as_mut().poll(cx).map(Some); + } + + if this.token.is_shutdown() { + return Poll::Ready(None); + } + + if let Poll::Ready(output) = this.future.as_mut().poll(cx) { + return Poll::Ready(Some(output)); + } + + // Register the waker so we get notified when shutdown occurs. + // register_waker atomically checks is_shutdown inside the lock, + // so no separate double-check is needed. + this.token.register_waker(cx.waker()); + + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_starts_not_shutdown() { + let token = ShutdownToken::new(); + assert!(!token.is_shutdown()); + } + + #[test] + fn token_shutdown_sets_state() { + let token = ShutdownToken::new(); + token.shutdown(); + assert!(token.is_shutdown()); + } + + #[test] + fn shared_ref_observes_shutdown() { + let token = ShutdownToken::new(); + let r = &token; + token.shutdown(); + assert!(r.is_shutdown()); + } + + #[test] + fn armed_and_already_shutdown_resolves_none() { + let token = ShutdownToken::new(); + token.shutdown(); + + let fut = core::future::ready(42).with_shutdown(&token, true); + let result = pollster_block_on(fut); + assert_eq!(result, None); + } + + #[test] + fn not_armed_ignores_shutdown() { + let token = ShutdownToken::new(); + token.shutdown(); + + let fut = core::future::ready(42).with_shutdown(&token, false); + let result = pollster_block_on(fut); + assert_eq!(result, Some(42)); + } + + #[test] + fn armed_future_completes_before_shutdown() { + let token = ShutdownToken::new(); + let fut = core::future::ready(42).with_shutdown(&token, true); + let result = pollster_block_on(fut); + assert_eq!(result, Some(42)); + } + + #[test] + fn shutdown_wakes_registered_waker() { + use alloc::sync::Arc; + use alloc::task::Wake; + use core::sync::atomic::{AtomicUsize, Ordering}; + + struct CountWake(AtomicUsize); + impl Wake for CountWake { + fn wake(self: Arc) { + self.0.fetch_add(1, Ordering::SeqCst); + } + } + let counter = Arc::new(CountWake(AtomicUsize::new(0))); + let waker = Waker::from(counter.clone()); + + let token = ShutdownToken::new(); + token.register_waker(&waker); + + assert_eq!(counter.0.load(Ordering::SeqCst), 0); + token.shutdown(); + assert_eq!(counter.0.load(Ordering::SeqCst), 1); + } + + #[test] + fn duplicate_waker_not_registered_twice() { + use alloc::sync::Arc; + use alloc::task::Wake; + use core::sync::atomic::{AtomicUsize, Ordering}; + + struct CountWake(AtomicUsize); + impl Wake for CountWake { + fn wake(self: Arc) { + self.0.fetch_add(1, Ordering::SeqCst); + } + } + let counter = Arc::new(CountWake(AtomicUsize::new(0))); + let waker = Waker::from(counter.clone()); + + let token = ShutdownToken::new(); + token.register_waker(&waker); + token.register_waker(&waker); + + token.shutdown(); + assert_eq!(counter.0.load(Ordering::SeqCst), 1); + } + + /// Minimal single-threaded block_on for tests (no tokio needed). + fn pollster_block_on(fut: F) -> F::Output { + let mut fut = core::pin::pin!(fut); + let waker = core::task::Waker::noop(); + let mut cx = Context::from_waker(waker); + + match fut.as_mut().poll(&mut cx) { + Poll::Ready(val) => return val, + Poll::Pending => { + panic!("future returned Pending in a synchronous test"); + } + } + } +} diff --git a/test/src/lib.rs b/test/src/lib.rs index 495a340d4..7025b5dc8 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -5,6 +5,7 @@ use std::{ net::SocketAddr, net::TcpListener, pin::Pin, + sync::Arc, task::{Context, Poll}, time::Duration, }; @@ -18,7 +19,7 @@ use xitca_http::{ }; use xitca_io::{bytes::Bytes, net::Stream as NetStream}; use xitca_server::{Builder, ServerFuture, ServerHandle}; -use xitca_service::{Service, ServiceExt, ready::ReadyService}; +use xitca_service::{Service, ServiceExt, ready::ReadyService, shutdown::ShutdownToken}; pub type Error = Box; @@ -29,7 +30,7 @@ type HResponse = Response; pub fn test_server(service: T) -> Result where T: Service + Send + Sync + 'static, - T::Response: ReadyService + Service, + T::Response: ReadyService + Service<(Req, Arc)>, Req: TryFrom + 'static, { let lst = TcpListener::bind("127.0.0.1:0")?; diff --git a/test/tests/h1.rs b/test/tests/h1.rs index 0bf9929b8..2e1d2e3b9 100644 --- a/test/tests/h1.rs +++ b/test/tests/h1.rs @@ -33,7 +33,7 @@ async fn h1_get() -> Result<(), Error> { assert_eq!("GET Response", body); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; @@ -61,7 +61,7 @@ async fn h1_get_without_body_reading() -> Result<(), Error> { let body = res.string().await?; assert_eq!("GET Response", body); - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; Ok(()) @@ -83,7 +83,7 @@ async fn h1_head() -> Result<(), Error> { assert_eq!("", body); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; @@ -112,7 +112,7 @@ async fn h1_post() -> Result<(), Error> { assert_eq!(body.len(), body_len); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; diff --git a/test/tests/h2.rs b/test/tests/h2.rs index 610a83215..55e459afd 100644 --- a/test/tests/h2.rs +++ b/test/tests/h2.rs @@ -26,7 +26,7 @@ async fn h2_get() -> Result<(), Error> { assert_eq!("GET Response", body); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; @@ -52,7 +52,7 @@ async fn h2_no_host_header() -> Result<(), Error> { assert_eq!("", body); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; @@ -78,7 +78,7 @@ async fn h2_post() -> Result<(), Error> { let _ = res.body().await; } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; @@ -118,7 +118,7 @@ async fn h2_connect() -> Result<(), Error> { core::future::poll_fn(|cx| core::pin::Pin::new(&mut tunnel).poll_shutdown(cx)).await?; - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; diff --git a/test/tests/h3.rs b/test/tests/h3.rs index a14e81ab9..bcf99d9cc 100644 --- a/test/tests/h3.rs +++ b/test/tests/h3.rs @@ -23,7 +23,7 @@ async fn h3_get() -> Result<(), Error> { assert_eq!("GET Response", body); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; @@ -49,7 +49,7 @@ async fn h3_no_host_header() -> Result<(), Error> { assert_eq!("", body); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?; @@ -74,7 +74,7 @@ async fn h3_post() -> Result<(), Error> { assert!(!res.can_close_connection()); } - handle.try_handle()?.stop(false); + handle.try_handle()?.stop(true); handle.await?;