diff --git a/client/src/client.rs b/client/src/client.rs index 14e1ce040..d8e32b54a 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -263,6 +263,10 @@ impl Client { return Ok((conn, expected_version)); } + if !connect.tls { + return Ok((conn, expected_version)); + } + timer .as_mut() .reset(Instant::now() + self.timeout_config.tls_connect_timeout); diff --git a/client/src/connect.rs b/client/src/connect.rs index fc0bda861..b8c64127c 100644 --- a/client/src/connect.rs +++ b/client/src/connect.rs @@ -80,17 +80,19 @@ pub struct Connect<'a> { pub(crate) uri: Uri<'a>, pub(crate) port: u16, pub(crate) addr: Addrs, + pub(crate) tls: bool, } impl<'a> Connect<'a> { /// Create `Connect` instance by splitting the string by ':' and convert the second part to u16 - pub fn new(uri: Uri<'a>) -> Self { + pub fn new(uri: Uri<'a>, tls: Option) -> Self { let (_, port) = parse_host(uri.hostname()); Self { uri, port: port.unwrap_or(0), addr: Addrs::None, + tls: tls.unwrap_or(port != Some(80)), } } diff --git a/client/src/h1/proto/context.rs b/client/src/h1/proto/context.rs index 1b4affce8..1bbc79d21 100644 --- a/client/src/h1/proto/context.rs +++ b/client/src/h1/proto/context.rs @@ -29,7 +29,7 @@ impl DerefMut for Context<'_, '_, HEADER_LIMIT> { } impl<'c, 'd, const HEADER_LIMIT: usize> Context<'c, 'd, HEADER_LIMIT> { - pub(crate) fn new(date: &'c DateTimeHandle<'d>) -> Self { - Self(context::Context::new(date)) + pub(crate) fn new(date: &'c DateTimeHandle<'d>, is_tls: bool) -> Self { + Self(context::Context::new(date, is_tls)) } } diff --git a/client/src/h1/proto/dispatcher.rs b/client/src/h1/proto/dispatcher.rs index afd061f2d..37d0744d1 100644 --- a/client/src/h1/proto/dispatcher.rs +++ b/client/src/h1/proto/dispatcher.rs @@ -23,6 +23,7 @@ pub(crate) async fn send( stream: &mut S, date: DateTimeHandle<'_>, req: &mut Request, + is_tls: bool, ) -> Result<(Response<()>, BytesMut, TransferCoding, bool), Error> where S: AsyncIo + Unpin, @@ -69,7 +70,7 @@ where } // TODO: make const generic params configurable. - let mut ctx = Context::<128>::new(&date); + let mut ctx = Context::<128>::new(&date, is_tls); // encode request head and return transfer encoding for request body let encoder = ctx.encode_head(&mut buf, req)?; diff --git a/client/src/middleware/redirect.rs b/client/src/middleware/redirect.rs index 5f85e7137..04e03faaf 100644 --- a/client/src/middleware/redirect.rs +++ b/client/src/middleware/redirect.rs @@ -28,12 +28,25 @@ where type Error = Error; async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result { - let ServiceRequest { req, client, timeout } = req; + let ServiceRequest { + req, + client, + timeout, + tls, + } = req; let mut headers = req.headers().clone(); let mut method = req.method().clone(); let mut uri = req.uri().clone(); loop { - let mut res = self.service.call(ServiceRequest { req, client, timeout }).await?; + let mut res = self + .service + .call(ServiceRequest { + req, + client, + timeout, + tls, + }) + .await?; match res.status() { StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => { if method != Method::HEAD { diff --git a/client/src/request.rs b/client/src/request.rs index c45df7d6c..b3ba13468 100644 --- a/client/src/request.rs +++ b/client/src/request.rs @@ -22,6 +22,7 @@ pub struct RequestBuilder<'a, M = marker::Http> { err: Vec, client: &'a Client, timeout: Duration, + tls: Option, _marker: PhantomData, } @@ -104,6 +105,7 @@ impl<'a, M> RequestBuilder<'a, M> { err: Vec::new(), client, timeout: client.timeout_config.request_timeout, + tls: None, _marker: PhantomData, } } @@ -114,6 +116,7 @@ impl<'a, M> RequestBuilder<'a, M> { err: self.err, client: self.client, timeout: self.timeout, + tls: self.tls, _marker: PhantomData, } } @@ -125,6 +128,7 @@ impl<'a, M> RequestBuilder<'a, M> { err, client, timeout, + tls, .. } = self; @@ -138,6 +142,7 @@ impl<'a, M> RequestBuilder<'a, M> { req: &mut req, client, timeout, + tls, }) .await } @@ -210,6 +215,13 @@ impl<'a, M> RequestBuilder<'a, M> { self } + /// Set TLS state of this request. + #[inline] + pub fn tls(mut self, tls: bool) -> Self { + self.tls = Some(tls); + self + } + fn map_body(mut self, b: B) -> RequestBuilder<'a, M> where B: Stream> + Send + 'static, diff --git a/client/src/service.rs b/client/src/service.rs index 0751f7c3c..c8c42676c 100644 --- a/client/src/service.rs +++ b/client/src/service.rs @@ -68,6 +68,7 @@ pub struct ServiceRequest<'r, 'c> { pub req: &'r mut Request, pub client: &'c Client, pub timeout: Duration, + pub tls: Option, } /// type alias for object safe wrapper of type implement [Service] trait. @@ -85,7 +86,12 @@ pub(crate) fn base_service() -> HttpService { #[cfg(any(feature = "http1", feature = "http2", feature = "http3"))] use crate::{error::TimeoutError, timeout::Timeout}; - let ServiceRequest { req, client, timeout } = req; + let ServiceRequest { + req, + client, + timeout, + tls, + } = req; let uri = Uri::try_parse(req.uri())?; @@ -94,7 +100,8 @@ pub(crate) fn base_service() -> HttpService { #[allow(unused_mut)] let mut version = req.version(); - let mut connect = Connect::new(uri); + let mut connect = Connect::new(uri, tls); + let is_tls = connect.tls; let _date = client.date_service.handle(); @@ -219,7 +226,7 @@ pub(crate) fn base_service() -> HttpService { #[cfg(feature = "http1")] { let mut timer = Box::pin(tokio::time::sleep(timeout)); - let res = crate::h1::proto::send(&mut *_conn, _date, req) + let res = crate::h1::proto::send(&mut *_conn, _date, req, is_tls) .timeout(timer.as_mut()) .await; diff --git a/http/src/h1/dispatcher.rs b/http/src/h1/dispatcher.rs index 402e8b45f..07de45054 100644 --- a/http/src/h1/dispatcher.rs +++ b/http/src/h1/dispatcher.rs @@ -63,6 +63,7 @@ pub(crate) async fn run< config: HttpServiceConfig, service: &'a S, date: &'a D, + is_tls: bool, ) -> Result<(), Error> where S: Service, Response = Response>, @@ -77,7 +78,7 @@ where EitherBuf::Right(WriteBuf::::default()) }; - Dispatcher::new(io, addr, timer, config, service, date, write_buf) + Dispatcher::new(io, addr, timer, config, service, date, is_tls, write_buf) .run() .await } @@ -166,6 +167,7 @@ where W: H1BufWrite, D: DateTime, { + #[allow(clippy::too_many_arguments)] fn new( io: &'a mut St, addr: SocketAddr, @@ -173,12 +175,13 @@ where config: HttpServiceConfig, service: &'a S, date: &'a D, + is_tls: bool, write_buf: W, ) -> Self { Self { io: BufferedIo::new(io, write_buf), timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout), - ctx: Context::with_addr(addr, date), + ctx: Context::with_addr(addr, date, is_tls), service, _phantom: PhantomData, } diff --git a/http/src/h1/dispatcher_uring.rs b/http/src/h1/dispatcher_uring.rs index 2a49b42e1..74d559ce6 100644 --- a/http/src/h1/dispatcher_uring.rs +++ b/http/src/h1/dispatcher_uring.rs @@ -118,11 +118,12 @@ where config: HttpServiceConfig, service: &'a S, date: &'a D, + is_tls: bool, ) -> Self { Self { io: Rc::new(io), timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout), - ctx: Context::<_, H_LIMIT>::with_addr(addr, date), + ctx: Context::<_, H_LIMIT>::with_addr(addr, date, is_tls), service, read_buf: BufOwned::new(), write_buf: BufOwned::new(), diff --git a/http/src/h1/proto/context.rs b/http/src/h1/proto/context.rs index 271b2e71a..5b47f85f6 100644 --- a/http/src/h1/proto/context.rs +++ b/http/src/h1/proto/context.rs @@ -11,6 +11,7 @@ pub struct Context<'a, D, const HEADER_LIMIT: usize> { // http extensions reused by next request. exts: Extensions, date: &'a D, + is_tls: bool, } // A set of state for current request that are used after request's ownership is passed @@ -49,21 +50,22 @@ impl<'a, D, const HEADER_LIMIT: usize> Context<'a, D, HEADER_LIMIT> { /// /// [DateTime]: crate::date::DateTime #[inline] - pub fn new(date: &'a D) -> Self { - Self::with_addr(crate::unspecified_socket_addr(), date) + pub fn new(date: &'a D, is_tls: bool) -> Self { + Self::with_addr(crate::unspecified_socket_addr(), date, is_tls) } /// Context is constructed with [SocketAddr] and reference of certain type that impl [DateTime] trait. /// /// [DateTime]: crate::date::DateTime #[inline] - pub fn with_addr(addr: SocketAddr, date: &'a D) -> Self { + pub fn with_addr(addr: SocketAddr, date: &'a D, is_tls: bool) -> Self { Self { addr, state: ContextState::new(), header: None, exts: Extensions::new(), date, + is_tls, } } @@ -73,6 +75,12 @@ impl<'a, D, const HEADER_LIMIT: usize> Context<'a, D, HEADER_LIMIT> { self.date } + /// Check if current connection is secure. + #[inline] + pub fn is_tls(&self) -> bool { + self.is_tls + } + /// Take ownership of HeaderMap stored in Context. /// /// When Context does not have one a new HeaderMap is constructed. diff --git a/http/src/h1/proto/decode.rs b/http/src/h1/proto/decode.rs index 3695e7648..491d1f522 100644 --- a/http/src/h1/proto/decode.rs +++ b/http/src/h1/proto/decode.rs @@ -82,7 +82,7 @@ impl Context<'_, D, MAX_HEADERS> { self.try_write_header(&mut headers, &mut decoder, idx, &slice, version)?; } - let ext = Extension::new(*self.socket_addr()); + let ext = Extension::new(*self.socket_addr(), self.is_tls()); let mut req = Request::new(RequestExt::from_parts((), ext)); let extensions = self.take_extensions(); @@ -173,7 +173,7 @@ mod test { #[test] fn connection_multiple_value() { - let mut ctx = Context::<_, 4>::new(&()); + let mut ctx = Context::<_, 4>::new(&(), false); let head = b"\ GET / HTTP/1.1\r\n\ @@ -211,7 +211,7 @@ mod test { #[test] fn transfer_encoding() { - let mut ctx = Context::<_, 4>::new(&()); + let mut ctx = Context::<_, 4>::new(&(), false); let head = b"\ GET / HTTP/1.1\r\n\ diff --git a/http/src/h1/proto/encode.rs b/http/src/h1/proto/encode.rs index 88fbf4894..6159c62ab 100644 --- a/http/src/h1/proto/encode.rs +++ b/http/src/h1/proto/encode.rs @@ -257,7 +257,7 @@ mod test { #[test] fn append_header() { - let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler); + let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false); let mut res = Response::new(BoxBody::new(Once::new(Bytes::new()))); @@ -287,7 +287,7 @@ mod test { #[test] fn multi_set_cookie() { - let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler); + let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false); let mut res = Response::new(BoxBody::new(Once::new(Bytes::new()))); diff --git a/http/src/h1/service.rs b/http/src/h1/service.rs index fc9ce91db..4b519079f 100644 --- a/http/src/h1/service.rs +++ b/http/src/h1/service.rs @@ -21,7 +21,7 @@ impl for H1Service where S: Service>, Response = Response>, - A: Service, + A: Service + IsTls, St: AsyncIo, A::Response: AsyncIo, B: Stream>, @@ -41,9 +41,17 @@ where .await .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - super::dispatcher::run(&mut io, addr, timer, self.config, &self.service, self.date.get()) - .await - .map_err(Into::into) + super::dispatcher::run( + &mut io, + addr, + timer, + self.config, + &self.service, + self.date.get(), + self.tls_acceptor.is_tls(), + ) + .await + .map_err(Into::into) } } @@ -56,6 +64,7 @@ use { xitca_service::ready::ReadyService, }; +use crate::tls::IsTls; #[cfg(feature = "io-uring")] use crate::{ config::HttpServiceConfig, @@ -94,7 +103,7 @@ impl for H1UringService where S: Service>, Response = Response>, - A: Service, + A: Service + IsTls, A::Response: AsyncBufRead + AsyncBufWrite + 'static, B: Stream>, HttpServiceError: From, @@ -113,10 +122,18 @@ where .await .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - super::dispatcher_uring::Dispatcher::new(io, addr, timer, self.config, &self.service, self.date.get()) - .run() - .await - .map_err(Into::into) + super::dispatcher_uring::Dispatcher::new( + io, + addr, + timer, + self.config, + &self.service, + self.date.get(), + self.tls_acceptor.is_tls(), + ) + .run() + .await + .map_err(Into::into) } } diff --git a/http/src/h2/proto/dispatcher.rs b/http/src/h2/proto/dispatcher.rs index c7775d516..45d4e818c 100644 --- a/http/src/h2/proto/dispatcher.rs +++ b/http/src/h2/proto/dispatcher.rs @@ -39,6 +39,7 @@ pub(crate) struct Dispatcher<'a, TlsSt, S, ReqB> { ka_dur: Duration, service: &'a S, date: &'a DateTimeHandle, + is_tls: bool, _req_body: PhantomData, } @@ -60,6 +61,7 @@ where ka_dur: Duration, service: &'a S, date: &'a DateTimeHandle, + is_tls: bool, ) -> Self { Self { io, @@ -68,6 +70,7 @@ where ka_dur, service, date, + is_tls, _req_body: PhantomData, } } @@ -80,6 +83,7 @@ where ka_dur, service, date, + is_tls, .. } = self; @@ -107,7 +111,7 @@ where // and reconstruct as HttpRequest. let req = req.map(|body| { let body = ReqB::from(RequestBody::from(body)); - RequestExt::from_parts(body, Extension::new(addr)) + RequestExt::from_parts(body, Extension::new(addr, is_tls)) }); queue.push(async move { diff --git a/http/src/h2/service.rs b/http/src/h2/service.rs index 0acdfa1bb..bfd0326e0 100644 --- a/http/src/h2/service.rs +++ b/http/src/h2/service.rs @@ -32,7 +32,7 @@ where S: Service>, Response = Response>, S::Error: fmt::Debug, - A: Service, + A: Service + IsTls, St: AsyncIo, TlsSt: AsyncIo, @@ -73,6 +73,7 @@ where self.config.keep_alive_timeout, &self.service, self.date.get(), + self.tls_acceptor.is_tls(), ); dispatcher.run().await?; @@ -81,6 +82,7 @@ where } } +use crate::tls::IsTls; #[cfg(feature = "io-uring")] pub(crate) use io_uring::H2UringService; diff --git a/http/src/h3/proto/dispatcher.rs b/http/src/h3/proto/dispatcher.rs index 9aa3a2bc9..51f04404a 100644 --- a/http/src/h3/proto/dispatcher.rs +++ b/http/src/h3/proto/dispatcher.rs @@ -69,7 +69,7 @@ where // Reconstruct Request to attach crate body type. let req = req.map(|_| { let body = ReqB::from(RequestBody(rx)); - RequestExt::from_parts(body, Extension::new(self.addr)) + RequestExt::from_parts(body, Extension::new(self.addr, true)) }); queue.push(async move { diff --git a/http/src/http.rs b/http/src/http.rs index ddb494e2c..3487523a0 100644 --- a/http/src/http.rs +++ b/http/src/http.rs @@ -146,9 +146,10 @@ where pub(crate) struct Extension(Box<_Extension>); impl Extension { - pub(crate) fn new(addr: SocketAddr) -> Self { + pub(crate) fn new(addr: SocketAddr, is_tls: bool) -> Self { Self(Box::new(_Extension { addr, + is_tls, #[cfg(feature = "router")] params: Default::default(), })) @@ -158,6 +159,7 @@ impl Extension { #[derive(Clone, Debug)] struct _Extension { addr: SocketAddr, + is_tls: bool, #[cfg(feature = "router")] params: Params, } @@ -176,6 +178,12 @@ impl RequestExt { &self.ext.0.addr } + /// retrieve whether the connection is tls encrypted. + #[inline] + pub fn is_tls(&self) -> bool { + self.ext.0.is_tls + } + /// exclusive version of [RequestExt::socket_addr] #[inline] pub fn socket_addr_mut(&mut self) -> &mut SocketAddr { @@ -209,7 +217,7 @@ where B: Default, { fn default() -> Self { - Self::from_parts(B::default(), Extension::new(crate::unspecified_socket_addr())) + Self::from_parts(B::default(), Extension::new(crate::unspecified_socket_addr(), false)) } } diff --git a/http/src/service.rs b/http/src/service.rs index c3cd1127c..e5fc53e6d 100644 --- a/http/src/service.rs +++ b/http/src/service.rs @@ -14,6 +14,7 @@ use super::{ date::{DateTime, DateTimeService}, error::{HttpServiceError, TimeoutError}, http::{Request, RequestExt, Response}, + tls::IsTls, util::timer::{KeepAlive, Timeout}, version::AsVersion, }; @@ -73,7 +74,7 @@ impl where S: Service>, Response = Response>, - A: Service, + A: Service + IsTls, A::Response: AsyncIo + AsVersion, HttpServiceError: From, S::Error: fmt::Debug, @@ -120,6 +121,7 @@ where self.config, &self.service, self.date.get(), + self.tls_acceptor.is_tls(), ) .await .map_err(From::from), @@ -142,6 +144,7 @@ where self.config.keep_alive_timeout, &self.service, self.date.get(), + self.tls_acceptor.is_tls(), ) .run() .await @@ -168,6 +171,7 @@ where self.config, &self.service, self.date.get(), + false, ) .await .map_err(From::from) diff --git a/http/src/tls/mod.rs b/http/src/tls/mod.rs index bb8e3da63..09ebb56e0 100644 --- a/http/src/tls/mod.rs +++ b/http/src/tls/mod.rs @@ -34,6 +34,12 @@ impl Service for NoOpTlsAcceptorBuilder { pub struct NoOpTlsAcceptorService; +pub trait IsTls { + fn is_tls(&self) -> bool { + true + } +} + impl Service for NoOpTlsAcceptorService { type Response = St; type Error = TlsError; @@ -42,3 +48,9 @@ impl Service for NoOpTlsAcceptorService { Ok(io) } } + +impl IsTls for NoOpTlsAcceptorService { + fn is_tls(&self) -> bool { + false + } +} diff --git a/http/src/tls/native_tls.rs b/http/src/tls/native_tls.rs index 056253dac..9127a3c16 100644 --- a/http/src/tls/native_tls.rs +++ b/http/src/tls/native_tls.rs @@ -14,9 +14,9 @@ use native_tls::{Error, HandshakeError}; use xitca_io::io::{AsyncIo, Interest, Ready}; use xitca_service::Service; -use crate::{http::Version, version::AsVersion}; - use super::error::TlsError; +use crate::tls::IsTls; +use crate::{http::Version, version::AsVersion}; /// A wrapper type for [TlsStream](native_tls::TlsStream). /// @@ -65,6 +65,8 @@ pub struct TlsAcceptorService { acceptor: TlsAcceptor, } +impl IsTls for TlsAcceptorService {} + impl Service for TlsAcceptorService { type Response = TlsStream; type Error = NativeTlsError; diff --git a/http/src/tls/openssl.rs b/http/src/tls/openssl.rs index 6e05f97d5..eee21e6f9 100644 --- a/http/src/tls/openssl.rs +++ b/http/src/tls/openssl.rs @@ -6,9 +6,9 @@ use xitca_io::io::AsyncIo; use xitca_service::Service; use xitca_tls::openssl::ssl; -use crate::{http::Version, version::AsVersion}; - use super::error::TlsError; +use crate::tls::IsTls; +use crate::{http::Version, version::AsVersion}; pub type TlsStream = xitca_tls::openssl::TlsStream; @@ -52,6 +52,8 @@ pub struct TlsAcceptorService { acceptor: TlsAcceptor, } +impl IsTls for TlsAcceptorService {} + impl TlsAcceptorService { #[inline(never)] async fn accept(&self, io: Io) -> Result, OpensslError> { diff --git a/http/src/tls/rustls.rs b/http/src/tls/rustls.rs index 320ed05d8..1df7202ea 100644 --- a/http/src/tls/rustls.rs +++ b/http/src/tls/rustls.rs @@ -6,9 +6,9 @@ use xitca_io::io::AsyncIo; use xitca_service::Service; use xitca_tls::rustls::{Error, ServerConfig, ServerConnection, TlsStream as _TlsStream}; -use crate::{http::Version, version::AsVersion}; - use super::error::TlsError; +use crate::tls::IsTls; +use crate::{http::Version, version::AsVersion}; pub(crate) type RustlsConfig = Arc; @@ -55,6 +55,8 @@ pub struct TlsAcceptorService { acceptor: Arc, } +impl IsTls for TlsAcceptorService {} + impl Service for TlsAcceptorService { type Response = TlsStream; type Error = RustlsError; diff --git a/http/src/tls/rustls_uring.rs b/http/src/tls/rustls_uring.rs index d99dd65e8..a066e8e1a 100644 --- a/http/src/tls/rustls_uring.rs +++ b/http/src/tls/rustls_uring.rs @@ -9,9 +9,9 @@ use xitca_tls::{ rustls_uring::TlsStream as _TlsStream, }; -use crate::{http::Version, version::AsVersion}; - use super::rustls::RustlsError; +use crate::tls::IsTls; +use crate::{http::Version, version::AsVersion}; /// A stream managed by rustls for tls read/write. pub struct TlsStream { @@ -52,6 +52,8 @@ pub struct TlsAcceptorService { acceptor: Arc, } +impl IsTls for TlsAcceptorService {} + impl Service for TlsAcceptorService where Io: AsyncBufRead + AsyncBufWrite,