Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions http/src/h1/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::net::Shutdown;
use tracing::trace;
use xitca_io::io::{AsyncBufRead, AsyncBufWrite};
use xitca_service::Service;
use xitca_service::shutdown::{ShutdownFutureExt, ShutdownToken};
use xitca_unsafe_collection::futures::SelectOutput;

use crate::{
Expand Down Expand Up @@ -58,6 +59,7 @@ where
config: HttpServiceConfig<H_LIMIT, R_LIMIT, W_LIMIT>,
service: &'a S,
date: &'a D,
shutdown_token: &ShutdownToken,
) -> Result<(), Error<S::Error, BE>> {
let mut dispatcher = Dispatcher::<_, _, _, _, H_LIMIT, R_LIMIT, W_LIMIT> {
io: SharedIo::new(io),
Expand Down Expand Up @@ -90,15 +92,23 @@ where

dispatcher.timer.update(dispatcher.ctx.date().now());

match read_buf.read(dispatcher.io.io()).timeout(dispatcher.timer.get()).await {
Ok((res, r_buf)) => {
let is_read_buf_empty = read_buf.is_empty();

match read_buf
.read(dispatcher.io.io())
.timeout(dispatcher.timer.get())
.with_shutdown(shutdown_token, is_read_buf_empty)
.await
{
Some(Ok((res, r_buf))) => {
read_buf = r_buf;

if res? == 0 {
break Ok(());
}
}
Err(_) => break Err(dispatcher.timer.map_to_err()),
Some(Err(_)) => break Err(dispatcher.timer.map_to_err()),
None => break Ok(()),
}
};

Expand Down
13 changes: 10 additions & 3 deletions http/src/h1/service.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use core::{net::SocketAddr, pin::pin};

use std::sync::Arc;

use crate::body::Body;
use xitca_io::io::{AsyncBufRead, AsyncBufWrite};
use xitca_service::Service;
use xitca_service::{Service, shutdown::ShutdownToken};

use crate::{
body::RequestBody,
Expand All @@ -18,7 +20,8 @@ pub type H1Service<St, Io, S, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT
HttpService<marker::Http1, Io, St, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>;

impl<St, Io, S, B, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const WRITE_BUF_LIMIT: usize>
Service<(St, SocketAddr)> for H1Service<St, Io, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
Service<((St, SocketAddr), Arc<ShutdownToken>)>
for H1Service<St, Io, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
A: Service<St>,
Expand All @@ -29,7 +32,10 @@ where
type Response = ();
type Error = HttpServiceError<S::Error, B::Error>;

async fn call(&self, (io, addr): (St, SocketAddr)) -> Result<Self::Response, Self::Error> {
async fn call(
&self,
((io, addr), st): ((St, SocketAddr), Arc<ShutdownToken>),
) -> Result<Self::Response, Self::Error> {
// at this stage keep-alive timer is used to tracks tls accept timeout.
let mut timer = pin!(self.keep_alive());

Expand All @@ -48,6 +54,7 @@ where
self.config,
&self.service,
self.date.get(),
&st,
)
.await
.map_err(Into::into)
Expand Down
12 changes: 9 additions & 3 deletions http/src/h2/service.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use core::{fmt, net::SocketAddr, pin::pin};

use std::sync::Arc;

use xitca_io::io::{AsyncBufRead, AsyncBufWrite};
use xitca_service::Service;
use xitca_service::{Service, shutdown::ShutdownToken};

use crate::{
body::Body,
Expand All @@ -19,7 +21,8 @@ pub type H2Service<St, Io, S, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT
HttpService<marker::Http2, Io, St, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>;

impl<St, Io, S, B, A, TlsSt, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const WRITE_BUF_LIMIT: usize>
Service<(St, SocketAddr)> for H2Service<St, Io, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
Service<((St, SocketAddr), Arc<ShutdownToken>)>
for H2Service<St, Io, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
S::Error: fmt::Debug,
Expand All @@ -32,7 +35,10 @@ where
type Response = ();
type Error = HttpServiceError<S::Error, B::Error>;

async fn call(&self, (io, addr): (St, SocketAddr)) -> Result<Self::Response, Self::Error> {
async fn call(
&self,
((io, addr), _st): ((St, SocketAddr), Arc<ShutdownToken>),
) -> Result<Self::Response, Self::Error> {
// tls accept timer.
let timer = self.keep_alive();
let mut timer = pin!(timer);
Expand Down
18 changes: 11 additions & 7 deletions http/src/h3/proto/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use ::h3::{
server::{self, RequestStream},
};
use xitca_io::net::QuicStream;
use xitca_service::Service;
use xitca_service::{
Service,
shutdown::{ShutdownFutureExt, ShutdownToken},
};
use xitca_unsafe_collection::futures::{Select, SelectOutput};

use crate::{
Expand Down Expand Up @@ -48,7 +51,7 @@ where
}
}

pub(crate) async fn run(self) -> Result<(), Error<S::Error, BE>> {
pub(crate) async fn run(self, st: &ShutdownToken) -> Result<(), Error<S::Error, BE>> {
// wait for connecting.
let conn = self.io.connecting().await?;

Expand All @@ -60,8 +63,8 @@ where

// accept loop
loop {
match conn.accept().select(queue.next()).await {
SelectOutput::A(Ok(Some(req))) => {
match conn.accept().select(queue.next()).with_shutdown(st, true).await {
Some(SelectOutput::A(Ok(Some(req)))) => {
queue.push(async move {
let (req, stream) = req.resolve_request().await?;
let (tx, rx) = stream.split();
Expand All @@ -76,13 +79,14 @@ where
h3_handler(fut, tx).await
});
}
SelectOutput::A(Ok(None)) => break,
SelectOutput::A(Err(e)) => return Err(e.into()),
SelectOutput::B(res) => {
Some(SelectOutput::A(Ok(None))) => break,
Some(SelectOutput::A(Err(e))) => return Err(e.into()),
Some(SelectOutput::B(res)) => {
if let Err(e) = res {
HttpServiceError::from(e).log("h3_dispatcher");
}
}
None => break,
}
}

Expand Down
13 changes: 9 additions & 4 deletions http/src/h3/service.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use core::{fmt, net::SocketAddr};

use std::sync::Arc;

use crate::body::Body;
use xitca_io::net::QuicStream;
use xitca_service::{Service, ready::ReadyService};
use xitca_service::{Service, ready::ReadyService, shutdown::ShutdownToken};

use crate::{
bytes::Bytes,
Expand All @@ -24,7 +26,7 @@ impl<S> H3Service<S> {
}
}

impl<S, ResB, BE> Service<(QuicStream, SocketAddr)> for H3Service<S>
impl<S, ResB, BE> Service<((QuicStream, SocketAddr), Arc<ShutdownToken>)> for H3Service<S>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<ResB>>,
S::Error: fmt::Debug,
Expand All @@ -33,10 +35,13 @@ where
{
type Response = ();
type Error = HttpServiceError<S::Error, BE>;
async fn call(&self, (stream, addr): (QuicStream, SocketAddr)) -> Result<Self::Response, Self::Error> {
async fn call(
&self,
((stream, addr), st): ((QuicStream, SocketAddr), Arc<ShutdownToken>),
) -> Result<Self::Response, Self::Error> {
let dispatcher = Dispatcher::new(stream, addr, &self.service);

dispatcher.run().await?;
dispatcher.run(st.as_ref()).await?;

Ok(())
}
Expand Down
34 changes: 22 additions & 12 deletions http/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use core::{fmt, marker::PhantomData, pin::pin};

use std::sync::Arc;

use xitca_io::{
io::{AsyncBufRead, AsyncBufWrite},
net::{Stream, TcpStream},
};
use xitca_service::{Service, ready::ReadyService};
use xitca_service::{Service, ready::ReadyService, shutdown::ShutdownToken};

use super::{
body::{Body, RequestBody},
Expand Down Expand Up @@ -169,6 +171,7 @@ where
_tls_stream: impl AsVersion + AsyncBufRead + AsyncBufWrite + 'static,
_addr: core::net::SocketAddr,
mut _timer: core::pin::Pin<&mut KeepAlive>,
shutdown_token: &ShutdownToken,
) -> Result<(), HttpServiceError<S::Error, B::Error>> {
#[allow(unused_mut)]
let mut version = _tls_stream.as_version();
Expand Down Expand Up @@ -197,6 +200,7 @@ where
self.config,
&self.service,
self.date.get(),
shutdown_token,
)
.await
.map_err(From::from),
Expand Down Expand Up @@ -225,7 +229,7 @@ where
}

#[cfg(feature = "io-uring")]
impl<S, B, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const WRITE_BUF_LIMIT: usize> Service<Stream>
impl<S, B, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const WRITE_BUF_LIMIT: usize> Service<(Stream, Arc<ShutdownToken>)>
for HttpService<marker::Http, marker::Uring, Stream, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
Expand All @@ -237,15 +241,15 @@ where
type Response = ();
type Error = HttpServiceError<S::Error, B::Error>;

async fn call(&self, io: Stream) -> Result<Self::Response, Self::Error> {
async fn call(&self, (io, st): (Stream, Arc<ShutdownToken>)) -> Result<Self::Response, Self::Error> {
// tls accept timer.
let timer = self.keep_alive();
let mut timer = pin!(timer);

match io {
#[cfg(feature = "http3")]
Stream::Udp(io, addr) => super::h3::Dispatcher::new(io, addr, &self.service)
.run()
.run(st.as_ref())
.await
.map_err(From::from),
Stream::Tcp(io, _addr) => {
Expand All @@ -258,7 +262,7 @@ where
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))?
.map_err(Into::into)?;

self.dispatch(_tls_stream, _addr, timer.as_mut()).await
self.dispatch(_tls_stream, _addr, timer.as_mut(), st.as_ref()).await
}
#[cfg(unix)]
Stream::Unix(_io, _) => {
Expand All @@ -271,14 +275,15 @@ where
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))?
.map_err(Into::into)?;

self.dispatch(_tls_stream, crate::unspecified_socket_addr(), timer.as_mut())
self.dispatch(_tls_stream, crate::unspecified_socket_addr(), timer.as_mut(), st.as_ref())
.await
}
}
}
}

impl<S, B, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const WRITE_BUF_LIMIT: usize> Service<Stream>
impl<S, B, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const WRITE_BUF_LIMIT: usize>
Service<(Stream, Arc<ShutdownToken>)>
for HttpService<marker::Http, marker::Poll, Stream, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
Expand All @@ -290,15 +295,15 @@ where
type Response = ();
type Error = HttpServiceError<S::Error, B::Error>;

async fn call(&self, io: Stream) -> Result<Self::Response, Self::Error> {
async fn call(&self, (io, st): (Stream, Arc<ShutdownToken>)) -> Result<Self::Response, Self::Error> {
// tls accept timer.
let timer = self.keep_alive();
let mut timer = pin!(timer);

match io {
#[cfg(feature = "http3")]
Stream::Udp(io, addr) => super::h3::Dispatcher::new(io, addr, &self.service)
.run()
.run(st.as_ref())
.await
.map_err(From::from),
Stream::Tcp(io, _addr) => {
Expand All @@ -311,7 +316,7 @@ where
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))?
.map_err(Into::into)?;

self.dispatch(_tls_stream, _addr, timer.as_mut()).await
self.dispatch(_tls_stream, _addr, timer.as_mut(), st.as_ref()).await
}
#[cfg(unix)]
Stream::Unix(_io, _) => {
Expand All @@ -324,8 +329,13 @@ where
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))?
.map_err(Into::into)?;

self.dispatch(_tls_stream, crate::unspecified_socket_addr(), timer.as_mut())
.await
self.dispatch(
_tls_stream,
crate::unspecified_socket_addr(),
timer.as_mut(),
st.as_ref(),
)
.await
}
}
}
Expand Down
27 changes: 17 additions & 10 deletions http/src/util/middleware/socket_config.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use core::{net::SocketAddr, time::Duration};

use std::io;

use socket2::{SockRef, TcpKeepalive};
use std::io;
use std::sync::Arc;

use tracing::warn;
use xitca_io::net::{Stream as ServerStream, TcpStream};
use xitca_service::{Service, ready::ReadyService};

#[cfg(unix)]
use xitca_io::net::UnixStream;
use xitca_service::shutdown::ShutdownToken;

/// A middleware for socket options config of `TcpStream` and `UnixStream`.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -90,30 +91,36 @@ where
}
}

impl<S> Service<(TcpStream, SocketAddr)> for SocketConfigService<S>
impl<S> Service<(TcpStream, SocketAddr, Arc<ShutdownToken>)> for SocketConfigService<S>
where
S: Service<(TcpStream, SocketAddr)>,
S: Service<(TcpStream, SocketAddr, Arc<ShutdownToken>)>,
{
type Response = S::Response;
type Error = S::Error;

async fn call(&self, (stream, addr): (TcpStream, SocketAddr)) -> Result<Self::Response, Self::Error> {
async fn call(
&self,
(stream, addr, st): (TcpStream, SocketAddr, Arc<ShutdownToken>),
) -> Result<Self::Response, Self::Error> {
self.try_apply_config(&stream);
self.service.call((stream, addr)).await
self.service.call((stream, addr, st)).await
}
}

#[cfg(unix)]
impl<S> Service<(UnixStream, SocketAddr)> for SocketConfigService<S>
impl<S> Service<(UnixStream, SocketAddr, Arc<ShutdownToken>)> for SocketConfigService<S>
where
S: Service<(UnixStream, SocketAddr)>,
S: Service<(UnixStream, SocketAddr, Arc<ShutdownToken>)>,
{
type Response = S::Response;
type Error = S::Error;

async fn call(&self, (stream, addr): (UnixStream, SocketAddr)) -> Result<Self::Response, Self::Error> {
async fn call(
&self,
(stream, addr, st): (UnixStream, SocketAddr, Arc<ShutdownToken>),
) -> Result<Self::Response, Self::Error> {
self.try_apply_config(&stream);
self.service.call((stream, addr)).await
self.service.call((stream, addr, st)).await
}
}

Expand Down
Loading
Loading