1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-09-27 22:01:56 +00:00

expose on_connect v2 (#1754)

Co-authored-by: Mikail Bagishov <bagishov.mikail@yandex.ru>
This commit is contained in:
Rob Ede 2020-10-30 02:03:26 +00:00 committed by GitHub
parent 4519db36b2
commit 9963a5ef54
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 372 additions and 70 deletions

View file

@ -6,6 +6,7 @@
* Add request-local data extractor `web::ReqData`. [#1748] * Add request-local data extractor `web::ReqData`. [#1748]
* Add ability to register closure for request middleware logging. [#1749] * Add ability to register closure for request middleware logging. [#1749]
* Add `app_data` to `ServiceConfig`. [#1757] * Add `app_data` to `ServiceConfig`. [#1757]
* Expose `on_connect` for access to the connection stream before request is handled. [#1754]
### Changed ### Changed
* Print non-configured `Data<T>` type when attempting extraction. [#1743] * Print non-configured `Data<T>` type when attempting extraction. [#1743]
@ -16,6 +17,7 @@
[#1743]: https://github.com/actix/actix-web/pull/1743 [#1743]: https://github.com/actix/actix-web/pull/1743
[#1748]: https://github.com/actix/actix-web/pull/1748 [#1748]: https://github.com/actix/actix-web/pull/1748
[#1750]: https://github.com/actix/actix-web/pull/1750 [#1750]: https://github.com/actix/actix-web/pull/1750
[#1754]: https://github.com/actix/actix-web/pull/1754
[#1749]: https://github.com/actix/actix-web/pull/1749 [#1749]: https://github.com/actix/actix-web/pull/1749

View file

@ -64,6 +64,14 @@ required-features = ["compress"]
name = "test_server" name = "test_server"
required-features = ["compress"] required-features = ["compress"]
[[example]]
name = "on_connect"
required-features = []
[[example]]
name = "client"
required-features = ["rustls"]
[dependencies] [dependencies]
actix-codec = "0.3.0" actix-codec = "0.3.0"
actix-service = "1.0.6" actix-service = "1.0.6"
@ -105,7 +113,7 @@ tinyvec = { version = "1", features = ["alloc"] }
actix = "0.10.0" actix = "0.10.0"
actix-http = { version = "2.0.0", features = ["actors"] } actix-http = { version = "2.0.0", features = ["actors"] }
rand = "0.7" rand = "0.7"
env_logger = "0.7" env_logger = "0.8"
serde_derive = "1.0" serde_derive = "1.0"
brotli2 = "0.3.2" brotli2 = "0.3.2"
flate2 = "1.0.13" flate2 = "1.0.13"
@ -125,10 +133,6 @@ actix-files = { path = "actix-files" }
actix-multipart = { path = "actix-multipart" } actix-multipart = { path = "actix-multipart" }
awc = { path = "awc" } awc = { path = "awc" }
[[example]]
name = "client"
required-features = ["rustls"]
[[bench]] [[bench]]
name = "server" name = "server"
harness = false harness = false

View file

@ -1,9 +1,16 @@
# Changes # Changes
## Unreleased - 2020-xx-xx ## Unreleased - 2020-xx-xx
### Added
* Added more flexible `on_connect_ext` methods for on-connect handling. [#1754]
### Changed
* Upgrade `base64` to `0.13`. * Upgrade `base64` to `0.13`.
* Upgrade `pin-project` to `1.0`. * Upgrade `pin-project` to `1.0`.
[#1754]: https://github.com/actix/actix-web/pull/1754
## 2.0.0 - 2020-09-11 ## 2.0.0 - 2020-09-11
* No significant changes from `2.0.0-beta.4`. * No significant changes from `2.0.0-beta.4`.

View file

@ -14,10 +14,11 @@ use crate::helpers::{Data, DataFactory};
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::service::HttpService; use crate::service::HttpService;
use crate::{ConnectCallback, Extensions};
/// A http service builder /// A HTTP service builder
/// ///
/// This type can be used to construct an instance of `http service` through a /// This type can be used to construct an instance of [`HttpService`] through a
/// builder-like pattern. /// builder-like pattern.
pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> { pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
keep_alive: KeepAlive, keep_alive: KeepAlive,
@ -27,7 +28,9 @@ pub struct HttpServiceBuilder<T, S, X = ExpectHandler, U = UpgradeHandler<T>> {
local_addr: Option<net::SocketAddr>, local_addr: Option<net::SocketAddr>,
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
// DEPRECATED: in favor of on_connect_ext
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, S)>, _t: PhantomData<(T, S)>,
} }
@ -49,6 +52,7 @@ where
expect: ExpectHandler, expect: ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
on_connect_ext: None,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -138,6 +142,7 @@ where
expect: expect.into_factory(), expect: expect.into_factory(),
upgrade: self.upgrade, upgrade: self.upgrade,
on_connect: self.on_connect, on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -167,14 +172,16 @@ where
expect: self.expect, expect: self.expect,
upgrade: Some(upgrade.into_factory()), upgrade: Some(upgrade.into_factory()),
on_connect: self.on_connect, on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData, _t: PhantomData,
} }
} }
/// Set on-connect callback. /// Set on-connect callback.
/// ///
/// It get called once per connection and result of the call /// Called once per connection. Return value of the call is stored in request extensions.
/// get stored to the request's extensions. ///
/// *SOFT DEPRECATED*: Prefer the `on_connect_ext` style callback.
pub fn on_connect<F, I>(mut self, f: F) -> Self pub fn on_connect<F, I>(mut self, f: F) -> Self
where where
F: Fn(&T) -> I + 'static, F: Fn(&T) -> I + 'static,
@ -184,7 +191,20 @@ where
self self
} }
/// Finish service configuration and create *http service* for HTTP/1 protocol. /// Sets the callback to be run on connection establishment.
///
/// Has mutable access to a data container that will be merged into request extensions.
/// This enables transport layer data (like client certificates) to be accessed in middleware
/// and handlers.
pub fn on_connect_ext<F>(mut self, f: F) -> Self
where
F: Fn(&T, &mut Extensions) + 'static,
{
self.on_connect_ext = Some(Rc::new(f));
self
}
/// Finish service configuration and create a HTTP Service for HTTP/1 protocol.
pub fn h1<F, B>(self, service: F) -> H1Service<T, S, B, X, U> pub fn h1<F, B>(self, service: F) -> H1Service<T, S, B, X, U>
where where
B: MessageBody, B: MessageBody,
@ -200,13 +220,15 @@ where
self.secure, self.secure,
self.local_addr, self.local_addr,
); );
H1Service::with_config(cfg, service.into_factory()) H1Service::with_config(cfg, service.into_factory())
.expect(self.expect) .expect(self.expect)
.upgrade(self.upgrade) .upgrade(self.upgrade)
.on_connect(self.on_connect) .on_connect(self.on_connect)
.on_connect_ext(self.on_connect_ext)
} }
/// Finish service configuration and create *http service* for HTTP/2 protocol. /// Finish service configuration and create a HTTP service for HTTP/2 protocol.
pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B> pub fn h2<F, B>(self, service: F) -> H2Service<T, S, B>
where where
B: MessageBody + 'static, B: MessageBody + 'static,
@ -223,7 +245,10 @@ where
self.secure, self.secure,
self.local_addr, self.local_addr,
); );
H2Service::with_config(cfg, service.into_factory()).on_connect(self.on_connect)
H2Service::with_config(cfg, service.into_factory())
.on_connect(self.on_connect)
.on_connect_ext(self.on_connect_ext)
} }
/// Finish service configuration and create `HttpService` instance. /// Finish service configuration and create `HttpService` instance.
@ -243,9 +268,11 @@ where
self.secure, self.secure,
self.local_addr, self.local_addr,
); );
HttpService::with_config(cfg, service.into_factory()) HttpService::with_config(cfg, service.into_factory())
.expect(self.expect) .expect(self.expect)
.upgrade(self.upgrade) .upgrade(self.upgrade)
.on_connect(self.on_connect) .on_connect(self.on_connect)
.on_connect_ext(self.on_connect_ext)
} }
} }

View file

@ -1,5 +1,5 @@
use std::any::{Any, TypeId}; use std::any::{Any, TypeId};
use std::fmt; use std::{fmt, mem};
use fxhash::FxHashMap; use fxhash::FxHashMap;
@ -66,6 +66,11 @@ impl Extensions {
pub fn extend(&mut self, other: Extensions) { pub fn extend(&mut self, other: Extensions) {
self.map.extend(other.map); self.map.extend(other.map);
} }
/// Sets (or overrides) items from `other` into this map.
pub(crate) fn drain_from(&mut self, other: &mut Self) {
self.map.extend(mem::take(&mut other.map));
}
} }
impl fmt::Debug for Extensions { impl fmt::Debug for Extensions {
@ -213,4 +218,27 @@ mod tests {
assert_eq!(extensions.get(), Some(&20u8)); assert_eq!(extensions.get(), Some(&20u8));
assert_eq!(extensions.get_mut(), Some(&mut 20u8)); assert_eq!(extensions.get_mut(), Some(&mut 20u8));
} }
#[test]
fn test_drain_from() {
let mut ext = Extensions::new();
ext.insert(2isize);
let mut more_ext = Extensions::new();
more_ext.insert(5isize);
more_ext.insert(5usize);
assert_eq!(ext.get::<isize>(), Some(&2isize));
assert_eq!(ext.get::<usize>(), None);
assert_eq!(more_ext.get::<isize>(), Some(&5isize));
assert_eq!(more_ext.get::<usize>(), Some(&5usize));
ext.drain_from(&mut more_ext);
assert_eq!(ext.get::<isize>(), Some(&5isize));
assert_eq!(ext.get::<usize>(), Some(&5usize));
assert_eq!(more_ext.get::<isize>(), None);
assert_eq!(more_ext.get::<usize>(), None);
}
} }

View file

@ -12,7 +12,6 @@ use bytes::{Buf, BytesMut};
use log::{error, trace}; use log::{error, trace};
use pin_project::pin_project; use pin_project::pin_project;
use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService; use crate::cloneable::CloneableService;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error};
@ -21,6 +20,10 @@ use crate::helpers::DataFactory;
use crate::httpmessage::HttpMessage; use crate::httpmessage::HttpMessage;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{
body::{Body, BodySize, MessageBody, ResponseBody},
Extensions,
};
use super::codec::Codec; use super::codec::Codec;
use super::payload::{Payload, PayloadSender, PayloadStatus}; use super::payload::{Payload, PayloadSender, PayloadStatus};
@ -88,6 +91,7 @@ where
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>, on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
flags: Flags, flags: Flags,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
error: Option<DispatchError>, error: Option<DispatchError>,
@ -167,7 +171,7 @@ where
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>, U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
/// Create http/1 dispatcher. /// Create HTTP/1 dispatcher.
pub(crate) fn new( pub(crate) fn new(
stream: T, stream: T,
config: ServiceConfig, config: ServiceConfig,
@ -175,6 +179,7 @@ where
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>, on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
) -> Self { ) -> Self {
Dispatcher::with_timeout( Dispatcher::with_timeout(
@ -187,6 +192,7 @@ where
expect, expect,
upgrade, upgrade,
on_connect, on_connect,
on_connect_data,
peer_addr, peer_addr,
) )
} }
@ -202,6 +208,7 @@ where
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
on_connect: Option<Box<dyn DataFactory>>, on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
) -> Self { ) -> Self {
let keepalive = config.keep_alive_enabled(); let keepalive = config.keep_alive_enabled();
@ -234,6 +241,7 @@ where
expect, expect,
upgrade, upgrade,
on_connect, on_connect,
on_connect_data,
flags, flags,
peer_addr, peer_addr,
ka_expire, ka_expire,
@ -526,11 +534,15 @@ where
let pl = this.codec.message_type(); let pl = this.codec.message_type();
req.head_mut().peer_addr = *this.peer_addr; req.head_mut().peer_addr = *this.peer_addr;
// DEPRECATED
// set on_connect data // set on_connect data
if let Some(ref on_connect) = this.on_connect { if let Some(ref on_connect) = this.on_connect {
on_connect.set(&mut req.extensions_mut()); on_connect.set(&mut req.extensions_mut());
} }
// merge on_connect_ext data into request extensions
req.extensions_mut().drain_from(this.on_connect_data);
if pl == MessageType::Stream && this.upgrade.is_some() { if pl == MessageType::Stream && this.upgrade.is_some() {
this.messages.push_back(DispatcherMessage::Upgrade(req)); this.messages.push_back(DispatcherMessage::Upgrade(req));
break; break;
@ -927,8 +939,10 @@ mod tests {
CloneableService::new(ExpectHandler), CloneableService::new(ExpectHandler),
None, None,
None, None,
Extensions::new(),
None, None,
); );
match Pin::new(&mut h1).poll(cx) { match Pin::new(&mut h1).poll(cx) {
Poll::Pending => panic!(), Poll::Pending => panic!(),
Poll::Ready(res) => assert!(res.is_err()), Poll::Ready(res) => assert!(res.is_err()),

View file

@ -18,6 +18,7 @@ use crate::error::{DispatchError, Error, ParseError};
use crate::helpers::DataFactory; use crate::helpers::DataFactory;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{ConnectCallback, Extensions};
use super::codec::Codec; use super::codec::Codec;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
@ -30,6 +31,7 @@ pub struct H1Service<T, S, B, X = ExpectHandler, U = UpgradeHandler<T>> {
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, B)>,
} }
@ -52,6 +54,7 @@ where
expect: ExpectHandler, expect: ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
on_connect_ext: None,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -213,6 +216,7 @@ where
srv: self.srv, srv: self.srv,
upgrade: self.upgrade, upgrade: self.upgrade,
on_connect: self.on_connect, on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -229,6 +233,7 @@ where
srv: self.srv, srv: self.srv,
expect: self.expect, expect: self.expect,
on_connect: self.on_connect, on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -241,6 +246,12 @@ where
self.on_connect = f; self.on_connect = f;
self self
} }
/// Set on connect callback.
pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
self.on_connect_ext = f;
self
}
} }
impl<T, S, B, X, U> ServiceFactory for H1Service<T, S, B, X, U> impl<T, S, B, X, U> ServiceFactory for H1Service<T, S, B, X, U>
@ -274,6 +285,7 @@ where
expect: None, expect: None,
upgrade: None, upgrade: None,
on_connect: self.on_connect.clone(), on_connect: self.on_connect.clone(),
on_connect_ext: self.on_connect_ext.clone(),
cfg: Some(self.cfg.clone()), cfg: Some(self.cfg.clone()),
_t: PhantomData, _t: PhantomData,
} }
@ -303,6 +315,7 @@ where
expect: Option<X::Service>, expect: Option<X::Service>,
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, B)>,
} }
@ -352,23 +365,26 @@ where
Poll::Ready(result.map(|service| { Poll::Ready(result.map(|service| {
let this = self.as_mut().project(); let this = self.as_mut().project();
H1ServiceHandler::new( H1ServiceHandler::new(
this.cfg.take().unwrap(), this.cfg.take().unwrap(),
service, service,
this.expect.take().unwrap(), this.expect.take().unwrap(),
this.upgrade.take(), this.upgrade.take(),
this.on_connect.clone(), this.on_connect.clone(),
this.on_connect_ext.clone(),
) )
})) }))
} }
} }
/// `Service` implementation for HTTP1 transport /// `Service` implementation for HTTP/1 transport
pub struct H1ServiceHandler<T, S: Service, B, X: Service, U: Service> { pub struct H1ServiceHandler<T, S: Service, B, X: Service, U: Service> {
srv: CloneableService<S>, srv: CloneableService<S>,
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: ServiceConfig, cfg: ServiceConfig,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, B)>,
} }
@ -390,6 +406,7 @@ where
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
) -> H1ServiceHandler<T, S, B, X, U> { ) -> H1ServiceHandler<T, S, B, X, U> {
H1ServiceHandler { H1ServiceHandler {
srv: CloneableService::new(srv), srv: CloneableService::new(srv),
@ -397,6 +414,7 @@ where
upgrade: upgrade.map(CloneableService::new), upgrade: upgrade.map(CloneableService::new),
cfg, cfg,
on_connect, on_connect,
on_connect_ext,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -462,11 +480,13 @@ where
} }
fn call(&mut self, (io, addr): Self::Request) -> Self::Future { fn call(&mut self, (io, addr): Self::Request) -> Self::Future {
let on_connect = if let Some(ref on_connect) = self.on_connect { let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io));
Some(on_connect(&io))
} else { let mut connect_extensions = Extensions::new();
None if let Some(ref handler) = self.on_connect_ext {
}; // run on_connect_ext callback, populating connect extensions
handler(&io, &mut connect_extensions);
}
Dispatcher::new( Dispatcher::new(
io, io,
@ -474,7 +494,8 @@ where
self.srv.clone(), self.srv.clone(),
self.expect.clone(), self.expect.clone(),
self.upgrade.clone(), self.upgrade.clone(),
on_connect, deprecated_on_connect,
connect_extensions,
addr, addr,
) )
} }

View file

@ -24,6 +24,7 @@ use crate::message::ResponseHead;
use crate::payload::Payload; use crate::payload::Payload;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::Extensions;
const CHUNK_SIZE: usize = 16_384; const CHUNK_SIZE: usize = 16_384;
@ -36,6 +37,7 @@ where
service: CloneableService<S>, service: CloneableService<S>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
on_connect: Option<Box<dyn DataFactory>>, on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
config: ServiceConfig, config: ServiceConfig,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
ka_expire: Instant, ka_expire: Instant,
@ -56,6 +58,7 @@ where
service: CloneableService<S>, service: CloneableService<S>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
on_connect: Option<Box<dyn DataFactory>>, on_connect: Option<Box<dyn DataFactory>>,
on_connect_data: Extensions,
config: ServiceConfig, config: ServiceConfig,
timeout: Option<Delay>, timeout: Option<Delay>,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
@ -82,6 +85,7 @@ where
peer_addr, peer_addr,
connection, connection,
on_connect, on_connect,
on_connect_data,
ka_expire, ka_expire,
ka_timer, ka_timer,
_t: PhantomData, _t: PhantomData,
@ -130,11 +134,15 @@ where
head.headers = parts.headers.into(); head.headers = parts.headers.into();
head.peer_addr = this.peer_addr; head.peer_addr = this.peer_addr;
// DEPRECATED
// set on_connect data // set on_connect data
if let Some(ref on_connect) = this.on_connect { if let Some(ref on_connect) = this.on_connect {
on_connect.set(&mut req.extensions_mut()); on_connect.set(&mut req.extensions_mut());
} }
// merge on_connect_ext data into request extensions
req.extensions_mut().drain_from(&mut this.on_connect_data);
actix_rt::spawn(ServiceResponse::< actix_rt::spawn(ServiceResponse::<
S::Future, S::Future,
S::Response, S::Response,

View file

@ -2,7 +2,7 @@ use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{net, rc}; use std::{net, rc::Rc};
use actix_codec::{AsyncRead, AsyncWrite}; use actix_codec::{AsyncRead, AsyncWrite};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
@ -23,6 +23,7 @@ use crate::error::{DispatchError, Error};
use crate::helpers::DataFactory; use crate::helpers::DataFactory;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{ConnectCallback, Extensions};
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
@ -30,7 +31,8 @@ use super::dispatcher::Dispatcher;
pub struct H2Service<T, S, B> { pub struct H2Service<T, S, B> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, B)>,
} }
@ -50,19 +52,27 @@ where
H2Service { H2Service {
cfg, cfg,
on_connect: None, on_connect: None,
on_connect_ext: None,
srv: service.into_factory(), srv: service.into_factory(),
_t: PhantomData, _t: PhantomData,
} }
} }
/// Set on connect callback. /// Set on connect callback.
pub(crate) fn on_connect( pub(crate) fn on_connect(
mut self, mut self,
f: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, f: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
) -> Self { ) -> Self {
self.on_connect = f; self.on_connect = f;
self self
} }
/// Set on connect callback.
pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
self.on_connect_ext = f;
self
}
} }
impl<S, B> H2Service<TcpStream, S, B> impl<S, B> H2Service<TcpStream, S, B>
@ -203,6 +213,7 @@ where
fut: self.srv.new_service(()), fut: self.srv.new_service(()),
cfg: Some(self.cfg.clone()), cfg: Some(self.cfg.clone()),
on_connect: self.on_connect.clone(), on_connect: self.on_connect.clone(),
on_connect_ext: self.on_connect_ext.clone(),
_t: PhantomData, _t: PhantomData,
} }
} }
@ -214,7 +225,8 @@ pub struct H2ServiceResponse<T, S: ServiceFactory, B> {
#[pin] #[pin]
fut: S::Future, fut: S::Future,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, B)>,
} }
@ -237,6 +249,7 @@ where
H2ServiceHandler::new( H2ServiceHandler::new(
this.cfg.take().unwrap(), this.cfg.take().unwrap(),
this.on_connect.clone(), this.on_connect.clone(),
this.on_connect_ext.clone(),
service, service,
) )
})) }))
@ -247,7 +260,8 @@ where
pub struct H2ServiceHandler<T, S: Service, B> { pub struct H2ServiceHandler<T, S: Service, B> {
srv: CloneableService<S>, srv: CloneableService<S>,
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, B)>,
} }
@ -261,12 +275,14 @@ where
{ {
fn new( fn new(
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
srv: S, srv: S,
) -> H2ServiceHandler<T, S, B> { ) -> H2ServiceHandler<T, S, B> {
H2ServiceHandler { H2ServiceHandler {
cfg, cfg,
on_connect, on_connect,
on_connect_ext,
srv: CloneableService::new(srv), srv: CloneableService::new(srv),
_t: PhantomData, _t: PhantomData,
} }
@ -296,18 +312,21 @@ where
} }
fn call(&mut self, (io, addr): Self::Request) -> Self::Future { fn call(&mut self, (io, addr): Self::Request) -> Self::Future {
let on_connect = if let Some(ref on_connect) = self.on_connect { let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io));
Some(on_connect(&io))
} else { let mut connect_extensions = Extensions::new();
None if let Some(ref handler) = self.on_connect_ext {
}; // run on_connect_ext callback, populating connect extensions
handler(&io, &mut connect_extensions);
}
H2ServiceHandlerResponse { H2ServiceHandlerResponse {
state: State::Handshake( state: State::Handshake(
Some(self.srv.clone()), Some(self.srv.clone()),
Some(self.cfg.clone()), Some(self.cfg.clone()),
addr, addr,
on_connect, deprecated_on_connect,
Some(connect_extensions),
server::handshake(io), server::handshake(io),
), ),
} }
@ -325,6 +344,7 @@ where
Option<ServiceConfig>, Option<ServiceConfig>,
Option<net::SocketAddr>, Option<net::SocketAddr>,
Option<Box<dyn DataFactory>>, Option<Box<dyn DataFactory>>,
Option<Extensions>,
Handshake<T, Bytes>, Handshake<T, Bytes>,
), ),
} }
@ -360,6 +380,7 @@ where
ref mut config, ref mut config,
ref peer_addr, ref peer_addr,
ref mut on_connect, ref mut on_connect,
ref mut on_connect_data,
ref mut handshake, ref mut handshake,
) => match Pin::new(handshake).poll(cx) { ) => match Pin::new(handshake).poll(cx) {
Poll::Ready(Ok(conn)) => { Poll::Ready(Ok(conn)) => {
@ -367,6 +388,7 @@ where
srv.take().unwrap(), srv.take().unwrap(),
conn, conn,
on_connect.take(), on_connect.take(),
on_connect_data.take().unwrap(),
config.take().unwrap(), config.take().unwrap(),
None, None,
*peer_addr, *peer_addr,

View file

@ -50,6 +50,7 @@ impl<'a> io::Write for Writer<'a> {
self.0.extend_from_slice(buf); self.0.extend_from_slice(buf);
Ok(buf.len()) Ok(buf.len())
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
Ok(()) Ok(())
} }

View file

@ -1,4 +1,4 @@
//! Basic http primitives for actix-net framework. //! Basic HTTP primitives for the Actix ecosystem.
#![deny(rust_2018_idioms)] #![deny(rust_2018_idioms)]
#![allow( #![allow(
@ -78,3 +78,5 @@ pub enum Protocol {
Http1, Http1,
Http2, Http2,
} }
type ConnectCallback<IO> = dyn Fn(&IO, &mut Extensions);

View file

@ -1,7 +1,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{fmt, net, rc}; use std::{fmt, net, rc::Rc};
use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_codec::{AsyncRead, AsyncWrite, Framed};
use actix_rt::net::TcpStream; use actix_rt::net::TcpStream;
@ -20,15 +20,17 @@ use crate::error::{DispatchError, Error};
use crate::helpers::DataFactory; use crate::helpers::DataFactory;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{h1, h2::Dispatcher, Protocol}; use crate::{h1, h2::Dispatcher, ConnectCallback, Extensions, Protocol};
/// `ServiceFactory` HTTP1.1/HTTP2 transport implementation /// A `ServiceFactory` for HTTP/1.1 or HTTP/2 protocol.
pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> { pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler<T>> {
srv: S, srv: S,
cfg: ServiceConfig, cfg: ServiceConfig,
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, // DEPRECATED: in favor of on_connect_ext
on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, B)>,
} }
@ -66,6 +68,7 @@ where
expect: h1::ExpectHandler, expect: h1::ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
on_connect_ext: None,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -81,6 +84,7 @@ where
expect: h1::ExpectHandler, expect: h1::ExpectHandler,
upgrade: None, upgrade: None,
on_connect: None, on_connect: None,
on_connect_ext: None,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -113,6 +117,7 @@ where
srv: self.srv, srv: self.srv,
upgrade: self.upgrade, upgrade: self.upgrade,
on_connect: self.on_connect, on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -138,6 +143,7 @@ where
srv: self.srv, srv: self.srv,
expect: self.expect, expect: self.expect,
on_connect: self.on_connect, on_connect: self.on_connect,
on_connect_ext: self.on_connect_ext,
_t: PhantomData, _t: PhantomData,
} }
} }
@ -145,11 +151,17 @@ where
/// Set on connect callback. /// Set on connect callback.
pub(crate) fn on_connect( pub(crate) fn on_connect(
mut self, mut self,
f: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, f: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
) -> Self { ) -> Self {
self.on_connect = f; self.on_connect = f;
self self
} }
/// Set connect callback with mutable access to request data container.
pub(crate) fn on_connect_ext(mut self, f: Option<Rc<ConnectCallback<T>>>) -> Self {
self.on_connect_ext = f;
self
}
} }
impl<S, B, X, U> HttpService<TcpStream, S, B, X, U> impl<S, B, X, U> HttpService<TcpStream, S, B, X, U>
@ -355,6 +367,7 @@ where
expect: None, expect: None,
upgrade: None, upgrade: None,
on_connect: self.on_connect.clone(), on_connect: self.on_connect.clone(),
on_connect_ext: self.on_connect_ext.clone(),
cfg: self.cfg.clone(), cfg: self.cfg.clone(),
_t: PhantomData, _t: PhantomData,
} }
@ -378,7 +391,8 @@ pub struct HttpServiceResponse<
fut_upg: Option<U::Future>, fut_upg: Option<U::Future>,
expect: Option<X::Service>, expect: Option<X::Service>,
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: ServiceConfig, cfg: ServiceConfig,
_t: PhantomData<(T, B)>, _t: PhantomData<(T, B)>,
} }
@ -429,6 +443,7 @@ where
.fut .fut
.poll(cx) .poll(cx)
.map_err(|e| log::error!("Init http service error: {:?}", e))); .map_err(|e| log::error!("Init http service error: {:?}", e)));
Poll::Ready(result.map(|service| { Poll::Ready(result.map(|service| {
let this = self.as_mut().project(); let this = self.as_mut().project();
HttpServiceHandler::new( HttpServiceHandler::new(
@ -437,6 +452,7 @@ where
this.expect.take().unwrap(), this.expect.take().unwrap(),
this.upgrade.take(), this.upgrade.take(),
this.on_connect.clone(), this.on_connect.clone(),
this.on_connect_ext.clone(),
) )
})) }))
} }
@ -448,7 +464,8 @@ pub struct HttpServiceHandler<T, S: Service, B, X: Service, U: Service> {
expect: CloneableService<X>, expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>, upgrade: Option<CloneableService<U>>,
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_t: PhantomData<(T, B, X)>, _t: PhantomData<(T, B, X)>,
} }
@ -469,11 +486,13 @@ where
srv: S, srv: S,
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect: Option<rc::Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>, on_connect: Option<Rc<dyn Fn(&T) -> Box<dyn DataFactory>>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>,
) -> HttpServiceHandler<T, S, B, X, U> { ) -> HttpServiceHandler<T, S, B, X, U> {
HttpServiceHandler { HttpServiceHandler {
cfg, cfg,
on_connect, on_connect,
on_connect_ext,
srv: CloneableService::new(srv), srv: CloneableService::new(srv),
expect: CloneableService::new(expect), expect: CloneableService::new(expect),
upgrade: upgrade.map(CloneableService::new), upgrade: upgrade.map(CloneableService::new),
@ -543,11 +562,12 @@ where
} }
fn call(&mut self, (io, proto, peer_addr): Self::Request) -> Self::Future { fn call(&mut self, (io, proto, peer_addr): Self::Request) -> Self::Future {
let on_connect = if let Some(ref on_connect) = self.on_connect { let mut connect_extensions = Extensions::new();
Some(on_connect(&io))
} else { let deprecated_on_connect = self.on_connect.as_ref().map(|handler| handler(&io));
None if let Some(ref handler) = self.on_connect_ext {
}; handler(&io, &mut connect_extensions);
}
match proto { match proto {
Protocol::Http2 => HttpServiceHandlerResponse { Protocol::Http2 => HttpServiceHandlerResponse {
@ -555,10 +575,12 @@ where
server::handshake(io), server::handshake(io),
self.cfg.clone(), self.cfg.clone(),
self.srv.clone(), self.srv.clone(),
on_connect, deprecated_on_connect,
connect_extensions,
peer_addr, peer_addr,
))), ))),
}, },
Protocol::Http1 => HttpServiceHandlerResponse { Protocol::Http1 => HttpServiceHandlerResponse {
state: State::H1(h1::Dispatcher::new( state: State::H1(h1::Dispatcher::new(
io, io,
@ -566,7 +588,8 @@ where
self.srv.clone(), self.srv.clone(),
self.expect.clone(), self.expect.clone(),
self.upgrade.clone(), self.upgrade.clone(),
on_connect, deprecated_on_connect,
connect_extensions,
peer_addr, peer_addr,
)), )),
}, },
@ -595,6 +618,7 @@ where
ServiceConfig, ServiceConfig,
CloneableService<S>, CloneableService<S>,
Option<Box<dyn DataFactory>>, Option<Box<dyn DataFactory>>,
Extensions,
Option<net::SocketAddr>, Option<net::SocketAddr>,
)>, )>,
), ),
@ -670,9 +694,16 @@ where
} else { } else {
panic!() panic!()
}; };
let (_, cfg, srv, on_connect, peer_addr) = data.take().unwrap(); let (_, cfg, srv, on_connect, on_connect_data, peer_addr) =
data.take().unwrap();
self.set(State::H2(Dispatcher::new( self.set(State::H2(Dispatcher::new(
srv, conn, on_connect, cfg, None, peer_addr, srv,
conn,
on_connect,
on_connect_data,
cfg,
None,
peer_addr,
))); )));
self.poll(cx) self.poll(cx)
} }

View file

@ -411,8 +411,10 @@ async fn test_h2_on_connect() {
let srv = test_server(move || { let srv = test_server(move || {
HttpService::build() HttpService::build()
.on_connect(|_| 10usize) .on_connect(|_| 10usize)
.on_connect_ext(|_, data| data.insert(20isize))
.h2(|req: Request| { .h2(|req: Request| {
assert!(req.extensions().contains::<usize>()); assert!(req.extensions().contains::<usize>());
assert!(req.extensions().contains::<isize>());
ok::<_, ()>(Response::Ok().finish()) ok::<_, ()>(Response::Ok().finish())
}) })
.openssl(ssl_acceptor()) .openssl(ssl_acceptor())

View file

@ -663,8 +663,10 @@ async fn test_h1_on_connect() {
let srv = test_server(|| { let srv = test_server(|| {
HttpService::build() HttpService::build()
.on_connect(|_| 10usize) .on_connect(|_| 10usize)
.on_connect_ext(|_, data| data.insert(20isize))
.h1(|req: Request| { .h1(|req: Request| {
assert!(req.extensions().contains::<usize>()); assert!(req.extensions().contains::<usize>());
assert!(req.extensions().contains::<isize>());
future::ok::<_, ()>(Response::Ok().finish()) future::ok::<_, ()>(Response::Ok().finish())
}) })
.tcp() .tcp()

51
examples/on_connect.rs Normal file
View file

@ -0,0 +1,51 @@
//! This example shows how to use `actix_web::HttpServer::on_connect` to access a lower-level socket
//! properties and pass them to a handler through request-local data.
//!
//! For an example of extracting a client TLS certificate, see:
//! <https://github.com/actix/examples/tree/HEAD/rustls-client-cert>
use std::{any::Any, env, io, net::SocketAddr};
use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer};
#[derive(Debug, Clone)]
struct ConnectionInfo {
bind: SocketAddr,
peer: SocketAddr,
ttl: Option<u32>,
}
async fn route_whoami(conn_info: web::ReqData<ConnectionInfo>) -> String {
format!(
"Here is some info about your connection:\n\n{:#?}",
conn_info
)
}
fn get_conn_info(connection: &dyn Any, data: &mut Extensions) {
if let Some(sock) = connection.downcast_ref::<TcpStream>() {
data.insert(ConnectionInfo {
bind: sock.local_addr().unwrap(),
peer: sock.peer_addr().unwrap(),
ttl: sock.ttl().ok(),
});
} else {
unreachable!("connection should only be plaintext since no TLS is set up");
}
}
#[actix_web::main]
async fn main() -> io::Result<()> {
if env::var("RUST_LOG").is_err() {
env::set_var("RUST_LOG", "info");
}
env_logger::init();
HttpServer::new(|| App::new().default_service(web::to(route_whoami)))
.on_connect(get_conn_info)
.bind(("127.0.0.1", 8080))?
.workers(1)
.run()
.await
}

View file

@ -1,8 +1,14 @@
use std::marker::PhantomData; use std::{
use std::sync::{Arc, Mutex}; any::Any,
use std::{fmt, io, net}; fmt, io,
marker::PhantomData,
net,
sync::{Arc, Mutex},
};
use actix_http::{body::MessageBody, Error, HttpService, KeepAlive, Request, Response}; use actix_http::{
body::MessageBody, Error, Extensions, HttpService, KeepAlive, Request, Response,
};
use actix_server::{Server, ServerBuilder}; use actix_server::{Server, ServerBuilder};
use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory}; use actix_service::{map_config, IntoServiceFactory, Service, ServiceFactory};
@ -64,6 +70,7 @@ where
backlog: i32, backlog: i32,
sockets: Vec<Socket>, sockets: Vec<Socket>,
builder: ServerBuilder, builder: ServerBuilder,
on_connect_fn: Option<Arc<dyn Fn(&dyn Any, &mut Extensions) + Send + Sync>>,
_t: PhantomData<(S, B)>, _t: PhantomData<(S, B)>,
} }
@ -91,6 +98,32 @@ where
backlog: 1024, backlog: 1024,
sockets: Vec::new(), sockets: Vec::new(),
builder: ServerBuilder::default(), builder: ServerBuilder::default(),
on_connect_fn: None,
_t: PhantomData,
}
}
/// Sets function that will be called once before each connection is handled.
/// It will receive a `&std::any::Any`, which contains underlying connection type and an
/// [Extensions] container so that request-local data can be passed to middleware and handlers.
///
/// For example:
/// - `actix_tls::openssl::SslStream<actix_web::rt::net::TcpStream>` when using openssl.
/// - `actix_tls::rustls::TlsStream<actix_web::rt::net::TcpStream>` when using rustls.
/// - `actix_web::rt::net::TcpStream` when no encryption is used.
///
/// See `on_connect` example for additional details.
pub fn on_connect<CB>(self, f: CB) -> HttpServer<F, I, S, B>
where
CB: Fn(&dyn Any, &mut Extensions) + Send + Sync + 'static,
{
HttpServer {
factory: self.factory,
config: self.config,
backlog: self.backlog,
sockets: self.sockets,
builder: self.builder,
on_connect_fn: Some(Arc::new(f)),
_t: PhantomData, _t: PhantomData,
} }
} }
@ -240,6 +273,7 @@ where
addr, addr,
scheme: "http", scheme: "http",
}); });
let on_connect_fn = self.on_connect_fn.clone();
self.builder = self.builder.listen( self.builder = self.builder.listen(
format!("actix-web-service-{}", addr), format!("actix-web-service-{}", addr),
@ -252,11 +286,20 @@ where
c.host.clone().unwrap_or_else(|| format!("{}", addr)), c.host.clone().unwrap_or_else(|| format!("{}", addr)),
); );
HttpService::build() let svc = HttpService::build()
.keep_alive(c.keep_alive) .keep_alive(c.keep_alive)
.client_timeout(c.client_timeout) .client_timeout(c.client_timeout)
.local_addr(addr) .local_addr(addr);
.finish(map_config(factory(), move |_| cfg.clone()))
let svc = if let Some(handler) = on_connect_fn.clone() {
svc.on_connect_ext(move |io: &_, ext: _| {
(handler)(io as &dyn Any, ext)
})
} else {
svc
};
svc.finish(map_config(factory(), move |_| cfg.clone()))
.tcp() .tcp()
}, },
)?; )?;
@ -289,6 +332,8 @@ where
scheme: "https", scheme: "https",
}); });
let on_connect_fn = self.on_connect_fn.clone();
self.builder = self.builder.listen( self.builder = self.builder.listen(
format!("actix-web-service-{}", addr), format!("actix-web-service-{}", addr),
lst, lst,
@ -299,11 +344,21 @@ where
addr, addr,
c.host.clone().unwrap_or_else(|| format!("{}", addr)), c.host.clone().unwrap_or_else(|| format!("{}", addr)),
); );
HttpService::build()
let svc = HttpService::build()
.keep_alive(c.keep_alive) .keep_alive(c.keep_alive)
.client_timeout(c.client_timeout) .client_timeout(c.client_timeout)
.client_disconnect(c.client_shutdown) .client_disconnect(c.client_shutdown);
.finish(map_config(factory(), move |_| cfg.clone()))
let svc = if let Some(handler) = on_connect_fn.clone() {
svc.on_connect_ext(move |io: &_, ext: _| {
(&*handler)(io as &dyn Any, ext)
})
} else {
svc
};
svc.finish(map_config(factory(), move |_| cfg.clone()))
.openssl(acceptor.clone()) .openssl(acceptor.clone())
}, },
)?; )?;
@ -336,6 +391,8 @@ where
scheme: "https", scheme: "https",
}); });
let on_connect_fn = self.on_connect_fn.clone();
self.builder = self.builder.listen( self.builder = self.builder.listen(
format!("actix-web-service-{}", addr), format!("actix-web-service-{}", addr),
lst, lst,
@ -346,11 +403,21 @@ where
addr, addr,
c.host.clone().unwrap_or_else(|| format!("{}", addr)), c.host.clone().unwrap_or_else(|| format!("{}", addr)),
); );
HttpService::build()
let svc = HttpService::build()
.keep_alive(c.keep_alive) .keep_alive(c.keep_alive)
.client_timeout(c.client_timeout) .client_timeout(c.client_timeout)
.client_disconnect(c.client_shutdown) .client_disconnect(c.client_shutdown);
.finish(map_config(factory(), move |_| cfg.clone()))
let svc = if let Some(handler) = on_connect_fn.clone() {
svc.on_connect_ext(move |io: &_, ext: _| {
(handler)(io as &dyn Any, ext)
})
} else {
svc
};
svc.finish(map_config(factory(), move |_| cfg.clone()))
.rustls(config.clone()) .rustls(config.clone())
}, },
)?; )?;
@ -441,7 +508,7 @@ where
} }
#[cfg(unix)] #[cfg(unix)]
/// Start listening for unix domain connections on existing listener. /// Start listening for unix domain (UDS) connections on existing listener.
pub fn listen_uds( pub fn listen_uds(
mut self, mut self,
lst: std::os::unix::net::UnixListener, lst: std::os::unix::net::UnixListener,
@ -460,6 +527,7 @@ where
}); });
let addr = format!("actix-web-service-{:?}", lst.local_addr()?); let addr = format!("actix-web-service-{:?}", lst.local_addr()?);
let on_connect_fn = self.on_connect_fn.clone();
self.builder = self.builder.listen_uds(addr, lst, move || { self.builder = self.builder.listen_uds(addr, lst, move || {
let c = cfg.lock().unwrap(); let c = cfg.lock().unwrap();
@ -468,11 +536,23 @@ where
socket_addr, socket_addr,
c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)), c.host.clone().unwrap_or_else(|| format!("{}", socket_addr)),
); );
pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then( pipeline_factory(|io: UnixStream| ok((io, Protocol::Http1, None))).and_then(
HttpService::build() {
.keep_alive(c.keep_alive) let svc = HttpService::build()
.client_timeout(c.client_timeout) .keep_alive(c.keep_alive)
.finish(map_config(factory(), move |_| config.clone())), .client_timeout(c.client_timeout);
let svc = if let Some(handler) = on_connect_fn.clone() {
svc.on_connect_ext(move |io: &_, ext: _| {
(&*handler)(io as &dyn Any, ext)
})
} else {
svc
};
svc.finish(map_config(factory(), move |_| config.clone()))
},
) )
})?; })?;
Ok(self) Ok(self)