Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ impl Client {
return Ok((conn, expected_version));
}

if !connect.tls {
return Ok((conn, expected_version));
}

timer
.as_mut()
.reset(Instant::now() + self.timeout_config.tls_connect_timeout);
Expand Down
4 changes: 3 additions & 1 deletion client/src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,19 @@ pub struct Connect<'a> {
pub(crate) uri: Uri<'a>,
pub(crate) port: u16,
pub(crate) addr: Addrs,
pub(crate) tls: bool,
}

impl<'a> Connect<'a> {
/// Create `Connect` instance by splitting the string by ':' and convert the second part to u16
pub fn new(uri: Uri<'a>) -> Self {
pub fn new(uri: Uri<'a>, tls: Option<bool>) -> Self {
let (_, port) = parse_host(uri.hostname());

Self {
uri,
port: port.unwrap_or(0),
addr: Addrs::None,
tls: tls.unwrap_or(port != Some(80)),
}
}

Expand Down
4 changes: 2 additions & 2 deletions client/src/h1/proto/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl<const HEADER_LIMIT: usize> DerefMut for Context<'_, '_, HEADER_LIMIT> {
}

impl<'c, 'd, const HEADER_LIMIT: usize> Context<'c, 'd, HEADER_LIMIT> {
pub(crate) fn new(date: &'c DateTimeHandle<'d>) -> Self {
Self(context::Context::new(date))
pub(crate) fn new(date: &'c DateTimeHandle<'d>, is_tls: bool) -> Self {
Self(context::Context::new(date, is_tls))
}
}
3 changes: 2 additions & 1 deletion client/src/h1/proto/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub(crate) async fn send<S, B, E>(
stream: &mut S,
date: DateTimeHandle<'_>,
req: &mut Request<B>,
is_tls: bool,
) -> Result<(Response<()>, BytesMut, TransferCoding, bool), Error>
where
S: AsyncIo + Unpin,
Expand Down Expand Up @@ -69,7 +70,7 @@ where
}

// TODO: make const generic params configurable.
let mut ctx = Context::<128>::new(&date);
let mut ctx = Context::<128>::new(&date, is_tls);

// encode request head and return transfer encoding for request body
let encoder = ctx.encode_head(&mut buf, req)?;
Expand Down
17 changes: 15 additions & 2 deletions client/src/middleware/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@ where
type Error = Error;

async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result<Self::Response, Self::Error> {
let ServiceRequest { req, client, timeout } = req;
let ServiceRequest {
req,
client,
timeout,
tls,
} = req;
let mut headers = req.headers().clone();
let mut method = req.method().clone();
let mut uri = req.uri().clone();
loop {
let mut res = self.service.call(ServiceRequest { req, client, timeout }).await?;
let mut res = self
.service
.call(ServiceRequest {
req,
client,
timeout,
tls,
})
.await?;
match res.status() {
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => {
if method != Method::HEAD {
Expand Down
12 changes: 12 additions & 0 deletions client/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub struct RequestBuilder<'a, M = marker::Http> {
err: Vec<Error>,
client: &'a Client,
timeout: Duration,
tls: Option<bool>,
_marker: PhantomData<M>,
}

Expand Down Expand Up @@ -104,6 +105,7 @@ impl<'a, M> RequestBuilder<'a, M> {
err: Vec::new(),
client,
timeout: client.timeout_config.request_timeout,
tls: None,
_marker: PhantomData,
}
}
Expand All @@ -114,6 +116,7 @@ impl<'a, M> RequestBuilder<'a, M> {
err: self.err,
client: self.client,
timeout: self.timeout,
tls: self.tls,
_marker: PhantomData,
}
}
Expand All @@ -125,6 +128,7 @@ impl<'a, M> RequestBuilder<'a, M> {
err,
client,
timeout,
tls,
..
} = self;

Expand All @@ -138,6 +142,7 @@ impl<'a, M> RequestBuilder<'a, M> {
req: &mut req,
client,
timeout,
tls,
})
.await
}
Expand Down Expand Up @@ -210,6 +215,13 @@ impl<'a, M> RequestBuilder<'a, M> {
self
}

/// Set TLS state of this request.
#[inline]
pub fn tls(mut self, tls: bool) -> Self {
self.tls = Some(tls);
self
}

fn map_body<B, E>(mut self, b: B) -> RequestBuilder<'a, M>
where
B: Stream<Item = Result<Bytes, E>> + Send + 'static,
Expand Down
13 changes: 10 additions & 3 deletions client/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub struct ServiceRequest<'r, 'c> {
pub req: &'r mut Request<BoxBody>,
pub client: &'c Client,
pub timeout: Duration,
pub tls: Option<bool>,
}

/// type alias for object safe wrapper of type implement [Service] trait.
Expand All @@ -85,7 +86,12 @@ pub(crate) fn base_service() -> HttpService {
#[cfg(any(feature = "http1", feature = "http2", feature = "http3"))]
use crate::{error::TimeoutError, timeout::Timeout};

let ServiceRequest { req, client, timeout } = req;
let ServiceRequest {
req,
client,
timeout,
tls,
} = req;

let uri = Uri::try_parse(req.uri())?;

Expand All @@ -94,7 +100,8 @@ pub(crate) fn base_service() -> HttpService {
#[allow(unused_mut)]
let mut version = req.version();

let mut connect = Connect::new(uri);
let mut connect = Connect::new(uri, tls);
let is_tls = connect.tls;

let _date = client.date_service.handle();

Expand Down Expand Up @@ -219,7 +226,7 @@ pub(crate) fn base_service() -> HttpService {
#[cfg(feature = "http1")]
{
let mut timer = Box::pin(tokio::time::sleep(timeout));
let res = crate::h1::proto::send(&mut *_conn, _date, req)
let res = crate::h1::proto::send(&mut *_conn, _date, req, is_tls)
.timeout(timer.as_mut())
.await;

Expand Down
7 changes: 5 additions & 2 deletions http/src/h1/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub(crate) async fn run<
config: HttpServiceConfig<HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>,
service: &'a S,
date: &'a D,
is_tls: bool,
) -> Result<(), Error<S::Error, BE>>
where
S: Service<ExtRequest<ReqB>, Response = Response<ResB>>,
Expand All @@ -77,7 +78,7 @@ where
EitherBuf::Right(WriteBuf::<WRITE_BUF_LIMIT>::default())
};

Dispatcher::new(io, addr, timer, config, service, date, write_buf)
Dispatcher::new(io, addr, timer, config, service, date, is_tls, write_buf)
.run()
.await
}
Expand Down Expand Up @@ -166,19 +167,21 @@ where
W: H1BufWrite,
D: DateTime,
{
#[allow(clippy::too_many_arguments)]
fn new<const WRITE_BUF_LIMIT: usize>(
io: &'a mut St,
addr: SocketAddr,
timer: Pin<&'a mut KeepAlive>,
config: HttpServiceConfig<HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>,
service: &'a S,
date: &'a D,
is_tls: bool,
write_buf: W,
) -> Self {
Self {
io: BufferedIo::new(io, write_buf),
timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout),
ctx: Context::with_addr(addr, date),
ctx: Context::with_addr(addr, date, is_tls),
service,
_phantom: PhantomData,
}
Expand Down
3 changes: 2 additions & 1 deletion http/src/h1/dispatcher_uring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,12 @@ where
config: HttpServiceConfig<H_LIMIT, R_LIMIT, W_LIMIT>,
service: &'a S,
date: &'a D,
is_tls: bool,
) -> Self {
Self {
io: Rc::new(io),
timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout),
ctx: Context::<_, H_LIMIT>::with_addr(addr, date),
ctx: Context::<_, H_LIMIT>::with_addr(addr, date, is_tls),
service,
read_buf: BufOwned::new(),
write_buf: BufOwned::new(),
Expand Down
14 changes: 11 additions & 3 deletions http/src/h1/proto/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct Context<'a, D, const HEADER_LIMIT: usize> {
// http extensions reused by next request.
exts: Extensions,
date: &'a D,
is_tls: bool,
}

// A set of state for current request that are used after request's ownership is passed
Expand Down Expand Up @@ -49,21 +50,22 @@ impl<'a, D, const HEADER_LIMIT: usize> Context<'a, D, HEADER_LIMIT> {
///
/// [DateTime]: crate::date::DateTime
#[inline]
pub fn new(date: &'a D) -> Self {
Self::with_addr(crate::unspecified_socket_addr(), date)
pub fn new(date: &'a D, is_tls: bool) -> Self {
Self::with_addr(crate::unspecified_socket_addr(), date, is_tls)
}

/// Context is constructed with [SocketAddr] and reference of certain type that impl [DateTime] trait.
///
/// [DateTime]: crate::date::DateTime
#[inline]
pub fn with_addr(addr: SocketAddr, date: &'a D) -> Self {
pub fn with_addr(addr: SocketAddr, date: &'a D, is_tls: bool) -> Self {
Self {
addr,
state: ContextState::new(),
header: None,
exts: Extensions::new(),
date,
is_tls,
}
}

Expand All @@ -73,6 +75,12 @@ impl<'a, D, const HEADER_LIMIT: usize> Context<'a, D, HEADER_LIMIT> {
self.date
}

/// Check if current connection is secure.
#[inline]
pub fn is_tls(&self) -> bool {
self.is_tls
}

/// Take ownership of HeaderMap stored in Context.
///
/// When Context does not have one a new HeaderMap is constructed.
Expand Down
6 changes: 3 additions & 3 deletions http/src/h1/proto/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<D, const MAX_HEADERS: usize> Context<'_, D, MAX_HEADERS> {
self.try_write_header(&mut headers, &mut decoder, idx, &slice, version)?;
}

let ext = Extension::new(*self.socket_addr());
let ext = Extension::new(*self.socket_addr(), self.is_tls());
let mut req = Request::new(RequestExt::from_parts((), ext));

let extensions = self.take_extensions();
Expand Down Expand Up @@ -173,7 +173,7 @@ mod test {

#[test]
fn connection_multiple_value() {
let mut ctx = Context::<_, 4>::new(&());
let mut ctx = Context::<_, 4>::new(&(), false);

let head = b"\
GET / HTTP/1.1\r\n\
Expand Down Expand Up @@ -211,7 +211,7 @@ mod test {

#[test]
fn transfer_encoding() {
let mut ctx = Context::<_, 4>::new(&());
let mut ctx = Context::<_, 4>::new(&(), false);

let head = b"\
GET / HTTP/1.1\r\n\
Expand Down
4 changes: 2 additions & 2 deletions http/src/h1/proto/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ mod test {

#[test]
fn append_header() {
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler);
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false);

let mut res = Response::new(BoxBody::new(Once::new(Bytes::new())));

Expand Down Expand Up @@ -287,7 +287,7 @@ mod test {

#[test]
fn multi_set_cookie() {
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler);
let mut ctx = Context::<_, 64>::new(&SystemTimeDateTimeHandler, false);

let mut res = Response::new(BoxBody::new(Once::new(Bytes::new())));

Expand Down
35 changes: 26 additions & 9 deletions http/src/h1/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl<St, S, B, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, co
Service<(St, SocketAddr)> for H1Service<St, S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
A: Service<St>,
A: Service<St> + IsTls,
St: AsyncIo,
A::Response: AsyncIo,
B: Stream<Item = Result<Bytes, BE>>,
Expand All @@ -41,9 +41,17 @@ where
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??;

super::dispatcher::run(&mut io, addr, timer, self.config, &self.service, self.date.get())
.await
.map_err(Into::into)
super::dispatcher::run(
&mut io,
addr,
timer,
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.await
.map_err(Into::into)
}
}

Expand All @@ -56,6 +64,7 @@ use {
xitca_service::ready::ReadyService,
};

use crate::tls::IsTls;
#[cfg(feature = "io-uring")]
use crate::{
config::HttpServiceConfig,
Expand Down Expand Up @@ -94,7 +103,7 @@ impl<S, B, BE, A, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize, const
Service<(TcpStream, SocketAddr)> for H1UringService<S, A, HEADER_LIMIT, READ_BUF_LIMIT, WRITE_BUF_LIMIT>
where
S: Service<Request<RequestExt<RequestBody>>, Response = Response<B>>,
A: Service<TcpStream>,
A: Service<TcpStream> + IsTls,
A::Response: AsyncBufRead + AsyncBufWrite + 'static,
B: Stream<Item = Result<Bytes, BE>>,
HttpServiceError<S::Error, BE>: From<A::Error>,
Expand All @@ -113,10 +122,18 @@ where
.await
.map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??;

super::dispatcher_uring::Dispatcher::new(io, addr, timer, self.config, &self.service, self.date.get())
.run()
.await
.map_err(Into::into)
super::dispatcher_uring::Dispatcher::new(
io,
addr,
timer,
self.config,
&self.service,
self.date.get(),
self.tls_acceptor.is_tls(),
)
.run()
.await
.map_err(Into::into)
}
}

Expand Down
Loading