1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-02-14 01:55:14 +00:00

feat(awc): allow to retrieve request head in client response

This commit is contained in:
Joel Wurtz 2024-12-16 15:40:23 +01:00
parent 8115c818c1
commit adf3a06805
No known key found for this signature in database
GPG key ID: ED264D1967A51B0D
10 changed files with 74 additions and 42 deletions

View file

@ -188,16 +188,16 @@ impl Decoder for ClientPayloadCodec {
}
}
impl Encoder<Message<(RequestHeadType, BodySize)>> for ClientCodec {
impl Encoder<Message<(&mut RequestHeadType, BodySize)>> for ClientCodec {
type Error = io::Error;
fn encode(
&mut self,
item: Message<(RequestHeadType, BodySize)>,
item: Message<(&mut RequestHeadType, BodySize)>,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
match item {
Message::Item((mut head, length)) => {
Message::Item((head, length)) => {
let inner = &mut self.inner;
inner.version = head.as_ref().version;
inner
@ -219,7 +219,7 @@ impl Encoder<Message<(RequestHeadType, BodySize)>> for ClientCodec {
inner.encoder.encode(
dst,
&mut head,
head,
false,
false,
inner.version,

View file

@ -5,6 +5,7 @@
- Update `brotli` dependency to `7`.
- Prevent panics on connection pool drop when Tokio runtime is shutdown early.
- Minimum supported Rust version (MSRV) is now 1.75.
- Allow to retrieve request head used to send the http request on `ClientResponse`
## 3.5.1

View file

@ -243,7 +243,7 @@ where
self,
head: H,
body: RB,
) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>
) -> LocalBoxFuture<'static, Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>>
where
H: Into<RequestHeadType> + 'static,
RB: MessageBody + 'static,
@ -273,17 +273,24 @@ where
head: H,
) -> LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Connection<A, B>, ClientCodec>), SendRequestError>,
Result<
(
RequestHeadType,
ResponseHead,
Framed<Connection<A, B>, ClientCodec>,
),
SendRequestError,
>,
> {
Box::pin(async move {
match self {
Connection::Tcp(ConnectionType::H1(ref _conn)) => {
let (head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, framed))
let (head, res_head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, res_head, framed))
}
Connection::Tls(ConnectionType::H1(ref _conn)) => {
let (head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, framed))
let (head, res_head, framed) = h1proto::open_tunnel(self, head.into()).await?;
Ok((head, res_head, framed))
}
Connection::Tls(ConnectionType::H2(mut conn)) => {
conn.release();

View file

@ -28,7 +28,7 @@ pub(crate) async fn send_request<Io, B>(
io: H1Connection<Io>,
mut head: RequestHeadType,
body: B,
) -> Result<(ResponseHead, Payload), SendRequestError>
) -> Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>
where
Io: ConnectionIo,
B: MessageBody,
@ -86,7 +86,7 @@ where
// special handle for EXPECT request.
let (do_send, mut res_head) = if is_expect {
pin_framed.send((head, body.size()).into()).await?;
pin_framed.send((&mut head, body.size()).into()).await?;
let head = poll_fn(|cx| pin_framed.as_mut().poll_next(cx))
.await
@ -96,7 +96,7 @@ where
// and current head would be used as final response head.
(head.status == StatusCode::CONTINUE, Some(head))
} else {
pin_framed.feed((head, body.size()).into()).await?;
pin_framed.feed((&mut head, body.size()).into()).await?;
(true, None)
};
@ -118,17 +118,16 @@ where
res_head = Some(head);
}
let head = res_head.unwrap();
match pin_framed.codec_ref().message_type() {
h1::MessageType::None => {
let keep_alive = pin_framed.codec_ref().keep_alive();
pin_framed.io_mut().on_release(keep_alive);
Ok((head, Payload::None))
Ok((head, res_head.unwrap(), Payload::None))
}
_ => Ok((
head,
res_head.unwrap(),
Payload::Stream {
payload: Box::pin(PlStream::new(framed)),
},
@ -138,21 +137,21 @@ where
pub(crate) async fn open_tunnel<Io>(
io: Io,
head: RequestHeadType,
) -> Result<(ResponseHead, Framed<Io, h1::ClientCodec>), SendRequestError>
mut head: RequestHeadType,
) -> Result<(RequestHeadType, ResponseHead, Framed<Io, h1::ClientCodec>), SendRequestError>
where
Io: ConnectionIo,
{
// create Framed and send request.
let mut framed = Framed::new(io, h1::ClientCodec::default());
framed.send((head, BodySize::None).into()).await?;
framed.send((&mut head, BodySize::None).into()).await?;
// read response head.
let head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx))
let res_head = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx))
.await
.ok_or(ConnectError::Disconnected)??;
Ok((head, framed))
Ok((head, res_head, framed))
}
/// send request body to the peer

View file

@ -29,7 +29,7 @@ pub(crate) async fn send_request<Io, B>(
mut io: H2Connection<Io>,
head: RequestHeadType,
body: B,
) -> Result<(ResponseHead, Payload), SendRequestError>
) -> Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>
where
Io: ConnectionIo,
B: MessageBody,
@ -129,10 +129,10 @@ where
let (parts, body) = resp.into_parts();
let payload = if head_req { Payload::None } else { body.into() };
let mut head = ResponseHead::new(parts.status);
head.version = parts.version;
head.headers = parts.headers.into();
Ok((head, payload))
let mut res_head = ResponseHead::new(parts.status);
res_head.version = parts.version;
res_head.headers = parts.headers.into();
Ok((head, res_head, payload))
}
async fn send_body<B>(body: B, mut send: SendStream<Bytes>) -> Result<(), SendRequestError>

View file

@ -49,7 +49,11 @@ pub enum ConnectResponse {
/// Tunnel used for WebSocket communication.
///
/// Contains response head and framed HTTP/1.1 codec.
Tunnel(ResponseHead, Framed<BoxedSocket, ClientCodec>),
Tunnel(
RequestHeadType,
ResponseHead,
Framed<BoxedSocket, ClientCodec>,
),
}
impl ConnectResponse {
@ -70,9 +74,15 @@ impl ConnectResponse {
///
/// # Panics
/// Panics if enum variant is not `Tunnel`.
pub fn into_tunnel_response(self) -> (ResponseHead, Framed<BoxedSocket, ClientCodec>) {
pub fn into_tunnel_response(
self,
) -> (
RequestHeadType,
ResponseHead,
Framed<BoxedSocket, ClientCodec>,
) {
match self {
ConnectResponse::Tunnel(head, framed) => (head, framed),
ConnectResponse::Tunnel(req, head, framed) => (req, head, framed),
_ => {
panic!("TunnelResponse only reachable with ConnectResponse::TunnelResponse variant")
}
@ -133,12 +143,12 @@ pin_project_lite::pin_project! {
req: Option<ConnectRequest>
},
Client {
fut: LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>
fut: LocalBoxFuture<'static, Result<(RequestHeadType, ResponseHead, Payload), SendRequestError>>
},
Tunnel {
fut: LocalBoxFuture<
'static,
Result<(ResponseHead, Framed<Connection<Io>, ClientCodec>), SendRequestError>,
Result<(RequestHeadType, ResponseHead, Framed<Connection<Io>, ClientCodec>), SendRequestError>,
>,
}
}
@ -181,16 +191,16 @@ where
}
ConnectRequestProj::Client { fut } => {
let (head, payload) = ready!(fut.as_mut().poll(cx))?;
let (req, head, payload) = ready!(fut.as_mut().poll(cx))?;
Poll::Ready(Ok(ConnectResponse::Client(ClientResponse::new(
head, payload,
req, head, payload,
))))
}
ConnectRequestProj::Tunnel { fut } => {
let (head, framed) = ready!(fut.as_mut().poll(cx))?;
let (req, head, framed) = ready!(fut.as_mut().poll(cx))?;
let framed = framed.into_map_io(|io| Box::new(io) as _);
Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed)))
Poll::Ready(Ok(ConnectResponse::Tunnel(req, head, framed)))
}
}
}

View file

@ -329,6 +329,7 @@ mod tests {
let res = client.get(srv.url("/")).send().await.unwrap();
assert_eq!(res.status().as_u16(), 400);
assert_eq!(res.req_head().uri.path(), "/test");
}
#[actix_rt::test]

View file

@ -8,7 +8,7 @@ use std::{
use actix_http::{
error::PayloadError, header::HeaderMap, BoxedPayloadStream, Extensions, HttpMessage, Payload,
ResponseHead, StatusCode, Version,
RequestHead, RequestHeadType, ResponseHead, StatusCode, Version,
};
use actix_rt::time::{sleep, Sleep};
use bytes::Bytes;
@ -23,6 +23,7 @@ use crate::cookie::{Cookie, ParseError as CookieParseError};
pin_project! {
/// Client Response
pub struct ClientResponse<S = BoxedPayloadStream> {
pub(crate) req_head: RequestHeadType,
pub(crate) head: ResponseHead,
#[pin]
pub(crate) payload: Payload<S>,
@ -34,8 +35,9 @@ pin_project! {
impl<S> ClientResponse<S> {
/// Create new Request instance
pub(crate) fn new(head: ResponseHead, payload: Payload<S>) -> Self {
pub(crate) fn new(req_head: RequestHeadType, head: ResponseHead, payload: Payload<S>) -> Self {
ClientResponse {
req_head,
head,
payload,
timeout: ResponseTimeout::default(),
@ -43,6 +45,12 @@ impl<S> ClientResponse<S> {
}
}
/// Returns the request head used to send the request.
#[inline]
pub fn req_head(&self) -> &RequestHead {
self.req_head.as_ref()
}
#[inline]
pub(crate) fn head(&self) -> &ResponseHead {
&self.head
@ -77,6 +85,7 @@ impl<S> ClientResponse<S> {
ClientResponse {
payload,
req_head: self.req_head,
head: self.head,
timeout: self.timeout,
extensions: self.extensions,
@ -105,6 +114,7 @@ impl<S> ClientResponse<S> {
Self {
payload: self.payload,
head: self.head,
req_head: self.req_head,
timeout,
extensions: self.extensions,
}

View file

@ -1,6 +1,8 @@
//! Test helpers for actix http client to use during testing.
use actix_http::{h1, header::TryIntoHeaderPair, Payload, ResponseHead, StatusCode, Version};
use actix_http::{
h1, header::TryIntoHeaderPair, Payload, RequestHead, ResponseHead, StatusCode, Version,
};
use bytes::Bytes;
#[cfg(feature = "cookies")]
@ -9,6 +11,7 @@ use crate::ClientResponse;
/// Test `ClientResponse` builder
pub struct TestResponse {
req_head: RequestHead,
head: ResponseHead,
#[cfg(feature = "cookies")]
cookies: CookieJar,
@ -18,6 +21,7 @@ pub struct TestResponse {
impl Default for TestResponse {
fn default() -> TestResponse {
TestResponse {
req_head: RequestHead::default(),
head: ResponseHead::new(StatusCode::OK),
#[cfg(feature = "cookies")]
cookies: CookieJar::new(),
@ -88,10 +92,10 @@ impl TestResponse {
}
if let Some(pl) = self.payload {
ClientResponse::new(head, pl)
ClientResponse::new(self.req_head.into(), head, pl)
} else {
let (_, payload) = h1::Payload::create(true);
ClientResponse::new(head, payload.into())
ClientResponse::new(self.req_head.into(), head, payload.into())
}
}
}

View file

@ -351,7 +351,7 @@ impl WebsocketsRequest {
fut.await?
};
let (head, framed) = res.into_tunnel_response();
let (req_head, head, framed) = res.into_tunnel_response();
// verify response
if head.status != StatusCode::SWITCHING_PROTOCOLS {
@ -411,7 +411,7 @@ impl WebsocketsRequest {
// response and ws framed
Ok((
ClientResponse::new(head, Payload::None),
ClientResponse::new(req_head, head, Payload::None),
framed.into_map_codec(|_| {
if server_mode {
ws::Codec::new().max_size(max_size)