1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-12-17 05:36:36 +00:00

feat(awc): split connector config with connect config, allow to configure connect config per request

This commit is contained in:
Joel Wurtz 2024-12-06 09:43:26 +01:00
parent 0915879267
commit 610dd616ef
No known key found for this signature in database
GPG key ID: ED264D1967A51B0D
11 changed files with 261 additions and 57 deletions

View file

@ -3,29 +3,33 @@ use std::{net::IpAddr, time::Duration};
const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB
const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB
/// Connector configuration /// Connect configuration
#[derive(Clone)] #[derive(Clone, Hash, Eq, PartialEq)]
pub(crate) struct ConnectorConfig { pub struct ConnectConfig {
pub(crate) timeout: Duration, pub(crate) timeout: Duration,
pub(crate) handshake_timeout: Duration, pub(crate) handshake_timeout: Duration,
pub(crate) conn_lifetime: Duration, pub(crate) conn_lifetime: Duration,
pub(crate) conn_keep_alive: Duration, pub(crate) conn_keep_alive: Duration,
pub(crate) disconnect_timeout: Option<Duration>,
pub(crate) limit: usize,
pub(crate) conn_window_size: u32, pub(crate) conn_window_size: u32,
pub(crate) stream_window_size: u32, pub(crate) stream_window_size: u32,
pub(crate) local_address: Option<IpAddr>, pub(crate) local_address: Option<IpAddr>,
} }
impl Default for ConnectorConfig { /// Connector configuration
#[derive(Clone)]
pub struct ConnectorConfig {
pub(crate) default_connect_config: ConnectConfig,
pub(crate) disconnect_timeout: Option<Duration>,
pub(crate) limit: usize,
}
impl Default for ConnectConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
timeout: Duration::from_secs(5), timeout: Duration::from_secs(5),
handshake_timeout: Duration::from_secs(5), handshake_timeout: Duration::from_secs(5),
conn_lifetime: Duration::from_secs(75), conn_lifetime: Duration::from_secs(75),
conn_keep_alive: Duration::from_secs(15), conn_keep_alive: Duration::from_secs(15),
disconnect_timeout: Some(Duration::from_millis(3000)),
limit: 100,
conn_window_size: DEFAULT_H2_CONN_WINDOW, conn_window_size: DEFAULT_H2_CONN_WINDOW,
stream_window_size: DEFAULT_H2_STREAM_WINDOW, stream_window_size: DEFAULT_H2_STREAM_WINDOW,
local_address: None, local_address: None,
@ -33,10 +37,88 @@ impl Default for ConnectorConfig {
} }
} }
impl Default for ConnectorConfig {
fn default() -> Self {
Self {
default_connect_config: ConnectConfig::default(),
disconnect_timeout: Some(Duration::from_millis(3000)),
limit: 100,
}
}
}
impl ConnectorConfig { impl ConnectorConfig {
pub(crate) fn no_disconnect_timeout(&self) -> Self { pub fn no_disconnect_timeout(&self) -> Self {
let mut res = self.clone(); let mut res = self.clone();
res.disconnect_timeout = None; res.disconnect_timeout = None;
res res
} }
} }
impl ConnectConfig {
/// Sets TCP connection timeout.
///
/// This is the max time allowed to connect to remote host, including DNS name resolution.
///
/// By default, the timeout is 5 seconds.
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
/// Sets TLS handshake timeout.
///
/// This is the max time allowed to perform the TLS handshake with remote host after TCP
/// connection is established.
///
/// By default, the timeout is 5 seconds.
pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
self.handshake_timeout = timeout;
self
}
/// Sets the initial window size (in bytes) for HTTP/2 stream-level flow control for received
/// data.
///
/// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_window_size(mut self, size: u32) -> Self {
self.stream_window_size = size;
self
}
/// Sets the initial window size (in bytes) for HTTP/2 connection-level flow control for
/// received data.
///
/// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_connection_window_size(mut self, size: u32) -> Self {
self.conn_window_size = size;
self
}
/// Set keep-alive period for opened connection.
///
/// Keep-alive period is the period between connection usage. If
/// the delay between repeated usages of the same connection
/// exceeds this period, the connection is closed.
/// Default keep-alive period is 15 seconds.
pub fn conn_keep_alive(mut self, dur: Duration) -> Self {
self.conn_keep_alive = dur;
self
}
/// Set max lifetime period for connection.
///
/// Connection lifetime is max lifetime of any opened connection
/// until it is closed regardless of keep-alive period.
/// Default lifetime period is 75 seconds.
pub fn conn_lifetime(mut self, dur: Duration) -> Self {
self.conn_lifetime = dur;
self
}
/// Set local IP Address the connector would use for establishing connection.
pub fn local_address(mut self, addr: IpAddr) -> Self {
self.local_address = Some(addr);
self
}
}

View file

@ -282,7 +282,7 @@ where
/// ///
/// By default, the timeout is 5 seconds. /// By default, the timeout is 5 seconds.
pub fn timeout(mut self, timeout: Duration) -> Self { pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = timeout; self.config.default_connect_config.timeout = timeout;
self self
} }
@ -293,7 +293,7 @@ where
/// ///
/// By default, the timeout is 5 seconds. /// By default, the timeout is 5 seconds.
pub fn handshake_timeout(mut self, timeout: Duration) -> Self { pub fn handshake_timeout(mut self, timeout: Duration) -> Self {
self.config.handshake_timeout = timeout; self.config.default_connect_config.handshake_timeout = timeout;
self self
} }
@ -387,7 +387,7 @@ where
/// ///
/// The default value is 65,535 and is good for APIs, but not for big objects. /// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_window_size(mut self, size: u32) -> Self { pub fn initial_window_size(mut self, size: u32) -> Self {
self.config.stream_window_size = size; self.config.default_connect_config.stream_window_size = size;
self self
} }
@ -396,7 +396,7 @@ where
/// ///
/// The default value is 65,535 and is good for APIs, but not for big objects. /// The default value is 65,535 and is good for APIs, but not for big objects.
pub fn initial_connection_window_size(mut self, size: u32) -> Self { pub fn initial_connection_window_size(mut self, size: u32) -> Self {
self.config.conn_window_size = size; self.config.default_connect_config.conn_window_size = size;
self self
} }
@ -422,7 +422,7 @@ where
/// exceeds this period, the connection is closed. /// exceeds this period, the connection is closed.
/// Default keep-alive period is 15 seconds. /// Default keep-alive period is 15 seconds.
pub fn conn_keep_alive(mut self, dur: Duration) -> Self { pub fn conn_keep_alive(mut self, dur: Duration) -> Self {
self.config.conn_keep_alive = dur; self.config.default_connect_config.conn_keep_alive = dur;
self self
} }
@ -432,7 +432,7 @@ where
/// until it is closed regardless of keep-alive period. /// until it is closed regardless of keep-alive period.
/// Default lifetime period is 75 seconds. /// Default lifetime period is 75 seconds.
pub fn conn_lifetime(mut self, dur: Duration) -> Self { pub fn conn_lifetime(mut self, dur: Duration) -> Self {
self.config.conn_lifetime = dur; self.config.default_connect_config.conn_lifetime = dur;
self self
} }
@ -451,7 +451,7 @@ where
/// Set local IP Address the connector would use for establishing connection. /// Set local IP Address the connector would use for establishing connection.
pub fn local_address(mut self, addr: IpAddr) -> Self { pub fn local_address(mut self, addr: IpAddr) -> Self {
self.config.local_address = Some(addr); self.config.default_connect_config.local_address = Some(addr);
self self
} }
@ -459,8 +459,8 @@ where
/// ///
/// The `Connector` builder always concludes by calling `finish()` last in its combinator chain. /// The `Connector` builder always concludes by calling `finish()` last in its combinator chain.
pub fn finish(self) -> ConnectorService<S, IO> { pub fn finish(self) -> ConnectorService<S, IO> {
let local_address = self.config.local_address; let local_address = self.config.default_connect_config.local_address;
let timeout = self.config.timeout; let timeout = self.config.default_connect_config.timeout;
let tcp_service_inner = let tcp_service_inner =
TcpConnectorInnerService::new(self.connector, timeout, local_address); TcpConnectorInnerService::new(self.connector, timeout, local_address);
@ -523,7 +523,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -557,7 +557,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -596,7 +596,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -630,7 +630,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -667,7 +667,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -701,7 +701,7 @@ where
} }
} }
let handshake_timeout = self.config.handshake_timeout; let handshake_timeout = self.config.default_connect_config.handshake_timeout;
let tls_service = TlsConnectorService { let tls_service = TlsConnectorService {
tcp_service: tcp_service_inner, tcp_service: tcp_service_inner,
@ -824,9 +824,13 @@ where
} }
fn call(&self, req: Connect) -> Self::Future { fn call(&self, req: Connect) -> Self::Future {
let timeout = req
.config
.clone()
.map(|c| c.handshake_timeout)
.unwrap_or(self.timeout);
let fut = self.tcp_service.call(req); let fut = self.tcp_service.call(req);
let tls_service = self.tls_service.clone(); let tls_service = self.tls_service.clone();
let timeout = self.timeout;
TlsConnectorFuture::TcpConnect { TlsConnectorFuture::TcpConnect {
fut, fut,
@ -935,6 +939,7 @@ where
actix_service::forward_ready!(service); actix_service::forward_ready!(service);
fn call(&self, req: Connect) -> Self::Future { fn call(&self, req: Connect) -> Self::Future {
let timeout = req.config.map(|c| c.timeout).unwrap_or(self.timeout);
let mut req = ConnectInfo::new(HostnameWithSni::ForTcp( let mut req = ConnectInfo::new(HostnameWithSni::ForTcp(
req.hostname, req.hostname,
req.port, req.port,
@ -949,7 +954,7 @@ where
TcpConnectorInnerFuture { TcpConnectorInnerFuture {
fut: self.service.call(req), fut: self.service.call(req),
timeout: sleep(self.timeout), timeout: sleep(timeout),
} }
} }
} }

View file

@ -19,7 +19,6 @@ use http::{
use log::trace; use log::trace;
use super::{ use super::{
config::ConnectorConfig,
connection::{ConnectionIo, H2Connection}, connection::{ConnectionIo, H2Connection},
error::SendRequestError, error::SendRequestError,
}; };
@ -186,12 +185,13 @@ where
pub(crate) fn handshake<Io: ConnectionIo>( pub(crate) fn handshake<Io: ConnectionIo>(
io: Io, io: Io,
config: &ConnectorConfig, stream_window_size: u32,
conn_window_size: u32,
) -> impl Future<Output = Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>> { ) -> impl Future<Output = Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>> {
let mut builder = Builder::new(); let mut builder = Builder::new();
builder builder
.initial_window_size(config.stream_window_size) .initial_window_size(stream_window_size)
.initial_connection_window_size(config.conn_window_size) .initial_connection_window_size(conn_window_size)
.enable_push(false); .enable_push(false);
builder.handshake(io) builder.handshake(io)
} }

View file

@ -20,6 +20,7 @@ mod h2proto;
mod pool; mod pool;
pub use self::{ pub use self::{
config::ConnectConfig,
connection::{Connection, ConnectionIo}, connection::{Connection, ConnectionIo},
connector::{Connector, ConnectorService, HostnameWithSni}, connector::{Connector, ConnectorService, HostnameWithSni},
error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}, error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError},
@ -49,6 +50,7 @@ pub struct Connect {
pub port: u16, pub port: u16,
pub tls: bool, pub tls: bool,
pub addr: Option<std::net::SocketAddr>, pub addr: Option<std::net::SocketAddr>,
pub config: Option<Rc<ConnectConfig>>,
} }
/// An asynchronous HTTP and WebSocket client. /// An asynchronous HTTP and WebSocket client.

View file

@ -190,7 +190,11 @@ where
let now = Instant::now(); let now = Instant::now();
while let Some(mut c) = conns.pop_front() { while let Some(mut c) = conns.pop_front() {
let config = &inner.config; let config = req
.config
.as_ref()
.map(|c| c.as_ref())
.unwrap_or(&inner.config.default_connect_config);
let idle_dur = now - c.used; let idle_dur = now - c.used;
let age = now - c.created; let age = now - c.created;
let conn_ineligible = let conn_ineligible =
@ -225,6 +229,17 @@ where
conn conn
}; };
let stream_window_size = req
.config
.as_ref()
.map(|c| c.stream_window_size)
.unwrap_or(inner.config.default_connect_config.stream_window_size);
let conn_window_size = req
.config
.as_ref()
.map(|c| c.conn_window_size)
.unwrap_or(inner.config.default_connect_config.conn_window_size);
// construct acquired. It's used to put Io type back to pool/ close the Io type. // construct acquired. It's used to put Io type back to pool/ close the Io type.
// permit is carried with the whole lifecycle of Acquired. // permit is carried with the whole lifecycle of Acquired.
let acquired = Acquired { let acquired = Acquired {
@ -245,8 +260,8 @@ where
if proto == Protocol::Http1 { if proto == Protocol::Http1 {
Ok(ConnectionType::from_h1(io, Instant::now(), acquired)) Ok(ConnectionType::from_h1(io, Instant::now(), acquired))
} else { } else {
let config = &acquired.inner.config; let (sender, connection) =
let (sender, connection) = handshake(io, config).await?; handshake(io, stream_window_size, conn_window_size).await?;
let inner = H2ConnectionInner::new(sender, connection); let inner = H2ConnectionInner::new(sender, connection);
Ok(ConnectionType::from_h2(inner, Instant::now(), acquired)) Ok(ConnectionType::from_h2(inner, Instant::now(), acquired))
} }
@ -381,6 +396,7 @@ mod test {
use std::cell::Cell; use std::cell::Cell;
use super::*; use super::*;
use crate::client::ConnectConfig;
/// A stream type that always returns pending on async read. /// A stream type that always returns pending on async read.
/// ///
@ -469,6 +485,7 @@ mod test {
tls: false, tls: false,
sni_host: None, sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -500,7 +517,10 @@ mod test {
let connector = TestPoolConnector { generated }; let connector = TestPoolConnector { generated };
let config = ConnectorConfig { let config = ConnectorConfig {
conn_keep_alive: Duration::from_secs(1), default_connect_config: ConnectConfig {
conn_keep_alive: Duration::from_secs(1),
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -512,6 +532,7 @@ mod test {
tls: false, tls: false,
sni_host: None, sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -545,7 +566,10 @@ mod test {
let connector = TestPoolConnector { generated }; let connector = TestPoolConnector { generated };
let config = ConnectorConfig { let config = ConnectorConfig {
conn_lifetime: Duration::from_secs(1), default_connect_config: ConnectConfig {
conn_lifetime: Duration::from_secs(1),
..Default::default()
},
..Default::default() ..Default::default()
}; };
@ -557,6 +581,7 @@ mod test {
tls: false, tls: false,
sni_host: None, sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -599,6 +624,7 @@ mod test {
tls: true, tls: true,
sni_host: None, sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -615,6 +641,7 @@ mod test {
tls: true, tls: true,
sni_host: None, sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -642,6 +669,7 @@ mod test {
tls: true, tls: true,
sni_host: None, sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
@ -654,6 +682,7 @@ mod test {
tls: true, tls: true,
sni_host: None, sni_host: None,
addr: None, addr: None,
config: None,
}; };
let conn = pool.call(req.clone()).await.unwrap(); let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(2, generated_clone.get()); assert_eq!(2, generated_clone.get());

View file

@ -14,8 +14,8 @@ use futures_core::{future::LocalBoxFuture, ready};
use crate::{ use crate::{
any_body::AnyBody, any_body::AnyBody,
client::{ client::{
Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, Connect as ClientConnect, ConnectConfig, ConnectError, Connection, ConnectionIo,
ServerName, SendRequestError, ServerName,
}, },
ClientResponse, ClientResponse,
}; };
@ -41,12 +41,18 @@ pub enum ConnectRequest {
AnyBody, AnyBody,
Option<net::SocketAddr>, Option<net::SocketAddr>,
Option<ServerName>, Option<ServerName>,
Option<Rc<ConnectConfig>>,
), ),
/// Tunnel used by WebSocket connection requests. /// Tunnel used by WebSocket connection requests.
/// ///
/// Contains the request head, optional pre-resolved socket address and optional sni host. /// Contains the request head, optional pre-resolved socket address and optional sni host.
Tunnel(RequestHead, Option<net::SocketAddr>, Option<ServerName>), Tunnel(
RequestHead,
Option<net::SocketAddr>,
Option<ServerName>,
Option<Rc<ConnectConfig>>,
),
} }
/// Combined HTTP response & WebSocket tunnel type returned from connection service. /// Combined HTTP response & WebSocket tunnel type returned from connection service.
@ -111,11 +117,13 @@ where
fn call(&self, req: ConnectRequest) -> Self::Future { fn call(&self, req: ConnectRequest) -> Self::Future {
// connect to the host // connect to the host
let (head, addr, sni_host) = match req { let (head, addr, sni_host, config) = match req {
ConnectRequest::Client(ref head, .., addr, ref sni_host) => { ConnectRequest::Client(ref head, .., addr, ref sni_host, ref config) => {
(head.as_ref(), addr, sni_host.clone()) (head.as_ref(), addr, sni_host.clone(), config.clone())
}
ConnectRequest::Tunnel(ref head, addr, ref sni_host, ref config) => {
(head, addr, sni_host.clone(), config.clone())
} }
ConnectRequest::Tunnel(ref head, addr, ref sni_host) => (head, addr, sni_host.clone()),
}; };
let authority = if let Some(authority) = head.uri.authority() { let authority = if let Some(authority) = head.uri.authority() {
@ -144,6 +152,7 @@ where
tls, tls,
sni_host, sni_host,
addr, addr,
config,
}); });
ConnectRequestFuture::Connection { ConnectRequestFuture::Connection {

View file

@ -11,7 +11,7 @@ use futures_core::Stream;
use serde::Serialize; use serde::Serialize;
use crate::{ use crate::{
client::{ClientConfig, ServerName}, client::{ClientConfig, ConnectConfig, ServerName},
sender::{RequestSender, SendClientRequest}, sender::{RequestSender, SendClientRequest},
BoxError, BoxError,
}; };
@ -27,6 +27,7 @@ pub struct FrozenClientRequest {
pub(crate) timeout: Option<Duration>, pub(crate) timeout: Option<Duration>,
pub(crate) config: ClientConfig, pub(crate) config: ClientConfig,
pub(crate) sni_host: Option<ServerName>, pub(crate) sni_host: Option<ServerName>,
pub(crate) connect_config: Option<Rc<ConnectConfig>>,
} }
impl FrozenClientRequest { impl FrozenClientRequest {
@ -56,6 +57,7 @@ impl FrozenClientRequest {
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(), self.sni_host.clone(),
self.connect_config.clone(),
body, body,
) )
} }
@ -68,6 +70,7 @@ impl FrozenClientRequest {
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(), self.sni_host.clone(),
self.connect_config.clone(),
value, value,
) )
} }
@ -80,6 +83,7 @@ impl FrozenClientRequest {
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(), self.sni_host.clone(),
self.connect_config.clone(),
value, value,
) )
} }
@ -96,6 +100,7 @@ impl FrozenClientRequest {
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(), self.sni_host.clone(),
self.connect_config.clone(),
stream, stream,
) )
} }
@ -108,6 +113,7 @@ impl FrozenClientRequest {
self.timeout, self.timeout,
&self.config, &self.config,
self.sni_host.clone(), self.sni_host.clone(),
self.connect_config.clone(),
) )
} }
@ -163,6 +169,7 @@ impl FrozenSendBuilder {
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(), self.req.sni_host.clone(),
self.req.connect_config,
body, body,
) )
} }
@ -179,6 +186,7 @@ impl FrozenSendBuilder {
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(), self.req.sni_host.clone(),
self.req.connect_config,
value, value,
) )
} }
@ -195,6 +203,7 @@ impl FrozenSendBuilder {
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(), self.req.sni_host.clone(),
self.req.connect_config,
value, value,
) )
} }
@ -215,6 +224,7 @@ impl FrozenSendBuilder {
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(), self.req.sni_host.clone(),
self.req.connect_config,
stream, stream,
) )
} }
@ -231,6 +241,7 @@ impl FrozenSendBuilder {
self.req.timeout, self.req.timeout,
&self.req.config, &self.req.config,
self.req.sni_host.clone(), self.req.sni_host.clone(),
self.req.connect_config,
) )
} }
} }

View file

@ -73,13 +73,13 @@ where
fn call(&self, req: ConnectRequest) -> Self::Future { fn call(&self, req: ConnectRequest) -> Self::Future {
match req { match req {
ConnectRequest::Tunnel(head, addr, sni_host) => { ConnectRequest::Tunnel(head, addr, sni_host, config) => {
let fut = self let fut = self
.connector .connector
.call(ConnectRequest::Tunnel(head, addr, sni_host)); .call(ConnectRequest::Tunnel(head, addr, sni_host, config));
RedirectServiceFuture::Tunnel { fut } RedirectServiceFuture::Tunnel { fut }
} }
ConnectRequest::Client(head, body, addr, sni_host) => { ConnectRequest::Client(head, body, addr, sni_host, config) => {
let connector = Rc::clone(&self.connector); let connector = Rc::clone(&self.connector);
let max_redirect_times = self.max_redirect_times; let max_redirect_times = self.max_redirect_times;
@ -98,7 +98,8 @@ where
_ => None, _ => None,
}; };
let fut = connector.call(ConnectRequest::Client(head, body, addr, sni_host)); let fut =
connector.call(ConnectRequest::Client(head, body, addr, sni_host, config));
RedirectServiceFuture::Client { RedirectServiceFuture::Client {
fut, fut,
@ -223,8 +224,8 @@ where
let fut = connector let fut = connector
.as_ref() .as_ref()
.unwrap() .unwrap()
// @TODO find a way to get sni host // @TODO find a way to get sni host and config
.call(ConnectRequest::Client(head, body_new, addr, None)); .call(ConnectRequest::Client(head, body_new, addr, None, None));
self.set(RedirectServiceFuture::Client { self.set(RedirectServiceFuture::Client {
fut, fut,

View file

@ -14,7 +14,7 @@ use serde::Serialize;
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
use crate::cookie::{Cookie, CookieJar}; use crate::cookie::{Cookie, CookieJar};
use crate::{ use crate::{
client::{ClientConfig, ServerName}, client::{ClientConfig, ConnectConfig, ServerName},
error::{FreezeRequestError, InvalidUrl}, error::{FreezeRequestError, InvalidUrl},
frozen::FrozenClientRequest, frozen::FrozenClientRequest,
sender::{PrepForSendingError, RequestSender, SendClientRequest}, sender::{PrepForSendingError, RequestSender, SendClientRequest},
@ -49,6 +49,7 @@ pub struct ClientRequest {
timeout: Option<Duration>, timeout: Option<Duration>,
config: ClientConfig, config: ClientConfig,
sni_host: Option<ServerName>, sni_host: Option<ServerName>,
connect_config: Option<ConnectConfig>,
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookies: Option<CookieJar>, cookies: Option<CookieJar>,
@ -71,6 +72,7 @@ impl ClientRequest {
timeout: None, timeout: None,
response_decompress: true, response_decompress: true,
sni_host: None, sni_host: None,
connect_config: None,
} }
.method(method) .method(method)
.uri(uri) .uri(uri)
@ -281,6 +283,15 @@ impl ClientRequest {
self self
} }
/// Set specific connector configuration for this request.
///
/// Not all config may be applied to the request, it depends on the connector and also
/// if there is already a connection established.
pub fn connect_config(mut self, config: ConnectConfig) -> Self {
self.connect_config = Some(config);
self
}
/// Set request timeout. Overrides client wide timeout setting. /// Set request timeout. Overrides client wide timeout setting.
/// ///
/// Request timeout is the total time before a response must be received. /// Request timeout is the total time before a response must be received.
@ -332,6 +343,7 @@ impl ClientRequest {
ServerName::Borrowed(r) => ServerName::Borrowed(r), ServerName::Borrowed(r) => ServerName::Borrowed(r),
ServerName::Owned(o) => ServerName::Borrowed(Rc::new(o)), ServerName::Owned(o) => ServerName::Borrowed(Rc::new(o)),
}), }),
connect_config: slf.connect_config.map(Rc::new),
}; };
Ok(request) Ok(request)
@ -353,6 +365,7 @@ impl ClientRequest {
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host, slf.sni_host,
slf.connect_config.map(Rc::new),
body, body,
) )
} }
@ -370,6 +383,7 @@ impl ClientRequest {
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host, slf.sni_host,
slf.connect_config.map(Rc::new),
value, value,
) )
} }
@ -389,6 +403,7 @@ impl ClientRequest {
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host, slf.sni_host,
slf.connect_config.map(Rc::new),
value, value,
) )
} }
@ -410,6 +425,7 @@ impl ClientRequest {
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host, slf.sni_host,
slf.connect_config.map(Rc::new),
stream, stream,
) )
} }
@ -427,6 +443,7 @@ impl ClientRequest {
slf.timeout, slf.timeout,
&slf.config, &slf.config,
slf.sni_host, slf.sni_host,
slf.connect_config.map(Rc::new),
) )
} }

View file

@ -23,7 +23,7 @@ use serde::Serialize;
use crate::{ use crate::{
any_body::AnyBody, any_body::AnyBody,
client::{ClientConfig, ServerName}, client::{ClientConfig, ConnectConfig, ServerName},
error::{FreezeRequestError, InvalidUrl, SendRequestError}, error::{FreezeRequestError, InvalidUrl, SendRequestError},
BoxError, ClientResponse, ConnectRequest, ConnectResponse, BoxError, ClientResponse, ConnectRequest, ConnectResponse,
}; };
@ -187,6 +187,7 @@ impl RequestSender {
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>, sni_host: Option<ServerName>,
connect_config: Option<Rc<ConnectConfig>>,
body: impl MessageBody + 'static, body: impl MessageBody + 'static,
) -> SendClientRequest { ) -> SendClientRequest {
let req = match self { let req = match self {
@ -195,12 +196,14 @@ impl RequestSender {
AnyBody::from_message_body(body).into_boxed(), AnyBody::from_message_body(body).into_boxed(),
addr, addr,
sni_host, sni_host,
connect_config,
), ),
RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( RequestSender::Rc(head, extra_headers) => ConnectRequest::Client(
RequestHeadType::Rc(head, extra_headers), RequestHeadType::Rc(head, extra_headers),
AnyBody::from_message_body(body).into_boxed(), AnyBody::from_message_body(body).into_boxed(),
addr, addr,
sni_host, sni_host,
connect_config,
), ),
}; };
@ -216,6 +219,7 @@ impl RequestSender {
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>, sni_host: Option<ServerName>,
connector_config: Option<Rc<ConnectConfig>>,
value: impl Serialize, value: impl Serialize,
) -> SendClientRequest { ) -> SendClientRequest {
let body = match serde_json::to_string(&value) { let body = match serde_json::to_string(&value) {
@ -227,7 +231,15 @@ impl RequestSender {
return err.into(); return err.into();
} }
self.send_body(addr, response_decompress, timeout, config, sni_host, body) self.send_body(
addr,
response_decompress,
timeout,
config,
sni_host,
connector_config,
body,
)
} }
pub(crate) fn send_form( pub(crate) fn send_form(
@ -237,6 +249,7 @@ impl RequestSender {
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>, sni_host: Option<ServerName>,
connector_config: Option<Rc<ConnectConfig>>,
value: impl Serialize, value: impl Serialize,
) -> SendClientRequest { ) -> SendClientRequest {
let body = match serde_urlencoded::to_string(value) { let body = match serde_urlencoded::to_string(value) {
@ -251,7 +264,15 @@ impl RequestSender {
return err.into(); return err.into();
} }
self.send_body(addr, response_decompress, timeout, config, sni_host, body) self.send_body(
addr,
response_decompress,
timeout,
config,
sni_host,
connector_config,
body,
)
} }
pub(crate) fn send_stream<S, E>( pub(crate) fn send_stream<S, E>(
@ -261,6 +282,7 @@ impl RequestSender {
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>, sni_host: Option<ServerName>,
connector_config: Option<Rc<ConnectConfig>>,
stream: S, stream: S,
) -> SendClientRequest ) -> SendClientRequest
where where
@ -273,6 +295,7 @@ impl RequestSender {
timeout, timeout,
config, config,
sni_host, sni_host,
connector_config,
BodyStream::new(stream), BodyStream::new(stream),
) )
} }
@ -284,8 +307,17 @@ impl RequestSender {
timeout: Option<Duration>, timeout: Option<Duration>,
config: &ClientConfig, config: &ClientConfig,
sni_host: Option<ServerName>, sni_host: Option<ServerName>,
connector_config: Option<Rc<ConnectConfig>>,
) -> SendClientRequest { ) -> SendClientRequest {
self.send_body(addr, response_decompress, timeout, config, sni_host, ()) self.send_body(
addr,
response_decompress,
timeout,
config,
sni_host,
connector_config,
(),
)
} }
fn set_header_if_none<V>(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> fn set_header_if_none<V>(&mut self, key: HeaderName, value: V) -> Result<(), HttpError>

View file

@ -26,7 +26,7 @@
//! } //! }
//! ``` //! ```
use std::{fmt, net::SocketAddr, str}; use std::{fmt, net::SocketAddr, rc::Rc, str};
use actix_codec::Framed; use actix_codec::Framed;
pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message};
@ -38,7 +38,7 @@ use base64::prelude::*;
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
use crate::cookie::{Cookie, CookieJar}; use crate::cookie::{Cookie, CookieJar};
use crate::{ use crate::{
client::{ClientConfig, ServerName}, client::{ClientConfig, ConnectConfig, ServerName},
connect::{BoxedSocket, ConnectRequest}, connect::{BoxedSocket, ConnectRequest},
error::{HttpError, InvalidUrl, SendRequestError, WsClientError}, error::{HttpError, InvalidUrl, SendRequestError, WsClientError},
http::{ http::{
@ -59,6 +59,7 @@ pub struct WebsocketsRequest {
server_mode: bool, server_mode: bool,
config: ClientConfig, config: ClientConfig,
sni_host: Option<ServerName>, sni_host: Option<ServerName>,
connect_config: Option<ConnectConfig>,
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookies: Option<CookieJar>, cookies: Option<CookieJar>,
@ -98,6 +99,7 @@ impl WebsocketsRequest {
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookies: None, cookies: None,
sni_host: None, sni_host: None,
connect_config: None,
} }
} }
@ -110,6 +112,15 @@ impl WebsocketsRequest {
self self
} }
/// Set specific connector configuration for this request.
///
/// Not all config may be applied to the request, it depends on the connector and also
/// if there is already a connection established.
pub fn connector_config(mut self, config: ConnectConfig) -> Self {
self.connect_config = Some(config);
self
}
/// Set supported WebSocket protocols /// Set supported WebSocket protocols
pub fn protocols<U, V>(mut self, protos: U) -> Self pub fn protocols<U, V>(mut self, protos: U) -> Self
where where
@ -346,7 +357,12 @@ impl WebsocketsRequest {
let max_size = self.max_size; let max_size = self.max_size;
let server_mode = self.server_mode; let server_mode = self.server_mode;
let req = ConnectRequest::Tunnel(head, self.addr, self.sni_host); let req = ConnectRequest::Tunnel(
head,
self.addr,
self.sni_host,
self.connect_config.map(Rc::new),
);
let fut = self.config.connector.call(req); let fut = self.config.connector.call(req);