diff --git a/actix-http/src/h1/client.rs b/actix-http/src/h1/client.rs index f3947dd12..8b6150247 100644 --- a/actix-http/src/h1/client.rs +++ b/actix-http/src/h1/client.rs @@ -188,16 +188,16 @@ impl Decoder for ClientPayloadCodec { } } -impl Encoder> for ClientCodec { +impl Encoder> for ClientCodec { type Error = io::Error; fn encode( &mut self, - item: Message<(RequestHeadType, BodySize)>, + item: Message<(&mut RequestHeadType, BodySize)>, dst: &mut BytesMut, ) -> Result<(), Self::Error> { match item { - Message::Item((mut head, length)) => { + Message::Item((head, length)) => { let inner = &mut self.inner; inner.version = head.as_ref().version; inner @@ -219,7 +219,7 @@ impl Encoder> for ClientCodec { inner.encoder.encode( dst, - &mut head, + head, false, false, inner.version, diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 8a2a1ec43..0ce8f63a0 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -5,6 +5,7 @@ - Update `brotli` dependency to `7`. - Prevent panics on connection pool drop when Tokio runtime is shutdown early. - Minimum supported Rust version (MSRV) is now 1.75. +- Allow to retrieve request head used to send the http request on `ClientResponse` ## 3.5.1 diff --git a/awc/src/client/connection.rs b/awc/src/client/connection.rs index 8164e2b59..6f3d42177 100644 --- a/awc/src/client/connection.rs +++ b/awc/src/client/connection.rs @@ -243,7 +243,7 @@ where self, head: H, body: RB, - ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + ) -> LocalBoxFuture<'static, Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>> where H: Into + 'static, RB: MessageBody + 'static, @@ -273,17 +273,24 @@ where head: H, ) -> LocalBoxFuture< 'static, - Result<(ResponseHead, Framed, ClientCodec>), SendRequestError>, + Result< + ( + RequestHeadType, + ResponseHead, + Framed, ClientCodec>, + ), + SendRequestError, + >, > { Box::pin(async move { match self { Connection::Tcp(ConnectionType::H1(ref _conn)) => { - let (head, framed) = h1proto::open_tunnel(self, head.into()).await?; - Ok((head, framed)) + let (head, res_head, framed) = h1proto::open_tunnel(self, head.into()).await?; + Ok((head, res_head, framed)) } Connection::Tls(ConnectionType::H1(ref _conn)) => { - let (head, framed) = h1proto::open_tunnel(self, head.into()).await?; - Ok((head, framed)) + let (head, res_head, framed) = h1proto::open_tunnel(self, head.into()).await?; + Ok((head, res_head, framed)) } Connection::Tls(ConnectionType::H2(mut conn)) => { conn.release(); diff --git a/awc/src/client/h1proto.rs b/awc/src/client/h1proto.rs index 3f4c9f979..dcf32e585 100644 --- a/awc/src/client/h1proto.rs +++ b/awc/src/client/h1proto.rs @@ -28,7 +28,7 @@ pub(crate) async fn send_request( io: H1Connection, mut head: RequestHeadType, body: B, -) -> Result<(ResponseHead, Payload), SendRequestError> +) -> Result<(RequestHeadType, ResponseHead, Payload), SendRequestError> where Io: ConnectionIo, B: MessageBody, @@ -86,7 +86,7 @@ where // special handle for EXPECT request. let (do_send, mut res_head) = if is_expect { - pin_framed.send((head, body.size()).into()).await?; + pin_framed.send((&mut head, body.size()).into()).await?; let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx)) .await @@ -96,7 +96,7 @@ where // and current head would be used as final response head. (head.status == StatusCode::CONTINUE, Some(head)) } else { - pin_framed.feed((head, body.size()).into()).await?; + pin_framed.feed((&mut head, body.size()).into()).await?; (true, None) }; @@ -118,17 +118,16 @@ where res_head = Some(head); } - let head = res_head.unwrap(); - match pin_framed.codec_ref().message_type() { h1::MessageType::None => { let keep_alive = pin_framed.codec_ref().keep_alive(); pin_framed.io_mut().on_release(keep_alive); - Ok((head, Payload::None)) + Ok((head, res_head.unwrap(), Payload::None)) } _ => Ok(( head, + res_head.unwrap(), Payload::Stream { payload: Box::pin(PlStream::new(framed)), }, @@ -138,21 +137,21 @@ where pub(crate) async fn open_tunnel( io: Io, - head: RequestHeadType, -) -> Result<(ResponseHead, Framed), SendRequestError> + mut head: RequestHeadType, +) -> Result<(RequestHeadType, ResponseHead, Framed), SendRequestError> where Io: ConnectionIo, { // create Framed and send request. let mut framed = Framed::new(io, h1::ClientCodec::default()); - framed.send((head, BodySize::None).into()).await?; + framed.send((&mut head, BodySize::None).into()).await?; // read response head. - let head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx)) + let res_head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx)) .await .ok_or(ConnectError::Disconnected)??; - Ok((head, framed)) + Ok((head, res_head, framed)) } /// send request body to the peer diff --git a/awc/src/client/h2proto.rs b/awc/src/client/h2proto.rs index c3f801f20..db1b80388 100644 --- a/awc/src/client/h2proto.rs +++ b/awc/src/client/h2proto.rs @@ -29,7 +29,7 @@ pub(crate) async fn send_request( mut io: H2Connection, head: RequestHeadType, body: B, -) -> Result<(ResponseHead, Payload), SendRequestError> +) -> Result<(RequestHeadType, ResponseHead, Payload), SendRequestError> where Io: ConnectionIo, B: MessageBody, @@ -129,10 +129,10 @@ where let (parts, body) = resp.into_parts(); let payload = if head_req { Payload::None } else { body.into() }; - let mut head = ResponseHead::new(parts.status); - head.version = parts.version; - head.headers = parts.headers.into(); - Ok((head, payload)) + let mut res_head = ResponseHead::new(parts.status); + res_head.version = parts.version; + res_head.headers = parts.headers.into(); + Ok((head, res_head, payload)) } async fn send_body(body: B, mut send: SendStream) -> Result<(), SendRequestError> diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 14ed9e958..80d0410d0 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -49,7 +49,11 @@ pub enum ConnectResponse { /// Tunnel used for WebSocket communication. /// /// Contains response head and framed HTTP/1.1 codec. - Tunnel(ResponseHead, Framed), + Tunnel( + RequestHeadType, + ResponseHead, + Framed, + ), } impl ConnectResponse { @@ -70,9 +74,15 @@ impl ConnectResponse { /// /// # Panics /// Panics if enum variant is not `Tunnel`. - pub fn into_tunnel_response(self) -> (ResponseHead, Framed) { + pub fn into_tunnel_response( + self, + ) -> ( + RequestHeadType, + ResponseHead, + Framed, + ) { match self { - ConnectResponse::Tunnel(head, framed) => (head, framed), + ConnectResponse::Tunnel(req, head, framed) => (req, head, framed), _ => { panic!("TunnelResponse only reachable with ConnectResponse::TunnelResponse variant") } @@ -133,12 +143,12 @@ pin_project_lite::pin_project! { req: Option }, Client { - fut: LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + fut: LocalBoxFuture<'static, Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>> }, Tunnel { fut: LocalBoxFuture< 'static, - Result<(ResponseHead, Framed, ClientCodec>), SendRequestError>, + Result<(RequestHeadType, ResponseHead, Framed, ClientCodec>), SendRequestError>, >, } } @@ -181,16 +191,16 @@ where } ConnectRequestProj::Client { fut } => { - let (head, payload) = ready!(fut.as_mut().poll(cx))?; + let (req, head, payload) = ready!(fut.as_mut().poll(cx))?; Poll::Ready(Ok(ConnectResponse::Client(ClientResponse::new( - head, payload, + req, head, payload, )))) } ConnectRequestProj::Tunnel { fut } => { - let (head, framed) = ready!(fut.as_mut().poll(cx))?; + let (req, head, framed) = ready!(fut.as_mut().poll(cx))?; let framed = framed.into_map_io(|io| Box::new(io) as _); - Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed))) + Poll::Ready(Ok(ConnectResponse::Tunnel(req, head, framed))) } } } diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs index b2cf9c45b..67cce8c07 100644 --- a/awc/src/middleware/redirect.rs +++ b/awc/src/middleware/redirect.rs @@ -329,6 +329,7 @@ mod tests { let res = client.get(srv.url("/")).send().await.unwrap(); assert_eq!(res.status().as_u16(), 400); + assert_eq!(res.req_head().uri.path(), "/test"); } #[actix_rt::test] diff --git a/awc/src/responses/response.rs b/awc/src/responses/response.rs index 0eafcff0a..665d5de8a 100644 --- a/awc/src/responses/response.rs +++ b/awc/src/responses/response.rs @@ -8,7 +8,7 @@ use std::{ use actix_http::{ error::PayloadError, header::HeaderMap, BoxedPayloadStream, Extensions, HttpMessage, Payload, - ResponseHead, StatusCode, Version, + RequestHead, RequestHeadType, ResponseHead, StatusCode, Version, }; use actix_rt::time::{sleep, Sleep}; use bytes::Bytes; @@ -23,6 +23,7 @@ use crate::cookie::{Cookie, ParseError as CookieParseError}; pin_project! { /// Client Response pub struct ClientResponse { + pub(crate) req_head: RequestHeadType, pub(crate) head: ResponseHead, #[pin] pub(crate) payload: Payload, @@ -34,8 +35,9 @@ pin_project! { impl ClientResponse { /// Create new Request instance - pub(crate) fn new(head: ResponseHead, payload: Payload) -> Self { + pub(crate) fn new(req_head: RequestHeadType, head: ResponseHead, payload: Payload) -> Self { ClientResponse { + req_head, head, payload, timeout: ResponseTimeout::default(), @@ -43,6 +45,12 @@ impl ClientResponse { } } + /// Returns the request head used to send the request. + #[inline] + pub fn req_head(&self) -> &RequestHead { + self.req_head.as_ref() + } + #[inline] pub(crate) fn head(&self) -> &ResponseHead { &self.head @@ -77,6 +85,7 @@ impl ClientResponse { ClientResponse { payload, + req_head: self.req_head, head: self.head, timeout: self.timeout, extensions: self.extensions, @@ -105,6 +114,7 @@ impl ClientResponse { Self { payload: self.payload, head: self.head, + req_head: self.req_head, timeout, extensions: self.extensions, } diff --git a/awc/src/test.rs b/awc/src/test.rs index 126583179..c5d9dc85c 100644 --- a/awc/src/test.rs +++ b/awc/src/test.rs @@ -1,6 +1,8 @@ //! Test helpers for actix http client to use during testing. -use actix_http::{h1, header::TryIntoHeaderPair, Payload, ResponseHead, StatusCode, Version}; +use actix_http::{ + h1, header::TryIntoHeaderPair, Payload, RequestHead, ResponseHead, StatusCode, Version, +}; use bytes::Bytes; #[cfg(feature = "cookies")] @@ -9,6 +11,7 @@ use crate::ClientResponse; /// Test `ClientResponse` builder pub struct TestResponse { + req_head: RequestHead, head: ResponseHead, #[cfg(feature = "cookies")] cookies: CookieJar, @@ -18,6 +21,7 @@ pub struct TestResponse { impl Default for TestResponse { fn default() -> TestResponse { TestResponse { + req_head: RequestHead::default(), head: ResponseHead::new(StatusCode::OK), #[cfg(feature = "cookies")] cookies: CookieJar::new(), @@ -88,10 +92,10 @@ impl TestResponse { } if let Some(pl) = self.payload { - ClientResponse::new(head, pl) + ClientResponse::new(self.req_head.into(), head, pl) } else { let (_, payload) = h1::Payload::create(true); - ClientResponse::new(head, payload.into()) + ClientResponse::new(self.req_head.into(), head, payload.into()) } } } diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 760331e9d..ff4b2a8c1 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -351,7 +351,7 @@ impl WebsocketsRequest { fut.await? }; - let (head, framed) = res.into_tunnel_response(); + let (req_head, head, framed) = res.into_tunnel_response(); // verify response if head.status != StatusCode::SWITCHING_PROTOCOLS { @@ -411,7 +411,7 @@ impl WebsocketsRequest { // response and ws framed Ok(( - ClientResponse::new(head, Payload::None), + ClientResponse::new(req_head, head, Payload::None), framed.into_map_codec(|_| { if server_mode { ws::Codec::new().max_size(max_size)