diff --git a/awc/src/builder.rs b/awc/src/builder.rs index 3d1613c66..b5e8bf5bf 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -7,7 +7,7 @@ use actix_http::client::{Connect as HttpConnect, ConnectError, Connection, Conne use actix_http::http::{self, header, Error as HttpError, HeaderMap, HeaderName}; use actix_service::Service; -use crate::connect::{Connect, ConnectorWrapper}; +use crate::connect::{ConnectService, ConnectorWrapper}; use crate::{Client, ClientConfig}; /// An HTTP Client builder @@ -23,7 +23,7 @@ pub struct ClientBuilder { conn_window_size: Option, headers: HeaderMap, timeout: Option, - connector: Option>, + connector: Option, } impl Default for ClientBuilder { @@ -54,7 +54,7 @@ impl ClientBuilder { T::Response: Connection, T::Future: 'static, { - self.connector = Some(Box::new(ConnectorWrapper(connector))); + self.connector = Some(Box::new(ConnectorWrapper::new(connector))); self } @@ -181,7 +181,7 @@ impl ClientBuilder { if let Some(val) = self.stream_window_size { connector = connector.initial_window_size(val) }; - Box::new(ConnectorWrapper(connector.finish())) as _ + Box::new(ConnectorWrapper::new(connector.finish())) as _ }; let config = ClientConfig { headers: self.headers, diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 9a2ded195..97af2d1cc 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -16,74 +16,100 @@ use futures_core::future::LocalBoxFuture; use crate::response::ClientResponse; -pub(crate) struct ConnectorWrapper(pub T); - -type TunnelResponse = (ResponseHead, Framed); - -pub(crate) trait Connect { - fn send_request( - &self, - head: RequestHeadType, - body: Body, - addr: Option, - ) -> LocalBoxFuture<'static, Result>; - - /// Send request, returns Response and Framed - fn open_tunnel( - &self, - head: RequestHead, - addr: Option, - ) -> LocalBoxFuture<'static, Result>; +pub(crate) struct ConnectorWrapper { + connector: T, } -impl Connect for ConnectorWrapper +impl ConnectorWrapper { + pub(crate) fn new(connector: T) -> Self { + Self { connector } + } +} + +pub type ConnectService = Box< + dyn Service< + ConnectRequest, + Response = ConnectResponse, + Error = SendRequestError, + Future = LocalBoxFuture<'static, Result>, + >, +>; + +pub enum ConnectRequest { + Client(RequestHeadType, Body, Option), + Tunnel(RequestHead, Option), +} + +pub enum ConnectResponse { + Client(ClientResponse), + Tunnel(ResponseHead, Framed), +} + +impl ConnectResponse { + pub fn into_client_response(self) -> ClientResponse { + match self { + ConnectResponse::Client(res) => res, + _ => panic!( + "ClientResponse only reachable with ConnectResponse::ClientResponse variant" + ), + } + } + + pub fn into_tunnel_response(self) -> (ResponseHead, Framed) { + match self { + ConnectResponse::Tunnel(head, framed) => (head, framed), + _ => panic!( + "TunnelResponse only reachable with ConnectResponse::TunnelResponse variant" + ), + } + } +} + +impl Service for ConnectorWrapper where T: Service, T::Response: Connection, ::Io: 'static, T::Future: 'static, { - fn send_request( - &self, - head: RequestHeadType, - body: Body, - addr: Option, - ) -> LocalBoxFuture<'static, Result> { + type Response = ConnectResponse; + type Error = SendRequestError; + type Future = LocalBoxFuture<'static, Result>; + + actix_service::forward_ready!(connector); + + fn call(&self, req: ConnectRequest) -> Self::Future { // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.as_ref().uri.clone(), - addr, - }); + let fut = match req { + ConnectRequest::Client(ref head, .., addr) => self.connector.call(ClientConnect { + uri: head.as_ref().uri.clone(), + addr, + }), + ConnectRequest::Tunnel(ref head, addr) => self.connector.call(ClientConnect { + uri: head.uri.clone(), + addr, + }), + }; Box::pin(async move { let connection = fut.await?; - // send request - let (head, payload) = connection.send_request(head, body).await?; + match req { + ConnectRequest::Client(head, body, ..) => { + // send request + let (head, payload) = connection.send_request(head, body).await?; - Ok(ClientResponse::new(head, payload)) - }) - } + Ok(ConnectResponse::Client(ClientResponse::new(head, payload))) + } + ConnectRequest::Tunnel(head, ..) => { + // send request + let (head, framed) = + connection.open_tunnel(RequestHeadType::from(head)).await?; - fn open_tunnel( - &self, - head: RequestHead, - addr: Option, - ) -> LocalBoxFuture<'static, Result> { - // connect to the host - let fut = self.0.call(ClientConnect { - uri: head.uri.clone(), - addr, - }); - - Box::pin(async move { - let connection = fut.await?; - - // send request - let (head, framed) = connection.open_tunnel(RequestHeadType::from(head)).await?; - - let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); - Ok((head, framed)) + let framed = framed.into_map_io(|io| BoxedSocket(Box::new(Socket(io)))); + Ok(ConnectResponse::Tunnel(head, framed)) + } + } }) } } diff --git a/awc/src/lib.rs b/awc/src/lib.rs index e50c19c8c..1e27ed9ab 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -115,13 +115,13 @@ pub mod test; pub mod ws; pub use self::builder::ClientBuilder; -pub use self::connect::BoxedSocket; +pub use self::connect::{BoxedSocket, ConnectRequest, ConnectResponse, ConnectService}; pub use self::frozen::{FrozenClientRequest, FrozenSendBuilder}; pub use self::request::ClientRequest; pub use self::response::{ClientResponse, JsonBody, MessageBody}; pub use self::sender::SendClientRequest; -use self::connect::{Connect, ConnectorWrapper}; +use self::connect::ConnectorWrapper; /// An asynchronous HTTP and WebSocket client. /// @@ -146,7 +146,7 @@ use self::connect::{Connect, ConnectorWrapper}; pub struct Client(Rc); pub(crate) struct ClientConfig { - pub(crate) connector: Box, + pub(crate) connector: ConnectService, pub(crate) headers: HeaderMap, pub(crate) timeout: Option, } @@ -154,7 +154,7 @@ pub(crate) struct ClientConfig { impl Default for Client { fn default() -> Self { Client(Rc::new(ClientConfig { - connector: Box::new(ConnectorWrapper(Connector::new().finish())), + connector: Box::new(ConnectorWrapper::new(Connector::new().finish())), headers: HeaderMap::new(), timeout: Some(Duration::from_secs(5)), })) diff --git a/awc/src/sender.rs b/awc/src/sender.rs index 6bac401c5..1170c69a0 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -18,12 +18,13 @@ use actix_http::{ use actix_rt::time::{sleep, Sleep}; use bytes::Bytes; use derive_more::From; -use futures_core::{ready, Stream}; +use futures_core::Stream; use serde::Serialize; #[cfg(feature = "compress")] use actix_http::{encoding::Decoder, http::header::ContentEncoding, Payload, PayloadStream}; +use crate::connect::{ConnectRequest, ConnectResponse}; use crate::error::{FreezeRequestError, InvalidUrl, SendRequestError}; use crate::response::ClientResponse; use crate::ClientConfig; @@ -56,7 +57,8 @@ impl From for SendRequestError { #[must_use = "futures do nothing unless polled"] pub enum SendClientRequest { Fut( - Pin>>>, + Pin>>>, + // FIXME: use a pinned Sleep instead of box. Option>>, bool, ), @@ -65,7 +67,7 @@ pub enum SendClientRequest { impl SendClientRequest { pub(crate) fn new( - send: Pin>>>, + send: Pin>>>, response_decompress: bool, timeout: Option, ) -> SendClientRequest { @@ -89,14 +91,19 @@ impl Future for SendClientRequest { } } - let res = ready!(send.as_mut().poll(cx)).map(|res| { - res._timeout(delay.take()).map_body(|head, payload| { - if *response_decompress { - Payload::Stream(Decoder::from_headers(payload, &head.headers)) - } else { - Payload::Stream(Decoder::new(payload, ContentEncoding::Identity)) - } - }) + let res = futures_core::ready!(send.as_mut().poll(cx)).map(|res| { + res.into_client_response()._timeout(delay.take()).map_body( + |head, payload| { + if *response_decompress { + Payload::Stream(Decoder::from_headers(payload, &head.headers)) + } else { + Payload::Stream(Decoder::new( + payload, + ContentEncoding::Identity, + )) + } + }, + ) }); Poll::Ready(res) @@ -122,10 +129,9 @@ impl Future for SendClientRequest { return Poll::Ready(Err(SendRequestError::Timeout)); } } - send.as_mut() .poll(cx) - .map_ok(|res| res._timeout(delay.take())) + .map_ok(|res| res.into_client_response()._timeout(delay.take())) } SendClientRequest::Err(ref mut e) => match e.take() { Some(e) => Poll::Ready(Err(e)), @@ -177,19 +183,19 @@ impl RequestSender { where B: Into, { - let fut = match self { + let req = match self { RequestSender::Owned(head) => { - config - .connector - .send_request(RequestHeadType::Owned(head), body.into(), addr) + ConnectRequest::Client(RequestHeadType::Owned(head), body.into(), addr) } - RequestSender::Rc(head, extra_headers) => config.connector.send_request( + RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( RequestHeadType::Rc(head, extra_headers), body.into(), addr, ), }; + let fut = config.connector.call(req); + SendClientRequest::new(fut, response_decompress, timeout.or(config.timeout)) } diff --git a/awc/src/ws.rs b/awc/src/ws.rs index d5528595d..5f4570963 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -36,10 +36,11 @@ use actix_codec::Framed; use actix_http::cookie::{Cookie, CookieJar}; use actix_http::{ws, Payload, RequestHead}; use actix_rt::time::timeout; +use actix_service::Service; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; -use crate::connect::BoxedSocket; +use crate::connect::{BoxedSocket, ConnectRequest}; use crate::error::{InvalidUrl, SendRequestError, WsClientError}; use crate::http::header::{self, HeaderName, HeaderValue, IntoHeaderValue, AUTHORIZATION}; use crate::http::{ConnectionType, Error as HttpError, Method, StatusCode, Uri, Version}; @@ -327,18 +328,21 @@ impl WebsocketsRequest { let max_size = self.max_size; let server_mode = self.server_mode; - let fut = self.config.connector.open_tunnel(head, self.addr); + let req = ConnectRequest::Tunnel(head, self.addr); + + let fut = self.config.connector.call(req); // set request timeout - let (head, framed) = if let Some(to) = self.config.timeout { + let res = if let Some(to) = self.config.timeout { timeout(to, fut) .await - .map_err(|_| SendRequestError::Timeout) - .and_then(|res| res)? + .map_err(|_| SendRequestError::Timeout)?? } else { fut.await? }; + let (head, framed) = res.into_tunnel_response(); + // verify response if head.status != StatusCode::SWITCHING_PROTOCOLS { return Err(WsClientError::InvalidResponseStatus(head.status));