From 4bab50c8611683e9e51c6f49130838e934155fff Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Wed, 29 Aug 2018 20:53:31 +0200 Subject: [PATCH 1/2] Add ability to pass a custom TlsConnector (#491) --- src/client/connector.rs | 217 +++++++++++++--------------------------- 1 file changed, 68 insertions(+), 149 deletions(-) diff --git a/src/client/connector.rs b/src/client/connector.rs index 1217b5bcf..430a0f752 100644 --- a/src/client/connector.rs +++ b/src/client/connector.rs @@ -17,14 +17,16 @@ use tokio_io::{AsyncRead, AsyncWrite}; use tokio_timer::Delay; #[cfg(feature = "alpn")] -use openssl::ssl::{Error as OpensslError, SslConnector, SslMethod}; -#[cfg(feature = "alpn")] -use tokio_openssl::SslConnectorExt; +use { + openssl::ssl::{Error as SslError, SslConnector, SslMethod}, + tokio_openssl::SslConnectorExt +}; #[cfg(all(feature = "tls", not(feature = "alpn")))] -use native_tls::{Error as TlsError, TlsConnector as NativeTlsConnector}; -#[cfg(all(feature = "tls", not(feature = "alpn")))] -use tokio_tls::{TlsConnector}; +use { + native_tls::{Error as SslError, TlsConnector as NativeTlsConnector}, + tokio_tls::TlsConnector as SslConnector +}; #[cfg( all( @@ -32,42 +34,25 @@ use tokio_tls::{TlsConnector}; not(any(feature = "alpn", feature = "tls")) ) )] -use rustls::ClientConfig; +use { + rustls::ClientConfig, + std::io::Error as SslError, + std::sync::Arc, + tokio_rustls::ClientConfigExt, + webpki::DNSNameRef, + webpki_roots, +}; + #[cfg( all( feature = "rust-tls", not(any(feature = "alpn", feature = "tls")) ) )] -use std::io::Error as TLSError; -#[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) -)] -use std::sync::Arc; -#[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) -)] -use tokio_rustls::ClientConfigExt; -#[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) -)] -use webpki::DNSNameRef; -#[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) -)] -use webpki_roots; +type SslConnector = Arc; + +#[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))] +type SslConnector = (); use server::IoStream; use {HAS_OPENSSL, HAS_RUSTLS, HAS_TLS}; @@ -173,24 +158,9 @@ pub enum ClientConnectorError { SslIsNotSupported, /// SSL error - #[cfg(feature = "alpn")] + #[cfg(any(feature = "tls", feature = "alpn", feature = "rust-tls"))] #[fail(display = "{}", _0)] - SslError(#[cause] OpensslError), - - /// SSL error - #[cfg(all(feature = "tls", not(feature = "alpn")))] - #[fail(display = "{}", _0)] - SslError(#[cause] TlsError), - - /// SSL error - #[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) - )] - #[fail(display = "{}", _0)] - SslError(#[cause] TLSError), + SslError(#[cause] SslError), /// Resolver error #[fail(display = "{}", _0)] @@ -242,17 +212,7 @@ impl Paused { /// `ClientConnector` type is responsible for transport layer of a /// client connection. pub struct ClientConnector { - #[cfg(all(feature = "alpn"))] connector: SslConnector, - #[cfg(all(feature = "tls", not(feature = "alpn")))] - connector: TlsConnector, - #[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) - )] - connector: Arc, stats: ClientConnectorStats, subscriber: Option>, @@ -293,71 +253,32 @@ impl SystemService for ClientConnector {} impl Default for ClientConnector { fn default() -> ClientConnector { - #[cfg(all(feature = "alpn"))] - { - let builder = SslConnector::builder(SslMethod::tls()).unwrap(); - ClientConnector::with_connector(builder.build()) - } - #[cfg(all(feature = "tls", not(feature = "alpn")))] - { - let (tx, rx) = mpsc::unbounded(); - let builder = NativeTlsConnector::builder(); - ClientConnector { - stats: ClientConnectorStats::default(), - subscriber: None, - acq_tx: tx, - acq_rx: Some(rx), - resolver: None, - connector: builder.build().unwrap().into(), - conn_lifetime: Duration::from_secs(75), - conn_keep_alive: Duration::from_secs(15), - limit: 100, - limit_per_host: 0, - acquired: 0, - acquired_per_host: HashMap::new(), - available: HashMap::new(), - to_close: Vec::new(), - waiters: Some(HashMap::new()), - wait_timeout: None, - paused: Paused::No, - } - } - #[cfg( - all( - feature = "rust-tls", - not(any(feature = "alpn", feature = "tls")) - ) - )] - { - let mut config = ClientConfig::new(); - config - .root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); - ClientConnector::with_connector(config) - } + let connector = { + #[cfg(all(feature = "alpn"))] + { SslConnector::builder(SslMethod::tls()).unwrap().build() } - #[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))] - { - let (tx, rx) = mpsc::unbounded(); - ClientConnector { - stats: ClientConnectorStats::default(), - subscriber: None, - acq_tx: tx, - acq_rx: Some(rx), - resolver: None, - conn_lifetime: Duration::from_secs(75), - conn_keep_alive: Duration::from_secs(15), - limit: 100, - limit_per_host: 0, - acquired: 0, - acquired_per_host: HashMap::new(), - available: HashMap::new(), - to_close: Vec::new(), - waiters: Some(HashMap::new()), - wait_timeout: None, - paused: Paused::No, + #[cfg(all(feature = "tls", not(feature = "alpn")))] + { NativeTlsConnector::builder().build().unwrap().into() } + + #[cfg( + all( + feature = "rust-tls", + not(any(feature = "alpn", feature = "tls")) + ) + )] + { + let mut config = ClientConfig::new(); + config + .root_store + .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + Arc::new(config) } - } + + #[cfg(not(any(feature = "alpn", feature = "tls", feature = "rust-tls")))] + { () } + }; + + ClientConnector::with_connector_impl(connector) } } @@ -402,27 +323,8 @@ impl ClientConnector { /// } /// ``` pub fn with_connector(connector: SslConnector) -> ClientConnector { - let (tx, rx) = mpsc::unbounded(); - - ClientConnector { - connector, - stats: ClientConnectorStats::default(), - subscriber: None, - acq_tx: tx, - acq_rx: Some(rx), - resolver: None, - conn_lifetime: Duration::from_secs(75), - conn_keep_alive: Duration::from_secs(15), - limit: 100, - limit_per_host: 0, - acquired: 0, - acquired_per_host: HashMap::new(), - available: HashMap::new(), - to_close: Vec::new(), - waiters: Some(HashMap::new()), - wait_timeout: None, - paused: Paused::No, - } + // keep level of indirection for docstrings matching featureflags + Self::with_connector_impl(connector) } #[cfg( @@ -476,10 +378,27 @@ impl ClientConnector { /// } /// ``` pub fn with_connector(connector: ClientConfig) -> ClientConnector { + // keep level of indirection for docstrings matching featureflags + Self::with_connector_impl(Arc::new(connector)) + } + + #[cfg( + all( + feature = "tls", + not(any(feature = "alpn", feature = "rust-tls")) + ) + )] + pub fn with_connector(connector: SslConnector) -> ClientConnector { + // keep level of indirection for docstrings matching featureflags + Self::with_connector_impl(connector) + } + + #[inline] + fn with_connector_impl(connector: SslConnector) -> ClientConnector { let (tx, rx) = mpsc::unbounded(); ClientConnector { - connector: Arc::new(connector), + connector, stats: ClientConnectorStats::default(), subscriber: None, acq_tx: tx, @@ -1364,4 +1283,4 @@ impl IoStream for TlsStream { fn set_linger(&mut self, dur: Option) -> io::Result<()> { self.get_mut().get_mut().set_linger(dur) } -} \ No newline at end of file +} From 797b52ecbf21bfd5cfec2306653af6741279b595 Mon Sep 17 00:00:00 2001 From: Armin Ronacher Date: Wed, 29 Aug 2018 20:58:23 +0200 Subject: [PATCH 2/2] Update CHANGES.md --- CHANGES.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index eaf7b42b8..34b0a9621 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [0.7.5] - 2018-09-xx + +### Added + +* Added the ability to pass a custom `TlsConnector`. + ## [0.7.4] - 2018-08-23 ### Added