diff --git a/actix-http/src/cloneable.rs b/actix-http/src/cloneable.rs deleted file mode 100644 index 5f0b1ea28..000000000 --- a/actix-http/src/cloneable.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::cell::RefCell; -use std::rc::Rc; -use std::task::{Context, Poll}; - -use actix_service::Service; - -/// Service that allows to turn non-clone service to a service with `Clone` impl -/// -/// # Panics -/// CloneableService might panic with some creative use of thread local storage. -/// See https://github.com/actix/actix-web/issues/1295 for example -#[doc(hidden)] -pub(crate) struct CloneableService(Rc>); - -impl CloneableService { - pub(crate) fn new(service: T) -> Self { - Self(Rc::new(RefCell::new(service))) - } -} - -impl Clone for CloneableService { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl, Req> Service for CloneableService { - type Response = T::Response; - type Error = T::Error; - type Future = T::Future; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.0.borrow_mut().poll_ready(cx) - } - - fn call(&mut self, req: Req) -> Self::Future { - self.0.borrow_mut().call(req) - } -} diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index a9510dc1e..60552d102 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -1,9 +1,11 @@ use std::{ + cell::RefCell, collections::VecDeque, fmt, future::Future, io, mem, net, pin::Pin, + rc::Rc, task::{Context, Poll}, }; @@ -15,17 +17,14 @@ use bytes::{Buf, BytesMut}; use log::{error, trace}; use pin_project::pin_project; -use crate::cloneable::CloneableService; +use crate::body::{Body, BodySize, MessageBody, ResponseBody}; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; use crate::error::{ParseError, PayloadError}; -use crate::httpmessage::HttpMessage; use crate::request::Request; use crate::response::Response; -use crate::{ - body::{Body, BodySize, MessageBody, ResponseBody}, - Extensions, -}; +use crate::service::HttpFlow; +use crate::OnConnectData; use super::codec::Codec; use super::payload::{Payload, PayloadSender, PayloadStatus}; @@ -78,7 +77,7 @@ where U::Error: fmt::Display, { Normal(#[pin] InnerDispatcher), - Upgrade(Pin>), + Upgrade(#[pin] U::Future), } #[pin_project(project = InnerDispatcherProj)] @@ -92,10 +91,8 @@ where U: Service<(Request, Framed), Response = ()>, U::Error: fmt::Display, { - service: CloneableService, - expect: CloneableService, - upgrade: Option>, - on_connect_data: Extensions, + services: Rc>>, + on_connect_data: OnConnectData, flags: Flags, peer_addr: Option, error: Option, @@ -180,10 +177,8 @@ where pub(crate) fn new( stream: T, config: ServiceConfig, - service: CloneableService, - expect: CloneableService, - upgrade: Option>, - on_connect_data: Extensions, + services: Rc>>, + on_connect_data: OnConnectData, peer_addr: Option, ) -> Self { Dispatcher::with_timeout( @@ -192,9 +187,7 @@ where config, BytesMut::with_capacity(HW_BUFFER_SIZE), None, - service, - expect, - upgrade, + services, on_connect_data, peer_addr, ) @@ -207,10 +200,8 @@ where config: ServiceConfig, read_buf: BytesMut, timeout: Option, - service: CloneableService, - expect: CloneableService, - upgrade: Option>, - on_connect_data: Extensions, + services: Rc>>, + on_connect_data: OnConnectData, peer_addr: Option, ) -> Self { let keepalive = config.keep_alive_enabled(); @@ -239,9 +230,7 @@ where io: Some(io), codec, read_buf, - service, - expect, - upgrade, + services, on_connect_data, flags, peer_addr, @@ -395,7 +384,8 @@ where Poll::Ready(Ok(req)) => { self.as_mut().send_continue(); this = self.as_mut().project(); - this.state.set(State::ServiceCall(this.service.call(req))); + let fut = this.services.borrow_mut().service.call(req); + this.state.set(State::ServiceCall(fut)); continue; } Poll::Ready(Err(e)) => { @@ -483,12 +473,14 @@ where // Handle `EXPECT: 100-Continue` header if req.head().expect() { // set dispatcher state so the future is pinned. - let task = self.as_mut().project().expect.call(req); - self.as_mut().project().state.set(State::ExpectCall(task)); + let mut this = self.as_mut().project(); + let task = this.services.borrow_mut().expect.call(req); + this.state.set(State::ExpectCall(task)); } else { // the same as above. - let task = self.as_mut().project().service.call(req); - self.as_mut().project().state.set(State::ServiceCall(task)); + let mut this = self.as_mut().project(); + let task = this.services.borrow_mut().service.call(req); + this.state.set(State::ServiceCall(task)); }; // eagerly poll the future for once(or twice if expect is resolved immediately). @@ -499,8 +491,9 @@ where // expect is resolved. continue loop and poll the service call branch. Poll::Ready(Ok(req)) => { self.as_mut().send_continue(); - let task = self.as_mut().project().service.call(req); - self.as_mut().project().state.set(State::ServiceCall(task)); + let mut this = self.as_mut().project(); + let task = this.services.borrow_mut().service.call(req); + this.state.set(State::ServiceCall(task)); continue; } // future is pending. return Ok(()) to notify that a new state is @@ -568,9 +561,11 @@ where req.head_mut().peer_addr = *this.peer_addr; // merge on_connect_ext data into request extensions - req.extensions_mut().drain_from(this.on_connect_data); + this.on_connect_data.merge_into(&mut req); - if pl == MessageType::Stream && this.upgrade.is_some() { + if pl == MessageType::Stream + && this.services.borrow().upgrade.is_some() + { this.messages.push_back(DispatcherMessage::Upgrade(req)); break; } @@ -834,12 +829,17 @@ where ); parts.write_buf = mem::take(inner_p.write_buf); let framed = Framed::from_parts(parts); - let upgrade = - inner_p.upgrade.take().unwrap().call((req, framed)); + let upgrade = inner_p + .services + .borrow_mut() + .upgrade + .take() + .unwrap() + .call((req, framed)); self.as_mut() .project() .inner - .set(DispatcherState::Upgrade(Box::pin(upgrade))); + .set(DispatcherState::Upgrade(upgrade)); return self.poll(cx); } @@ -890,7 +890,7 @@ where } } } - DispatcherStateProj::Upgrade(fut) => fut.as_mut().poll(cx).map_err(|e| { + DispatcherStateProj::Upgrade(fut) => fut.poll(cx).map_err(|e| { error!("Upgrade handler error: {}", e); DispatchError::Upgrade }), @@ -1028,13 +1028,13 @@ mod tests { lazy(|cx| { let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); + let services = HttpFlow::new(ok_service(), ExpectHandler, None); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, ServiceConfig::default(), - CloneableService::new(ok_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); @@ -1068,13 +1068,13 @@ mod tests { let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); + let services = HttpFlow::new(echo_path_service(), ExpectHandler, None); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, cfg, - CloneableService::new(echo_path_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); @@ -1122,13 +1122,13 @@ mod tests { let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); + let services = HttpFlow::new(echo_path_service(), ExpectHandler, None); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf, cfg, - CloneableService::new(echo_path_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); @@ -1171,13 +1171,14 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); + + let services = HttpFlow::new(echo_payload_service(), ExpectHandler, None); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf.clone(), cfg, - CloneableService::new(echo_payload_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); @@ -1242,13 +1243,14 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); + + let services = HttpFlow::new(echo_path_service(), ExpectHandler, None); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf.clone(), cfg, - CloneableService::new(echo_path_service()), - CloneableService::new(ExpectHandler), - None, - Extensions::new(), + services, + OnConnectData::default(), None, ); @@ -1301,13 +1303,15 @@ mod tests { lazy(|cx| { let mut buf = TestSeqBuffer::empty(); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); + + let services = + HttpFlow::new(ok_service(), ExpectHandler, Some(UpgradeHandler)); + let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( buf.clone(), cfg, - CloneableService::new(ok_service()), - CloneableService::new(ExpectHandler), - Some(CloneableService::new(UpgradeHandler)), - Extensions::new(), + services, + OnConnectData::default(), None, ); diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index 34b7e31a1..19272c133 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; @@ -12,12 +13,12 @@ use futures_core::ready; use futures_util::future::ready; use crate::body::MessageBody; -use crate::cloneable::CloneableService; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; use crate::request::Request; use crate::response::Response; -use crate::{ConnectCallback, Extensions}; +use crate::service::HttpFlow; +use crate::{ConnectCallback, OnConnectData}; use super::codec::Codec; use super::dispatcher::Dispatcher; @@ -299,7 +300,7 @@ where upgrade: Option, on_connect_ext: Option>>, cfg: Option, - _phantom: PhantomData<(T, B)>, + _phantom: PhantomData, } impl Future for H1ServiceResponse @@ -366,9 +367,7 @@ where X: Service, U: Service<(Request, Framed)>, { - srv: CloneableService, - expect: CloneableService, - upgrade: Option>, + services: Rc>>, on_connect_ext: Option>>, cfg: ServiceConfig, _phantom: PhantomData, @@ -387,15 +386,13 @@ where { fn new( cfg: ServiceConfig, - srv: S, + service: S, expect: X, upgrade: Option, on_connect_ext: Option>>, ) -> H1ServiceHandler { H1ServiceHandler { - srv: CloneableService::new(srv), - expect: CloneableService::new(expect), - upgrade: upgrade.map(CloneableService::new), + services: HttpFlow::new(service, expect, upgrade), cfg, on_connect_ext, _phantom: PhantomData, @@ -421,7 +418,8 @@ where type Future = Dispatcher; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - let ready = self + let mut services = self.services.borrow_mut(); + let ready = services .expect .poll_ready(cx) .map_err(|e| { @@ -431,8 +429,8 @@ where })? .is_ready(); - let ready = self - .srv + let ready = services + .service .poll_ready(cx) .map_err(|e| { let e = e.into(); @@ -442,7 +440,7 @@ where .is_ready() && ready; - let ready = if let Some(ref mut upg) = self.upgrade { + let ready = if let Some(ref mut upg) = services.upgrade { upg.poll_ready(cx) .map_err(|e| { let e = e.into(); @@ -463,19 +461,14 @@ where } fn call(&mut self, (io, addr): (T, Option)) -> Self::Future { - let mut connect_extensions = Extensions::new(); - if let Some(ref handler) = self.on_connect_ext { - // run on_connect_ext callback, populating connect extensions - handler(&io, &mut connect_extensions); - } + let on_connect_data = + OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); Dispatcher::new( io, self.cfg.clone(), - self.srv.clone(), - self.expect.clone(), - self.upgrade.clone(), - connect_extensions, + self.services.clone(), + on_connect_data, addr, ) } diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index b8828edd0..621035869 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -1,7 +1,9 @@ +use std::cell::RefCell; use std::future::Future; use std::marker::PhantomData; use std::net; use std::pin::Pin; +use std::rc::Rc; use std::task::{Context, Poll}; use std::{cmp, convert::TryFrom}; @@ -16,29 +18,28 @@ use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCOD use log::{error, trace}; use crate::body::{BodySize, MessageBody, ResponseBody}; -use crate::cloneable::CloneableService; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; -use crate::httpmessage::HttpMessage; use crate::message::ResponseHead; use crate::payload::Payload; use crate::request::Request; use crate::response::Response; -use crate::Extensions; +use crate::service::HttpFlow; +use crate::OnConnectData; const CHUNK_SIZE: usize = 16_384; /// Dispatcher for HTTP/2 protocol. #[pin_project::pin_project] -pub struct Dispatcher +pub struct Dispatcher where T: AsyncRead + AsyncWrite + Unpin, S: Service, B: MessageBody, { - service: CloneableService, + services: Rc>>, connection: Connection, - on_connect_data: Extensions, + on_connect_data: OnConnectData, config: ServiceConfig, peer_addr: Option, ka_expire: Instant, @@ -46,7 +47,7 @@ where _phantom: PhantomData, } -impl Dispatcher +impl Dispatcher where T: AsyncRead + AsyncWrite + Unpin, S: Service, @@ -55,9 +56,9 @@ where B: MessageBody, { pub(crate) fn new( - service: CloneableService, + services: Rc>>, connection: Connection, - on_connect_data: Extensions, + on_connect_data: OnConnectData, config: ServiceConfig, timeout: Option, peer_addr: Option, @@ -79,7 +80,7 @@ where }; Dispatcher { - service, + services, config, peer_addr, connection, @@ -91,7 +92,7 @@ where } } -impl Future for Dispatcher +impl Future for Dispatcher where T: AsyncRead + AsyncWrite + Unpin, S: Service, @@ -133,11 +134,11 @@ where head.peer_addr = this.peer_addr; // merge on_connect_ext data into request extensions - req.extensions_mut().drain_from(&mut this.on_connect_data); + this.on_connect_data.merge_into(&mut req); let svc = ServiceResponse:: { state: ServiceResponseState::ServiceCall( - this.service.call(req), + this.services.borrow_mut().service.call(req), Some(res), ), config: this.config.clone(), diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index 462f5c2c1..f94aae79e 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; @@ -17,12 +18,12 @@ use h2::server::{self, Handshake}; use log::error; use crate::body::MessageBody; -use crate::cloneable::CloneableService; use crate::config::ServiceConfig; use crate::error::{DispatchError, Error}; use crate::request::Request; use crate::response::Response; -use crate::{ConnectCallback, Extensions}; +use crate::service::HttpFlow; +use crate::{ConnectCallback, OnConnectData}; use super::dispatcher::Dispatcher; @@ -248,7 +249,7 @@ pub struct H2ServiceHandler where S: Service, { - srv: CloneableService, + services: Rc>>, cfg: ServiceConfig, on_connect_ext: Option>>, _phantom: PhantomData, @@ -265,12 +266,12 @@ where fn new( cfg: ServiceConfig, on_connect_ext: Option>>, - srv: S, + service: S, ) -> H2ServiceHandler { H2ServiceHandler { + services: HttpFlow::new(service, (), None), cfg, on_connect_ext, - srv: CloneableService::new(srv), _phantom: PhantomData, } } @@ -290,26 +291,27 @@ where type Future = H2ServiceHandlerResponse; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.srv.poll_ready(cx).map_err(|e| { - let e = e.into(); - error!("Service readiness error: {:?}", e); - DispatchError::Service(e) - }) + self.services + .borrow_mut() + .service + .poll_ready(cx) + .map_err(|e| { + let e = e.into(); + error!("Service readiness error: {:?}", e); + DispatchError::Service(e) + }) } fn call(&mut self, (io, addr): (T, Option)) -> Self::Future { - let mut connect_extensions = Extensions::new(); - if let Some(ref handler) = self.on_connect_ext { - // run on_connect_ext callback, populating connect extensions - handler(&io, &mut connect_extensions); - } + let on_connect_data = + OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); H2ServiceHandlerResponse { state: State::Handshake( - Some(self.srv.clone()), + Some(self.services.clone()), Some(self.cfg.clone()), addr, - Some(connect_extensions), + on_connect_data, server::handshake(io), ), } @@ -321,12 +323,12 @@ where T: AsyncRead + AsyncWrite + Unpin, S::Future: 'static, { - Incoming(Dispatcher), + Incoming(Dispatcher), Handshake( - Option>, + Option>>>, Option, Option, - Option, + OnConnectData, Handshake, ), } @@ -365,10 +367,11 @@ where ref mut handshake, ) => match ready!(Pin::new(handshake).poll(cx)) { Ok(conn) => { + let on_connect_data = std::mem::take(on_connect_data); self.state = State::Incoming(Dispatcher::new( srv.take().unwrap(), conn, - on_connect_data.take().unwrap(), + on_connect_data, config.take().unwrap(), None, *peer_addr, diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index 94cc50a76..0c58df2ed 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -19,7 +19,6 @@ mod macros; pub mod body; mod builder; pub mod client; -mod cloneable; mod config; #[cfg(feature = "compress")] pub mod encoding; @@ -81,3 +80,36 @@ pub enum Protocol { } type ConnectCallback = dyn Fn(&IO, &mut Extensions); + +/// Container for data that extract with ConnectCallback. +pub(crate) struct OnConnectData(Option); + +impl Default for OnConnectData { + fn default() -> Self { + Self(None) + } +} + +impl OnConnectData { + // construct self from io. + pub(crate) fn from_io( + io: &T, + on_connect_ext: Option<&ConnectCallback>, + ) -> Self { + let ext = on_connect_ext.map(|handler| { + let mut extensions = Extensions::new(); + handler(io, &mut extensions); + extensions + }); + + Self(ext) + } + + // merge self to given request's head extension. + #[inline] + pub(crate) fn merge_into(&mut self, req: &mut Request) { + if let Some(ref mut ext) = self.0 { + req.head.extensions.get_mut().drain_from(ext); + } + } +} diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index af625b1bf..eb16a6e70 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; @@ -14,12 +15,11 @@ use pin_project::pin_project; use crate::body::MessageBody; use crate::builder::HttpServiceBuilder; -use crate::cloneable::CloneableService; use crate::config::{KeepAlive, ServiceConfig}; use crate::error::{DispatchError, Error}; use crate::request::Request; use crate::response::Response; -use crate::{h1, h2::Dispatcher, ConnectCallback, Extensions, Protocol}; +use crate::{h1, h2::Dispatcher, ConnectCallback, OnConnectData, Protocol}; /// A `ServiceFactory` for HTTP/1.1 or HTTP/2 protocol. pub struct HttpService { @@ -371,7 +371,7 @@ where upgrade: Option, on_connect_ext: Option>>, cfg: ServiceConfig, - _phantom: PhantomData<(T, B)>, + _phantom: PhantomData, } impl Future for HttpServiceResponse @@ -441,14 +441,29 @@ where X: Service, U: Service<(Request, Framed)>, { - srv: CloneableService, - expect: CloneableService, - upgrade: Option>, + services: Rc>>, cfg: ServiceConfig, on_connect_ext: Option>>, _phantom: PhantomData, } +// a collection of service for http. +pub(super) struct HttpFlow { + pub(super) service: S, + pub(super) expect: X, + pub(super) upgrade: Option, +} + +impl HttpFlow { + pub(super) fn new(service: S, expect: X, upgrade: Option) -> Rc> { + Rc::new(RefCell::new(Self { + service, + expect, + upgrade, + })) + } +} + impl HttpServiceHandler where S: Service, @@ -463,7 +478,7 @@ where { fn new( cfg: ServiceConfig, - srv: S, + service: S, expect: X, upgrade: Option, on_connect_ext: Option>>, @@ -471,9 +486,7 @@ where HttpServiceHandler { cfg, on_connect_ext, - srv: CloneableService::new(srv), - expect: CloneableService::new(expect), - upgrade: upgrade.map(CloneableService::new), + services: HttpFlow::new(service, expect, upgrade), _phantom: PhantomData, } } @@ -498,7 +511,8 @@ where type Future = HttpServiceHandlerResponse; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - let ready = self + let mut services = self.services.borrow_mut(); + let ready = services .expect .poll_ready(cx) .map_err(|e| { @@ -508,8 +522,8 @@ where })? .is_ready(); - let ready = self - .srv + let ready = services + .service .poll_ready(cx) .map_err(|e| { let e = e.into(); @@ -519,7 +533,7 @@ where .is_ready() && ready; - let ready = if let Some(ref mut upg) = self.upgrade { + let ready = if let Some(ref mut upg) = services.upgrade { upg.poll_ready(cx) .map_err(|e| { let e = e.into(); @@ -543,19 +557,16 @@ where &mut self, (io, proto, peer_addr): (T, Protocol, Option), ) -> Self::Future { - let mut connect_extensions = Extensions::new(); - - if let Some(ref handler) = self.on_connect_ext { - handler(&io, &mut connect_extensions); - } + let on_connect_data = + OnConnectData::from_io(&io, self.on_connect_ext.as_deref()); match proto { Protocol::Http2 => HttpServiceHandlerResponse { state: State::H2Handshake(Some(( server::handshake(io), self.cfg.clone(), - self.srv.clone(), - connect_extensions, + self.services.clone(), + on_connect_data, peer_addr, ))), }, @@ -564,10 +575,8 @@ where state: State::H1(h1::Dispatcher::new( io, self.cfg.clone(), - self.srv.clone(), - self.expect.clone(), - self.upgrade.clone(), - connect_extensions, + self.services.clone(), + on_connect_data, peer_addr, )), }, @@ -589,13 +598,13 @@ where U::Error: fmt::Display, { H1(#[pin] h1::Dispatcher), - H2(#[pin] Dispatcher), + H2(#[pin] Dispatcher), H2Handshake( Option<( Handshake, ServiceConfig, - CloneableService, - Extensions, + Rc>>, + OnConnectData, Option, )>, ),