diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index 778083a1c..97ecd0515 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -75,11 +75,14 @@ pub trait Connection { type Io: AsyncRead + AsyncWrite + Unpin; /// Send request and body - fn send_request>( + fn send_request( self, head: H, body: B, - ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; + ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + where + B: MessageBody + 'static, + H: Into + 'static; /// Send request, returns Response and Framed fn open_tunnel + 'static>( @@ -144,47 +147,31 @@ impl IoConnection { pub(crate) fn into_parts(self) -> (ConnectionType, time::Instant, Acquired) { (self.io.unwrap(), self.created, self.pool.unwrap()) } -} -impl Connection for IoConnection -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ - type Io = T; - - fn send_request>( + async fn send_request>( mut self, head: H, body: B, - ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> { + ) -> Result<(ResponseHead, Payload), SendRequestError> { match self.io.take().unwrap() { - ConnectionType::H1(io) => Box::pin(h1proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), - ConnectionType::H2(io) => Box::pin(h2proto::send_request( - io, - head.into(), - body, - self.created, - self.pool, - )), + ConnectionType::H1(io) => { + h1proto::send_request(io, head.into(), body, self.created, self.pool) + .await + } + ConnectionType::H2(io) => { + h2proto::send_request(io, head.into(), body, self.created, self.pool) + .await + } } } /// Send request, returns Response and Framed - fn open_tunnel>( + async fn open_tunnel>( mut self, head: H, - ) -> LocalBoxFuture< - 'static, - Result<(ResponseHead, Framed), SendRequestError>, - > { + ) -> Result<(ResponseHead, Framed), SendRequestError> { match self.io.take().unwrap() { - ConnectionType::H1(io) => Box::pin(h1proto::open_tunnel(io, head.into())), + ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await, ConnectionType::H2(io) => { if let Some(mut pool) = self.pool.take() { pool.release(IoConnection::new( @@ -193,7 +180,7 @@ where None, )); } - Box::pin(async { Err(SendRequestError::TunnelNotSupported) }) + Err(SendRequestError::TunnelNotSupported) } } } @@ -216,14 +203,18 @@ where { type Io = EitherIo; - fn send_request>( + fn send_request( self, head: H, body: RB, - ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> { + ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> + where + RB: MessageBody + 'static, + H: Into + 'static, + { match self { - EitherIoConnection::A(con) => con.send_request(head, body), - EitherIoConnection::B(con) => con.send_request(head, body), + EitherIoConnection::A(con) => Box::pin(con.send_request(head, body)), + EitherIoConnection::B(con) => Box::pin(con.send_request(head, body)), } } diff --git a/actix-http/src/client/connector.rs b/actix-http/src/client/connector.rs index 65536f257..8aa5b1319 100644 --- a/actix-http/src/client/connector.rs +++ b/actix-http/src/client/connector.rs @@ -1,5 +1,8 @@ use std::fmt; +use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; use actix_codec::{AsyncRead, AsyncWrite}; @@ -12,7 +15,7 @@ use actix_utils::timeout::{TimeoutError, TimeoutService}; use http::Uri; use super::config::ConnectorConfig; -use super::connection::Connection; +use super::connection::{Connection, EitherIoConnection}; use super::error::ConnectError; use super::pool::{ConnectionPool, Protocol}; use super::Connect; @@ -55,7 +58,7 @@ pub struct Connector { _phantom: PhantomData, } -trait Io: AsyncRead + AsyncWrite + Unpin {} +pub trait Io: AsyncRead + AsyncWrite + Unpin {} impl Io for T {} impl Connector<(), ()> { @@ -244,28 +247,43 @@ where self, ) -> impl Service + Clone { + let tcp_service = TimeoutService::new( + self.config.timeout, + apply_fn(self.connector.clone(), |msg: Connect, srv| { + srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) + }) + .map_err(ConnectError::from) + .map(|stream| (stream.into_parts().0, Protocol::Http1)), + ) + .map_err(|e| match e { + TimeoutError::Service(e) => e, + TimeoutError::Timeout => ConnectError::Timeout, + }); + #[cfg(not(any(feature = "openssl", feature = "rustls")))] { - let connector = TimeoutService::new( - self.config.timeout, - apply_fn(self.connector, |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from) - .map(|stream| (stream.into_parts().0, Protocol::Http1)), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => ConnectError::Timeout, - }); + // A dummy service for annotate tls pool's type signature. + pub type DummyService = Box< + dyn Service< + Connect, + Response = (Box, Protocol), + Error = ConnectError, + Future = futures_core::future::LocalBoxFuture< + 'static, + Result<(Box, Protocol), ConnectError>, + >, + >, + >; - connect_impl::InnerConnector { + InnerConnector::<_, DummyService, _, Box> { tcp_pool: ConnectionPool::new( - connector, + tcp_service, self.config.no_disconnect_timeout(), ), + tls_pool: None, } } + #[cfg(any(feature = "openssl", feature = "rustls"))] { const H2: &[u8] = b"h2"; @@ -328,172 +346,97 @@ where TimeoutError::Timeout => ConnectError::Timeout, }); - let tcp_service = TimeoutService::new( - self.config.timeout, - apply_fn(self.connector, |msg: Connect, srv| { - srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) - }) - .map_err(ConnectError::from) - .map(|stream| (stream.into_parts().0, Protocol::Http1)), - ) - .map_err(|e| match e { - TimeoutError::Service(e) => e, - TimeoutError::Timeout => ConnectError::Timeout, - }); - - connect_impl::InnerConnector { + InnerConnector { tcp_pool: ConnectionPool::new( tcp_service, self.config.no_disconnect_timeout(), ), - ssl_pool: ConnectionPool::new(ssl_service, self.config), + tls_pool: Some(ConnectionPool::new(ssl_service, self.config)), } } } } -#[cfg(not(any(feature = "openssl", feature = "rustls")))] -mod connect_impl { - use std::task::{Context, Poll}; +struct InnerConnector +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + tcp_pool: ConnectionPool, + tls_pool: Option>, +} - use futures_core::future::LocalBoxFuture; - - use super::*; - use crate::client::connection::IoConnection; - - pub(crate) struct InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, - { - pub(crate) tcp_pool: ConnectionPool, - } - - impl Clone for InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, - { - fn clone(&self) -> Self { - InnerConnector { - tcp_pool: self.tcp_pool.clone(), - } - } - } - - impl Service for InnerConnector - where - Io: AsyncRead + AsyncWrite + Unpin + 'static, - T: Service + 'static, - { - type Response = IoConnection; - type Error = ConnectError; - type Future = LocalBoxFuture<'static, Result, ConnectError>>; - - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.tcp_pool.poll_ready(cx) - } - - fn call(&self, req: Connect) -> Self::Future { - match req.uri.scheme_str() { - Some("https") | Some("wss") => { - Box::pin(async { Err(ConnectError::SslIsNotSupported) }) - } - _ => self.tcp_pool.call(req), - } +impl Clone for InnerConnector +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn clone(&self) -> Self { + InnerConnector { + tcp_pool: self.tcp_pool.clone(), + tls_pool: self.tls_pool.as_ref().cloned(), } } } -#[cfg(any(feature = "openssl", feature = "rustls"))] -mod connect_impl { - use std::future::Future; - use std::pin::Pin; - use std::task::{Context, Poll}; +impl Service for InnerConnector +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Response = EitherIoConnection; + type Error = ConnectError; + type Future = InnerConnectorResponse; - use super::*; - use crate::client::connection::EitherIoConnection; - - pub(crate) struct InnerConnector - where - S1: Service + 'static, - S2: Service + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - pub(crate) tcp_pool: ConnectionPool, - pub(crate) ssl_pool: ConnectionPool, + fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.tcp_pool.poll_ready(cx) } - impl Clone for InnerConnector - where - S1: Service + 'static, - S2: Service + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - fn clone(&self) -> Self { - InnerConnector { - tcp_pool: self.tcp_pool.clone(), - ssl_pool: self.ssl_pool.clone(), - } + fn call(&self, req: Connect) -> Self::Future { + match req.uri.scheme_str() { + Some("https") | Some("wss") => match self.tls_pool { + None => InnerConnectorResponse::SslIsNotSupported, + Some(ref pool) => InnerConnectorResponse::Io2(pool.call(req)), + }, + _ => InnerConnectorResponse::Io1(self.tcp_pool.call(req)), } } +} - impl Service for InnerConnector - where - S1: Service + 'static, - S2: Service + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - type Response = EitherIoConnection; - type Error = ConnectError; - type Future = InnerConnectorResponse; +#[pin_project::pin_project(project = InnerConnectorProj)] +enum InnerConnectorResponse +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + Io1(#[pin] as Service>::Future), + Io2(#[pin] as Service>::Future), + SslIsNotSupported, +} - fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { - self.tcp_pool.poll_ready(cx) - } +impl Future for InnerConnectorResponse +where + S1: Service + 'static, + S2: Service + 'static, + Io1: AsyncRead + AsyncWrite + Unpin + 'static, + Io2: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Output = Result, ConnectError>; - fn call(&self, req: Connect) -> Self::Future { - match req.uri.scheme_str() { - Some("https") | Some("wss") => { - InnerConnectorResponse::Io2(self.ssl_pool.call(req)) - } - _ => InnerConnectorResponse::Io1(self.tcp_pool.call(req)), - } - } - } - - #[pin_project::pin_project(project = InnerConnectorProj)] - pub(crate) enum InnerConnectorResponse - where - S1: Service + 'static, - S2: Service + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - Io1(#[pin] as Service>::Future), - Io2(#[pin] as Service>::Future), - } - - impl Future for InnerConnectorResponse - where - S1: Service + 'static, - S2: Service + 'static, - Io1: AsyncRead + AsyncWrite + Unpin + 'static, - Io2: AsyncRead + AsyncWrite + Unpin + 'static, - { - type Output = Result, ConnectError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project() { - InnerConnectorProj::Io1(fut) => { - fut.poll(cx).map_ok(EitherIoConnection::A) - } - InnerConnectorProj::Io2(fut) => { - fut.poll(cx).map_ok(EitherIoConnection::B) - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + InnerConnectorProj::Io1(fut) => fut.poll(cx).map_ok(EitherIoConnection::A), + InnerConnectorProj::Io2(fut) => fut.poll(cx).map_ok(EitherIoConnection::B), + InnerConnectorProj::SslIsNotSupported => { + Poll::Ready(Err(ConnectError::SslIsNotSupported)) } } }