From 610dd616ef3206e781a04d33ed6ed2e2ed4bd50f Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Fri, 6 Dec 2024 09:43:26 +0100 Subject: [PATCH] feat(awc): split connector config with connect config, allow to configure connect config per request --- awc/src/client/config.rs | 100 ++++++++++++++++++++++++++++++--- awc/src/client/connector.rs | 39 +++++++------ awc/src/client/h2proto.rs | 8 +-- awc/src/client/mod.rs | 2 + awc/src/client/pool.rs | 39 +++++++++++-- awc/src/connect.rs | 23 +++++--- awc/src/frozen.rs | 13 ++++- awc/src/middleware/redirect.rs | 13 +++-- awc/src/request.rs | 19 ++++++- awc/src/sender.rs | 40 +++++++++++-- awc/src/ws.rs | 22 +++++++- 11 files changed, 261 insertions(+), 57 deletions(-) diff --git a/awc/src/client/config.rs b/awc/src/client/config.rs index 530c1e03b..bd3da234d 100644 --- a/awc/src/client/config.rs +++ b/awc/src/client/config.rs @@ -3,29 +3,33 @@ use std::{net::IpAddr, time::Duration}; const DEFAULT_H2_CONN_WINDOW: u32 = 1024 * 1024 * 2; // 2MB const DEFAULT_H2_STREAM_WINDOW: u32 = 1024 * 1024; // 1MB -/// Connector configuration -#[derive(Clone)] -pub(crate) struct ConnectorConfig { +/// Connect configuration +#[derive(Clone, Hash, Eq, PartialEq)] +pub struct ConnectConfig { pub(crate) timeout: Duration, pub(crate) handshake_timeout: Duration, pub(crate) conn_lifetime: Duration, pub(crate) conn_keep_alive: Duration, - pub(crate) disconnect_timeout: Option, - pub(crate) limit: usize, pub(crate) conn_window_size: u32, pub(crate) stream_window_size: u32, pub(crate) local_address: Option, } -impl Default for ConnectorConfig { +/// Connector configuration +#[derive(Clone)] +pub struct ConnectorConfig { + pub(crate) default_connect_config: ConnectConfig, + pub(crate) disconnect_timeout: Option, + pub(crate) limit: usize, +} + +impl Default for ConnectConfig { fn default() -> Self { Self { timeout: Duration::from_secs(5), handshake_timeout: Duration::from_secs(5), conn_lifetime: Duration::from_secs(75), conn_keep_alive: Duration::from_secs(15), - disconnect_timeout: Some(Duration::from_millis(3000)), - limit: 100, conn_window_size: DEFAULT_H2_CONN_WINDOW, stream_window_size: DEFAULT_H2_STREAM_WINDOW, local_address: None, @@ -33,10 +37,88 @@ impl Default for ConnectorConfig { } } +impl Default for ConnectorConfig { + fn default() -> Self { + Self { + default_connect_config: ConnectConfig::default(), + disconnect_timeout: Some(Duration::from_millis(3000)), + limit: 100, + } + } +} + impl ConnectorConfig { - pub(crate) fn no_disconnect_timeout(&self) -> Self { + pub fn no_disconnect_timeout(&self) -> Self { let mut res = self.clone(); res.disconnect_timeout = None; res } } + +impl ConnectConfig { + /// Sets TCP connection timeout. + /// + /// This is the max time allowed to connect to remote host, including DNS name resolution. + /// + /// By default, the timeout is 5 seconds. + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + /// Sets TLS handshake timeout. + /// + /// This is the max time allowed to perform the TLS handshake with remote host after TCP + /// connection is established. + /// + /// By default, the timeout is 5 seconds. + pub fn handshake_timeout(mut self, timeout: Duration) -> Self { + self.handshake_timeout = timeout; + self + } + + /// Sets the initial window size (in bytes) for HTTP/2 stream-level flow control for received + /// data. + /// + /// The default value is 65,535 and is good for APIs, but not for big objects. + pub fn initial_window_size(mut self, size: u32) -> Self { + self.stream_window_size = size; + self + } + + /// Sets the initial window size (in bytes) for HTTP/2 connection-level flow control for + /// received data. + /// + /// The default value is 65,535 and is good for APIs, but not for big objects. + pub fn initial_connection_window_size(mut self, size: u32) -> Self { + self.conn_window_size = size; + self + } + + /// Set keep-alive period for opened connection. + /// + /// Keep-alive period is the period between connection usage. If + /// the delay between repeated usages of the same connection + /// exceeds this period, the connection is closed. + /// Default keep-alive period is 15 seconds. + pub fn conn_keep_alive(mut self, dur: Duration) -> Self { + self.conn_keep_alive = dur; + self + } + + /// Set max lifetime period for connection. + /// + /// Connection lifetime is max lifetime of any opened connection + /// until it is closed regardless of keep-alive period. + /// Default lifetime period is 75 seconds. + pub fn conn_lifetime(mut self, dur: Duration) -> Self { + self.conn_lifetime = dur; + self + } + + /// Set local IP Address the connector would use for establishing connection. + pub fn local_address(mut self, addr: IpAddr) -> Self { + self.local_address = Some(addr); + self + } +} diff --git a/awc/src/client/connector.rs b/awc/src/client/connector.rs index 2e3f977fa..d0118d3cc 100644 --- a/awc/src/client/connector.rs +++ b/awc/src/client/connector.rs @@ -282,7 +282,7 @@ where /// /// By default, the timeout is 5 seconds. pub fn timeout(mut self, timeout: Duration) -> Self { - self.config.timeout = timeout; + self.config.default_connect_config.timeout = timeout; self } @@ -293,7 +293,7 @@ where /// /// By default, the timeout is 5 seconds. pub fn handshake_timeout(mut self, timeout: Duration) -> Self { - self.config.handshake_timeout = timeout; + self.config.default_connect_config.handshake_timeout = timeout; self } @@ -387,7 +387,7 @@ where /// /// The default value is 65,535 and is good for APIs, but not for big objects. pub fn initial_window_size(mut self, size: u32) -> Self { - self.config.stream_window_size = size; + self.config.default_connect_config.stream_window_size = size; self } @@ -396,7 +396,7 @@ where /// /// The default value is 65,535 and is good for APIs, but not for big objects. pub fn initial_connection_window_size(mut self, size: u32) -> Self { - self.config.conn_window_size = size; + self.config.default_connect_config.conn_window_size = size; self } @@ -422,7 +422,7 @@ where /// exceeds this period, the connection is closed. /// Default keep-alive period is 15 seconds. pub fn conn_keep_alive(mut self, dur: Duration) -> Self { - self.config.conn_keep_alive = dur; + self.config.default_connect_config.conn_keep_alive = dur; self } @@ -432,7 +432,7 @@ where /// until it is closed regardless of keep-alive period. /// Default lifetime period is 75 seconds. pub fn conn_lifetime(mut self, dur: Duration) -> Self { - self.config.conn_lifetime = dur; + self.config.default_connect_config.conn_lifetime = dur; self } @@ -451,7 +451,7 @@ where /// Set local IP Address the connector would use for establishing connection. pub fn local_address(mut self, addr: IpAddr) -> Self { - self.config.local_address = Some(addr); + self.config.default_connect_config.local_address = Some(addr); self } @@ -459,8 +459,8 @@ where /// /// The `Connector` builder always concludes by calling `finish()` last in its combinator chain. pub fn finish(self) -> ConnectorService { - let local_address = self.config.local_address; - let timeout = self.config.timeout; + let local_address = self.config.default_connect_config.local_address; + let timeout = self.config.default_connect_config.timeout; let tcp_service_inner = TcpConnectorInnerService::new(self.connector, timeout, local_address); @@ -523,7 +523,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -557,7 +557,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -596,7 +596,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -630,7 +630,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -667,7 +667,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -701,7 +701,7 @@ where } } - let handshake_timeout = self.config.handshake_timeout; + let handshake_timeout = self.config.default_connect_config.handshake_timeout; let tls_service = TlsConnectorService { tcp_service: tcp_service_inner, @@ -824,9 +824,13 @@ where } fn call(&self, req: Connect) -> Self::Future { + let timeout = req + .config + .clone() + .map(|c| c.handshake_timeout) + .unwrap_or(self.timeout); let fut = self.tcp_service.call(req); let tls_service = self.tls_service.clone(); - let timeout = self.timeout; TlsConnectorFuture::TcpConnect { fut, @@ -935,6 +939,7 @@ where actix_service::forward_ready!(service); fn call(&self, req: Connect) -> Self::Future { + let timeout = req.config.map(|c| c.timeout).unwrap_or(self.timeout); let mut req = ConnectInfo::new(HostnameWithSni::ForTcp( req.hostname, req.port, @@ -949,7 +954,7 @@ where TcpConnectorInnerFuture { fut: self.service.call(req), - timeout: sleep(self.timeout), + timeout: sleep(timeout), } } } diff --git a/awc/src/client/h2proto.rs b/awc/src/client/h2proto.rs index c3f801f20..738a9f12b 100644 --- a/awc/src/client/h2proto.rs +++ b/awc/src/client/h2proto.rs @@ -19,7 +19,6 @@ use http::{ use log::trace; use super::{ - config::ConnectorConfig, connection::{ConnectionIo, H2Connection}, error::SendRequestError, }; @@ -186,12 +185,13 @@ where pub(crate) fn handshake( io: Io, - config: &ConnectorConfig, + stream_window_size: u32, + conn_window_size: u32, ) -> impl Future, Connection), h2::Error>> { let mut builder = Builder::new(); builder - .initial_window_size(config.stream_window_size) - .initial_connection_window_size(config.conn_window_size) + .initial_window_size(stream_window_size) + .initial_connection_window_size(conn_window_size) .enable_push(false); builder.handshake(io) } diff --git a/awc/src/client/mod.rs b/awc/src/client/mod.rs index 5ac8650d3..ca324b3da 100644 --- a/awc/src/client/mod.rs +++ b/awc/src/client/mod.rs @@ -20,6 +20,7 @@ mod h2proto; mod pool; pub use self::{ + config::ConnectConfig, connection::{Connection, ConnectionIo}, connector::{Connector, ConnectorService, HostnameWithSni}, error::{ConnectError, FreezeRequestError, InvalidUrl, SendRequestError}, @@ -49,6 +50,7 @@ pub struct Connect { pub port: u16, pub tls: bool, pub addr: Option, + pub config: Option>, } /// An asynchronous HTTP and WebSocket client. diff --git a/awc/src/client/pool.rs b/awc/src/client/pool.rs index 9b1058d85..c0429e045 100644 --- a/awc/src/client/pool.rs +++ b/awc/src/client/pool.rs @@ -190,7 +190,11 @@ where let now = Instant::now(); while let Some(mut c) = conns.pop_front() { - let config = &inner.config; + let config = req + .config + .as_ref() + .map(|c| c.as_ref()) + .unwrap_or(&inner.config.default_connect_config); let idle_dur = now - c.used; let age = now - c.created; let conn_ineligible = @@ -225,6 +229,17 @@ where conn }; + let stream_window_size = req + .config + .as_ref() + .map(|c| c.stream_window_size) + .unwrap_or(inner.config.default_connect_config.stream_window_size); + let conn_window_size = req + .config + .as_ref() + .map(|c| c.conn_window_size) + .unwrap_or(inner.config.default_connect_config.conn_window_size); + // construct acquired. It's used to put Io type back to pool/ close the Io type. // permit is carried with the whole lifecycle of Acquired. let acquired = Acquired { @@ -245,8 +260,8 @@ where if proto == Protocol::Http1 { Ok(ConnectionType::from_h1(io, Instant::now(), acquired)) } else { - let config = &acquired.inner.config; - let (sender, connection) = handshake(io, config).await?; + let (sender, connection) = + handshake(io, stream_window_size, conn_window_size).await?; let inner = H2ConnectionInner::new(sender, connection); Ok(ConnectionType::from_h2(inner, Instant::now(), acquired)) } @@ -381,6 +396,7 @@ mod test { use std::cell::Cell; use super::*; + use crate::client::ConnectConfig; /// A stream type that always returns pending on async read. /// @@ -469,6 +485,7 @@ mod test { tls: false, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -500,7 +517,10 @@ mod test { let connector = TestPoolConnector { generated }; let config = ConnectorConfig { - conn_keep_alive: Duration::from_secs(1), + default_connect_config: ConnectConfig { + conn_keep_alive: Duration::from_secs(1), + ..Default::default() + }, ..Default::default() }; @@ -512,6 +532,7 @@ mod test { tls: false, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -545,7 +566,10 @@ mod test { let connector = TestPoolConnector { generated }; let config = ConnectorConfig { - conn_lifetime: Duration::from_secs(1), + default_connect_config: ConnectConfig { + conn_lifetime: Duration::from_secs(1), + ..Default::default() + }, ..Default::default() }; @@ -557,6 +581,7 @@ mod test { tls: false, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -599,6 +624,7 @@ mod test { tls: true, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -615,6 +641,7 @@ mod test { tls: true, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -642,6 +669,7 @@ mod test { tls: true, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); @@ -654,6 +682,7 @@ mod test { tls: true, sni_host: None, addr: None, + config: None, }; let conn = pool.call(req.clone()).await.unwrap(); assert_eq!(2, generated_clone.get()); diff --git a/awc/src/connect.rs b/awc/src/connect.rs index a7bbd7b2d..f5b79d0de 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -14,8 +14,8 @@ use futures_core::{future::LocalBoxFuture, ready}; use crate::{ any_body::AnyBody, client::{ - Connect as ClientConnect, ConnectError, Connection, ConnectionIo, SendRequestError, - ServerName, + Connect as ClientConnect, ConnectConfig, ConnectError, Connection, ConnectionIo, + SendRequestError, ServerName, }, ClientResponse, }; @@ -41,12 +41,18 @@ pub enum ConnectRequest { AnyBody, Option, Option, + Option>, ), /// Tunnel used by WebSocket connection requests. /// /// Contains the request head, optional pre-resolved socket address and optional sni host. - Tunnel(RequestHead, Option, Option), + Tunnel( + RequestHead, + Option, + Option, + Option>, + ), } /// Combined HTTP response & WebSocket tunnel type returned from connection service. @@ -111,11 +117,13 @@ where fn call(&self, req: ConnectRequest) -> Self::Future { // connect to the host - let (head, addr, sni_host) = match req { - ConnectRequest::Client(ref head, .., addr, ref sni_host) => { - (head.as_ref(), addr, sni_host.clone()) + let (head, addr, sni_host, config) = match req { + ConnectRequest::Client(ref head, .., addr, ref sni_host, ref config) => { + (head.as_ref(), addr, sni_host.clone(), config.clone()) + } + ConnectRequest::Tunnel(ref head, addr, ref sni_host, ref config) => { + (head, addr, sni_host.clone(), config.clone()) } - ConnectRequest::Tunnel(ref head, addr, ref sni_host) => (head, addr, sni_host.clone()), }; let authority = if let Some(authority) = head.uri.authority() { @@ -144,6 +152,7 @@ where tls, sni_host, addr, + config, }); ConnectRequestFuture::Connection { diff --git a/awc/src/frozen.rs b/awc/src/frozen.rs index d622f8ece..eba5e1a43 100644 --- a/awc/src/frozen.rs +++ b/awc/src/frozen.rs @@ -11,7 +11,7 @@ use futures_core::Stream; use serde::Serialize; use crate::{ - client::{ClientConfig, ServerName}, + client::{ClientConfig, ConnectConfig, ServerName}, sender::{RequestSender, SendClientRequest}, BoxError, }; @@ -27,6 +27,7 @@ pub struct FrozenClientRequest { pub(crate) timeout: Option, pub(crate) config: ClientConfig, pub(crate) sni_host: Option, + pub(crate) connect_config: Option>, } impl FrozenClientRequest { @@ -56,6 +57,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), body, ) } @@ -68,6 +70,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), value, ) } @@ -80,6 +83,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), value, ) } @@ -96,6 +100,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), stream, ) } @@ -108,6 +113,7 @@ impl FrozenClientRequest { self.timeout, &self.config, self.sni_host.clone(), + self.connect_config.clone(), ) } @@ -163,6 +169,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, body, ) } @@ -179,6 +186,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, value, ) } @@ -195,6 +203,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, value, ) } @@ -215,6 +224,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, stream, ) } @@ -231,6 +241,7 @@ impl FrozenSendBuilder { self.req.timeout, &self.req.config, self.req.sni_host.clone(), + self.req.connect_config, ) } } diff --git a/awc/src/middleware/redirect.rs b/awc/src/middleware/redirect.rs index 81f4d7993..d927328bb 100644 --- a/awc/src/middleware/redirect.rs +++ b/awc/src/middleware/redirect.rs @@ -73,13 +73,13 @@ where fn call(&self, req: ConnectRequest) -> Self::Future { match req { - ConnectRequest::Tunnel(head, addr, sni_host) => { + ConnectRequest::Tunnel(head, addr, sni_host, config) => { let fut = self .connector - .call(ConnectRequest::Tunnel(head, addr, sni_host)); + .call(ConnectRequest::Tunnel(head, addr, sni_host, config)); RedirectServiceFuture::Tunnel { fut } } - ConnectRequest::Client(head, body, addr, sni_host) => { + ConnectRequest::Client(head, body, addr, sni_host, config) => { let connector = Rc::clone(&self.connector); let max_redirect_times = self.max_redirect_times; @@ -98,7 +98,8 @@ where _ => None, }; - let fut = connector.call(ConnectRequest::Client(head, body, addr, sni_host)); + let fut = + connector.call(ConnectRequest::Client(head, body, addr, sni_host, config)); RedirectServiceFuture::Client { fut, @@ -223,8 +224,8 @@ where let fut = connector .as_ref() .unwrap() - // @TODO find a way to get sni host - .call(ConnectRequest::Client(head, body_new, addr, None)); + // @TODO find a way to get sni host and config + .call(ConnectRequest::Client(head, body_new, addr, None, None)); self.set(RedirectServiceFuture::Client { fut, diff --git a/awc/src/request.rs b/awc/src/request.rs index b0f995a63..e24f19c70 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -14,7 +14,7 @@ use serde::Serialize; #[cfg(feature = "cookies")] use crate::cookie::{Cookie, CookieJar}; use crate::{ - client::{ClientConfig, ServerName}, + client::{ClientConfig, ConnectConfig, ServerName}, error::{FreezeRequestError, InvalidUrl}, frozen::FrozenClientRequest, sender::{PrepForSendingError, RequestSender, SendClientRequest}, @@ -49,6 +49,7 @@ pub struct ClientRequest { timeout: Option, config: ClientConfig, sni_host: Option, + connect_config: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -71,6 +72,7 @@ impl ClientRequest { timeout: None, response_decompress: true, sni_host: None, + connect_config: None, } .method(method) .uri(uri) @@ -281,6 +283,15 @@ impl ClientRequest { self } + /// Set specific connector configuration for this request. + /// + /// Not all config may be applied to the request, it depends on the connector and also + /// if there is already a connection established. + pub fn connect_config(mut self, config: ConnectConfig) -> Self { + self.connect_config = Some(config); + self + } + /// Set request timeout. Overrides client wide timeout setting. /// /// Request timeout is the total time before a response must be received. @@ -332,6 +343,7 @@ impl ClientRequest { ServerName::Borrowed(r) => ServerName::Borrowed(r), ServerName::Owned(o) => ServerName::Borrowed(Rc::new(o)), }), + connect_config: slf.connect_config.map(Rc::new), }; Ok(request) @@ -353,6 +365,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), body, ) } @@ -370,6 +383,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), value, ) } @@ -389,6 +403,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), value, ) } @@ -410,6 +425,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), stream, ) } @@ -427,6 +443,7 @@ impl ClientRequest { slf.timeout, &slf.config, slf.sni_host, + slf.connect_config.map(Rc::new), ) } diff --git a/awc/src/sender.rs b/awc/src/sender.rs index ab3ca596d..8347ceb4c 100644 --- a/awc/src/sender.rs +++ b/awc/src/sender.rs @@ -23,7 +23,7 @@ use serde::Serialize; use crate::{ any_body::AnyBody, - client::{ClientConfig, ServerName}, + client::{ClientConfig, ConnectConfig, ServerName}, error::{FreezeRequestError, InvalidUrl, SendRequestError}, BoxError, ClientResponse, ConnectRequest, ConnectResponse, }; @@ -187,6 +187,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connect_config: Option>, body: impl MessageBody + 'static, ) -> SendClientRequest { let req = match self { @@ -195,12 +196,14 @@ impl RequestSender { AnyBody::from_message_body(body).into_boxed(), addr, sni_host, + connect_config, ), RequestSender::Rc(head, extra_headers) => ConnectRequest::Client( RequestHeadType::Rc(head, extra_headers), AnyBody::from_message_body(body).into_boxed(), addr, sni_host, + connect_config, ), }; @@ -216,6 +219,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connector_config: Option>, value: impl Serialize, ) -> SendClientRequest { let body = match serde_json::to_string(&value) { @@ -227,7 +231,15 @@ impl RequestSender { return err.into(); } - self.send_body(addr, response_decompress, timeout, config, sni_host, body) + self.send_body( + addr, + response_decompress, + timeout, + config, + sni_host, + connector_config, + body, + ) } pub(crate) fn send_form( @@ -237,6 +249,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connector_config: Option>, value: impl Serialize, ) -> SendClientRequest { let body = match serde_urlencoded::to_string(value) { @@ -251,7 +264,15 @@ impl RequestSender { return err.into(); } - self.send_body(addr, response_decompress, timeout, config, sni_host, body) + self.send_body( + addr, + response_decompress, + timeout, + config, + sni_host, + connector_config, + body, + ) } pub(crate) fn send_stream( @@ -261,6 +282,7 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connector_config: Option>, stream: S, ) -> SendClientRequest where @@ -273,6 +295,7 @@ impl RequestSender { timeout, config, sni_host, + connector_config, BodyStream::new(stream), ) } @@ -284,8 +307,17 @@ impl RequestSender { timeout: Option, config: &ClientConfig, sni_host: Option, + connector_config: Option>, ) -> SendClientRequest { - self.send_body(addr, response_decompress, timeout, config, sni_host, ()) + self.send_body( + addr, + response_decompress, + timeout, + config, + sni_host, + connector_config, + (), + ) } fn set_header_if_none(&mut self, key: HeaderName, value: V) -> Result<(), HttpError> diff --git a/awc/src/ws.rs b/awc/src/ws.rs index ef5cb7155..77b00fceb 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -26,7 +26,7 @@ //! } //! ``` -use std::{fmt, net::SocketAddr, str}; +use std::{fmt, net::SocketAddr, rc::Rc, str}; use actix_codec::Framed; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; @@ -38,7 +38,7 @@ use base64::prelude::*; #[cfg(feature = "cookies")] use crate::cookie::{Cookie, CookieJar}; use crate::{ - client::{ClientConfig, ServerName}, + client::{ClientConfig, ConnectConfig, ServerName}, connect::{BoxedSocket, ConnectRequest}, error::{HttpError, InvalidUrl, SendRequestError, WsClientError}, http::{ @@ -59,6 +59,7 @@ pub struct WebsocketsRequest { server_mode: bool, config: ClientConfig, sni_host: Option, + connect_config: Option, #[cfg(feature = "cookies")] cookies: Option, @@ -98,6 +99,7 @@ impl WebsocketsRequest { #[cfg(feature = "cookies")] cookies: None, sni_host: None, + connect_config: None, } } @@ -110,6 +112,15 @@ impl WebsocketsRequest { self } + /// Set specific connector configuration for this request. + /// + /// Not all config may be applied to the request, it depends on the connector and also + /// if there is already a connection established. + pub fn connector_config(mut self, config: ConnectConfig) -> Self { + self.connect_config = Some(config); + self + } + /// Set supported WebSocket protocols pub fn protocols(mut self, protos: U) -> Self where @@ -346,7 +357,12 @@ impl WebsocketsRequest { let max_size = self.max_size; let server_mode = self.server_mode; - let req = ConnectRequest::Tunnel(head, self.addr, self.sni_host); + let req = ConnectRequest::Tunnel( + head, + self.addr, + self.sni_host, + self.connect_config.map(Rc::new), + ); let fut = self.config.connector.call(req);