1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-12-23 00:26:34 +00:00

reduce duplicate code (#2020)

This commit is contained in:
fakeshadow 2021-02-22 03:15:12 -08:00 committed by GitHub
parent 2dbdf61c37
commit aacec30ad1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 130 additions and 196 deletions

View file

@ -75,11 +75,14 @@ pub trait Connection {
type Io: AsyncRead + AsyncWrite + Unpin; type Io: AsyncRead + AsyncWrite + Unpin;
/// Send request and body /// Send request and body
fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>( fn send_request<B, H>(
self, self,
head: H, head: H,
body: B, body: B,
) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>; ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>
where
B: MessageBody + 'static,
H: Into<RequestHeadType> + 'static;
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType> + 'static>( fn open_tunnel<H: Into<RequestHeadType> + 'static>(
@ -144,47 +147,31 @@ impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
pub(crate) fn into_parts(self) -> (ConnectionType<T>, time::Instant, Acquired<T>) { pub(crate) fn into_parts(self) -> (ConnectionType<T>, time::Instant, Acquired<T>) {
(self.io.unwrap(), self.created, self.pool.unwrap()) (self.io.unwrap(), self.created, self.pool.unwrap())
} }
}
impl<T> Connection for IoConnection<T> async fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Io = T;
fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
mut self, mut self,
head: H, head: H,
body: B, body: B,
) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> { ) -> Result<(ResponseHead, Payload), SendRequestError> {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => Box::pin(h1proto::send_request( ConnectionType::H1(io) => {
io, h1proto::send_request(io, head.into(), body, self.created, self.pool)
head.into(), .await
body, }
self.created, ConnectionType::H2(io) => {
self.pool, h2proto::send_request(io, head.into(), body, self.created, self.pool)
)), .await
ConnectionType::H2(io) => Box::pin(h2proto::send_request( }
io,
head.into(),
body,
self.created,
self.pool,
)),
} }
} }
/// Send request, returns Response and Framed /// Send request, returns Response and Framed
fn open_tunnel<H: Into<RequestHeadType>>( async fn open_tunnel<H: Into<RequestHeadType>>(
mut self, mut self,
head: H, head: H,
) -> LocalBoxFuture< ) -> Result<(ResponseHead, Framed<T, ClientCodec>), SendRequestError> {
'static,
Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
> {
match self.io.take().unwrap() { match self.io.take().unwrap() {
ConnectionType::H1(io) => Box::pin(h1proto::open_tunnel(io, head.into())), ConnectionType::H1(io) => h1proto::open_tunnel(io, head.into()).await,
ConnectionType::H2(io) => { ConnectionType::H2(io) => {
if let Some(mut pool) = self.pool.take() { if let Some(mut pool) = self.pool.take() {
pool.release(IoConnection::new( pool.release(IoConnection::new(
@ -193,7 +180,7 @@ where
None, None,
)); ));
} }
Box::pin(async { Err(SendRequestError::TunnelNotSupported) }) Err(SendRequestError::TunnelNotSupported)
} }
} }
} }
@ -216,14 +203,18 @@ where
{ {
type Io = EitherIo<A, B>; type Io = EitherIo<A, B>;
fn send_request<RB: MessageBody + 'static, H: Into<RequestHeadType>>( fn send_request<RB, H>(
self, self,
head: H, head: H,
body: RB, body: RB,
) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>> { ) -> LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>
where
RB: MessageBody + 'static,
H: Into<RequestHeadType> + 'static,
{
match self { match self {
EitherIoConnection::A(con) => con.send_request(head, body), EitherIoConnection::A(con) => Box::pin(con.send_request(head, body)),
EitherIoConnection::B(con) => con.send_request(head, body), EitherIoConnection::B(con) => Box::pin(con.send_request(head, body)),
} }
} }

View file

@ -1,5 +1,8 @@
use std::fmt; use std::fmt;
use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
@ -12,7 +15,7 @@ use actix_utils::timeout::{TimeoutError, TimeoutService};
use http::Uri; use http::Uri;
use super::config::ConnectorConfig; use super::config::ConnectorConfig;
use super::connection::Connection; use super::connection::{Connection, EitherIoConnection};
use super::error::ConnectError; use super::error::ConnectError;
use super::pool::{ConnectionPool, Protocol}; use super::pool::{ConnectionPool, Protocol};
use super::Connect; use super::Connect;
@ -55,7 +58,7 @@ pub struct Connector<T, U> {
_phantom: PhantomData<U>, _phantom: PhantomData<U>,
} }
trait Io: AsyncRead + AsyncWrite + Unpin {} pub trait Io: AsyncRead + AsyncWrite + Unpin {}
impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {} impl<T: AsyncRead + AsyncWrite + Unpin> Io for T {}
impl Connector<(), ()> { impl Connector<(), ()> {
@ -244,11 +247,9 @@ where
self, self,
) -> impl Service<Connect, Response = impl Connection, Error = ConnectError> + Clone ) -> impl Service<Connect, Response = impl Connection, Error = ConnectError> + Clone
{ {
#[cfg(not(any(feature = "openssl", feature = "rustls")))] let tcp_service = TimeoutService::new(
{
let connector = TimeoutService::new(
self.config.timeout, self.config.timeout,
apply_fn(self.connector, |msg: Connect, srv| { apply_fn(self.connector.clone(), |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr)) srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
}) })
.map_err(ConnectError::from) .map_err(ConnectError::from)
@ -259,13 +260,30 @@ where
TimeoutError::Timeout => ConnectError::Timeout, TimeoutError::Timeout => ConnectError::Timeout,
}); });
connect_impl::InnerConnector { #[cfg(not(any(feature = "openssl", feature = "rustls")))]
{
// A dummy service for annotate tls pool's type signature.
pub type DummyService = Box<
dyn Service<
Connect,
Response = (Box<dyn Io>, Protocol),
Error = ConnectError,
Future = futures_core::future::LocalBoxFuture<
'static,
Result<(Box<dyn Io>, Protocol), ConnectError>,
>,
>,
>;
InnerConnector::<_, DummyService, _, Box<dyn Io>> {
tcp_pool: ConnectionPool::new( tcp_pool: ConnectionPool::new(
connector, tcp_service,
self.config.no_disconnect_timeout(), self.config.no_disconnect_timeout(),
), ),
tls_pool: None,
} }
} }
#[cfg(any(feature = "openssl", feature = "rustls"))] #[cfg(any(feature = "openssl", feature = "rustls"))]
{ {
const H2: &[u8] = b"h2"; const H2: &[u8] = b"h2";
@ -328,125 +346,50 @@ where
TimeoutError::Timeout => ConnectError::Timeout, TimeoutError::Timeout => ConnectError::Timeout,
}); });
let tcp_service = TimeoutService::new( InnerConnector {
self.config.timeout,
apply_fn(self.connector, |msg: Connect, srv| {
srv.call(TcpConnect::new(msg.uri).set_addr(msg.addr))
})
.map_err(ConnectError::from)
.map(|stream| (stream.into_parts().0, Protocol::Http1)),
)
.map_err(|e| match e {
TimeoutError::Service(e) => e,
TimeoutError::Timeout => ConnectError::Timeout,
});
connect_impl::InnerConnector {
tcp_pool: ConnectionPool::new( tcp_pool: ConnectionPool::new(
tcp_service, tcp_service,
self.config.no_disconnect_timeout(), self.config.no_disconnect_timeout(),
), ),
ssl_pool: ConnectionPool::new(ssl_service, self.config), tls_pool: Some(ConnectionPool::new(ssl_service, self.config)),
} }
} }
} }
} }
#[cfg(not(any(feature = "openssl", feature = "rustls")))] struct InnerConnector<S1, S2, Io1, Io2>
mod connect_impl { where
use std::task::{Context, Poll}; S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{
tcp_pool: ConnectionPool<S1, Io1>,
tls_pool: Option<ConnectionPool<S2, Io2>>,
}
use futures_core::future::LocalBoxFuture; impl<S1, S2, Io1, Io2> Clone for InnerConnector<S1, S2, Io1, Io2>
where
use super::*; S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
use crate::client::connection::IoConnection; S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
pub(crate) struct InnerConnector<T, Io> Io2: AsyncRead + AsyncWrite + Unpin + 'static,
where {
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Connect, Response = (Io, Protocol), Error = ConnectError> + 'static,
{
pub(crate) tcp_pool: ConnectionPool<T, Io>,
}
impl<T, Io> Clone for InnerConnector<T, Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Connect, Response = (Io, Protocol), Error = ConnectError> + 'static,
{
fn clone(&self) -> Self { fn clone(&self) -> Self {
InnerConnector { InnerConnector {
tcp_pool: self.tcp_pool.clone(), tcp_pool: self.tcp_pool.clone(),
} tls_pool: self.tls_pool.as_ref().cloned(),
}
}
impl<T, Io> Service<Connect> for InnerConnector<T, Io>
where
Io: AsyncRead + AsyncWrite + Unpin + 'static,
T: Service<Connect, Response = (Io, Protocol), Error = ConnectError> + 'static,
{
type Response = IoConnection<Io>;
type Error = ConnectError;
type Future = LocalBoxFuture<'static, Result<IoConnection<Io>, ConnectError>>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tcp_pool.poll_ready(cx)
}
fn call(&self, req: Connect) -> Self::Future {
match req.uri.scheme_str() {
Some("https") | Some("wss") => {
Box::pin(async { Err(ConnectError::SslIsNotSupported) })
}
_ => self.tcp_pool.call(req),
}
} }
} }
} }
#[cfg(any(feature = "openssl", feature = "rustls"))] impl<S1, S2, Io1, Io2> Service<Connect> for InnerConnector<S1, S2, Io1, Io2>
mod connect_impl { where
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use super::*;
use crate::client::connection::EitherIoConnection;
pub(crate) struct InnerConnector<S1, S2, Io1, Io2>
where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static, S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static, S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
pub(crate) tcp_pool: ConnectionPool<S1, Io1>,
pub(crate) ssl_pool: ConnectionPool<S2, Io2>,
}
impl<S1, S2, Io1, Io2> Clone for InnerConnector<S1, S2, Io1, Io2>
where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn clone(&self) -> Self {
InnerConnector {
tcp_pool: self.tcp_pool.clone(),
ssl_pool: self.ssl_pool.clone(),
}
}
}
impl<S1, S2, Io1, Io2> Service<Connect> for InnerConnector<S1, S2, Io1, Io2>
where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{
type Response = EitherIoConnection<Io1, Io2>; type Response = EitherIoConnection<Io1, Io2>;
type Error = ConnectError; type Error = ConnectError;
type Future = InnerConnectorResponse<S1, S2, Io1, Io2>; type Future = InnerConnectorResponse<S1, S2, Io1, Io2>;
@ -457,43 +400,43 @@ mod connect_impl {
fn call(&self, req: Connect) -> Self::Future { fn call(&self, req: Connect) -> Self::Future {
match req.uri.scheme_str() { match req.uri.scheme_str() {
Some("https") | Some("wss") => { Some("https") | Some("wss") => match self.tls_pool {
InnerConnectorResponse::Io2(self.ssl_pool.call(req)) None => InnerConnectorResponse::SslIsNotSupported,
} Some(ref pool) => InnerConnectorResponse::Io2(pool.call(req)),
},
_ => InnerConnectorResponse::Io1(self.tcp_pool.call(req)), _ => InnerConnectorResponse::Io1(self.tcp_pool.call(req)),
} }
} }
} }
#[pin_project::pin_project(project = InnerConnectorProj)] #[pin_project::pin_project(project = InnerConnectorProj)]
pub(crate) enum InnerConnectorResponse<S1, S2, Io1, Io2> enum InnerConnectorResponse<S1, S2, Io1, Io2>
where where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static, S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static, S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
Io1(#[pin] <ConnectionPool<S1, Io1> as Service<Connect>>::Future), Io1(#[pin] <ConnectionPool<S1, Io1> as Service<Connect>>::Future),
Io2(#[pin] <ConnectionPool<S2, Io2> as Service<Connect>>::Future), Io2(#[pin] <ConnectionPool<S2, Io2> as Service<Connect>>::Future),
} SslIsNotSupported,
}
impl<S1, S2, Io1, Io2> Future for InnerConnectorResponse<S1, S2, Io1, Io2> impl<S1, S2, Io1, Io2> Future for InnerConnectorResponse<S1, S2, Io1, Io2>
where where
S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static, S1: Service<Connect, Response = (Io1, Protocol), Error = ConnectError> + 'static,
S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static, S2: Service<Connect, Response = (Io2, Protocol), Error = ConnectError> + 'static,
Io1: AsyncRead + AsyncWrite + Unpin + 'static, Io1: AsyncRead + AsyncWrite + Unpin + 'static,
Io2: AsyncRead + AsyncWrite + Unpin + 'static, Io2: AsyncRead + AsyncWrite + Unpin + 'static,
{ {
type Output = Result<EitherIoConnection<Io1, Io2>, ConnectError>; type Output = Result<EitherIoConnection<Io1, Io2>, ConnectError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() { match self.project() {
InnerConnectorProj::Io1(fut) => { InnerConnectorProj::Io1(fut) => fut.poll(cx).map_ok(EitherIoConnection::A),
fut.poll(cx).map_ok(EitherIoConnection::A) InnerConnectorProj::Io2(fut) => fut.poll(cx).map_ok(EitherIoConnection::B),
} InnerConnectorProj::SslIsNotSupported => {
InnerConnectorProj::Io2(fut) => { Poll::Ready(Err(ConnectError::SslIsNotSupported))
fut.poll(cx).map_ok(EitherIoConnection::B)
}
} }
} }
} }