From 0915879267b4a18afef5a5afdef12113b5d4e567 Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Mon, 9 Dec 2024 11:16:54 +0100 Subject: [PATCH] feat(awc): allow to set a specific sni host on the request --- awc/CHANGES.md | 1 + awc/src/builder.rs | 22 ++++-- awc/src/client/connector.rs | 126 +++++++++++++++++++++++--------- awc/src/client/mod.rs | 32 ++++++-- awc/src/client/pool.rs | 62 ++++++++++------ awc/src/connect.rs | 67 +++++++++++++---- awc/src/frozen.rs | 13 +++- awc/src/middleware/redirect.rs | 13 ++-- awc/src/request.rs | 19 ++++- awc/src/sender.rs | 16 +++- awc/src/ws.rs | 12 ++- awc/tests/test_rustls_client.rs | 101 ++++++++++++++++++++++++- 12 files changed, 382 insertions(+), 102 deletions(-) diff --git a/awc/CHANGES.md b/awc/CHANGES.md index 8a2a1ec43..2a1b44622 100644 --- a/awc/CHANGES.md +++ b/awc/CHANGES.md @@ -5,6 +5,7 @@ - Update `brotli` dependency to `7`. - Prevent panics on connection pool drop when Tokio runtime is shutdown early. - Minimum supported Rust version (MSRV) is now 1.75. +- Allow to set a specific SNI hostname on the request for TLS connections. ## 3.5.1 diff --git a/awc/src/builder.rs b/awc/src/builder.rs index 5aae394f8..0dfcd5472 100644 --- a/awc/src/builder.rs +++ b/awc/src/builder.rs @@ -3,7 +3,6 @@ use std::{fmt, net::IpAddr, rc::Rc, time::Duration}; use actix_http::{ error::HttpError, header::{self, HeaderMap, HeaderName, TryIntoHeaderPair}, - Uri, }; use actix_rt::net::{ActixStream, TcpStream}; use actix_service::{boxed, Service}; @@ -11,7 +10,8 @@ use base64::prelude::*; use crate::{ client::{ - ClientConfig, ConnectInfo, Connector, ConnectorService, TcpConnectError, TcpConnection, + ClientConfig, ConnectInfo, Connector, ConnectorService, HostnameWithSni, TcpConnectError, + TcpConnection, }, connect::DefaultConnector, error::SendRequestError, @@ -46,8 +46,8 @@ impl ClientBuilder { #[allow(clippy::new_ret_no_self)] pub fn new() -> ClientBuilder< impl Service< - ConnectInfo, - Response = TcpConnection, + ConnectInfo, + Response = TcpConnection, Error = TcpConnectError, > + Clone, (), @@ -69,16 +69,22 @@ impl ClientBuilder { impl ClientBuilder where - S: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone + S: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + 'static, Io: ActixStream + fmt::Debug + 'static, { /// Use custom connector service. pub fn connector(self, connector: Connector) -> ClientBuilder where - S1: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone + S1: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + 'static, Io1: ActixStream + fmt::Debug + 'static, { diff --git a/awc/src/client/connector.rs b/awc/src/client/connector.rs index f3d443070..2e3f977fa 100644 --- a/awc/src/client/connector.rs +++ b/awc/src/client/connector.rs @@ -16,10 +16,9 @@ use actix_rt::{ use actix_service::Service; use actix_tls::connect::{ ConnectError as TcpConnectError, ConnectInfo, Connection as TcpConnection, - Connector as TcpConnector, Resolver, + Connector as TcpConnector, Host, Resolver, }; use futures_core::{future::LocalBoxFuture, ready}; -use http::Uri; use pin_project_lite::pin_project; use super::{ @@ -27,9 +26,41 @@ use super::{ connection::{Connection, ConnectionIo}, error::ConnectError, pool::ConnectionPool, - Connect, + Connect, ServerName, }; +pub enum HostnameWithSni { + ForTcp(String, u16, Option), + ForTls(String, u16, Option), +} + +impl Host for HostnameWithSni { + fn hostname(&self) -> &str { + match self { + HostnameWithSni::ForTcp(hostname, _, _) => hostname, + HostnameWithSni::ForTls(hostname, _, sni) => sni.as_deref().unwrap_or(hostname), + } + } + + fn port(&self) -> Option { + match self { + HostnameWithSni::ForTcp(_, port, _) => Some(*port), + HostnameWithSni::ForTls(_, port, _) => Some(*port), + } + } +} + +impl HostnameWithSni { + pub fn to_tls(self) -> Self { + match self { + HostnameWithSni::ForTcp(hostname, port, sni) => { + HostnameWithSni::ForTls(hostname, port, sni) + } + HostnameWithSni::ForTls(_, _, _) => self, + } + } +} + enum OurTlsConnector { #[allow(dead_code)] // only dead when no TLS feature is enabled None, @@ -95,8 +126,8 @@ impl Connector<()> { #[allow(clippy::new_ret_no_self, clippy::let_unit_value)] pub fn new() -> Connector< impl Service< - ConnectInfo, - Response = TcpConnection, + ConnectInfo, + Response = TcpConnection, Error = actix_tls::connect::ConnectError, > + Clone, > { @@ -214,8 +245,11 @@ impl Connector { pub fn connector(self, connector: S1) -> Connector where Io1: ActixStream + fmt::Debug + 'static, - S1: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone, + S1: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone, { Connector { connector, @@ -235,8 +269,11 @@ where // This remap is to hide ActixStream's trait methods. They are not meant to be called // from user code. IO: ActixStream + fmt::Debug + 'static, - S: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone + S: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + 'static, { /// Sets TCP connection timeout. @@ -454,7 +491,7 @@ where use actix_utils::future::{ready, Ready}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let io = self.into_parts().0; (io, Protocol::Http2) @@ -505,7 +542,7 @@ where use actix_tls::connect::openssl::{reexports::AsyncSslStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -543,7 +580,7 @@ where use actix_tls::connect::rustls_0_20::{reexports::AsyncTlsStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -577,7 +614,7 @@ where use actix_tls::connect::rustls_0_21::{reexports::AsyncTlsStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -614,7 +651,7 @@ where use actix_tls::connect::rustls_0_22::{reexports::AsyncTlsStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -648,7 +685,7 @@ where use actix_tls::connect::rustls_0_23::{reexports::AsyncTlsStream, TlsConnector}; #[allow(non_local_definitions)] - impl IntoConnectionIo for TcpConnection> { + impl IntoConnectionIo for TcpConnection> { fn into_connection_io(self) -> (Box, Protocol) { let sock = self.into_parts().0; let h2 = sock @@ -688,7 +725,7 @@ where } } -/// tcp service for map `TcpConnection` type to `(Io, Protocol)` +/// tcp service for map `TcpConnection` type to `(Io, Protocol)` #[derive(Clone)] pub struct TcpConnectorService { service: S, @@ -696,7 +733,9 @@ pub struct TcpConnectorService { impl Service for TcpConnectorService where - S: Service, Error = ConnectError> + Clone + 'static, + S: Service, Error = ConnectError> + + Clone + + 'static, { type Response = (Io, Protocol); type Error = ConnectError; @@ -721,7 +760,7 @@ pin_project! { impl Future for TcpConnectorFuture where - Fut: Future, ConnectError>>, + Fut: Future, ConnectError>>, { type Output = Result<(Io, Protocol), ConnectError>; @@ -767,9 +806,10 @@ struct TlsConnectorService { ))] impl Service for TlsConnectorService where - Tcp: - Service, Error = ConnectError> + Clone + 'static, - Tls: Service, Error = std::io::Error> + Clone + 'static, + Tcp: Service, Error = ConnectError> + + Clone + + 'static, + Tls: Service, Error = std::io::Error> + Clone + 'static, Tls::Response: IntoConnectionIo, IO: ConnectionIo, { @@ -822,9 +862,14 @@ trait IntoConnectionIo { impl Future for TlsConnectorFuture where - S: Service, Response = Res, Error = std::io::Error, Future = Fut2>, + S: Service< + TcpConnection, + Response = Res, + Error = std::io::Error, + Future = Fut2, + >, S::Response: IntoConnectionIo, - Fut1: Future, ConnectError>>, + Fut1: Future, ConnectError>>, Fut2: Future>, Io: ConnectionIo, { @@ -838,10 +883,11 @@ where timeout, } => { let res = ready!(fut.poll(cx))?; + let (io, hostname_with_sni) = res.into_parts(); let fut = tls_service .take() .expect("TlsConnectorFuture polled after complete") - .call(res); + .call(TcpConnection::new(hostname_with_sni.to_tls(), io)); let timeout = sleep(*timeout); self.set(TlsConnectorFuture::TlsConnect { fut, timeout }); self.poll(cx) @@ -875,8 +921,11 @@ impl TcpConnectorInnerService { impl Service for TcpConnectorInnerService where - S: Service, Response = TcpConnection, Error = TcpConnectError> - + Clone + S: Service< + ConnectInfo, + Response = TcpConnection, + Error = TcpConnectError, + > + Clone + 'static, { type Response = S::Response; @@ -886,7 +935,13 @@ where actix_service::forward_ready!(service); fn call(&self, req: Connect) -> Self::Future { - let mut req = ConnectInfo::new(req.uri).set_addr(req.addr); + let mut req = ConnectInfo::new(HostnameWithSni::ForTcp( + req.hostname, + req.port, + req.sni_host, + )) + .set_addr(req.addr) + .set_port(req.port); if let Some(local_addr) = self.local_address { req = req.set_local_addr(local_addr); @@ -911,9 +966,9 @@ pin_project! { impl Future for TcpConnectorInnerFuture where - Fut: Future, TcpConnectError>>, + Fut: Future, TcpConnectError>>, { - type Output = Result, ConnectError>; + type Output = Result, ConnectError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); @@ -973,16 +1028,17 @@ where } fn call(&self, req: Connect) -> Self::Future { - match req.uri.scheme_str() { - Some("https") | Some("wss") => match self.tls_pool { + if req.tls { + match &self.tls_pool { None => ConnectorServiceFuture::SslIsNotSupported, - Some(ref pool) => ConnectorServiceFuture::Tls { + Some(pool) => ConnectorServiceFuture::Tls { fut: pool.call(req), }, - }, - _ => ConnectorServiceFuture::Tcp { + } + } else { + ConnectorServiceFuture::Tcp { fut: self.tcp_pool.call(req), - }, + } } } } diff --git a/awc/src/client/mod.rs b/awc/src/client/mod.rs index c9fa37253..5ac8650d3 100644 --- a/awc/src/client/mod.rs +++ b/awc/src/client/mod.rs @@ -1,6 +1,6 @@ //! HTTP client. -use std::{rc::Rc, time::Duration}; +use std::{ops::Deref, rc::Rc, time::Duration}; use actix_http::{error::HttpError, header::HeaderMap, Method, RequestHead, Uri}; use actix_rt::net::TcpStream; @@ -21,13 +21,33 @@ mod pool; pub use self::{ connection::{Connection, ConnectionIo}, - connector::{Connector, ConnectorService}, + connector::{Connector, ConnectorService, HostnameWithSni}, error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}, }; -#[derive(Clone)] +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum ServerName { + Owned(String), + Borrowed(Rc), +} + +impl Deref for ServerName { + type Target = str; + + fn deref(&self) -> &str { + match self { + ServerName::Owned(ref s) => s, + ServerName::Borrowed(ref s) => s, + } + } +} + +#[derive(Clone, Hash, PartialEq, Eq)] pub struct Connect { - pub uri: Uri, + pub hostname: String, + pub sni_host: Option, + pub port: u16, + pub tls: bool, pub addr: Option, } @@ -79,8 +99,8 @@ impl Client { /// This function is equivalent of `ClientBuilder::new()`. pub fn builder() -> ClientBuilder< impl Service< - ConnectInfo, - Response = TcpConnection, + ConnectInfo, + Response = TcpConnection, Error = TcpConnectError, > + Clone, > { diff --git a/awc/src/client/pool.rs b/awc/src/client/pool.rs index 5d764f729..9b1058d85 100644 --- a/awc/src/client/pool.rs +++ b/awc/src/client/pool.rs @@ -4,6 +4,7 @@ use std::{ cell::RefCell, collections::{HashMap, VecDeque}, future::Future, + hash::Hash, io, ops::Deref, pin::Pin, @@ -127,7 +128,7 @@ where Io: AsyncWrite + Unpin + 'static, { config: ConnectorConfig, - available: RefCell>>>, + available: RefCell>>>, permits: Arc, } @@ -168,12 +169,6 @@ where let inner = self.inner.clone(); Box::pin(async move { - let key = if let Some(authority) = req.uri.authority() { - authority.clone().into() - } else { - return Err(ConnectError::Unresolved); - }; - // acquire an owned permit and carry it with connection let permit = Arc::clone(&inner.permits) .acquire_owned() @@ -191,7 +186,7 @@ where // check if there is idle connection for given key. let mut map = inner.available.borrow_mut(); - if let Some(conns) = map.get_mut(&key) { + if let Some(conns) = map.get_mut(&req) { let now = Instant::now(); while let Some(mut c) = conns.pop_front() { @@ -232,7 +227,11 @@ where // construct acquired. It's used to put Io type back to pool/ close the Io type. // permit is carried with the whole lifecycle of Acquired. - let acquired = Acquired { key, inner, permit }; + let acquired = Acquired { + req: req.clone(), + inner, + permit, + }; // match the connection and spawn new one if did not get anything. match conn { @@ -344,8 +343,8 @@ pub struct Acquired where Io: AsyncWrite + Unpin + 'static, { - /// authority key for identify connection. - key: Key, + /// hash key for identify connection. + req: Connect, /// handle to connection pool. inner: ConnectionPoolInner, /// permit for limit concurrent in-flight connection for a Client object. @@ -360,12 +359,12 @@ impl Acquired { /// Release IO back into pool. pub(super) fn release(&self, conn: ConnectionInnerType, created: Instant) { - let Acquired { key, inner, .. } = self; + let Acquired { req, inner, .. } = self; inner .available .borrow_mut() - .entry(key.clone()) + .entry(req.clone()) .or_insert_with(VecDeque::new) .push_back(PooledConnection { conn, @@ -381,8 +380,6 @@ impl Acquired { mod test { use std::cell::Cell; - use http::Uri; - use super::*; /// A stream type that always returns pending on async read. @@ -467,7 +464,10 @@ mod test { let pool = super::ConnectionPool::new(connector, config); let req = Connect { - uri: Uri::from_static("http://localhost"), + hostname: "localhost".to_string(), + port: 80, + tls: false, + sni_host: None, addr: None, }; @@ -507,7 +507,10 @@ mod test { let pool = super::ConnectionPool::new(connector, config); let req = Connect { - uri: Uri::from_static("http://localhost"), + hostname: "localhost".to_string(), + port: 80, + tls: false, + sni_host: None, addr: None, }; @@ -549,7 +552,10 @@ mod test { let pool = super::ConnectionPool::new(connector, config); let req = Connect { - uri: Uri::from_static("http://localhost"), + hostname: "localhost".to_string(), + port: 80, + tls: false, + sni_host: None, addr: None, }; @@ -588,7 +594,10 @@ mod test { let pool = super::ConnectionPool::new(connector, config); let req = Connect { - uri: Uri::from_static("https://crates.io"), + hostname: "crates.io".to_string(), + port: 443, + tls: true, + sni_host: None, addr: None, }; @@ -601,7 +610,10 @@ mod test { release(conn); let req = Connect { - uri: Uri::from_static("https://google.com"), + hostname: "google.com".to_string(), + port: 443, + tls: true, + sni_host: None, addr: None, }; @@ -625,7 +637,10 @@ mod test { let pool = Rc::new(super::ConnectionPool::new(connector, config)); let req = Connect { - uri: Uri::from_static("https://crates.io"), + hostname: "crates.io".to_string(), + port: 443, + tls: true, + sni_host: None, addr: None, }; @@ -634,7 +649,10 @@ mod test { release(conn); let req = Connect { - uri: Uri::from_static("https://google.com"), + hostname: "google.com".to_string(), + port: 443, + tls: true, + sni_host: None, addr: None, }; let conn = pool.call(req.clone()).await.unwrap(); diff --git a/awc/src/connect.rs b/awc/src/connect.rs index 14ed9e958..a7bbd7b2d 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -13,7 +13,10 @@ use futures_core::{future::LocalBoxFuture, ready}; use crate::{ any_body::AnyBody, - client::{Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError}, + client::{ + Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, + ServerName, + }, ClientResponse, }; @@ -32,13 +35,18 @@ pub type BoxedSocket = Box; pub enum ConnectRequest { /// Standard HTTP request. /// - /// Contains the request head, body type, and optional pre-resolved socket address. - Client(RequestHeadType, AnyBody, Option), + /// Contains the request head, body type, optional pre-resolved socket address and optional sni host. + Client( + RequestHeadType, + AnyBody, + Option, + Option, + ), /// Tunnel used by WebSocket connection requests. /// - /// Contains the request head and optional pre-resolved socket address. - Tunnel(RequestHead, Option), + /// Contains the request head, optional pre-resolved socket address and optional sni host. + Tunnel(RequestHead, Option, Option), } /// Combined HTTP response & WebSocket tunnel type returned from connection service. @@ -103,17 +111,41 @@ where fn call(&self, req: ConnectRequest) -> Self::Future { // connect to the host - let fut = match req { - ConnectRequest::Client(ref head, .., addr) => self.connector.call(ClientConnect { - uri: head.as_ref().uri.clone(), - addr, - }), - ConnectRequest::Tunnel(ref head, addr) => self.connector.call(ClientConnect { - uri: head.uri.clone(), - addr, - }), + let (head, addr, sni_host) = match req { + ConnectRequest::Client(ref head, .., addr, ref sni_host) => { + (head.as_ref(), addr, sni_host.clone()) + } + ConnectRequest::Tunnel(ref head, addr, ref sni_host) => (head, addr, sni_host.clone()), }; + let authority = if let Some(authority) = head.uri.authority() { + authority + } else { + return ConnectRequestFuture::Error { + err: ConnectError::Unresolved, + }; + }; + + let tls = match head.uri.scheme_str() { + Some("https") | Some("wss") => true, + _ => false, + }; + + let fut = + self.connector.call(ClientConnect { + hostname: authority.host().to_string(), + port: authority.port().map(|p| p.as_u16()).unwrap_or_else(|| { + if tls { + 443 + } else { + 80 + } + }), + tls, + sni_host, + addr, + }); + ConnectRequestFuture::Connection { fut, req: Some(req), @@ -127,6 +159,9 @@ pin_project_lite::pin_project! { where Io: ConnectionIo { + Error { + err: ConnectError + }, Connection { #[pin] fut: Fut, @@ -192,6 +227,10 @@ where let framed = framed.into_map_io(|io| Box::new(io) as _); Poll::Ready(Ok(ConnectResponse::Tunnel(head, framed))) } + + ConnectRequestProj::Error { .. } => { + Poll::Ready(Err(SendRequestError::Connect(ConnectError::Unresolved))) + } } } } diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs index 862405234..d622f8ece 100644 --- a/awc/src/frozen.rs +++ b/awc/src/frozen.rs @@ -11,7 +11,7 @@ use futures_core::Stream; use serde::Serialize; use crate::{ - client::ClientConfig, + client::{ClientConfig, ServerName}, sender::{RequestSender, SendClientRequest}, BoxError, }; @@ -26,6 +26,7 @@ pub struct FrozenClientRequest { pub(crate) response_decompress: bool, pub(crate) timeout: Option, pub(crate) config: ClientConfig, + pub(crate) sni_host: Option, } impl FrozenClientRequest { @@ -54,6 +55,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), body, ) } @@ -65,6 +67,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), value, ) } @@ -76,6 +79,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), value, ) } @@ -91,6 +95,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), stream, ) } @@ -102,6 +107,7 @@ impl FrozenClientRequest { self.response_decompress, self.timeout, &self.config, + self.sni_host.clone(), ) } @@ -156,6 +162,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), body, ) } @@ -171,6 +178,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), value, ) } @@ -186,6 +194,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), value, ) } @@ -205,6 +214,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), stream, ) } @@ -220,6 +230,7 @@ impl FrozenSendBuilder { self.req.response_decompress, self.req.timeout, &self.req.config, + self.req.sni_host.clone(), ) } } diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs index b2cf9c45b..81f4d7993 100644 --- a/awc/src/middleware/redirect.rs +++ b/awc/src/middleware/redirect.rs @@ -73,11 +73,13 @@ where fn call(&self, req: ConnectRequest) -> Self::Future { match req { - ConnectRequest::Tunnel(head, addr) => { - let fut = self.connector.call(ConnectRequest::Tunnel(head, addr)); + ConnectRequest::Tunnel(head, addr, sni_host) => { + let fut = self + .connector + .call(ConnectRequest::Tunnel(head, addr, sni_host)); RedirectServiceFuture::Tunnel { fut } } - ConnectRequest::Client(head, body, addr) => { + ConnectRequest::Client(head, body, addr, sni_host) => { let connector = Rc::clone(&self.connector); let max_redirect_times = self.max_redirect_times; @@ -96,7 +98,7 @@ where _ => None, }; - let fut = connector.call(ConnectRequest::Client(head, body, addr)); + let fut = connector.call(ConnectRequest::Client(head, body, addr, sni_host)); RedirectServiceFuture::Client { fut, @@ -221,7 +223,8 @@ where let fut = connector .as_ref() .unwrap() - .call(ConnectRequest::Client(head, body_new, addr)); + // @TODO find a way to get sni host + .call(ConnectRequest::Client(head, body_new, addr, None)); self.set(RedirectServiceFuture::Client { fut, diff --git a/awc/src/request.rs b/awc/src/request.rs index 5f42f67ec..b0f995a63 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -14,7 +14,7 @@ use serde::Serialize; #[cfg(feature = "cookies")] use crate::cookie::{Cookie, CookieJar}; use crate::{ - client::ClientConfig, + client::{ClientConfig, ServerName}, error::{FreezeRequestError, InvalidUrl}, frozen::FrozenClientRequest, sender::{PrepForSendingError, RequestSender, SendClientRequest}, @@ -48,6 +48,7 @@ pub struct ClientRequest { response_decompress: bool, timeout: Option, config: ClientConfig, + sni_host: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -69,6 +70,7 @@ impl ClientRequest { cookies: None, timeout: None, response_decompress: true, + sni_host: None, } .method(method) .uri(uri) @@ -306,6 +308,12 @@ impl ClientRequest { Ok(self) } + /// Set SNI (Server Name Indication) host for this request. + pub fn sni_host(mut self, host: impl Into) -> Self { + self.sni_host = Some(ServerName::Owned(host.into())); + self + } + /// Freeze request builder and construct `FrozenClientRequest`, /// which could be used for sending same request multiple times. pub fn freeze(self) -> Result { @@ -320,6 +328,10 @@ impl ClientRequest { response_decompress: slf.response_decompress, timeout: slf.timeout, config: slf.config, + sni_host: slf.sni_host.map(|v| match v { + ServerName::Borrowed(r) => ServerName::Borrowed(r), + ServerName::Owned(o) => ServerName::Borrowed(Rc::new(o)), + }), }; Ok(request) @@ -340,6 +352,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, body, ) } @@ -356,6 +369,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, value, ) } @@ -374,6 +388,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, value, ) } @@ -394,6 +409,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, stream, ) } @@ -410,6 +426,7 @@ impl ClientRequest { slf.response_decompress, slf.timeout, &slf.config, + slf.sni_host, ) } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index b676ebf28..ab3ca596d 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -23,7 +23,7 @@ use serde::Serialize; use crate::{ any_body::AnyBody, - client::ClientConfig, + client::{ClientConfig, ServerName}, error::{FreezeRequestError, InvalidUrl, SendRequestError}, BoxError, ClientResponse, ConnectRequest, ConnectResponse, }; @@ -186,6 +186,7 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, body: impl MessageBody + 'static, ) -> SendClientRequest { let req = match self { @@ -193,11 +194,13 @@ impl RequestSender { RequestHeadType::Owned(head), AnyBody::from_message_body(body).into_boxed(), addr, + sni_host, ), RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( RequestHeadType::Rc(head, extra_headers), AnyBody::from_message_body(body).into_boxed(), addr, + sni_host, ), }; @@ -212,6 +215,7 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, value: impl Serialize, ) -> SendClientRequest { let body = match serde_json::to_string(&value) { @@ -223,7 +227,7 @@ impl RequestSender { return err.into(); } - self.send_body(addr, response_decompress, timeout, config, body) + self.send_body(addr, response_decompress, timeout, config, sni_host, body) } pub(crate) fn send_form( @@ -232,6 +236,7 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, value: impl Serialize, ) -> SendClientRequest { let body = match serde_urlencoded::to_string(value) { @@ -246,7 +251,7 @@ impl RequestSender { return err.into(); } - self.send_body(addr, response_decompress, timeout, config, body) + self.send_body(addr, response_decompress, timeout, config, sni_host, body) } pub(crate) fn send_stream( @@ -255,6 +260,7 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, stream: S, ) -> SendClientRequest where @@ -266,6 +272,7 @@ impl RequestSender { response_decompress, timeout, config, + sni_host, BodyStream::new(stream), ) } @@ -276,8 +283,9 @@ impl RequestSender { response_decompress: bool, timeout: Option, config: &ClientConfig, + sni_host: Option, ) -> SendClientRequest { - self.send_body(addr, response_decompress, timeout, config, ()) + self.send_body(addr, response_decompress, timeout, config, sni_host, ()) } fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 760331e9d..ef5cb7155 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -38,7 +38,7 @@ use base64::prelude::*; #[cfg(feature = "cookies")] use crate::cookie::{Cookie, CookieJar}; use crate::{ - client::ClientConfig, + client::{ClientConfig, ServerName}, connect::{BoxedSocket, ConnectRequest}, error::{HttpError, InvalidUrl, SendRequestError, WsClientError}, http::{ @@ -58,6 +58,7 @@ pub struct WebsocketsRequest { max_size: usize, server_mode: bool, config: ClientConfig, + sni_host: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -96,6 +97,7 @@ impl WebsocketsRequest { server_mode: false, #[cfg(feature = "cookies")] cookies: None, + sni_host: None, } } @@ -249,6 +251,12 @@ impl WebsocketsRequest { self.header(AUTHORIZATION, format!("Bearer {}", token)) } + /// Set SNI (Server Name Indication) host for this request. + pub fn sni_host(mut self, host: impl Into) -> Self { + self.sni_host = Some(ServerName::Owned(host.into())); + self + } + /// Complete request construction and connect to a WebSocket server. pub async fn connect( mut self, @@ -338,7 +346,7 @@ impl WebsocketsRequest { let max_size = self.max_size; let server_mode = self.server_mode; - let req = ConnectRequest::Tunnel(head, self.addr); + let req = ConnectRequest::Tunnel(head, self.addr, self.sni_host); let fut = self.config.connector.call(req); diff --git a/awc/tests/test_rustls_client.rs b/awc/tests/test_rustls_client.rs index 7e832f67d..afe21f9e1 100644 --- a/awc/tests/test_rustls_client.rs +++ b/awc/tests/test_rustls_client.rs @@ -43,6 +43,8 @@ fn tls_config() -> ServerConfig { } mod danger { + use std::collections::HashSet; + use rustls::{ client::danger::{ServerCertVerified, ServerCertVerifier}, pki_types::UnixTime, @@ -50,8 +52,10 @@ mod danger { use super::*; - #[derive(Debug)] - pub struct NoCertificateVerification; + #[derive(Debug, Default)] + pub struct NoCertificateVerification { + pub trusted_hosts: HashSet, + } impl ServerCertVerifier for NoCertificateVerification { fn verify_server_cert( @@ -62,7 +66,15 @@ mod danger { _ocsp_response: &[u8], _now: UnixTime, ) -> Result { - Ok(rustls::client::danger::ServerCertVerified::assertion()) + if self.trusted_hosts.is_empty() { + return Ok(ServerCertVerified::assertion()); + } + + if self.trusted_hosts.contains(_server_name.to_str().as_ref()) { + return Ok(ServerCertVerified::assertion()); + } + + Err(rustls::Error::General("untrusted host".into())) } fn verify_tls12_signature( @@ -124,7 +136,7 @@ async fn test_connection_reuse_h2() { // disable TLS verification config .dangerous() - .set_certificate_verifier(Arc::new(danger::NoCertificateVerification)); + .set_certificate_verifier(Arc::new(danger::NoCertificateVerification::default())); let client = awc::Client::builder() .connector(awc::Connector::new().rustls_0_23(Arc::new(config))) @@ -144,3 +156,84 @@ async fn test_connection_reuse_h2() { // one connection assert_eq!(num.load(Ordering::Relaxed), 1); } + +#[actix_rt::test] +async fn test_connection_with_sni() { + let srv = test_server(move || { + HttpService::build() + .h2(map_config( + App::new().service(web::resource("/").route(web::to(HttpResponse::Ok))), + |_| AppConfig::default(), + )) + .rustls_0_23(tls_config()) + .map_err(|_| ()) + }) + .await; + + let mut config = ClientConfig::builder() + .with_root_certificates(webpki_roots_cert_store()) + .with_no_client_auth(); + + let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + config.alpn_protocols = protos; + + // disable TLS verification + config + .dangerous() + .set_certificate_verifier(Arc::new(danger::NoCertificateVerification { + trusted_hosts: ["localhost".to_owned()].iter().cloned().collect(), + })); + + let client = awc::Client::builder() + .connector(awc::Connector::new().rustls_0_23(Arc::new(config))) + .finish(); + + // req : standard request + let request = client.get(srv.surl("/")).send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req : test specific host with address, return trusted host + let request = client.get(srv.surl("/")).sni_host("localhost").send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req : test bad host, return untrusted host + let request = client.get(srv.surl("/")).sni_host("bad.host").send(); + let response = request.await; + + assert!(response.is_err()); + assert_eq!( + response.unwrap_err().to_string(), + "Failed to connect to host: unexpected error: untrusted host" + ); + + // req : test specific host with address, return untrusted host + let addr = srv.addr(); + let request = client.get("https://example.com:443/").address(addr).send(); + let response = request.await; + + assert!(response.is_err()); + assert_eq!( + response.unwrap_err().to_string(), + "Failed to connect to host: unexpected error: untrusted host" + ); + + // req : test specify sni_host with address and other host (authority) + let request = client + .get("https://example.com:443/") + .address(addr) + .sni_host("localhost") + .send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); + + // req : test ip address with sni host + let request = client + .get("https://127.0.0.1:443/") + .address(addr) + .sni_host("localhost") + .send(); + let response = request.await.unwrap(); + assert!(response.status().is_success()); +}