From 7548fe02dcb714f2b3f0f7fc4a12f0217e2f0947 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 3 Mar 2026 15:21:16 +0800 Subject: [PATCH 01/21] more homebrew protocol --- postgres/Cargo.toml | 2 + postgres/src/driver/codec.rs | 86 +-- postgres/src/driver/codec/encode.rs | 26 +- postgres/src/driver/generic.rs | 15 +- postgres/src/protocol.rs | 166 +---- postgres/src/protocol/message.rs | 2 + postgres/src/protocol/message/backend.rs | 856 ++++++++++++++++++++++ postgres/src/protocol/message/frontend.rs | 165 +++++ 8 files changed, 1074 insertions(+), 244 deletions(-) create mode 100644 postgres/src/protocol/message.rs create mode 100644 postgres/src/protocol/message/backend.rs create mode 100644 postgres/src/protocol/message/frontend.rs diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 73ace6207..1a743e6f7 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -38,9 +38,11 @@ nightly = [] xitca-io = { version = "0.5.1", features = ["runtime"] } xitca-unsafe-collection = { version = "0.2.0", features = ["bytes"] } +byteorder = "1.5.0" fallible-iterator = "0.2" futures-core = { version = "0.3", default-features = false } lru = { version = "0.16", default-features = false } +memchr = "2.7.1" percent-encoding = "2" postgres-protocol = "0.6.5" postgres-types = "0.2" diff --git a/postgres/src/driver/codec.rs b/postgres/src/driver/codec.rs index 4f4b33722..b17418c22 100644 --- a/postgres/src/driver/codec.rs +++ b/postgres/src/driver/codec.rs @@ -17,27 +17,18 @@ use crate::{ pub(super) fn request_pair() -> (ResponseSender, Response) { let (tx, rx) = unbounded_channel(); - ( - tx, - Response { - rx, - buf: BytesMut::new(), - }, - ) + (tx, Response { rx }) } #[derive(Debug)] pub struct Response { rx: ResponseReceiver, - buf: BytesMut, } impl Response { pub(crate) fn blocking_recv(&mut self) -> Result { - if self.buf.is_empty() { - self.buf = self.rx.blocking_recv().ok_or_else(|| Error::from(ClosedByDriver))?; - } - self.parse_message() + let msg = self.rx.blocking_recv().ok_or_else(|| Error::from(ClosedByDriver))?; + Self::parse_message(msg) } pub(crate) fn recv(&mut self) -> impl Future> + Send + '_ { @@ -45,10 +36,8 @@ impl Response { } pub(crate) fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.buf.is_empty() { - self.buf = ready!(self.rx.poll_recv(cx)).ok_or_else(|| Error::from(ClosedByDriver))?; - } - Poll::Ready(self.parse_message()) + let msg = ready!(self.rx.poll_recv(cx)).ok_or_else(|| Error::from(ClosedByDriver))?; + Poll::Ready(Self::parse_message(msg)) } pub(crate) fn try_into_row_affected(mut self) -> impl Future> + Send { @@ -100,8 +89,8 @@ impl Response { } } - fn parse_message(&mut self) -> Result { - match backend::Message::parse(&mut self.buf)?.expect("must not parse message from empty buffer.") { + fn parse_message(msg: backend::MessageRaw) -> Result { + match msg.try_into_message()? { backend::Message::ErrorResponse(body) => Err(Error::db(body.fields())), msg => Ok(msg), } @@ -115,24 +104,24 @@ pub(crate) fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Resu .map(|r| r.rsplit(' ').next().unwrap().parse().unwrap_or(0)) } -pub(super) type ResponseSender = UnboundedSender; +pub(super) type ResponseSender = UnboundedSender; // TODO: remove this lint. #[allow(dead_code)] -pub(super) type ResponseReceiver = UnboundedReceiver; +pub(super) type ResponseReceiver = UnboundedReceiver; pub(super) struct BytesMessage { - pub(super) buf: BytesMut, + pub(super) msg: backend::MessageRaw, pub(super) complete: bool, } impl BytesMessage { #[cold] #[inline(never)] - pub(super) fn parse_error(&mut self) -> Error { - match backend::Message::parse(&mut self.buf) { + pub(super) fn parse_error(self) -> Error { + match self.msg.try_into_message() { Err(e) => Error::from(e), - Ok(Some(backend::Message::ErrorResponse(body))) => Error::db(body.fields()), + Ok(backend::Message::ErrorResponse(body)) => Error::db(body.fields()), _ => Error::unexpected(), } } @@ -145,47 +134,20 @@ pub(super) enum ResponseMessage { impl ResponseMessage { pub(crate) fn try_from_buf(buf: &mut BytesMut) -> Result, Error> { - let mut tail = 0; - let mut complete = false; - - loop { - let slice = &buf[tail..]; - let Some(header) = backend::Header::parse(slice)? else { - break; - }; - let len = header.len() as usize + 1; - - if slice.len() < len { - break; + let Some(msg) = backend::MessageRaw::parse(buf)? else { + return Ok(None); + }; + + match msg.tag() { + backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG | backend::PARAMETER_STATUS_TAG => { + let message = msg.try_into_message()?; + Ok(Some(ResponseMessage::Async(message))) } - - match header.tag() { - backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG | backend::PARAMETER_STATUS_TAG => { - if tail > 0 { - break; - } - let message = backend::Message::parse(buf)? - .expect("buffer contains at least one Message. parser must produce Some"); - return Ok(Some(ResponseMessage::Async(message))); - } - tag => { - tail += len; - if matches!(tag, backend::READY_FOR_QUERY_TAG) { - complete = true; - break; - } - } + tag => { + let complete = matches!(tag, backend::READY_FOR_QUERY_TAG); + Ok(Some(ResponseMessage::Normal(BytesMessage { msg, complete }))) } } - - if tail == 0 { - Ok(None) - } else { - Ok(Some(ResponseMessage::Normal(BytesMessage { - buf: buf.split_to(tail), - complete, - }))) - } } } diff --git a/postgres/src/driver/codec/encode.rs b/postgres/src/driver/codec/encode.rs index 4d943dc89..dfcff9d2f 100644 --- a/postgres/src/driver/codec/encode.rs +++ b/postgres/src/driver/codec/encode.rs @@ -82,7 +82,7 @@ where fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> { frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?; frontend::describe(b'S', name, buf)?; - protocol::sync(buf); + frontend::sync(buf); Ok(()) } @@ -94,8 +94,8 @@ impl Encode for StatementPreparedCancel<'_> { #[inline] fn encode(self, buf: &mut BytesMut) -> Result { let Self { name } = self; - protocol::close(b'S', name, buf)?; - protocol::sync(buf); + frontend::close(b'S', name, buf)?; + frontend::sync(buf); Ok(NoOpIntoRowStream) } } @@ -135,8 +135,8 @@ where P: AsParams, { encode_bind(stmt.name(), stmt.params(), params, "", buf)?; - protocol::execute("", 0, buf)?; - protocol::sync(buf); + frontend::execute("", 0, buf)?; + frontend::sync(buf); Ok(()) } @@ -156,8 +156,8 @@ where frontend::parse("", stmt, types.iter().map(Type::oid), buf)?; encode_bind("", types, params, "", buf)?; frontend::describe(b'S', "", buf)?; - protocol::execute("", 0, buf)?; - protocol::sync(buf); + frontend::execute("", 0, buf)?; + frontend::sync(buf); Ok(IntoRowStreamGuard(cli)) } } @@ -186,7 +186,7 @@ where params, } = self; encode_bind(stmt, types, params, name, buf)?; - protocol::sync(buf); + frontend::sync(buf); Ok(NoOpIntoRowStream) } } @@ -202,8 +202,8 @@ impl Encode for PortalCancel<'_> { #[inline] fn encode(self, buf: &mut BytesMut) -> Result { - protocol::close(b'P', self.name, buf)?; - protocol::sync(buf); + frontend::close(b'P', self.name, buf)?; + frontend::sync(buf); Ok(NoOpIntoRowStream) } } @@ -226,8 +226,8 @@ impl<'s> Encode for PortalQuery<'s> { max_rows, columns, } = self; - protocol::execute(name, max_rows, buf)?; - protocol::sync(buf); + frontend::execute(name, max_rows, buf)?; + frontend::sync(buf); Ok(columns) } } @@ -252,7 +252,7 @@ where let params = params.zip(types); - protocol::bind( + frontend::bind( portal_name, stmt_name, params.clone().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _), diff --git a/postgres/src/driver/generic.rs b/postgres/src/driver/generic.rs index dd38eeadb..0fa15e8f7 100644 --- a/postgres/src/driver/generic.rs +++ b/postgres/src/driver/generic.rs @@ -351,13 +351,18 @@ impl DriverRx { while let Some(res) = ResponseMessage::try_from_buf(read_buf)? { match res { - ResponseMessage::Normal(mut msg) => { + ResponseMessage::Normal(msg) => { // lock the shared state only when needed and keep the lock around a bit for possible multiple messages let inner = guard.get_or_insert_with(|| self.guarded.lock().unwrap()); - let res = inner.res.pop_front().ok_or_else(|| msg.parse_error())?; - let _ = res.send(msg.buf); - if !msg.complete { - inner.res.push_front(res); + + match inner.res.pop_front() { + Some(tx) => { + let _ = tx.send(msg.msg); + if !msg.complete { + inner.res.push_front(tx); + } + } + None => return Err(msg.parse_error()), } } ResponseMessage::Async(msg) => return Ok(Some(msg)), diff --git a/postgres/src/protocol.rs b/postgres/src/protocol.rs index 34e65f512..9f1901347 100644 --- a/postgres/src/protocol.rs +++ b/postgres/src/protocol.rs @@ -1,167 +1,5 @@ -//! re-export of postgres-protocol crate with additional query functions +//! re-export of postgres-protocol crate with additional protocol functions -use xitca_io::bytes::{BufMut, BytesMut}; - -use std::{convert::Infallible, io}; - -use super::error::Error; +pub(crate) mod message; pub(crate) use postgres_protocol::*; - -// optimized version of protocol functions depending on specific inputs - -// optimized version of postgres_protocol::frontend::bind -pub(crate) fn bind( - portal_name: &str, - stmt_name: &str, - formats: I, - values: J, - mut serializer: F, - buf: &mut BytesMut, -) -> Result<(), Error> -where - I: ExactSizeIterator, - J: ExactSizeIterator, - F: FnMut(T, &mut BytesMut) -> Result, -{ - buf.put_u8(b'B'); - write_body(buf, |buf| { - write_cstr(portal_name, buf); - write_cstr(stmt_name, buf); - write_counted( - formats, - |f, buf| { - buf.put_i16(f); - Ok::<_, Infallible>(()) - }, - buf, - )?; - write_counted(values, |v, buf| write_nullable(|buf| serializer(v, buf), buf), buf)?; - result_fmt_binary(buf); - Ok(()) - }) -} - -// optimized version of postgres_protocol::frontend::execute -pub(crate) fn execute(portal_name: &str, max_rows: i32, buf: &mut BytesMut) -> Result<(), Error> { - buf.put_u8(b'E'); - write_body(buf, |buf| { - write_cstr(portal_name, buf); - buf.put_i32(max_rows); - Ok(()) - }) -} - -// optimized version of postgres_protocol::frontend::close -pub(crate) fn close(variant: u8, name: &str, buf: &mut BytesMut) -> Result<(), Error> { - buf.put_u8(b'C'); - write_body(buf, |buf| { - buf.put_u8(variant); - write_cstr(name, buf); - Ok(()) - }) -} - -// optimized version of postgres_protocol::frontend::sync -pub(crate) fn sync(buf: &mut BytesMut) { - buf.extend_from_slice(&[b'S', 0, 0, 0, 4]); -} - -fn write_body(buf: &mut BytesMut, f: F) -> Result<(), Error> -where - F: FnOnce(&mut BytesMut) -> Result<(), Error>, -{ - let base = buf.len(); - buf.put_i32(0); - - f(buf)?; - - let size = FromUsize::from_usize(buf.len() - base)?; - buf.put_i32_at(base, size); - Ok(()) -} - -fn write_counted(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), Error> -where - I: ExactSizeIterator, - F: FnMut(T, &mut BytesMut) -> Result<(), E>, - Error: From, -{ - let count = FromUsize::from_usize(items.len())?; - buf.put_u16(count); - - for item in items { - serializer(item, buf)?; - } - - Ok(()) -} - -fn write_nullable(serializer: F, buf: &mut BytesMut) -> Result<(), Error> -where - F: FnOnce(&mut BytesMut) -> Result, -{ - let base = buf.len(); - buf.put_i32(0); - - let size = match serializer(buf)? { - IsNull::No => FromUsize::from_usize(buf.len() - base - 4)?, - IsNull::Yes => -1, - }; - buf.put_i32_at(base, size); - - Ok(()) -} - -fn write_cstr(s: &str, buf: &mut BytesMut) { - // strings used inside library dont contain c style null bytes. - // debug assertions is enough for catching violation. - #[cfg(debug_assertions)] - if s.as_bytes().contains(&0) { - panic!("input string: {s} contains embedded null") - } - - buf.put_slice(s.as_bytes()); - buf.put_u8(0); -} - -// hard coded reesult format to always ask for binary output from server -fn result_fmt_binary(buf: &mut BytesMut) { - buf.extend_from_slice(&[0, 1, 0, 1]); -} - -trait FromUsize: Sized { - fn from_usize(x: usize) -> Result; -} - -macro_rules! from_usize { - ($t:ty) => { - impl FromUsize for $t { - #[inline] - fn from_usize(x: usize) -> io::Result<$t> { - if x > <$t>::MAX as usize { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "value too large to transmit", - )) - } else { - Ok(x as $t) - } - } - } - }; -} - -from_usize!(i16); -from_usize!(u16); -from_usize!(i32); - -trait WriteSize { - fn put_i32_at(&mut self, offset: usize, num: i32); -} - -impl WriteSize for BytesMut { - fn put_i32_at(&mut self, offset: usize, num: i32) { - self[offset..offset + 4].copy_from_slice(&num.to_be_bytes()); - } -} diff --git a/postgres/src/protocol/message.rs b/postgres/src/protocol/message.rs new file mode 100644 index 000000000..9f6aebb7b --- /dev/null +++ b/postgres/src/protocol/message.rs @@ -0,0 +1,2 @@ +pub(crate) mod backend; +pub(crate) mod frontend; diff --git a/postgres/src/protocol/message/backend.rs b/postgres/src/protocol/message/backend.rs new file mode 100644 index 000000000..bde862b24 --- /dev/null +++ b/postgres/src/protocol/message/backend.rs @@ -0,0 +1,856 @@ +pub(crate) use postgres_protocol::message::backend::*; + +use core::{cmp, ops::Range}; + +use std::io; + +use byteorder::{BigEndian, ReadBytesExt}; +use fallible_iterator::FallibleIterator; +use xitca_io::bytes::{Buf, Bytes, BytesMut}; + +use crate::types::Oid; + +/// An enum representing Postgres backend messages. +#[non_exhaustive] +pub enum Message { + AuthenticationCleartextPassword, + AuthenticationGss, + AuthenticationKerberosV5, + AuthenticationMd5Password(AuthenticationMd5PasswordBody), + AuthenticationOk, + AuthenticationScmCredential, + AuthenticationSspi, + AuthenticationGssContinue(AuthenticationGssContinueBody), + AuthenticationSasl(AuthenticationSaslBody), + AuthenticationSaslContinue(AuthenticationSaslContinueBody), + AuthenticationSaslFinal(AuthenticationSaslFinalBody), + BackendKeyData(BackendKeyDataBody), + BindComplete, + CloseComplete, + CommandComplete(CommandCompleteBody), + CopyData(CopyDataBody), + CopyDone, + CopyInResponse(CopyInResponseBody), + CopyOutResponse(CopyOutResponseBody), + DataRow(DataRowBody), + EmptyQueryResponse, + ErrorResponse(ErrorResponseBody), + NoData, + NoticeResponse(NoticeResponseBody), + NotificationResponse(NotificationResponseBody), + ParameterDescription(ParameterDescriptionBody), + ParameterStatus(ParameterStatusBody), + ParseComplete, + PortalSuspended, + ReadyForQuery(ReadyForQueryBody), + RowDescription(RowDescriptionBody), +} + +impl Message { + pub(crate) fn parse(buf: &mut BytesMut) -> io::Result> { + let Some(msg) = MessageRaw::parse(buf)? else { + return Ok(None); + }; + msg.try_into_message().map(Some) + } +} + +macro_rules! empty_check { + ($buf: ident) => { + debug_assert!( + $buf.slice().is_empty(), + "invalid message length: expected buffer to be empty" + ); + }; +} + +pub(crate) struct MessageRaw { + pub(crate) buf: BytesMut, +} + +impl MessageRaw { + #[inline(always)] + pub(crate) fn tag(&self) -> u8 { + *self.buf.get(0).expect("MessageRaw::parse produced illformed data type") + } + + pub(crate) fn parse(buf: &mut BytesMut) -> io::Result> { + let Some(header) = Header::parse(buf)? else { + return Ok(None); + }; + + let len = header.len() as usize + 1; + + Ok(if buf.len() < len { + None + } else { + Some(Self { buf: buf.split_to(len) }) + }) + } + + #[inline] + pub(crate) fn try_into_message(self) -> io::Result { + let tag = self.tag(); + + let mut buf = Buffer { + bytes: self.buf.freeze(), + idx: 5, + }; + + let message = match tag { + PARSE_COMPLETE_TAG => Message::ParseComplete, + BIND_COMPLETE_TAG => Message::BindComplete, + CLOSE_COMPLETE_TAG => Message::CloseComplete, + NOTIFICATION_RESPONSE_TAG => { + let process_id = buf.read_i32::()?; + let channel = buf.read_cstr()?; + let message = buf.read_cstr()?; + empty_check!(buf); + Message::NotificationResponse(NotificationResponseBody { + process_id, + channel, + message, + }) + } + COPY_DONE_TAG => Message::CopyDone, + COMMAND_COMPLETE_TAG => { + let tag = buf.read_cstr()?; + empty_check!(buf); + Message::CommandComplete(CommandCompleteBody { tag }) + } + COPY_DATA_TAG => { + let storage = buf.read_all(); + Message::CopyData(CopyDataBody { storage }) + } + DATA_ROW_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::DataRow(DataRowBody { storage, len }) + } + ERROR_RESPONSE_TAG => { + let storage = buf.read_all(); + Message::ErrorResponse(ErrorResponseBody { storage }) + } + COPY_IN_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyInResponse(CopyInResponseBody { format, len, storage }) + } + COPY_OUT_RESPONSE_TAG => { + let format = buf.read_u8()?; + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::CopyOutResponse(CopyOutResponseBody { format, len, storage }) + } + EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, + BACKEND_KEY_DATA_TAG => { + let process_id = buf.read_i32::()?; + let secret_key = buf.read_i32::()?; + empty_check!(buf); + Message::BackendKeyData(BackendKeyDataBody { process_id, secret_key }) + } + NO_DATA_TAG => Message::NoData, + NOTICE_RESPONSE_TAG => { + let storage = buf.read_all(); + Message::NoticeResponse(NoticeResponseBody { storage }) + } + AUTHENTICATION_TAG => match buf.read_i32::()? { + 0 => { + empty_check!(buf); + Message::AuthenticationOk + } + 2 => { + empty_check!(buf); + Message::AuthenticationKerberosV5 + } + 3 => { + empty_check!(buf); + Message::AuthenticationCleartextPassword + } + 5 => { + let mut salt = [0; 4]; + io::Read::read_exact(&mut buf, &mut salt)?; + empty_check!(buf); + Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt }) + } + 6 => { + empty_check!(buf); + Message::AuthenticationScmCredential + } + 7 => { + empty_check!(buf); + Message::AuthenticationGss + } + 8 => { + let storage = buf.read_all(); + Message::AuthenticationGssContinue(AuthenticationGssContinueBody(storage)) + } + 9 => { + empty_check!(buf); + Message::AuthenticationSspi + } + 10 => { + let storage = buf.read_all(); + Message::AuthenticationSasl(AuthenticationSaslBody(storage)) + } + 11 => { + let storage = buf.read_all(); + Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage)) + } + 12 => { + let storage = buf.read_all(); + Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage)) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown authentication tag `{tag}`"), + )); + } + }, + PORTAL_SUSPENDED_TAG => { + empty_check!(buf); + Message::PortalSuspended + } + PARAMETER_STATUS_TAG => { + let name = buf.read_cstr()?; + let value = buf.read_cstr()?; + empty_check!(buf); + Message::ParameterStatus(ParameterStatusBody { name, value }) + } + PARAMETER_DESCRIPTION_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::ParameterDescription(ParameterDescriptionBody { storage, len }) + } + ROW_DESCRIPTION_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::RowDescription(RowDescriptionBody { storage, len }) + } + READY_FOR_QUERY_TAG => { + let status = buf.read_u8()?; + empty_check!(buf); + Message::ReadyForQuery(ReadyForQueryBody { status }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown message tag `{tag}`"), + )); + } + }; + + Ok(message) + } +} + +struct Buffer { + bytes: Bytes, + idx: usize, +} + +impl Buffer { + #[inline(always)] + fn slice(&self) -> &[u8] { + &self.bytes[self.idx..] + } + + #[inline] + fn read_cstr(&mut self) -> io::Result { + match memchr::memchr(0, self.slice()) { + Some(pos) => { + let start = self.idx; + let end = start + pos; + let cstr = self.bytes.slice(start..end); + self.idx = end + 1; + Ok(cstr) + } + None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")), + } + } + + #[inline(always)] + fn read_all(mut self) -> Bytes { + self.bytes.advance(self.idx); + self.bytes + } +} + +impl io::Read for Buffer { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = { + let slice = self.slice(); + let len = cmp::min(slice.len(), buf.len()); + buf[..len].copy_from_slice(&slice[..len]); + len + }; + self.idx += len; + Ok(len) + } +} + +pub struct AuthenticationMd5PasswordBody { + salt: [u8; 4], +} + +impl AuthenticationMd5PasswordBody { + #[inline] + pub fn salt(&self) -> [u8; 4] { + self.salt + } +} + +pub struct AuthenticationGssContinueBody(Bytes); + +impl AuthenticationGssContinueBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct AuthenticationSaslBody(Bytes); + +impl AuthenticationSaslBody { + #[inline] + pub fn mechanisms(&self) -> SaslMechanisms<'_> { + SaslMechanisms(&self.0) + } +} + +pub struct SaslMechanisms<'a>(&'a [u8]); + +impl<'a> FallibleIterator for SaslMechanisms<'a> { + type Item = &'a str; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + let value_end = find_null(self.0, 0)?; + if value_end == 0 { + if self.0.len() != 1 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: expected to be at end of iterator for sasl", + )); + } + Ok(None) + } else { + let value = get_str(&self.0[..value_end])?; + self.0 = &self.0[value_end + 1..]; + Ok(Some(value)) + } + } +} + +pub struct AuthenticationSaslContinueBody(Bytes); + +impl AuthenticationSaslContinueBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct AuthenticationSaslFinalBody(Bytes); + +impl AuthenticationSaslFinalBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct BackendKeyDataBody { + process_id: i32, + secret_key: i32, +} + +impl BackendKeyDataBody { + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn secret_key(&self) -> i32 { + self.secret_key + } +} + +pub struct CommandCompleteBody { + tag: Bytes, +} + +impl CommandCompleteBody { + #[inline] + pub fn tag(&self) -> io::Result<&str> { + get_str(&self.tag) + } +} + +pub struct CopyDataBody { + storage: Bytes, +} + +impl CopyDataBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.storage + } + + #[inline] + pub fn into_bytes(self) -> Bytes { + self.storage + } +} + +pub struct CopyInResponseBody { + format: u8, + len: u16, + storage: Bytes, +} + +impl CopyInResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + +pub struct ColumnFormats<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl FallibleIterator for ColumnFormats<'_> { + type Item = u16; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: wrong column formats", + )); + } + } + + self.remaining -= 1; + self.buf.read_u16::().map(Some) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct CopyOutResponseBody { + format: u8, + len: u16, + storage: Bytes, +} + +impl CopyOutResponseBody { + #[inline] + pub fn format(&self) -> u8 { + self.format + } + + #[inline] + pub fn column_formats(&self) -> ColumnFormats<'_> { + ColumnFormats { + remaining: self.len, + buf: &self.storage, + } + } +} + +#[derive(Debug, Clone)] +pub struct DataRowBody { + storage: Bytes, + len: u16, +} + +impl DataRowBody { + #[inline] + pub fn ranges(&self) -> DataRowRanges<'_> { + DataRowRanges { + buf: &self.storage, + len: self.storage.len(), + remaining: self.len, + } + } + + #[inline] + pub fn buffer(&self) -> &[u8] { + &self.storage + } + + #[inline] + pub fn buffer_bytes(&self) -> &Bytes { + &self.storage + } +} + +pub struct DataRowRanges<'a> { + buf: &'a [u8], + len: usize, + remaining: u16, +} + +impl FallibleIterator for DataRowRanges<'_> { + type Item = Option>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: datarowrange is not empty", + )); + } + } + + self.remaining -= 1; + let len = self.buf.read_i32::()?; + if len < 0 { + Ok(Some(None)) + } else { + let len = len as usize; + if self.buf.len() < len { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")); + } + let base = self.len - self.buf.len(); + self.buf = &self.buf[len..]; + Ok(Some(Some(base..base + len))) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ErrorResponseBody { + storage: Bytes, +} + +impl ErrorResponseBody { + #[inline] + pub fn fields(&self) -> ErrorFields<'_> { + ErrorFields { buf: &self.storage } + } +} + +pub struct ErrorFields<'a> { + buf: &'a [u8], +} + +impl<'a> FallibleIterator for ErrorFields<'a> { + type Item = ErrorField<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + let type_ = self.buf.read_u8()?; + if type_ == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: error fields is not drained", + )); + } + } + + let value_end = find_null(self.buf, 0)?; + let value = &self.buf[..value_end]; + self.buf = &self.buf[value_end + 1..]; + + Ok(Some(ErrorField { type_, value })) + } +} + +pub struct ErrorField<'a> { + type_: u8, + value: &'a [u8], +} + +impl ErrorField<'_> { + #[inline] + pub fn type_(&self) -> u8 { + self.type_ + } + + #[inline] + #[deprecated(note = "use value_bytes instead", since = "0.6.7")] + pub fn value(&self) -> &str { + core::str::from_utf8(self.value).expect("error field value contained non-UTF8 bytes") + } + + #[inline] + pub fn value_bytes(&self) -> &[u8] { + self.value + } +} + +pub struct NoticeResponseBody { + storage: Bytes, +} + +impl NoticeResponseBody { + #[inline] + pub fn fields(&self) -> ErrorFields<'_> { + ErrorFields { buf: &self.storage } + } +} + +pub struct NotificationResponseBody { + process_id: i32, + channel: Bytes, + message: Bytes, +} + +impl NotificationResponseBody { + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn channel(&self) -> io::Result<&str> { + get_str(&self.channel) + } + + #[inline] + pub fn message(&self) -> io::Result<&str> { + get_str(&self.message) + } +} + +pub struct ParameterDescriptionBody { + storage: Bytes, + len: u16, +} + +impl ParameterDescriptionBody { + #[inline] + pub fn parameters(&self) -> Parameters<'_> { + Parameters { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Parameters<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl FallibleIterator for Parameters<'_> { + type Item = Oid; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parameters is not drained", + )); + } + } + + self.remaining -= 1; + self.buf.read_u32::().map(Some) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ParameterStatusBody { + name: Bytes, + value: Bytes, +} + +impl ParameterStatusBody { + #[inline] + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + pub fn value(&self) -> io::Result<&str> { + get_str(&self.value) + } +} + +pub struct ReadyForQueryBody { + status: u8, +} + +impl ReadyForQueryBody { + #[inline] + pub fn status(&self) -> u8 { + self.status + } +} + +pub struct RowDescriptionBody { + storage: Bytes, + len: u16, +} + +impl RowDescriptionBody { + #[inline] + pub fn fields(&self) -> Fields<'_> { + Fields { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Fields<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl<'a> FallibleIterator for Fields<'a> { + type Item = Field<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: field is not drained", + )); + } + } + + self.remaining -= 1; + let name_end = find_null(self.buf, 0)?; + let name = get_str(&self.buf[..name_end])?; + self.buf = &self.buf[name_end + 1..]; + let table_oid = self.buf.read_u32::()?; + let column_id = self.buf.read_i16::()?; + let type_oid = self.buf.read_u32::()?; + let type_size = self.buf.read_i16::()?; + let type_modifier = self.buf.read_i32::()?; + let format = self.buf.read_i16::()?; + + Ok(Some(Field { + name, + table_oid, + column_id, + type_oid, + type_size, + type_modifier, + format, + })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct Field<'a> { + name: &'a str, + table_oid: Oid, + column_id: i16, + type_oid: Oid, + type_size: i16, + type_modifier: i32, + format: i16, +} + +impl<'a> Field<'a> { + #[inline] + pub fn name(&self) -> &'a str { + self.name + } + + #[inline] + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + #[inline] + pub fn column_id(&self) -> i16 { + self.column_id + } + + #[inline] + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + #[inline] + pub fn type_size(&self) -> i16 { + self.type_size + } + + #[inline] + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } + + #[inline] + pub fn format(&self) -> i16 { + self.format + } +} + +#[inline] +fn find_null(buf: &[u8], start: usize) -> io::Result { + match memchr::memchr(0, &buf[start..]) { + Some(pos) => Ok(pos + start), + None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")), + } +} + +#[inline] +fn get_str(buf: &[u8]) -> io::Result<&str> { + core::str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) +} diff --git a/postgres/src/protocol/message/frontend.rs b/postgres/src/protocol/message/frontend.rs new file mode 100644 index 000000000..45dd9e598 --- /dev/null +++ b/postgres/src/protocol/message/frontend.rs @@ -0,0 +1,165 @@ +pub(crate) use postgres_protocol::message::frontend::*; + +use xitca_io::bytes::{BufMut, BytesMut}; + +use std::{convert::Infallible, io}; + +use crate::{error::Error, protocol::IsNull}; + +// optimized version of protocol functions depending on specific inputs + +// optimized version of postgres_protocol::frontend::bind +pub(crate) fn bind( + portal_name: &str, + stmt_name: &str, + formats: I, + values: J, + mut serializer: F, + buf: &mut BytesMut, +) -> Result<(), Error> +where + I: ExactSizeIterator, + J: ExactSizeIterator, + F: FnMut(T, &mut BytesMut) -> Result, +{ + buf.put_u8(b'B'); + write_body(buf, |buf| { + write_cstr(portal_name, buf); + write_cstr(stmt_name, buf); + write_counted( + formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, Infallible>(()) + }, + buf, + )?; + write_counted(values, |v, buf| write_nullable(|buf| serializer(v, buf), buf), buf)?; + result_fmt_binary(buf); + Ok(()) + }) +} + +// optimized version of postgres_protocol::frontend::execute +pub(crate) fn execute(portal_name: &str, max_rows: i32, buf: &mut BytesMut) -> Result<(), Error> { + buf.put_u8(b'E'); + write_body(buf, |buf| { + write_cstr(portal_name, buf); + buf.put_i32(max_rows); + Ok(()) + }) +} + +// optimized version of postgres_protocol::frontend::close +pub(crate) fn close(variant: u8, name: &str, buf: &mut BytesMut) -> Result<(), Error> { + buf.put_u8(b'C'); + write_body(buf, |buf| { + buf.put_u8(variant); + write_cstr(name, buf); + Ok(()) + }) +} + +// optimized version of postgres_protocol::frontend::sync +pub(crate) fn sync(buf: &mut BytesMut) { + buf.extend_from_slice(&[b'S', 0, 0, 0, 4]); +} + +fn write_body(buf: &mut BytesMut, f: F) -> Result<(), Error> +where + F: FnOnce(&mut BytesMut) -> Result<(), Error>, +{ + let base = buf.len(); + buf.put_i32(0); + + f(buf)?; + + let size = FromUsize::from_usize(buf.len() - base)?; + buf.put_i32_at(base, size); + Ok(()) +} + +fn write_counted(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), Error> +where + I: ExactSizeIterator, + F: FnMut(T, &mut BytesMut) -> Result<(), E>, + Error: From, +{ + let count = FromUsize::from_usize(items.len())?; + buf.put_u16(count); + + for item in items { + serializer(item, buf)?; + } + + Ok(()) +} + +fn write_nullable(serializer: F, buf: &mut BytesMut) -> Result<(), Error> +where + F: FnOnce(&mut BytesMut) -> Result, +{ + let base = buf.len(); + buf.put_i32(0); + + let size = match serializer(buf)? { + IsNull::No => FromUsize::from_usize(buf.len() - base - 4)?, + IsNull::Yes => -1, + }; + buf.put_i32_at(base, size); + + Ok(()) +} + +fn write_cstr(s: &str, buf: &mut BytesMut) { + // strings used inside library dont contain c style null bytes. + // debug assertions is enough for catching violation. + #[cfg(debug_assertions)] + if s.as_bytes().contains(&0) { + panic!("input string: {s} contains embedded null") + } + + buf.put_slice(s.as_bytes()); + buf.put_u8(0); +} + +// hard coded reesult format to always ask for binary output from server +fn result_fmt_binary(buf: &mut BytesMut) { + buf.extend_from_slice(&[0, 1, 0, 1]); +} + +trait FromUsize: Sized { + fn from_usize(x: usize) -> Result; +} + +macro_rules! from_usize { + ($t:ty) => { + impl FromUsize for $t { + #[inline] + fn from_usize(x: usize) -> io::Result<$t> { + if x > <$t>::MAX as usize { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "value too large to transmit", + )) + } else { + Ok(x as $t) + } + } + } + }; +} + +from_usize!(i16); +from_usize!(u16); +from_usize!(i32); + +trait WriteSize { + fn put_i32_at(&mut self, offset: usize, num: i32); +} + +impl WriteSize for BytesMut { + fn put_i32_at(&mut self, offset: usize, num: i32) { + self[offset..offset + 4].copy_from_slice(&num.to_be_bytes()); + } +} From c222bae22b1be1a334943dd0610fc273eccd5af2 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 3 Mar 2026 15:31:08 +0800 Subject: [PATCH 02/21] clippy fix --- postgres/src/protocol/message/backend.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres/src/protocol/message/backend.rs b/postgres/src/protocol/message/backend.rs index bde862b24..aa5c5b640 100644 --- a/postgres/src/protocol/message/backend.rs +++ b/postgres/src/protocol/message/backend.rs @@ -71,7 +71,7 @@ pub(crate) struct MessageRaw { impl MessageRaw { #[inline(always)] pub(crate) fn tag(&self) -> u8 { - *self.buf.get(0).expect("MessageRaw::parse produced illformed data type") + *self.buf.first().expect("MessageRaw::parse produced illformed data type") } pub(crate) fn parse(buf: &mut BytesMut) -> io::Result> { From f2dd7ee99d250af93da3079fdf91e5b754dd7407 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 3 Mar 2026 15:32:29 +0800 Subject: [PATCH 03/21] more empty buf check --- postgres/src/protocol/message/backend.rs | 25 +++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/postgres/src/protocol/message/backend.rs b/postgres/src/protocol/message/backend.rs index aa5c5b640..42af6ed49 100644 --- a/postgres/src/protocol/message/backend.rs +++ b/postgres/src/protocol/message/backend.rs @@ -71,7 +71,10 @@ pub(crate) struct MessageRaw { impl MessageRaw { #[inline(always)] pub(crate) fn tag(&self) -> u8 { - *self.buf.first().expect("MessageRaw::parse produced illformed data type") + *self + .buf + .first() + .expect("MessageRaw::parse produced illformed data type") } pub(crate) fn parse(buf: &mut BytesMut) -> io::Result> { @@ -98,9 +101,18 @@ impl MessageRaw { }; let message = match tag { - PARSE_COMPLETE_TAG => Message::ParseComplete, - BIND_COMPLETE_TAG => Message::BindComplete, - CLOSE_COMPLETE_TAG => Message::CloseComplete, + PARSE_COMPLETE_TAG => { + empty_check!(buf); + Message::ParseComplete + } + BIND_COMPLETE_TAG => { + empty_check!(buf); + Message::BindComplete + } + CLOSE_COMPLETE_TAG => { + empty_check!(buf); + Message::CloseComplete + } NOTIFICATION_RESPONSE_TAG => { let process_id = buf.read_i32::()?; let channel = buf.read_cstr()?; @@ -112,7 +124,10 @@ impl MessageRaw { message, }) } - COPY_DONE_TAG => Message::CopyDone, + COPY_DONE_TAG => { + empty_check!(buf); + Message::CopyDone + } COMMAND_COMPLETE_TAG => { let tag = buf.read_cstr()?; empty_check!(buf); From 7c19b648f8fc82ef8d4fac0ba8339756d3551c64 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 3 Mar 2026 18:10:57 +0800 Subject: [PATCH 04/21] style clean up --- postgres/src/driver/codec.rs | 2 +- postgres/src/driver/generic.rs | 8 +- postgres/src/protocol/message/backend.rs | 156 +++++++++-------------- 3 files changed, 65 insertions(+), 101 deletions(-) diff --git a/postgres/src/driver/codec.rs b/postgres/src/driver/codec.rs index b17418c22..88cae94ef 100644 --- a/postgres/src/driver/codec.rs +++ b/postgres/src/driver/codec.rs @@ -118,7 +118,7 @@ pub(super) struct BytesMessage { impl BytesMessage { #[cold] #[inline(never)] - pub(super) fn parse_error(self) -> Error { + pub(super) fn into_error(self) -> Error { match self.msg.try_into_message() { Err(e) => Error::from(e), Ok(backend::Message::ErrorResponse(body)) => Error::db(body.fields()), diff --git a/postgres/src/driver/generic.rs b/postgres/src/driver/generic.rs index 0fa15e8f7..83fd6feaa 100644 --- a/postgres/src/driver/generic.rs +++ b/postgres/src/driver/generic.rs @@ -355,14 +355,14 @@ impl DriverRx { // lock the shared state only when needed and keep the lock around a bit for possible multiple messages let inner = guard.get_or_insert_with(|| self.guarded.lock().unwrap()); - match inner.res.pop_front() { + match inner.res.front_mut() { Some(tx) => { let _ = tx.send(msg.msg); - if !msg.complete { - inner.res.push_front(tx); + if msg.complete { + inner.res.pop_front(); } } - None => return Err(msg.parse_error()), + None => return Err(msg.into_error()), } } ResponseMessage::Async(msg) => return Ok(Some(msg)), diff --git a/postgres/src/protocol/message/backend.rs b/postgres/src/protocol/message/backend.rs index 42af6ed49..0364c0ae3 100644 --- a/postgres/src/protocol/message/backend.rs +++ b/postgres/src/protocol/message/backend.rs @@ -6,7 +6,7 @@ use std::io; use byteorder::{BigEndian, ReadBytesExt}; use fallible_iterator::FallibleIterator; -use xitca_io::bytes::{Buf, Bytes, BytesMut}; +use xitca_io::bytes::{Bytes, BytesMut}; use crate::types::Oid; @@ -55,17 +55,8 @@ impl Message { } } -macro_rules! empty_check { - ($buf: ident) => { - debug_assert!( - $buf.slice().is_empty(), - "invalid message length: expected buffer to be empty" - ); - }; -} - pub(crate) struct MessageRaw { - pub(crate) buf: BytesMut, + buf: BytesMut, } impl MessageRaw { @@ -101,36 +92,22 @@ impl MessageRaw { }; let message = match tag { - PARSE_COMPLETE_TAG => { - empty_check!(buf); - Message::ParseComplete - } - BIND_COMPLETE_TAG => { - empty_check!(buf); - Message::BindComplete - } - CLOSE_COMPLETE_TAG => { - empty_check!(buf); - Message::CloseComplete - } + PARSE_COMPLETE_TAG => Message::ParseComplete, + BIND_COMPLETE_TAG => Message::BindComplete, + CLOSE_COMPLETE_TAG => Message::CloseComplete, NOTIFICATION_RESPONSE_TAG => { let process_id = buf.read_i32::()?; let channel = buf.read_cstr()?; let message = buf.read_cstr()?; - empty_check!(buf); Message::NotificationResponse(NotificationResponseBody { process_id, channel, message, }) } - COPY_DONE_TAG => { - empty_check!(buf); - Message::CopyDone - } + COPY_DONE_TAG => Message::CopyDone, COMMAND_COMPLETE_TAG => { let tag = buf.read_cstr()?; - empty_check!(buf); Message::CommandComplete(CommandCompleteBody { tag }) } COPY_DATA_TAG => { @@ -162,7 +139,6 @@ impl MessageRaw { BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; let secret_key = buf.read_i32::()?; - empty_check!(buf); Message::BackendKeyData(BackendKeyDataBody { process_id, secret_key }) } NO_DATA_TAG => Message::NoData, @@ -171,40 +147,21 @@ impl MessageRaw { Message::NoticeResponse(NoticeResponseBody { storage }) } AUTHENTICATION_TAG => match buf.read_i32::()? { - 0 => { - empty_check!(buf); - Message::AuthenticationOk - } - 2 => { - empty_check!(buf); - Message::AuthenticationKerberosV5 - } - 3 => { - empty_check!(buf); - Message::AuthenticationCleartextPassword - } + 0 => Message::AuthenticationOk, + 2 => Message::AuthenticationKerberosV5, + 3 => Message::AuthenticationCleartextPassword, 5 => { let mut salt = [0; 4]; io::Read::read_exact(&mut buf, &mut salt)?; - empty_check!(buf); Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt }) } - 6 => { - empty_check!(buf); - Message::AuthenticationScmCredential - } - 7 => { - empty_check!(buf); - Message::AuthenticationGss - } + 6 => Message::AuthenticationScmCredential, + 7 => Message::AuthenticationGss, 8 => { let storage = buf.read_all(); Message::AuthenticationGssContinue(AuthenticationGssContinueBody(storage)) } - 9 => { - empty_check!(buf); - Message::AuthenticationSspi - } + 9 => Message::AuthenticationSspi, 10 => { let storage = buf.read_all(); Message::AuthenticationSasl(AuthenticationSaslBody(storage)) @@ -224,14 +181,10 @@ impl MessageRaw { )); } }, - PORTAL_SUSPENDED_TAG => { - empty_check!(buf); - Message::PortalSuspended - } + PORTAL_SUSPENDED_TAG => Message::PortalSuspended, PARAMETER_STATUS_TAG => { let name = buf.read_cstr()?; let value = buf.read_cstr()?; - empty_check!(buf); Message::ParameterStatus(ParameterStatusBody { name, value }) } PARAMETER_DESCRIPTION_TAG => { @@ -246,7 +199,6 @@ impl MessageRaw { } READY_FOR_QUERY_TAG => { let status = buf.read_u8()?; - empty_check!(buf); Message::ReadyForQuery(ReadyForQueryBody { status }) } tag => { @@ -257,6 +209,12 @@ impl MessageRaw { } }; + #[cfg(debug_assertions)] + assert!( + buf.slice().is_empty(), + "invalid message length: expected buffer to be empty" + ); + Ok(message) } } @@ -274,23 +232,29 @@ impl Buffer { #[inline] fn read_cstr(&mut self) -> io::Result { - match memchr::memchr(0, self.slice()) { - Some(pos) => { - let start = self.idx; - let end = start + pos; - let cstr = self.bytes.slice(start..end); - self.idx = end + 1; - Ok(cstr) - } - None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")), - } + let pos = memchr::memchr(0, self.slice()) + .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))?; + let start = self.idx; + let end = start + pos; + let cstr = self.bytes.slice(start..end); + self.idx = end + 1; + Ok(cstr) } + #[cfg(not(debug_assertions))] #[inline(always)] fn read_all(mut self) -> Bytes { + use xitca_io::bytes::Buf; self.bytes.advance(self.idx); self.bytes } + + #[cfg(debug_assertions)] + fn read_all(&mut self) -> Bytes { + let buf = self.bytes.slice(self.idx..); + self.idx = self.bytes.len(); + buf + } } impl io::Read for Buffer { @@ -456,14 +420,14 @@ impl FallibleIterator for ColumnFormats<'_> { #[inline] fn next(&mut self) -> io::Result> { if self.remaining == 0 { - if self.buf.is_empty() { - return Ok(None); + return if self.buf.is_empty() { + Ok(None) } else { - return Err(io::Error::new( + Err(io::Error::new( io::ErrorKind::InvalidInput, "invalid message length: wrong column formats", - )); - } + )) + }; } self.remaining -= 1; @@ -538,14 +502,14 @@ impl FallibleIterator for DataRowRanges<'_> { #[inline] fn next(&mut self) -> io::Result>>> { if self.remaining == 0 { - if self.buf.is_empty() { - return Ok(None); + return if self.buf.is_empty() { + Ok(None) } else { - return Err(io::Error::new( + Err(io::Error::new( io::ErrorKind::InvalidInput, "invalid message length: datarowrange is not empty", - )); - } + )) + }; } self.remaining -= 1; @@ -593,14 +557,14 @@ impl<'a> FallibleIterator for ErrorFields<'a> { fn next(&mut self) -> io::Result>> { let type_ = self.buf.read_u8()?; if type_ == 0 { - if self.buf.is_empty() { - return Ok(None); + return if self.buf.is_empty() { + Ok(None) } else { - return Err(io::Error::new( + Err(io::Error::new( io::ErrorKind::InvalidInput, "invalid message length: error fields is not drained", - )); - } + )) + }; } let value_end = find_null(self.buf, 0)?; @@ -695,14 +659,14 @@ impl FallibleIterator for Parameters<'_> { #[inline] fn next(&mut self) -> io::Result> { if self.remaining == 0 { - if self.buf.is_empty() { - return Ok(None); + return if self.buf.is_empty() { + Ok(None) } else { - return Err(io::Error::new( + Err(io::Error::new( io::ErrorKind::InvalidInput, "invalid message length: parameters is not drained", - )); - } + )) + }; } self.remaining -= 1; @@ -771,14 +735,14 @@ impl<'a> FallibleIterator for Fields<'a> { #[inline] fn next(&mut self) -> io::Result>> { if self.remaining == 0 { - if self.buf.is_empty() { - return Ok(None); + return if self.buf.is_empty() { + Ok(None) } else { - return Err(io::Error::new( + Err(io::Error::new( io::ErrorKind::InvalidInput, "invalid message length: field is not drained", - )); - } + )) + }; } self.remaining -= 1; From 13e30a783e6e71e34f2b08c6e96fb685cb725306 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Thu, 5 Mar 2026 22:39:16 +0800 Subject: [PATCH 05/21] avoid unnecessary waker clone --- postgres/src/driver/codec.rs | 2 +- postgres/src/driver/generic.rs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/postgres/src/driver/codec.rs b/postgres/src/driver/codec.rs index 88cae94ef..1496b5a4b 100644 --- a/postgres/src/driver/codec.rs +++ b/postgres/src/driver/codec.rs @@ -98,7 +98,7 @@ impl Response { } // Extract the number of rows affected. -pub(crate) fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Result { +fn body_to_affected_rows(body: &backend::CommandCompleteBody) -> Result { body.tag() .map_err(|_| Error::todo()) .map(|r| r.rsplit(' ').next().unwrap().parse().unwrap_or(0)) diff --git a/postgres/src/driver/generic.rs b/postgres/src/driver/generic.rs index 83fd6feaa..1ccf17d74 100644 --- a/postgres/src/driver/generic.rs +++ b/postgres/src/driver/generic.rs @@ -132,7 +132,10 @@ pub(super) struct State { impl State { fn register(&mut self, waker: &Waker) { - self.waker = Some(waker.clone()); + match self.waker { + Some(ref w) if w.will_wake(waker) => {} + _ => self.waker = Some(waker.clone()), + }; } fn wake(&mut self) { From 2216bb6e599cd88db3666996a08c3a7062fb4b0e Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 10 Mar 2026 11:30:43 +0800 Subject: [PATCH 06/21] backend protocol optimization --- postgres/src/protocol/message/backend.rs | 29 ++++++++++++++---------- postgres/src/row/types.rs | 10 +------- postgres/src/session.rs | 4 ++-- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/postgres/src/protocol/message/backend.rs b/postgres/src/protocol/message/backend.rs index 0364c0ae3..f1a9564d9 100644 --- a/postgres/src/protocol/message/backend.rs +++ b/postgres/src/protocol/message/backend.rs @@ -308,7 +308,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> { #[inline] fn next(&mut self) -> io::Result> { - let value_end = find_null(self.0, 0)?; + let value_end = find_null(self.0)?; if value_end == 0 { if self.0.len() != 1 { return Err(io::Error::new( @@ -496,11 +496,11 @@ pub struct DataRowRanges<'a> { } impl FallibleIterator for DataRowRanges<'_> { - type Item = Option>; + type Item = Range; type Error = io::Error; #[inline] - fn next(&mut self) -> io::Result>>> { + fn next(&mut self) -> io::Result>> { if self.remaining == 0 { return if self.buf.is_empty() { Ok(None) @@ -515,7 +515,15 @@ impl FallibleIterator for DataRowRanges<'_> { self.remaining -= 1; let len = self.buf.read_i32::()?; if len < 0 { - Ok(Some(None)) + /* + an empty Range value is used to represent null pg value offsets inside row's raw + data buffer. + when empty range is used to slice data collection through a safe Rust API(`<&[u8]>::get(Range)` + in this case) it always produce Option type where the None variant can be used as final output of null + pg value. + this saves 8 bytes per range storage + */ + Ok(Some(Range { start: 1, end: 0 })) } else { let len = len as usize; if self.buf.len() < len { @@ -523,7 +531,7 @@ impl FallibleIterator for DataRowRanges<'_> { } let base = self.len - self.buf.len(); self.buf = &self.buf[len..]; - Ok(Some(Some(base..base + len))) + Ok(Some(base..base + len)) } } @@ -567,7 +575,7 @@ impl<'a> FallibleIterator for ErrorFields<'a> { }; } - let value_end = find_null(self.buf, 0)?; + let value_end = find_null(self.buf)?; let value = &self.buf[..value_end]; self.buf = &self.buf[value_end + 1..]; @@ -746,7 +754,7 @@ impl<'a> FallibleIterator for Fields<'a> { } self.remaining -= 1; - let name_end = find_null(self.buf, 0)?; + let name_end = find_null(self.buf)?; let name = get_str(&self.buf[..name_end])?; self.buf = &self.buf[name_end + 1..]; let table_oid = self.buf.read_u32::()?; @@ -822,11 +830,8 @@ impl<'a> Field<'a> { } #[inline] -fn find_null(buf: &[u8], start: usize) -> io::Result { - match memchr::memchr(0, &buf[start..]) { - Some(pos) => Ok(pos + start), - None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")), - } +fn find_null(buf: &[u8]) -> io::Result { + memchr::memchr(0, buf).ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")) } #[inline] diff --git a/postgres/src/row/types.rs b/postgres/src/row/types.rs index 2b27dfff0..1052cc04c 100644 --- a/postgres/src/row/types.rs +++ b/postgres/src/row/types.rs @@ -57,15 +57,7 @@ where ranges_mut.clear(); while let Some(range) = iter.next()? { - /* - when unwrapping the Range an empty value is used to represent null pg value offsets inside row's raw - data buffer. - when empty range is used to slice data collection through a safe Rust API(`<&[u8]>::get(Range)` - in this case) it always produce Option type where the None variant can be used as final output of null - pg value. - this saves 8 bytes per range storage - */ - ranges_mut.push(range.unwrap_or(Range { start: 1, end: 0 })); + ranges_mut.push(range); } Ok(Self { diff --git a/postgres/src/session.rs b/postgres/src/session.rs index 06d92a18f..43977aa03 100644 --- a/postgres/src/session.rs +++ b/postgres/src/session.rs @@ -116,8 +116,8 @@ impl Session { loop { match drv.recv().await? { backend::Message::DataRow(body) => { - let range = body.ranges().next()?.flatten().ok_or(Error::todo())?; - let slice = &body.buffer()[range.start..range.end]; + let range = body.ranges().next()?.ok_or(Error::todo())?; + let slice = body.buffer().get(range).ok_or(Error::todo())?; match (slice, cfg.get_target_session_attrs()) { (b"on", TargetSessionAttrs::ReadWrite) => return Err(Error::todo()), (b"off", TargetSessionAttrs::ReadOnly) => return Err(Error::todo()), From 2618cb355590fb60018aef57e57f0cb4c4df84d3 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Mon, 30 Mar 2026 22:57:00 +0800 Subject: [PATCH 07/21] use low level UnbufferedConnection --- http/src/tls/rustls_uring.rs | 9 +- tls/src/rustls_uring.rs | 657 ++++++++++++++++++++++------------- 2 files changed, 409 insertions(+), 257 deletions(-) diff --git a/http/src/tls/rustls_uring.rs b/http/src/tls/rustls_uring.rs index d99dd65e8..d13eb287c 100644 --- a/http/src/tls/rustls_uring.rs +++ b/http/src/tls/rustls_uring.rs @@ -4,10 +4,7 @@ use std::{io, net::Shutdown, sync::Arc}; use xitca_io::io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; use xitca_service::Service; -use xitca_tls::{ - rustls::{ServerConfig, ServerConnection}, - rustls_uring::TlsStream as _TlsStream, -}; +use xitca_tls::rustls_uring::{ServerConfig, TlsStream as _TlsStream, UnbufferedServerConnection}; use crate::{http::Version, version::AsVersion}; @@ -15,7 +12,7 @@ use super::rustls::RustlsError; /// A stream managed by rustls for tls read/write. pub struct TlsStream { - inner: _TlsStream, + inner: _TlsStream, } impl AsVersion for TlsStream { @@ -60,7 +57,7 @@ where type Error = RustlsError; async fn call(&self, io: Io) -> Result { - let conn = ServerConnection::new(self.acceptor.clone())?; + let conn = UnbufferedServerConnection::new(self.acceptor.clone())?; let inner = _TlsStream::handshake(io, conn).await?; Ok(TlsStream { inner }) } diff --git a/tls/src/rustls_uring.rs b/tls/src/rustls_uring.rs index 63c8eca81..e1dd7c39a 100644 --- a/tls/src/rustls_uring.rs +++ b/tls/src/rustls_uring.rs @@ -1,51 +1,122 @@ #![allow(clippy::await_holding_refcell_ref)] // clippy is dumb -use core::{ - cell::RefCell, - ops::{Deref, DerefMut}, - slice, -}; +use core::{cell::RefCell, slice}; use std::{io, net::Shutdown, rc::Rc}; pub use rustls_crate::*; +use rustls_crate::{ + client::UnbufferedClientConnection, + server::UnbufferedServerConnection, + unbuffered::UnbufferedConnectionCommon, + unbuffered::{ConnectionState, EncryptError, UnbufferedStatus}, +}; + use xitca_io::{ bytes::{Buf, BytesMut}, io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, }; -use self::buf::WriteBuf; +/// Trait to abstract over `UnbufferedServerConnection` and `UnbufferedClientConnection`, +/// since `process_tls_records` is not on a shared trait in rustls. +#[doc(hidden)] +pub trait ProcessTlsRecords: sealed::Sealed { + type Data; -/// A tls stream type enable concurrent async read/write through [AsyncBufRead] and [AsyncBufWrite] -/// traits. -/// -/// # Panics -/// For now due to design limitation TlsStream offers concurrency with [AsyncBufRead::read] and -/// [AsyncBufWrite::write] but in either case the async function must run to completion and cancel -/// it prematurely would cause panic. -/// ``` -/// use xitca_io::{ -/// io_uring::{AsyncBufRead, AsyncBufWrite}, -/// net::io_uring::TcpStream -/// }; -/// use xitca_tls::rustls_uring::{ServerConnection, TlsStream}; -/// -/// async fn complete(stream: TlsStream) { -/// let _ = stream.read(vec![0; 128]).await; -/// let _ = stream.read(vec![0; 128]).await; // serialize read to complete is ok. -/// -/// let _ = stream.read(vec![0; 128]); -/// let _ = stream.write(vec![0; 128]); // concurrent read and write is ok. + fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data>; +} + +mod sealed { + pub trait Sealed {} + impl Sealed for super::UnbufferedServerConnection {} + impl Sealed for super::UnbufferedClientConnection {} +} + +impl ProcessTlsRecords for UnbufferedServerConnection { + type Data = server::ServerConnectionData; + + fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data> { + let inner: &mut UnbufferedConnectionCommon = self; + inner.process_tls_records(incoming_tls) + } +} + +impl ProcessTlsRecords for UnbufferedClientConnection { + type Data = client::ClientConnectionData; + + fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data> { + let inner: &mut UnbufferedConnectionCommon = self; + inner.process_tls_records(incoming_tls) + } +} + +/// Reduced `ConnectionState` that doesn't borrow the connection or incoming buffer. +/// Created by draining all needed data from the borrowed state variants. +enum State { + /// Read traffic was processed (plaintext drained by caller). + ReadTraffic, + /// Handshake data was encoded into write_buf. + EncodedTlsData, + /// Encoded data needs to be transmitted, then call done. + TransmitTlsData, + /// Need more ciphertext. + BlockedHandshake, + /// Handshake complete, ready to send app data. + WriteTraffic, + /// Peer sent close_notify or connection fully closed. + Closed, + /// Peer closed (edge-triggered). + PeerClosed, +} + +/// Process one round of TLS records and return an owned State plus the discard count. +/// This function handles `EncodeTlsData` inline (encoding into write_buf) and +/// `ReadTraffic` inline (draining plaintext into a BytesMut). +fn process_once( + conn: &mut C, + read_buf: &mut BytesMut, + write_buf: &mut BytesMut, +) -> io::Result { + let UnbufferedStatus { discard, state } = conn.process_tls_records(read_buf.as_mut()); + + let state = match state.map_err(tls_err)? { + ConnectionState::ReadTraffic(_) => State::ReadTraffic, + ConnectionState::EncodeTlsData(mut state) => { + encode_tls_data(&mut state, write_buf)?; + State::EncodedTlsData + } + ConnectionState::TransmitTlsData(state) => { + state.done(); + State::TransmitTlsData + } + ConnectionState::WriteTraffic(_) => { + // WriteTraffic may have pending TLS data (key updates). + // We don't encrypt here — caller decides. + State::WriteTraffic + } + ConnectionState::BlockedHandshake => State::BlockedHandshake, + ConnectionState::PeerClosed => State::PeerClosed, + ConnectionState::Closed => State::Closed, + _ => State::BlockedHandshake, // Unknown variants treated as needing more data. + }; + + // Discard consumed bytes from read_buf after all borrows are released. + read_buf.advance(discard); + + Ok(state) +} + +/// A TLS stream type that supports concurrent async read/write through [AsyncBufRead] and +/// [AsyncBufWrite] traits. /// -/// let read = stream.read(vec![0; 128]); -/// drop(read); // drop read without completion. -/// let read = stream.read(vec![0; 128]).await; // this line would cause panic. +/// [AsyncBufRead::read] and [AsyncBufWrite::write] can be polled concurrently from separate +/// tasks. The read path owns `read_buf` during IO and the write path owns `write_buf`, so +/// neither blocks the other while awaiting kernel completions. /// -/// let read = stream.read(vec![0; 128]); // making two concurrent read future. -/// let read = stream.read(vec![0; 128]); // this line would cause panic. -/// } -/// ``` +/// # Panics +/// Each async read/write operation must be polled to completion. Dropping a future before it +/// completes will leave internal buffers in a taken state, causing the next call to panic. pub struct TlsStream { io: Io, session: Rc>>, @@ -64,256 +135,299 @@ where } struct Session { - session: C, + conn: C, read_buf: Option, - write_buf: Option, + /// Write buffer for application data (used by write path). + write_buf: Option, + /// Write buffer for TLS protocol responses during reads (key updates, alerts). + proto_write_buf: BytesMut, + /// Plaintext buffered from a previous read. + pending_plaintext: BytesMut, } -impl Session +impl TlsStream where - C: DerefMut + Deref>, + C: ProcessTlsRecords, + Io: AsyncBufRead + AsyncBufWrite, { - fn read_plain(&mut self, buf: &mut impl BoundedBufMut) -> io::Result { - io::Read::read(&mut self.session.reader(), io_ref_mut_slice(buf)).inspect(|n| { - // SAFETY - // required by IoBufMut trait. when n bytes is write into buffer this method - // must be called to advance the initialized part of it. - unsafe { buf.set_init(*n) }; - }) - } - - fn write_plain(&mut self, buf: &impl BoundedBuf) -> io::Result { - let writer = &mut self.session.writer(); - let n = io::Write::write(writer, io_ref_slice(buf))?; - // keep this no op in case rustls change it's behavior. - io::Write::flush(writer).expect("::flush should be no op"); - Ok(n) + pub async fn handshake(io: Io, conn: C) -> io::Result { + let stream = TlsStream { + io, + session: Rc::new(RefCell::new(Session { + conn, + read_buf: Some(BytesMut::new()), + write_buf: Some(BytesMut::new()), + proto_write_buf: BytesMut::new(), + pending_plaintext: BytesMut::new(), + })), + }; + stream._handshake().await?; + Ok(stream) } -} -impl TlsStream -where - C: DerefMut + Deref>, - S: SideData, - Io: AsyncBufRead, -{ - async fn read_tls(&self) -> io::Result { + async fn _handshake(&self) -> io::Result<()> { let mut session = self.session.borrow_mut(); + let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); + let mut proto_write_buf = session.proto_write_buf.split(); - let mut buf = session.read_buf.take().expect(POLL_TO_COMPLETE); + let res = loop { + let res = match process_once(&mut session.conn, &mut read_buf, &mut proto_write_buf) { + Err(e) => Err(e), - if buf.is_empty() { - drop(session); + // Continue processing — more handshake data may follow. + Ok(State::EncodedTlsData) => continue, - let rem = buf.capacity() - buf.len(); - if rem < 4096 { - buf.reserve(4096 - rem); - } + Ok(State::TransmitTlsData) => { + let (res, b) = write_all_buf(&self.io, proto_write_buf).await; + proto_write_buf = b; + res + } + + Ok(State::BlockedHandshake) => { + let (res, b) = read_to_buf(&self.io, read_buf).await; + read_buf = b; + res + } - let (res, b) = self.io.read(buf).await; - buf = b; + Ok(State::WriteTraffic | State::ReadTraffic) => break Ok(()), - session = self.session.borrow_mut(); + Ok(State::PeerClosed | State::Closed) => { + Err(io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof")) + } + }; - if res? == 0 { - session.read_buf.replace(buf); - return Ok(0); + if res.is_err() { + break res; } + }; + + session.read_buf.replace(read_buf); + session.proto_write_buf = proto_write_buf; + res + } + + /// Read ciphertext from IO, decrypt, and return plaintext. + async fn read_tls(&self, plain_buf: &mut impl BoundedBufMut) -> io::Result { + let mut session = self.session.borrow_mut(); + + // Check for plaintext buffered from a previous read first. + if !session.pending_plaintext.is_empty() { + let dst = io_ref_mut_slice(plain_buf); + let n = session.pending_plaintext.len().min(dst.len()); + dst[..n].copy_from_slice(&session.pending_plaintext[..n]); + session.pending_plaintext.advance(n); + unsafe { plain_buf.set_init(n) }; + return Ok(n); } - let res = session.session.read_tls(&mut buf.as_ref()).inspect(|n| buf.advance(*n)); + let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); - session.read_buf.replace(buf); + let res = loop { + // Call process_tls_records directly to copy record payload + // straight into the caller's buffer (no intermediate BytesMut). + let session_ref = &mut *session; - let n = res?; + let UnbufferedStatus { discard, state } = session_ref.conn.process_tls_records(read_buf.as_mut()); - let state = session - .session - .process_new_packets() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let res = match state.map_err(tls_err) { + Err(e) => { + read_buf.advance(discard); + break Err(e); + } - if state.peer_has_closed() && session.session.is_handshaking() { - return Err(io::ErrorKind::UnexpectedEof.into()); - } + Ok(ConnectionState::ReadTraffic(mut traffic)) => { + let dst = io_ref_mut_slice(plain_buf); + let mut written = 0; + + let mut err = None; + while let Some(res) = traffic.next_record() { + match res.map_err(tls_err) { + Ok(record) => { + let payload = record.payload; + let n = payload.len().min(dst.len() - written); + dst[written..written + n].copy_from_slice(&payload[..n]); + written += n; + // Buffer overflow into pending_plaintext. + if n < payload.len() { + session_ref.pending_plaintext.extend_from_slice(&payload[n..]); + } + } + Err(e) => { + err = Some(e); + break; + } + } + } + + drop(traffic); + read_buf.advance(discard); + + if let Some(e) = err { + break Err(e); + } + + // Empty plaintext means TLS overhead with no payload — keep going. + if written == 0 { + continue; + } + + unsafe { plain_buf.set_init(written) }; + break Ok(written); + } - Ok(n) - } -} + Ok(ConnectionState::EncodeTlsData(mut state)) => { + // Encode into proto_write_buf via session_ref (same borrow scope as state). + let enc_res = encode_tls_data(&mut state, &mut session_ref.proto_write_buf); + drop(state); + read_buf.advance(discard); -impl TlsStream -where - C: DerefMut + Deref>, - S: SideData, - Io: AsyncBufWrite, -{ - async fn write_tls(&self) -> io::Result { - let mut session = self.session.borrow_mut(); + if let Err(e) = enc_res { + break Err(e); + } + continue; + } - let mut write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); + Ok(ConnectionState::TransmitTlsData(state)) => { + // Data is in proto_write_buf. Acknowledge and continue — + // write_tls will flush it on the next write call. + state.done(); + read_buf.advance(discard); + continue; + } - let n = match session.session.write_tls(&mut write_buf) { - Ok(n) => n, - Err(e) => { - session.write_buf.replace(write_buf); - return Err(e); - } - }; + // Need more ciphertext. + Ok(ConnectionState::BlockedHandshake | ConnectionState::WriteTraffic(_)) => { + read_buf.advance(discard); - drop(session); + drop(session); - let (res, write_buf) = write_buf.write_io(&self.io).await; + let (res, b) = read_to_buf(&self.io, read_buf).await; + read_buf = b; - session = self.session.borrow_mut(); - session.write_buf.replace(write_buf); + session = self.session.borrow_mut(); - match res { - Ok(0) => Err(io::ErrorKind::UnexpectedEof.into()), - Ok(_) => Ok(n), - Err(e) => Err(e), - } - } -} + res + } -impl TlsStream -where - C: DerefMut + Deref>, - S: SideData, - Io: AsyncBufRead + AsyncBufWrite, -{ - pub async fn handshake(io: Io, session: C) -> io::Result { - let mut stream = TlsStream { - io, - session: Rc::new(RefCell::new(Session { - session, - read_buf: Some(BytesMut::new()), - write_buf: Some(WriteBuf::default()), - })), + Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { + read_buf.advance(discard); + break Ok(0); + } + + Ok(_) => { + read_buf.advance(discard); + continue; + } + }; + + if let Err(e) = res { + break Err(e); + } }; - stream._handshake().await?; - Ok(stream) - } - fn session_mut(&mut self) -> &mut C { - Rc::get_mut(&mut self.session) - .map(|session| &mut session.get_mut().session) - .expect("handshake must have exclusive ownership of TlsStream until it returns") + session.read_buf.replace(read_buf); + res } - pub(crate) async fn _handshake(&mut self) -> io::Result<(usize, usize)> { - let mut wrlen = 0; - let mut rdlen = 0; - let mut eof = false; + /// Encrypt plaintext and write ciphertext to IO. + async fn write_tls(&self, plain: &impl BoundedBuf) -> io::Result { + let mut session = self.session.borrow_mut(); + let mut write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); + let plaintext = io_ref_slice(plain); + + // Flush protocol data buffered by read path (key updates, alerts). + if !session.proto_write_buf.is_empty() { + write_buf.extend_from_slice(&session.proto_write_buf); + session.proto_write_buf.clear(); + } - loop { - while self.session_mut().wants_write() && self.session_mut().is_handshaking() { - wrlen += self.write_tls().await?; - } + let res = loop { + // Pass empty slice — write path doesn't process incoming TLS records. + // Incoming data (key updates, etc.) is handled by the read path. + let UnbufferedStatus { state, .. } = session.conn.process_tls_records(&mut []); + + match state.map_err(tls_err) { + Err(e) => break Err(e), + + Ok(ConnectionState::WriteTraffic(mut traffic)) => { + let enc_res = encrypt_to_buf(&mut traffic, plaintext, &mut write_buf); + + if let Err(e) = enc_res { + break Err(e); + } + + drop(session); + + let (res, b) = write_all_buf(&self.io, write_buf).await; + write_buf = b; + + session = self.session.borrow_mut(); - while !eof && self.session_mut().wants_read() && self.session_mut().is_handshaking() { - let n = self.read_tls().await?; - rdlen += n; - if n == 0 { - eof = true; + break res.map(|_| plaintext.len()); } - } - match (eof, self.session_mut().is_handshaking()) { - (true, true) => { - return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof")); + Ok(ConnectionState::EncodeTlsData(mut state)) => { + let enc_res = encode_tls_data(&mut state, &mut write_buf); + drop(state); + + if let Err(e) = enc_res { + break Err(e); + } } - (false, true) => {} - (_, false) => break, - }; - } - while self.session_mut().wants_write() { - wrlen += self.write_tls().await?; - } + Ok(ConnectionState::TransmitTlsData(state)) => { + state.done(); + + drop(session); + + let (res, b) = write_all_buf(&self.io, write_buf).await; + write_buf = b; + + session = self.session.borrow_mut(); + + if let Err(e) = res { + break Err(e); + } + } - Ok((rdlen, wrlen)) + Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { + break Err(io::ErrorKind::UnexpectedEof.into()); + } + + Ok(_) => {} + } + }; + + session.write_buf.replace(write_buf); + res } } -impl AsyncBufRead for TlsStream +impl AsyncBufRead for TlsStream where - C: DerefMut + Deref>, - S: SideData, - Io: AsyncBufRead, + C: ProcessTlsRecords, + Io: AsyncBufRead + AsyncBufWrite, { async fn read(&self, mut buf: B) -> (io::Result, B) where B: BoundedBufMut, { - let mut session = self.session.borrow_mut(); - - loop { - match session.read_plain(&mut buf) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - res => return (res, buf), - } - - drop(session); - - match self.read_tls().await { - Ok(0) => return (Err(io::ErrorKind::UnexpectedEof.into()), buf), - Ok(_) => session = self.session.borrow_mut(), - e => return (e, buf), - }; - } + let res = self.read_tls(&mut buf).await; + (res, buf) } } -impl AsyncBufWrite for TlsStream +impl AsyncBufWrite for TlsStream where - C: DerefMut + Deref>, - S: SideData, - Io: AsyncBufWrite, + C: ProcessTlsRecords, + Io: AsyncBufRead + AsyncBufWrite, { async fn write(&self, buf: B) -> (io::Result, B) where B: BoundedBuf, { - let mut session = self.session.borrow_mut(); - - let len = match session.write_plain(&buf) { - Ok(n) => n, - e => return (e, buf), - }; - - let mut write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); - - // currently there is no AsyncBufWrite::flush so write must keep flushing io until - // every bit of tls data is sent. this could be changed in the future for more efficient - // rustl buffer usage. - while session.session.wants_write() { - if let Err(e) = session.session.write_tls(&mut write_buf) { - session.write_buf.replace(write_buf); - return (Err(e), buf); - } - - drop(session); - - let (res, b) = write_buf.write_io(&self.io).await; - write_buf = b; - - session = self.session.borrow_mut(); - - match res { - Ok(0) => { - session.write_buf.replace(write_buf); - return (Err(io::ErrorKind::UnexpectedEof.into()), buf); - } - Ok(_) => {} - e => { - session.write_buf.replace(write_buf); - return (e, buf); - } - } - } - - session.write_buf.replace(write_buf); - - (Ok(len), buf) + let res = self.write_tls(&buf).await; + (res, buf) } fn shutdown(&self, direction: Shutdown) -> io::Result<()> { @@ -322,52 +436,93 @@ where } fn io_ref_slice(buf: &impl BoundedBuf) -> &[u8] { - // SAFETY - // have to trust IoBuf implementor to provide valid pointer and it's length. + // SAFETY: trust BoundedBuf implementor to provide valid pointer and length. unsafe { slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init()) } } fn io_ref_mut_slice(buf: &mut impl BoundedBufMut) -> &mut [u8] { - // SAFETY - // have to trust IoBufMut implementor to provide valid pointer and it's capacity. + // SAFETY: trust BoundedBufMut implementor to provide valid pointer and capacity. unsafe { slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total()) } } -const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; +fn tls_err(e: Error) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, e) +} + +/// Read from IO into a BytesMut, reserving space if needed. +async fn read_to_buf(io: &impl AsyncBufRead, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { + let len = buf.len(); + buf.reserve(4096); -mod buf { - use super::*; + let (res, b) = io.read(buf.slice(len..)).await; + buf = b.into_inner(); - pub(super) struct WriteBuf { - buf: BytesMut, + match res { + Ok(0) => (Err(io::ErrorKind::UnexpectedEof.into()), buf), + Ok(_) => (Ok(()), buf), + Err(e) => (Err(e), buf), } +} - impl Default for WriteBuf { - fn default() -> Self { - Self { buf: BytesMut::new() } +/// Write all bytes from a BytesMut to IO, then clear it. +async fn write_all_buf(io: &impl AsyncBufWrite, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { + while !buf.is_empty() { + let (res, b) = io.write(buf).await; + buf = b; + match res { + Ok(0) => return (Err(io::ErrorKind::UnexpectedEof.into()), buf), + Ok(n) => buf.advance(n), + Err(e) => return (Err(e), buf), } } + (Ok(()), buf) +} - impl WriteBuf { - pub(super) async fn write_io(mut self, io: &impl AsyncBufWrite) -> (io::Result, Self) { - if self.buf.is_empty() { - return (Ok(0), self); +/// Encode TLS handshake data into the write buffer, resizing if needed. +fn encode_tls_data(state: &mut unbuffered::EncodeTlsData<'_, Data>, write_buf: &mut BytesMut) -> io::Result<()> { + loop { + let spare = write_buf.spare_capacity_mut(); + // SAFETY: encode writes into the spare capacity; we track how many bytes. + let dst = unsafe { slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), spare.len()) }; + match state.encode(dst) { + Ok(n) => { + // SAFETY: encode wrote n bytes into the spare capacity. + unsafe { write_buf.set_len(write_buf.len() + n) }; + return Ok(()); } - let (res, buf) = io.write(self.buf).await; - self.buf = buf; - - (res.inspect(|n| self.buf.advance(*n)), self) + Err(unbuffered::EncodeError::InsufficientSize(unbuffered::InsufficientSizeError { required_size })) => { + write_buf.reserve(required_size); + } + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), } } +} - impl io::Write for WriteBuf { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.buf.extend_from_slice(buf); - Ok(buf.len()) - } +/// Encrypt plaintext into the write buffer, resizing if needed. +fn encrypt_to_buf( + traffic: &mut unbuffered::WriteTraffic<'_, Data>, + plaintext: &[u8], + write_buf: &mut BytesMut, +) -> io::Result<()> { + let needed = plaintext.len() + 64; + if write_buf.capacity() - write_buf.len() < needed { + write_buf.reserve(needed); + } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + loop { + let spare = write_buf.spare_capacity_mut(); + let dst = unsafe { slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), spare.len()) }; + match traffic.encrypt(plaintext, dst) { + Ok(n) => { + unsafe { write_buf.set_len(write_buf.len() + n) }; + return Ok(()); + } + Err(EncryptError::InsufficientSize(unbuffered::InsufficientSizeError { required_size })) => { + write_buf.reserve(required_size); + } + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), } } } + +const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; From 5dc0ffedc41418a88ab82a2c0c0ba6855ba46796 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 31 Mar 2026 17:10:59 +0800 Subject: [PATCH 08/21] fix import. add draft openssl-uring --- examples/io-uring-h2/Cargo.toml | 6 +- examples/io-uring-h2/src/main.rs | 21 +- examples/io-uring/Cargo.toml | 4 +- examples/io-uring/src/main.rs | 2 +- http/CHANGES.md | 4 +- http/Cargo.toml | 4 +- http/src/h2/dispatcher_uring.rs | 35 ++- http/src/tls/rustls_uring.rs | 2 +- tls/CHANGES.md | 4 +- tls/Cargo.toml | 16 +- tls/src/lib.rs | 2 + tls/src/openssl_uring.rs | 398 +++++++++++++++++++++++++++++++ tls/src/rustls.rs | 2 +- tls/src/rustls_uring.rs | 256 ++++++++++++-------- 14 files changed, 616 insertions(+), 140 deletions(-) create mode 100644 tls/src/openssl_uring.rs diff --git a/examples/io-uring-h2/Cargo.toml b/examples/io-uring-h2/Cargo.toml index db1b43cbb..fb8a6cbfb 100644 --- a/examples/io-uring-h2/Cargo.toml +++ b/examples/io-uring-h2/Cargo.toml @@ -5,13 +5,15 @@ authors = ["fakeshadow <24548779@qq.com>"] edition = "2024" [dependencies] -xitca-http = { path = "../../http", features = ["http2", "io-uring", "router"] } +xitca-http = { path = "../../http", features = ["http2", "io-uring", "router"] } xitca-server = { version = "0.6.1", features = ["io-uring"] } xitca-service = "0.3" futures-core = "0.3" - mimalloc = { version = "0.1.48", default-features = false, features = ["v3"] } +# rcgen = "0.14" +# rustls = "0.23" +# rustls-pki-types = "1" [profile.release] opt-level = 3 diff --git a/examples/io-uring-h2/src/main.rs b/examples/io-uring-h2/src/main.rs index 0d0423e9c..712c17603 100644 --- a/examples/io-uring-h2/src/main.rs +++ b/examples/io-uring-h2/src/main.rs @@ -1,7 +1,7 @@ //! A Http/2 server returns Hello World String as Response. //! -//! *. use h2c prior knowledge as protocol. //! *. io_uring is a linux OS feature. +//! *. random self signed cert is used for tls certification. #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -69,3 +69,22 @@ impl Stream for Once { (len, Some(len)) } } + +// // rustls configuration. +// fn tls_config() -> std::sync::Arc { +// let subject_alt_names = vec!["127.0.0.1".to_string(), "localhost".to_string()]; + +// let cert = rcgen::generate_simple_self_signed(subject_alt_names).unwrap(); + +// let mut config = rustls::ServerConfig::builder() +// .with_no_client_auth() +// .with_single_cert( +// vec![cert.cert.into()], +// cert.signing_key.serialize_der().try_into().unwrap(), +// ) +// .unwrap(); + +// config.alpn_protocols = vec![b"h2".to_vec()]; + +// std::sync::Arc::new(config) +// } diff --git a/examples/io-uring/Cargo.toml b/examples/io-uring/Cargo.toml index b85a00114..e7ba94894 100644 --- a/examples/io-uring/Cargo.toml +++ b/examples/io-uring/Cargo.toml @@ -9,6 +9,6 @@ xitca-http = { version = "0.7", features = ["io-uring", "router", "rustls-uring" xitca-server = { version = "0.5", features = ["io-uring"] } xitca-service = "0.3" -rcgen = "0.13" +rcgen = "0.14" rustls = "0.23" -rustls-pki-types = "1" \ No newline at end of file +rustls-pki-types = "1" diff --git a/examples/io-uring/src/main.rs b/examples/io-uring/src/main.rs index 67bbb6503..b8756f80f 100644 --- a/examples/io-uring/src/main.rs +++ b/examples/io-uring/src/main.rs @@ -45,7 +45,7 @@ fn tls_config() -> Arc { .with_no_client_auth() .with_single_cert( vec![cert.cert.into()], - cert.key_pair.serialize_der().try_into().unwrap(), + cert.signing_key.serialize_der().try_into().unwrap(), ) .unwrap(); diff --git a/http/CHANGES.md b/http/CHANGES.md index 94cc7ad78..7be00e31d 100644 --- a/http/CHANGES.md +++ b/http/CHANGES.md @@ -1,4 +1,6 @@ -# unreleased +# unreleased 0.8.3 +## Change +- update `xitca-tls` to `0.5.2` # 0.8.2 ## Fix diff --git a/http/Cargo.toml b/http/Cargo.toml index bd4e5616d..b7b8a2a89 100644 --- a/http/Cargo.toml +++ b/http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xitca-http" -version = "0.8.2" +version = "0.8.3" edition = "2024" license = "Apache-2.0" description = "http library for xitca" @@ -53,7 +53,7 @@ tracing = { version = "0.1.40", default-features = false } native-tls = { version = "0.2.7", features = ["alpn"], optional = true } # tls support shared -xitca-tls = { version = "0.5.1", optional = true } +xitca-tls = { version = "0.5.2", optional = true } # http/1 support httparse = { version = "1.8", optional = true } diff --git a/http/src/h2/dispatcher_uring.rs b/http/src/h2/dispatcher_uring.rs index 1035016a1..5c0cacf4f 100644 --- a/http/src/h2/dispatcher_uring.rs +++ b/http/src/h2/dispatcher_uring.rs @@ -1096,7 +1096,7 @@ impl<'a> EncodeContext<'a> { inner.flow.remote_settings.encode(&mut self.encoder, write_buf); - let is_eof = loop { + let writable = loop { match inner.queue.poll_recv() { Poll::Ready(Some(msg)) => match msg { Message::Head(headers) => { @@ -1127,11 +1127,11 @@ impl<'a> EncodeContext<'a> { // TRAILERS) queued by response tasks are still delivered. // The write task exits naturally when the queue is both // empty and closed (Poll::Ready(None) below). - break false; + break true; } }, - Poll::Pending => break false, - Poll::Ready(None) => break true, + Poll::Pending => break true, + Poll::Ready(None) => break false, } }; @@ -1155,14 +1155,12 @@ impl<'a> EncodeContext<'a> { inner.queue.keepalive_ping = KeepalivePing::InFlight; } - // Return Pending only when there is nothing to write AND the queue is - // still open. If is_eof is true (queue closed+empty), return Ready even - // with an empty buffer: poll_recv returns Poll::Ready(None) without - // storing a waker, so returning Pending here would stall forever. - if write_buf.is_empty() && !is_eof { - Poll::Pending + if !write_buf.is_empty() { + Poll::Ready(true) + } else if !writable { + Poll::Ready(false) } else { - Poll::Ready(is_eof) + Poll::Pending } } } @@ -1403,18 +1401,19 @@ where let mut read_task = pin!(read_io(read_buf, &io)); let mut write_task = pin!(async { - loop { - let is_eof = poll_fn(|_| enc.poll_encode(&mut write_buf)).await; - - let (res, buf) = write_io(write_buf, &io).await; + while poll_fn(|_| enc.poll_encode(&mut write_buf)).await { + let (res, buf) = io.write(write_buf).await; write_buf = buf; - res?; - if is_eof { - return Ok(()); + match res { + Ok(0) => return Err(io::ErrorKind::WriteZero.into()), + Ok(n) => write_buf.advance(n), + Err(e) => return Err(e), } } + + Ok(()) }); let shutdown = loop { diff --git a/http/src/tls/rustls_uring.rs b/http/src/tls/rustls_uring.rs index d13eb287c..078f21ab4 100644 --- a/http/src/tls/rustls_uring.rs +++ b/http/src/tls/rustls_uring.rs @@ -4,7 +4,7 @@ use std::{io, net::Shutdown, sync::Arc}; use xitca_io::io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; use xitca_service::Service; -use xitca_tls::rustls_uring::{ServerConfig, TlsStream as _TlsStream, UnbufferedServerConnection}; +use xitca_tls::rustls_uring::{ServerConfig, TlsStream as _TlsStream, server::UnbufferedServerConnection}; use crate::{http::Version, version::AsVersion}; diff --git a/tls/CHANGES.md b/tls/CHANGES.md index 649395d97..37e34bebd 100644 --- a/tls/CHANGES.md +++ b/tls/CHANGES.md @@ -1,4 +1,6 @@ -# unreleased +# unreleased 0.5.2 +## Change +- internal change to reduce memory copy when `io-uring` feature enabled # 0.5.1 ## Add diff --git a/tls/Cargo.toml b/tls/Cargo.toml index 570f60994..bfce3e74b 100644 --- a/tls/Cargo.toml +++ b/tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xitca-tls" -version = "0.5.1" +version = "0.5.2" edition = "2024" license = "Apache-2.0" description = "tls utility for xitca" @@ -11,19 +11,21 @@ readme= "README.md" [features] openssl = ["dep:openssl"] +# openssl for xitca-io io-uring traits +# openssl-uring = ["dep:openssl", "xitca-io/runtime-uring"] # rustls with no default crypto provider -rustls-no-crypto = ["rustls_crate"] +rustls-no-crypto = ["dep:rustls"] # rustls with aws-lc as crypto provider (default provider from `rustls` crate) -rustls = ["rustls_crate/aws-lc-rs"] +rustls = ["rustls/aws-lc-rs"] # rustls with ring as crypto provider -rustls-ring-crypto = ["rustls_crate/ring"] +rustls-ring-crypto = ["rustls/ring"] # rustls with no crypto provider for xitca-io io-uring traits -rustls-uring-no-crypto = ["rustls_crate", "xitca-io/runtime-uring"] +rustls-uring-no-crypto = ["rustls-no-crypto", "xitca-io/runtime-uring"] # rustls with aws-lc as crypto provider for xitca-io io-uring trait (default provider from `rustls` crate) -rustls-uring = ["rustls_crate/aws-lc-rs", "xitca-io/runtime-uring"] +rustls-uring = ["rustls/aws-lc-rs", "xitca-io/runtime-uring"] [dependencies] xitca-io = { version = "0.5.0", features = ["runtime"] } openssl = { version = "0.10", optional = true } -rustls_crate = { package = "rustls", version = "0.23", default-features = false, features = ["logging", "std", "tls12"], optional = true } +rustls = { version = "0.23", default-features = false, features = ["std", "tls12"], optional = true } diff --git a/tls/src/lib.rs b/tls/src/lib.rs index 771c1479b..03ded8fe5 100644 --- a/tls/src/lib.rs +++ b/tls/src/lib.rs @@ -1,5 +1,7 @@ #[cfg(feature = "openssl")] pub mod openssl; +// #[cfg(feature = "openssl-uring")] +// pub mod openssl_uring; #[cfg(any(feature = "rustls", feature = "rustls-ring-crypto", feature = "rustls-no-crypto"))] pub mod rustls; #[cfg(any(feature = "rustls-uring", feature = "rustls-uring-no-crypto"))] diff --git a/tls/src/openssl_uring.rs b/tls/src/openssl_uring.rs new file mode 100644 index 000000000..fddb30785 --- /dev/null +++ b/tls/src/openssl_uring.rs @@ -0,0 +1,398 @@ +use core::cell::RefCell; + +use std::{ + io::{self, Read, Write}, + net::Shutdown, + rc::Rc, +}; + +pub use openssl::*; + +use openssl::ssl::{ErrorCode, Ssl, SslRef, SslStream}; + +use xitca_io::{ + bytes::{Buf, BytesMut}, + io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, +}; + +/// A TLS stream backed by OpenSSL that implements completion-based IO traits. +/// +/// Uses a sync-to-async bridge: OpenSSL reads/writes against in-memory buffers +/// ([`SyncStream`]), and the actual socket IO is performed asynchronously. +/// +/// Supports concurrent read/write: `ssl_read`/`ssl_write` are synchronous +/// operations on in-memory buffers. The `RefCell` borrow is dropped before +/// any async socket IO, allowing the other path to proceed. +/// +/// # Panics +/// Each async read/write operation must be polled to completion. Dropping a future before it +/// completes will leave internal buffers in a taken state, causing the next call to panic. +/// Concurrent reads or concurrent writes (two reads at the same time, etc.) will also panic. +pub struct TlsStream { + io: Io, + session: Rc>, +} + +struct Session { + ssl: SslStream, + /// Protocol data produced by read path (key updates, alerts). + /// Flushed by write path before sending application data. + proto_write_buf: BytesMut, +} + +/// Synchronous stream adapter that OpenSSL reads from / writes to. +/// +/// Buffers are `Option` to detect concurrent misuse: each path +/// takes its buffer before async IO and replaces it after. A second concurrent +/// operation on the same path will find `None` and panic. +/// +/// `Read` pulls ciphertext from `read_buf`. `Write` appends ciphertext to +/// `write_buf`. Returns `WouldBlock` when `read_buf` is empty to signal +/// OpenSSL to yield. +struct SyncStream { + /// Ciphertext from the socket, consumed by OpenSSL during `ssl_read`. + read_buf: Option, + /// Ciphertext produced by OpenSSL, to be sent to the socket. + write_buf: Option, +} + +impl Read for SyncStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let read_buf = self.read_buf.as_mut().expect(POLL_TO_COMPLETE); + if read_buf.is_empty() { + return Err(io::ErrorKind::WouldBlock.into()); + } + let n = buf.len().min(read_buf.len()); + buf[..n].copy_from_slice(&read_buf[..n]); + read_buf.advance(n); + Ok(n) + } +} + +impl Write for SyncStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + let write_buf = self.write_buf.as_mut().expect(POLL_TO_COMPLETE); + write_buf.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + /// Perform a TLS handshake as the server side. + pub async fn accept(ssl: Ssl, io: Io) -> Result { + let stream = Self::new(ssl, io)?; + stream.handshake(|ssl| ssl.accept()).await?; + Ok(stream) + } + + /// Perform a TLS handshake as the client side. + pub async fn connect(ssl: Ssl, io: Io) -> Result { + let stream = Self::new(ssl, io)?; + stream.handshake(|ssl| ssl.connect()).await?; + Ok(stream) + } + + fn new(ssl: Ssl, io: Io) -> Result { + let sync_stream = SyncStream { + read_buf: Some(BytesMut::new()), + write_buf: Some(BytesMut::new()), + }; + let ssl_stream = SslStream::new(ssl, sync_stream)?; + + Ok(TlsStream { + io, + session: Rc::new(RefCell::new(Session { + ssl: ssl_stream, + proto_write_buf: BytesMut::new(), + })), + }) + } + + /// Acquire a reference to the `SslRef` for inspecting the session. + pub fn session(&self) -> impl core::ops::Deref + '_ { + std::cell::Ref::map(self.session.borrow(), |s| s.ssl.ssl()) + } + + async fn handshake(&self, mut func: F) -> Result<(), Error> + where + F: FnMut(&mut SslStream) -> Result<(), openssl::ssl::Error>, + { + let mut session = self.session.borrow_mut(); + + loop { + match func(&mut session.ssl) { + Ok(()) => { + // Flush any remaining handshake data. + let sync = session.ssl.get_mut(); + if sync.write_buf.as_ref().is_some_and(|b| !b.is_empty()) { + let mut write_buf = sync.write_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = flush_write_buf(&self.io, write_buf).await; + write_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().write_buf = Some(write_buf); + res?; + } + return Ok(()); + } + Err(ref e) if e.code() == ErrorCode::WANT_READ => { + // Flush outgoing handshake data first. + let sync = session.ssl.get_mut(); + if sync.write_buf.as_ref().is_some_and(|b| !b.is_empty()) { + let mut write_buf = sync.write_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = flush_write_buf(&self.io, write_buf).await; + write_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().write_buf = Some(write_buf); + res?; + } + + // Read more ciphertext from the socket. + let mut read_buf = session.ssl.get_mut().read_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = read_to_buf(&self.io, read_buf).await; + read_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().read_buf = Some(read_buf); + res?; + } + Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { + // Flush outgoing handshake data. + let mut write_buf = session.ssl.get_mut().write_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = flush_write_buf(&self.io, write_buf).await; + write_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().write_buf = Some(write_buf); + res?; + } + Err(e) => return Err(Error::Tls(e)), + } + } + } + + /// Read plaintext by decrypting ciphertext from the socket. + /// + /// Protocol data produced during read (key updates, alerts) is buffered + /// in `proto_write_buf` and flushed by the next `write_tls` call. + async fn read_tls(&self, plain_buf: &mut impl BoundedBufMut) -> io::Result { + let mut session = self.session.borrow_mut(); + + loop { + let dst = io_ref_mut_slice(plain_buf); + match session.ssl.ssl_read(dst) { + Ok(n) => { + unsafe { plain_buf.set_init(n) }; + + // Drain protocol data into proto_write_buf for write path to flush. + drain_proto_write(&mut session); + + return Ok(n); + } + Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0), + Err(ref e) if e.code() == ErrorCode::WANT_READ => { + // Drain protocol data into proto_write_buf. + drain_proto_write(&mut session); + + // Read more ciphertext from the socket. + let mut read_buf = session.ssl.get_mut().read_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = read_to_buf(&self.io, read_buf).await; + read_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().read_buf = Some(read_buf); + res?; + } + Err(e) => { + return Err(e + .into_io_error() + .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))); + } + } + } + } + + /// Encrypt plaintext and write ciphertext to the socket. + /// + /// Flushes any protocol data buffered by the read path before writing. + async fn write_tls(&self, plain: &impl BoundedBuf) -> io::Result { + let mut session = self.session.borrow_mut(); + let plaintext = io_ref_slice(plain); + + // Flush protocol data from read path into write_buf. + let session_ref = &mut *session; + if !session_ref.proto_write_buf.is_empty() { + let write_buf = session_ref.ssl.get_mut().write_buf.as_mut().expect(POLL_TO_COMPLETE); + write_buf.extend_from_slice(&session_ref.proto_write_buf); + session_ref.proto_write_buf.clear(); + } + + loop { + match session.ssl.ssl_write(plaintext) { + Ok(n) => { + // Flush ciphertext to the socket. + let mut write_buf = session.ssl.get_mut().write_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = flush_write_buf(&self.io, write_buf).await; + write_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().write_buf = Some(write_buf); + res?; + + return Ok(n); + } + Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { + // Flush and retry. + let mut write_buf = session.ssl.get_mut().write_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = flush_write_buf(&self.io, write_buf).await; + write_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().write_buf = Some(write_buf); + res?; + } + Err(ref e) if e.code() == ErrorCode::WANT_READ => { + // Renegotiation — flush then read before retrying. + let sync = session.ssl.get_mut(); + if sync.write_buf.as_ref().is_some_and(|b| !b.is_empty()) { + let mut write_buf = sync.write_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = flush_write_buf(&self.io, write_buf).await; + write_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().write_buf = Some(write_buf); + res?; + } + + let mut read_buf = session.ssl.get_mut().read_buf.take().expect(POLL_TO_COMPLETE); + drop(session); + let (res, b) = read_to_buf(&self.io, read_buf).await; + read_buf = b; + session = self.session.borrow_mut(); + session.ssl.get_mut().read_buf = Some(read_buf); + res?; + } + Err(e) => { + return Err(e + .into_io_error() + .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))); + } + } + } + } +} + +/// Move protocol ciphertext produced during `ssl_read` to `proto_write_buf`. +/// The write path will flush it to the socket. +fn drain_proto_write(session: &mut Session) { + let sync = session.ssl.get_mut(); + if let Some(write_buf) = sync.write_buf.as_mut() { + if !write_buf.is_empty() { + session.proto_write_buf.extend_from_slice(write_buf); + write_buf.clear(); + } + } +} + +impl AsyncBufRead for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn read(&self, mut buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + let res = self.read_tls(&mut buf).await; + (res, buf) + } +} + +impl AsyncBufWrite for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + let res = self.write_tls(&buf).await; + (res, buf) + } + + fn shutdown(&self, direction: Shutdown) -> io::Result<()> { + self.io.shutdown(direction) + } +} + +fn io_ref_slice(buf: &impl BoundedBuf) -> &[u8] { + unsafe { core::slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init()) } +} + +fn io_ref_mut_slice(buf: &mut impl BoundedBufMut) -> &mut [u8] { + unsafe { core::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total()) } +} + +/// Read from IO into a BytesMut, reserving space if needed. +async fn read_to_buf(io: &impl AsyncBufRead, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { + let len = buf.len(); + buf.reserve(4096); + + let (res, b) = io.read(buf.slice(len..)).await; + buf = b.into_inner(); + + match res { + Ok(0) => (Err(io::ErrorKind::UnexpectedEof.into()), buf), + Ok(_) => (Ok(()), buf), + Err(e) => (Err(e), buf), + } +} + +/// Write all bytes from a BytesMut to IO, then clear it. +async fn flush_write_buf(io: &impl AsyncBufWrite, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { + let (res, b) = xitca_io::io_uring::write_all(io, buf).await; + buf = b; + if res.is_ok() { + buf.clear(); + } + (res, buf) +} + +const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; + +/// Collection of OpenSSL error types. +#[derive(Debug)] +pub enum Error { + Io(io::Error), + Tls(openssl::ssl::Error), +} + +impl From for Error { + fn from(e: openssl::error::ErrorStack) -> Self { + Self::Tls(e.into()) + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Self::Io(e) + } +} + +impl core::fmt::Display for Error { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::Io(e) => core::fmt::Display::fmt(e, f), + Self::Tls(e) => core::fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for Error {} diff --git a/tls/src/rustls.rs b/tls/src/rustls.rs index 1c9d61ee0..70b99d5a0 100644 --- a/tls/src/rustls.rs +++ b/tls/src/rustls.rs @@ -7,7 +7,7 @@ use core::{ use std::io; -pub use rustls_crate::*; +pub use rustls::*; use xitca_io::io::{AsyncIo, Interest, Ready}; diff --git a/tls/src/rustls_uring.rs b/tls/src/rustls_uring.rs index e1dd7c39a..03838caa2 100644 --- a/tls/src/rustls_uring.rs +++ b/tls/src/rustls_uring.rs @@ -4,9 +4,9 @@ use core::{cell::RefCell, slice}; use std::{io, net::Shutdown, rc::Rc}; -pub use rustls_crate::*; +pub use rustls::*; -use rustls_crate::{ +use rustls::{ client::UnbufferedClientConnection, server::UnbufferedServerConnection, unbuffered::UnbufferedConnectionCommon, @@ -53,60 +53,6 @@ impl ProcessTlsRecords for UnbufferedClientConnection { /// Reduced `ConnectionState` that doesn't borrow the connection or incoming buffer. /// Created by draining all needed data from the borrowed state variants. -enum State { - /// Read traffic was processed (plaintext drained by caller). - ReadTraffic, - /// Handshake data was encoded into write_buf. - EncodedTlsData, - /// Encoded data needs to be transmitted, then call done. - TransmitTlsData, - /// Need more ciphertext. - BlockedHandshake, - /// Handshake complete, ready to send app data. - WriteTraffic, - /// Peer sent close_notify or connection fully closed. - Closed, - /// Peer closed (edge-triggered). - PeerClosed, -} - -/// Process one round of TLS records and return an owned State plus the discard count. -/// This function handles `EncodeTlsData` inline (encoding into write_buf) and -/// `ReadTraffic` inline (draining plaintext into a BytesMut). -fn process_once( - conn: &mut C, - read_buf: &mut BytesMut, - write_buf: &mut BytesMut, -) -> io::Result { - let UnbufferedStatus { discard, state } = conn.process_tls_records(read_buf.as_mut()); - - let state = match state.map_err(tls_err)? { - ConnectionState::ReadTraffic(_) => State::ReadTraffic, - ConnectionState::EncodeTlsData(mut state) => { - encode_tls_data(&mut state, write_buf)?; - State::EncodedTlsData - } - ConnectionState::TransmitTlsData(state) => { - state.done(); - State::TransmitTlsData - } - ConnectionState::WriteTraffic(_) => { - // WriteTraffic may have pending TLS data (key updates). - // We don't encrypt here — caller decides. - State::WriteTraffic - } - ConnectionState::BlockedHandshake => State::BlockedHandshake, - ConnectionState::PeerClosed => State::PeerClosed, - ConnectionState::Closed => State::Closed, - _ => State::BlockedHandshake, // Unknown variants treated as needing more data. - }; - - // Discard consumed bytes from read_buf after all borrows are released. - read_buf.advance(discard); - - Ok(state) -} - /// A TLS stream type that supports concurrent async read/write through [AsyncBufRead] and /// [AsyncBufWrite] traits. /// @@ -171,29 +117,53 @@ where let mut proto_write_buf = session.proto_write_buf.split(); let res = loop { - let res = match process_once(&mut session.conn, &mut read_buf, &mut proto_write_buf) { - Err(e) => Err(e), + let UnbufferedStatus { discard, state } = session.conn.process_tls_records(read_buf.as_mut()); - // Continue processing — more handshake data may follow. - Ok(State::EncodedTlsData) => continue, + let res = match state.map_err(tls_err) { + Err(e) => { + read_buf.advance(discard); + Err(e) + } + + Ok(ConnectionState::EncodeTlsData(mut state)) => { + let enc_res = encode_tls_data(&mut state, &mut proto_write_buf); + drop(state); + read_buf.advance(discard); + enc_res?; + continue; + } + + Ok(ConnectionState::TransmitTlsData(state)) => { + state.done(); + read_buf.advance(discard); - Ok(State::TransmitTlsData) => { let (res, b) = write_all_buf(&self.io, proto_write_buf).await; proto_write_buf = b; res } - Ok(State::BlockedHandshake) => { + Ok(ConnectionState::BlockedHandshake) => { + read_buf.advance(discard); + let (res, b) = read_to_buf(&self.io, read_buf).await; read_buf = b; res } - Ok(State::WriteTraffic | State::ReadTraffic) => break Ok(()), + Ok(ConnectionState::WriteTraffic(_) | ConnectionState::ReadTraffic(_)) => { + read_buf.advance(discard); + break Ok(()); + } - Ok(State::PeerClosed | State::Closed) => { + Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { + read_buf.advance(discard); Err(io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof")) } + + Ok(_) => { + read_buf.advance(discard); + continue; + } }; if res.is_err() { @@ -205,7 +175,13 @@ where session.proto_write_buf = proto_write_buf; res } +} +impl TlsStream +where + C: ProcessTlsRecords, + Io: AsyncBufRead, +{ /// Read ciphertext from IO, decrypt, and return plaintext. async fn read_tls(&self, plain_buf: &mut impl BoundedBufMut) -> io::Result { let mut session = self.session.borrow_mut(); @@ -328,8 +304,14 @@ where session.read_buf.replace(read_buf); res } +} - /// Encrypt plaintext and write ciphertext to IO. +impl TlsStream +where + C: ProcessTlsRecords, + Io: AsyncBufWrite, +{ + /// Encrypt plaintext and write all ciphertext to IO. async fn write_tls(&self, plain: &impl BoundedBuf) -> io::Result { let mut session = self.session.borrow_mut(); let mut write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); @@ -406,7 +388,7 @@ where impl AsyncBufRead for TlsStream where C: ProcessTlsRecords, - Io: AsyncBufRead + AsyncBufWrite, + Io: AsyncBufRead, { async fn read(&self, mut buf: B) -> (io::Result, B) where @@ -420,7 +402,7 @@ where impl AsyncBufWrite for TlsStream where C: ProcessTlsRecords, - Io: AsyncBufRead + AsyncBufWrite, + Io: AsyncBufWrite, { async fn write(&self, buf: B) -> (io::Result, B) where @@ -466,36 +448,28 @@ async fn read_to_buf(io: &impl AsyncBufRead, mut buf: BytesMut) -> (io::Result<( /// Write all bytes from a BytesMut to IO, then clear it. async fn write_all_buf(io: &impl AsyncBufWrite, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { - while !buf.is_empty() { - let (res, b) = io.write(buf).await; - buf = b; - match res { - Ok(0) => return (Err(io::ErrorKind::UnexpectedEof.into()), buf), - Ok(n) => buf.advance(n), - Err(e) => return (Err(e), buf), - } + let (res, b) = xitca_io::io_uring::write_all(io, buf).await; + buf = b; + if res.is_ok() { + buf.clear(); } - (Ok(()), buf) + (res, buf) } /// Encode TLS handshake data into the write buffer, resizing if needed. fn encode_tls_data(state: &mut unbuffered::EncodeTlsData<'_, Data>, write_buf: &mut BytesMut) -> io::Result<()> { - loop { - let spare = write_buf.spare_capacity_mut(); - // SAFETY: encode writes into the spare capacity; we track how many bytes. - let dst = unsafe { slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), spare.len()) }; - match state.encode(dst) { - Ok(n) => { - // SAFETY: encode wrote n bytes into the spare capacity. - unsafe { write_buf.set_len(write_buf.len() + n) }; - return Ok(()); - } - Err(unbuffered::EncodeError::InsufficientSize(unbuffered::InsufficientSizeError { required_size })) => { + // SAFETY: EncodeTlsData::encode copies a single chunk contiguously from index 0. + // On Ok(n), exactly n bytes are written. On InsufficientSize or AlreadyEncoded, + // the size check happens before any write so the slice is untouched. + while let Err(e) = unsafe { SpareCapBuf::new(write_buf).with_mut_slice(|slice| state.encode(slice)) } { + match e { + unbuffered::EncodeError::InsufficientSize(unbuffered::InsufficientSizeError { required_size }) => { write_buf.reserve(required_size); } - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + e => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), } } + Ok(()) } /// Encrypt plaintext into the write buffer, resizing if needed. @@ -504,25 +478,101 @@ fn encrypt_to_buf( plaintext: &[u8], write_buf: &mut BytesMut, ) -> io::Result<()> { - let needed = plaintext.len() + 64; - if write_buf.capacity() - write_buf.len() < needed { - write_buf.reserve(needed); - } - - loop { - let spare = write_buf.spare_capacity_mut(); - let dst = unsafe { slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), spare.len()) }; - match traffic.encrypt(plaintext, dst) { - Ok(n) => { - unsafe { write_buf.set_len(write_buf.len() + n) }; - return Ok(()); - } - Err(EncryptError::InsufficientSize(unbuffered::InsufficientSizeError { required_size })) => { + write_buf.reserve(plaintext.len() + 64); + // SAFETY: WriteTraffic::encrypt writes TLS records contiguously from index 0 via + // write_fragments. On Ok(n), exactly n bytes are written. On InsufficientSize, + // check_required_size returns before any write. On EncryptExhausted, the error + // is returned during pre-encryption checks before any write. + while let Err(err) = + unsafe { SpareCapBuf::new(write_buf).with_mut_slice(|spare| traffic.encrypt(plaintext, spare)) } + { + match err { + EncryptError::InsufficientSize(unbuffered::InsufficientSizeError { required_size }) => { write_buf.reserve(required_size); } - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + e => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), } } + Ok(()) +} + +/// Wraps a `BytesMut`'s spare capacity as a mutable byte slice. +/// +/// Encapsulates the unsafe operations of interpreting spare capacity as `&mut [u8]` +/// and committing written bytes via `set_len`. +struct SpareCapBuf<'a> { + buf: &'a mut BytesMut, +} + +impl<'a> SpareCapBuf<'a> { + fn new(buf: &'a mut BytesMut) -> Self { + Self { buf } + } + + /// # Safety + /// + /// The callback `func` must uphold the following contract: + /// - Writes must be sequential and contiguous, starting from index 0 of the slice. + /// - On `Ok(n)`, exactly `n` bytes must have been written to `slice[..n]`. + /// - On `Err`, zero bytes must have been written into the slice. + unsafe fn with_mut_slice(self, func: F) -> Result<(), E> + where + F: FnOnce(&mut [u8]) -> Result, + { + let spare = self.buf.spare_capacity_mut(); + + // SAFETY: the caller must write into the slice before reading. + // We only expose this for write-before-read patterns (TLS encode/encrypt). + let slice = unsafe { slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), spare.len()) }; + + let n = func(slice)?; + + // SAFETY: caller guarantees n bytes were written into the spare capacity. + unsafe { self.buf.set_len(self.buf.len() + n) }; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn spare_cap_buf_write_and_commit() { + let mut buf = BytesMut::with_capacity(64); + buf.extend_from_slice(b"hello"); + + let res = unsafe { + SpareCapBuf::new(&mut buf).with_mut_slice(|slice| { + assert!(slice.len() >= 59); + slice[..5].copy_from_slice(b"world"); + Ok::<_, ()>(5) + }) + }; + assert!(res.is_ok()); + assert_eq!(&buf[..], b"helloworld"); + } + + #[test] + fn spare_cap_buf_commit_zero() { + let mut buf = BytesMut::with_capacity(16); + buf.extend_from_slice(b"abc"); + + let res = unsafe { SpareCapBuf::new(&mut buf).with_mut_slice(|_| Ok::<_, ()>(0)) }; + assert!(res.is_ok()); + assert_eq!(&buf[..], b"abc"); + } + + #[test] + fn spare_cap_buf_error_no_commit() { + let mut buf = BytesMut::with_capacity(16); + buf.extend_from_slice(b"abc"); + + let res = unsafe { SpareCapBuf::new(&mut buf).with_mut_slice(|_| Err::("too small")) }; + assert!(res.is_err()); + assert_eq!(&buf[..], b"abc"); + } } const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; From 39049266d5ecca2efd54ddd068359a289d150b7c Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 31 Mar 2026 17:15:33 +0800 Subject: [PATCH 09/21] clippy fix --- tls/src/rustls_uring.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tls/src/rustls_uring.rs b/tls/src/rustls_uring.rs index 03838caa2..24bf333ba 100644 --- a/tls/src/rustls_uring.rs +++ b/tls/src/rustls_uring.rs @@ -534,6 +534,8 @@ impl<'a> SpareCapBuf<'a> { } } +const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; + #[cfg(test)] mod tests { use super::*; @@ -574,5 +576,3 @@ mod tests { assert_eq!(&buf[..], b"abc"); } } - -const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; From 1a0bdb993a55dccecf09a191db182a37097cb7ad Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 31 Mar 2026 17:40:16 +0800 Subject: [PATCH 10/21] feature flag fix --- tls/Cargo.toml | 13 ++++++++----- tls/src/rustls.rs | 2 +- tls/src/rustls_uring.rs | 4 ++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tls/Cargo.toml b/tls/Cargo.toml index bfce3e74b..ba4caffd8 100644 --- a/tls/Cargo.toml +++ b/tls/Cargo.toml @@ -11,21 +11,24 @@ readme= "README.md" [features] openssl = ["dep:openssl"] + # openssl for xitca-io io-uring traits # openssl-uring = ["dep:openssl", "xitca-io/runtime-uring"] + # rustls with no default crypto provider -rustls-no-crypto = ["dep:rustls"] +rustls-no-crypto = ["dep:rustls_crate"] # rustls with aws-lc as crypto provider (default provider from `rustls` crate) -rustls = ["rustls/aws-lc-rs"] +rustls = ["rustls-no-crypto", "rustls_crate/aws-lc-rs"] # rustls with ring as crypto provider -rustls-ring-crypto = ["rustls/ring"] +rustls-ring-crypto = ["rustls-no-crypto", "rustls_crate/ring"] + # rustls with no crypto provider for xitca-io io-uring traits rustls-uring-no-crypto = ["rustls-no-crypto", "xitca-io/runtime-uring"] # rustls with aws-lc as crypto provider for xitca-io io-uring trait (default provider from `rustls` crate) -rustls-uring = ["rustls/aws-lc-rs", "xitca-io/runtime-uring"] +rustls-uring = ["rustls-uring-no-crypto", "rustls_crate/aws-lc-rs"] [dependencies] xitca-io = { version = "0.5.0", features = ["runtime"] } openssl = { version = "0.10", optional = true } -rustls = { version = "0.23", default-features = false, features = ["std", "tls12"], optional = true } +rustls_crate = { package = "rustls", version = "0.23", default-features = false, features = ["std", "tls12"], optional = true } diff --git a/tls/src/rustls.rs b/tls/src/rustls.rs index 70b99d5a0..1c9d61ee0 100644 --- a/tls/src/rustls.rs +++ b/tls/src/rustls.rs @@ -7,7 +7,7 @@ use core::{ use std::io; -pub use rustls::*; +pub use rustls_crate::*; use xitca_io::io::{AsyncIo, Interest, Ready}; diff --git a/tls/src/rustls_uring.rs b/tls/src/rustls_uring.rs index 24bf333ba..b887264e5 100644 --- a/tls/src/rustls_uring.rs +++ b/tls/src/rustls_uring.rs @@ -4,9 +4,9 @@ use core::{cell::RefCell, slice}; use std::{io, net::Shutdown, rc::Rc}; -pub use rustls::*; +pub use rustls_crate::*; -use rustls::{ +use rustls_crate::{ client::UnbufferedClientConnection, server::UnbufferedServerConnection, unbuffered::UnbufferedConnectionCommon, From 44d96f89ddf095fdd82518950471f48704693ac1 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Tue, 31 Mar 2026 21:09:51 +0800 Subject: [PATCH 11/21] version bump for tokio-uring and io --- io/CHANGES.md | 4 +- io/Cargo.toml | 6 +- tokio-uring/CHANGELOG.md | 4 +- tokio-uring/Cargo.toml | 13 ++-- tokio-uring/src/buf/mod.rs | 1 + tokio-uring/src/lib.rs | 147 ++++--------------------------------- 6 files changed, 32 insertions(+), 143 deletions(-) diff --git a/io/CHANGES.md b/io/CHANGES.md index 5e15fb517..d5aea4381 100644 --- a/io/CHANGES.md +++ b/io/CHANGES.md @@ -1,4 +1,6 @@ -# unreleased +# unreleased 0.6.0 +## Change +- update `tokio-uring-xitca` to `0.2.0` # 0.5.1 ## Fix diff --git a/io/Cargo.toml b/io/Cargo.toml index 2bd5a0366..87206b2fc 100644 --- a/io/Cargo.toml +++ b/io/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xitca-io" -version = "0.5.1" +version = "0.6.0" edition = "2024" license = "Apache-2.0" description = "async network io types and traits" @@ -14,7 +14,7 @@ default = [] # tokio runtime support runtime = ["tokio"] # tokio-uring runtime support -runtime-uring = ["dep:tokio-uring-xitca"] +runtime-uring = ["tokio-uring-xitca/runtime"] # quic support quic = ["dep:quinn", "runtime"] @@ -24,4 +24,4 @@ xitca-unsafe-collection = { version = "0.2.0", features = ["bytes"] } bytes = "1.4" quinn = { version = "0.11", features = ["ring"], optional = true } tokio = { version = "1.48", features = ["net"], optional = true } -tokio-uring-xitca = { version = "0.1.1", features = ["bytes"], optional = true } +tokio-uring-xitca = { version = "0.2.0", features = ["bytes"] } diff --git a/tokio-uring/CHANGELOG.md b/tokio-uring/CHANGELOG.md index 285ad0512..1d3892f14 100644 --- a/tokio-uring/CHANGELOG.md +++ b/tokio-uring/CHANGELOG.md @@ -1,4 +1,6 @@ -# unreleased 0.1.2 +# unreleased 0.2.0 +- remove runtime from default feature +- add runtime feature for io-uring runtime - perf improvement # 0.1.1 diff --git a/tokio-uring/Cargo.toml b/tokio-uring/Cargo.toml index 9bd82b7a8..90b51c9b8 100644 --- a/tokio-uring/Cargo.toml +++ b/tokio-uring/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-uring-xitca" -version = "0.1.2" +version = "0.2.0" authors = ["Tokio Contributors ", "fakeshadow "] edition = "2024" readme = "README.md" @@ -18,13 +18,14 @@ workspace = true [features] bytes = ["dep:bytes"] +runtime = ["dep:tokio", "dep:slab", "dep:libc", "dep:io-uring", "dep:socket2"] [dependencies] -tokio = { version = "1.48", features = ["net", "rt", "sync"] } -slab = "0.4.11" -libc = "0.2.178" -io-uring = "0.7.11" -socket2 = { version = "0.6.1", features = ["all"] } +tokio = { version = "1.48", features = ["net", "rt", "sync"], optional = true } +slab = { version = "0.4.11", optional = true } +libc = { version = "0.2.178", optional = true } +io-uring = { version = "0.7.11", optional = true } +socket2 = { version = "0.6.1", features = ["all"], optional = true } bytes = { version = "1.11.0", optional = true } [dev-dependencies] diff --git a/tokio-uring/src/buf/mod.rs b/tokio-uring/src/buf/mod.rs index 71ab196c2..c6fa458e3 100644 --- a/tokio-uring/src/buf/mod.rs +++ b/tokio-uring/src/buf/mod.rs @@ -4,6 +4,7 @@ //! crate defines [`IoBuf`] and [`IoBufMut`] traits which are implemented by buffer //! types that respect the `io-uring` contract. +#[cfg(feature = "runtime")] pub mod fixed; mod io_buf; diff --git a/tokio-uring/src/lib.rs b/tokio-uring/src/lib.rs index 19245cf78..6c21359d2 100644 --- a/tokio-uring/src/lib.rs +++ b/tokio-uring/src/lib.rs @@ -59,6 +59,7 @@ #![warn(missing_docs)] #![allow(clippy::missing_const_for_thread_local)] +#[cfg(feature = "runtime")] macro_rules! syscall { ($fn: ident ( $($arg: expr),* $(,)* ) ) => {{ let res = unsafe { ::libc::$fn($($arg, )*) }; @@ -70,24 +71,33 @@ macro_rules! syscall { }}; } +#[cfg(feature = "runtime")] mod io; +#[cfg(feature = "runtime")] mod runtime; pub mod buf; +#[cfg(feature = "runtime")] pub mod fs; +#[cfg(feature = "runtime")] pub mod net; +#[cfg(feature = "runtime")] pub use io::write::*; +#[cfg(feature = "runtime")] pub use runtime::{ Runtime, driver::op::{InFlightOneshot, OneshotOutputTransform, UnsubmittedOneshot}, spawn, }; +#[cfg(feature = "runtime")] use core::future::Future; +#[cfg(feature = "runtime")] use runtime::driver::op::Op; +#[cfg(feature = "runtime")] /// Starts an `io_uring` enabled Tokio runtime. /// /// All `tokio-uring` resource types must be used from within the context of a @@ -100,80 +110,28 @@ use runtime::driver::op::Op; /// `io-uring` driver. All tasks spawned on the `tokio-uring` runtime are /// executed on the current thread. To add concurrency, spawn multiple threads, /// each with a `tokio-uring` runtime. -/// -/// # Examples -/// -/// Basic usage -/// -/// ```no_run -/// use tokio_uring_xitca::fs::File; -/// -/// fn main() -> Result<(), Box> { -/// tokio_uring_xitca::start(async { -/// // Open a file -/// let file = File::open("hello.txt").await?; -/// -/// let buf = vec![0; 4096]; -/// // Read some data, the buffer is passed by ownership and -/// // submitted to the kernel. When the operation completes, -/// // we get the buffer back. -/// let (res, buf) = file.read_at(buf, 0).await; -/// let n = res?; -/// -/// // Display the contents -/// println!("{:?}", &buf[..n]); -/// -/// Ok(()) -/// }) -/// } -/// ``` -/// -/// Using Tokio types from the `tokio-uring` runtime -/// -/// -/// ```no_run -/// use tokio::net::TcpListener; -/// -/// fn main() -> Result<(), Box> { -/// tokio_uring_xitca::start(async { -/// let listener = TcpListener::bind("127.0.0.1:8080").await?; -/// -/// loop { -/// let (socket, _) = listener.accept().await?; -/// // process socket -/// } -/// }) -/// } -/// ``` pub fn start(future: F) -> F::Output { let rt = Runtime::new(&builder()).unwrap(); rt.block_on(future) } +#[cfg(feature = "runtime")] /// Creates and returns an io_uring::Builder that can then be modified /// through its implementation methods. -/// -/// This function is provided to avoid requiring the user of this crate from -/// having to use the io_uring crate as well. Refer to Builder::start example -/// for its intended usage. pub fn uring_builder() -> io_uring::Builder { io_uring::IoUring::builder() } +#[cfg(feature = "runtime")] /// Builder API that can create and start the `io_uring` runtime with non-default parameters, /// while abstracting away the underlying io_uring crate. -// #[derive(Clone, Default)] pub struct Builder { entries: u32, urb: io_uring::Builder, } +#[cfg(feature = "runtime")] /// Constructs a [`Builder`] with default settings. -/// -/// Use this to alter submission and completion queue parameters, and to create the io_uring -/// Runtime. -/// -/// Refer to [`Builder::start`] for an example. pub fn builder() -> Builder { Builder { entries: 256, @@ -181,13 +139,9 @@ pub fn builder() -> Builder { } } +#[cfg(feature = "runtime")] impl Builder { /// Sets the number of Submission Queue entries in uring. - /// - /// The default value is 256. - /// The kernel requires the number of submission queue entries to be a power of two, - /// and that it be less than the number of completion queue entries. - /// This function will adjust the `cq_entries` value to be at least 2 times `sq_entries` pub fn entries(&mut self, sq_entries: u32) -> &mut Self { self.entries = sq_entries; self @@ -195,40 +149,12 @@ impl Builder { /// Replaces the default [`io_uring::Builder`], which controls the settings for the /// inner `io_uring` API. - /// - /// Refer to the [`io_uring::Builder`] documentation for all the supported methods. pub fn uring_builder(&mut self, b: &io_uring::Builder) -> &mut Self { self.urb = b.clone(); self } /// Starts an `io_uring` enabled Tokio runtime. - /// - /// # Examples - /// - /// Creating a uring driver with only 64 submission queue entries but - /// many more completion queue entries. - /// - /// ```no_run - /// use tokio::net::TcpListener; - /// - /// fn main() -> Result<(), Box> { - /// tokio_uring_xitca::builder() - /// .entries(64) - /// .uring_builder(tokio_uring_xitca::uring_builder() - /// .setup_cqsize(1024) - /// ) - /// .start(async { - /// let listener = TcpListener::bind("127.0.0.1:8080").await?; - /// - /// loop { - /// let (socket, _) = listener.accept().await?; - /// // process socket - /// } - /// } - /// ) - /// } - /// ``` pub fn start(&self, future: F) -> F::Output { let rt = runtime::Runtime::new(self).unwrap(); rt.block_on(future) @@ -236,53 +162,10 @@ impl Builder { } /// A specialized `Result` type for `io-uring` operations with buffers. -/// -/// This type is used as a return value for asynchronous `io-uring` methods that -/// require passing ownership of a buffer to the runtime. When the operation -/// completes, the buffer is returned whether or not the operation completed -/// successfully. -/// -/// # Examples -/// -/// ```no_run -/// use tokio_uring_xitca::fs::File; -/// -/// fn main() -> Result<(), Box> { -/// tokio_uring_xitca::start(async { -/// // Open a file -/// let file = File::open("hello.txt").await?; -/// -/// let buf = vec![0; 4096]; -/// // Read some data, the buffer is passed by ownership and -/// // submitted to the kernel. When the operation completes, -/// // we get the buffer back. -/// let (res, buf) = file.read_at(buf, 0).await; -/// let n = res?; -/// -/// // Display the contents -/// println!("{:?}", &buf[..n]); -/// -/// Ok(()) -/// }) -/// } -/// ``` pub type BufResult = (std::io::Result, B); +#[cfg(feature = "runtime")] /// The simplest possible operation. Just posts a completion event, nothing else. -/// -/// This has a place in benchmarking and sanity checking uring. -/// -/// # Examples -/// -/// ```no_run -/// fn main() -> Result<(), Box> { -/// tokio_uring_xitca::start(async { -/// // Place a NoOp on the ring, and await completion event -/// tokio_uring_xitca::no_op().await?; -/// Ok(()) -/// }) -/// } -/// ``` pub async fn no_op() -> std::io::Result<()> { let op = Op::::no_op().unwrap(); op.await From 362a4d66809f308bd8df8f13a75885d291edb022 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 1 Apr 2026 17:24:36 +0800 Subject: [PATCH 12/21] wip --- examples/Cargo.toml | 7 + examples/io-uring/Cargo.toml | 4 +- examples/io-uring/src/main.rs | 9 +- http/CHANGES.md | 6 +- http/Cargo.toml | 20 +- http/src/builder.rs | 12 - http/src/h1/body.rs | 221 +----- http/src/h1/builder.rs | 21 +- http/src/h1/dispatcher.rs | 702 +++++++++++------- http/src/h1/dispatcher_compio.rs | 451 ----------- http/src/h1/dispatcher_uring.rs | 456 ------------ http/src/h1/mod.rs | 8 +- http/src/h1/service.rs | 93 +-- http/src/h2/dispatcher_uring.rs | 2 +- http/src/h2/service.rs | 2 +- http/src/service.rs | 32 +- http/src/tls/mod.rs | 2 - http/src/tls/rustls.rs | 66 +- http/src/tls/rustls_uring.rs | 94 --- io/CHANGES.md | 4 + io/Cargo.toml | 8 +- io/src/io.rs | 258 +------ io/src/{io_uring.rs => io/complete.rs} | 15 +- io/src/io/poll.rs | 254 +++++++ io/src/lib.rs | 6 +- io/src/net.rs | 81 ++ io/src/net/io_uring.rs | 6 +- server/CHANGES.md | 5 +- server/Cargo.toml | 6 +- tls/CHANGES.md | 10 +- tls/Cargo.toml | 21 +- tls/src/bridge.rs | 123 +++ tls/src/lib.rs | 15 +- tls/src/native_tls_complete.rs | 246 ++++++ tls/src/openssl.rs | 27 +- tls/src/openssl_complete.rs | 188 +++++ tls/src/openssl_uring.rs | 398 ---------- .../{rustls_uring.rs => rustls_complete.rs} | 74 +- tokio-uring/CHANGELOG.md | 10 +- tokio-uring/src/buf/bounded.rs | 91 ++- tokio-uring/src/buf/fixed/handle.rs | 6 +- tokio-uring/tests/fixed_buf.rs | 2 +- 42 files changed, 1614 insertions(+), 2448 deletions(-) delete mode 100644 http/src/h1/dispatcher_compio.rs delete mode 100644 http/src/h1/dispatcher_uring.rs delete mode 100644 http/src/tls/rustls_uring.rs rename io/src/{io_uring.rs => io/complete.rs} (57%) create mode 100644 io/src/io/poll.rs create mode 100644 tls/src/bridge.rs create mode 100644 tls/src/native_tls_complete.rs create mode 100644 tls/src/openssl_complete.rs delete mode 100644 tls/src/openssl_uring.rs rename tls/src/{rustls_uring.rs => rustls_complete.rs} (89%) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index a745a868c..ce4110fe6 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -25,3 +25,10 @@ members = [ [profile.dev] debug-assertions = false + +[patch.crates-io] +tokio-uring-xitca = { path = "../tokio-uring" } +xitca-tls= { path = "../tls" } +xitca-io= { path = "../io" } +xitca-server = { path = "../server" } +xitca-http = { path = "../http" } diff --git a/examples/io-uring/Cargo.toml b/examples/io-uring/Cargo.toml index e7ba94894..e38e7f33b 100644 --- a/examples/io-uring/Cargo.toml +++ b/examples/io-uring/Cargo.toml @@ -5,8 +5,8 @@ authors = ["fakeshadow <24548779@qq.com>"] edition = "2024" [dependencies] -xitca-http = { version = "0.7", features = ["io-uring", "router", "rustls-uring"] } -xitca-server = { version = "0.5", features = ["io-uring"] } +xitca-http = { version = "0.8", features = ["io-uring", "router", "rustls"] } +xitca-server = { version = "0.7", features = ["io-uring"] } xitca-service = "0.3" rcgen = "0.14" diff --git a/examples/io-uring/src/main.rs b/examples/io-uring/src/main.rs index b8756f80f..495714ac2 100644 --- a/examples/io-uring/src/main.rs +++ b/examples/io-uring/src/main.rs @@ -7,11 +7,10 @@ use std::{convert::Infallible, io, sync::Arc}; use rustls::ServerConfig; use xitca_http::{ - h1, - http::{const_header_value::TEXT_UTF8, header::CONTENT_TYPE, Request, RequestExt, Response}, - HttpServiceBuilder, ResponseBody, + HttpServiceBuilder, ResponseBody, h1, + http::{Request, RequestExt, Response, const_header_value::TEXT_UTF8, header::CONTENT_TYPE}, }; -use xitca_service::{fn_service, ServiceExt}; +use xitca_service::{ServiceExt, fn_service}; fn main() -> io::Result<()> { xitca_server::Builder::new() @@ -21,7 +20,7 @@ fn main() -> io::Result<()> { fn_service(handler).enclosed( HttpServiceBuilder::h1() .io_uring() // specify io_uring flavor of http service. - .rustls_uring(tls_config()), // specify io_uring flavor of tls. + .rustls(tls_config()), // specify io_uring flavor of tls. ), )? .build() diff --git a/http/CHANGES.md b/http/CHANGES.md index 7be00e31d..6caaa218b 100644 --- a/http/CHANGES.md +++ b/http/CHANGES.md @@ -1,6 +1,8 @@ -# unreleased 0.8.3 +# unreleased 0.9.0 ## Change -- update `xitca-tls` to `0.5.2` +- update `xitca-io` to `0.6.0` +- update `xitca-tls` to `0.6.0` +- use completion based API for all I/O operations # 0.8.2 ## Fix diff --git a/http/Cargo.toml b/http/Cargo.toml index b7b8a2a89..0c8a30fd8 100644 --- a/http/Cargo.toml +++ b/http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xitca-http" -version = "0.8.3" +version = "0.9.0" edition = "2024" license = "Apache-2.0" description = "http library for xitca" @@ -23,11 +23,9 @@ http2 = ["h2", "fnv", "futures-util/alloc", "runtime", "slab"] # http3 specific feature. http3 = ["xitca-io/quic", "futures-util/alloc", "h3", "h3-quinn", "runtime"] # openssl as server side tls. -openssl = ["xitca-tls/openssl", "runtime"] +openssl = ["xitca-tls/openssl"] # rustls as server side tls. -rustls = ["xitca-tls/rustls-no-crypto", "runtime"] -# rustls as server side tls. -rustls-uring = ["rustls", "xitca-tls/rustls-uring-no-crypto", "xitca-io/runtime-uring"] +rustls = ["xitca-tls/rustls"] # rustls as server side tls. native-tls = ["dep:native-tls", "runtime"] # async runtime feature. @@ -35,11 +33,10 @@ runtime = ["xitca-io/runtime", "tokio"] # unstable features that are subject to be changed at anytime. io-uring = ["xitca-io/runtime-uring"] -compio = ["dep:compio-buf", "dep:compio-io", "dep:compio-net"] router = ["xitca-router"] [dependencies] -xitca-io = "0.5.1" +xitca-io = "0.6.0" xitca-service = { version = "0.3.0", features = ["alloc"] } xitca-unsafe-collection = { version = "0.2.0", features = ["bytes"] } @@ -53,7 +50,7 @@ tracing = { version = "0.1.40", default-features = false } native-tls = { version = "0.2.7", features = ["alpn"], optional = true } # tls support shared -xitca-tls = { version = "0.5.2", optional = true } +xitca-tls = { version = "0.6.0", optional = true } # http/1 support httparse = { version = "1.8", optional = true } @@ -75,16 +72,11 @@ tokio = { version = "1.48", features = ["rt", "time"], optional = true } # util service support xitca-router = { version = "0.4.1", optional = true } -# compio optional. not officially supported only exist for possible benchmarking usage -compio-buf = { version = "0.7", features = ["bytes"], optional = true } -compio-io = { version = "0.8", optional = true } -compio-net = { version = "0.10", optional = true } - [target.'cfg(not(target_family = "wasm"))'.dependencies] socket2 = { version = "0.6.0", features = ["all"] } [dev-dependencies] -criterion = "0.7" +criterion = "0.8" xitca-server = "0.6.1" [[bench]] diff --git a/http/src/builder.rs b/http/src/builder.rs index c5bd8417d..daef6f0f8 100644 --- a/http/src/builder.rs +++ b/http/src/builder.rs @@ -17,8 +17,6 @@ pub(crate) mod marker { pub struct Http; #[cfg(feature = "http1")] pub struct Http1; - #[cfg(all(feature = "io-uring", feature = "http1"))] - pub struct Http1Uring; #[cfg(all(feature = "io-uring", feature = "http2"))] pub struct Http2Uring; #[cfg(feature = "http2")] @@ -160,16 +158,6 @@ impl HttpServiceBuilder - { - self.with_tls(tls::rustls_uring::TlsAcceptorBuilder::new(config)) - } - #[cfg(feature = "native-tls")] /// use native-tls as tls service. tnative-tlsls service is used for Http/1 protocol only. pub fn native_tls( diff --git a/http/src/h1/body.rs b/http/src/h1/body.rs index b55aed3a2..d126f9a32 100644 --- a/http/src/h1/body.rs +++ b/http/src/h1/body.rs @@ -1,21 +1,15 @@ use core::{ - cell::{RefCell, RefMut}, fmt, - future::poll_fn, - ops::DerefMut, pin::Pin, - task::{Context, Poll, Waker}, + task::{Context, Poll}, }; -use std::{collections::VecDeque, io, rc::Rc}; +use std::io; use futures_core::stream::Stream; use crate::bytes::Bytes; -/// max buffer size 32k -pub(crate) const MAX_BUFFER_SIZE: usize = 32_768; - /// Buffered stream of request body chunk. /// /// impl [Stream] trait to produce chunk as [Bytes] type in async manner. @@ -36,16 +30,6 @@ impl Default for RequestBody { } impl RequestBody { - // an async spsc channel where sender is used to push data and popped from RequestBody. - pub(super) fn channel(eof: bool) -> (BodySender, Self) { - if eof { - (ChannelBody::none(), RequestBody::none()) - } else { - let body = ChannelBody::stream(); - (body.clone(), RequestBody::stream(body)) - } - } - pub(super) fn stream(stream: S) -> Self where S: Stream> + 'static, @@ -74,204 +58,3 @@ impl From for crate::body::RequestBody { Self::H1(body) } } - -/// Sender part of the payload stream -pub(super) type BodySender = ChannelBody; - -#[derive(Clone)] -pub(super) struct ChannelBody(Option>>); - -impl Stream for ChannelBody { - type Item = io::Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - match self.get_mut().0 { - Some(ref body) => body.borrow_mut().poll_next_unpin(cx), - None => Poll::Ready(None), - } - } -} - -// TODO: rework early eof error handling. -impl Drop for ChannelBody { - fn drop(&mut self) { - if let Some(mut inner) = self.try_inner() { - if !inner.eof { - inner.feed_error(io::ErrorKind::UnexpectedEof.into()); - } - } - } -} - -impl ChannelBody { - fn stream() -> Self { - Self(Some(Default::default())) - } - - fn none() -> Self { - Self(None) - } - - // try to get a mutable reference of inner and ignore RequestBody::None variant. - fn try_inner(&mut self) -> Option> { - self.try_inner_on_none_with(|| {}) - } - - // try to get a mutable reference of inner and panic on RequestBody::None variant. - // this is a runtime check for internal optimization to avoid unnecessary operations. - // public api must not be able to trigger this panic. - fn try_inner_infallible(&mut self) -> Option> { - self.try_inner_on_none_with(|| panic!("No Request Body found. Do not waste operation on Sender.")) - } - - fn try_inner_on_none_with(&mut self, func: F) -> Option> - where - F: FnOnce(), - { - match self.0 { - Some(ref inner) => { - // request body is a shared pointer between only two owners and no weak reference. - debug_assert!(Rc::strong_count(inner) <= 2); - debug_assert_eq!(Rc::weak_count(inner), 0); - (Rc::strong_count(inner) != 1).then_some(inner.borrow_mut()) - } - None => { - func(); - None - } - } - } - - pub(super) fn feed_error(&mut self, e: io::Error) { - if let Some(mut inner) = self.try_inner_infallible() { - inner.feed_error(e); - } - } - - pub(super) fn feed_eof(&mut self) { - if let Some(mut inner) = self.try_inner_infallible() { - inner.feed_eof(); - } - } - - pub(super) fn feed_data(&mut self, data: Bytes) { - if let Some(mut inner) = self.try_inner_infallible() { - inner.feed_data(data); - } - } - - pub(super) fn ready(&mut self) -> impl Future> + '_ { - self.ready_with(|inner| !inner.backpressure()) - } - - // Lazily wait until RequestBody is already polled. - // For specific use case body must not be eagerly polled. - // For example: Request with Expect: Continue header. - pub(super) fn wait_for_poll(&mut self) -> impl Future> + '_ { - self.ready_with(|inner| inner.waiting()) - } - - async fn ready_with(&mut self, func: F) -> io::Result<()> - where - F: Fn(&mut Inner) -> bool, - { - poll_fn(|cx| { - // Check only if Payload (other side) is alive, Otherwise always return io error. - match self.try_inner_infallible() { - Some(mut inner) => { - if func(inner.deref_mut()) { - Poll::Ready(Ok(())) - } else { - // when payload is not ready register current task waker and wait. - inner.register_io(cx); - Poll::Pending - } - } - None => Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), - } - }) - .await - } -} - -#[derive(Debug, Default)] -struct Inner { - eof: bool, - len: usize, - err: Option, - items: VecDeque, - task: Option, - io_task: Option, -} - -impl Inner { - /// Wake up future waiting for payload data to be available. - fn wake(&mut self) { - if let Some(waker) = self.task.take() { - waker.wake(); - } - } - - /// Wake up future feeding data to Payload. - fn wake_io(&mut self) { - if let Some(waker) = self.io_task.take() { - waker.wake(); - } - } - - /// true when a future is waiting for payload data. - fn waiting(&self) -> bool { - self.task.is_some() - } - - /// Register future waiting data from payload. - /// Waker would be used in `Inner::wake` - fn register(&mut self, cx: &Context<'_>) { - if self.task.as_ref().map(|w| !cx.waker().will_wake(w)).unwrap_or(true) { - self.task = Some(cx.waker().clone()); - } - } - - // Register future feeding data to payload. - /// Waker would be used in `Inner::wake_io` - fn register_io(&mut self, cx: &Context<'_>) { - if self.io_task.as_ref().map(|w| !cx.waker().will_wake(w)).unwrap_or(true) { - self.io_task = Some(cx.waker().clone()); - } - } - - fn feed_error(&mut self, err: io::Error) { - self.err = Some(err); - self.wake(); - } - - fn feed_eof(&mut self) { - self.eof = true; - self.wake(); - } - - fn feed_data(&mut self, data: Bytes) { - self.len += data.len(); - self.items.push_back(data); - self.wake(); - } - - fn backpressure(&self) -> bool { - self.len >= MAX_BUFFER_SIZE - } - - fn poll_next_unpin(&mut self, cx: &Context<'_>) -> Poll>> { - if let Some(data) = self.items.pop_front() { - self.len -= data.len(); - Poll::Ready(Some(Ok(data))) - } else if let Some(err) = self.err.take() { - Poll::Ready(Some(Err(err))) - } else if self.eof { - Poll::Ready(None) - } else { - self.register(cx); - self.wake_io(); - Poll::Pending - } - } -} diff --git a/http/src/h1/builder.rs b/http/src/h1/builder.rs index df9a4320d..28aa14c39 100644 --- a/http/src/h1/builder.rs +++ b/http/src/h1/builder.rs @@ -29,7 +29,7 @@ impl HttpServiceBuilder< - marker::Http1Uring, + marker::Http1, xitca_io::net::io_uring::TcpStream, FA, HEADER_LIMIT, @@ -65,22 +65,3 @@ where Ok(H1Service::new(self.config, service, tls_acceptor)) } } - -#[cfg(feature = "io-uring")] -impl - Service> - for HttpServiceBuilder -where - FA: Service, - FA::Error: fmt::Debug + 'static, - E: fmt::Debug + 'static, -{ - type Response = super::service::H1UringService; - type Error = Error; - - async fn call(&self, res: Result) -> Result { - let service = res.map_err(|e| Box::new(e) as Error)?; - let tls_acceptor = self.tls_factory.call(()).await.map_err(|e| Box::new(e) as Error)?; - Ok(super::service::H1UringService::new(self.config, service, tls_acceptor)) - } -} diff --git a/http/src/h1/dispatcher.rs b/http/src/h1/dispatcher.rs index b2a13d447..8d5dc8c16 100644 --- a/http/src/h1/dispatcher.rs +++ b/http/src/h1/dispatcher.rs @@ -1,326 +1,540 @@ use core::{ - convert::Infallible, - future::{pending, poll_fn}, + cell::RefCell, + future::poll_fn, marker::PhantomData, + mem, net::SocketAddr, pin::{Pin, pin}, + task::{self, Poll, Waker, ready}, time::Duration, }; -use std::io; +use std::{io, net::Shutdown, rc::Rc}; use futures_core::stream::Stream; +use pin_project_lite::pin_project; use tracing::trace; -use xitca_io::io::{AsyncIo, Interest, Ready}; +use xitca_io::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, write_all}; use xitca_service::Service; -use xitca_unsafe_collection::futures::{Select as _, SelectOutput}; +use xitca_unsafe_collection::futures::SelectOutput; use crate::{ body::NoneBody, - bytes::{Bytes, EitherBuf}, + bytes::{Bytes, BytesMut}, config::HttpServiceConfig, date::DateTime, h1::{ - body::{BodySender, RequestBody}, + body::RequestBody, error::Error, + proto::{buf_write::H1BufWrite, error::ProtoError}, }, - http::{ - StatusCode, - response::{Parts, Response}, - }, - util::{ - buffered::{BufferedIo, ListWriteBuf, ReadBuf, WriteBuf}, - timer::{KeepAlive, Timeout}, - }, + http::{StatusCode, response::Response}, + util::timer::{KeepAlive, Timeout}, }; use super::proto::{ - buf_write::H1BufWrite, codec::{ChunkResult, TransferCoding}, context::Context, - encode::CONTINUE, - error::ProtoError, + encode::CONTINUE_BYTES, }; type ExtRequest = crate::http::Request>; -/// function to generic over different writer buffer types dispatcher. -pub(crate) async fn run< - 'a, - St, - S, - ReqB, - ResB, - BE, - D, - const HEADER_LIMIT: usize, - const READ_BUF_LIMIT: usize, - const WRITE_BUF_LIMIT: usize, ->( - io: &'a mut St, - addr: SocketAddr, - timer: Pin<&'a mut KeepAlive>, - config: HttpServiceConfig, - service: &'a S, - date: &'a D, -) -> Result<(), Error> -where - S: Service, Response = Response>, - ReqB: From, - ResB: Stream>, - St: AsyncIo, - D: DateTime, -{ - let write_buf = if config.vectored_write && io.is_vectored_write() { - EitherBuf::Left(ListWriteBuf::<_, WRITE_BUF_LIMIT>::default()) - } else { - EitherBuf::Right(WriteBuf::::default()) - }; - - Dispatcher::new(io, addr, timer, config, service, date, write_buf) - .run() - .await -} - /// Http/1 dispatcher -struct Dispatcher<'a, St, S, ReqB, W, D, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize> { - io: BufferedIo<'a, St, W, READ_BUF_LIMIT>, +pub struct Dispatcher<'a, Io, S, ReqB, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize> { + io: SharedIo, timer: Timer<'a>, - ctx: Context<'a, D, HEADER_LIMIT>, + ctx: Context<'a, D, H_LIMIT>, service: &'a S, _phantom: PhantomData, } -// timer state is transformed in following order: -// -// Idle (expecting keep-alive duration) <-- -// | | -// --> Wait (expecting request head duration) | -// | | -// --> Throttle (expecting manually set to Idle again) -enum TimerState { - Idle, - Wait, - Throttle, -} +trait BufIo { + fn read(self, io: &impl AsyncBufRead) -> impl Future, Self)>; -pub(super) struct Timer<'a> { - timer: Pin<&'a mut KeepAlive>, - state: TimerState, - ka_dur: Duration, - req_dur: Duration, + fn write(self, io: &impl AsyncBufWrite) -> impl Future, Self)>; } -impl<'a> Timer<'a> { - pub(super) fn new(timer: Pin<&'a mut KeepAlive>, ka_dur: Duration, req_dur: Duration) -> Self { - Self { - timer, - state: TimerState::Idle, - ka_dur, - req_dur, - } - } +impl BufIo for BytesMut { + async fn read(mut self, io: &impl AsyncBufRead) -> (io::Result, Self) { + let len = self.len(); - pub(super) fn reset_state(&mut self) { - self.state = TimerState::Idle; - } + self.reserve(4096); - pub(super) fn get(&mut self) -> Pin<&mut KeepAlive> { - self.timer.as_mut() + let (res, buf) = io.read(self.slice(len..)).await; + (res, buf.into_inner()) } - // update timer with a given base instant value. the final deadline is calculated base on it. - pub(super) fn update(&mut self, now: tokio::time::Instant) { - let dur = match self.state { - TimerState::Idle => { - self.state = TimerState::Wait; - self.ka_dur - } - TimerState::Wait => { - self.state = TimerState::Throttle; - self.req_dur - } - TimerState::Throttle => return, - }; - self.timer.as_mut().update(now + dur) - } - - #[cold] - #[inline(never)] - pub(super) fn map_to_err(&self) -> Error { - match self.state { - TimerState::Wait => Error::KeepAliveExpire, - TimerState::Throttle => Error::RequestTimeout, - TimerState::Idle => unreachable!(), - } + async fn write(self, io: &impl AsyncBufWrite) -> (io::Result<()>, Self) { + let (res, mut buf) = write_all(io, self).await; + buf.clear(); + (res, buf) } } -impl<'a, St, S, ReqB, ResB, BE, W, D, const HEADER_LIMIT: usize, const READ_BUF_LIMIT: usize> - Dispatcher<'a, St, S, ReqB, W, D, HEADER_LIMIT, READ_BUF_LIMIT> +impl<'a, Io, S, ReqB, ResB, BE, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize> + Dispatcher<'a, Io, S, ReqB, D, H_LIMIT, R_LIMIT, W_LIMIT> where + Io: AsyncBufRead + AsyncBufWrite + 'static, S: Service, Response = Response>, ReqB: From, ResB: Stream>, - St: AsyncIo, - W: H1BufWrite, D: DateTime, { - fn new( - io: &'a mut St, + pub async fn run( + io: Io, addr: SocketAddr, timer: Pin<&'a mut KeepAlive>, - config: HttpServiceConfig, + config: HttpServiceConfig, service: &'a S, date: &'a D, - write_buf: W, - ) -> Self { - Self { - io: BufferedIo::new(io, write_buf), + ) -> Result<(), Error> { + let mut dispatcher = Dispatcher::<_, _, _, _, H_LIMIT, R_LIMIT, W_LIMIT> { + io: SharedIo::new(io), timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout), ctx: Context::with_addr(addr, date), service, _phantom: PhantomData, - } - } + }; + + let mut read_buf = BytesMut::new(); + let mut write_buf = BytesMut::new(); - async fn run(mut self) -> Result<(), Error> { loop { - if let Err(err) = self._run().await { - handle_error(&mut self.ctx, &mut self.io.write_buf, err)?; + let (res, r_buf, w_buf) = dispatcher._run(read_buf, write_buf).await; + read_buf = r_buf; + write_buf = w_buf; + if let Err(err) = res { + handle_error(&mut dispatcher.ctx, &mut write_buf, err)?; } - // TODO: add timeout for drain write? - self.io.drain_write().await?; + let (res, w_buf) = write_buf.write(dispatcher.io.io()).await; + write_buf = w_buf; + + res?; - if self.ctx.is_connection_closed() { - return self.io.shutdown().await.map_err(Into::into); + if dispatcher.ctx.is_connection_closed() { + return dispatcher.shutdown().await; } } } - async fn _run(&mut self) -> Result<(), Error> { + async fn _run( + &mut self, + mut read_buf: BytesMut, + mut write_buf: BytesMut, + ) -> (Result<(), Error>, BytesMut, BytesMut) { self.timer.update(self.ctx.date().now()); - let read = self - .io - .read() - .timeout(self.timer.get()) - .await - .map_err(|_| self.timer.map_to_err())??; - - if read == 0 { - self.ctx.set_close(); - return Ok(()); + + match read_buf.read(self.io.io()).timeout(self.timer.get()).await { + Ok((res, r_buf)) => { + read_buf = r_buf; + match res { + Ok(read) => { + if read == 0 { + self.ctx.set_close(); + return (Ok(()), read_buf, write_buf); + } + } + Err(e) => return (Err(e.into()), read_buf, write_buf), + } + } + // read_buf is lost during timeout cancel. make an empty new one instead + Err(_) => return (Err(self.timer.map_to_err()), BytesMut::new(), write_buf), } - while let Some((req, decoder)) = self.ctx.decode_head::(&mut self.io.read_buf)? { + loop { + let (req, decoder) = match self.ctx.decode_head::(&mut read_buf) { + Ok(Some(req)) => req, + Ok(None) => break, + Err(e) => return (Err(e.into()), read_buf, write_buf), + }; + self.timer.reset_state(); - let (mut body_reader, body) = BodyReader::from_coding(decoder); + let (wait_for_notify, body) = if decoder.is_eof() { + (false, RequestBody::none()) + } else { + let body = body( + self.io.notifier(), + self.ctx.is_expect_header(), + R_LIMIT, + decoder, + read_buf.split(), + ); + + (true, body) + }; + let req = req.map(|ext| ext.map_body(|_| ReqB::from(body))); - let (parts, body) = match self - .service - .call(req) - .select(self.request_body_handler(&mut body_reader)) - .await - { - SelectOutput::A(Ok(res)) => res.into_parts(), - SelectOutput::A(Err(e)) => return Err(Error::Service(e)), - SelectOutput::B(Err(e)) => return Err(e), - SelectOutput::B(Ok(i)) => match i {}, + let (parts, body) = match self.service.call(req).await { + Ok(res) => res.into_parts(), + Err(e) => return (Err(Error::Service(e)), read_buf, write_buf), }; - let encoder = &mut self.encode_head(parts, &body)?; - let mut body = pin!(body); - - loop { - match self - .try_poll_body(body.as_mut()) - .select(self.io_ready(&mut body_reader)) - .await - { - SelectOutput::A(Some(Ok(bytes))) => encoder.encode(bytes, &mut self.io.write_buf), - SelectOutput::B(Ok(ready)) => { - if ready.is_readable() { - match self.io.try_read() { - Ok(Some(0)) => body_reader.feed_error(io::ErrorKind::UnexpectedEof.into()), - Ok(_) => {} - Err(e) => body_reader.feed_error(e), + let mut encoder = match self.ctx.encode_head(parts, &body, &mut write_buf) { + Ok(encoder) => encoder, + Err(e) => return (Err(e.into()), read_buf, write_buf), + }; + + // this block is necessary. ResB has to be dropped asap as it may hold ownership of + // Body type which if not dropped before Notifier::notify is called would prevent + // Notifier from waking up Notify. + { + let mut body = pin!(body); + + loop { + let res = poll_fn(|cx| match body.as_mut().poll_next(cx) { + Poll::Ready(res) => Poll::Ready(SelectOutput::A(res)), + Poll::Pending if write_buf.is_empty() => Poll::Pending, + Poll::Pending => Poll::Ready(SelectOutput::B(())), + }) + .await; + + match res { + SelectOutput::A(Some(Ok(bytes))) => { + encoder.encode(bytes, &mut write_buf); + if write_buf.len() < W_LIMIT { + continue; } } - if ready.is_writable() { - self.io.try_write()?; + SelectOutput::A(Some(Err(e))) => { + let (res, w_buf) = self.on_body_error(e, write_buf).await; + write_buf = w_buf; + return (res, read_buf, write_buf); } + SelectOutput::A(None) => break encoder.encode_eof(&mut write_buf), + SelectOutput::B(_) => {} } - SelectOutput::A(None) => { - encoder.encode_eof(&mut self.io.write_buf); - break; + + let (res, w_buf) = write_buf.write(self.io.io()).await; + write_buf = w_buf; + if let Err(e) = res { + return (Err(e.into()), read_buf, write_buf); } - SelectOutput::B(Err(e)) => return Err(e.into()), - SelectOutput::A(Some(Err(e))) => return Err(Error::Body(e)), } } - if !body_reader.decoder.is_eof() { - self.ctx.set_close(); - break; + if wait_for_notify { + match self.io.wait().await { + Some(r_buf) => read_buf = r_buf, + None => { + self.ctx.set_close(); + break; + } + } } } - Ok(()) + (Ok(()), read_buf, write_buf) } - fn encode_head(&mut self, parts: Parts, body: &impl Stream) -> Result { - self.ctx.encode_head(parts, body, &mut self.io.write_buf) + #[cold] + #[inline(never)] + async fn shutdown(self) -> Result<(), Error> { + self.io.io().shutdown(Shutdown::Both).await.map_err(Into::into) } - // an associated future of self.service that runs until service is resolved or error produced. - async fn request_body_handler(&mut self, body_reader: &mut BodyReader) -> Result> { - if self.ctx.is_expect_header() { - // wait for service future to start polling RequestBody. - if body_reader.wait_for_poll().await.is_ok() { - // encode continue as service future want a body. - self.io.write_buf.write_buf_static(CONTINUE); - // use drain write to make sure continue is sent to client. - self.io.drain_write().await?; - } + #[cold] + #[inline(never)] + async fn on_body_error(&mut self, e: BE, write_buf: BytesMut) -> (Result<(), Error>, BytesMut) { + let (res, write_buf) = write_buf.write(self.io.io()).await; + let e = res.err().map(Error::from).unwrap_or(Error::Body(e)); + (Err(e), write_buf) + } +} + +fn body( + io: NotifierIo, + is_expect: bool, + limit: usize, + decoder: TransferCoding, + read_buf: BytesMut, +) -> RequestBody +where + Io: AsyncBufRead + AsyncBufWrite + 'static, +{ + let body = BodyInner { + io, + decoder: Decoder { + decoder, + limit, + read_buf, + }, + }; + + let state = if is_expect { + State::ExpectWrite { + fut: async { + let (res, _) = write_all(body.io.io(), CONTINUE_BYTES).await; + res.map(|_| body) + }, } + } else { + State::Body { body } + }; + + RequestBody::stream(BodyReader { chunk_read, state }) +} + +pin_project! { + #[project = StateProj] + #[project_replace = StateProjReplace] + enum State { + Body { + body: BodyInner + }, + ChunkRead { + #[pin] + fut: FutC + }, + ExpectWrite { + #[pin] + fut: FutE, + }, + None, + } +} + +pin_project! { + struct BodyReader { + chunk_read: F, + #[pin] + state: State + } +} + +struct BodyInner { + io: NotifierIo, + decoder: Decoder, +} + +async fn chunk_read(mut body: BodyInner) -> io::Result<(usize, BodyInner)> +where + Io: AsyncBufRead, +{ + let (res, r_buf) = body.decoder.read_buf.split().read(body.io.io()).await; + body.decoder.read_buf.unsplit(r_buf); + let read = res?; + Ok((read, body)) +} + +impl Stream for BodyReader +where + Io: AsyncBufRead, + F: Fn(BodyInner) -> FutC, + FutC: Future)>>, + FutE: Future>>, +{ + type Item = io::Result; + #[inline] + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + let mut this = self.project(); loop { - body_reader.ready(&mut self.io.read_buf).await; - if self.io.read().await? == 0 { - body_reader.feed_error(io::ErrorKind::UnexpectedEof.into()); + match this.state.as_mut().project() { + StateProj::Body { body } => { + match body.decoder.decode() { + ChunkResult::Ok(bytes) => return Poll::Ready(Some(Ok(bytes))), + ChunkResult::Err(e) => return Poll::Ready(Some(Err(e))), + ChunkResult::InsufficientData => body.decoder.limit_check()?, + _ => return Poll::Ready(None), + } + + let StateProjReplace::Body { body } = this.state.as_mut().project_replace(State::None) else { + unreachable!() + }; + this.state.as_mut().project_replace(State::ChunkRead { + fut: (this.chunk_read)(body), + }); + } + StateProj::ChunkRead { fut } => { + let (read, body) = ready!(fut.poll(cx))?; + if read == 0 { + this.state.as_mut().project_replace(State::None); + return Poll::Ready(None); + } + this.state.as_mut().project_replace(State::Body { body }); + } + StateProj::ExpectWrite { fut } => { + let body = ready!(fut.poll(cx))?; + this.state.as_mut().project_replace(State::ChunkRead { + fut: (this.chunk_read)(body), + }); + } + StateProj::None => return Poll::Ready(None), } } } +} - fn try_poll_body<'b>(&self, mut body: Pin<&'b mut ResB>) -> impl Future>> + 'b { - let want_buf = self.io.write_buf.want_write_buf(); - async move { - if want_buf { - poll_fn(|cx| body.as_mut().poll_next(cx)).await - } else { - pending().await +impl Drop for BodyInner { + fn drop(&mut self) { + if self.decoder.decoder.is_eof() { + let buf = mem::take(&mut self.decoder.read_buf); + self.io.notify(buf); + } + } +} + +struct Decoder { + decoder: TransferCoding, + limit: usize, + read_buf: BytesMut, +} + +impl Decoder { + fn decode(&mut self) -> ChunkResult { + self.decoder.decode(&mut self.read_buf) + } + + fn limit_check(&self) -> io::Result<()> { + if self.read_buf.len() < self.limit { + return Ok(()); + } + + let msg = format!( + "READ_BUF_LIMIT reached: {{ limit: {}, length: {} }}", + self.limit, + self.read_buf.len() + ); + Err(io::Error::other(msg)) + } +} + +struct SharedIo { + inner: Rc<_SharedIo>, +} + +struct _SharedIo { + io: Io, + notify: RefCell>, +} + +impl SharedIo { + fn new(io: Io) -> Self { + Self { + inner: Rc::new(_SharedIo { + io, + notify: RefCell::new(Inner { waker: None, val: None }), + }), + } + } + + #[inline(always)] + fn io(&self) -> &Io { + &self.inner.io + } + + fn notifier(&mut self) -> NotifierIo { + NotifierIo { + inner: self.inner.clone(), + } + } + + fn wait(&mut self) -> impl Future> { + poll_fn(|cx| { + let mut inner = self.inner.notify.borrow_mut(); + if let Some(val) = inner.val.take() { + return Poll::Ready(Some(val)); + } else if Rc::strong_count(&self.inner) == 1 { + return Poll::Ready(None); } + inner.waker = Some(cx.waker().clone()); + Poll::Pending + }) + } +} + +struct NotifierIo { + inner: Rc<_SharedIo>, +} + +impl Drop for NotifierIo { + fn drop(&mut self) { + if let Some(waker) = self.inner.notify.borrow_mut().waker.take() { + waker.wake(); } } +} - // Check readable and writable state of BufferedIo and ready state of request body reader. - // return error when runtime is shutdown.(See AsyncIo::ready for reason). - async fn io_ready(&mut self, body_reader: &mut BodyReader) -> io::Result { - if !self.io.write_buf.want_write_io() { - body_reader.ready(&mut self.io.read_buf).await; - self.io.io.ready(Interest::READABLE).await - } else { - match body_reader - .ready(&mut self.io.read_buf) - .select(self.io.io.ready(Interest::WRITABLE)) - .await - { - SelectOutput::A(_) => self.io.io.ready(Interest::READABLE | Interest::WRITABLE).await, - SelectOutput::B(res) => res, +impl NotifierIo { + fn io(&self) -> &Io { + &self.inner.io + } + + fn notify(&mut self, val: BytesMut) { + self.inner.notify.borrow_mut().val = Some(val); + } +} + +struct Inner { + waker: Option, + val: Option, +} + +// timer state is transformed in following order: +// +// Idle (expecting keep-alive duration) <-- +// | | +// --> Wait (expecting request head duration) | +// | | +// --> Throttle (expecting manually set to Idle again) +enum TimerState { + Idle, + Wait, + Throttle, +} + +struct Timer<'a> { + timer: Pin<&'a mut KeepAlive>, + state: TimerState, + ka_dur: Duration, + req_dur: Duration, +} + +impl<'a> Timer<'a> { + fn new(timer: Pin<&'a mut KeepAlive>, ka_dur: Duration, req_dur: Duration) -> Self { + Self { + timer, + state: TimerState::Idle, + ka_dur, + req_dur, + } + } + + fn reset_state(&mut self) { + self.state = TimerState::Idle; + } + + fn get(&mut self) -> Pin<&mut KeepAlive> { + self.timer.as_mut() + } + + // update timer with a given base instant value. the final deadline is calculated base on it. + fn update(&mut self, now: tokio::time::Instant) { + let dur = match self.state { + TimerState::Idle => { + self.state = TimerState::Wait; + self.ka_dur + } + TimerState::Wait => { + self.state = TimerState::Throttle; + self.req_dur } + TimerState::Throttle => return, + }; + self.timer.as_mut().update(now + dur) + } + + #[cold] + #[inline(never)] + fn map_to_err(&self) -> Error { + match self.state { + TimerState::Wait => Error::KeepAliveExpire, + TimerState::Throttle => Error::RequestTimeout, + TimerState::Idle => unreachable!(), } } } @@ -359,51 +573,3 @@ where } Ok(()) } - -pub(super) struct BodyReader { - pub(super) decoder: TransferCoding, - tx: BodySender, -} - -impl BodyReader { - pub(super) fn from_coding(decoder: TransferCoding) -> (Self, RequestBody) { - let (tx, body) = RequestBody::channel(decoder.is_eof()); - let body_reader = BodyReader { decoder, tx }; - (body_reader, body) - } - - // dispatcher MUST call this method before do any io reading. - // a none ready state means the body consumer either is in backpressure or don't expect body. - pub(super) async fn ready(&mut self, read_buf: &mut ReadBuf) { - loop { - match self.decoder.decode(&mut *read_buf) { - ChunkResult::Ok(bytes) => self.tx.feed_data(bytes), - ChunkResult::InsufficientData => match self.tx.ready().await { - Ok(_) => return, - // service future drop RequestBody so marker decoder to corrupted. - Err(_) => self.decoder.set_corrupted(), - }, - ChunkResult::OnEof => self.tx.feed_eof(), - ChunkResult::AlreadyEof | ChunkResult::Corrupted => pending().await, - ChunkResult::Err(e) => self.feed_error(e), - } - } - } - - // feed error to body sender and prepare for close connection. - #[cold] - #[inline(never)] - pub(super) fn feed_error(&mut self, e: io::Error) { - self.tx.feed_error(e); - self.decoder.set_corrupted(); - } - - // wait for service start to consume RequestBody. - pub(super) async fn wait_for_poll(&mut self) -> io::Result<()> { - // IMPORTANT: service future drop RequestBody so marker decoder to corrupted. - self.tx - .wait_for_poll() - .await - .inspect_err(|_| self.decoder.set_corrupted()) - } -} diff --git a/http/src/h1/dispatcher_compio.rs b/http/src/h1/dispatcher_compio.rs deleted file mode 100644 index a40eeb77f..000000000 --- a/http/src/h1/dispatcher_compio.rs +++ /dev/null @@ -1,451 +0,0 @@ -use core::{ - cell::RefCell, - future::poll_fn, - marker::PhantomData, - mem, - net::SocketAddr, - pin::{Pin, pin}, - task::{self, Poll, Waker, ready}, -}; - -use std::{io, rc::Rc}; - -use compio_buf::{BufResult, IntoInner, IoBuf}; -use compio_io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use compio_net::TcpStream; -use futures_core::stream::Stream; -use pin_project_lite::pin_project; -use xitca_service::Service; -use xitca_unsafe_collection::futures::SelectOutput; - -use crate::{ - bytes::{Bytes, BytesMut}, - date::DateTime, - h1::{body::RequestBody, error::Error}, - http::response::Response, -}; - -use super::{ - dispatcher::handle_error, - proto::{ - codec::{ChunkResult, TransferCoding}, - context::Context, - encode::CONTINUE_BYTES, - }, -}; - -type ExtRequest = crate::http::Request>; - -/// Http/1 dispatcher -pub struct Dispatcher<'a, S, ReqB, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize> { - io: SharedIo, - ctx: Context<'a, D, H_LIMIT>, - service: &'a S, - _phantom: PhantomData, -} - -trait BufIo { - fn read(self, io: &TcpStream) -> impl Future, Self)>; - - fn write(self, io: &TcpStream) -> impl Future, Self)>; -} - -impl BufIo for BytesMut { - async fn read(mut self, mut io: &TcpStream) -> (io::Result, Self) { - let len = self.len(); - - self.reserve(4096); - - let BufResult(res, buf) = (&mut io).read(self.slice(len..)).await; - (res, buf.into_inner()) - } - - async fn write(self, mut io: &TcpStream) -> (io::Result<()>, Self) { - let BufResult(res, mut buf) = (&mut io).write_all(self).await; - buf.clear(); - (res, buf) - } -} - -impl<'a, S, ReqB, ResB, BE, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize> - Dispatcher<'a, S, ReqB, D, H_LIMIT, R_LIMIT, W_LIMIT> -where - S: Service, Response = Response>, - ReqB: From, - ResB: Stream>, - D: DateTime, -{ - pub async fn run(io: TcpStream, addr: SocketAddr, service: &'a S, date: &'a D) -> Result<(), Error> { - let mut dispatcher = Dispatcher::<_, _, _, H_LIMIT, R_LIMIT, W_LIMIT> { - io: SharedIo::new(io), - ctx: Context::with_addr(addr, date), - service, - _phantom: PhantomData, - }; - - let mut read_buf = BytesMut::new(); - let mut write_buf = BytesMut::new(); - - loop { - let (res, r_buf, w_buf) = dispatcher._run(read_buf, write_buf).await; - read_buf = r_buf; - write_buf = w_buf; - if let Err(err) = res { - handle_error(&mut dispatcher.ctx, &mut write_buf, err)?; - } - - let (res, w_buf) = write_buf.write(dispatcher.io.io()).await; - write_buf = w_buf; - - res?; - - if dispatcher.ctx.is_connection_closed() { - return dispatcher.shutdown().await; - } - } - } - - async fn _run( - &mut self, - mut read_buf: BytesMut, - mut write_buf: BytesMut, - ) -> (Result<(), Error>, BytesMut, BytesMut) { - let (res, r_buf) = read_buf.read(self.io.io()).await; - read_buf = r_buf; - match res { - Ok(read) => { - if read == 0 { - self.ctx.set_close(); - return (Ok(()), read_buf, write_buf); - } - } - Err(e) => return (Err(e.into()), read_buf, write_buf), - } - - loop { - let (req, decoder) = match self.ctx.decode_head::(&mut read_buf) { - Ok(Some(req)) => req, - Ok(None) => break, - Err(e) => return (Err(e.into()), read_buf, write_buf), - }; - - let (wait_for_notify, body) = if decoder.is_eof() { - (false, RequestBody::none()) - } else { - let body = body( - self.io.notifier(), - self.ctx.is_expect_header(), - R_LIMIT, - decoder, - read_buf.split(), - ); - - (true, body) - }; - - let req = req.map(|ext| ext.map_body(|_| ReqB::from(body))); - - let (parts, body) = match self.service.call(req).await { - Ok(res) => res.into_parts(), - Err(e) => return (Err(Error::Service(e)), read_buf, write_buf), - }; - - let mut encoder = match self.ctx.encode_head(parts, &body, &mut write_buf) { - Ok(encoder) => encoder, - Err(e) => return (Err(e.into()), read_buf, write_buf), - }; - - // this block is necessary. ResB has to be dropped asap as it may hold ownership of - // Body type which if not dropped before Notifier::notify is called would prevent - // Notifier from waking up Notify. - { - let mut body = pin!(body); - - loop { - let res = poll_fn(|cx| match body.as_mut().poll_next(cx) { - Poll::Ready(res) => Poll::Ready(SelectOutput::A(res)), - Poll::Pending if write_buf.is_empty() => Poll::Pending, - Poll::Pending => Poll::Ready(SelectOutput::B(())), - }) - .await; - - match res { - SelectOutput::A(Some(Ok(bytes))) => { - encoder.encode(bytes, &mut write_buf); - if write_buf.len() < W_LIMIT { - continue; - } - } - SelectOutput::A(Some(Err(e))) => { - let (res, w_buf) = self.on_body_error(e, write_buf).await; - write_buf = w_buf; - return (res, read_buf, write_buf); - } - SelectOutput::A(None) => break encoder.encode_eof(&mut write_buf), - SelectOutput::B(_) => {} - } - - let (res, w_buf) = write_buf.write(self.io.io()).await; - write_buf = w_buf; - if let Err(e) = res { - return (Err(e.into()), read_buf, write_buf); - } - } - } - - if wait_for_notify { - match self.io.wait().await { - Some(r_buf) => read_buf = r_buf, - None => { - self.ctx.set_close(); - break; - } - } - } - } - - (Ok(()), read_buf, write_buf) - } - - #[cold] - #[inline(never)] - async fn shutdown(self) -> Result<(), Error> { - self.io.take().shutdown().await.map_err(Into::into) - } - - #[cold] - #[inline(never)] - async fn on_body_error(&mut self, e: BE, write_buf: BytesMut) -> (Result<(), Error>, BytesMut) { - let (res, write_buf) = write_buf.write(self.io.io()).await; - let e = res.err().map(Error::from).unwrap_or(Error::Body(e)); - (Err(e), write_buf) - } -} - -fn body( - io: NotifierIo, - is_expect: bool, - limit: usize, - decoder: TransferCoding, - read_buf: BytesMut, -) -> RequestBody { - let body = BodyInner { - io, - decoder: Decoder { - decoder, - limit, - read_buf, - }, - }; - - let state = if is_expect { - State::ExpectWrite { - fut: async { - let BufResult(res, _) = body.io.io().write_all(BytesMut::from(CONTINUE_BYTES)).await; - res.map(|_| body) - }, - } - } else { - State::Body { body } - }; - - RequestBody::stream(BodyReader { chunk_read, state }) -} - -pin_project! { - #[project = StateProj] - #[project_replace = StateProjReplace] - enum State { - Body { - body: BodyInner - }, - ChunkRead { - #[pin] - fut: FutC - }, - ExpectWrite { - #[pin] - fut: FutE, - }, - None, - } -} - -pin_project! { - struct BodyReader< F, FutC, FutE> { - chunk_read: F, - #[pin] - state: State - } -} - -struct BodyInner { - io: NotifierIo, - decoder: Decoder, -} - -async fn chunk_read(mut body: BodyInner) -> io::Result<(usize, BodyInner)> { - let (res, r_buf) = body.decoder.read_buf.split().read(body.io.io()).await; - body.decoder.read_buf.unsplit(r_buf); - let read = res?; - Ok((read, body)) -} - -impl Stream for BodyReader -where - F: Fn(BodyInner) -> FutC, - FutC: Future>, - FutE: Future>, -{ - type Item = io::Result; - - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - let mut this = self.project(); - loop { - match this.state.as_mut().project() { - StateProj::Body { body } => { - match body.decoder.decode() { - ChunkResult::Ok(bytes) => return Poll::Ready(Some(Ok(bytes))), - ChunkResult::Err(e) => return Poll::Ready(Some(Err(e))), - ChunkResult::InsufficientData => body.decoder.limit_check()?, - _ => return Poll::Ready(None), - } - - let StateProjReplace::Body { body } = this.state.as_mut().project_replace(State::None) else { - unreachable!() - }; - this.state.as_mut().project_replace(State::ChunkRead { - fut: (this.chunk_read)(body), - }); - } - StateProj::ChunkRead { fut } => { - let (read, body) = ready!(fut.poll(cx))?; - if read == 0 { - this.state.as_mut().project_replace(State::None); - return Poll::Ready(None); - } - this.state.as_mut().project_replace(State::Body { body }); - } - StateProj::ExpectWrite { fut } => { - let body = ready!(fut.poll(cx))?; - this.state.as_mut().project_replace(State::ChunkRead { - fut: (this.chunk_read)(body), - }); - } - StateProj::None => return Poll::Ready(None), - } - } - } -} - -impl Drop for BodyInner { - fn drop(&mut self) { - if self.decoder.decoder.is_eof() { - let buf = mem::take(&mut self.decoder.read_buf); - self.io.notify(buf); - } - } -} - -struct Decoder { - decoder: TransferCoding, - limit: usize, - read_buf: BytesMut, -} - -impl Decoder { - fn decode(&mut self) -> ChunkResult { - self.decoder.decode(&mut self.read_buf) - } - - fn limit_check(&self) -> io::Result<()> { - if self.read_buf.len() < self.limit { - return Ok(()); - } - - let msg = format!( - "READ_BUF_LIMIT reached: {{ limit: {}, length: {} }}", - self.limit, - self.read_buf.len() - ); - Err(io::Error::other(msg)) - } -} - -struct SharedIo { - inner: Rc<_SharedIo>, -} - -struct _SharedIo { - io: Io, - notify: RefCell>, -} - -impl SharedIo { - fn new(io: Io) -> Self { - Self { - inner: Rc::new(_SharedIo { - io, - notify: RefCell::new(Inner { waker: None, val: None }), - }), - } - } - - fn take(self) -> Io { - Rc::try_unwrap(self.inner) - .ok() - .expect("SharedIo must have exclusive ownership to Io when closing connection") - .io - } - - fn io(&self) -> &Io { - &self.inner.io - } - - fn notifier(&mut self) -> NotifierIo { - NotifierIo { - inner: self.inner.clone(), - } - } - - fn wait(&mut self) -> impl Future> { - poll_fn(|cx| { - let mut inner = self.inner.notify.borrow_mut(); - if let Some(val) = inner.val.take() { - return Poll::Ready(Some(val)); - } else if Rc::strong_count(&self.inner) == 1 { - return Poll::Ready(None); - } - inner.waker = Some(cx.waker().clone()); - Poll::Pending - }) - } -} - -struct NotifierIo { - inner: Rc<_SharedIo>, -} - -impl Drop for NotifierIo { - fn drop(&mut self) { - if let Some(waker) = self.inner.notify.borrow_mut().waker.take() { - waker.wake(); - } - } -} - -impl NotifierIo { - fn io(&self) -> &Io { - &self.inner.io - } - - fn notify(&mut self, val: BytesMut) { - self.inner.notify.borrow_mut().val = Some(val); - } -} - -struct Inner { - waker: Option, - val: Option, -} diff --git a/http/src/h1/dispatcher_uring.rs b/http/src/h1/dispatcher_uring.rs deleted file mode 100644 index 1d6f2d3d6..000000000 --- a/http/src/h1/dispatcher_uring.rs +++ /dev/null @@ -1,456 +0,0 @@ -use core::{ - cell::RefCell, - future::poll_fn, - marker::PhantomData, - mem, - net::SocketAddr, - pin::{Pin, pin}, - task::{self, Poll, Waker, ready}, -}; - -use std::{io, net::Shutdown, rc::Rc}; - -use futures_core::stream::Stream; -use pin_project_lite::pin_project; -use xitca_io::io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, write_all}; -use xitca_service::Service; -use xitca_unsafe_collection::futures::SelectOutput; - -use crate::{ - bytes::{Bytes, BytesMut}, - config::HttpServiceConfig, - date::DateTime, - h1::{body::RequestBody, error::Error}, - http::response::Response, - util::timer::{KeepAlive, Timeout}, -}; - -use super::{ - dispatcher::{Timer, handle_error}, - proto::{ - codec::{ChunkResult, TransferCoding}, - context::Context, - encode::CONTINUE_BYTES, - }, -}; - -type ExtRequest = crate::http::Request>; - -/// Http/1 dispatcher -pub(super) struct Dispatcher<'a, Io, S, ReqB, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize> { - io: Io, - notify: Notify, - timer: Timer<'a>, - ctx: Context<'a, D, H_LIMIT>, - service: &'a S, - _phantom: PhantomData, -} - -trait BufIo { - fn read(self, io: &impl AsyncBufRead) -> impl Future, Self)>; - - fn write(self, io: &impl AsyncBufWrite) -> impl Future, Self)>; -} - -impl BufIo for BytesMut { - async fn read(mut self, io: &impl AsyncBufRead) -> (io::Result, Self) { - let len = self.len(); - - self.reserve(4096); - - let (res, buf) = io.read(self.slice(len..)).await; - (res, buf.into_inner()) - } - - async fn write(self, io: &impl AsyncBufWrite) -> (io::Result<()>, Self) { - let (res, mut buf) = write_all(io, self).await; - buf.clear(); - (res, buf) - } -} - -impl<'a, Io, S, ReqB, ResB, BE, D, const H_LIMIT: usize, const R_LIMIT: usize, const W_LIMIT: usize> - Dispatcher<'a, Io, S, ReqB, D, H_LIMIT, R_LIMIT, W_LIMIT> -where - Io: AsyncBufRead + AsyncBufWrite + Clone + 'static, - S: Service, Response = Response>, - ReqB: From, - ResB: Stream>, - D: DateTime, -{ - pub(super) async fn run( - io: Io, - addr: SocketAddr, - timer: Pin<&'a mut KeepAlive>, - config: HttpServiceConfig, - service: &'a S, - date: &'a D, - ) -> Result<(), Error> { - let mut dispatcher = Dispatcher::<_, _, _, _, H_LIMIT, R_LIMIT, W_LIMIT> { - io, - notify: Notify::default(), - timer: Timer::new(timer, config.keep_alive_timeout, config.request_head_timeout), - ctx: Context::with_addr(addr, date), - service, - _phantom: PhantomData, - }; - - let mut read_buf = BytesMut::new(); - let mut write_buf = BytesMut::new(); - - loop { - let (res, r_buf, w_buf) = dispatcher._run(read_buf, write_buf).await; - read_buf = r_buf; - write_buf = w_buf; - if let Err(err) = res { - handle_error(&mut dispatcher.ctx, &mut write_buf, err)?; - } - - let (res, w_buf) = write_buf.write(&dispatcher.io).await; - write_buf = w_buf; - - res?; - - if dispatcher.ctx.is_connection_closed() { - return dispatcher.shutdown(); - } - } - } - - async fn _run( - &mut self, - mut read_buf: BytesMut, - mut write_buf: BytesMut, - ) -> (Result<(), Error>, BytesMut, BytesMut) { - self.timer.update(self.ctx.date().now()); - - match read_buf.read(&self.io).timeout(self.timer.get()).await { - Ok((res, r_buf)) => { - read_buf = r_buf; - match res { - Ok(read) => { - if read == 0 { - self.ctx.set_close(); - return (Ok(()), read_buf, write_buf); - } - } - Err(e) => return (Err(e.into()), read_buf, write_buf), - } - } - // read_buf is lost during timeout cancel. make an empty new one instead - Err(_) => return (Err(self.timer.map_to_err()), BytesMut::new(), write_buf), - } - - loop { - let (req, decoder) = match self.ctx.decode_head::(&mut read_buf) { - Ok(Some(req)) => req, - Ok(None) => break, - Err(e) => return (Err(e.into()), read_buf, write_buf), - }; - - self.timer.reset_state(); - - let (wait_for_notify, body) = if decoder.is_eof() { - (false, RequestBody::none()) - } else { - let body = body( - self.io.clone(), - self.notify.notifier(), - self.ctx.is_expect_header(), - R_LIMIT, - decoder, - read_buf.split(), - ); - - (true, body) - }; - - let req = req.map(|ext| ext.map_body(|_| ReqB::from(body))); - - let (parts, body) = match self.service.call(req).await { - Ok(res) => res.into_parts(), - Err(e) => return (Err(Error::Service(e)), read_buf, write_buf), - }; - - let mut encoder = match self.ctx.encode_head(parts, &body, &mut write_buf) { - Ok(encoder) => encoder, - Err(e) => return (Err(e.into()), read_buf, write_buf), - }; - - // this block is necessary. ResB has to be dropped asap as it may hold ownership of - // Body type which if not dropped before Notifier::notify is called would prevent - // Notifier from waking up Notify. - { - let mut body = pin!(body); - - loop { - let res = poll_fn(|cx| match body.as_mut().poll_next(cx) { - Poll::Ready(res) => Poll::Ready(SelectOutput::A(res)), - Poll::Pending if write_buf.is_empty() => Poll::Pending, - Poll::Pending => Poll::Ready(SelectOutput::B(())), - }) - .await; - - match res { - SelectOutput::A(Some(Ok(bytes))) => { - encoder.encode(bytes, &mut write_buf); - if write_buf.len() < W_LIMIT { - continue; - } - } - SelectOutput::A(Some(Err(e))) => { - let (res, w_buf) = self.on_body_error(e, write_buf).await; - write_buf = w_buf; - return (res, read_buf, write_buf); - } - SelectOutput::A(None) => break encoder.encode_eof(&mut write_buf), - SelectOutput::B(_) => {} - } - - let (res, w_buf) = write_buf.write(&self.io).await; - write_buf = w_buf; - if let Err(e) = res { - return (Err(e.into()), read_buf, write_buf); - } - } - } - - if wait_for_notify { - match self.notify.wait().await { - Some(r_buf) => read_buf = r_buf, - None => { - self.ctx.set_close(); - break; - } - } - } - } - - (Ok(()), read_buf, write_buf) - } - - #[cold] - #[inline(never)] - fn shutdown(self) -> Result<(), Error> { - self.io.shutdown(Shutdown::Both).map_err(Into::into) - } - - #[cold] - #[inline(never)] - async fn on_body_error(&mut self, e: BE, write_buf: BytesMut) -> (Result<(), Error>, BytesMut) { - let (res, write_buf) = write_buf.write(&self.io).await; - let e = res.err().map(Error::from).unwrap_or(Error::Body(e)); - (Err(e), write_buf) - } -} - -fn body( - io: Io, - notifier: Notifier, - is_expect: bool, - limit: usize, - decoder: TransferCoding, - read_buf: BytesMut, -) -> RequestBody -where - Io: AsyncBufRead + AsyncBufWrite + 'static, -{ - let body = BodyInner { - io, - notifier, - decoder: Decoder { - decoder, - limit, - read_buf, - }, - }; - - let state = if is_expect { - State::ExpectWrite { - fut: async { - let (res, _) = write_all(&body.io, CONTINUE_BYTES).await; - res.map(|_| body) - }, - } - } else { - State::Body { body } - }; - - RequestBody::stream(BodyReader { chunk_read, state }) -} - -pin_project! { - #[project = StateProj] - #[project_replace = StateProjReplace] - enum State { - Body { - body: BodyInner - }, - ChunkRead { - #[pin] - fut: FutC - }, - ExpectWrite { - #[pin] - fut: FutE, - }, - None, - } -} - -pin_project! { - struct BodyReader { - chunk_read: F, - #[pin] - state: State - } -} - -struct BodyInner { - io: Io, - notifier: Notifier, - decoder: Decoder, -} - -async fn chunk_read(mut body: BodyInner) -> io::Result<(usize, BodyInner)> -where - Io: AsyncBufRead, -{ - let (res, r_buf) = body.decoder.read_buf.split().read(&body.io).await; - body.decoder.read_buf.unsplit(r_buf); - let read = res?; - Ok((read, body)) -} - -impl Stream for BodyReader -where - Io: AsyncBufRead, - F: Fn(BodyInner) -> FutC, - FutC: Future)>>, - FutE: Future>>, -{ - type Item = io::Result; - - #[inline] - fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - let mut this = self.project(); - loop { - match this.state.as_mut().project() { - StateProj::Body { body } => { - match body.decoder.decode() { - ChunkResult::Ok(bytes) => return Poll::Ready(Some(Ok(bytes))), - ChunkResult::Err(e) => return Poll::Ready(Some(Err(e))), - ChunkResult::InsufficientData => body.decoder.limit_check()?, - _ => return Poll::Ready(None), - } - - let StateProjReplace::Body { body } = this.state.as_mut().project_replace(State::None) else { - unreachable!() - }; - this.state.as_mut().project_replace(State::ChunkRead { - fut: (this.chunk_read)(body), - }); - } - StateProj::ChunkRead { fut } => { - let (read, body) = ready!(fut.poll(cx))?; - if read == 0 { - this.state.as_mut().project_replace(State::None); - return Poll::Ready(None); - } - this.state.as_mut().project_replace(State::Body { body }); - } - StateProj::ExpectWrite { fut } => { - let body = ready!(fut.poll(cx))?; - this.state.as_mut().project_replace(State::ChunkRead { - fut: (this.chunk_read)(body), - }); - } - StateProj::None => return Poll::Ready(None), - } - } - } -} - -impl Drop for BodyInner { - fn drop(&mut self) { - if self.decoder.decoder.is_eof() { - let buf = mem::take(&mut self.decoder.read_buf); - self.notifier.notify(buf); - } - } -} - -struct Decoder { - decoder: TransferCoding, - limit: usize, - read_buf: BytesMut, -} - -impl Decoder { - fn decode(&mut self) -> ChunkResult { - self.decoder.decode(&mut self.read_buf) - } - - fn limit_check(&self) -> io::Result<()> { - if self.read_buf.len() < self.limit { - return Ok(()); - } - - let msg = format!( - "READ_BUF_LIMIT reached: {{ limit: {}, length: {} }}", - self.limit, - self.read_buf.len() - ); - Err(io::Error::other(msg)) - } -} - -#[derive(Default)] -struct Notify { - inner: Rc>>, -} - -struct Notifier { - inner: Rc>>, -} - -impl Notify { - fn notifier(&mut self) -> Notifier { - Notifier { - inner: self.inner.clone(), - } - } - - fn wait(&mut self) -> impl Future> { - poll_fn(|cx| { - let mut inner = self.inner.borrow_mut(); - if let Some(val) = inner.val.take() { - return Poll::Ready(Some(val)); - } else if Rc::strong_count(&self.inner) == 1 { - return Poll::Ready(None); - } - inner.waker = Some(cx.waker().clone()); - Poll::Pending - }) - } -} - -impl Drop for Notifier { - fn drop(&mut self) { - if let Some(waker) = self.inner.borrow_mut().waker.take() { - waker.wake(); - } - } -} - -impl Notifier { - fn notify(&mut self, val: BytesMut) { - self.inner.borrow_mut().val = Some(val); - } -} - -#[derive(Default)] -struct Inner { - waker: Option, - val: Option, -} diff --git a/http/src/h1/mod.rs b/http/src/h1/mod.rs index 1f5375416..3508c8f62 100644 --- a/http/src/h1/mod.rs +++ b/http/src/h1/mod.rs @@ -3,10 +3,7 @@ pub mod dispatcher_unreal; pub mod proto; -#[cfg(feature = "compio")] -pub mod dispatcher_compio; - -pub(crate) mod dispatcher; +pub mod dispatcher; mod body; mod builder; @@ -16,6 +13,3 @@ mod service; pub use self::body::RequestBody; pub use self::error::Error; pub use self::service::H1Service; - -#[cfg(feature = "io-uring")] -mod dispatcher_uring; diff --git a/http/src/h1/service.rs b/http/src/h1/service.rs index d1dbcc0e1..647f237af 100644 --- a/http/src/h1/service.rs +++ b/http/src/h1/service.rs @@ -1,7 +1,7 @@ use core::{net::SocketAddr, pin::pin}; use futures_core::stream::Stream; -use xitca_io::io::AsyncIo; +use xitca_io::io::{AsyncBufRead, AsyncBufWrite}; use xitca_service::Service; use crate::{ @@ -22,8 +22,7 @@ impl>, Response = Response>, A: Service, - St: AsyncIo, - A::Response: AsyncIo, + A::Response: AsyncBufRead + AsyncBufWrite + 'static, B: Stream>, HttpServiceError: From, { @@ -34,78 +33,6 @@ where // at this stage keep-alive timer is used to tracks tls accept timeout. let mut timer = pin!(self.keep_alive()); - let mut io = self - .tls_acceptor - .call(io) - .timeout(timer.as_mut()) - .await - .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - - super::dispatcher::run(&mut io, addr, timer, self.config, &self.service, self.date.get()) - .await - .map_err(Into::into) - } -} - -#[cfg(feature = "io-uring")] -use { - xitca_io::{ - io_uring::{AsyncBufRead, AsyncBufWrite}, - net::io_uring::TcpStream, - }, - xitca_service::ready::ReadyService, -}; - -#[cfg(feature = "io-uring")] -use crate::{ - config::HttpServiceConfig, - date::{DateTime, DateTimeService}, - util::timer::KeepAlive, -}; - -#[cfg(feature = "io-uring")] -pub struct H1UringService { - pub(crate) config: HttpServiceConfig, - pub(crate) date: DateTimeService, - pub(crate) service: S, - pub(crate) tls_acceptor: A, -} - -#[cfg(feature = "io-uring")] -impl - H1UringService -{ - pub(super) fn new( - config: HttpServiceConfig, - service: S, - tls_acceptor: A, - ) -> Self { - Self { - config, - date: DateTimeService::new(), - service, - tls_acceptor, - } - } -} - -#[cfg(feature = "io-uring")] -impl - Service<(TcpStream, SocketAddr)> for H1UringService -where - S: Service>, Response = Response>, - A: Service, - A::Response: AsyncBufRead + AsyncBufWrite + Clone + 'static, - B: Stream>, - HttpServiceError: From, -{ - type Response = (); - type Error = HttpServiceError; - async fn call(&self, (io, addr): (TcpStream, SocketAddr)) -> Result { - let accept_dur = self.config.tls_accept_timeout; - let deadline = self.date.get().now() + accept_dur; - let mut timer = pin!(KeepAlive::new(deadline)); - let io = self .tls_acceptor .call(io) @@ -113,22 +40,8 @@ where .await .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - super::dispatcher_uring::Dispatcher::run(io, addr, timer, self.config, &self.service, self.date.get()) + super::dispatcher::Dispatcher::run(io, addr, timer, self.config, &self.service, self.date.get()) .await .map_err(Into::into) } } - -#[cfg(feature = "io-uring")] -impl ReadyService - for H1UringService -where - S: ReadyService, -{ - type Ready = S::Ready; - - #[inline] - async fn ready(&self) -> Self::Ready { - self.service.ready().await - } -} diff --git a/http/src/h2/dispatcher_uring.rs b/http/src/h2/dispatcher_uring.rs index 5c0cacf4f..fd9ca832c 100644 --- a/http/src/h2/dispatcher_uring.rs +++ b/http/src/h2/dispatcher_uring.rs @@ -19,7 +19,7 @@ use futures_core::stream::Stream; use tracing::error; use xitca_io::{ bytes::{Buf, BufMut, BytesMut}, - io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, write_all}, + io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, write_all}, }; use xitca_service::Service; use xitca_unsafe_collection::{ diff --git a/http/src/h2/service.rs b/http/src/h2/service.rs index 9a91471f8..744067c84 100644 --- a/http/src/h2/service.rs +++ b/http/src/h2/service.rs @@ -80,7 +80,7 @@ pub(crate) use io_uring::H2UringService; mod io_uring { use { xitca_io::{ - io_uring::{AsyncBufRead, AsyncBufWrite}, + io::{AsyncBufRead, AsyncBufWrite}, net::io_uring::TcpStream, }, xitca_service::ready::ReadyService, diff --git a/http/src/service.rs b/http/src/service.rs index 703a814bf..feafd1fa5 100644 --- a/http/src/service.rs +++ b/http/src/service.rs @@ -2,7 +2,7 @@ use core::{fmt, marker::PhantomData, pin::pin}; use futures_core::Stream; use xitca_io::{ - io::AsyncIo, + io::{AsyncBufRead, AsyncBufWrite, AsyncIo}, net::{Stream as ServerStream, TcpStream}, }; use xitca_service::{Service, ready::ReadyService}; @@ -74,7 +74,7 @@ impl>, Response = Response>, A: Service, - A::Response: AsyncIo + AsVersion, + A::Response: AsyncIo + AsVersion + AsyncBufRead + AsyncBufWrite + 'static, HttpServiceError: From, S::Error: fmt::Debug, ResB: Stream>, @@ -113,16 +113,18 @@ where match version { #[cfg(feature = "http1")] - super::http::Version::HTTP_11 | super::http::Version::HTTP_10 => super::h1::dispatcher::run( - &mut _tls_stream, - _addr, - timer.as_mut(), - self.config, - &self.service, - self.date.get(), - ) - .await - .map_err(From::from), + super::http::Version::HTTP_11 | super::http::Version::HTTP_10 => { + super::h1::dispatcher::Dispatcher::run( + _tls_stream, + _addr, + timer.as_mut(), + self.config, + &self.service, + self.date.get(), + ) + .await + .map_err(From::from) + } #[cfg(feature = "http2")] super::http::Version::HTTP_2 => { // update timer to first request timeout. @@ -159,10 +161,10 @@ where #[cfg(feature = "http1")] { - let mut io = xitca_io::net::UnixStream::from_std(_io).expect("TODO: handle io error"); + let io = xitca_io::net::UnixStream::from_std(_io).expect("TODO: handle io error"); - super::h1::dispatcher::run( - &mut io, + super::h1::dispatcher::Dispatcher::run( + io, crate::unspecified_socket_addr(), timer.as_mut(), self.config, diff --git a/http/src/tls/mod.rs b/http/src/tls/mod.rs index bb8e3da63..42d14f1d2 100644 --- a/http/src/tls/mod.rs +++ b/http/src/tls/mod.rs @@ -10,8 +10,6 @@ pub(crate) mod native_tls; pub(crate) mod openssl; #[cfg(feature = "rustls")] pub(crate) mod rustls; -#[cfg(feature = "rustls-uring")] -pub(crate) mod rustls_uring; mod error; diff --git a/http/src/tls/rustls.rs b/http/src/tls/rustls.rs index 320ed05d8..635fecbec 100644 --- a/http/src/tls/rustls.rs +++ b/http/src/tls/rustls.rs @@ -1,10 +1,10 @@ -use core::{convert::Infallible, fmt}; +use core::{convert::Infallible, error, fmt}; -use std::{error, io, sync::Arc}; +use std::{io, net::Shutdown, sync::Arc}; -use xitca_io::io::AsyncIo; +use xitca_io::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; use xitca_service::Service; -use xitca_tls::rustls::{Error, ServerConfig, ServerConnection, TlsStream as _TlsStream}; +use xitca_tls::rustls_complete::{Error, ServerConfig, TlsStream as _TlsStream, server::UnbufferedServerConnection}; use crate::{http::Version, version::AsVersion}; @@ -13,17 +13,17 @@ use super::error::TlsError; pub(crate) type RustlsConfig = Arc; /// A stream managed by rustls for tls read/write. -pub type TlsStream = _TlsStream; +pub struct TlsStream { + inner: _TlsStream, +} -impl AsVersion for TlsStream -where - Io: AsyncIo, -{ +impl AsVersion for TlsStream { fn as_version(&self) -> Version { - self.session() - .alpn_protocol() - .map(Self::from_alpn) - .unwrap_or(Version::HTTP_11) + Version::HTTP_11 + // self.inner.session() + // .alpn_protocol() + // .map(Self::from_alpn) + // .unwrap_or(Version::HTTP_11) } } @@ -55,13 +55,47 @@ pub struct TlsAcceptorService { acceptor: Arc, } -impl Service for TlsAcceptorService { +impl Service for TlsAcceptorService +where + Io: AsyncBufRead + AsyncBufWrite, +{ type Response = TlsStream; type Error = RustlsError; async fn call(&self, io: Io) -> Result { - let conn = ServerConnection::new(self.acceptor.clone())?; - _TlsStream::handshake(io, conn).await.map_err(Into::into) + let conn = UnbufferedServerConnection::new(self.acceptor.clone())?; + let inner = _TlsStream::handshake(io, conn).await?; + Ok(TlsStream { inner }) + } +} + +impl AsyncBufRead for TlsStream +where + Io: AsyncBufRead, +{ + #[inline] + async fn read(&self, buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + self.inner.read(buf).await + } +} + +impl AsyncBufWrite for TlsStream +where + Io: AsyncBufWrite, +{ + #[inline] + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + self.inner.write(buf).await + } + + async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { + self.inner.shutdown(direction).await } } diff --git a/http/src/tls/rustls_uring.rs b/http/src/tls/rustls_uring.rs deleted file mode 100644 index 078f21ab4..000000000 --- a/http/src/tls/rustls_uring.rs +++ /dev/null @@ -1,94 +0,0 @@ -use core::convert::Infallible; - -use std::{io, net::Shutdown, sync::Arc}; - -use xitca_io::io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; -use xitca_service::Service; -use xitca_tls::rustls_uring::{ServerConfig, TlsStream as _TlsStream, server::UnbufferedServerConnection}; - -use crate::{http::Version, version::AsVersion}; - -use super::rustls::RustlsError; - -/// A stream managed by rustls for tls read/write. -pub struct TlsStream { - inner: _TlsStream, -} - -impl AsVersion for TlsStream { - fn as_version(&self) -> Version { - Version::HTTP_11 - } -} - -#[derive(Clone)] -pub struct TlsAcceptorBuilder { - acceptor: Arc, -} - -impl TlsAcceptorBuilder { - pub fn new(acceptor: Arc) -> Self { - Self { acceptor } - } -} - -impl Service for TlsAcceptorBuilder { - type Response = TlsAcceptorService; - type Error = Infallible; - - async fn call(&self, _: ()) -> Result { - let service = TlsAcceptorService { - acceptor: self.acceptor.clone(), - }; - Ok(service) - } -} - -/// Rustls Acceptor. Used to accept a unsecure Stream and upgrade it to a TlsStream. -pub struct TlsAcceptorService { - acceptor: Arc, -} - -impl Service for TlsAcceptorService -where - Io: AsyncBufRead + AsyncBufWrite, -{ - type Response = TlsStream; - type Error = RustlsError; - - async fn call(&self, io: Io) -> Result { - let conn = UnbufferedServerConnection::new(self.acceptor.clone())?; - let inner = _TlsStream::handshake(io, conn).await?; - Ok(TlsStream { inner }) - } -} - -impl AsyncBufRead for TlsStream -where - Io: AsyncBufRead, -{ - #[inline] - async fn read(&self, buf: B) -> (io::Result, B) - where - B: BoundedBufMut, - { - self.inner.read(buf).await - } -} - -impl AsyncBufWrite for TlsStream -where - Io: AsyncBufWrite, -{ - #[inline] - async fn write(&self, buf: B) -> (io::Result, B) - where - B: BoundedBuf, - { - self.inner.write(buf).await - } - - fn shutdown(&self, direction: Shutdown) -> io::Result<()> { - self.inner.shutdown(direction) - } -} diff --git a/io/CHANGES.md b/io/CHANGES.md index d5aea4381..70aa463c5 100644 --- a/io/CHANGES.md +++ b/io/CHANGES.md @@ -1,5 +1,9 @@ # unreleased 0.6.0 +## Add +- add `io::{AsyncBufWrite, AsyncBufRead}` impl for `net::{TcpStream, UnixStream}` + ## Change +- `io::AsyncBufWrite::shutdown` becomes async method - update `tokio-uring-xitca` to `0.2.0` # 0.5.1 diff --git a/io/Cargo.toml b/io/Cargo.toml index 87206b2fc..c6479edc7 100644 --- a/io/Cargo.toml +++ b/io/Cargo.toml @@ -14,7 +14,7 @@ default = [] # tokio runtime support runtime = ["tokio"] # tokio-uring runtime support -runtime-uring = ["tokio-uring-xitca/runtime"] +runtime-uring = ["runtime", "tokio-uring-xitca/runtime"] # quic support quic = ["dep:quinn", "runtime"] @@ -22,6 +22,10 @@ quic = ["dep:quinn", "runtime"] xitca-unsafe-collection = { version = "0.2.0", features = ["bytes"] } bytes = "1.4" +# Always required for its buffer traits (BoundedBuf, BoundedBufMut, etc.) +# and completion-based IO trait definitions. The io-uring runtime itself is +# NOT enabled unless the `runtime-uring` feature is activated. +tokio-uring-xitca = { version = "0.2.0", features = ["bytes"] } + quinn = { version = "0.11", features = ["ring"], optional = true } tokio = { version = "1.48", features = ["net"], optional = true } -tokio-uring-xitca = { version = "0.2.0", features = ["bytes"] } diff --git a/io/src/io.rs b/io/src/io.rs index 8e68b86f2..cc6dee9ce 100644 --- a/io/src/io.rs +++ b/io/src/io.rs @@ -1,253 +1,7 @@ -//! re-export of [tokio::io] types and extended AsyncIo trait on top of it. +mod complete; +pub use complete::*; -// TODO: io module should not re-export tokio types so AsyncIO trait does not depend on runtime -// crate feature. -pub use tokio::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; - -use core::{ - future::Future, - pin::Pin, - task::{Context, Poll, ready}, -}; - -use std::io; - -/// A wrapper trait for an [AsyncRead]/[AsyncWrite] tokio type with additional methods. -pub trait AsyncIo: io::Read + io::Write + Unpin { - /// asynchronously wait for the IO type and return it's state as [Ready]. - /// - /// # Errors: - /// - /// The only error cause of ready should be from runtime shutdown. Indicates no further - /// operations can be done. - /// - /// Actual IO error should be exposed from [std::io::Read]/[std::io::Write] methods. - /// - /// This constraint is from `tokio`'s behavior which is what xitca built upon and rely on - /// in downstream crates like `xitca-http` etc. - fn ready(&mut self, interest: Interest) -> impl Future> + Send; - - /// a poll version of ready method. - /// - /// # Why: - /// This is a temporary method for backward compat of [AsyncRead] and [AsyncWrite] traits. - fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll>; - - /// hint if IO can be vectored write. - /// - /// # Why: - /// std `can_vector` feature is not stabled yet and xitca make use of vectored io write. - fn is_vectored_write(&self) -> bool; - - /// poll shutdown the write part of Self. - /// - /// # Why: - /// tokio's network Stream types do not expose other api for shutdown besides [AsyncWrite::poll_shutdown]. - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; -} - -/// object safe version of [AsyncIo] trait. -pub trait AsyncIoDyn: io::Read + io::Write + Unpin { - fn ready(&mut self, interest: Interest) -> Pin> + Send + '_>>; - - fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll>; - - fn is_vectored_write(&self) -> bool; - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; -} - -impl AsyncIoDyn for Io -where - Io: AsyncIo, -{ - #[inline] - fn ready(&mut self, interest: Interest) -> Pin> + Send + '_>> { - Box::pin(AsyncIo::ready(self, interest)) - } - - #[inline] - fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { - AsyncIo::poll_ready(self, interest, cx) - } - - fn is_vectored_write(&self) -> bool { - AsyncIo::is_vectored_write(self) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - AsyncIo::poll_shutdown(self, cx) - } -} - -impl AsyncIo for Box -where - IoDyn: AsyncIoDyn + Send + ?Sized, -{ - #[inline] - async fn ready(&mut self, interest: Interest) -> io::Result { - AsyncIoDyn::ready(&mut **self, interest).await - } - - #[inline] - fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { - AsyncIoDyn::poll_ready(&mut **self, interest, cx) - } - - fn is_vectored_write(&self) -> bool { - AsyncIoDyn::is_vectored_write(&**self) - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - AsyncIoDyn::poll_shutdown(Pin::new(&mut **self.get_mut()), cx) - } -} - -fn _assert_object_safe(mut io: Box) { - let _ = io.read(&mut []); - let _ = io.write(&[]); -} - -/// adapter type for transforming a type impl [AsyncIo] trait to a type impl [AsyncRead] and [AsyncWrite] traits. -/// # Example -/// ```rust -/// use std::{future::poll_fn, pin::Pin}; -/// use xitca_io::io::{AsyncIo, AsyncRead, AsyncWrite, PollIoAdapter, ReadBuf}; -/// -/// async fn adapt(io: impl AsyncIo) { -/// // wrap async io type to adapter. -/// let mut poll_io = PollIoAdapter(io); -/// // use adaptor for polling based io operations. -/// poll_fn(|cx| Pin::new(&mut poll_io).poll_read(cx, &mut ReadBuf::new(&mut [0u8; 1]))).await; -/// poll_fn(|cx| Pin::new(&mut poll_io).poll_write(cx, b"996")).await; -/// } -/// ``` -pub struct PollIoAdapter(pub T) -where - T: AsyncIo; - -impl AsyncRead for PollIoAdapter -where - T: AsyncIo, -{ - fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { - let this = self.get_mut(); - loop { - ready!(this.0.poll_ready(Interest::READABLE, cx))?; - match io::Read::read(&mut this.0, buf.initialize_unfilled()) { - Ok(n) => { - buf.advance(n); - return Poll::Ready(Ok(())); - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => return Poll::Ready(Err(e)), - } - } - } -} - -impl AsyncWrite for PollIoAdapter -where - T: AsyncIo, -{ - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { - let this = self.get_mut(); - loop { - ready!(this.0.poll_ready(Interest::WRITABLE, cx))?; - match io::Write::write(&mut this.0, buf) { - Ok(n) => return Poll::Ready(Ok(n)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => return Poll::Ready(Err(e)), - } - } - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - loop { - ready!(this.0.poll_ready(Interest::WRITABLE, cx))?; - match io::Write::flush(&mut this.0) { - Ok(_) => return Poll::Ready(Ok(())), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => return Poll::Ready(Err(e)), - } - } - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - AsyncIo::poll_shutdown(Pin::new(&mut self.get_mut().0), cx) - } - - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - let this = self.get_mut(); - loop { - ready!(this.0.poll_ready(Interest::WRITABLE, cx))?; - match io::Write::write_vectored(&mut this.0, bufs) { - Ok(n) => return Poll::Ready(Ok(n)), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => return Poll::Ready(Err(e)), - } - } - } - - fn is_write_vectored(&self) -> bool { - self.0.is_vectored_write() - } -} - -impl AsyncIo for PollIoAdapter -where - Io: AsyncIo, -{ - #[inline(always)] - fn ready(&mut self, interest: Interest) -> impl Future> + Send { - self.0.ready(interest) - } - - #[inline(always)] - fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { - self.0.poll_ready(interest, cx) - } - - fn is_vectored_write(&self) -> bool { - self.0.is_vectored_write() - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.get_mut().0).poll_shutdown(cx) - } -} - -impl io::Write for PollIoAdapter -where - Io: AsyncIo, -{ - #[inline(always)] - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) - } - - #[inline(always)] - fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { - self.0.write_vectored(bufs) - } - - #[inline(always)] - fn flush(&mut self) -> io::Result<()> { - self.0.flush() - } -} - -impl io::Read for PollIoAdapter -where - Io: AsyncIo, -{ - #[inline(always)] - fn read(&mut self, buf: &mut [u8]) -> ::std::io::Result { - self.0.read(buf) - } -} +#[cfg(feature = "runtime")] +mod poll; +#[cfg(feature = "runtime")] +pub use poll::*; diff --git a/io/src/io_uring.rs b/io/src/io/complete.rs similarity index 57% rename from io/src/io_uring.rs rename to io/src/io/complete.rs index 491057dfc..46e6a2873 100644 --- a/io/src/io_uring.rs +++ b/io/src/io/complete.rs @@ -1,4 +1,9 @@ -//! async buffer io trait for linux io-uring feature with tokio-uring as runtime. +//! Completion-based async IO traits. +//! +//! These traits model IO where buffer ownership is transferred to the operation +//! and returned on completion — the pattern originated from io_uring but not +//! tied to any specific runtime. They can be implemented on top of epoll/kqueue +//! or any other async runtime. use core::future::Future; @@ -8,20 +13,26 @@ use tokio_uring_xitca::buf::IoBuf; pub use tokio_uring_xitca::buf::{BoundedBuf, BoundedBufMut, Slice}; +/// Async read trait with buffer ownership transfer. pub trait AsyncBufRead { + /// Read into a buffer, returning the result and the buffer. fn read(&self, buf: B) -> impl Future, B)> where B: BoundedBufMut; } +/// Async write trait with buffer ownership transfer. pub trait AsyncBufWrite { + /// Write from a buffer, returning the result and the buffer. fn write(&self, buf: B) -> impl Future, B)> where B: BoundedBuf; - fn shutdown(&self, direction: Shutdown) -> io::Result<()>; + /// Shutdown the connection in the given direction. + fn shutdown(&self, direction: Shutdown) -> impl Future>; } +/// Write all bytes from a buffer to IO. pub async fn write_all(io: &Io, buf: B) -> (io::Result<()>, B) where Io: AsyncBufWrite, diff --git a/io/src/io/poll.rs b/io/src/io/poll.rs new file mode 100644 index 000000000..42fa614a0 --- /dev/null +++ b/io/src/io/poll.rs @@ -0,0 +1,254 @@ +//! Poll-based async IO traits. +//! +//! These traits model the readiness-based IO pattern used by epoll/kqueue, +//! built on top of tokio's runtime. + +pub use tokio::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; + +use core::{ + future::Future, + pin::Pin, + task::{Context, Poll, ready}, +}; + +use std::io; + +/// A wrapper trait for an [AsyncRead]/[AsyncWrite] tokio type with additional methods. +pub trait AsyncIo: io::Read + io::Write + Unpin { + /// asynchronously wait for the IO type and return it's state as [Ready]. + /// + /// # Errors: + /// + /// The only error cause of ready should be from runtime shutdown. Indicates no further + /// operations can be done. + /// + /// Actual IO error should be exposed from [std::io::Read]/[std::io::Write] methods. + /// + /// This constraint is from `tokio`'s behavior which is what xitca built upon and rely on + /// in downstream crates like `xitca-http` etc. + fn ready(&mut self, interest: Interest) -> impl Future> + Send; + + /// a poll version of ready method. + /// + /// # Why: + /// This is a temporary method for backward compat of [AsyncRead] and [AsyncWrite] traits. + fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll>; + + /// hint if IO can be vectored write. + /// + /// # Why: + /// std `can_vector` feature is not stabled yet and xitca make use of vectored io write. + fn is_vectored_write(&self) -> bool; + + /// poll shutdown the write part of Self. + /// + /// # Why: + /// tokio's network Stream types do not expose other api for shutdown besides [AsyncWrite::poll_shutdown]. + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +/// object safe version of [AsyncIo] trait. +pub trait AsyncIoDyn: io::Read + io::Write + Unpin { + fn ready(&mut self, interest: Interest) -> Pin> + Send + '_>>; + + fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll>; + + fn is_vectored_write(&self) -> bool; + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +impl AsyncIoDyn for Io +where + Io: AsyncIo, +{ + #[inline] + fn ready(&mut self, interest: Interest) -> Pin> + Send + '_>> { + Box::pin(AsyncIo::ready(self, interest)) + } + + #[inline] + fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { + AsyncIo::poll_ready(self, interest, cx) + } + + fn is_vectored_write(&self) -> bool { + AsyncIo::is_vectored_write(self) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncIo::poll_shutdown(self, cx) + } +} + +impl AsyncIo for Box +where + IoDyn: AsyncIoDyn + Send + ?Sized, +{ + #[inline] + async fn ready(&mut self, interest: Interest) -> io::Result { + AsyncIoDyn::ready(&mut **self, interest).await + } + + #[inline] + fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { + AsyncIoDyn::poll_ready(&mut **self, interest, cx) + } + + fn is_vectored_write(&self) -> bool { + AsyncIoDyn::is_vectored_write(&**self) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncIoDyn::poll_shutdown(Pin::new(&mut **self.get_mut()), cx) + } +} + +fn _assert_object_safe(mut io: Box) { + let _ = io.read(&mut []); + let _ = io.write(&[]); +} + +/// adapter type for transforming a type impl [AsyncIo] trait to a type impl [AsyncRead] and [AsyncWrite] traits. +/// # Example +/// ```rust +/// use std::{future::poll_fn, pin::Pin}; +/// use xitca_io::io::{AsyncIo, AsyncRead, AsyncWrite, PollIoAdapter, ReadBuf}; +/// +/// async fn adapt(io: impl AsyncIo) { +/// // wrap async io type to adapter. +/// let mut poll_io = PollIoAdapter(io); +/// // use adaptor for polling based io operations. +/// poll_fn(|cx| Pin::new(&mut poll_io).poll_read(cx, &mut ReadBuf::new(&mut [0u8; 1]))).await; +/// poll_fn(|cx| Pin::new(&mut poll_io).poll_write(cx, b"996")).await; +/// } +/// ``` +pub struct PollIoAdapter(pub T) +where + T: AsyncIo; + +impl AsyncRead for PollIoAdapter +where + T: AsyncIo, +{ + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + let this = self.get_mut(); + loop { + ready!(this.0.poll_ready(Interest::READABLE, cx))?; + match io::Read::read(&mut this.0, buf.initialize_unfilled()) { + Ok(n) => { + buf.advance(n); + return Poll::Ready(Ok(())); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => return Poll::Ready(Err(e)), + } + } + } +} + +impl AsyncWrite for PollIoAdapter +where + T: AsyncIo, +{ + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let this = self.get_mut(); + loop { + ready!(this.0.poll_ready(Interest::WRITABLE, cx))?; + match io::Write::write(&mut this.0, buf) { + Ok(n) => return Poll::Ready(Ok(n)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + loop { + ready!(this.0.poll_ready(Interest::WRITABLE, cx))?; + match io::Write::flush(&mut this.0) { + Ok(_) => return Poll::Ready(Ok(())), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncIo::poll_shutdown(Pin::new(&mut self.get_mut().0), cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let this = self.get_mut(); + loop { + ready!(this.0.poll_ready(Interest::WRITABLE, cx))?; + match io::Write::write_vectored(&mut this.0, bufs) { + Ok(n) => return Poll::Ready(Ok(n)), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + fn is_write_vectored(&self) -> bool { + self.0.is_vectored_write() + } +} + +impl AsyncIo for PollIoAdapter +where + Io: AsyncIo, +{ + #[inline(always)] + fn ready(&mut self, interest: Interest) -> impl Future> + Send { + self.0.ready(interest) + } + + #[inline(always)] + fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(interest, cx) + } + + fn is_vectored_write(&self) -> bool { + self.0.is_vectored_write() + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_shutdown(cx) + } +} + +impl io::Write for PollIoAdapter +where + Io: AsyncIo, +{ + #[inline(always)] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + #[inline(always)] + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + self.0.write_vectored(bufs) + } + + #[inline(always)] + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl io::Read for PollIoAdapter +where + Io: AsyncIo, +{ + #[inline(always)] + fn read(&mut self, buf: &mut [u8]) -> ::std::io::Result { + self.0.read(buf) + } +} diff --git a/io/src/lib.rs b/io/src/lib.rs index 1b06280c4..ce7287bde 100644 --- a/io/src/lib.rs +++ b/io/src/lib.rs @@ -1,11 +1,9 @@ //! Async traits and types used for Io operations. -#![forbid(unsafe_code)] +#![deny(unsafe_code)] pub mod bytes; -#[cfg(feature = "runtime")] pub mod io; -#[cfg(feature = "runtime-uring")] -pub mod io_uring; + #[cfg(feature = "runtime")] pub mod net; diff --git a/io/src/net.rs b/io/src/net.rs index f7d335da5..190087665 100644 --- a/io/src/net.rs +++ b/io/src/net.rs @@ -21,6 +21,87 @@ use core::net::SocketAddr; macro_rules! default_aio_impl { ($ty: ty) => { + impl crate::io::AsyncBufRead for $ty { + #[allow(unsafe_code)] + async fn read(&self, mut buf: B) -> (::std::io::Result, B) + where + B: crate::io::BoundedBufMut, + { + let ready = self.0.ready(crate::io::Interest::READABLE).await; + + if let Err(e) = ready { + return (Err(e), buf); + } + + let init = buf.bytes_init(); + let total = buf.bytes_total(); + + // Safety: construct a mutable slice over the spare capacity. + // try_read writes contiguously from the start of the slice + // and returns the exact byte count written on Ok(n). + let spare = unsafe { ::core::slice::from_raw_parts_mut(buf.stable_mut_ptr().add(init), total - init) }; + + let mut written = 0; + + let res = loop { + if written == spare.len() { + break Ok(written); + } + + match self.0.try_read(&mut spare[written..]) { + Ok(0) => break Ok(written), + Ok(n) => written += n, + Err(e) if e.kind() == ::std::io::ErrorKind::WouldBlock => break Ok(written), + Err(e) => break Err(e), + } + }; + + // SAFETY: TcpStream::try_read has put written bytes into buf. + unsafe { + buf.set_init(init + written); + } + + (res, buf) + } + } + + impl crate::io::AsyncBufWrite for $ty { + async fn write(&self, buf: B) -> (::std::io::Result, B) + where + B: crate::io::BoundedBuf, + { + let ready = self.0.ready(crate::io::Interest::WRITABLE).await; + + if let Err(e) = ready { + return (Err(e), buf); + } + + let data = buf.chunk(); + + let mut written = 0; + + let res = loop { + if written == data.len() { + break Ok(written); + } + + match self.0.try_write(&data[written..]) { + Ok(0) => break Ok(written), + Ok(n) => written += n, + Err(e) if e.kind() == ::std::io::ErrorKind::WouldBlock => break Ok(written), + Err(e) => break Err(e), + } + }; + + (res, buf) + } + + async fn shutdown(&self, _direction: ::std::net::Shutdown) -> ::std::io::Result<()> { + // TODO: this is a no-op and shutdown is always handled by dropping the stream type + Ok(()) + } + } + impl crate::io::AsyncIo for $ty { #[inline] async fn ready(&mut self, interest: crate::io::Interest) -> ::std::io::Result { diff --git a/io/src/net/io_uring.rs b/io/src/net/io_uring.rs index f1f89895e..61c869470 100644 --- a/io/src/net/io_uring.rs +++ b/io/src/net/io_uring.rs @@ -7,7 +7,7 @@ pub use tokio_uring_xitca::net::TcpStream; #[cfg(unix)] pub use tokio_uring_xitca::net::UnixStream; -use crate::io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; +use crate::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; use super::Stream; @@ -51,7 +51,7 @@ impl AsyncBufWrite for TcpStream { } #[inline(always)] - fn shutdown(&self, direction: Shutdown) -> io::Result<()> { + async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { TcpStream::shutdown(self, direction) } } @@ -104,7 +104,7 @@ mod unix { } #[inline(always)] - fn shutdown(&self, direction: Shutdown) -> io::Result<()> { + async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { UnixStream::shutdown(self, direction) } } diff --git a/server/CHANGES.md b/server/CHANGES.md index 2f4bfa6c7..c9c98d663 100644 --- a/server/CHANGES.md +++ b/server/CHANGES.md @@ -1,4 +1,7 @@ -# unreleased +# unreleased 0.7.0 +## Change +- update `xitca-io` to `0.6.0` +- update `tokio-uring-xitca` to `0.2.0` # 0.6.1 ## Fix diff --git a/server/Cargo.toml b/server/Cargo.toml index 9fd350a67..42cd82a61 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xitca-server" -version = "0.6.1" +version = "0.7.0" edition = "2024" license = "Apache-2.0" description = "http server for xitca" @@ -17,7 +17,7 @@ quic = ["xitca-io/quic"] io-uring = ["dep:tokio-uring-xitca"] [dependencies] -xitca-io = { version = "0.5.1", features = ["runtime"] } +xitca-io = { version = "0.6.0", features = ["runtime"] } xitca-service = { version = "0.3.0", features = ["alloc"] } xitca-unsafe-collection = "0.2.0" @@ -25,7 +25,7 @@ tokio = { version = "1.48", features = ["sync", "time"] } tracing = { version = "0.1.40", default-features = false } # io-uring support -tokio-uring-xitca = { version = "0.1.1", optional = true } +tokio-uring-xitca = { version = "0.2.0", optional = true } [target.'cfg(not(target_family = "wasm"))'.dependencies] socket2 = { version = "0.6.0" } diff --git a/tls/CHANGES.md b/tls/CHANGES.md index 37e34bebd..eb94d85e3 100644 --- a/tls/CHANGES.md +++ b/tls/CHANGES.md @@ -1,5 +1,13 @@ -# unreleased 0.5.2 +# unreleased 0.6.0 +## Remove +- removed `rustls-uring` feature. + +## Add +- `rustls` feature would always carries completion asyn IO trait impl from `xitca-io` + ## Change +- rename `rustls-no-crypto` feature to `rustls` +- rename `rustls` feature to `rustls-aws-crypto` - internal change to reduce memory copy when `io-uring` feature enabled # 0.5.1 diff --git a/tls/Cargo.toml b/tls/Cargo.toml index ba4caffd8..63bb50040 100644 --- a/tls/Cargo.toml +++ b/tls/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xitca-tls" -version = "0.5.2" +version = "0.6.0" edition = "2024" license = "Apache-2.0" description = "tls utility for xitca" @@ -11,24 +11,17 @@ readme= "README.md" [features] openssl = ["dep:openssl"] - -# openssl for xitca-io io-uring traits -# openssl-uring = ["dep:openssl", "xitca-io/runtime-uring"] - +native-tls = ["dep:native_tls_crate"] # rustls with no default crypto provider -rustls-no-crypto = ["dep:rustls_crate"] +rustls = ["dep:rustls_crate"] # rustls with aws-lc as crypto provider (default provider from `rustls` crate) -rustls = ["rustls-no-crypto", "rustls_crate/aws-lc-rs"] +rustls-aws-crypto = ["rustls", "rustls_crate/aws-lc-rs", "xitca-io/runtime"] # rustls with ring as crypto provider -rustls-ring-crypto = ["rustls-no-crypto", "rustls_crate/ring"] - -# rustls with no crypto provider for xitca-io io-uring traits -rustls-uring-no-crypto = ["rustls-no-crypto", "xitca-io/runtime-uring"] -# rustls with aws-lc as crypto provider for xitca-io io-uring trait (default provider from `rustls` crate) -rustls-uring = ["rustls-uring-no-crypto", "rustls_crate/aws-lc-rs"] +rustls-ring-crypto = ["rustls", "rustls_crate/ring", "xitca-io/runtime"] [dependencies] -xitca-io = { version = "0.5.0", features = ["runtime"] } +xitca-io = { version = "0.6.0" } +native_tls_crate = { package = "native-tls", version = "0.2.7", features = ["alpn"], optional = true } openssl = { version = "0.10", optional = true } rustls_crate = { package = "rustls", version = "0.23", default-features = false, features = ["std", "tls12"], optional = true } diff --git a/tls/src/bridge.rs b/tls/src/bridge.rs new file mode 100644 index 000000000..64b1f2703 --- /dev/null +++ b/tls/src/bridge.rs @@ -0,0 +1,123 @@ +//! A synchronous IO bridge for adapting blocking TLS libraries (OpenSSL, native-tls) +//! to completion-based async IO traits. +//! +//! The bridge implements `std::io::Read + Write` using in-memory buffers. +//! TLS libraries read/write to the bridge synchronously, then the caller +//! drains/fills the buffers asynchronously via `AsyncBufRead`/`AsyncBufWrite`. + +use std::io; + +use xitca_io::{ + bytes::{Buf, BytesMut}, + io::{AsyncBufRead, AsyncBufWrite}, +}; + +/// A synchronous bridge that pairs in-memory buffers with an async IO handle. +/// +/// TLS libraries see `Read + Write` backed by `read_buf` and `write_buf`. +/// The caller uses [`fill_read_buf`] and [`drain_write_buf`] to move data +/// between the buffers and the underlying async IO. +pub(crate) struct SyncBridge { + /// Ciphertext read from the network, consumed by TLS `read`. + pub read_buf: BytesMut, + /// Ciphertext produced by TLS `write`, drained to the network. + pub write_buf: BytesMut, +} + +impl SyncBridge { + pub fn new() -> Self { + Self { + read_buf: BytesMut::new(), + write_buf: BytesMut::new(), + } + } +} + +impl io::Read for SyncBridge { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if self.read_buf.is_empty() { + return Err(io::ErrorKind::WouldBlock.into()); + } + let len = buf.len().min(self.read_buf.len()); + buf[..len].copy_from_slice(&self.read_buf[..len]); + self.read_buf.advance(len); + Ok(len) + } +} + +impl io::Write for SyncBridge { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.write_buf.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +/// Read ciphertext from the network into `bridge.read_buf`. +pub(crate) async fn fill_read_buf(io: &impl AsyncBufRead, bridge: &mut SyncBridge) -> io::Result<()> { + let len = bridge.read_buf.len(); + bridge.read_buf.reserve(4096); + + let (res, b) = io.read(bridge.read_buf.split_off(len)).await; + let returned = b; + + match res { + Ok(0) => { + bridge.read_buf.unsplit(returned); + Err(io::ErrorKind::UnexpectedEof.into()) + } + Ok(_) => { + bridge.read_buf.unsplit(returned); + Ok(()) + } + Err(e) => { + bridge.read_buf.unsplit(returned); + Err(e) + } + } +} + +/// Drain all ciphertext from `bridge.write_buf` to the network. +pub(crate) async fn drain_write_buf(io: &impl AsyncBufWrite, bridge: &mut SyncBridge) -> io::Result<()> { + if bridge.write_buf.is_empty() { + return Ok(()); + } + let buf = bridge.write_buf.split(); + let (res, b) = xitca_io::io::write_all(io, buf).await; + drop(b); + res +} + +/// Drain a pre-split write buffer to the network. +/// Used when the caller has already split the write_buf out of the bridge +/// (e.g. to drop a RefCell borrow before awaiting). +pub(crate) async fn drain_split(io: &impl AsyncBufWrite, buf: BytesMut) -> io::Result<()> { + if buf.is_empty() { + return Ok(()); + } + let (res, _) = xitca_io::io::write_all(io, buf).await; + res +} + +/// Split off a read buffer from the bridge for async filling. +/// Reserves space and returns the tail portion for IO. +/// After the read, unsplit the returned buffer back into `bridge.read_buf`. +pub(crate) fn take_read_buf(bridge: &mut SyncBridge) -> BytesMut { + let len = bridge.read_buf.len(); + bridge.read_buf.reserve(4096); + bridge.read_buf.split_off(len) +} + +/// Fill a pre-split read buffer from the network. +/// Always returns the buffer (even on error) so the caller can unsplit it back. +pub(crate) async fn fill_split(io: &impl AsyncBufRead, buf: BytesMut) -> (io::Result<()>, BytesMut) { + let (res, buf) = io.read(buf).await; + match res { + Ok(0) => (Err(io::ErrorKind::UnexpectedEof.into()), buf), + Ok(_) => (Ok(()), buf), + Err(e) => (Err(e), buf), + } +} diff --git a/tls/src/lib.rs b/tls/src/lib.rs index 03ded8fe5..ce73850cb 100644 --- a/tls/src/lib.rs +++ b/tls/src/lib.rs @@ -1,8 +1,13 @@ +#[cfg(feature = "native-tls")] +pub mod native_tls_complete; #[cfg(feature = "openssl")] pub mod openssl; -// #[cfg(feature = "openssl-uring")] -// pub mod openssl_uring; -#[cfg(any(feature = "rustls", feature = "rustls-ring-crypto", feature = "rustls-no-crypto"))] +#[cfg(feature = "openssl")] +pub mod openssl_complete; +#[cfg(any(feature = "rustls", feature = "rustls-ring-crypto", feature = "rustls-aws-crypto"))] pub mod rustls; -#[cfg(any(feature = "rustls-uring", feature = "rustls-uring-no-crypto"))] -pub mod rustls_uring; +#[cfg(feature = "rustls")] +pub mod rustls_complete; + +#[cfg(any(feature = "openssl", feature = "native-tls"))] +pub(crate) mod bridge; diff --git a/tls/src/native_tls_complete.rs b/tls/src/native_tls_complete.rs new file mode 100644 index 000000000..a222f4bc3 --- /dev/null +++ b/tls/src/native_tls_complete.rs @@ -0,0 +1,246 @@ +//! Completion-based async IO wrapper for native-tls TLS streams. + +use core::{cell::RefCell, fmt}; + +use std::{io, net::Shutdown}; + +use native_tls_crate::{HandshakeError, TlsAcceptor, TlsConnector}; + +use xitca_io::{ + bytes::BytesMut, + io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, +}; + +use crate::bridge::{self, SyncBridge}; + +/// A TLS stream using native-tls with completion-based async IO. +/// +/// Supports concurrent read + write from separate tasks. Concurrent read + read +/// or write + write will panic. +pub struct TlsStream { + io: Io, + session: RefCell, +} + +struct Session { + tls: native_tls_crate::TlsStream, + /// Taken by the read path. Carries network data across await points. + read_buf: Option, + /// Taken by the write path. Serves as a concurrent-write guard. + write_buf: Option, +} + +const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + /// Perform a TLS server-side accept handshake. + pub async fn accept(acceptor: &TlsAcceptor, io: Io) -> Result { + let bridge = SyncBridge::new(); + Self::handshake(io, acceptor.accept(bridge)).await + } + + /// Perform a TLS client-side connect handshake. + pub async fn connect(connector: &TlsConnector, domain: &str, io: Io) -> Result { + let bridge = SyncBridge::new(); + Self::handshake(io, connector.connect(domain, bridge)).await + } + + async fn handshake( + io: Io, + result: Result, HandshakeError>, + ) -> Result { + let mut mid = match result { + Ok(tls) => return Ok(Self::from_tls(io, tls)), + Err(HandshakeError::WouldBlock(mid)) => mid, + Err(HandshakeError::Failure(e)) => return Err(Error::Tls(e)), + }; + + loop { + bridge::drain_write_buf(&io, mid.get_mut()).await.map_err(Error::Io)?; + bridge::fill_read_buf(&io, mid.get_mut()).await.map_err(Error::Io)?; + + match mid.handshake() { + Ok(tls) => return Ok(Self::from_tls(io, tls)), + Err(HandshakeError::WouldBlock(m)) => mid = m, + Err(HandshakeError::Failure(e)) => return Err(Error::Tls(e)), + } + } + } + + fn from_tls(io: Io, tls: native_tls_crate::TlsStream) -> Self { + TlsStream { + io, + session: RefCell::new(Session { + tls, + read_buf: Some(BytesMut::new()), + write_buf: Some(BytesMut::new()), + }), + } + } +} + +impl AsyncBufRead for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn read(&self, mut buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + let init = buf.bytes_init(); + let total = buf.bytes_total(); + let spare = unsafe { core::slice::from_raw_parts_mut(buf.stable_mut_ptr().add(init), total - init) }; + + let res = self.read_tls(spare).await; + + if let Ok(n) = &res { + unsafe { buf.set_init(init + n) }; + } + + (res, buf) + } +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn read_tls(&self, buf: &mut [u8]) -> io::Result { + let mut session = self.session.borrow_mut(); + let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); + + // Return previously fetched network data to bridge. + session.tls.get_mut().read_buf.unsplit(read_buf); + + let res = loop { + match io::Read::read(&mut session.tls, buf) { + Ok(n) => { + let proto_data = session.tls.get_mut().write_buf.split(); + drop(session); + + let drain_res = bridge::drain_split(&self.io, proto_data).await; + + session = self.session.borrow_mut(); + if let Err(e) = drain_res { + break Err(e); + } + break Ok(n); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + let proto_data = session.tls.get_mut().write_buf.split(); + read_buf = bridge::take_read_buf(session.tls.get_mut()); + drop(session); + + let drain_res = bridge::drain_split(&self.io, proto_data).await; + let (fill_res, b) = bridge::fill_split(&self.io, read_buf).await; + read_buf = b; + + session = self.session.borrow_mut(); + session.tls.get_mut().read_buf.unsplit(read_buf); + + drain_res?; + if let Err(e) = fill_res { + break Err(e); + } + } + Err(e) => break Err(e), + } + }; + + session.read_buf = Some(BytesMut::new()); + res + } +} + +impl AsyncBufWrite for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + let data = buf.chunk(); + let res = self.write_tls(data).await; + (res, buf) + } + + async fn shutdown(&self, _direction: Shutdown) -> io::Result<()> { + Ok(()) + } +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn write_tls(&self, buf: &[u8]) -> io::Result { + let mut session = self.session.borrow_mut(); + let write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); + + let res = loop { + match io::Write::write(&mut session.tls, buf) { + Ok(n) => { + let ciphertext = session.tls.get_mut().write_buf.split(); + drop(session); + + let drain_res = bridge::drain_split(&self.io, ciphertext).await; + + session = self.session.borrow_mut(); + if let Err(e) = drain_res { + break Err(e); + } + break Ok(n); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + let ciphertext = session.tls.get_mut().write_buf.split(); + drop(session); + + let drain_res = bridge::drain_split(&self.io, ciphertext).await; + + session = self.session.borrow_mut(); + if let Err(e) = drain_res { + break Err(e); + } + } + Err(e) => break Err(e), + } + }; + + session.write_buf = Some(write_buf); + res + } +} + +/// Collection of native-tls error types. +#[derive(Debug)] +pub enum Error { + Io(io::Error), + Tls(native_tls_crate::Error), +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Self::Io(e) + } +} + +impl From for Error { + fn from(e: native_tls_crate::Error) -> Self { + Self::Tls(e) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(e) => fmt::Display::fmt(e, f), + Self::Tls(e) => fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for Error {} diff --git a/tls/src/openssl.rs b/tls/src/openssl.rs index 97de72d5c..541dd3553 100644 --- a/tls/src/openssl.rs +++ b/tls/src/openssl.rs @@ -1,5 +1,4 @@ use core::{ - fmt, future::Future, pin::Pin, task::{Context, Poll}, @@ -12,6 +11,8 @@ pub use openssl::*; use openssl::ssl::{ErrorCode, ShutdownResult, Ssl, SslRef, SslStream}; use xitca_io::io::{AsyncIo, Interest, Ready}; +pub use super::openssl_complete::Error; + /// A stream managed by `openssl` crate for tls read/write. pub struct TlsStream { io: SslStream, @@ -112,27 +113,3 @@ impl io::Write for TlsStream { io::Write::flush(&mut self.io) } } - -/// Collection of 'openssl' error types. -#[derive(Debug)] -pub enum Error { - Io(io::Error), - Tls(openssl::ssl::Error), -} - -impl From for Error { - fn from(e: openssl::error::ErrorStack) -> Self { - Self::Tls(e.into()) - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Io(e) => fmt::Display::fmt(e, f), - Self::Tls(e) => fmt::Display::fmt(e, f), - } - } -} - -impl std::error::Error for Error {} diff --git a/tls/src/openssl_complete.rs b/tls/src/openssl_complete.rs new file mode 100644 index 000000000..442517748 --- /dev/null +++ b/tls/src/openssl_complete.rs @@ -0,0 +1,188 @@ +//! Completion-based async IO wrapper for OpenSSL TLS streams. + +use core::{cell::RefCell, fmt}; + +use std::{io, net::Shutdown}; + +pub use openssl::*; + +use openssl::ssl::{ErrorCode, Ssl, SslRef, SslStream}; + +use xitca_io::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; + +use crate::bridge::{self, SyncBridge}; + +/// A TLS stream using OpenSSL with completion-based async IO. +pub struct TlsStream { + io: Io, + tls: RefCell>, +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + /// Perform a TLS server-side accept handshake. + pub async fn accept(ssl: Ssl, io: Io) -> Result { + Self::handshake(ssl, io, |tls| tls.accept()).await + } + + /// Perform a TLS client-side connect handshake. + pub async fn connect(ssl: Ssl, io: Io) -> Result { + Self::handshake(ssl, io, |tls| tls.connect()).await + } + + async fn handshake(ssl: Ssl, io: Io, mut func: F) -> Result + where + F: FnMut(&mut SslStream) -> Result<(), openssl::ssl::Error>, + { + let bridge = SyncBridge::new(); + let mut tls = SslStream::new(ssl, bridge)?; + + loop { + match func(&mut tls) { + Ok(_) => { + return Ok(TlsStream { + io, + tls: RefCell::new(tls), + }); + } + Err(ref e) if e.code() == ErrorCode::WANT_READ => { + bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; + bridge::fill_read_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; + } + Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { + bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; + } + Err(e) => return Err(Error::Tls(e)), + } + } + } + + /// Acquire a reference to the SSL session. + pub fn session(&self) -> &SslRef { + let tls = self.tls.borrow(); + // SAFETY: SslRef points into the heap-allocated OpenSSL context, not + // the RefCell guard. The context lives as long as self. + unsafe { &*(tls.ssl() as *const SslRef) } + } +} + +impl AsyncBufRead for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn read(&self, mut buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + let init = buf.bytes_init(); + let total = buf.bytes_total(); + let spare = unsafe { core::slice::from_raw_parts_mut(buf.stable_mut_ptr().add(init), total - init) }; + + let res = self.read_tls(spare).await; + + if let Ok(n) = &res { + unsafe { buf.set_init(init + n) }; + } + + (res, buf) + } +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn read_tls(&self, buf: &mut [u8]) -> io::Result { + loop { + let mut tls = self.tls.borrow_mut(); + match io::Read::read(&mut *tls, buf) { + Ok(n) => { + let write_data = tls.get_mut().write_buf.split(); + drop(tls); + bridge::drain_split(&self.io, write_data).await?; + return Ok(n); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + let write_data = tls.get_mut().write_buf.split(); + let read_buf = bridge::take_read_buf(tls.get_mut()); + drop(tls); + + bridge::drain_split(&self.io, write_data).await?; + let read_buf = bridge::fill_split(&self.io, read_buf).await?; + + self.tls.borrow_mut().get_mut().read_buf.unsplit(read_buf); + } + Err(e) => return Err(e), + } + } + } +} + +impl AsyncBufWrite for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + let data = buf.chunk(); + let res = self.write_tls(data).await; + (res, buf) + } + + async fn shutdown(&self, _direction: Shutdown) -> io::Result<()> { + Ok(()) + } +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn write_tls(&self, buf: &[u8]) -> io::Result { + loop { + let mut tls = self.tls.borrow_mut(); + match io::Write::write(&mut *tls, buf) { + Ok(n) => { + let write_data = tls.get_mut().write_buf.split(); + drop(tls); + bridge::drain_split(&self.io, write_data).await?; + return Ok(n); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + let write_data = tls.get_mut().write_buf.split(); + drop(tls); + bridge::drain_split(&self.io, write_data).await?; + } + Err(e) => return Err(e), + } + } + } +} + +/// Collection of OpenSSL error types. +#[derive(Debug)] +pub enum Error { + Io(io::Error), + Tls(openssl::ssl::Error), +} + +impl From for Error { + fn from(e: openssl::error::ErrorStack) -> Self { + Self::Tls(e.into()) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(e) => fmt::Display::fmt(e, f), + Self::Tls(e) => fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for Error {} diff --git a/tls/src/openssl_uring.rs b/tls/src/openssl_uring.rs deleted file mode 100644 index fddb30785..000000000 --- a/tls/src/openssl_uring.rs +++ /dev/null @@ -1,398 +0,0 @@ -use core::cell::RefCell; - -use std::{ - io::{self, Read, Write}, - net::Shutdown, - rc::Rc, -}; - -pub use openssl::*; - -use openssl::ssl::{ErrorCode, Ssl, SslRef, SslStream}; - -use xitca_io::{ - bytes::{Buf, BytesMut}, - io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, -}; - -/// A TLS stream backed by OpenSSL that implements completion-based IO traits. -/// -/// Uses a sync-to-async bridge: OpenSSL reads/writes against in-memory buffers -/// ([`SyncStream`]), and the actual socket IO is performed asynchronously. -/// -/// Supports concurrent read/write: `ssl_read`/`ssl_write` are synchronous -/// operations on in-memory buffers. The `RefCell` borrow is dropped before -/// any async socket IO, allowing the other path to proceed. -/// -/// # Panics -/// Each async read/write operation must be polled to completion. Dropping a future before it -/// completes will leave internal buffers in a taken state, causing the next call to panic. -/// Concurrent reads or concurrent writes (two reads at the same time, etc.) will also panic. -pub struct TlsStream { - io: Io, - session: Rc>, -} - -struct Session { - ssl: SslStream, - /// Protocol data produced by read path (key updates, alerts). - /// Flushed by write path before sending application data. - proto_write_buf: BytesMut, -} - -/// Synchronous stream adapter that OpenSSL reads from / writes to. -/// -/// Buffers are `Option` to detect concurrent misuse: each path -/// takes its buffer before async IO and replaces it after. A second concurrent -/// operation on the same path will find `None` and panic. -/// -/// `Read` pulls ciphertext from `read_buf`. `Write` appends ciphertext to -/// `write_buf`. Returns `WouldBlock` when `read_buf` is empty to signal -/// OpenSSL to yield. -struct SyncStream { - /// Ciphertext from the socket, consumed by OpenSSL during `ssl_read`. - read_buf: Option, - /// Ciphertext produced by OpenSSL, to be sent to the socket. - write_buf: Option, -} - -impl Read for SyncStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let read_buf = self.read_buf.as_mut().expect(POLL_TO_COMPLETE); - if read_buf.is_empty() { - return Err(io::ErrorKind::WouldBlock.into()); - } - let n = buf.len().min(read_buf.len()); - buf[..n].copy_from_slice(&read_buf[..n]); - read_buf.advance(n); - Ok(n) - } -} - -impl Write for SyncStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - let write_buf = self.write_buf.as_mut().expect(POLL_TO_COMPLETE); - write_buf.extend_from_slice(buf); - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -impl TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - /// Perform a TLS handshake as the server side. - pub async fn accept(ssl: Ssl, io: Io) -> Result { - let stream = Self::new(ssl, io)?; - stream.handshake(|ssl| ssl.accept()).await?; - Ok(stream) - } - - /// Perform a TLS handshake as the client side. - pub async fn connect(ssl: Ssl, io: Io) -> Result { - let stream = Self::new(ssl, io)?; - stream.handshake(|ssl| ssl.connect()).await?; - Ok(stream) - } - - fn new(ssl: Ssl, io: Io) -> Result { - let sync_stream = SyncStream { - read_buf: Some(BytesMut::new()), - write_buf: Some(BytesMut::new()), - }; - let ssl_stream = SslStream::new(ssl, sync_stream)?; - - Ok(TlsStream { - io, - session: Rc::new(RefCell::new(Session { - ssl: ssl_stream, - proto_write_buf: BytesMut::new(), - })), - }) - } - - /// Acquire a reference to the `SslRef` for inspecting the session. - pub fn session(&self) -> impl core::ops::Deref + '_ { - std::cell::Ref::map(self.session.borrow(), |s| s.ssl.ssl()) - } - - async fn handshake(&self, mut func: F) -> Result<(), Error> - where - F: FnMut(&mut SslStream) -> Result<(), openssl::ssl::Error>, - { - let mut session = self.session.borrow_mut(); - - loop { - match func(&mut session.ssl) { - Ok(()) => { - // Flush any remaining handshake data. - let sync = session.ssl.get_mut(); - if sync.write_buf.as_ref().is_some_and(|b| !b.is_empty()) { - let mut write_buf = sync.write_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = flush_write_buf(&self.io, write_buf).await; - write_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().write_buf = Some(write_buf); - res?; - } - return Ok(()); - } - Err(ref e) if e.code() == ErrorCode::WANT_READ => { - // Flush outgoing handshake data first. - let sync = session.ssl.get_mut(); - if sync.write_buf.as_ref().is_some_and(|b| !b.is_empty()) { - let mut write_buf = sync.write_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = flush_write_buf(&self.io, write_buf).await; - write_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().write_buf = Some(write_buf); - res?; - } - - // Read more ciphertext from the socket. - let mut read_buf = session.ssl.get_mut().read_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = read_to_buf(&self.io, read_buf).await; - read_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().read_buf = Some(read_buf); - res?; - } - Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { - // Flush outgoing handshake data. - let mut write_buf = session.ssl.get_mut().write_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = flush_write_buf(&self.io, write_buf).await; - write_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().write_buf = Some(write_buf); - res?; - } - Err(e) => return Err(Error::Tls(e)), - } - } - } - - /// Read plaintext by decrypting ciphertext from the socket. - /// - /// Protocol data produced during read (key updates, alerts) is buffered - /// in `proto_write_buf` and flushed by the next `write_tls` call. - async fn read_tls(&self, plain_buf: &mut impl BoundedBufMut) -> io::Result { - let mut session = self.session.borrow_mut(); - - loop { - let dst = io_ref_mut_slice(plain_buf); - match session.ssl.ssl_read(dst) { - Ok(n) => { - unsafe { plain_buf.set_init(n) }; - - // Drain protocol data into proto_write_buf for write path to flush. - drain_proto_write(&mut session); - - return Ok(n); - } - Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => return Ok(0), - Err(ref e) if e.code() == ErrorCode::WANT_READ => { - // Drain protocol data into proto_write_buf. - drain_proto_write(&mut session); - - // Read more ciphertext from the socket. - let mut read_buf = session.ssl.get_mut().read_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = read_to_buf(&self.io, read_buf).await; - read_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().read_buf = Some(read_buf); - res?; - } - Err(e) => { - return Err(e - .into_io_error() - .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))); - } - } - } - } - - /// Encrypt plaintext and write ciphertext to the socket. - /// - /// Flushes any protocol data buffered by the read path before writing. - async fn write_tls(&self, plain: &impl BoundedBuf) -> io::Result { - let mut session = self.session.borrow_mut(); - let plaintext = io_ref_slice(plain); - - // Flush protocol data from read path into write_buf. - let session_ref = &mut *session; - if !session_ref.proto_write_buf.is_empty() { - let write_buf = session_ref.ssl.get_mut().write_buf.as_mut().expect(POLL_TO_COMPLETE); - write_buf.extend_from_slice(&session_ref.proto_write_buf); - session_ref.proto_write_buf.clear(); - } - - loop { - match session.ssl.ssl_write(plaintext) { - Ok(n) => { - // Flush ciphertext to the socket. - let mut write_buf = session.ssl.get_mut().write_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = flush_write_buf(&self.io, write_buf).await; - write_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().write_buf = Some(write_buf); - res?; - - return Ok(n); - } - Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { - // Flush and retry. - let mut write_buf = session.ssl.get_mut().write_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = flush_write_buf(&self.io, write_buf).await; - write_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().write_buf = Some(write_buf); - res?; - } - Err(ref e) if e.code() == ErrorCode::WANT_READ => { - // Renegotiation — flush then read before retrying. - let sync = session.ssl.get_mut(); - if sync.write_buf.as_ref().is_some_and(|b| !b.is_empty()) { - let mut write_buf = sync.write_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = flush_write_buf(&self.io, write_buf).await; - write_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().write_buf = Some(write_buf); - res?; - } - - let mut read_buf = session.ssl.get_mut().read_buf.take().expect(POLL_TO_COMPLETE); - drop(session); - let (res, b) = read_to_buf(&self.io, read_buf).await; - read_buf = b; - session = self.session.borrow_mut(); - session.ssl.get_mut().read_buf = Some(read_buf); - res?; - } - Err(e) => { - return Err(e - .into_io_error() - .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))); - } - } - } - } -} - -/// Move protocol ciphertext produced during `ssl_read` to `proto_write_buf`. -/// The write path will flush it to the socket. -fn drain_proto_write(session: &mut Session) { - let sync = session.ssl.get_mut(); - if let Some(write_buf) = sync.write_buf.as_mut() { - if !write_buf.is_empty() { - session.proto_write_buf.extend_from_slice(write_buf); - write_buf.clear(); - } - } -} - -impl AsyncBufRead for TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn read(&self, mut buf: B) -> (io::Result, B) - where - B: BoundedBufMut, - { - let res = self.read_tls(&mut buf).await; - (res, buf) - } -} - -impl AsyncBufWrite for TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn write(&self, buf: B) -> (io::Result, B) - where - B: BoundedBuf, - { - let res = self.write_tls(&buf).await; - (res, buf) - } - - fn shutdown(&self, direction: Shutdown) -> io::Result<()> { - self.io.shutdown(direction) - } -} - -fn io_ref_slice(buf: &impl BoundedBuf) -> &[u8] { - unsafe { core::slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init()) } -} - -fn io_ref_mut_slice(buf: &mut impl BoundedBufMut) -> &mut [u8] { - unsafe { core::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total()) } -} - -/// Read from IO into a BytesMut, reserving space if needed. -async fn read_to_buf(io: &impl AsyncBufRead, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { - let len = buf.len(); - buf.reserve(4096); - - let (res, b) = io.read(buf.slice(len..)).await; - buf = b.into_inner(); - - match res { - Ok(0) => (Err(io::ErrorKind::UnexpectedEof.into()), buf), - Ok(_) => (Ok(()), buf), - Err(e) => (Err(e), buf), - } -} - -/// Write all bytes from a BytesMut to IO, then clear it. -async fn flush_write_buf(io: &impl AsyncBufWrite, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { - let (res, b) = xitca_io::io_uring::write_all(io, buf).await; - buf = b; - if res.is_ok() { - buf.clear(); - } - (res, buf) -} - -const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; - -/// Collection of OpenSSL error types. -#[derive(Debug)] -pub enum Error { - Io(io::Error), - Tls(openssl::ssl::Error), -} - -impl From for Error { - fn from(e: openssl::error::ErrorStack) -> Self { - Self::Tls(e.into()) - } -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Self::Io(e) - } -} - -impl core::fmt::Display for Error { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::Io(e) => core::fmt::Display::fmt(e, f), - Self::Tls(e) => core::fmt::Display::fmt(e, f), - } - } -} - -impl std::error::Error for Error {} diff --git a/tls/src/rustls_uring.rs b/tls/src/rustls_complete.rs similarity index 89% rename from tls/src/rustls_uring.rs rename to tls/src/rustls_complete.rs index b887264e5..617c76cd4 100644 --- a/tls/src/rustls_uring.rs +++ b/tls/src/rustls_complete.rs @@ -1,8 +1,8 @@ #![allow(clippy::await_holding_refcell_ref)] // clippy is dumb -use core::{cell::RefCell, slice}; +use core::{cell::RefCell, cmp, slice}; -use std::{io, net::Shutdown, rc::Rc}; +use std::{io, net::Shutdown}; pub use rustls_crate::*; @@ -15,7 +15,7 @@ use rustls_crate::{ use xitca_io::{ bytes::{Buf, BytesMut}, - io_uring::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, + io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, }; /// Trait to abstract over `UnbufferedServerConnection` and `UnbufferedClientConnection`, @@ -65,19 +65,7 @@ impl ProcessTlsRecords for UnbufferedClientConnection { /// completes will leave internal buffers in a taken state, causing the next call to panic. pub struct TlsStream { io: Io, - session: Rc>>, -} - -impl Clone for TlsStream -where - Io: Clone, -{ - fn clone(&self) -> Self { - Self { - io: self.io.clone(), - session: self.session.clone(), - } - } + session: RefCell>, } struct Session { @@ -99,13 +87,13 @@ where pub async fn handshake(io: Io, conn: C) -> io::Result { let stream = TlsStream { io, - session: Rc::new(RefCell::new(Session { + session: RefCell::new(Session { conn, read_buf: Some(BytesMut::new()), write_buf: Some(BytesMut::new()), proto_write_buf: BytesMut::new(), pending_plaintext: BytesMut::new(), - })), + }), }; stream._handshake().await?; Ok(stream) @@ -188,12 +176,14 @@ where // Check for plaintext buffered from a previous read first. if !session.pending_plaintext.is_empty() { - let dst = io_ref_mut_slice(plain_buf); - let n = session.pending_plaintext.len().min(dst.len()); - dst[..n].copy_from_slice(&session.pending_plaintext[..n]); - session.pending_plaintext.advance(n); - unsafe { plain_buf.set_init(n) }; - return Ok(n); + let rem = plain_buf.bytes_total() - plain_buf.bytes_init(); + let aval = session.pending_plaintext.len(); + let len = cmp::min(rem, aval); + + plain_buf.put_slice(&session.pending_plaintext[..len]); + session.pending_plaintext.advance(len); + + return Ok(len); } let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); @@ -212,7 +202,7 @@ where } Ok(ConnectionState::ReadTraffic(mut traffic)) => { - let dst = io_ref_mut_slice(plain_buf); + let rem = plain_buf.bytes_total() - plain_buf.bytes_init(); let mut written = 0; let mut err = None; @@ -220,13 +210,15 @@ where match res.map_err(tls_err) { Ok(record) => { let payload = record.payload; - let n = payload.len().min(dst.len() - written); - dst[written..written + n].copy_from_slice(&payload[..n]); - written += n; + let len = payload.len().min(rem - written); + + let (head, tail) = payload.split_at(len); + + plain_buf.put_slice(head); + written += len; + // Buffer overflow into pending_plaintext. - if n < payload.len() { - session_ref.pending_plaintext.extend_from_slice(&payload[n..]); - } + session_ref.pending_plaintext.extend_from_slice(tail); } Err(e) => { err = Some(e); @@ -235,7 +227,6 @@ where } } - drop(traffic); read_buf.advance(discard); if let Some(e) = err { @@ -247,7 +238,6 @@ where continue; } - unsafe { plain_buf.set_init(written) }; break Ok(written); } @@ -315,7 +305,7 @@ where async fn write_tls(&self, plain: &impl BoundedBuf) -> io::Result { let mut session = self.session.borrow_mut(); let mut write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); - let plaintext = io_ref_slice(plain); + let plaintext = plain.chunk(); // Flush protocol data buffered by read path (key updates, alerts). if !session.proto_write_buf.is_empty() { @@ -412,21 +402,11 @@ where (res, buf) } - fn shutdown(&self, direction: Shutdown) -> io::Result<()> { - self.io.shutdown(direction) + async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { + self.io.shutdown(direction).await } } -fn io_ref_slice(buf: &impl BoundedBuf) -> &[u8] { - // SAFETY: trust BoundedBuf implementor to provide valid pointer and length. - unsafe { slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init()) } -} - -fn io_ref_mut_slice(buf: &mut impl BoundedBufMut) -> &mut [u8] { - // SAFETY: trust BoundedBufMut implementor to provide valid pointer and capacity. - unsafe { slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total()) } -} - fn tls_err(e: Error) -> io::Error { io::Error::new(io::ErrorKind::InvalidData, e) } @@ -448,7 +428,7 @@ async fn read_to_buf(io: &impl AsyncBufRead, mut buf: BytesMut) -> (io::Result<( /// Write all bytes from a BytesMut to IO, then clear it. async fn write_all_buf(io: &impl AsyncBufWrite, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { - let (res, b) = xitca_io::io_uring::write_all(io, buf).await; + let (res, b) = xitca_io::io::write_all(io, buf).await; buf = b; if res.is_ok() { buf.clear(); diff --git a/tokio-uring/CHANGELOG.md b/tokio-uring/CHANGELOG.md index 1d3892f14..b077fc67b 100644 --- a/tokio-uring/CHANGELOG.md +++ b/tokio-uring/CHANGELOG.md @@ -1,8 +1,16 @@ # unreleased 0.2.0 +## Fix +- `BoundedBuf::put_slice` now extend to it's uninit part. Multiple calls to it would result in accumulation of bytes and not overwritting + +## Change +- `FixedBuf` would be cleared on check out to buffer pool (By setting is initialized count to zero) - remove runtime from default feature -- add runtime feature for io-uring runtime - perf improvement +## Add +- add `BoundedBuf::chunk` to get inited part as byte slice +- add runtime feature for io-uring runtime + # 0.1.1 - fix MSRV diff --git a/tokio-uring/src/buf/bounded.rs b/tokio-uring/src/buf/bounded.rs index 4d128ef77..3e1ee053a 100644 --- a/tokio-uring/src/buf/bounded.rs +++ b/tokio-uring/src/buf/bounded.rs @@ -1,7 +1,6 @@ use super::{IoBuf, IoBufMut, Slice}; -use std::ops; -use std::ptr; +use core::{ops, ptr, slice}; /// A possibly bounded view into an owned [`IoBuf`] buffer. /// @@ -74,6 +73,13 @@ pub trait BoundedBuf: Unpin + 'static { /// Total size of the view, including uninitialized memory, if any. fn bytes_total(&self) -> usize; + + /// Returns a shared reference to the initialized portion of the buffer. + fn chunk(&self) -> &[u8] { + // Safety: BoundedBuf implementor guarantees stable_ptr points to valid + // memory and bytes_init bytes starting from that pointer are initialized. + unsafe { slice::from_raw_parts(self.stable_ptr(), self.bytes_init()) } + } } impl BoundedBuf for T { @@ -159,21 +165,22 @@ pub trait BoundedBufMut: BoundedBuf { /// /// # Panics /// - /// If the slice's length exceeds the destination's total capacity, + /// If the slice's length exceeds the destination's remaining capacity, /// this method panics. fn put_slice(&mut self, src: &[u8]) { - assert!(self.bytes_total() >= src.len()); - let dst = self.stable_mut_ptr(); + let init = self.bytes_init(); + assert!(self.bytes_total() - init >= src.len()); // Safety: - // dst pointer validity is ensured by stable_mut_ptr; - // the length is checked to not exceed the view's total capacity; + // dst pointer validity is ensured by stable_mut_ptr, offset by + // bytes_init() to write after already-initialized data; + // the length is checked to not exceed the remaining capacity; // src (immutable) and dst (mutable) cannot point to overlapping memory; - // after copying the amount of bytes given by the slice, it's safe - // to mark them as initialized in the buffer. + // after copying, the new initialized watermark is set accordingly. unsafe { + let dst = self.stable_mut_ptr().add(init); ptr::copy_nonoverlapping(src.as_ptr(), dst, src.len()); - self.set_init(src.len()); + self.set_init(init + src.len()); } } } @@ -192,3 +199,67 @@ impl BoundedBufMut for T { unsafe { IoBufMut::set_init(self, pos) } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn put_slice_appends() { + let mut buf = Vec::with_capacity(64); + buf.put_slice(b"hello"); + assert_eq!(&buf, b"hello"); + + buf.put_slice(b" world"); + assert_eq!(&buf, b"hello world"); + } + + #[test] + fn put_slice_empty() { + let mut buf = Vec::with_capacity(16); + buf.put_slice(b""); + assert!(buf.is_empty()); + + buf.put_slice(b"abc"); + assert_eq!(&buf, b"abc"); + + buf.put_slice(b""); + assert_eq!(&buf, b"abc"); + } + + #[test] + fn put_slice_fills_capacity() { + let mut buf = Vec::with_capacity(5); + buf.put_slice(b"ab"); + buf.put_slice(b"cde"); + assert_eq!(&buf, b"abcde"); + assert_eq!(buf.len(), 5); + } + + #[test] + #[should_panic] + fn put_slice_exceeds_capacity() { + let mut buf = Vec::with_capacity(4); + buf.put_slice(b"abcde"); + } + + #[test] + fn chunk_returns_initialized() { + let buf = b"hello".to_vec(); + assert_eq!(buf.chunk(), b"hello"); + } + + #[test] + fn chunk_empty() { + let buf = Vec::::with_capacity(16); + assert_eq!(buf.chunk(), b""); + } + + #[test] + fn chunk_after_put_slice() { + let mut buf = Vec::with_capacity(32); + buf.put_slice(b"foo"); + buf.put_slice(b"bar"); + assert_eq!(buf.chunk(), b"foobar"); + } +} diff --git a/tokio-uring/src/buf/fixed/handle.rs b/tokio-uring/src/buf/fixed/handle.rs index 5e77c1253..6fddf980a 100644 --- a/tokio-uring/src/buf/fixed/handle.rs +++ b/tokio-uring/src/buf/fixed/handle.rs @@ -38,11 +38,9 @@ pub struct FixedBuf { impl Drop for FixedBuf { fn drop(&mut self) { let mut registry = self.registry.borrow_mut(); - // Safety: the length of the initialized data in the buffer has been - // maintained accordingly to the safety contracts on - // Self::new and IoBufMut. + // Safety: passing 0 resets the buffer so it starts empty on next checkout. unsafe { - registry.check_in(self.buf.index, self.buf.init_len); + registry.check_in(self.buf.index, 0); } } } diff --git a/tokio-uring/tests/fixed_buf.rs b/tokio-uring/tests/fixed_buf.rs index 6f2c30872..2e17c35b3 100644 --- a/tokio-uring/tests/fixed_buf.rs +++ b/tokio-uring/tests/fixed_buf.rs @@ -55,7 +55,7 @@ fn fixed_buf_turnaround() { // The buffer has been released, check it out again. let fixed_buf = buffers.check_out(0).unwrap(); assert_eq!(fixed_buf.bytes_total(), 30); - assert_eq!(fixed_buf.bytes_init(), HELLO.len()); + assert_eq!(fixed_buf.bytes_init(), 0); }); } From 631f93b9fba46f6200a665f9a002143a93a39d22 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 1 Apr 2026 17:41:14 +0800 Subject: [PATCH 13/21] example fix --- examples/io-uring/Cargo.toml | 2 +- http/src/h1/dispatcher_unreal.rs | 181 ------------------------------- http/src/h1/mod.rs | 5 +- http/src/h1/service.rs | 2 +- http/src/service.rs | 24 ++-- server/Cargo.toml | 2 +- tokio-uring/CHANGELOG.md | 2 +- 7 files changed, 17 insertions(+), 201 deletions(-) delete mode 100644 http/src/h1/dispatcher_unreal.rs diff --git a/examples/io-uring/Cargo.toml b/examples/io-uring/Cargo.toml index e38e7f33b..fe50582b6 100644 --- a/examples/io-uring/Cargo.toml +++ b/examples/io-uring/Cargo.toml @@ -5,7 +5,7 @@ authors = ["fakeshadow <24548779@qq.com>"] edition = "2024" [dependencies] -xitca-http = { version = "0.8", features = ["io-uring", "router", "rustls"] } +xitca-http = { version = "0.9", features = ["io-uring", "router", "rustls"] } xitca-server = { version = "0.7", features = ["io-uring"] } xitca-service = "0.3" diff --git a/http/src/h1/dispatcher_unreal.rs b/http/src/h1/dispatcher_unreal.rs deleted file mode 100644 index 7f85cafce..000000000 --- a/http/src/h1/dispatcher_unreal.rs +++ /dev/null @@ -1,181 +0,0 @@ -use core::mem::MaybeUninit; - -use std::{io, rc::Rc}; - -use http::StatusCode; -use httparse::{Header, Status}; -use xitca_io::{ - bytes::{Buf, BytesMut, PagedBytesMut}, - io::{AsyncIo, Interest}, - net::TcpStream, -}; -use xitca_service::Service; -use xitca_unsafe_collection::bytes::read_buf; - -use crate::date::{DateTime, DateTimeHandle, DateTimeService}; - -pub type Error = Box; - -pub struct Request<'a, C> { - pub method: &'a str, - pub path: &'a str, - pub headers: &'a mut [Header<'a>], - pub ctx: &'a C, -} - -pub struct Response<'a, const STEP: usize = 1> { - buf: &'a mut BytesMut, - date: &'a DateTimeHandle, -} - -impl<'a> Response<'a> { - pub fn status(self, status: StatusCode) -> Response<'a, 2> { - if status == StatusCode::OK { - self.buf.extend_from_slice(b"HTTP/1.1 200 OK"); - } else { - self.buf.extend_from_slice(b"HTTP/1.1 "); - let reason = status.canonical_reason().unwrap_or("").as_bytes(); - let status = status.as_str().as_bytes(); - self.buf.extend_from_slice(status); - self.buf.extend_from_slice(b" "); - self.buf.extend_from_slice(reason); - } - - Response { - buf: self.buf, - date: self.date, - } - } -} - -impl<'a> Response<'a, 2> { - pub fn header(self, key: &str, val: &str) -> Self { - let key = key.as_bytes(); - let val = val.as_bytes(); - - self.buf.reserve(key.len() + val.len() + 4); - self.buf.extend_from_slice(b"\r\n"); - self.buf.extend_from_slice(key); - self.buf.extend_from_slice(b": "); - self.buf.extend_from_slice(val); - self - } - - pub fn body(self, body: &[u8]) -> Response<'a, 3> { - super::proto::encode::write_length_header(self.buf, body.len()); - self.body_writer(|buf| buf.extend_from_slice(body)) - } - - pub fn body_writer(mut self, func: F) -> Response<'a, 3> - where - F: for<'b> FnOnce(&'b mut BytesMut), - { - self.try_write_date(); - - self.buf.extend_from_slice(b"\r\n\r\n"); - - func(self.buf); - - Response { - buf: self.buf, - date: self.date, - } - } - - fn try_write_date(&mut self) { - self.buf.reserve(DateTimeHandle::DATE_SIZE_HINT + 12); - self.buf.extend_from_slice(b"\r\ndate: "); - self.date.with_date(|date| self.buf.extend_from_slice(date)); - } -} - -pub struct Dispatcher { - handler: F, - ctx: C, - date: DateTimeService, -} - -impl Dispatcher { - pub fn new(handler: F, ctx: C) -> Rc { - Rc::new(Self { - handler, - ctx, - date: DateTimeService::new(), - }) - } -} - -impl Service for Dispatcher -where - F: for<'h, 'b> AsyncFn(Request<'h, C>, Response<'h>) -> Response<'h, 3>, -{ - type Response = (); - type Error = Error; - - async fn call(&self, mut stream: TcpStream) -> Result { - let mut r_buf = PagedBytesMut::<4096>::new(); - let mut w_buf = BytesMut::with_capacity(4096); - - let mut read_closed = false; - - loop { - stream.ready(Interest::READABLE).await?; - - loop { - match read_buf(&mut stream, &mut r_buf) { - Ok(0) => { - if core::mem::replace(&mut read_closed, true) { - return Ok(()); - } - break; - } - Ok(_) => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, - Err(e) => return Err(e.into()), - } - } - - loop { - let mut headers = [const { MaybeUninit::uninit() }; 16]; - - let mut req = httparse::Request::new(&mut []); - - match req.parse_with_uninit_headers(r_buf.chunk(), &mut headers)? { - Status::Complete(len) => { - let req = Request { - path: req.path.unwrap(), - method: req.method.unwrap(), - headers: req.headers, - ctx: &self.ctx, - }; - - let res = Response { - buf: &mut w_buf, - date: self.date.get(), - }; - - (self.handler)(req, res).await; - - r_buf.advance(len); - } - Status::Partial => break, - }; - } - - let mut written = 0; - - while written != w_buf.len() { - match io::Write::write(&mut stream, &w_buf[written..]) { - Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero).into()), - Ok(n) => written += n, - Err(e) if e.kind() == io::ErrorKind::WouldBlock => { - stream.ready(Interest::WRITABLE).await?; - } - Err(e) => return Err(e.into()), - } - } - - w_buf.clear(); - } - } -} diff --git a/http/src/h1/mod.rs b/http/src/h1/mod.rs index 3508c8f62..105889a3c 100644 --- a/http/src/h1/mod.rs +++ b/http/src/h1/mod.rs @@ -1,15 +1,14 @@ //! http/1 specific module for types and protocol utilities. -pub mod dispatcher_unreal; pub mod proto; -pub mod dispatcher; - mod body; mod builder; +mod dispatcher; mod error; mod service; pub use self::body::RequestBody; +pub use self::dispatcher::Dispatcher; pub use self::error::Error; pub use self::service::H1Service; diff --git a/http/src/h1/service.rs b/http/src/h1/service.rs index 647f237af..f7f5d2c06 100644 --- a/http/src/h1/service.rs +++ b/http/src/h1/service.rs @@ -40,7 +40,7 @@ where .await .map_err(|_| HttpServiceError::Timeout(TimeoutError::TlsAccept))??; - super::dispatcher::Dispatcher::run(io, addr, timer, self.config, &self.service, self.date.get()) + super::Dispatcher::run(io, addr, timer, self.config, &self.service, self.date.get()) .await .map_err(Into::into) } diff --git a/http/src/service.rs b/http/src/service.rs index feafd1fa5..c1c48a8b3 100644 --- a/http/src/service.rs +++ b/http/src/service.rs @@ -113,18 +113,16 @@ where match version { #[cfg(feature = "http1")] - super::http::Version::HTTP_11 | super::http::Version::HTTP_10 => { - super::h1::dispatcher::Dispatcher::run( - _tls_stream, - _addr, - timer.as_mut(), - self.config, - &self.service, - self.date.get(), - ) - .await - .map_err(From::from) - } + super::http::Version::HTTP_11 | super::http::Version::HTTP_10 => super::h1::Dispatcher::run( + _tls_stream, + _addr, + timer.as_mut(), + self.config, + &self.service, + self.date.get(), + ) + .await + .map_err(From::from), #[cfg(feature = "http2")] super::http::Version::HTTP_2 => { // update timer to first request timeout. @@ -163,7 +161,7 @@ where { let io = xitca_io::net::UnixStream::from_std(_io).expect("TODO: handle io error"); - super::h1::dispatcher::Dispatcher::run( + super::h1::Dispatcher::run( io, crate::unspecified_socket_addr(), timer.as_mut(), diff --git a/server/Cargo.toml b/server/Cargo.toml index 42cd82a61..e58a35faa 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -25,7 +25,7 @@ tokio = { version = "1.48", features = ["sync", "time"] } tracing = { version = "0.1.40", default-features = false } # io-uring support -tokio-uring-xitca = { version = "0.2.0", optional = true } +tokio-uring-xitca = { version = "0.2.0", features = ["runtime"], optional = true } [target.'cfg(not(target_family = "wasm"))'.dependencies] socket2 = { version = "0.6.0" } diff --git a/tokio-uring/CHANGELOG.md b/tokio-uring/CHANGELOG.md index b077fc67b..5b3c9e162 100644 --- a/tokio-uring/CHANGELOG.md +++ b/tokio-uring/CHANGELOG.md @@ -3,7 +3,7 @@ - `BoundedBuf::put_slice` now extend to it's uninit part. Multiple calls to it would result in accumulation of bytes and not overwritting ## Change -- `FixedBuf` would be cleared on check out to buffer pool (By setting is initialized count to zero) +- `FixedBuf` would be cleared on check out to buffer pool (By setting its initialized size to zero) - remove runtime from default feature - perf improvement From 0d761ff26cfe2afd3873b8f708df37574c70ff06 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 1 Apr 2026 22:21:01 +0800 Subject: [PATCH 14/21] wip --- examples/io-uring-h2/Cargo.toml | 13 +- examples/io-uring-h2/src/main.rs | 25 ++ http/Cargo.toml | 5 +- http/src/h2/dispatcher_uring.rs | 2 +- http/src/tls/native_tls.rs | 137 +------ http/src/tls/openssl.rs | 14 +- http/src/tls/rustls.rs | 52 +-- tls/CHANGES.md | 16 +- tls/Cargo.toml | 6 +- tls/src/bridge.rs | 124 +++--- tls/src/lib.rs | 14 +- tls/src/native_tls.rs | 236 +++++++++++ tls/src/native_tls_complete.rs | 246 ------------ tls/src/openssl.rs | 238 ++++++++---- tls/src/openssl_complete.rs | 188 --------- tls/src/openssl_poll.rs | 115 ++++++ tls/src/rustls.rs | 644 +++++++++++++++++++++++++------ tls/src/rustls_complete.rs | 558 -------------------------- tls/src/rustls_poll.rs | 182 +++++++++ 19 files changed, 1385 insertions(+), 1430 deletions(-) create mode 100644 tls/src/native_tls.rs delete mode 100644 tls/src/native_tls_complete.rs delete mode 100644 tls/src/openssl_complete.rs create mode 100644 tls/src/openssl_poll.rs delete mode 100644 tls/src/rustls_complete.rs create mode 100644 tls/src/rustls_poll.rs diff --git a/examples/io-uring-h2/Cargo.toml b/examples/io-uring-h2/Cargo.toml index fb8a6cbfb..8806d8101 100644 --- a/examples/io-uring-h2/Cargo.toml +++ b/examples/io-uring-h2/Cargo.toml @@ -5,15 +5,18 @@ authors = ["fakeshadow <24548779@qq.com>"] edition = "2024" [dependencies] -xitca-http = { path = "../../http", features = ["http2", "io-uring", "router"] } -xitca-server = { version = "0.6.1", features = ["io-uring"] } +xitca-http = { version = "0.9", features = ["http2", "io-uring", "router"] } +xitca-server = { version = "0.7", features = ["io-uring"] } xitca-service = "0.3" futures-core = "0.3" mimalloc = { version = "0.1.48", default-features = false, features = ["v3"] } -# rcgen = "0.14" -# rustls = "0.23" -# rustls-pki-types = "1" + +# openssl = "0.10.44" +rcgen = "0.14" +rustls = "0.23" +rustls-pki-types = "1" + [profile.release] opt-level = 3 diff --git a/examples/io-uring-h2/src/main.rs b/examples/io-uring-h2/src/main.rs index 712c17603..e97a73cf9 100644 --- a/examples/io-uring-h2/src/main.rs +++ b/examples/io-uring-h2/src/main.rs @@ -88,3 +88,28 @@ impl Stream for Once { // std::sync::Arc::new(config) // } + +// use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; +// fn tls_config() -> io::Result { +// // set up openssl and alpn protocol. +// let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; +// builder.set_private_key_file("../cert/key.pem", SslFiletype::PEM)?; +// builder.set_certificate_chain_file("../cert/cert.pem")?; + +// builder.set_alpn_select_callback(|_, protocols| { +// const H2: &[u8] = b"\x02h2"; +// const H11: &[u8] = b"\x08http/1.1"; + +// if protocols.windows(3).any(|window| window == H2) { +// Ok(b"h2") +// } else if protocols.windows(9).any(|window| window == H11) { +// Ok(b"http/1.1") +// } else { +// Err(AlpnError::NOACK) +// } +// }); + +// builder.set_alpn_protos(b"\x08http/1.1\x02h2")?; + +// Ok(builder.build()) +// } diff --git a/http/Cargo.toml b/http/Cargo.toml index 0c8a30fd8..6c08bf65a 100644 --- a/http/Cargo.toml +++ b/http/Cargo.toml @@ -27,7 +27,7 @@ openssl = ["xitca-tls/openssl"] # rustls as server side tls. rustls = ["xitca-tls/rustls"] # rustls as server side tls. -native-tls = ["dep:native-tls", "runtime"] +native-tls = ["xitca-tls/native-tls"] # async runtime feature. runtime = ["xitca-io/runtime", "tokio"] @@ -46,9 +46,6 @@ httpdate = "1.0" pin-project-lite = "0.2.10" tracing = { version = "0.1.40", default-features = false } -# native tls support -native-tls = { version = "0.2.7", features = ["alpn"], optional = true } - # tls support shared xitca-tls = { version = "0.6.0", optional = true } diff --git a/http/src/h2/dispatcher_uring.rs b/http/src/h2/dispatcher_uring.rs index fd9ca832c..f3f15bfa5 100644 --- a/http/src/h2/dispatcher_uring.rs +++ b/http/src/h2/dispatcher_uring.rs @@ -1514,7 +1514,7 @@ impl ShutDown { res?; // Send FIN so the peer sees a clean connection // close rather than RST (RFC 7540 §6.8). - let _ = io.shutdown(Shutdown::Write); + let _ = io.shutdown(Shutdown::Write).await; return read_res; } SelectOutput::B(res) => res?, diff --git a/http/src/tls/native_tls.rs b/http/src/tls/native_tls.rs index 177dc74d6..882ad6a02 100644 --- a/http/src/tls/native_tls.rs +++ b/http/src/tls/native_tls.rs @@ -1,35 +1,26 @@ -pub(crate) use native_tls::TlsAcceptor; +pub(crate) use xitca_tls::native_tls::TlsAcceptor; -use core::{ - convert::Infallible, - fmt, - pin::Pin, - task::{Context, Poll}, -}; +use core::convert::Infallible; -use std::io; - -use native_tls::{Error, HandshakeError}; -use xitca_io::io::{AsyncIo, Interest, Ready}; +use xitca_io::io::{AsyncBufRead, AsyncBufWrite}; use xitca_service::Service; use crate::{http::Version, version::AsVersion}; use super::error::TlsError; -/// A wrapper type for [TlsStream](native_tls::TlsStream). -/// -/// This is to impl new trait for it. -pub struct TlsStream { - io: native_tls::TlsStream, -} +pub type TlsStream = xitca_tls::native_tls::TlsStream; -impl AsVersion for TlsStream { +impl AsVersion for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ fn as_version(&self) -> Version { - self.io + self.session() .negotiated_alpn() .ok() - .and_then(|proto| proto) + .flatten() + .as_deref() .map(Self::from_alpn) .unwrap_or(Version::HTTP_11) } @@ -54,7 +45,6 @@ impl Service for TlsAcceptorBuilder { let service = TlsAcceptorService { acceptor: self.acceptor.clone(), }; - Ok(service) } } @@ -64,109 +54,20 @@ pub struct TlsAcceptorService { acceptor: TlsAcceptor, } -impl Service for TlsAcceptorService { - type Response = TlsStream; +impl Service for TlsAcceptorService +where + Io: AsyncBufRead + AsyncBufWrite, +{ + type Response = TlsStream; type Error = NativeTlsError; - async fn call(&self, mut io: St) -> Result { - let mut interest = Interest::READABLE; - - io.ready(interest).await?; - - let mut res = self.acceptor.accept(io); - - loop { - let mut stream = match res { - Ok(io) => return Ok(TlsStream { io }), - Err(HandshakeError::WouldBlock(stream)) => { - interest = Interest::READABLE; - stream - } - Err(HandshakeError::Failure(e)) => return Err(e.into()), - }; - - stream.get_mut().ready(interest).await?; - - res = stream.handshake(); - } - } -} - -impl AsyncIo for TlsStream { - #[inline] - fn ready(&mut self, interest: Interest) -> impl Future> + Send { - self.io.get_mut().ready(interest) - } - - #[inline] - fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { - self.io.get_mut().poll_ready(interest, cx) - } - - fn is_vectored_write(&self) -> bool { - self.io.get_ref().is_vectored_write() - } - - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - this.io.shutdown()?; - - AsyncIo::poll_shutdown(Pin::new(this.io.get_mut()), cx) - } -} - -impl io::Read for TlsStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - io::Read::read(&mut self.io, buf) - } -} - -impl io::Write for TlsStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - io::Write::write(&mut self.io, buf) - } - - fn flush(&mut self) -> io::Result<()> { - io::Write::flush(&mut self.io) + async fn call(&self, io: Io) -> Result { + TlsStream::accept(&self.acceptor, io).await } } /// Collection of 'native-tls' error types. -pub enum NativeTlsError { - Io(io::Error), - Tls(Error), -} - -impl fmt::Debug for NativeTlsError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - Self::Io(ref e) => fmt::Debug::fmt(e, f), - Self::Tls(ref e) => fmt::Debug::fmt(e, f), - } - } -} - -impl fmt::Display for NativeTlsError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match *self { - Self::Io(ref e) => fmt::Display::fmt(e, f), - Self::Tls(ref e) => fmt::Display::fmt(e, f), - } - } -} - -impl From for NativeTlsError { - fn from(e: io::Error) -> Self { - Self::Io(e) - } -} - -impl From for NativeTlsError { - fn from(e: Error) -> Self { - Self::Tls(e) - } -} +pub type NativeTlsError = xitca_tls::native_tls::Error; impl From for TlsError { fn from(e: NativeTlsError) -> Self { diff --git a/http/src/tls/openssl.rs b/http/src/tls/openssl.rs index 1822254f4..9b7b847fd 100644 --- a/http/src/tls/openssl.rs +++ b/http/src/tls/openssl.rs @@ -2,7 +2,7 @@ pub(crate) use xitca_tls::openssl::ssl::SslAcceptor as TlsAcceptor; use core::convert::Infallible; -use xitca_io::io::AsyncIo; +use xitca_io::io::{AsyncBufRead, AsyncBufWrite}; use xitca_service::Service; use xitca_tls::openssl::ssl; @@ -14,7 +14,7 @@ pub type TlsStream = xitca_tls::openssl::TlsStream; impl AsVersion for TlsStream where - Io: AsyncIo, + Io: AsyncBufRead + AsyncBufWrite, { fn as_version(&self) -> Version { self.session() @@ -54,14 +54,20 @@ pub struct TlsAcceptorService { impl TlsAcceptorService { #[inline(never)] - async fn accept(&self, io: Io) -> Result, OpensslError> { + async fn accept(&self, io: Io) -> Result, OpensslError> + where + Io: AsyncBufRead + AsyncBufWrite, + { let ctx = self.acceptor.context(); let ssl = ssl::Ssl::new(ctx)?; TlsStream::accept(ssl, io).await } } -impl Service for TlsAcceptorService { +impl Service for TlsAcceptorService +where + Io: AsyncBufRead + AsyncBufWrite, +{ type Response = TlsStream; type Error = OpensslError; diff --git a/http/src/tls/rustls.rs b/http/src/tls/rustls.rs index 635fecbec..244c0ae27 100644 --- a/http/src/tls/rustls.rs +++ b/http/src/tls/rustls.rs @@ -1,10 +1,10 @@ use core::{convert::Infallible, error, fmt}; -use std::{io, net::Shutdown, sync::Arc}; +use std::{io, sync::Arc}; -use xitca_io::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; +use xitca_io::io::{AsyncBufRead, AsyncBufWrite}; use xitca_service::Service; -use xitca_tls::rustls_complete::{Error, ServerConfig, TlsStream as _TlsStream, server::UnbufferedServerConnection}; +use xitca_tls::rustls::{Error, ServerConfig, TlsStream as _TlsStream, server::UnbufferedServerConnection}; use crate::{http::Version, version::AsVersion}; @@ -13,17 +13,14 @@ use super::error::TlsError; pub(crate) type RustlsConfig = Arc; /// A stream managed by rustls for tls read/write. -pub struct TlsStream { - inner: _TlsStream, -} +pub type TlsStream = _TlsStream; impl AsVersion for TlsStream { fn as_version(&self) -> Version { - Version::HTTP_11 - // self.inner.session() - // .alpn_protocol() - // .map(Self::from_alpn) - // .unwrap_or(Version::HTTP_11) + self.session() + .alpn_protocol() + .map(Self::from_alpn) + .unwrap_or(Version::HTTP_11) } } @@ -64,38 +61,7 @@ where async fn call(&self, io: Io) -> Result { let conn = UnbufferedServerConnection::new(self.acceptor.clone())?; - let inner = _TlsStream::handshake(io, conn).await?; - Ok(TlsStream { inner }) - } -} - -impl AsyncBufRead for TlsStream -where - Io: AsyncBufRead, -{ - #[inline] - async fn read(&self, buf: B) -> (io::Result, B) - where - B: BoundedBufMut, - { - self.inner.read(buf).await - } -} - -impl AsyncBufWrite for TlsStream -where - Io: AsyncBufWrite, -{ - #[inline] - async fn write(&self, buf: B) -> (io::Result, B) - where - B: BoundedBuf, - { - self.inner.write(buf).await - } - - async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { - self.inner.shutdown(direction).await + _TlsStream::handshake(io, conn).await.map_err(Into::into) } } diff --git a/tls/CHANGES.md b/tls/CHANGES.md index eb94d85e3..6c4e28054 100644 --- a/tls/CHANGES.md +++ b/tls/CHANGES.md @@ -1,13 +1,21 @@ # unreleased 0.6.0 ## Remove -- removed `rustls-uring` feature. +- removed `rustls-uring` feature +- removed `Clone` impl from all TlsStream types ## Add -- `rustls` feature would always carries completion asyn IO trait impl from `xitca-io` +- add `native-tls` feature ## Change -- rename `rustls-no-crypto` feature to `rustls` -- rename `rustls` feature to `rustls-aws-crypto` +- Cargo feature name rework + + `` -> completion based aysnc IO impl + + `-poll` -> poll based async IO impl + + `rustls` -> the same constraint to above convension. with additonal constraint that no crypto provide is enabled + + `rustls-poll-` -> specific crypto provider enabled - internal change to reduce memory copy when `io-uring` feature enabled # 0.5.1 diff --git a/tls/Cargo.toml b/tls/Cargo.toml index 63bb50040..967444445 100644 --- a/tls/Cargo.toml +++ b/tls/Cargo.toml @@ -11,13 +11,15 @@ readme= "README.md" [features] openssl = ["dep:openssl"] +openssl-poll = ["openssl", "xitca-io/runtime"] native-tls = ["dep:native_tls_crate"] +native-tls-poll = ["dep:native_tls_crate", "xitca-io/runtime"] # rustls with no default crypto provider rustls = ["dep:rustls_crate"] # rustls with aws-lc as crypto provider (default provider from `rustls` crate) -rustls-aws-crypto = ["rustls", "rustls_crate/aws-lc-rs", "xitca-io/runtime"] +rustls-poll-aws-crypto = ["dep:rustls_crate", "rustls_crate/aws-lc-rs", "xitca-io/runtime"] # rustls with ring as crypto provider -rustls-ring-crypto = ["rustls", "rustls_crate/ring", "xitca-io/runtime"] +rustls-poll-ring-crypto = ["dep:rustls_crate", "rustls_crate/ring", "xitca-io/runtime"] [dependencies] xitca-io = { version = "0.6.0" } diff --git a/tls/src/bridge.rs b/tls/src/bridge.rs index 64b1f2703..29eed8e30 100644 --- a/tls/src/bridge.rs +++ b/tls/src/bridge.rs @@ -9,7 +9,7 @@ use std::io; use xitca_io::{ bytes::{Buf, BytesMut}, - io::{AsyncBufRead, AsyncBufWrite}, + io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, }; /// A synchronous bridge that pairs in-memory buffers with an async IO handle. @@ -17,37 +17,56 @@ use xitca_io::{ /// TLS libraries see `Read + Write` backed by `read_buf` and `write_buf`. /// The caller uses [`fill_read_buf`] and [`drain_write_buf`] to move data /// between the buffers and the underlying async IO. -pub(crate) struct SyncBridge { +pub struct SyncBridge { /// Ciphertext read from the network, consumed by TLS `read`. - pub read_buf: BytesMut, + /// Taken by the read path across await points. + pub read_buf: Option, /// Ciphertext produced by TLS `write`, drained to the network. - pub write_buf: BytesMut, + /// Taken by the write path across await points. + pub write_buf: Option, } impl SyncBridge { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self { - read_buf: BytesMut::new(), - write_buf: BytesMut::new(), + read_buf: Some(BytesMut::new()), + write_buf: Some(BytesMut::new()), } } + + pub(crate) fn take_write_buf(&mut self) -> BytesMut { + self.write_buf.take().expect(POLL_TO_COMPLETE) + } + + pub(crate) fn take_read_buf(&mut self) -> BytesMut { + self.read_buf.take().expect(POLL_TO_COMPLETE) + } + + pub(crate) fn set_read_buf(&mut self, buf: BytesMut) { + self.read_buf = Some(buf); + } + + pub(crate) fn set_write_buf(&mut self, buf: BytesMut) { + self.write_buf = Some(buf); + } } impl io::Read for SyncBridge { fn read(&mut self, buf: &mut [u8]) -> io::Result { - if self.read_buf.is_empty() { + let read_buf = self.read_buf.as_mut().expect(POLL_TO_COMPLETE); + if read_buf.is_empty() { return Err(io::ErrorKind::WouldBlock.into()); } - let len = buf.len().min(self.read_buf.len()); - buf[..len].copy_from_slice(&self.read_buf[..len]); - self.read_buf.advance(len); + let len = buf.len().min(read_buf.len()); + buf[..len].copy_from_slice(&read_buf[..len]); + read_buf.advance(len); Ok(len) } } impl io::Write for SyncBridge { fn write(&mut self, buf: &[u8]) -> io::Result { - self.write_buf.extend_from_slice(buf); + self.write_buf.as_mut().expect(POLL_TO_COMPLETE).extend_from_slice(buf); Ok(buf.len()) } @@ -58,66 +77,55 @@ impl io::Write for SyncBridge { /// Read ciphertext from the network into `bridge.read_buf`. pub(crate) async fn fill_read_buf(io: &impl AsyncBufRead, bridge: &mut SyncBridge) -> io::Result<()> { - let len = bridge.read_buf.len(); - bridge.read_buf.reserve(4096); + let mut buf = bridge.take_read_buf(); + let len = buf.len(); + buf.reserve(4096); - let (res, b) = io.read(bridge.read_buf.split_off(len)).await; - let returned = b; + let (res, buf) = io.read(buf.slice(len..)).await; + bridge.set_read_buf(buf.into_inner()); match res { - Ok(0) => { - bridge.read_buf.unsplit(returned); - Err(io::ErrorKind::UnexpectedEof.into()) - } - Ok(_) => { - bridge.read_buf.unsplit(returned); - Ok(()) - } - Err(e) => { - bridge.read_buf.unsplit(returned); - Err(e) - } + Ok(0) => Err(io::ErrorKind::UnexpectedEof.into()), + Ok(_) => Ok(()), + Err(e) => Err(e), } } /// Drain all ciphertext from `bridge.write_buf` to the network. pub(crate) async fn drain_write_buf(io: &impl AsyncBufWrite, bridge: &mut SyncBridge) -> io::Result<()> { - if bridge.write_buf.is_empty() { - return Ok(()); - } - let buf = bridge.write_buf.split(); - let (res, b) = xitca_io::io::write_all(io, buf).await; - drop(b); + let buf = bridge.take_write_buf(); + + let (res, buf) = drain_write(io, buf).await; + bridge.set_write_buf(buf); + res } -/// Drain a pre-split write buffer to the network. -/// Used when the caller has already split the write_buf out of the bridge -/// (e.g. to drop a RefCell borrow before awaiting). -pub(crate) async fn drain_split(io: &impl AsyncBufWrite, buf: BytesMut) -> io::Result<()> { +/// Drain a taken write buffer to the network. +/// Always returns the buffer (cleared on success) so the caller can put it back. +pub(crate) async fn drain_write(io: &impl AsyncBufWrite, buf: BytesMut) -> (io::Result<()>, BytesMut) { if buf.is_empty() { - return Ok(()); + return (Ok(()), buf); } - let (res, _) = xitca_io::io::write_all(io, buf).await; - res -} -/// Split off a read buffer from the bridge for async filling. -/// Reserves space and returns the tail portion for IO. -/// After the read, unsplit the returned buffer back into `bridge.read_buf`. -pub(crate) fn take_read_buf(bridge: &mut SyncBridge) -> BytesMut { - let len = bridge.read_buf.len(); - bridge.read_buf.reserve(4096); - bridge.read_buf.split_off(len) + let (res, mut buf) = xitca_io::io::write_all(io, buf).await; + buf.clear(); + + (res, buf) } -/// Fill a pre-split read buffer from the network. -/// Always returns the buffer (even on error) so the caller can unsplit it back. -pub(crate) async fn fill_split(io: &impl AsyncBufRead, buf: BytesMut) -> (io::Result<()>, BytesMut) { - let (res, buf) = io.read(buf).await; - match res { - Ok(0) => (Err(io::ErrorKind::UnexpectedEof.into()), buf), - Ok(_) => (Ok(()), buf), - Err(e) => (Err(e), buf), - } +/// Get a mutable slice over the spare (uninitialized) capacity of a `BoundedBufMut`. +/// +/// The returned slice is safe to write into freely. It is however unsafe to +/// read from, as the memory may be uninitialized. +/// +/// # Safety +/// +/// The caller must not read from the returned slice beyond what has been written. +pub(crate) unsafe fn spare_capacity_mut(buf: &mut impl BoundedBufMut) -> &mut [u8] { + let init = buf.bytes_init(); + let total = buf.bytes_total(); + unsafe { core::slice::from_raw_parts_mut(buf.stable_mut_ptr().add(init), total - init) } } + +const POLL_TO_COMPLETE: &str = "previous call to future didn't polled to completion"; diff --git a/tls/src/lib.rs b/tls/src/lib.rs index ce73850cb..33ed11df1 100644 --- a/tls/src/lib.rs +++ b/tls/src/lib.rs @@ -1,13 +1,15 @@ #[cfg(feature = "native-tls")] -pub mod native_tls_complete; +pub mod native_tls; + #[cfg(feature = "openssl")] pub mod openssl; -#[cfg(feature = "openssl")] -pub mod openssl_complete; -#[cfg(any(feature = "rustls", feature = "rustls-ring-crypto", feature = "rustls-aws-crypto"))] -pub mod rustls; +#[cfg(feature = "openssl-poll")] +pub mod openssl_poll; + #[cfg(feature = "rustls")] -pub mod rustls_complete; +pub mod rustls; +#[cfg(any(feature = "rustls-poll-ring-crypto", feature = "rustls-poll-aws-crypto"))] +pub mod rustls_poll; #[cfg(any(feature = "openssl", feature = "native-tls"))] pub(crate) mod bridge; diff --git a/tls/src/native_tls.rs b/tls/src/native_tls.rs new file mode 100644 index 000000000..fa49d4332 --- /dev/null +++ b/tls/src/native_tls.rs @@ -0,0 +1,236 @@ +#![allow(clippy::await_holding_refcell_ref)] // clippy is dumb + +//! Completion-based async IO wrapper for native-tls TLS streams. + +use core::{ + cell::{Ref, RefCell}, + fmt, +}; + +use std::{io, net::Shutdown}; + +pub use native_tls_crate::{TlsAcceptor, TlsConnector}; + +use native_tls_crate::HandshakeError; +use xitca_io::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; + +use crate::bridge::{self, SyncBridge}; + +/// A TLS stream using native-tls with completion-based async IO. +/// +/// Supports one concurrent read + one concurrent write. Concurrent read + read +/// or write + write will panic. +pub struct TlsStream { + io: Io, + tls: RefCell>, +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + /// Perform a TLS server-side accept handshake. + pub async fn accept(acceptor: &TlsAcceptor, io: Io) -> Result { + let bridge = SyncBridge::new(); + Self::handshake(io, acceptor.accept(bridge)).await + } + + /// Perform a TLS client-side connect handshake. + pub async fn connect(connector: &TlsConnector, domain: &str, io: Io) -> Result { + let bridge = SyncBridge::new(); + Self::handshake(io, connector.connect(domain, bridge)).await + } + + async fn handshake( + io: Io, + result: Result, HandshakeError>, + ) -> Result { + let mut mid = match result { + Ok(tls) => { + return Ok(TlsStream { + io, + tls: RefCell::new(tls), + }); + } + Err(HandshakeError::WouldBlock(mid)) => mid, + Err(HandshakeError::Failure(e)) => return Err(Error::Tls(e)), + }; + + loop { + bridge::drain_write_buf(&io, mid.get_mut()).await.map_err(Error::Io)?; + bridge::fill_read_buf(&io, mid.get_mut()).await.map_err(Error::Io)?; + + match mid.handshake() { + Ok(tls) => { + return Ok(TlsStream { + io, + tls: RefCell::new(tls), + }); + } + Err(HandshakeError::WouldBlock(m)) => mid = m, + Err(HandshakeError::Failure(e)) => return Err(Error::Tls(e)), + } + } + } +} + +impl TlsStream { + /// Returns the negotiated ALPN protocol, if any. + pub fn session(&self) -> Ref<'_, native_tls_crate::TlsStream> { + self.tls.borrow() + } +} + +impl AsyncBufRead for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn read(&self, mut buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + // SAFETY: only write to spare slice without reading it. + let spare = unsafe { bridge::spare_capacity_mut(&mut buf) }; + let res = self.read_tls(spare).await; + + if let Ok(n) = &res { + let init = buf.bytes_init(); + // SAFETY: read_tls writes contiguously from the start of the spare + // slice, returning exactly n bytes written. init + n is the new + // initialized boundary. + unsafe { buf.set_init(init + n) }; + } + + (res, buf) + } +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + /// Read path: only reads from the IO socket, never writes. + /// Protocol write data (key updates, alerts) is stashed in + /// `proto_write_buf` for the write path to flush. + /// Read path: only reads from the IO socket, never writes. + /// Protocol write data (key updates, alerts) is stashed in + /// `proto_write_buf` for the write path to flush. + async fn read_tls(&self, buf: &mut [u8]) -> io::Result { + let mut tls = self.tls.borrow_mut(); + + loop { + match io::Read::read(&mut *tls, buf) { + Ok(n) => return Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // Take read_buf — panics if another read is in progress. + let mut read_buf = tls.get_mut().take_read_buf(); + + let len = read_buf.len(); + read_buf.reserve(4096); + + drop(tls); + + let (res, buf) = self.io.read(read_buf.slice(len..)).await; + + tls = self.tls.borrow_mut(); + // Feed new data to bridge for next tls.read() iteration. + tls.get_mut().set_read_buf(buf.into_inner()); + + match res { + Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()), + Ok(_) => {} + Err(e) => return Err(e), + } + } + Err(e) => return Err(e), + } + } + } +} + +impl AsyncBufWrite for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + let data = buf.chunk(); + let res = self.write_tls(data).await; + (res, buf) + } + + async fn shutdown(&self, _direction: Shutdown) -> io::Result<()> { + Ok(()) + } +} + +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + /// Write path: owns the IO write side exclusively. + /// Flushes protocol data buffered by the read path before its own ciphertext. + async fn write_tls(&self, buf: &[u8]) -> io::Result { + let mut tls = self.tls.borrow_mut(); + + loop { + match io::Write::write(&mut *tls, buf) { + Ok(n) => { + let buf = tls.get_mut().take_write_buf(); + drop(tls); + + let (res, buf) = bridge::drain_write(&self.io, buf).await; + + tls = self.tls.borrow_mut(); + tls.get_mut().set_write_buf(buf); + + return res.map(|_| n); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + let buf = tls.get_mut().take_write_buf(); + drop(tls); + + let (res, buf) = bridge::drain_write(&self.io, buf).await; + + tls = self.tls.borrow_mut(); + tls.get_mut().set_write_buf(buf); + + res?; + } + Err(e) => return Err(e), + } + } + } +} + +/// Collection of native-tls error types. +#[derive(Debug)] +pub enum Error { + Io(io::Error), + Tls(native_tls_crate::Error), +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Self::Io(e) + } +} + +impl From for Error { + fn from(e: native_tls_crate::Error) -> Self { + Self::Tls(e) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(e) => fmt::Display::fmt(e, f), + Self::Tls(e) => fmt::Display::fmt(e, f), + } + } +} + +impl std::error::Error for Error {} diff --git a/tls/src/native_tls_complete.rs b/tls/src/native_tls_complete.rs deleted file mode 100644 index a222f4bc3..000000000 --- a/tls/src/native_tls_complete.rs +++ /dev/null @@ -1,246 +0,0 @@ -//! Completion-based async IO wrapper for native-tls TLS streams. - -use core::{cell::RefCell, fmt}; - -use std::{io, net::Shutdown}; - -use native_tls_crate::{HandshakeError, TlsAcceptor, TlsConnector}; - -use xitca_io::{ - bytes::BytesMut, - io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, -}; - -use crate::bridge::{self, SyncBridge}; - -/// A TLS stream using native-tls with completion-based async IO. -/// -/// Supports concurrent read + write from separate tasks. Concurrent read + read -/// or write + write will panic. -pub struct TlsStream { - io: Io, - session: RefCell, -} - -struct Session { - tls: native_tls_crate::TlsStream, - /// Taken by the read path. Carries network data across await points. - read_buf: Option, - /// Taken by the write path. Serves as a concurrent-write guard. - write_buf: Option, -} - -const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; - -impl TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - /// Perform a TLS server-side accept handshake. - pub async fn accept(acceptor: &TlsAcceptor, io: Io) -> Result { - let bridge = SyncBridge::new(); - Self::handshake(io, acceptor.accept(bridge)).await - } - - /// Perform a TLS client-side connect handshake. - pub async fn connect(connector: &TlsConnector, domain: &str, io: Io) -> Result { - let bridge = SyncBridge::new(); - Self::handshake(io, connector.connect(domain, bridge)).await - } - - async fn handshake( - io: Io, - result: Result, HandshakeError>, - ) -> Result { - let mut mid = match result { - Ok(tls) => return Ok(Self::from_tls(io, tls)), - Err(HandshakeError::WouldBlock(mid)) => mid, - Err(HandshakeError::Failure(e)) => return Err(Error::Tls(e)), - }; - - loop { - bridge::drain_write_buf(&io, mid.get_mut()).await.map_err(Error::Io)?; - bridge::fill_read_buf(&io, mid.get_mut()).await.map_err(Error::Io)?; - - match mid.handshake() { - Ok(tls) => return Ok(Self::from_tls(io, tls)), - Err(HandshakeError::WouldBlock(m)) => mid = m, - Err(HandshakeError::Failure(e)) => return Err(Error::Tls(e)), - } - } - } - - fn from_tls(io: Io, tls: native_tls_crate::TlsStream) -> Self { - TlsStream { - io, - session: RefCell::new(Session { - tls, - read_buf: Some(BytesMut::new()), - write_buf: Some(BytesMut::new()), - }), - } - } -} - -impl AsyncBufRead for TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn read(&self, mut buf: B) -> (io::Result, B) - where - B: BoundedBufMut, - { - let init = buf.bytes_init(); - let total = buf.bytes_total(); - let spare = unsafe { core::slice::from_raw_parts_mut(buf.stable_mut_ptr().add(init), total - init) }; - - let res = self.read_tls(spare).await; - - if let Ok(n) = &res { - unsafe { buf.set_init(init + n) }; - } - - (res, buf) - } -} - -impl TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn read_tls(&self, buf: &mut [u8]) -> io::Result { - let mut session = self.session.borrow_mut(); - let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); - - // Return previously fetched network data to bridge. - session.tls.get_mut().read_buf.unsplit(read_buf); - - let res = loop { - match io::Read::read(&mut session.tls, buf) { - Ok(n) => { - let proto_data = session.tls.get_mut().write_buf.split(); - drop(session); - - let drain_res = bridge::drain_split(&self.io, proto_data).await; - - session = self.session.borrow_mut(); - if let Err(e) = drain_res { - break Err(e); - } - break Ok(n); - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - let proto_data = session.tls.get_mut().write_buf.split(); - read_buf = bridge::take_read_buf(session.tls.get_mut()); - drop(session); - - let drain_res = bridge::drain_split(&self.io, proto_data).await; - let (fill_res, b) = bridge::fill_split(&self.io, read_buf).await; - read_buf = b; - - session = self.session.borrow_mut(); - session.tls.get_mut().read_buf.unsplit(read_buf); - - drain_res?; - if let Err(e) = fill_res { - break Err(e); - } - } - Err(e) => break Err(e), - } - }; - - session.read_buf = Some(BytesMut::new()); - res - } -} - -impl AsyncBufWrite for TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn write(&self, buf: B) -> (io::Result, B) - where - B: BoundedBuf, - { - let data = buf.chunk(); - let res = self.write_tls(data).await; - (res, buf) - } - - async fn shutdown(&self, _direction: Shutdown) -> io::Result<()> { - Ok(()) - } -} - -impl TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn write_tls(&self, buf: &[u8]) -> io::Result { - let mut session = self.session.borrow_mut(); - let write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); - - let res = loop { - match io::Write::write(&mut session.tls, buf) { - Ok(n) => { - let ciphertext = session.tls.get_mut().write_buf.split(); - drop(session); - - let drain_res = bridge::drain_split(&self.io, ciphertext).await; - - session = self.session.borrow_mut(); - if let Err(e) = drain_res { - break Err(e); - } - break Ok(n); - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - let ciphertext = session.tls.get_mut().write_buf.split(); - drop(session); - - let drain_res = bridge::drain_split(&self.io, ciphertext).await; - - session = self.session.borrow_mut(); - if let Err(e) = drain_res { - break Err(e); - } - } - Err(e) => break Err(e), - } - }; - - session.write_buf = Some(write_buf); - res - } -} - -/// Collection of native-tls error types. -#[derive(Debug)] -pub enum Error { - Io(io::Error), - Tls(native_tls_crate::Error), -} - -impl From for Error { - fn from(e: io::Error) -> Self { - Self::Io(e) - } -} - -impl From for Error { - fn from(e: native_tls_crate::Error) -> Self { - Self::Tls(e) - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Io(e) => fmt::Display::fmt(e, f), - Self::Tls(e) => fmt::Display::fmt(e, f), - } - } -} - -impl std::error::Error for Error {} diff --git a/tls/src/openssl.rs b/tls/src/openssl.rs index 541dd3553..62fae9020 100644 --- a/tls/src/openssl.rs +++ b/tls/src/openssl.rs @@ -1,115 +1,219 @@ +#![allow(clippy::await_holding_refcell_ref)] // clippy is dumb + +//! Completion-based async IO wrapper for OpenSSL TLS streams. + use core::{ - future::Future, - pin::Pin, - task::{Context, Poll}, + cell::{Ref, RefCell}, + fmt, }; -use std::io; +use std::{io, net::Shutdown}; pub use openssl::*; -use openssl::ssl::{ErrorCode, ShutdownResult, Ssl, SslRef, SslStream}; -use xitca_io::io::{AsyncIo, Interest, Ready}; +use openssl::ssl::{ErrorCode, Ssl, SslRef, SslStream}; +use xitca_io::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; -pub use super::openssl_complete::Error; +use crate::bridge::{self, SyncBridge}; -/// A stream managed by `openssl` crate for tls read/write. +/// A TLS stream using OpenSSL with completion-based async IO. +/// +/// Supports one concurrent read + one concurrent write. Concurrent read + read +/// or write + write will panic. pub struct TlsStream { - io: SslStream, + io: Io, + tls: RefCell>, } impl TlsStream where - Io: AsyncIo, + Io: AsyncBufRead + AsyncBufWrite, { - /// acquire a reference to the session type. - pub fn session(&self) -> &SslRef { - self.io.ssl() - } - + /// Perform a TLS server-side accept handshake. pub async fn accept(ssl: Ssl, io: Io) -> Result { - Self::connect_or_accept(ssl, io, |io| io.accept()).await + Self::handshake(ssl, io, |tls| tls.accept()).await } + /// Perform a TLS client-side connect handshake. pub async fn connect(ssl: Ssl, io: Io) -> Result { - Self::connect_or_accept(ssl, io, |io| io.connect()).await + Self::handshake(ssl, io, |tls| tls.connect()).await } - async fn connect_or_accept(ssl: Ssl, io: Io, mut func: F) -> Result + async fn handshake(ssl: Ssl, io: Io, mut func: F) -> Result where - F: FnMut(&mut SslStream) -> Result<(), openssl::ssl::Error>, + F: FnMut(&mut SslStream) -> Result<(), openssl::ssl::Error>, { - let mut io = SslStream::new(ssl, io)?; - let mut interest = Interest::READABLE | Interest::WRITABLE; + let bridge = SyncBridge::new(); + let mut tls = SslStream::new(ssl, bridge)?; + loop { - io.get_mut().ready(interest).await.map_err(Error::Io)?; - match func(&mut io) { - Ok(_) => return Ok(TlsStream { io }), - Err(ref e) if e.code() == ErrorCode::WANT_READ => { - interest = Interest::READABLE; - } - Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { - interest = Interest::WRITABLE; + let res = func(&mut tls); + + bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; + bridge::fill_read_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; + + match res { + Ok(_) => { + return Ok(TlsStream { + io, + tls: RefCell::new(tls), + }); } + Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {} Err(e) => return Err(Error::Tls(e)), } } } -} -impl AsyncIo for TlsStream { - #[inline] - fn ready(&mut self, interest: Interest) -> impl Future> + Send { - self.io.get_mut().ready(interest) + /// Acquire a reference to the SSL session. + pub fn session(&self) -> Ref<'_, SslRef> { + let tls = self.tls.borrow(); + Ref::map(tls, |tls| tls.ssl()) } +} - #[inline] - fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { - self.io.get_mut().poll_ready(interest, cx) - } +impl AsyncBufRead for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn read(&self, mut buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + // SAFETY: only write to spare slice without reading it. + let spare = unsafe { bridge::spare_capacity_mut(&mut buf) }; + let res = self.read_tls(spare).await; + + if let Ok(n) = &res { + let init = buf.bytes_init(); + // SAFETY: read_tls writes contiguously from the start of the spare + // slice, returning exactly n bytes written. init + n is the new + // initialized boundary. + unsafe { buf.set_init(init + n) }; + } - fn is_vectored_write(&self) -> bool { - self.io.get_ref().is_vectored_write() + (res, buf) } +} - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - // copied from tokio-openssl crate. - match this.io.shutdown() { - Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {} - Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {} - Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => { - return Poll::Pending; - } - Err(e) => { - return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other))); +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + /// Read path: only reads from the IO socket, never writes. + /// Protocol write data (key updates, alerts) is stashed in + /// `proto_write_buf` for the write path to flush. + async fn read_tls(&self, buf: &mut [u8]) -> io::Result { + let mut tls = self.tls.borrow_mut(); + + loop { + match io::Read::read(&mut *tls, buf) { + Ok(n) => return Ok(n), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // Take read_buf — panics if another read is in progress. + let mut read_buf = tls.get_mut().take_read_buf(); + + // Prepare read_buf for network fill. + let len = read_buf.len(); + read_buf.reserve(4096); + + drop(tls); + + let (res, buf) = self.io.read(read_buf.slice(len..)).await; + + tls = self.tls.borrow_mut(); + // Feed new data to bridge for next tls.read() iteration. + tls.get_mut().set_read_buf(buf.into_inner()); + + match res { + Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()), + Ok(_) => {} + Err(e) => return Err(e), + } + } + Err(e) => return Err(e), } } - - AsyncIo::poll_shutdown(Pin::new(this.io.get_mut()), cx) } } -impl io::Read for TlsStream { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - io::Read::read(&mut self.io, buf) +impl AsyncBufWrite for TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + let data = buf.chunk(); + let res = self.write_tls(data).await; + (res, buf) + } + + async fn shutdown(&self, _direction: Shutdown) -> io::Result<()> { + Ok(()) } } -impl io::Write for TlsStream { - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - io::Write::write(&mut self.io, buf) +impl TlsStream +where + Io: AsyncBufRead + AsyncBufWrite, +{ + /// Write path: owns the IO write side exclusively. + /// Flushes protocol data buffered by the read path before its own ciphertext. + async fn write_tls(&self, buf: &[u8]) -> io::Result { + let mut tls = self.tls.borrow_mut(); + + loop { + match io::Write::write(&mut *tls, buf) { + Ok(n) => { + let buf = tls.get_mut().take_write_buf(); + drop(tls); + + let (res, buf) = bridge::drain_write(&self.io, buf).await; + + tls = self.tls.borrow_mut(); + tls.get_mut().set_write_buf(buf); + + return res.map(|_| n); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + let buf = tls.get_mut().take_write_buf(); + drop(tls); + + let (res, buf) = bridge::drain_write(&self.io, buf).await; + + tls = self.tls.borrow_mut(); + tls.get_mut().set_write_buf(buf); + + res?; + } + Err(e) => return Err(e), + } + } } +} + +/// Collection of OpenSSL error types. +#[derive(Debug)] +pub enum Error { + Io(io::Error), + Tls(openssl::ssl::Error), +} - #[inline] - fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { - io::Write::write_vectored(&mut self.io, bufs) +impl From for Error { + fn from(e: openssl::error::ErrorStack) -> Self { + Self::Tls(e.into()) } +} - #[inline] - fn flush(&mut self) -> io::Result<()> { - io::Write::flush(&mut self.io) +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(e) => fmt::Display::fmt(e, f), + Self::Tls(e) => fmt::Display::fmt(e, f), + } } } + +impl std::error::Error for Error {} diff --git a/tls/src/openssl_complete.rs b/tls/src/openssl_complete.rs deleted file mode 100644 index 442517748..000000000 --- a/tls/src/openssl_complete.rs +++ /dev/null @@ -1,188 +0,0 @@ -//! Completion-based async IO wrapper for OpenSSL TLS streams. - -use core::{cell::RefCell, fmt}; - -use std::{io, net::Shutdown}; - -pub use openssl::*; - -use openssl::ssl::{ErrorCode, Ssl, SslRef, SslStream}; - -use xitca_io::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; - -use crate::bridge::{self, SyncBridge}; - -/// A TLS stream using OpenSSL with completion-based async IO. -pub struct TlsStream { - io: Io, - tls: RefCell>, -} - -impl TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - /// Perform a TLS server-side accept handshake. - pub async fn accept(ssl: Ssl, io: Io) -> Result { - Self::handshake(ssl, io, |tls| tls.accept()).await - } - - /// Perform a TLS client-side connect handshake. - pub async fn connect(ssl: Ssl, io: Io) -> Result { - Self::handshake(ssl, io, |tls| tls.connect()).await - } - - async fn handshake(ssl: Ssl, io: Io, mut func: F) -> Result - where - F: FnMut(&mut SslStream) -> Result<(), openssl::ssl::Error>, - { - let bridge = SyncBridge::new(); - let mut tls = SslStream::new(ssl, bridge)?; - - loop { - match func(&mut tls) { - Ok(_) => { - return Ok(TlsStream { - io, - tls: RefCell::new(tls), - }); - } - Err(ref e) if e.code() == ErrorCode::WANT_READ => { - bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; - bridge::fill_read_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; - } - Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { - bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; - } - Err(e) => return Err(Error::Tls(e)), - } - } - } - - /// Acquire a reference to the SSL session. - pub fn session(&self) -> &SslRef { - let tls = self.tls.borrow(); - // SAFETY: SslRef points into the heap-allocated OpenSSL context, not - // the RefCell guard. The context lives as long as self. - unsafe { &*(tls.ssl() as *const SslRef) } - } -} - -impl AsyncBufRead for TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn read(&self, mut buf: B) -> (io::Result, B) - where - B: BoundedBufMut, - { - let init = buf.bytes_init(); - let total = buf.bytes_total(); - let spare = unsafe { core::slice::from_raw_parts_mut(buf.stable_mut_ptr().add(init), total - init) }; - - let res = self.read_tls(spare).await; - - if let Ok(n) = &res { - unsafe { buf.set_init(init + n) }; - } - - (res, buf) - } -} - -impl TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn read_tls(&self, buf: &mut [u8]) -> io::Result { - loop { - let mut tls = self.tls.borrow_mut(); - match io::Read::read(&mut *tls, buf) { - Ok(n) => { - let write_data = tls.get_mut().write_buf.split(); - drop(tls); - bridge::drain_split(&self.io, write_data).await?; - return Ok(n); - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - let write_data = tls.get_mut().write_buf.split(); - let read_buf = bridge::take_read_buf(tls.get_mut()); - drop(tls); - - bridge::drain_split(&self.io, write_data).await?; - let read_buf = bridge::fill_split(&self.io, read_buf).await?; - - self.tls.borrow_mut().get_mut().read_buf.unsplit(read_buf); - } - Err(e) => return Err(e), - } - } - } -} - -impl AsyncBufWrite for TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn write(&self, buf: B) -> (io::Result, B) - where - B: BoundedBuf, - { - let data = buf.chunk(); - let res = self.write_tls(data).await; - (res, buf) - } - - async fn shutdown(&self, _direction: Shutdown) -> io::Result<()> { - Ok(()) - } -} - -impl TlsStream -where - Io: AsyncBufRead + AsyncBufWrite, -{ - async fn write_tls(&self, buf: &[u8]) -> io::Result { - loop { - let mut tls = self.tls.borrow_mut(); - match io::Write::write(&mut *tls, buf) { - Ok(n) => { - let write_data = tls.get_mut().write_buf.split(); - drop(tls); - bridge::drain_split(&self.io, write_data).await?; - return Ok(n); - } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - let write_data = tls.get_mut().write_buf.split(); - drop(tls); - bridge::drain_split(&self.io, write_data).await?; - } - Err(e) => return Err(e), - } - } - } -} - -/// Collection of OpenSSL error types. -#[derive(Debug)] -pub enum Error { - Io(io::Error), - Tls(openssl::ssl::Error), -} - -impl From for Error { - fn from(e: openssl::error::ErrorStack) -> Self { - Self::Tls(e.into()) - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Io(e) => fmt::Display::fmt(e, f), - Self::Tls(e) => fmt::Display::fmt(e, f), - } - } -} - -impl std::error::Error for Error {} diff --git a/tls/src/openssl_poll.rs b/tls/src/openssl_poll.rs new file mode 100644 index 000000000..fe3beaa30 --- /dev/null +++ b/tls/src/openssl_poll.rs @@ -0,0 +1,115 @@ +use core::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use std::io; + +pub use openssl::*; + +use openssl::ssl::{ErrorCode, ShutdownResult, Ssl, SslRef, SslStream}; +use xitca_io::io::{AsyncIo, Interest, Ready}; + +pub use super::openssl::Error; + +/// A stream managed by `openssl` crate for tls read/write. +pub struct TlsStream { + io: SslStream, +} + +impl TlsStream +where + Io: AsyncIo, +{ + /// acquire a reference to the session type. + pub fn session(&self) -> &SslRef { + self.io.ssl() + } + + pub async fn accept(ssl: Ssl, io: Io) -> Result { + Self::connect_or_accept(ssl, io, |io| io.accept()).await + } + + pub async fn connect(ssl: Ssl, io: Io) -> Result { + Self::connect_or_accept(ssl, io, |io| io.connect()).await + } + + async fn connect_or_accept(ssl: Ssl, io: Io, mut func: F) -> Result + where + F: FnMut(&mut SslStream) -> Result<(), openssl::ssl::Error>, + { + let mut io = SslStream::new(ssl, io)?; + let mut interest = Interest::READABLE | Interest::WRITABLE; + loop { + io.get_mut().ready(interest).await.map_err(Error::Io)?; + match func(&mut io) { + Ok(_) => return Ok(TlsStream { io }), + Err(ref e) if e.code() == ErrorCode::WANT_READ => { + interest = Interest::READABLE; + } + Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { + interest = Interest::WRITABLE; + } + Err(e) => return Err(Error::Tls(e)), + } + } + } +} + +impl AsyncIo for TlsStream { + #[inline] + fn ready(&mut self, interest: Interest) -> impl Future> + Send { + self.io.get_mut().ready(interest) + } + + #[inline] + fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { + self.io.get_mut().poll_ready(interest, cx) + } + + fn is_vectored_write(&self) -> bool { + self.io.get_ref().is_vectored_write() + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + // copied from tokio-openssl crate. + match this.io.shutdown() { + Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {} + Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {} + Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => { + return Poll::Pending; + } + Err(e) => { + return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other))); + } + } + + AsyncIo::poll_shutdown(Pin::new(this.io.get_mut()), cx) + } +} + +impl io::Read for TlsStream { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut self.io, buf) + } +} + +impl io::Write for TlsStream { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut self.io, buf) + } + + #[inline] + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + io::Write::write_vectored(&mut self.io, bufs) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut self.io) + } +} diff --git a/tls/src/rustls.rs b/tls/src/rustls.rs index 1c9d61ee0..cf495b39b 100644 --- a/tls/src/rustls.rs +++ b/tls/src/rustls.rs @@ -1,182 +1,574 @@ +#![allow(clippy::await_holding_refcell_ref)] // clippy is dumb + use core::{ - future::Future, - ops::DerefMut, - pin::Pin, - task::{Context, Poll}, + cell::{Ref, RefCell}, + cmp, + ops::Deref, + slice, }; -use std::io; +use std::{io, net::Shutdown}; pub use rustls_crate::*; -use xitca_io::io::{AsyncIo, Interest, Ready}; +use rustls_crate::{ + client::UnbufferedClientConnection, + server::UnbufferedServerConnection, + unbuffered::UnbufferedConnectionCommon, + unbuffered::{ConnectionState, EncryptError, UnbufferedStatus}, +}; + +use xitca_io::{ + bytes::{Buf, BytesMut}, + io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, +}; + +/// Trait to abstract over `UnbufferedServerConnection` and `UnbufferedClientConnection`, +/// since `process_tls_records` is not on a shared trait in rustls. +#[doc(hidden)] +pub trait ProcessTlsRecords: sealed::Sealed { + type Data; + + fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data>; +} + +mod sealed { + pub trait Sealed {} + impl Sealed for super::UnbufferedServerConnection {} + impl Sealed for super::UnbufferedClientConnection {} +} + +impl ProcessTlsRecords for UnbufferedServerConnection { + type Data = server::ServerConnectionData; -/// A stream managed by `rustls` crate for tls read/write. + fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data> { + let inner: &mut UnbufferedConnectionCommon = self; + inner.process_tls_records(incoming_tls) + } +} + +impl ProcessTlsRecords for UnbufferedClientConnection { + type Data = client::ClientConnectionData; + + fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data> { + let inner: &mut UnbufferedConnectionCommon = self; + inner.process_tls_records(incoming_tls) + } +} + +/// Reduced `ConnectionState` that doesn't borrow the connection or incoming buffer. +/// Created by draining all needed data from the borrowed state variants. +/// A TLS stream type that supports concurrent async read/write through [AsyncBufRead] and +/// [AsyncBufWrite] traits. +/// +/// [AsyncBufRead::read] and [AsyncBufWrite::write] can be polled concurrently from separate +/// tasks. The read path owns `read_buf` during IO and the write path owns `write_buf`, so +/// neither blocks the other while awaiting kernel completions. +/// +/// # Panics +/// Each async read/write operation must be polled to completion. Dropping a future before it +/// completes will leave internal buffers in a taken state, causing the next call to panic. pub struct TlsStream { - conn: C, io: Io, + session: RefCell>, +} + +struct Session { + conn: C, + read_buf: Option, + /// Write buffer for application data (used by write path). + write_buf: Option, + /// Write buffer for TLS protocol responses during reads (key updates, alerts). + proto_write_buf: BytesMut, + /// Plaintext buffered from a previous read. + pending_plaintext: BytesMut, } -impl TlsStream +impl TlsStream where - C: DerefMut>, - S: SideData, - Io: io::Read + io::Write, + C: Deref>, { - fn process_new_packets(&mut self) -> io::Result<()> { - match self.conn.process_new_packets() { - Ok(_) => Ok(()), - Err(e) => { - // In case we have an alert to send describing this error, - // try a last-gasp write -- but don't predate the primary - // error. - let _ = self.write_tls(); - - Err(io::Error::new(io::ErrorKind::InvalidData, e)) - } - } + /// Returns the negotiated ALPN protocol, if any. + pub fn session(&self) -> Ref<'_, CommonState> { + let session = self.session.borrow(); + Ref::map(session, |session| &**session.conn) } +} - fn write_tls(&mut self) -> io::Result { - self.conn.write_tls(&mut self.io) +impl TlsStream +where + C: ProcessTlsRecords, + Io: AsyncBufRead + AsyncBufWrite, +{ + pub async fn handshake(io: Io, conn: C) -> io::Result { + let stream = TlsStream { + io, + session: RefCell::new(Session { + conn, + read_buf: Some(BytesMut::new()), + write_buf: Some(BytesMut::new()), + proto_write_buf: BytesMut::new(), + pending_plaintext: BytesMut::new(), + }), + }; + stream._handshake().await?; + Ok(stream) } - fn read_tls(&mut self) -> io::Result { - self.conn.read_tls(&mut self.io) + async fn _handshake(&self) -> io::Result<()> { + let mut session = self.session.borrow_mut(); + let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); + let mut proto_write_buf = session.proto_write_buf.split(); + + let res = loop { + let UnbufferedStatus { discard, state } = session.conn.process_tls_records(read_buf.as_mut()); + + let res = match state.map_err(tls_err) { + Err(e) => { + read_buf.advance(discard); + Err(e) + } + + Ok(ConnectionState::EncodeTlsData(mut state)) => { + let enc_res = encode_tls_data(&mut state, &mut proto_write_buf); + drop(state); + read_buf.advance(discard); + enc_res?; + continue; + } + + Ok(ConnectionState::TransmitTlsData(state)) => { + state.done(); + read_buf.advance(discard); + + let (res, b) = write_all_buf(&self.io, proto_write_buf).await; + proto_write_buf = b; + res + } + + Ok(ConnectionState::BlockedHandshake) => { + read_buf.advance(discard); + + let (res, b) = read_to_buf(&self.io, read_buf).await; + read_buf = b; + res + } + + Ok(ConnectionState::WriteTraffic(_) | ConnectionState::ReadTraffic(_)) => { + read_buf.advance(discard); + break Ok(()); + } + + Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { + read_buf.advance(discard); + Err(io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof")) + } + + Ok(_) => { + read_buf.advance(discard); + continue; + } + }; + + if res.is_err() { + break res; + } + }; + + session.read_buf.replace(read_buf); + session.proto_write_buf = proto_write_buf; + res } } -impl TlsStream +impl TlsStream where - C: DerefMut> + Unpin, - S: SideData, - Io: AsyncIo, + C: ProcessTlsRecords, + Io: AsyncBufRead, { - /// acquire a reference to the session type. Typically either [ClientConnection] or [ServerConnection] - /// - /// [ClientConnection]: rustls::ClientConnection - /// [ServerConnection]: rustls::ServerConnection - pub fn session(&self) -> &C { - &self.conn - } + /// Read ciphertext from IO, decrypt, and return plaintext. + async fn read_tls(&self, plain_buf: &mut impl BoundedBufMut) -> io::Result { + let mut session = self.session.borrow_mut(); - /// finish handshake with given io and connection type. - /// # Examples: - /// ```rust - /// use std::sync::Arc; - /// - /// use xitca_io::net::TcpStream; - /// use xitca_tls::rustls::{pki_types::ServerName, ClientConfig, ClientConnection, TlsStream}; - /// - /// async fn client_connect(io: TcpStream, cfg: Arc, server_name: ServerName<'static>) { - /// let conn = ClientConnection::new(cfg, server_name).unwrap(); - /// let _stream = TlsStream::handshake(io, conn).await.unwrap(); - /// } - /// ``` - pub async fn handshake(mut io: Io, mut conn: C) -> io::Result { - while conn.is_handshaking() { - if let Err(e) = conn.complete_io(&mut io) { - if !matches!(e.kind(), io::ErrorKind::WouldBlock) { - return Err(e); - } - let interest = match (conn.wants_read(), conn.wants_write()) { - (true, true) => Interest::READABLE | Interest::WRITABLE, - (true, false) => Interest::READABLE, - (false, true) => Interest::WRITABLE, - (false, false) => unreachable!(), - }; - io.ready(interest).await?; - } + // Check for plaintext buffered from a previous read first. + if !session.pending_plaintext.is_empty() { + let rem = plain_buf.bytes_total() - plain_buf.bytes_init(); + let aval = session.pending_plaintext.len(); + let len = cmp::min(rem, aval); + + plain_buf.put_slice(&session.pending_plaintext[..len]); + session.pending_plaintext.advance(len); + + return Ok(len); } - Ok(TlsStream { io, conn }) + let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); + + let res = loop { + // Call process_tls_records directly to copy record payload + // straight into the caller's buffer (no intermediate BytesMut). + let session_ref = &mut *session; + + let UnbufferedStatus { discard, state } = session_ref.conn.process_tls_records(read_buf.as_mut()); + + let res = match state.map_err(tls_err) { + Err(e) => { + read_buf.advance(discard); + break Err(e); + } + + Ok(ConnectionState::ReadTraffic(mut traffic)) => { + let rem = plain_buf.bytes_total() - plain_buf.bytes_init(); + let mut written = 0; + + let mut err = None; + while let Some(res) = traffic.next_record() { + match res.map_err(tls_err) { + Ok(record) => { + let payload = record.payload; + let len = payload.len().min(rem - written); + + let (head, tail) = payload.split_at(len); + + plain_buf.put_slice(head); + written += len; + + // Buffer overflow into pending_plaintext. + session_ref.pending_plaintext.extend_from_slice(tail); + } + Err(e) => { + err = Some(e); + break; + } + } + } + + read_buf.advance(discard); + + if let Some(e) = err { + break Err(e); + } + + // Empty plaintext means TLS overhead with no payload — keep going. + if written == 0 { + continue; + } + + break Ok(written); + } + + Ok(ConnectionState::EncodeTlsData(mut state)) => { + // Encode into proto_write_buf via session_ref (same borrow scope as state). + let enc_res = encode_tls_data(&mut state, &mut session_ref.proto_write_buf); + drop(state); + read_buf.advance(discard); + + if let Err(e) = enc_res { + break Err(e); + } + continue; + } + + Ok(ConnectionState::TransmitTlsData(state)) => { + // Data is in proto_write_buf. Acknowledge and continue — + // write_tls will flush it on the next write call. + state.done(); + read_buf.advance(discard); + continue; + } + + // Need more ciphertext. + Ok(ConnectionState::BlockedHandshake | ConnectionState::WriteTraffic(_)) => { + read_buf.advance(discard); + + drop(session); + + let (res, b) = read_to_buf(&self.io, read_buf).await; + read_buf = b; + + session = self.session.borrow_mut(); + + res + } + + Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { + read_buf.advance(discard); + break Ok(0); + } + + Ok(_) => { + read_buf.advance(discard); + continue; + } + }; + + if let Err(e) = res { + break Err(e); + } + }; + + session.read_buf.replace(read_buf); + res } } -impl AsyncIo for TlsStream +impl TlsStream where - C: DerefMut> + Unpin, - S: SideData, - Io: AsyncIo, + C: ProcessTlsRecords, + Io: AsyncBufWrite, { - #[inline] - fn ready(&mut self, interest: Interest) -> impl Future> + Send { - self.io.ready(interest) - } + /// Encrypt plaintext and write all ciphertext to IO. + async fn write_tls(&self, plain: &impl BoundedBuf) -> io::Result { + let mut session = self.session.borrow_mut(); + let mut write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); + let plaintext = plain.chunk(); - #[inline] - fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { - self.io.poll_ready(interest, cx) - } + // Flush protocol data buffered by read path (key updates, alerts). + if !session.proto_write_buf.is_empty() { + write_buf.extend_from_slice(&session.proto_write_buf); + session.proto_write_buf.clear(); + } + + let res = loop { + // Pass empty slice — write path doesn't process incoming TLS records. + // Incoming data (key updates, etc.) is handled by the read path. + let UnbufferedStatus { state, .. } = session.conn.process_tls_records(&mut []); + + match state.map_err(tls_err) { + Err(e) => break Err(e), + + Ok(ConnectionState::WriteTraffic(mut traffic)) => { + let enc_res = encrypt_to_buf(&mut traffic, plaintext, &mut write_buf); + + if let Err(e) = enc_res { + break Err(e); + } + + drop(session); + + let (res, b) = write_all_buf(&self.io, write_buf).await; + write_buf = b; + + session = self.session.borrow_mut(); - fn is_vectored_write(&self) -> bool { - self.io.is_vectored_write() + break res.map(|_| plaintext.len()); + } + + Ok(ConnectionState::EncodeTlsData(mut state)) => { + let enc_res = encode_tls_data(&mut state, &mut write_buf); + drop(state); + + if let Err(e) = enc_res { + break Err(e); + } + } + + Ok(ConnectionState::TransmitTlsData(state)) => { + state.done(); + + drop(session); + + let (res, b) = write_all_buf(&self.io, write_buf).await; + write_buf = b; + + session = self.session.borrow_mut(); + + if let Err(e) = res { + break Err(e); + } + } + + Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { + break Err(io::ErrorKind::UnexpectedEof.into()); + } + + Ok(_) => {} + } + }; + + session.write_buf.replace(write_buf); + res } +} - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - AsyncIo::poll_shutdown(Pin::new(&mut self.get_mut().io), cx) +impl AsyncBufRead for TlsStream +where + C: ProcessTlsRecords, + Io: AsyncBufRead, +{ + async fn read(&self, mut buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + let res = self.read_tls(&mut buf).await; + (res, buf) } } -impl io::Read for TlsStream +impl AsyncBufWrite for TlsStream where - C: DerefMut>, - S: SideData, - Io: AsyncIo, + C: ProcessTlsRecords, + Io: AsyncBufWrite, { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - while self.conn.wants_read() { - let n = self.read_tls()?; + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + let res = self.write_tls(&buf).await; + (res, buf) + } - self.process_new_packets()?; + async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { + self.io.shutdown(direction).await + } +} - if n == 0 { - break; - } - } - self.conn.reader().read(buf) +fn tls_err(e: Error) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, e) +} + +/// Read from IO into a BytesMut, reserving space if needed. +async fn read_to_buf(io: &impl AsyncBufRead, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { + let len = buf.len(); + buf.reserve(4096); + + let (res, b) = io.read(buf.slice(len..)).await; + buf = b.into_inner(); + + match res { + Ok(0) => (Err(io::ErrorKind::UnexpectedEof.into()), buf), + Ok(_) => (Ok(()), buf), + Err(e) => (Err(e), buf), } } -impl io::Write for TlsStream -where - C: DerefMut>, - S: SideData, - Io: AsyncIo, -{ - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - write_with(self, |writer| writer.write(buf)) +/// Write all bytes from a BytesMut to IO, then clear it. +async fn write_all_buf(io: &impl AsyncBufWrite, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { + let (res, b) = xitca_io::io::write_all(io, buf).await; + buf = b; + if res.is_ok() { + buf.clear(); } + (res, buf) +} - #[inline] - fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { - write_with(self, |writer| writer.write_vectored(bufs)) +/// Encode TLS handshake data into the write buffer, resizing if needed. +fn encode_tls_data(state: &mut unbuffered::EncodeTlsData<'_, Data>, write_buf: &mut BytesMut) -> io::Result<()> { + // SAFETY: EncodeTlsData::encode copies a single chunk contiguously from index 0. + // On Ok(n), exactly n bytes are written. On InsufficientSize or AlreadyEncoded, + // the size check happens before any write so the slice is untouched. + while let Err(e) = unsafe { SpareCapBuf::new(write_buf).with_mut_slice(|slice| state.encode(slice)) } { + match e { + unbuffered::EncodeError::InsufficientSize(unbuffered::InsufficientSizeError { required_size }) => { + write_buf.reserve(required_size); + } + e => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), + } } + Ok(()) +} - fn flush(&mut self) -> io::Result<()> { - while self.conn.wants_write() { - if self.write_tls()? == 0 { - return Err(io::ErrorKind::WriteZero.into()); +/// Encrypt plaintext into the write buffer, resizing if needed. +fn encrypt_to_buf( + traffic: &mut unbuffered::WriteTraffic<'_, Data>, + plaintext: &[u8], + write_buf: &mut BytesMut, +) -> io::Result<()> { + write_buf.reserve(plaintext.len() + 64); + // SAFETY: WriteTraffic::encrypt writes TLS records contiguously from index 0 via + // write_fragments. On Ok(n), exactly n bytes are written. On InsufficientSize, + // check_required_size returns before any write. On EncryptExhausted, the error + // is returned during pre-encryption checks before any write. + while let Err(err) = + unsafe { SpareCapBuf::new(write_buf).with_mut_slice(|spare| traffic.encrypt(plaintext, spare)) } + { + match err { + EncryptError::InsufficientSize(unbuffered::InsufficientSizeError { required_size }) => { + write_buf.reserve(required_size); } + e => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), } + } + Ok(()) +} + +/// Wraps a `BytesMut`'s spare capacity as a mutable byte slice. +/// +/// Encapsulates the unsafe operations of interpreting spare capacity as `&mut [u8]` +/// and committing written bytes via `set_len`. +struct SpareCapBuf<'a> { + buf: &'a mut BytesMut, +} + +impl<'a> SpareCapBuf<'a> { + fn new(buf: &'a mut BytesMut) -> Self { + Self { buf } + } + + /// # Safety + /// + /// The callback `func` must uphold the following contract: + /// - Writes must be sequential and contiguous, starting from index 0 of the slice. + /// - On `Ok(n)`, exactly `n` bytes must have been written to `slice[..n]`. + /// - On `Err`, zero bytes must have been written into the slice. + unsafe fn with_mut_slice(self, func: F) -> Result<(), E> + where + F: FnOnce(&mut [u8]) -> Result, + { + let spare = self.buf.spare_capacity_mut(); + + // SAFETY: the caller must write into the slice before reading. + // We only expose this for write-before-read patterns (TLS encode/encrypt). + let slice = unsafe { slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), spare.len()) }; + + let n = func(slice)?; + + // SAFETY: caller guarantees n bytes were written into the spare capacity. + unsafe { self.buf.set_len(self.buf.len() + n) }; + Ok(()) } } -fn write_with(stream: &mut TlsStream, mut func: F) -> io::Result -where - Io: AsyncIo, - C: DerefMut>, - S: SideData, - F: for<'r> FnMut(&mut Writer<'r>) -> io::Result, -{ - loop { - match func(&mut stream.conn.writer())? { - // when rustls writer write 0 it means either the input buffer is empty or it's internal - // buffer is full. check the condition and flush the io. - 0 if stream.conn.wants_write() => io::Write::flush(stream)?, - n => return Ok(n), - } +const POLL_TO_COMPLETE: &str = "previous call to future didn't polled to completion"; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn spare_cap_buf_write_and_commit() { + let mut buf = BytesMut::with_capacity(64); + buf.extend_from_slice(b"hello"); + + let res = unsafe { + SpareCapBuf::new(&mut buf).with_mut_slice(|slice| { + assert!(slice.len() >= 59); + slice[..5].copy_from_slice(b"world"); + Ok::<_, ()>(5) + }) + }; + assert!(res.is_ok()); + assert_eq!(&buf[..], b"helloworld"); + } + + #[test] + fn spare_cap_buf_commit_zero() { + let mut buf = BytesMut::with_capacity(16); + buf.extend_from_slice(b"abc"); + + let res = unsafe { SpareCapBuf::new(&mut buf).with_mut_slice(|_| Ok::<_, ()>(0)) }; + assert!(res.is_ok()); + assert_eq!(&buf[..], b"abc"); + } + + #[test] + fn spare_cap_buf_error_no_commit() { + let mut buf = BytesMut::with_capacity(16); + buf.extend_from_slice(b"abc"); + + let res = unsafe { SpareCapBuf::new(&mut buf).with_mut_slice(|_| Err::("too small")) }; + assert!(res.is_err()); + assert_eq!(&buf[..], b"abc"); } } diff --git a/tls/src/rustls_complete.rs b/tls/src/rustls_complete.rs deleted file mode 100644 index 617c76cd4..000000000 --- a/tls/src/rustls_complete.rs +++ /dev/null @@ -1,558 +0,0 @@ -#![allow(clippy::await_holding_refcell_ref)] // clippy is dumb - -use core::{cell::RefCell, cmp, slice}; - -use std::{io, net::Shutdown}; - -pub use rustls_crate::*; - -use rustls_crate::{ - client::UnbufferedClientConnection, - server::UnbufferedServerConnection, - unbuffered::UnbufferedConnectionCommon, - unbuffered::{ConnectionState, EncryptError, UnbufferedStatus}, -}; - -use xitca_io::{ - bytes::{Buf, BytesMut}, - io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}, -}; - -/// Trait to abstract over `UnbufferedServerConnection` and `UnbufferedClientConnection`, -/// since `process_tls_records` is not on a shared trait in rustls. -#[doc(hidden)] -pub trait ProcessTlsRecords: sealed::Sealed { - type Data; - - fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data>; -} - -mod sealed { - pub trait Sealed {} - impl Sealed for super::UnbufferedServerConnection {} - impl Sealed for super::UnbufferedClientConnection {} -} - -impl ProcessTlsRecords for UnbufferedServerConnection { - type Data = server::ServerConnectionData; - - fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data> { - let inner: &mut UnbufferedConnectionCommon = self; - inner.process_tls_records(incoming_tls) - } -} - -impl ProcessTlsRecords for UnbufferedClientConnection { - type Data = client::ClientConnectionData; - - fn process_tls_records<'c, 'i>(&'c mut self, incoming_tls: &'i mut [u8]) -> UnbufferedStatus<'c, 'i, Self::Data> { - let inner: &mut UnbufferedConnectionCommon = self; - inner.process_tls_records(incoming_tls) - } -} - -/// Reduced `ConnectionState` that doesn't borrow the connection or incoming buffer. -/// Created by draining all needed data from the borrowed state variants. -/// A TLS stream type that supports concurrent async read/write through [AsyncBufRead] and -/// [AsyncBufWrite] traits. -/// -/// [AsyncBufRead::read] and [AsyncBufWrite::write] can be polled concurrently from separate -/// tasks. The read path owns `read_buf` during IO and the write path owns `write_buf`, so -/// neither blocks the other while awaiting kernel completions. -/// -/// # Panics -/// Each async read/write operation must be polled to completion. Dropping a future before it -/// completes will leave internal buffers in a taken state, causing the next call to panic. -pub struct TlsStream { - io: Io, - session: RefCell>, -} - -struct Session { - conn: C, - read_buf: Option, - /// Write buffer for application data (used by write path). - write_buf: Option, - /// Write buffer for TLS protocol responses during reads (key updates, alerts). - proto_write_buf: BytesMut, - /// Plaintext buffered from a previous read. - pending_plaintext: BytesMut, -} - -impl TlsStream -where - C: ProcessTlsRecords, - Io: AsyncBufRead + AsyncBufWrite, -{ - pub async fn handshake(io: Io, conn: C) -> io::Result { - let stream = TlsStream { - io, - session: RefCell::new(Session { - conn, - read_buf: Some(BytesMut::new()), - write_buf: Some(BytesMut::new()), - proto_write_buf: BytesMut::new(), - pending_plaintext: BytesMut::new(), - }), - }; - stream._handshake().await?; - Ok(stream) - } - - async fn _handshake(&self) -> io::Result<()> { - let mut session = self.session.borrow_mut(); - let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); - let mut proto_write_buf = session.proto_write_buf.split(); - - let res = loop { - let UnbufferedStatus { discard, state } = session.conn.process_tls_records(read_buf.as_mut()); - - let res = match state.map_err(tls_err) { - Err(e) => { - read_buf.advance(discard); - Err(e) - } - - Ok(ConnectionState::EncodeTlsData(mut state)) => { - let enc_res = encode_tls_data(&mut state, &mut proto_write_buf); - drop(state); - read_buf.advance(discard); - enc_res?; - continue; - } - - Ok(ConnectionState::TransmitTlsData(state)) => { - state.done(); - read_buf.advance(discard); - - let (res, b) = write_all_buf(&self.io, proto_write_buf).await; - proto_write_buf = b; - res - } - - Ok(ConnectionState::BlockedHandshake) => { - read_buf.advance(discard); - - let (res, b) = read_to_buf(&self.io, read_buf).await; - read_buf = b; - res - } - - Ok(ConnectionState::WriteTraffic(_) | ConnectionState::ReadTraffic(_)) => { - read_buf.advance(discard); - break Ok(()); - } - - Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { - read_buf.advance(discard); - Err(io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof")) - } - - Ok(_) => { - read_buf.advance(discard); - continue; - } - }; - - if res.is_err() { - break res; - } - }; - - session.read_buf.replace(read_buf); - session.proto_write_buf = proto_write_buf; - res - } -} - -impl TlsStream -where - C: ProcessTlsRecords, - Io: AsyncBufRead, -{ - /// Read ciphertext from IO, decrypt, and return plaintext. - async fn read_tls(&self, plain_buf: &mut impl BoundedBufMut) -> io::Result { - let mut session = self.session.borrow_mut(); - - // Check for plaintext buffered from a previous read first. - if !session.pending_plaintext.is_empty() { - let rem = plain_buf.bytes_total() - plain_buf.bytes_init(); - let aval = session.pending_plaintext.len(); - let len = cmp::min(rem, aval); - - plain_buf.put_slice(&session.pending_plaintext[..len]); - session.pending_plaintext.advance(len); - - return Ok(len); - } - - let mut read_buf = session.read_buf.take().expect(POLL_TO_COMPLETE); - - let res = loop { - // Call process_tls_records directly to copy record payload - // straight into the caller's buffer (no intermediate BytesMut). - let session_ref = &mut *session; - - let UnbufferedStatus { discard, state } = session_ref.conn.process_tls_records(read_buf.as_mut()); - - let res = match state.map_err(tls_err) { - Err(e) => { - read_buf.advance(discard); - break Err(e); - } - - Ok(ConnectionState::ReadTraffic(mut traffic)) => { - let rem = plain_buf.bytes_total() - plain_buf.bytes_init(); - let mut written = 0; - - let mut err = None; - while let Some(res) = traffic.next_record() { - match res.map_err(tls_err) { - Ok(record) => { - let payload = record.payload; - let len = payload.len().min(rem - written); - - let (head, tail) = payload.split_at(len); - - plain_buf.put_slice(head); - written += len; - - // Buffer overflow into pending_plaintext. - session_ref.pending_plaintext.extend_from_slice(tail); - } - Err(e) => { - err = Some(e); - break; - } - } - } - - read_buf.advance(discard); - - if let Some(e) = err { - break Err(e); - } - - // Empty plaintext means TLS overhead with no payload — keep going. - if written == 0 { - continue; - } - - break Ok(written); - } - - Ok(ConnectionState::EncodeTlsData(mut state)) => { - // Encode into proto_write_buf via session_ref (same borrow scope as state). - let enc_res = encode_tls_data(&mut state, &mut session_ref.proto_write_buf); - drop(state); - read_buf.advance(discard); - - if let Err(e) = enc_res { - break Err(e); - } - continue; - } - - Ok(ConnectionState::TransmitTlsData(state)) => { - // Data is in proto_write_buf. Acknowledge and continue — - // write_tls will flush it on the next write call. - state.done(); - read_buf.advance(discard); - continue; - } - - // Need more ciphertext. - Ok(ConnectionState::BlockedHandshake | ConnectionState::WriteTraffic(_)) => { - read_buf.advance(discard); - - drop(session); - - let (res, b) = read_to_buf(&self.io, read_buf).await; - read_buf = b; - - session = self.session.borrow_mut(); - - res - } - - Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { - read_buf.advance(discard); - break Ok(0); - } - - Ok(_) => { - read_buf.advance(discard); - continue; - } - }; - - if let Err(e) = res { - break Err(e); - } - }; - - session.read_buf.replace(read_buf); - res - } -} - -impl TlsStream -where - C: ProcessTlsRecords, - Io: AsyncBufWrite, -{ - /// Encrypt plaintext and write all ciphertext to IO. - async fn write_tls(&self, plain: &impl BoundedBuf) -> io::Result { - let mut session = self.session.borrow_mut(); - let mut write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); - let plaintext = plain.chunk(); - - // Flush protocol data buffered by read path (key updates, alerts). - if !session.proto_write_buf.is_empty() { - write_buf.extend_from_slice(&session.proto_write_buf); - session.proto_write_buf.clear(); - } - - let res = loop { - // Pass empty slice — write path doesn't process incoming TLS records. - // Incoming data (key updates, etc.) is handled by the read path. - let UnbufferedStatus { state, .. } = session.conn.process_tls_records(&mut []); - - match state.map_err(tls_err) { - Err(e) => break Err(e), - - Ok(ConnectionState::WriteTraffic(mut traffic)) => { - let enc_res = encrypt_to_buf(&mut traffic, plaintext, &mut write_buf); - - if let Err(e) = enc_res { - break Err(e); - } - - drop(session); - - let (res, b) = write_all_buf(&self.io, write_buf).await; - write_buf = b; - - session = self.session.borrow_mut(); - - break res.map(|_| plaintext.len()); - } - - Ok(ConnectionState::EncodeTlsData(mut state)) => { - let enc_res = encode_tls_data(&mut state, &mut write_buf); - drop(state); - - if let Err(e) = enc_res { - break Err(e); - } - } - - Ok(ConnectionState::TransmitTlsData(state)) => { - state.done(); - - drop(session); - - let (res, b) = write_all_buf(&self.io, write_buf).await; - write_buf = b; - - session = self.session.borrow_mut(); - - if let Err(e) = res { - break Err(e); - } - } - - Ok(ConnectionState::PeerClosed | ConnectionState::Closed) => { - break Err(io::ErrorKind::UnexpectedEof.into()); - } - - Ok(_) => {} - } - }; - - session.write_buf.replace(write_buf); - res - } -} - -impl AsyncBufRead for TlsStream -where - C: ProcessTlsRecords, - Io: AsyncBufRead, -{ - async fn read(&self, mut buf: B) -> (io::Result, B) - where - B: BoundedBufMut, - { - let res = self.read_tls(&mut buf).await; - (res, buf) - } -} - -impl AsyncBufWrite for TlsStream -where - C: ProcessTlsRecords, - Io: AsyncBufWrite, -{ - async fn write(&self, buf: B) -> (io::Result, B) - where - B: BoundedBuf, - { - let res = self.write_tls(&buf).await; - (res, buf) - } - - async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { - self.io.shutdown(direction).await - } -} - -fn tls_err(e: Error) -> io::Error { - io::Error::new(io::ErrorKind::InvalidData, e) -} - -/// Read from IO into a BytesMut, reserving space if needed. -async fn read_to_buf(io: &impl AsyncBufRead, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { - let len = buf.len(); - buf.reserve(4096); - - let (res, b) = io.read(buf.slice(len..)).await; - buf = b.into_inner(); - - match res { - Ok(0) => (Err(io::ErrorKind::UnexpectedEof.into()), buf), - Ok(_) => (Ok(()), buf), - Err(e) => (Err(e), buf), - } -} - -/// Write all bytes from a BytesMut to IO, then clear it. -async fn write_all_buf(io: &impl AsyncBufWrite, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { - let (res, b) = xitca_io::io::write_all(io, buf).await; - buf = b; - if res.is_ok() { - buf.clear(); - } - (res, buf) -} - -/// Encode TLS handshake data into the write buffer, resizing if needed. -fn encode_tls_data(state: &mut unbuffered::EncodeTlsData<'_, Data>, write_buf: &mut BytesMut) -> io::Result<()> { - // SAFETY: EncodeTlsData::encode copies a single chunk contiguously from index 0. - // On Ok(n), exactly n bytes are written. On InsufficientSize or AlreadyEncoded, - // the size check happens before any write so the slice is untouched. - while let Err(e) = unsafe { SpareCapBuf::new(write_buf).with_mut_slice(|slice| state.encode(slice)) } { - match e { - unbuffered::EncodeError::InsufficientSize(unbuffered::InsufficientSizeError { required_size }) => { - write_buf.reserve(required_size); - } - e => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - } - } - Ok(()) -} - -/// Encrypt plaintext into the write buffer, resizing if needed. -fn encrypt_to_buf( - traffic: &mut unbuffered::WriteTraffic<'_, Data>, - plaintext: &[u8], - write_buf: &mut BytesMut, -) -> io::Result<()> { - write_buf.reserve(plaintext.len() + 64); - // SAFETY: WriteTraffic::encrypt writes TLS records contiguously from index 0 via - // write_fragments. On Ok(n), exactly n bytes are written. On InsufficientSize, - // check_required_size returns before any write. On EncryptExhausted, the error - // is returned during pre-encryption checks before any write. - while let Err(err) = - unsafe { SpareCapBuf::new(write_buf).with_mut_slice(|spare| traffic.encrypt(plaintext, spare)) } - { - match err { - EncryptError::InsufficientSize(unbuffered::InsufficientSizeError { required_size }) => { - write_buf.reserve(required_size); - } - e => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())), - } - } - Ok(()) -} - -/// Wraps a `BytesMut`'s spare capacity as a mutable byte slice. -/// -/// Encapsulates the unsafe operations of interpreting spare capacity as `&mut [u8]` -/// and committing written bytes via `set_len`. -struct SpareCapBuf<'a> { - buf: &'a mut BytesMut, -} - -impl<'a> SpareCapBuf<'a> { - fn new(buf: &'a mut BytesMut) -> Self { - Self { buf } - } - - /// # Safety - /// - /// The callback `func` must uphold the following contract: - /// - Writes must be sequential and contiguous, starting from index 0 of the slice. - /// - On `Ok(n)`, exactly `n` bytes must have been written to `slice[..n]`. - /// - On `Err`, zero bytes must have been written into the slice. - unsafe fn with_mut_slice(self, func: F) -> Result<(), E> - where - F: FnOnce(&mut [u8]) -> Result, - { - let spare = self.buf.spare_capacity_mut(); - - // SAFETY: the caller must write into the slice before reading. - // We only expose this for write-before-read patterns (TLS encode/encrypt). - let slice = unsafe { slice::from_raw_parts_mut(spare.as_mut_ptr().cast::(), spare.len()) }; - - let n = func(slice)?; - - // SAFETY: caller guarantees n bytes were written into the spare capacity. - unsafe { self.buf.set_len(self.buf.len() + n) }; - - Ok(()) - } -} - -const POLL_TO_COMPLETE: &str = "previous call to future dropped before polling to completion"; - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn spare_cap_buf_write_and_commit() { - let mut buf = BytesMut::with_capacity(64); - buf.extend_from_slice(b"hello"); - - let res = unsafe { - SpareCapBuf::new(&mut buf).with_mut_slice(|slice| { - assert!(slice.len() >= 59); - slice[..5].copy_from_slice(b"world"); - Ok::<_, ()>(5) - }) - }; - assert!(res.is_ok()); - assert_eq!(&buf[..], b"helloworld"); - } - - #[test] - fn spare_cap_buf_commit_zero() { - let mut buf = BytesMut::with_capacity(16); - buf.extend_from_slice(b"abc"); - - let res = unsafe { SpareCapBuf::new(&mut buf).with_mut_slice(|_| Ok::<_, ()>(0)) }; - assert!(res.is_ok()); - assert_eq!(&buf[..], b"abc"); - } - - #[test] - fn spare_cap_buf_error_no_commit() { - let mut buf = BytesMut::with_capacity(16); - buf.extend_from_slice(b"abc"); - - let res = unsafe { SpareCapBuf::new(&mut buf).with_mut_slice(|_| Err::("too small")) }; - assert!(res.is_err()); - assert_eq!(&buf[..], b"abc"); - } -} diff --git a/tls/src/rustls_poll.rs b/tls/src/rustls_poll.rs new file mode 100644 index 000000000..1c9d61ee0 --- /dev/null +++ b/tls/src/rustls_poll.rs @@ -0,0 +1,182 @@ +use core::{ + future::Future, + ops::DerefMut, + pin::Pin, + task::{Context, Poll}, +}; + +use std::io; + +pub use rustls_crate::*; + +use xitca_io::io::{AsyncIo, Interest, Ready}; + +/// A stream managed by `rustls` crate for tls read/write. +pub struct TlsStream { + conn: C, + io: Io, +} + +impl TlsStream +where + C: DerefMut>, + S: SideData, + Io: io::Read + io::Write, +{ + fn process_new_packets(&mut self) -> io::Result<()> { + match self.conn.process_new_packets() { + Ok(_) => Ok(()), + Err(e) => { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_tls(); + + Err(io::Error::new(io::ErrorKind::InvalidData, e)) + } + } + } + + fn write_tls(&mut self) -> io::Result { + self.conn.write_tls(&mut self.io) + } + + fn read_tls(&mut self) -> io::Result { + self.conn.read_tls(&mut self.io) + } +} + +impl TlsStream +where + C: DerefMut> + Unpin, + S: SideData, + Io: AsyncIo, +{ + /// acquire a reference to the session type. Typically either [ClientConnection] or [ServerConnection] + /// + /// [ClientConnection]: rustls::ClientConnection + /// [ServerConnection]: rustls::ServerConnection + pub fn session(&self) -> &C { + &self.conn + } + + /// finish handshake with given io and connection type. + /// # Examples: + /// ```rust + /// use std::sync::Arc; + /// + /// use xitca_io::net::TcpStream; + /// use xitca_tls::rustls::{pki_types::ServerName, ClientConfig, ClientConnection, TlsStream}; + /// + /// async fn client_connect(io: TcpStream, cfg: Arc, server_name: ServerName<'static>) { + /// let conn = ClientConnection::new(cfg, server_name).unwrap(); + /// let _stream = TlsStream::handshake(io, conn).await.unwrap(); + /// } + /// ``` + pub async fn handshake(mut io: Io, mut conn: C) -> io::Result { + while conn.is_handshaking() { + if let Err(e) = conn.complete_io(&mut io) { + if !matches!(e.kind(), io::ErrorKind::WouldBlock) { + return Err(e); + } + let interest = match (conn.wants_read(), conn.wants_write()) { + (true, true) => Interest::READABLE | Interest::WRITABLE, + (true, false) => Interest::READABLE, + (false, true) => Interest::WRITABLE, + (false, false) => unreachable!(), + }; + io.ready(interest).await?; + } + } + + Ok(TlsStream { io, conn }) + } +} + +impl AsyncIo for TlsStream +where + C: DerefMut> + Unpin, + S: SideData, + Io: AsyncIo, +{ + #[inline] + fn ready(&mut self, interest: Interest) -> impl Future> + Send { + self.io.ready(interest) + } + + #[inline] + fn poll_ready(&mut self, interest: Interest, cx: &mut Context<'_>) -> Poll> { + self.io.poll_ready(interest, cx) + } + + fn is_vectored_write(&self) -> bool { + self.io.is_vectored_write() + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncIo::poll_shutdown(Pin::new(&mut self.get_mut().io), cx) + } +} + +impl io::Read for TlsStream +where + C: DerefMut>, + S: SideData, + Io: AsyncIo, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + while self.conn.wants_read() { + let n = self.read_tls()?; + + self.process_new_packets()?; + + if n == 0 { + break; + } + } + self.conn.reader().read(buf) + } +} + +impl io::Write for TlsStream +where + C: DerefMut>, + S: SideData, + Io: AsyncIo, +{ + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + write_with(self, |writer| writer.write(buf)) + } + + #[inline] + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + write_with(self, |writer| writer.write_vectored(bufs)) + } + + fn flush(&mut self) -> io::Result<()> { + while self.conn.wants_write() { + if self.write_tls()? == 0 { + return Err(io::ErrorKind::WriteZero.into()); + } + } + Ok(()) + } +} + +fn write_with(stream: &mut TlsStream, mut func: F) -> io::Result +where + Io: AsyncIo, + C: DerefMut>, + S: SideData, + F: for<'r> FnMut(&mut Writer<'r>) -> io::Result, +{ + loop { + match func(&mut stream.conn.writer())? { + // when rustls writer write 0 it means either the input buffer is empty or it's internal + // buffer is full. check the condition and flush the io. + 0 if stream.conn.wants_write() => io::Write::flush(stream)?, + n => return Ok(n), + } + } +} From 145eaf42ddd16f308f804ac73dbeef48720ab467 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 1 Apr 2026 23:12:56 +0800 Subject: [PATCH 15/21] fix test --- client/Cargo.toml | 12 ++++++------ client/src/builder.rs | 23 ++++++++++++----------- client/src/tls/connector.rs | 8 ++++---- http/src/tls/mod.rs | 2 ++ test/Cargo.toml | 6 +++--- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/client/Cargo.toml b/client/Cargo.toml index 12dd67932..d898ecdd4 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -17,11 +17,11 @@ http2 = ["h2", "itoa", "xitca-http/http2"] # http/3 client(tls always enabled with rustls) http3 = ["h3", "h3-quinn", "quinn", "itoa", "async-stream", "rustls-ring-crypto"] # openssl as http/1 and http/2 tls handler -openssl = ["xitca-tls/openssl"] +openssl = ["xitca-tls/openssl-poll"] # rustls as http/1 and http/2 tls handler -rustls = ["xitca-tls/rustls", "webpki-roots"] +rustls = ["xitca-tls/rustls-poll-aws-crypto", "webpki-roots"] # rustls as tls handler with ring as crypto provider -rustls-ring-crypto = ["xitca-tls/rustls-ring-crypto", "webpki-roots"] +rustls-ring-crypto = ["xitca-tls/rustls-poll-ring-crypto", "webpki-roots"] # compression and decompression middleware support compress = ["http-encoding"] # json response body parsing support @@ -36,8 +36,8 @@ multipart = ["dep:http-multipart"] dangerous = [] [dependencies] -xitca-http = { version = "0.8.0", default-features = false, features = ["runtime"] } -xitca-io = "0.5.1" +xitca-http = { version = "0.9.0", default-features = false, features = ["runtime"] } +xitca-io = "0.6.0" xitca-unsafe-collection = "0.2.0" futures-core = { version = "0.3.17", default-features = false } @@ -62,7 +62,7 @@ async-stream = { version = "0.3", optional = true } itoa = { version = "1", optional = true } # tls shared -xitca-tls = { version = "0.5.0", optional = true } +xitca-tls = { version = "0.6.0", optional = true } # rustls, http3 and dangerous features shared webpki-roots = { version = "1", optional = true } diff --git a/client/src/builder.rs b/client/src/builder.rs index f709077dc..c4878cf1f 100644 --- a/client/src/builder.rs +++ b/client/src/builder.rs @@ -389,7 +389,7 @@ impl ClientBuilder { use h3_quinn::quinn::Endpoint; use webpki_roots::TLS_SERVER_ROOTS; - use xitca_tls::rustls::{ClientConfig, RootCertStore}; + use xitca_tls::rustls_poll::{ClientConfig, RootCertStore}; let mut root_store = RootCertStore::empty(); @@ -403,7 +403,7 @@ impl ClientBuilder { #[cfg(feature = "dangerous")] { - use xitca_tls::rustls::{ + use xitca_tls::rustls_poll::{ self, DigitallySignedStruct, client::danger::HandshakeSignatureValid, crypto::{verify_tls12_signature, verify_tls13_signature}, @@ -419,7 +419,7 @@ impl ClientBuilder { } } - impl rustls::client::danger::ServerCertVerifier for SkipServerVerification { + impl rustls_poll::client::danger::ServerCertVerifier for SkipServerVerification { fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, @@ -427,8 +427,9 @@ impl ClientBuilder { _server_name: &ServerName<'_>, _ocsp: &[u8], _now: UnixTime, - ) -> Result { - Ok(rustls::client::danger::ServerCertVerified::assertion()) + ) -> Result + { + Ok(rustls_poll::client::danger::ServerCertVerified::assertion()) } fn verify_tls12_signature( @@ -436,12 +437,12 @@ impl ClientBuilder { message: &[u8], cert: &CertificateDer<'_>, dss: &DigitallySignedStruct, - ) -> Result { + ) -> Result { verify_tls12_signature( message, cert, dss, - &rustls::crypto::ring::default_provider().signature_verification_algorithms, + &rustls_poll::crypto::ring::default_provider().signature_verification_algorithms, ) } @@ -450,17 +451,17 @@ impl ClientBuilder { message: &[u8], cert: &CertificateDer<'_>, dss: &DigitallySignedStruct, - ) -> Result { + ) -> Result { verify_tls13_signature( message, cert, dss, - &rustls::crypto::ring::default_provider().signature_verification_algorithms, + &rustls_poll::crypto::ring::default_provider().signature_verification_algorithms, ) } - fn supported_verify_schemes(&self) -> Vec { - rustls::crypto::ring::default_provider() + fn supported_verify_schemes(&self) -> Vec { + rustls_poll::crypto::ring::default_provider() .signature_verification_algorithms .supported_schemes() } diff --git a/client/src/tls/connector.rs b/client/src/tls/connector.rs index af7fdedea..cdd181bc3 100644 --- a/client/src/tls/connector.rs +++ b/client/src/tls/connector.rs @@ -39,7 +39,7 @@ pub(crate) fn nop() -> Connector { #[cfg(feature = "openssl")] pub(crate) mod openssl { use xitca_http::bytes::BufMut; - use xitca_tls::openssl::{ + use xitca_tls::openssl_poll::{ self, ssl::{SslConnector, SslMethod}, }; @@ -52,7 +52,7 @@ pub(crate) mod openssl { async fn call(&self, (name, io): (&'n str, TlsStream)) -> Result { let ssl = self.configure()?.into_ssl(name)?; - let stream = openssl::TlsStream::connect(ssl, io).await?; + let stream = openssl_poll::TlsStream::connect(ssl, io).await?; let version = stream .session() @@ -90,7 +90,7 @@ pub(crate) mod rustls { use std::sync::Arc; use webpki_roots::TLS_SERVER_ROOTS; - use xitca_tls::rustls::{self, ClientConfig, ClientConnection, RootCertStore, pki_types::ServerName}; + use xitca_tls::rustls_poll::{self, ClientConfig, ClientConnection, RootCertStore, pki_types::ServerName}; use super::*; @@ -107,7 +107,7 @@ pub(crate) mod rustls { let conn = ClientConnection::new(self.0.clone(), name).unwrap(); - let stream = rustls::TlsStream::handshake(io, conn) + let stream = rustls_poll::TlsStream::handshake(io, conn) .await .map_err(crate::error::RustlsError::Io)?; diff --git a/http/src/tls/mod.rs b/http/src/tls/mod.rs index 42d14f1d2..5d0ee4529 100644 --- a/http/src/tls/mod.rs +++ b/http/src/tls/mod.rs @@ -17,6 +17,8 @@ pub use error::TlsError; use xitca_service::Service; +// TODO: remove this allow +#[allow(dead_code)] /// A NoOp Tls Acceptor pass through input Stream type. #[derive(Copy, Clone)] pub struct NoOpTlsAcceptorBuilder; diff --git a/test/Cargo.toml b/test/Cargo.toml index c6a64c5ff..007c2e311 100644 --- a/test/Cargo.toml +++ b/test/Cargo.toml @@ -11,10 +11,10 @@ io-uring = ["xitca-http/io-uring", "xitca-server/io-uring"] [dependencies] xitca-client = { version = "0.1", features = ["http2", "http3", "websocket", "dangerous"] } -xitca-http = { version = "0.8.0", features = ["http2", "http3"] } +xitca-http = { version = "0.9.0", features = ["http2", "http3"] } xitca-codegen = "0.4" -xitca-io = "0.5.1" -xitca-server = { version = "0.6.1", features = ["quic"] } +xitca-io = "0.6.0" +xitca-server = { version = "0.7.0", features = ["quic"] } xitca-service = "0.3.0" xitca-unsafe-collection = "0.2" xitca-web = { version = "0.8", features = ["codegen"] } From 6089619735bc6f08e2ce4aa2f5e572d29613ef5b Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 1 Apr 2026 23:19:33 +0800 Subject: [PATCH 16/21] fix cargo check --- http/Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/http/Cargo.toml b/http/Cargo.toml index 6c08bf65a..d37f614f9 100644 --- a/http/Cargo.toml +++ b/http/Cargo.toml @@ -23,11 +23,11 @@ http2 = ["h2", "fnv", "futures-util/alloc", "runtime", "slab"] # http3 specific feature. http3 = ["xitca-io/quic", "futures-util/alloc", "h3", "h3-quinn", "runtime"] # openssl as server side tls. -openssl = ["xitca-tls/openssl"] +openssl = ["xitca-tls/openssl", "runtime"] # rustls as server side tls. -rustls = ["xitca-tls/rustls"] +rustls = ["xitca-tls/rustls", "runtime"] # rustls as server side tls. -native-tls = ["xitca-tls/native-tls"] +native-tls = ["xitca-tls/native-tls", "runtime"] # async runtime feature. runtime = ["xitca-io/runtime", "tokio"] From 9b5faacee1e038b03d99547b32b6c52564bf9381 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 1 Apr 2026 23:31:07 +0800 Subject: [PATCH 17/21] fix test --- test/tests/h2spec.rs | 10 +++++++--- tls/src/rustls_poll.rs | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/tests/h2spec.rs b/test/tests/h2spec.rs index d1aa983fd..ccb7643cb 100644 --- a/test/tests/h2spec.rs +++ b/test/tests/h2spec.rs @@ -14,6 +14,7 @@ mod inner { convert::Infallible, pin::Pin, task::{Context, Poll}, + time::Duration, }; use std::{net::TcpListener, process::Command}; @@ -77,9 +78,12 @@ mod inner { std::thread::spawn(move || { let service = fn_service(handler).enclosed( - HttpServiceBuilder::h2() - .io_uring() - .config(HttpServiceConfig::new().h2_max_concurrent_streams(2)), + HttpServiceBuilder::h2().io_uring().config( + HttpServiceConfig::new() + .request_head_timeout(Duration::from_mins(1)) + .keep_alive_timeout(Duration::from_mins(1)) + .h2_max_concurrent_streams(2), + ), ); let server = xitca_server::Builder::new() diff --git a/tls/src/rustls_poll.rs b/tls/src/rustls_poll.rs index 1c9d61ee0..b8db1b2e3 100644 --- a/tls/src/rustls_poll.rs +++ b/tls/src/rustls_poll.rs @@ -66,7 +66,7 @@ where /// use std::sync::Arc; /// /// use xitca_io::net::TcpStream; - /// use xitca_tls::rustls::{pki_types::ServerName, ClientConfig, ClientConnection, TlsStream}; + /// use xitca_tls::rustls_poll::{pki_types::ServerName, ClientConfig, ClientConnection, TlsStream}; /// /// async fn client_connect(io: TcpStream, cfg: Arc, server_name: ServerName<'static>) { /// let conn = ClientConnection::new(cfg, server_name).unwrap(); From 47aefca82b5a33016d0967996fa1a08bb5f6ce02 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Wed, 1 Apr 2026 23:36:33 +0800 Subject: [PATCH 18/21] fix test --- http/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/http/Cargo.toml b/http/Cargo.toml index d37f614f9..e1c798abf 100644 --- a/http/Cargo.toml +++ b/http/Cargo.toml @@ -74,7 +74,7 @@ socket2 = { version = "0.6.0", features = ["all"] } [dev-dependencies] criterion = "0.8" -xitca-server = "0.6.1" +xitca-server = "0.7.0" [[bench]] name = "h1_decode" From 45a8fe2780280c7c44135b827aa1d8c75d7bde55 Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Thu, 2 Apr 2026 15:06:13 +0800 Subject: [PATCH 19/21] wip --- examples/io-uring-h2/Cargo.toml | 10 +- examples/io-uring-h2/src/main.rs | 50 +++--- http/src/h1/dispatcher.rs | 9 +- http/src/h2/dispatcher_uring.rs | 158 ++++++++---------- io/Cargo.toml | 4 +- io/src/io/complete.rs | 52 +----- io/src/lib.rs | 2 +- io/src/net.rs | 74 +------- io/src/net/io_uring.rs | 90 ++-------- postgres/Cargo.toml | 6 + postgres/benches/fortune.rs | 158 ++++++++++++++++++ server/Cargo.toml | 2 +- tls/src/bridge.rs | 2 +- tls/src/native_tls.rs | 28 +++- tls/src/openssl.rs | 49 +++++- tls/src/rustls.rs | 66 ++++++-- tokio-uring/CHANGELOG.md | 15 +- tokio-uring/Cargo.toml | 3 +- tokio-uring/src/buf/mod.rs | 2 +- tokio-uring/src/io/async_buf_read.rs | 20 +++ .../src/io/async_buf_read_write_impl.rs | 101 +++++++++++ tokio-uring/src/io/async_buf_write.rs | 56 +++++++ tokio-uring/src/io/mod.rs | 48 ++++++ tokio-uring/src/lib.rs | 35 ++-- tokio-uring/src/net/tcp/stream.rs | 31 +++- tokio-uring/src/net/unix/stream.rs | 41 ++++- 26 files changed, 747 insertions(+), 365 deletions(-) create mode 100644 postgres/benches/fortune.rs create mode 100644 tokio-uring/src/io/async_buf_read.rs create mode 100644 tokio-uring/src/io/async_buf_read_write_impl.rs create mode 100644 tokio-uring/src/io/async_buf_write.rs diff --git a/examples/io-uring-h2/Cargo.toml b/examples/io-uring-h2/Cargo.toml index 8806d8101..cb90de071 100644 --- a/examples/io-uring-h2/Cargo.toml +++ b/examples/io-uring-h2/Cargo.toml @@ -5,17 +5,17 @@ authors = ["fakeshadow <24548779@qq.com>"] edition = "2024" [dependencies] -xitca-http = { version = "0.9", features = ["http2", "io-uring", "router"] } +xitca-http = { version = "0.9", features = ["http2", "io-uring", "router", "openssl"] } xitca-server = { version = "0.7", features = ["io-uring"] } xitca-service = "0.3" futures-core = "0.3" mimalloc = { version = "0.1.48", default-features = false, features = ["v3"] } -# openssl = "0.10.44" -rcgen = "0.14" -rustls = "0.23" -rustls-pki-types = "1" +openssl = "0.10.44" +# rcgen = "0.14" +# rustls = "0.23" +# rustls-pki-types = "1" [profile.release] diff --git a/examples/io-uring-h2/src/main.rs b/examples/io-uring-h2/src/main.rs index e97a73cf9..ca63df67c 100644 --- a/examples/io-uring-h2/src/main.rs +++ b/examples/io-uring-h2/src/main.rs @@ -28,7 +28,7 @@ fn main() -> io::Result<()> { "http/2", "127.0.0.1:8080", fn_service(handler).enclosed( - HttpServiceBuilder::h2().io_uring(), // specify io_uring flavor of http service. + HttpServiceBuilder::h2().io_uring().openssl(tls_config()?), // specify io_uring flavor of http service. ), )? .build() @@ -89,27 +89,27 @@ impl Stream for Once { // std::sync::Arc::new(config) // } -// use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; -// fn tls_config() -> io::Result { -// // set up openssl and alpn protocol. -// let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; -// builder.set_private_key_file("../cert/key.pem", SslFiletype::PEM)?; -// builder.set_certificate_chain_file("../cert/cert.pem")?; - -// builder.set_alpn_select_callback(|_, protocols| { -// const H2: &[u8] = b"\x02h2"; -// const H11: &[u8] = b"\x08http/1.1"; - -// if protocols.windows(3).any(|window| window == H2) { -// Ok(b"h2") -// } else if protocols.windows(9).any(|window| window == H11) { -// Ok(b"http/1.1") -// } else { -// Err(AlpnError::NOACK) -// } -// }); - -// builder.set_alpn_protos(b"\x08http/1.1\x02h2")?; - -// Ok(builder.build()) -// } +use openssl::ssl::{AlpnError, SslAcceptor, SslFiletype, SslMethod}; +fn tls_config() -> io::Result { + // set up openssl and alpn protocol. + let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; + builder.set_private_key_file("../cert/key.pem", SslFiletype::PEM)?; + builder.set_certificate_chain_file("../cert/cert.pem")?; + + builder.set_alpn_select_callback(|_, protocols| { + const H2: &[u8] = b"\x02h2"; + const H11: &[u8] = b"\x08http/1.1"; + + if protocols.windows(3).any(|window| window == H2) { + Ok(b"h2") + } else if protocols.windows(9).any(|window| window == H11) { + Ok(b"http/1.1") + } else { + Err(AlpnError::NOACK) + } + }); + + builder.set_alpn_protos(b"\x08http/1.1\x02h2")?; + + Ok(builder.build()) +} diff --git a/http/src/h1/dispatcher.rs b/http/src/h1/dispatcher.rs index 8d5dc8c16..0d0effe20 100644 --- a/http/src/h1/dispatcher.rs +++ b/http/src/h1/dispatcher.rs @@ -233,7 +233,7 @@ where #[cold] #[inline(never)] async fn shutdown(self) -> Result<(), Error> { - self.io.io().shutdown(Shutdown::Both).await.map_err(Into::into) + self.io.into_io().shutdown(Shutdown::Both).await.map_err(Into::into) } #[cold] @@ -427,6 +427,13 @@ impl SharedIo { &self.inner.io } + fn into_io(self) -> Io { + Rc::try_unwrap(self.inner) + .ok() + .expect("SharedIo still has outstanding references") + .io + } + fn notifier(&mut self) -> NotifierIo { NotifierIo { inner: self.inner.clone(), diff --git a/http/src/h2/dispatcher_uring.rs b/http/src/h2/dispatcher_uring.rs index f3f15bfa5..43ad75f0b 100644 --- a/http/src/h2/dispatcher_uring.rs +++ b/http/src/h2/dispatcher_uring.rs @@ -1398,95 +1398,67 @@ where let mut queue = Queue::new(); let mut ping_pong = PingPong::new(ka, &shared, date, ka_dur); - let mut read_task = pin!(read_io(read_buf, &io)); + let res = { + let mut read_task = pin!(read_io(read_buf, &io)); - let mut write_task = pin!(async { - while poll_fn(|_| enc.poll_encode(&mut write_buf)).await { - let (res, buf) = io.write(write_buf).await; + let mut write_task = pin!(async { + while poll_fn(|_| enc.poll_encode(&mut write_buf)).await { + let (res, buf) = io.write(write_buf).await; - write_buf = buf; + write_buf = buf; - match res { - Ok(0) => return Err(io::ErrorKind::WriteZero.into()), - Ok(n) => write_buf.advance(n), - Err(e) => return Err(e), + match res { + Ok(0) => return Err(io::ErrorKind::WriteZero.into()), + Ok(n) => write_buf.advance(n), + Err(e) => return Err(e), + } } - } - Ok(()) - }); - - let shutdown = loop { - match read_task - .as_mut() - .select(async { - loop { - let _ = queue.next().await; - } - }) - .select(write_task.as_mut()) - .select(ping_pong.tick()) - .await - { - SelectOutput::A(SelectOutput::A(SelectOutput::A((res, buf)))) => { - read_buf = buf; + Ok(()) + }); - match res { - Ok(n) if n > 0 => { - if let Err(shutdown) = ctx.decode(&mut read_buf, |decoder, (req, id)| { - queue.push(response_task(req, id, decoder.service, decoder.ctx, decoder.date)); - }) { - break shutdown; - } + let shutdown = loop { + match read_task + .as_mut() + .select(async { + loop { + let _ = queue.next().await; } - res => break ShutDown::ReadClosed(res.map(|_| ())), - }; + }) + .select(write_task.as_mut()) + .select(ping_pong.tick()) + .await + { + SelectOutput::A(SelectOutput::A(SelectOutput::A((res, buf)))) => { + read_buf = buf; + + match res { + Ok(n) if n > 0 => { + if let Err(shutdown) = ctx.decode(&mut read_buf, |decoder, (req, id)| { + queue.push(response_task(req, id, decoder.service, decoder.ctx, decoder.date)); + }) { + break shutdown; + } + } + res => break ShutDown::ReadClosed(res.map(|_| ())), + }; - read_task.set(read_io(read_buf, &io)); + read_task.set(read_io(read_buf, &io)); + } + SelectOutput::A(SelectOutput::A(SelectOutput::B(_))) => {} + SelectOutput::A(SelectOutput::B(res)) => break ShutDown::WriteClosed(res), + SelectOutput::B(Ok(_)) => {} + SelectOutput::B(Err(e)) => break ShutDown::Timeout(e), } - SelectOutput::A(SelectOutput::A(SelectOutput::B(_))) => {} - SelectOutput::A(SelectOutput::B(res)) => break ShutDown::WriteClosed(res), - SelectOutput::B(Ok(_)) => {} - SelectOutput::B(Err(e)) => break ShutDown::Timeout(e), - } - }; - - shutdown.shutdown(queue, &shared, write_task, &io, ping_pong).await -} - -type BoxedFuture<'a, T> = Pin + 'a>>; - -enum ShutDown { - ReadClosed(io::Result<()>), - Graceful, - WriteClosed(io::Result<()>), - Timeout(io::Error), - DrainWrite, - Forced, -} + }; -impl ShutDown { - #[cold] - #[inline(never)] - fn shutdown<'a, T, F>( - self, - mut queue: Queue, - shared: &'a Shared, - mut write_task: Pin<&'a mut F>, - io: &'a impl AsyncBufWrite, - mut pin_pong: PingPong<'a>, - ) -> BoxedFuture<'a, io::Result<()>> - where - T: Future + 'a, - F: Future> + 'a, - { - Box::pin(async move { + Box::pin(async { let mut read_res = Ok(()); - match self { - Self::WriteClosed(res) => return res, - Self::Timeout(err) => return Err(err), - Self::ReadClosed(res) => { + match shutdown { + ShutDown::WriteClosed(res) => return res, + ShutDown::Timeout(err) => return Err(err), + ShutDown::ReadClosed(res) => { { let mut inner = shared.borrow_mut(); for state in inner.flow.stream_map.values_mut() { @@ -1502,26 +1474,38 @@ impl ShutDown { read_res = res; } - Self::Graceful => {} - Self::DrainWrite => queue.clear(), - Self::Forced => return Ok(()), + ShutDown::Graceful => {} + ShutDown::DrainWrite => queue.clear(), + ShutDown::Forced => return Ok(()), } loop { - match queue.next().select(write_task.as_mut()).select(pin_pong.tick()).await { + match queue.next().select(write_task.as_mut()).select(ping_pong.tick()).await { SelectOutput::A(SelectOutput::A(_)) => {} SelectOutput::A(SelectOutput::B(res)) => { res?; - // Send FIN so the peer sees a clean connection - // close rather than RST (RFC 7540 §6.8). - let _ = io.shutdown(Shutdown::Write).await; - return read_res; + break read_res; } SelectOutput::B(res) => res?, } } }) - } + .await + }; + + // Send FIN so the peer sees a clean connection + // close rather than RST (RFC 7540 §6.8). + let _ = io.shutdown(Shutdown::Write).await; + res +} + +enum ShutDown { + ReadClosed(io::Result<()>), + Graceful, + WriteClosed(io::Result<()>), + Timeout(io::Error), + DrainWrite, + Forced, } /// Validate a PRIORITY frame payload (RFC 7540 §6.3, §5.3.1). @@ -1541,6 +1525,8 @@ fn handle_priority(id: StreamId, payload: &[u8]) -> Result<(), Error> { } } +type BoxedFuture<'a, T> = Pin + 'a>>; + /// Perform the HTTP/2 connection handshake: validate the client preface, /// then send our SETTINGS frame. Returns the read buffer (with preface /// consumed) and the write buffer (ready for reuse). diff --git a/io/Cargo.toml b/io/Cargo.toml index c6479edc7..2e061e5fb 100644 --- a/io/Cargo.toml +++ b/io/Cargo.toml @@ -12,9 +12,9 @@ readme= "README.md" [features] default = [] # tokio runtime support -runtime = ["tokio"] +runtime = ["tokio", "tokio-uring-xitca/runtime"] # tokio-uring runtime support -runtime-uring = ["runtime", "tokio-uring-xitca/runtime"] +runtime-uring = ["runtime", "tokio-uring-xitca/runtime-uring"] # quic support quic = ["dep:quinn", "runtime"] diff --git a/io/src/io/complete.rs b/io/src/io/complete.rs index 46e6a2873..76ba44fa0 100644 --- a/io/src/io/complete.rs +++ b/io/src/io/complete.rs @@ -1,54 +1,6 @@ //! Completion-based async IO traits. //! -//! These traits model IO where buffer ownership is transferred to the operation -//! and returned on completion — the pattern originated from io_uring but not -//! tied to any specific runtime. They can be implemented on top of epoll/kqueue -//! or any other async runtime. - -use core::future::Future; - -use std::{io, net::Shutdown}; - -use tokio_uring_xitca::buf::IoBuf; +//! Re-exported from [`tokio_uring_xitca::io`]. pub use tokio_uring_xitca::buf::{BoundedBuf, BoundedBufMut, Slice}; - -/// Async read trait with buffer ownership transfer. -pub trait AsyncBufRead { - /// Read into a buffer, returning the result and the buffer. - fn read(&self, buf: B) -> impl Future, B)> - where - B: BoundedBufMut; -} - -/// Async write trait with buffer ownership transfer. -pub trait AsyncBufWrite { - /// Write from a buffer, returning the result and the buffer. - fn write(&self, buf: B) -> impl Future, B)> - where - B: BoundedBuf; - - /// Shutdown the connection in the given direction. - fn shutdown(&self, direction: Shutdown) -> impl Future>; -} - -/// Write all bytes from a buffer to IO. -pub async fn write_all(io: &Io, buf: B) -> (io::Result<()>, B) -where - Io: AsyncBufWrite, - B: IoBuf, -{ - let mut buf = buf.slice_full(); - while buf.bytes_init() != 0 { - match io.write(buf).await { - (Ok(0), slice) => { - return (Err(io::ErrorKind::WriteZero.into()), slice.into_inner()); - } - (Ok(n), slice) => buf = slice.slice(n..), - (Err(e), slice) => { - return (Err(e), slice.into_inner()); - } - } - } - (Ok(()), buf.into_inner()) -} +pub use tokio_uring_xitca::io::{AsyncBufRead, AsyncBufWrite, write_all}; diff --git a/io/src/lib.rs b/io/src/lib.rs index ce7287bde..c21daaa32 100644 --- a/io/src/lib.rs +++ b/io/src/lib.rs @@ -1,6 +1,6 @@ //! Async traits and types used for Io operations. -#![deny(unsafe_code)] +#![forbid(unsafe_code)] pub mod bytes; pub mod io; diff --git a/io/src/net.rs b/io/src/net.rs index 190087665..f40fbb2bc 100644 --- a/io/src/net.rs +++ b/io/src/net.rs @@ -22,83 +22,27 @@ use core::net::SocketAddr; macro_rules! default_aio_impl { ($ty: ty) => { impl crate::io::AsyncBufRead for $ty { - #[allow(unsafe_code)] - async fn read(&self, mut buf: B) -> (::std::io::Result, B) + fn read(&self, buf: B) -> impl core::future::Future, B)> where B: crate::io::BoundedBufMut, { - let ready = self.0.ready(crate::io::Interest::READABLE).await; - - if let Err(e) = ready { - return (Err(e), buf); - } - - let init = buf.bytes_init(); - let total = buf.bytes_total(); - - // Safety: construct a mutable slice over the spare capacity. - // try_read writes contiguously from the start of the slice - // and returns the exact byte count written on Ok(n). - let spare = unsafe { ::core::slice::from_raw_parts_mut(buf.stable_mut_ptr().add(init), total - init) }; - - let mut written = 0; - - let res = loop { - if written == spare.len() { - break Ok(written); - } - - match self.0.try_read(&mut spare[written..]) { - Ok(0) => break Ok(written), - Ok(n) => written += n, - Err(e) if e.kind() == ::std::io::ErrorKind::WouldBlock => break Ok(written), - Err(e) => break Err(e), - } - }; - - // SAFETY: TcpStream::try_read has put written bytes into buf. - unsafe { - buf.set_init(init + written); - } - - (res, buf) + crate::io::AsyncBufRead::read(&self.0, buf) } } impl crate::io::AsyncBufWrite for $ty { - async fn write(&self, buf: B) -> (::std::io::Result, B) + fn write(&self, buf: B) -> impl core::future::Future, B)> where B: crate::io::BoundedBuf, { - let ready = self.0.ready(crate::io::Interest::WRITABLE).await; - - if let Err(e) = ready { - return (Err(e), buf); - } - - let data = buf.chunk(); - - let mut written = 0; - - let res = loop { - if written == data.len() { - break Ok(written); - } - - match self.0.try_write(&data[written..]) { - Ok(0) => break Ok(written), - Ok(n) => written += n, - Err(e) if e.kind() == ::std::io::ErrorKind::WouldBlock => break Ok(written), - Err(e) => break Err(e), - } - }; - - (res, buf) + crate::io::AsyncBufWrite::write(&self.0, buf) } - async fn shutdown(&self, _direction: ::std::net::Shutdown) -> ::std::io::Result<()> { - // TODO: this is a no-op and shutdown is always handled by dropping the stream type - Ok(()) + fn shutdown( + self, + direction: ::std::net::Shutdown, + ) -> impl core::future::Future> { + crate::io::AsyncBufWrite::shutdown(self.0, direction) } } diff --git a/io/src/net/io_uring.rs b/io/src/net/io_uring.rs index 61c869470..377ce422f 100644 --- a/io/src/net/io_uring.rs +++ b/io/src/net/io_uring.rs @@ -1,13 +1,8 @@ use core::net::SocketAddr; -use std::{io, net::Shutdown}; +use std::io; -pub use tokio_uring_xitca::net::TcpStream; - -#[cfg(unix)] -pub use tokio_uring_xitca::net::UnixStream; - -use crate::io::{AsyncBufRead, AsyncBufWrite, BoundedBuf, BoundedBufMut}; +pub use tokio_uring_xitca::net::{TcpStream, UnixStream}; use super::Stream; @@ -31,81 +26,22 @@ impl TryFrom for (TcpStream, SocketAddr) { } } -impl AsyncBufRead for TcpStream { - #[inline(always)] - async fn read(&self, buf: B) -> (io::Result, B) - where - B: BoundedBufMut, - { - TcpStream::read(self, buf).await - } -} - -impl AsyncBufWrite for TcpStream { - #[inline(always)] - async fn write(&self, buf: B) -> (io::Result, B) - where - B: BoundedBuf, - { - TcpStream::write(self, buf).submit().await - } +impl TryFrom for UnixStream { + type Error = io::Error; - #[inline(always)] - async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { - TcpStream::shutdown(self, direction) + fn try_from(stream: Stream) -> Result { + <(UnixStream, std::os::unix::net::SocketAddr)>::try_from(stream).map(|(tcp, _)| tcp) } } -#[cfg(unix)] -mod unix { - use std::os::unix::net::SocketAddr; - - use tokio_uring_xitca::buf::BoundedBuf; - - use super::*; - - impl TryFrom for UnixStream { - type Error = io::Error; - - fn try_from(stream: Stream) -> Result { - <(UnixStream, SocketAddr)>::try_from(stream).map(|(tcp, _)| tcp) - } - } - - impl TryFrom for (UnixStream, SocketAddr) { - type Error = io::Error; - - fn try_from(stream: Stream) -> Result { - match stream { - Stream::Unix(unix, addr) => Ok((UnixStream::from_std(unix), addr)), - #[allow(unreachable_patterns)] - _ => unreachable!("Can not be casted to UnixStream"), - } - } - } - - impl AsyncBufRead for UnixStream { - #[inline(always)] - async fn read(&self, buf: B) -> (io::Result, B) - where - B: BoundedBufMut, - { - UnixStream::read(self, buf).await - } - } - - impl AsyncBufWrite for UnixStream { - #[inline(always)] - async fn write(&self, buf: B) -> (io::Result, B) - where - B: BoundedBuf, - { - UnixStream::write(self, buf).submit().await - } +impl TryFrom for (UnixStream, std::os::unix::net::SocketAddr) { + type Error = io::Error; - #[inline(always)] - async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { - UnixStream::shutdown(self, direction) + fn try_from(stream: Stream) -> Result { + match stream { + Stream::Unix(unix, addr) => Ok((UnixStream::from_std(unix), addr)), + #[allow(unreachable_patterns)] + _ => unreachable!("Can not be casted to UnixStream"), } } } diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 73ace6207..e730d0792 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -65,9 +65,15 @@ compio = { version = "0.17", features = ["bytes"], optional = true } [dev-dependencies] xitca-postgres-codegen = "0.1" bb8 = "0.9.0" +criterion = { version = "0.5", features = ["async_tokio"] } futures = { version = "0.3", default-features = false } postgres-derive = "0.4" postgres-types = { version = "0.2", features = ["with-uuid-1"] } rcgen = "0.14" +tempfile = "3" tokio = { version = "1.47.1", features = ["macros", "rt-multi-thread", "time"] } uuid = "1" + +[[bench]] +name = "fortune" +harness = false diff --git a/postgres/benches/fortune.rs b/postgres/benches/fortune.rs new file mode 100644 index 000000000..0b70ccbba --- /dev/null +++ b/postgres/benches/fortune.rs @@ -0,0 +1,158 @@ +use std::{ + future::IntoFuture, + io::Write, + process::{Command, Stdio}, +}; + +use criterion::{Criterion, criterion_group, criterion_main}; +use tokio::runtime::Runtime; +use xitca_postgres::{Execute, Postgres, Statement, iter::AsyncLendingIterator}; + +const PG_USER: &str = "bench_user"; +const PG_DB: &str = "bench_db"; +const PG_CONN: &str = "postgres://bench_user@localhost:5432/bench_db"; + +const FORTUNE_SQL: &str = "\ +CREATE TABLE IF NOT EXISTS fortune (\ + id integer NOT NULL,\ + message varchar(2048) NOT NULL,\ + PRIMARY KEY (id)\ +);\ +DELETE FROM fortune;\ +INSERT INTO fortune (id, message) VALUES (1, 'fortune: No such file or directory');\ +INSERT INTO fortune (id, message) VALUES (2, 'A computer scientist is someone who fixes things that aren''t broken.');\ +INSERT INTO fortune (id, message) VALUES (3, 'After enough decimal places, nobody gives a damn.');\ +INSERT INTO fortune (id, message) VALUES (4, 'A bad random number generator: 1, 1, 1, 1, 1, 4.33e+67, 1, 1, 1');\ +INSERT INTO fortune (id, message) VALUES (5, 'A computer program does what you tell it to do, not what you want it to do.');\ +INSERT INTO fortune (id, message) VALUES (6, 'Emacs is a nice operating system, but I prefer UNIX. — Tom Christaensen');\ +INSERT INTO fortune (id, message) VALUES (7, 'Any program that runs right is obsolete.');\ +INSERT INTO fortune (id, message) VALUES (8, 'A list is only as strong as its weakest link. — Donald Knuth');\ +INSERT INTO fortune (id, message) VALUES (9, 'Feature: A bug with seniority.');\ +INSERT INTO fortune (id, message) VALUES (10, 'Computers make very fast, very accurate mistakes.');\ +INSERT INTO fortune (id, message) VALUES (11, '');\ +INSERT INTO fortune (id, message) VALUES (12, 'フレームワークのベンチマーク');\ +"; + +struct PgInstance { + data_dir: tempfile::TempDir, +} + +impl PgInstance { + fn start() -> Self { + let data_dir = tempfile::tempdir().expect("failed to create tempdir"); + let dir = data_dir.path(); + + // initdb + let out = Command::new("initdb") + .args(["-D", dir.to_str().unwrap(), "--no-locale", "-E", "UTF8"]) + .output() + .expect("initdb not found"); + assert!( + out.status.success(), + "initdb failed: {}", + String::from_utf8_lossy(&out.stderr) + ); + + // start postgres + let child = Command::new("pg_ctl") + .args([ + "start", + "-D", + dir.to_str().unwrap(), + "-l", + dir.join("logfile").to_str().unwrap(), + "-o", + "-k /tmp -h localhost", + ]) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("pg_ctl not found"); + let status = child.wait_with_output().expect("pg_ctl wait failed"); + assert!(status.status.success(), "pg_ctl start failed"); + + // wait for ready + for _ in 0..30 { + let out = Command::new("pg_isready").args(["-h", "localhost"]).output().unwrap(); + if out.status.success() { + break; + } + std::thread::sleep(std::time::Duration::from_millis(200)); + } + + // create user and database + run_cmd("createuser", &["-h", "localhost", PG_USER]); + run_cmd("createdb", &["-h", "localhost", "-O", PG_USER, PG_DB]); + + // populate tables + let mut psql = Command::new("psql") + .args(["-h", "localhost", "-d", PG_DB, "-U", PG_USER]) + .stdin(Stdio::piped()) + .stdout(Stdio::null()) + .stderr(Stdio::piped()) + .spawn() + .expect("psql not found"); + psql.stdin.as_mut().unwrap().write_all(FORTUNE_SQL.as_bytes()).unwrap(); + let out = psql.wait_with_output().unwrap(); + assert!( + out.status.success(), + "psql populate failed: {}", + String::from_utf8_lossy(&out.stderr) + ); + + PgInstance { data_dir } + } +} + +impl Drop for PgInstance { + fn drop(&mut self) { + let _ = Command::new("pg_ctl") + .args(["stop", "-D", self.data_dir.path().to_str().unwrap(), "-m", "immediate"]) + .output(); + } +} + +fn run_cmd(cmd: &str, args: &[&str]) { + let out = Command::new(cmd) + .args(args) + .output() + .expect(&format!("{cmd} not found")); + assert!( + out.status.success(), + "{cmd} failed: {}", + String::from_utf8_lossy(&out.stderr) + ); +} + +fn bench_fortune(c: &mut Criterion) { + let _pg = PgInstance::start(); + let rt = Runtime::new().unwrap(); + + let (cli, drv) = rt.block_on(Postgres::new(PG_CONN).connect()).expect("connect failed"); + let handle = rt.spawn(drv.into_future()); + + // prepare and leak to get an unguarded Statement (no borrow on cli) + let stmt = rt + .block_on(Statement::named("SELECT id, message FROM fortune", &[]).execute(&cli)) + .expect("prepare failed") + .leak(); + + c.bench_function("fortune_select_all", |b| { + b.to_async(&rt).iter(|| async { + let mut stream = stmt.query(&cli).await.unwrap(); + let mut count = 0u32; + while let Some(row) = stream.try_next().await.unwrap() { + let _id: i32 = row.get(0); + let _msg: &str = row.get(1); + count += 1; + } + assert_eq!(count, 12); + }); + }); + + drop(cli); + rt.block_on(handle).unwrap().unwrap(); +} + +criterion_group!(benches, bench_fortune); +criterion_main!(benches); diff --git a/server/Cargo.toml b/server/Cargo.toml index e58a35faa..6543f13cf 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -25,7 +25,7 @@ tokio = { version = "1.48", features = ["sync", "time"] } tracing = { version = "0.1.40", default-features = false } # io-uring support -tokio-uring-xitca = { version = "0.2.0", features = ["runtime"], optional = true } +tokio-uring-xitca = { version = "0.2.0", features = ["runtime-uring"], optional = true } [target.'cfg(not(target_family = "wasm"))'.dependencies] socket2 = { version = "0.6.0" } diff --git a/tls/src/bridge.rs b/tls/src/bridge.rs index 29eed8e30..3641fd254 100644 --- a/tls/src/bridge.rs +++ b/tls/src/bridge.rs @@ -108,7 +108,7 @@ pub(crate) async fn drain_write(io: &impl AsyncBufWrite, buf: BytesMut) -> (io:: return (Ok(()), buf); } - let (res, mut buf) = xitca_io::io::write_all(io, buf).await; + let (res, mut buf) = io.write_all(buf).await; buf.clear(); (res, buf) diff --git a/tls/src/native_tls.rs b/tls/src/native_tls.rs index fa49d4332..65bf7c4a7 100644 --- a/tls/src/native_tls.rs +++ b/tls/src/native_tls.rs @@ -161,8 +161,12 @@ where (res, buf) } - async fn shutdown(&self, _direction: Shutdown) -> io::Result<()> { - Ok(()) + async fn shutdown(mut self, direction: Shutdown) -> io::Result<()> { + let res = self.tls_shutdown().await; + let shutdown_res = self.io.shutdown(direction).await; + + res?; + shutdown_res } } @@ -203,6 +207,26 @@ where } } } + + async fn tls_shutdown(&mut self) -> io::Result<()> { + self.write_tls(&[]).await?; + + let tls = self.tls.get_mut(); + + loop { + match tls.shutdown() { + Ok(()) => { + bridge::drain_write_buf(&self.io, tls.get_mut()).await?; + return Ok(()); + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + bridge::drain_write_buf(&self.io, tls.get_mut()).await?; + bridge::fill_read_buf(&self.io, tls.get_mut()).await?; + } + Err(e) => return Err(e), + } + } + } } /// Collection of native-tls error types. diff --git a/tls/src/openssl.rs b/tls/src/openssl.rs index 62fae9020..c0b7892da 100644 --- a/tls/src/openssl.rs +++ b/tls/src/openssl.rs @@ -47,19 +47,21 @@ where let mut tls = SslStream::new(ssl, bridge)?; loop { - let res = func(&mut tls); - - bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; - bridge::fill_read_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; - - match res { + match func(&mut tls) { Ok(_) => { + bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; return Ok(TlsStream { io, tls: RefCell::new(tls), }); } - Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {} + Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { + bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; + } + Err(ref e) if e.code() == ErrorCode::WANT_READ => { + bridge::drain_write_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; + bridge::fill_read_buf(&io, tls.get_mut()).await.map_err(Error::Io)?; + } Err(e) => return Err(Error::Tls(e)), } } @@ -150,8 +152,12 @@ where (res, buf) } - async fn shutdown(&self, _direction: Shutdown) -> io::Result<()> { - Ok(()) + async fn shutdown(mut self, direction: Shutdown) -> io::Result<()> { + let res = self.tls_shutdown().await; + let shutdown_res = self.io.shutdown(direction).await; + + res?; + shutdown_res } } @@ -192,6 +198,31 @@ where } } } + + async fn tls_shutdown(&mut self) -> io::Result<()> { + self.write_tls(&[]).await?; + + let tls = self.tls.get_mut(); + + loop { + match tls.shutdown() { + Ok(ssl::ShutdownResult::Sent) | Ok(ssl::ShutdownResult::Received) => { + bridge::drain_write_buf(&self.io, tls.get_mut()).await?; + return Ok(()); + } + Err(ref e) if e.code() == ErrorCode::WANT_WRITE => { + bridge::drain_write_buf(&self.io, tls.get_mut()).await?; + } + Err(ref e) if e.code() == ErrorCode::WANT_READ => { + bridge::drain_write_buf(&self.io, tls.get_mut()).await?; + bridge::fill_read_buf(&self.io, tls.get_mut()).await?; + } + Err(e) => { + return Err(io::Error::new(io::ErrorKind::InvalidData, e)); + } + } + } + } } /// Collection of OpenSSL error types. diff --git a/tls/src/rustls.rs b/tls/src/rustls.rs index cf495b39b..c7e4b79cb 100644 --- a/tls/src/rustls.rs +++ b/tls/src/rustls.rs @@ -338,9 +338,7 @@ where Err(e) => break Err(e), Ok(ConnectionState::WriteTraffic(mut traffic)) => { - let enc_res = encrypt_to_buf(&mut traffic, plaintext, &mut write_buf); - - if let Err(e) = enc_res { + if let Err(e) = encrypt_to_buf(&mut traffic, plaintext, &mut write_buf) { break Err(e); } @@ -355,10 +353,7 @@ where } Ok(ConnectionState::EncodeTlsData(mut state)) => { - let enc_res = encode_tls_data(&mut state, &mut write_buf); - drop(state); - - if let Err(e) = enc_res { + if let Err(e) = encode_tls_data(&mut state, &mut write_buf) { break Err(e); } } @@ -389,6 +384,53 @@ where session.write_buf.replace(write_buf); res } + + /// Send a TLS close_notify alert, flushing any pending data first. + async fn tls_shutdown(&self) -> io::Result<()> { + // Flush pending application/protocol data. + self.write_tls(&Vec::new()).await?; + + let mut session = self.session.borrow_mut(); + let mut write_buf = session.write_buf.take().expect(POLL_TO_COMPLETE); + + loop { + let UnbufferedStatus { state, .. } = session.conn.process_tls_records(&mut []); + + match state.map_err(tls_err)? { + ConnectionState::WriteTraffic(mut traffic) => { + write_buf.reserve(64); + // SAFETY: queue_close_notify writes a single TLS alert record + // contiguously from index 0. On Ok(n), exactly n bytes are written. + // On InsufficientSize, no bytes are written. + let res = unsafe { + SpareCapBuf::new(&mut write_buf).with_mut_slice(|spare| traffic.queue_close_notify(spare)) + }; + + if let Err(EncryptError::InsufficientSize(s)) = res { + write_buf.reserve(s.required_size); + continue; + } + + drop(session); + + return write_all_buf(&self.io, write_buf).await.0; + } + ConnectionState::EncodeTlsData(mut state) => encode_tls_data(&mut state, &mut write_buf)?, + ConnectionState::TransmitTlsData(state) => { + state.done(); + drop(session); + + let (res, b) = write_all_buf(&self.io, write_buf).await; + write_buf = b; + + res?; + session = self.session.borrow_mut(); + } + ConnectionState::PeerClosed | ConnectionState::Closed => return Ok(()), + _ => {} + } + } + } } impl AsyncBufRead for TlsStream @@ -418,8 +460,12 @@ where (res, buf) } - async fn shutdown(&self, direction: Shutdown) -> io::Result<()> { - self.io.shutdown(direction).await + async fn shutdown(self, direction: Shutdown) -> io::Result<()> { + let res = self.tls_shutdown().await; + let shutdown_res = self.io.shutdown(direction).await; + + res?; + shutdown_res } } @@ -444,7 +490,7 @@ async fn read_to_buf(io: &impl AsyncBufRead, mut buf: BytesMut) -> (io::Result<( /// Write all bytes from a BytesMut to IO, then clear it. async fn write_all_buf(io: &impl AsyncBufWrite, mut buf: BytesMut) -> (io::Result<()>, BytesMut) { - let (res, b) = xitca_io::io::write_all(io, buf).await; + let (res, b) = io.write_all(buf).await; buf = b; if res.is_ok() { buf.clear(); diff --git a/tokio-uring/CHANGELOG.md b/tokio-uring/CHANGELOG.md index 5b3c9e162..03020d162 100644 --- a/tokio-uring/CHANGELOG.md +++ b/tokio-uring/CHANGELOG.md @@ -1,10 +1,23 @@ # unreleased 0.2.0 +## Add +- add `io::{AsyncBufRead, AsyncBufWrte, write_all}` +- add `io::{AsyncBufRead, AsyncBufWrite}` impl for `net::{TcpStream, UnixStream}` + ## Fix - `BoundedBuf::put_slice` now extend to it's uninit part. Multiple calls to it would result in accumulation of bytes and not overwritting ## Change - `FixedBuf` would be cleared on check out to buffer pool (By setting its initialized size to zero) -- remove runtime from default feature +- Cargo feature rework + + `default` -> `buf` and `io` module for universal type and trait + + `bytes` -> enable `buf` trait impl for `bytes` crate + + `runtime` -> enable `io` trait impl for `tokio` crate type + + `runtime-uring` -> enablle `io_uring` on top of `tokio` runtime + - perf improvement ## Add diff --git a/tokio-uring/Cargo.toml b/tokio-uring/Cargo.toml index 90b51c9b8..1453dbb22 100644 --- a/tokio-uring/Cargo.toml +++ b/tokio-uring/Cargo.toml @@ -18,7 +18,8 @@ workspace = true [features] bytes = ["dep:bytes"] -runtime = ["dep:tokio", "dep:slab", "dep:libc", "dep:io-uring", "dep:socket2"] +runtime = ["dep:tokio"] +runtime-uring = ["runtime", "dep:slab", "dep:libc", "dep:io-uring", "dep:socket2"] [dependencies] tokio = { version = "1.48", features = ["net", "rt", "sync"], optional = true } diff --git a/tokio-uring/src/buf/mod.rs b/tokio-uring/src/buf/mod.rs index c6fa458e3..fecae4e6d 100644 --- a/tokio-uring/src/buf/mod.rs +++ b/tokio-uring/src/buf/mod.rs @@ -4,7 +4,7 @@ //! crate defines [`IoBuf`] and [`IoBufMut`] traits which are implemented by buffer //! types that respect the `io-uring` contract. -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] pub mod fixed; mod io_buf; diff --git a/tokio-uring/src/io/async_buf_read.rs b/tokio-uring/src/io/async_buf_read.rs new file mode 100644 index 000000000..ab1398f93 --- /dev/null +++ b/tokio-uring/src/io/async_buf_read.rs @@ -0,0 +1,20 @@ +//! Completion-based async IO traits. +//! +//! These traits model IO where buffer ownership is transferred to the operation +//! and returned on completion — the pattern originated from io_uring but not +//! tied to any specific runtime. They can be implemented on top of epoll/kqueue +//! or any other async runtime. + +use core::future::Future; + +use std::io; + +use crate::buf::BoundedBufMut; + +/// Async read trait with buffer ownership transfer. +pub trait AsyncBufRead { + /// Read into a buffer, returning the result and the buffer. + fn read(&self, buf: B) -> impl Future, B)> + where + B: BoundedBufMut; +} diff --git a/tokio-uring/src/io/async_buf_read_write_impl.rs b/tokio-uring/src/io/async_buf_read_write_impl.rs new file mode 100644 index 000000000..17fa04ea2 --- /dev/null +++ b/tokio-uring/src/io/async_buf_read_write_impl.rs @@ -0,0 +1,101 @@ +use core::{future::poll_fn, pin::Pin, slice}; + +use std::{io, net::Shutdown}; + +use tokio::{ + io::{AsyncWrite, Interest}, + net::{TcpStream, UnixStream}, +}; + +use crate::buf::{BoundedBuf, BoundedBufMut}; + +use super::{async_buf_read::AsyncBufRead, async_buf_write::AsyncBufWrite}; + +macro_rules! trait_impl { + ($ty: ty) => { + impl AsyncBufRead for $ty { + #[allow(unsafe_code)] + async fn read(&self, mut buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + let init = buf.bytes_init(); + let total = buf.bytes_total(); + + // Safety: construct a mutable slice over the spare capacity. + // try_read writes contiguously from the start of the slice + // and returns the exact byte count written on Ok(n). + let spare = unsafe { slice::from_raw_parts_mut(buf.stable_mut_ptr().add(init), total - init) }; + + let mut written = 0; + + let res = loop { + if written == spare.len() { + break Ok(written); + } + + match self.try_read(&mut spare[written..]) { + Ok(0) => break Ok(written), + Ok(n) => written += n, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + if written > 0 { + break Ok(written); + } + if let Err(e) = self.ready(Interest::READABLE).await { + break Err(e); + } + } + Err(e) => break Err(e), + } + }; + + // SAFETY: TcpStream::try_read has put written bytes into buf. + unsafe { + buf.set_init(init + written); + } + + (res, buf) + } + } + + impl AsyncBufWrite for $ty { + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + let data = buf.chunk(); + + let mut written = 0; + + let res = loop { + if written == data.len() { + break Ok(written); + } + + match self.try_write(&data[written..]) { + Ok(0) => break Ok(written), + Ok(n) => written += n, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + if written > 0 { + break Ok(written); + } + if let Err(e) = self.ready(Interest::WRITABLE).await { + break Err(e); + } + } + Err(e) => break Err(e), + } + }; + + (res, buf) + } + + async fn shutdown(mut self, _: Shutdown) -> io::Result<()> { + poll_fn(|cx| Pin::new(&mut self).poll_shutdown(cx)).await + } + } + }; +} + +trait_impl!(TcpStream); +trait_impl!(UnixStream); diff --git a/tokio-uring/src/io/async_buf_write.rs b/tokio-uring/src/io/async_buf_write.rs new file mode 100644 index 000000000..74752647e --- /dev/null +++ b/tokio-uring/src/io/async_buf_write.rs @@ -0,0 +1,56 @@ +//! Completion-based async IO traits. +//! +//! These traits model IO where buffer ownership is transferred to the operation +//! and returned on completion — the pattern originated from io_uring but not +//! tied to any specific runtime. They can be implemented on top of epoll/kqueue +//! or any other async runtime. + +use core::future::Future; + +use std::{io, net::Shutdown}; + +use crate::buf::{BoundedBuf, IoBuf}; + +/// Async write trait with buffer ownership transfer. +pub trait AsyncBufWrite { + /// Write from a buffer, returning the result and the buffer. + fn write(&self, buf: B) -> impl Future, B)> + where + B: BoundedBuf; + + /// Write all bytes from a buffer to IO. + fn write_all(&self, buf: B) -> impl Future, B)> + where + B: IoBuf, + Self: Sized, + { + write_all(self, buf) + } + + /// Shutdown the connection in the given direction. + /// + /// Takes ownership of `self` because shutdown is a terminal operation. + /// No further reads or writes should occur after shutdown. + fn shutdown(self, direction: Shutdown) -> impl Future>; +} + +/// Write all bytes from a buffer to IO. +pub async fn write_all(io: &Io, buf: B) -> (io::Result<()>, B) +where + Io: AsyncBufWrite, + B: IoBuf, +{ + let mut buf = buf.slice_full(); + while buf.bytes_init() != 0 { + match io.write(buf).await { + (Ok(0), slice) => { + return (Err(io::ErrorKind::WriteZero.into()), slice.into_inner()); + } + (Ok(n), slice) => buf = slice.slice(n..), + (Err(e), slice) => { + return (Err(e), slice.into_inner()); + } + } + } + (Ok(()), buf.into_inner()) +} diff --git a/tokio-uring/src/io/mod.rs b/tokio-uring/src/io/mod.rs index 1afcef229..929efe4f4 100644 --- a/tokio-uring/src/io/mod.rs +++ b/tokio-uring/src/io/mod.rs @@ -1,60 +1,108 @@ +//! IO traits and utilities. +//! +//! Completion-based async IO traits ([`AsyncBufRead`], [`AsyncBufWrite`]) are +//! always available. The remaining io_uring operation types require the +//! `runtime-uring` feature. + +#[cfg(feature = "runtime-uring")] mod accept; +#[cfg(feature = "runtime")] +mod async_buf_read_write_impl; + +#[cfg(feature = "runtime-uring")] mod close; +#[cfg(feature = "runtime-uring")] mod connect; +#[cfg(feature = "runtime-uring")] mod fallocate; +#[cfg(feature = "runtime-uring")] mod fsync; +#[cfg(feature = "runtime-uring")] mod mkdir_at; +#[cfg(feature = "runtime-uring")] mod noop; +#[cfg(feature = "runtime-uring")] pub(crate) use noop::NoOp; +#[cfg(feature = "runtime-uring")] mod open; +#[cfg(feature = "runtime-uring")] mod read; +#[cfg(feature = "runtime-uring")] mod read_fixed; +#[cfg(feature = "runtime-uring")] mod readv; +#[cfg(feature = "runtime-uring")] mod recv_from; +#[cfg(feature = "runtime-uring")] mod recvmsg; +#[cfg(feature = "runtime-uring")] mod rename_at; +#[cfg(feature = "runtime-uring")] mod send_to; +#[cfg(feature = "runtime-uring")] mod send_zc; +#[cfg(feature = "runtime-uring")] mod sendmsg; +#[cfg(feature = "runtime-uring")] mod sendmsg_zc; +#[cfg(feature = "runtime-uring")] mod shared_fd; +#[cfg(feature = "runtime-uring")] pub(crate) use shared_fd::SharedFd; +#[cfg(feature = "runtime-uring")] mod socket; +#[cfg(feature = "runtime-uring")] pub(crate) use socket::Socket; +#[cfg(feature = "runtime-uring")] mod statx; +#[cfg(feature = "runtime-uring")] mod symlink; +#[cfg(feature = "runtime-uring")] mod unlink_at; +#[cfg(feature = "runtime-uring")] mod util; +#[cfg(feature = "runtime-uring")] pub(crate) use util::cstr; +#[cfg(feature = "runtime-uring")] pub(crate) mod write; +#[cfg(feature = "runtime-uring")] mod write_fixed; +#[cfg(feature = "runtime-uring")] mod writev; +#[cfg(feature = "runtime-uring")] mod writev_all; +#[cfg(feature = "runtime-uring")] pub(crate) use writev_all::writev_at_all; + +mod async_buf_read; +mod async_buf_write; + +pub use async_buf_read::AsyncBufRead; +pub use async_buf_write::{AsyncBufWrite, write_all}; diff --git a/tokio-uring/src/lib.rs b/tokio-uring/src/lib.rs index 6c21359d2..58df8f693 100644 --- a/tokio-uring/src/lib.rs +++ b/tokio-uring/src/lib.rs @@ -59,7 +59,7 @@ #![warn(missing_docs)] #![allow(clippy::missing_const_for_thread_local)] -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] macro_rules! syscall { ($fn: ident ( $($arg: expr),* $(,)* ) ) => {{ let res = unsafe { ::libc::$fn($($arg, )*) }; @@ -71,33 +71,32 @@ macro_rules! syscall { }}; } -#[cfg(feature = "runtime")] -mod io; -#[cfg(feature = "runtime")] -mod runtime; - pub mod buf; -#[cfg(feature = "runtime")] +pub mod io; + +#[cfg(feature = "runtime-uring")] pub mod fs; -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] pub mod net; +#[cfg(feature = "runtime-uring")] +mod runtime; -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] pub use io::write::*; -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] pub use runtime::{ Runtime, driver::op::{InFlightOneshot, OneshotOutputTransform, UnsubmittedOneshot}, spawn, }; -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] use core::future::Future; -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] use runtime::driver::op::Op; -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] /// Starts an `io_uring` enabled Tokio runtime. /// /// All `tokio-uring` resource types must be used from within the context of a @@ -115,14 +114,14 @@ pub fn start(future: F) -> F::Output { rt.block_on(future) } -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] /// Creates and returns an io_uring::Builder that can then be modified /// through its implementation methods. pub fn uring_builder() -> io_uring::Builder { io_uring::IoUring::builder() } -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] /// Builder API that can create and start the `io_uring` runtime with non-default parameters, /// while abstracting away the underlying io_uring crate. pub struct Builder { @@ -130,7 +129,7 @@ pub struct Builder { urb: io_uring::Builder, } -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] /// Constructs a [`Builder`] with default settings. pub fn builder() -> Builder { Builder { @@ -139,7 +138,7 @@ pub fn builder() -> Builder { } } -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] impl Builder { /// Sets the number of Submission Queue entries in uring. pub fn entries(&mut self, sq_entries: u32) -> &mut Self { @@ -164,7 +163,7 @@ impl Builder { /// A specialized `Result` type for `io-uring` operations with buffers. pub type BufResult = (std::io::Result, B); -#[cfg(feature = "runtime")] +#[cfg(feature = "runtime-uring")] /// The simplest possible operation. Just posts a completion event, nothing else. pub async fn no_op() -> std::io::Result<()> { let op = Op::::no_op().unwrap(); diff --git a/tokio-uring/src/net/tcp/stream.rs b/tokio-uring/src/net/tcp/stream.rs index ee0903ba0..f2099b18a 100644 --- a/tokio-uring/src/net/tcp/stream.rs +++ b/tokio-uring/src/net/tcp/stream.rs @@ -1,6 +1,8 @@ +use core::net::SocketAddr; + use std::{ io, - net::SocketAddr, + net::Shutdown, os::unix::prelude::{AsRawFd, FromRawFd, RawFd}, }; @@ -8,7 +10,7 @@ use crate::{ UnsubmittedWrite, buf::fixed::FixedBuf, buf::{BoundedBuf, BoundedBufMut}, - io::{SharedFd, Socket}, + io::{AsyncBufRead, AsyncBufWrite, SharedFd, Socket}, }; /// A TCP stream between a local and a remote socket. @@ -263,6 +265,31 @@ impl TcpStream { } } +impl AsyncBufRead for TcpStream { + #[inline(always)] + async fn read(&self, buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + TcpStream::read(self, buf).await + } +} + +impl AsyncBufWrite for TcpStream { + #[inline(always)] + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + TcpStream::write(self, buf).submit().await + } + + #[inline(always)] + async fn shutdown(self, direction: Shutdown) -> io::Result<()> { + TcpStream::shutdown(&self, direction) + } +} + impl FromRawFd for TcpStream { unsafe fn from_raw_fd(fd: RawFd) -> Self { TcpStream::from_socket(Socket::from_shared_fd(SharedFd::new(fd))) diff --git a/tokio-uring/src/net/unix/stream.rs b/tokio-uring/src/net/unix/stream.rs index 666b91b89..ed8e69b8f 100644 --- a/tokio-uring/src/net/unix/stream.rs +++ b/tokio-uring/src/net/unix/stream.rs @@ -1,16 +1,18 @@ -use crate::{ - UnsubmittedWrite, - buf::fixed::FixedBuf, - buf::{BoundedBuf, BoundedBufMut}, - io::{SharedFd, Socket}, -}; -use socket2::SockAddr; use std::{ io, + net::Shutdown, os::unix::prelude::{AsRawFd, FromRawFd, RawFd}, path::Path, }; +use socket2::SockAddr; + +use crate::{ + UnsubmittedWrite, + buf::{BoundedBuf, BoundedBufMut, fixed::FixedBuf}, + io::{AsyncBufRead, AsyncBufWrite, SharedFd, Socket}, +}; + /// A Unix stream between two local sockets on a Unix OS. /// /// A Unix stream can either be created by connecting to an endpoint, via the @@ -213,6 +215,31 @@ impl UnixStream { } } +impl AsyncBufRead for UnixStream { + #[inline(always)] + async fn read(&self, buf: B) -> (io::Result, B) + where + B: BoundedBufMut, + { + UnixStream::read(self, buf).await + } +} + +impl AsyncBufWrite for UnixStream { + #[inline(always)] + async fn write(&self, buf: B) -> (io::Result, B) + where + B: BoundedBuf, + { + UnixStream::write(self, buf).submit().await + } + + #[inline(always)] + async fn shutdown(self, direction: Shutdown) -> io::Result<()> { + UnixStream::shutdown(&self, direction) + } +} + impl FromRawFd for UnixStream { unsafe fn from_raw_fd(fd: RawFd) -> Self { UnixStream::from_socket(Socket::from_shared_fd(SharedFd::new(fd))) From d1db357fc9caab1b58b5e54d3d786bca7281163f Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Thu, 2 Apr 2026 15:11:27 +0800 Subject: [PATCH 20/21] clippy fix --- postgres/benches/fortune.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres/benches/fortune.rs b/postgres/benches/fortune.rs index 0b70ccbba..336bfd658 100644 --- a/postgres/benches/fortune.rs +++ b/postgres/benches/fortune.rs @@ -116,7 +116,7 @@ fn run_cmd(cmd: &str, args: &[&str]) { let out = Command::new(cmd) .args(args) .output() - .expect(&format!("{cmd} not found")); + .unwrap_or_else(|_| panic!("{cmd} not found")); assert!( out.status.success(), "{cmd} failed: {}", From 81b3852021939892bfce6686ae498ec28a2583ad Mon Sep 17 00:00:00 2001 From: fakeshadow <24548779@qq.com> Date: Thu, 2 Apr 2026 15:38:00 +0800 Subject: [PATCH 21/21] increase sample size --- postgres/benches/fortune.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/postgres/benches/fortune.rs b/postgres/benches/fortune.rs index 336bfd658..814d5cec5 100644 --- a/postgres/benches/fortune.rs +++ b/postgres/benches/fortune.rs @@ -154,5 +154,11 @@ fn bench_fortune(c: &mut Criterion) { rt.block_on(handle).unwrap().unwrap(); } -criterion_group!(benches, bench_fortune); +criterion_group! { + name = benches; + config = Criterion::default() + .sample_size(500) + .measurement_time(std::time::Duration::from_secs(10)); + targets = bench_fortune +} criterion_main!(benches);