From d286ccb4f5a86eca12c65b1632506a8bd8b37d19 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 28 Jun 2019 14:34:26 +0600 Subject: [PATCH] Add on-connect callback #946 --- actix-http/CHANGES.md | 6 ++++- actix-http/Cargo.toml | 6 ++--- actix-http/src/builder.rs | 21 +++++++++++++++ actix-http/src/h1/dispatcher.rs | 17 ++++++++++-- actix-http/src/h1/service.rs | 33 ++++++++++++++++++++++- actix-http/src/h2/dispatcher.rs | 6 ++++- actix-http/src/h2/service.rs | 36 ++++++++++++++++++++++++-- actix-http/src/helpers.rs | 14 ++++++++++ actix-http/src/service.rs | 46 ++++++++++++++++++++++++++++++--- 9 files changed, 171 insertions(+), 14 deletions(-) diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 6dea516dc..636cbedf7 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,6 +1,10 @@ # Changes -## [0.2.5] - unreleased +## [0.2.5] - 2019-06-28 + +### Added + +* Add `on-connect` callback, `HttpServiceBuilder::on_connect()` #946 ### Changed diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index c3930a7a6..afbf0a487 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "actix-http" -version = "0.2.4" +version = "0.2.5" authors = ["Nikolay Kim "] description = "Actix http primitives" readme = "README.md" @@ -44,10 +44,10 @@ fail = ["failure"] secure-cookies = ["ring"] [dependencies] -actix-service = "0.4.0" +actix-service = "0.4.1" actix-codec = "0.1.2" actix-connect = "0.2.0" -actix-utils = "0.4.1" +actix-utils = "0.4.2" actix-server-config = "0.1.1" actix-threadpool = "0.1.0" diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index b1b193a9e..b6967d94d 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -1,5 +1,6 @@ use std::fmt; use std::marker::PhantomData; +use std::rc::Rc; use actix_codec::Framed; use actix_server_config::ServerConfig as SrvConfig; @@ -10,6 +11,7 @@ use crate::config::{KeepAlive, ServiceConfig}; use crate::error::Error; use crate::h1::{Codec, ExpectHandler, H1Service, UpgradeHandler}; use crate::h2::H2Service; +use crate::helpers::{Data, DataFactory}; use crate::request::Request; use crate::response::Response; use crate::service::HttpService; @@ -24,6 +26,7 @@ pub struct HttpServiceBuilder> { client_disconnect: u64, expect: X, upgrade: Option, + on_connect: Option Box>>, _t: PhantomData<(T, S)>, } @@ -41,6 +44,7 @@ where client_disconnect: 0, expect: ExpectHandler, upgrade: None, + on_connect: None, _t: PhantomData, } } @@ -115,6 +119,7 @@ where client_disconnect: self.client_disconnect, expect: expect.into_new_service(), upgrade: self.upgrade, + on_connect: self.on_connect, _t: PhantomData, } } @@ -140,10 +145,24 @@ where client_disconnect: self.client_disconnect, expect: self.expect, upgrade: Some(upgrade.into_new_service()), + on_connect: self.on_connect, _t: PhantomData, } } + /// Set on-connect callback. + /// + /// It get called once per connection and result of the call + /// get stored to the request's extensions. + pub fn on_connect(mut self, f: F) -> Self + where + F: Fn(&T) -> I + 'static, + I: Clone + 'static, + { + self.on_connect = Some(Rc::new(move |io| Box::new(Data(f(io))))); + self + } + /// Finish service configuration and create *http service* for HTTP/1 protocol. pub fn h1(self, service: F) -> H1Service where @@ -161,6 +180,7 @@ where H1Service::with_config(cfg, service.into_new_service()) .expect(self.expect) .upgrade(self.upgrade) + .on_connect(self.on_connect) } /// Finish service configuration and create *http service* for HTTP/2 protocol. @@ -199,5 +219,6 @@ where HttpService::with_config(cfg, service.into_new_service()) .expect(self.expect) .upgrade(self.upgrade) + .on_connect(self.on_connect) } } diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 220984f8d..91990d05c 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -16,6 +16,8 @@ use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; use crate::error::{ParseError, PayloadError}; +use crate::helpers::DataFactory; +use crate::httpmessage::HttpMessage; use crate::request::Request; use crate::response::Response; @@ -81,6 +83,7 @@ where service: CloneableService, expect: CloneableService, upgrade: Option>, + on_connect: Option>, flags: Flags, peer_addr: Option, error: Option, @@ -174,12 +177,13 @@ where U::Error: fmt::Display, { /// Create http/1 dispatcher. - pub fn new( + pub(crate) fn new( stream: T, config: ServiceConfig, service: CloneableService, expect: CloneableService, upgrade: Option>, + on_connect: Option>, ) -> Self { Dispatcher::with_timeout( stream, @@ -190,11 +194,12 @@ where service, expect, upgrade, + on_connect, ) } /// Create http/1 dispatcher with slow request timeout. - pub fn with_timeout( + pub(crate) fn with_timeout( io: T, codec: Codec, config: ServiceConfig, @@ -203,6 +208,7 @@ where service: CloneableService, expect: CloneableService, upgrade: Option>, + on_connect: Option>, ) -> Self { let keepalive = config.keep_alive_enabled(); let flags = if keepalive { @@ -234,6 +240,7 @@ where service, expect, upgrade, + on_connect, flags, ka_expire, ka_timer, @@ -495,6 +502,11 @@ where let pl = self.codec.message_type(); req.head_mut().peer_addr = self.peer_addr; + // on_connect data + if let Some(ref on_connect) = self.on_connect { + on_connect.set(&mut req.extensions_mut()); + } + if pl == MessageType::Stream && self.upgrade.is_some() { self.messages.push_back(DispatcherMessage::Upgrade(req)); break; @@ -851,6 +863,7 @@ mod tests { ), CloneableService::new(ExpectHandler), None, + None, ); assert!(h1.poll().is_err()); diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index 2c0a48eba..192d1b598 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -1,5 +1,6 @@ use std::fmt; use std::marker::PhantomData; +use std::rc::Rc; use actix_codec::Framed; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; @@ -11,6 +12,7 @@ use futures::{try_ready, Async, Future, IntoFuture, Poll, Stream}; use crate::body::MessageBody; use crate::config::{KeepAlive, ServiceConfig}; use crate::error::{DispatchError, Error, ParseError}; +use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; @@ -24,6 +26,7 @@ pub struct H1Service> { cfg: ServiceConfig, expect: X, upgrade: Option, + on_connect: Option Box>>, _t: PhantomData<(T, P, B)>, } @@ -44,6 +47,7 @@ where srv: service.into_new_service(), expect: ExpectHandler, upgrade: None, + on_connect: None, _t: PhantomData, } } @@ -55,6 +59,7 @@ where srv: service.into_new_service(), expect: ExpectHandler, upgrade: None, + on_connect: None, _t: PhantomData, } } @@ -79,6 +84,7 @@ where cfg: self.cfg, srv: self.srv, upgrade: self.upgrade, + on_connect: self.on_connect, _t: PhantomData, } } @@ -94,9 +100,19 @@ where cfg: self.cfg, srv: self.srv, expect: self.expect, + on_connect: self.on_connect, _t: PhantomData, } } + + /// Set on connect callback. + pub(crate) fn on_connect( + mut self, + f: Option Box>>, + ) -> Self { + self.on_connect = f; + self + } } impl NewService for H1Service @@ -133,6 +149,7 @@ where fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), expect: None, upgrade: None, + on_connect: self.on_connect.clone(), cfg: Some(self.cfg.clone()), _t: PhantomData, } @@ -157,6 +174,7 @@ where fut_upg: Option, expect: Option, upgrade: Option, + on_connect: Option Box>>, cfg: Option, _t: PhantomData<(T, P, B)>, } @@ -205,6 +223,7 @@ where service, self.expect.take().unwrap(), self.upgrade.take(), + self.on_connect.clone(), ))) } } @@ -214,6 +233,7 @@ pub struct H1ServiceHandler { srv: CloneableService, expect: CloneableService, upgrade: Option>, + on_connect: Option Box>>, cfg: ServiceConfig, _t: PhantomData<(T, P, B)>, } @@ -234,12 +254,14 @@ where srv: S, expect: X, upgrade: Option, + on_connect: Option Box>>, ) -> H1ServiceHandler { H1ServiceHandler { srv: CloneableService::new(srv), expect: CloneableService::new(expect), upgrade: upgrade.map(|s| CloneableService::new(s)), cfg, + on_connect, _t: PhantomData, } } @@ -292,12 +314,21 @@ where } fn call(&mut self, req: Self::Request) -> Self::Future { + let io = req.into_parts().0; + + let on_connect = if let Some(ref on_connect) = self.on_connect { + Some(on_connect(&io)) + } else { + None + }; + Dispatcher::new( - req.into_parts().0, + io, self.cfg.clone(), self.srv.clone(), self.expect.clone(), self.upgrade.clone(), + on_connect, ) } } diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index e66ff63c3..48d32993d 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -22,6 +22,7 @@ use tokio_timer::Delay; use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error, ParseError, PayloadError, ResponseError}; +use crate::helpers::DataFactory; use crate::message::ResponseHead; use crate::payload::Payload; use crate::request::Request; @@ -33,6 +34,7 @@ const CHUNK_SIZE: usize = 16_384; pub struct Dispatcher, B: MessageBody> { service: CloneableService, connection: Connection, + on_connect: Option>, config: ServiceConfig, peer_addr: Option, ka_expire: Instant, @@ -49,9 +51,10 @@ where S::Response: Into>, B: MessageBody + 'static, { - pub fn new( + pub(crate) fn new( service: CloneableService, connection: Connection, + on_connect: Option>, config: ServiceConfig, timeout: Option, peer_addr: Option, @@ -77,6 +80,7 @@ where config, peer_addr, connection, + on_connect, ka_expire, ka_timer, _t: PhantomData, diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index b4191f03a..efc400da1 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; use std::marker::PhantomData; -use std::{io, net}; +use std::{io, net, rc}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_server_config::{Io, IoStream, ServerConfig as SrvConfig}; @@ -16,6 +16,7 @@ use log::error; use crate::body::MessageBody; use crate::config::{KeepAlive, ServiceConfig}; use crate::error::{DispatchError, Error, ParseError, ResponseError}; +use crate::helpers::DataFactory; use crate::payload::Payload; use crate::request::Request; use crate::response::Response; @@ -26,6 +27,7 @@ use super::dispatcher::Dispatcher; pub struct H2Service { srv: S, cfg: ServiceConfig, + on_connect: Option Box>>, _t: PhantomData<(T, P, B)>, } @@ -43,6 +45,7 @@ where H2Service { cfg, + on_connect: None, srv: service.into_new_service(), _t: PhantomData, } @@ -52,10 +55,20 @@ where pub fn with_config>(cfg: ServiceConfig, service: F) -> Self { H2Service { cfg, + on_connect: None, srv: service.into_new_service(), _t: PhantomData, } } + + /// Set on connect callback. + pub(crate) fn on_connect( + mut self, + f: Option Box>>, + ) -> Self { + self.on_connect = f; + self + } } impl NewService for H2Service @@ -79,6 +92,7 @@ where H2ServiceResponse { fut: self.srv.new_service(cfg).into_future(), cfg: Some(self.cfg.clone()), + on_connect: self.on_connect.clone(), _t: PhantomData, } } @@ -88,6 +102,7 @@ where pub struct H2ServiceResponse { fut: ::Future, cfg: Option, + on_connect: Option Box>>, _t: PhantomData<(T, P, B)>, } @@ -107,6 +122,7 @@ where let service = try_ready!(self.fut.poll()); Ok(Async::Ready(H2ServiceHandler::new( self.cfg.take().unwrap(), + self.on_connect.clone(), service, ))) } @@ -116,6 +132,7 @@ where pub struct H2ServiceHandler { srv: CloneableService, cfg: ServiceConfig, + on_connect: Option Box>>, _t: PhantomData<(T, P, B)>, } @@ -127,9 +144,14 @@ where S::Response: Into>, B: MessageBody + 'static, { - fn new(cfg: ServiceConfig, srv: S) -> H2ServiceHandler { + fn new( + cfg: ServiceConfig, + on_connect: Option Box>>, + srv: S, + ) -> H2ServiceHandler { H2ServiceHandler { cfg, + on_connect, srv: CloneableService::new(srv), _t: PhantomData, } @@ -161,11 +183,18 @@ where fn call(&mut self, req: Self::Request) -> Self::Future { let io = req.into_parts().0; let peer_addr = io.peer_addr(); + let on_connect = if let Some(ref on_connect) = self.on_connect { + Some(on_connect(&io)) + } else { + None + }; + H2ServiceHandlerResponse { state: State::Handshake( Some(self.srv.clone()), Some(self.cfg.clone()), peer_addr, + on_connect, server::handshake(io), ), } @@ -181,6 +210,7 @@ where Option>, Option, Option, + Option>, Handshake, ), } @@ -216,12 +246,14 @@ where ref mut srv, ref mut config, ref peer_addr, + ref mut on_connect, ref mut handshake, ) => match handshake.poll() { Ok(Async::Ready(conn)) => { self.state = State::Incoming(Dispatcher::new( srv.take().unwrap(), conn, + on_connect.take(), config.take().unwrap(), None, peer_addr.clone(), diff --git a/actix-http/src/helpers.rs b/actix-http/src/helpers.rs index e8dbcd82a..e4583ee37 100644 --- a/actix-http/src/helpers.rs +++ b/actix-http/src/helpers.rs @@ -3,6 +3,8 @@ use std::{io, mem, ptr, slice}; use bytes::{BufMut, BytesMut}; use http::Version; +use crate::extensions::Extensions; + const DEC_DIGITS_LUT: &[u8] = b"0001020304050607080910111213141516171819\ 2021222324252627282930313233343536373839\ 4041424344454647484950515253545556575859\ @@ -180,6 +182,18 @@ impl<'a> io::Write for Writer<'a> { } } +pub(crate) trait DataFactory { + fn set(&self, ext: &mut Extensions); +} + +pub(crate) struct Data(pub(crate) T); + +impl DataFactory for Data { + fn set(&self, ext: &mut Extensions) { + ext.insert(self.0.clone()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index b762f3cb9..1ac018803 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -1,5 +1,5 @@ use std::marker::PhantomData; -use std::{fmt, io, net}; +use std::{fmt, io, net, rc}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_server_config::{ @@ -15,6 +15,7 @@ use crate::body::MessageBody; use crate::builder::HttpServiceBuilder; use crate::config::{KeepAlive, ServiceConfig}; use crate::error::{DispatchError, Error}; +use crate::helpers::DataFactory; use crate::request::Request; use crate::response::Response; use crate::{h1, h2::Dispatcher}; @@ -25,6 +26,7 @@ pub struct HttpService, + on_connect: Option Box>>, _t: PhantomData<(T, P, B)>, } @@ -61,6 +63,7 @@ where srv: service.into_new_service(), expect: h1::ExpectHandler, upgrade: None, + on_connect: None, _t: PhantomData, } } @@ -75,6 +78,7 @@ where srv: service.into_new_service(), expect: h1::ExpectHandler, upgrade: None, + on_connect: None, _t: PhantomData, } } @@ -104,6 +108,7 @@ where cfg: self.cfg, srv: self.srv, upgrade: self.upgrade, + on_connect: self.on_connect, _t: PhantomData, } } @@ -127,9 +132,19 @@ where cfg: self.cfg, srv: self.srv, expect: self.expect, + on_connect: self.on_connect, _t: PhantomData, } } + + /// Set on connect callback. + pub(crate) fn on_connect( + mut self, + f: Option Box>>, + ) -> Self { + self.on_connect = f; + self + } } impl NewService for HttpService @@ -167,6 +182,7 @@ where fut_upg: self.upgrade.as_ref().map(|f| f.new_service(cfg)), expect: None, upgrade: None, + on_connect: self.on_connect.clone(), cfg: Some(self.cfg.clone()), _t: PhantomData, } @@ -180,6 +196,7 @@ pub struct HttpServiceResponse, expect: Option, upgrade: Option, + on_connect: Option Box>>, cfg: Option, _t: PhantomData<(T, P, B)>, } @@ -229,6 +246,7 @@ where service, self.expect.take().unwrap(), self.upgrade.take(), + self.on_connect.clone(), ))) } } @@ -239,6 +257,7 @@ pub struct HttpServiceHandler { expect: CloneableService, upgrade: Option>, cfg: ServiceConfig, + on_connect: Option Box>>, _t: PhantomData<(T, P, B, X)>, } @@ -259,9 +278,11 @@ where srv: S, expect: X, upgrade: Option, + on_connect: Option Box>>, ) -> HttpServiceHandler { HttpServiceHandler { cfg, + on_connect, srv: CloneableService::new(srv), expect: CloneableService::new(expect), upgrade: upgrade.map(|s| CloneableService::new(s)), @@ -319,6 +340,13 @@ where fn call(&mut self, req: Self::Request) -> Self::Future { let (io, _, proto) = req.into_parts(); + + let on_connect = if let Some(ref on_connect) = self.on_connect { + Some(on_connect(&io)) + } else { + None + }; + match proto { Protocol::Http2 => { let peer_addr = io.peer_addr(); @@ -332,6 +360,7 @@ where self.cfg.clone(), self.srv.clone(), peer_addr, + on_connect, ))), } } @@ -342,6 +371,7 @@ where self.srv.clone(), self.expect.clone(), self.upgrade.clone(), + on_connect, )), }, _ => HttpServiceHandlerResponse { @@ -352,6 +382,7 @@ where self.srv.clone(), self.expect.clone(), self.upgrade.clone(), + on_connect, ))), }, } @@ -380,6 +411,7 @@ where CloneableService, CloneableService, Option>, + Option>, )>, ), Handshake( @@ -388,6 +420,7 @@ where ServiceConfig, CloneableService, Option, + Option>, )>, ), } @@ -448,7 +481,8 @@ where } else { panic!() } - let (io, buf, cfg, srv, expect, upgrade) = data.take().unwrap(); + let (io, buf, cfg, srv, expect, upgrade, on_connect) = + data.take().unwrap(); if buf[..14] == HTTP2_PREFACE[..] { let peer_addr = io.peer_addr(); let io = Io { @@ -460,6 +494,7 @@ where cfg, srv, peer_addr, + on_connect, ))); } else { self.state = State::H1(h1::Dispatcher::with_timeout( @@ -471,6 +506,7 @@ where srv, expect, upgrade, + on_connect, )) } self.poll() @@ -488,8 +524,10 @@ where } else { panic!() }; - let (_, cfg, srv, peer_addr) = data.take().unwrap(); - self.state = State::H2(Dispatcher::new(srv, conn, cfg, None, peer_addr)); + let (_, cfg, srv, peer_addr, on_connect) = data.take().unwrap(); + self.state = State::H2(Dispatcher::new( + srv, conn, on_connect, cfg, None, peer_addr, + )); self.poll() } }