diff --git a/client/src/builder.rs b/client/src/builder.rs index a9a2e28da..03bd3c64d 100644 --- a/client/src/builder.rs +++ b/client/src/builder.rs @@ -72,11 +72,11 @@ impl ClientBuilder { /// /// // trait implement for the logic of middleware. most of the types are boilerplate /// // that can be copy/pasted. the real logic goes into `async fn call` - /// impl<'r, 'c> Service> for MyMiddleware { + /// impl<'c> Service> for MyMiddleware { /// type Response = Response; /// type Error = Error; /// - /// async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result { + /// async fn call(&self, req: ServiceRequest<'c>) -> Result { /// // my middleware receive ServiceRequest and can do pre-process before passing it to /// // HttpService. in this case we just print out the HTTP method of request. /// println!("request method is: {}", req.req.method()); @@ -115,7 +115,7 @@ impl ClientBuilder { pub fn middleware(mut self, func: F) -> Self where F: FnOnce(HttpService) -> S, - S: for<'r, 'c> Service, Response = Response, Error = Error> + Send + Sync + 'static, + S: for<'c> Service, Response = Response, Error = Error> + Send + Sync + 'static, { self.service = Box::new(func(self.service)); self diff --git a/client/src/h1/proto/dispatcher.rs b/client/src/h1/proto/dispatcher.rs index afd061f2d..739df52d9 100644 --- a/client/src/h1/proto/dispatcher.rs +++ b/client/src/h1/proto/dispatcher.rs @@ -22,7 +22,7 @@ use super::context::Context; pub(crate) async fn send( stream: &mut S, date: DateTimeHandle<'_>, - req: &mut Request, + mut req: Request, ) -> Result<(Response<()>, BytesMut, TransferCoding, bool), Error> where S: AsyncIo + Unpin, @@ -72,7 +72,7 @@ where let mut ctx = Context::<128>::new(&date); // encode request head and return transfer encoding for request body - let encoder = ctx.encode_head(&mut buf, req)?; + let encoder = ctx.encode_head(&mut buf, &mut req)?; // it's important to call set_head_method after encode_head. Context would remove http body it encodes/decodes // for head http method. diff --git a/client/src/middleware/decompress.rs b/client/src/middleware/decompress.rs index b134e3e79..ed1c05105 100644 --- a/client/src/middleware/decompress.rs +++ b/client/src/middleware/decompress.rs @@ -23,14 +23,14 @@ impl Decompress { } } -impl<'r, 'c, S> Service> for Decompress +impl<'c, S> Service> for Decompress where - S: for<'r2, 'c2> Service, Response = Response, Error = Error> + Send + Sync, + S: for<'c2> Service, Response = Response, Error = Error> + Send + Sync, { type Response = Response; type Error = Error; - async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result { + async fn call(&self, mut req: ServiceRequest<'c>) -> Result { req.req .headers_mut() .insert(ACCEPT_ENCODING, HeaderValue::from_static("gzip, deflate, br")); diff --git a/client/src/middleware/redirect.rs b/client/src/middleware/redirect.rs index eeb1e4ad3..4cc1d84eb 100644 --- a/client/src/middleware/redirect.rs +++ b/client/src/middleware/redirect.rs @@ -1,5 +1,5 @@ +use xitca_http::Request; use crate::{ - body::BoxBody, error::{Error, InvalidUri}, http::{ header::{ @@ -41,22 +41,21 @@ impl FollowRedirect { } } -impl<'r, 'c, S, const MAX: usize> Service> for FollowRedirect +impl<'c, S, const MAX: usize> Service> for FollowRedirect where - S: for<'r2, 'c2> Service, Response = Response, Error = Error> + Send + Sync, + S: for<'c2> Service, Response = Response, Error = Error> + Send + Sync, { type Response = Response; type Error = Error; - async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result { + async fn call(&self, req: ServiceRequest<'c>) -> Result { let ServiceRequest { req, client, timeout } = req; - let mut headers = req.headers().clone(); - let mut method = req.method().clone(); - let mut uri = req.uri().clone(); - let ext = req.extensions().clone(); + let (mut head, mut body) = req.into_parts(); let mut count = 0; loop { + let body = core::mem::take(&mut body); + let req = Request::from_parts(head.clone(), body); let mut res = self.service.call(ServiceRequest { req, client, timeout }).await?; if count == MAX { @@ -65,14 +64,12 @@ where match res.status() { StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::SEE_OTHER => { - if method != Method::HEAD { - method = Method::GET; + if head.method != Method::HEAD { + head.method = Method::GET; } - *req.body_mut() = BoxBody::default(); - for header in &[TRANSFER_ENCODING, CONTENT_ENCODING, CONTENT_TYPE, CONTENT_LENGTH] { - headers.remove(header); + head.headers.remove(header); } } StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {} @@ -83,7 +80,7 @@ where return Ok(res); }; - let parts = uri.into_parts(); + let parts = core::mem::take(&mut head.uri).into_parts(); let parts_location = location .to_str() @@ -93,9 +90,9 @@ where // remove authenticated headers when redirected to different scheme/authority if parts_location.scheme != parts.scheme || parts_location.authority != parts.authority { - headers.remove(AUTHORIZATION); - headers.remove(PROXY_AUTHORIZATION); - headers.remove(COOKIE); + head.headers.remove(AUTHORIZATION); + head.headers.remove(PROXY_AUTHORIZATION); + head.headers.remove(COOKIE); } let mut uri_builder = Uri::builder(); @@ -109,12 +106,7 @@ where } let path = parts_location.path_and_query.ok_or(InvalidUri::MissingPathQuery)?; - uri = uri_builder.path_and_query(path).build().unwrap(); - - *req.uri_mut() = uri.clone(); - *req.method_mut() = method.clone(); - *req.headers_mut() = headers.clone(); - *req.extensions_mut() = ext.clone(); + head.uri = uri_builder.path_and_query(path).build().unwrap(); count += 1; } @@ -124,11 +116,10 @@ where #[cfg(test)] mod test { use crate::{ - body::ResponseBody, + body::{BoxBody, ResponseBody}, http, service::{mock_service, Service}, }; - use super::*; #[tokio::test] @@ -155,21 +146,21 @@ mod test { p => panic!("unexpected uri path: {p}"), }; - let mut req = http::Request::builder() + let req = http::Request::builder() .uri("http://foo.bar/foo") .body(Default::default()) .unwrap(); - let req = handle.mock(&mut req, handler); + let req = handle.mock(req, handler); let res = redirect.call(req).await.unwrap(); assert_eq!(res.status(), StatusCode::IM_A_TEAPOT); - let mut req = http::Request::builder() + let req = http::Request::builder() .uri("http://foo.bar/fur") .body(Default::default()) .unwrap(); - let req = handle.mock(&mut req, handler); + let req = handle.mock(req, handler); let res = redirect.call(req).await.unwrap(); assert_eq!(res.status(), StatusCode::SEE_OTHER); assert_eq!(res.headers().get(LOCATION).unwrap().to_str().unwrap(), "/bar"); diff --git a/client/src/request.rs b/client/src/request.rs index 9490e56d9..49cb671c5 100644 --- a/client/src/request.rs +++ b/client/src/request.rs @@ -121,7 +121,7 @@ impl<'a, M> RequestBuilder<'a, M> { // send request to server pub(crate) async fn _send(self) -> Result { let Self { - mut req, + req, err, client, timeout, @@ -132,14 +132,7 @@ impl<'a, M> RequestBuilder<'a, M> { return Err(err.into()); } - client - .service - .call(ServiceRequest { - req: &mut req, - client, - timeout, - }) - .await + client.service.call(ServiceRequest { req, client, timeout }).await } pub(crate) fn push_error(&mut self, e: Error) { diff --git a/client/src/service.rs b/client/src/service.rs index cd56adcc4..9394eb28b 100644 --- a/client/src/service.rs +++ b/client/src/service.rs @@ -64,28 +64,32 @@ where /// It's similar to [RequestBuilder] type but with additional side effect enabled. /// /// [RequestBuilder]: crate::request::RequestBuilder -pub struct ServiceRequest<'r, 'c> { - pub req: &'r mut Request, +pub struct ServiceRequest<'c> { + pub req: Request, pub client: &'c Client, pub timeout: Duration, } /// type alias for object safe wrapper of type implement [Service] trait. pub type HttpService = - Box ServiceDyn, Response = Response, Error = Error> + Send + Sync>; + Box ServiceDyn, Response = Response, Error = Error> + Send + Sync>; pub(crate) fn base_service() -> HttpService { struct HttpService; - impl<'r, 'c> Service> for HttpService { + impl<'c> Service> for HttpService { type Response = Response; type Error = Error; - async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result { + async fn call(&self, req: ServiceRequest<'c>) -> Result { #[cfg(any(feature = "http1", feature = "http2", feature = "http3"))] use crate::{error::TimeoutError, timeout::Timeout}; - let ServiceRequest { req, client, timeout } = req; + let ServiceRequest { + mut req, + client, + timeout, + } = req; let uri = Uri::try_parse(req.uri())?; @@ -108,10 +112,7 @@ pub(crate) fn base_service() -> HttpService { return match _conn.conn { #[cfg(feature = "http2")] crate::connection::ConnectionShared::H2(ref mut conn) => { - match crate::h2::proto::send(conn, _date, core::mem::take(req)) - .timeout(_timer.as_mut()) - .await - { + match crate::h2::proto::send(conn, _date, req).timeout(_timer.as_mut()).await { Ok(Ok(res)) => { let timeout = client.timeout_config.response_timeout; Ok(Response::new(res, _timer, timeout)) @@ -128,7 +129,7 @@ pub(crate) fn base_service() -> HttpService { } #[cfg(feature = "http3")] crate::connection::ConnectionShared::H3(ref mut conn) => { - let res = crate::h3::proto::send(conn, _date, core::mem::take(req)) + let res = crate::h3::proto::send(conn, _date, req) .timeout(_timer.as_mut()) .await .map_err(|_| TimeoutError::Request)??; @@ -297,11 +298,11 @@ mod test { impl HttpServiceMockHandle { /// compose a service request with given http request and it's mocked server side handler function - pub(crate) fn mock<'r, 'c>( - &'c self, - req: &'r mut Request, + pub(crate) fn mock( + &self, + mut req: Request, handler: impl Fn(Request) -> Result, Error> + Send + Sync + 'static, - ) -> ServiceRequest<'r, 'c> { + ) -> ServiceRequest<'_> { req.extensions_mut().insert(Arc::new(handler) as HandlerFn); ServiceRequest { req, @@ -311,17 +312,17 @@ mod test { } } - impl<'r, 'c> Service> for HttpServiceMock { + impl<'c> Service> for HttpServiceMock { type Response = Response; type Error = Error; async fn call( &self, - ServiceRequest { req, timeout, .. }: ServiceRequest<'r, 'c>, + ServiceRequest { req, timeout, .. }: ServiceRequest<'c>, ) -> Result { let handler = req.extensions().get::().unwrap().clone(); - let res = handler(core::mem::take(req))?; + let res = handler(req)?; Ok(Response::new( res,