1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-11-18 15:41:17 +00:00

replace cloneable service with httpflow abstraction (#1876)

This commit is contained in:
fakeshadow 2021-01-07 02:43:52 +08:00 committed by GitHub
parent 57a3722146
commit a03dbe2dcf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 190 additions and 187 deletions

View file

@ -1,39 +0,0 @@
use std::cell::RefCell;
use std::rc::Rc;
use std::task::{Context, Poll};
use actix_service::Service;
/// Service that allows to turn non-clone service to a service with `Clone` impl
///
/// # Panics
/// CloneableService might panic with some creative use of thread local storage.
/// See https://github.com/actix/actix-web/issues/1295 for example
#[doc(hidden)]
pub(crate) struct CloneableService<T>(Rc<RefCell<T>>);
impl<T> CloneableService<T> {
pub(crate) fn new(service: T) -> Self {
Self(Rc::new(RefCell::new(service)))
}
}
impl<T> Clone for CloneableService<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T: Service<Req>, Req> Service<Req> for CloneableService<T> {
type Response = T::Response;
type Error = T::Error;
type Future = T::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.borrow_mut().poll_ready(cx)
}
fn call(&mut self, req: Req) -> Self::Future {
self.0.borrow_mut().call(req)
}
}

View file

@ -1,9 +1,11 @@
use std::{ use std::{
cell::RefCell,
collections::VecDeque, collections::VecDeque,
fmt, fmt,
future::Future, future::Future,
io, mem, net, io, mem, net,
pin::Pin, pin::Pin,
rc::Rc,
task::{Context, Poll}, task::{Context, Poll},
}; };
@ -15,17 +17,14 @@ use bytes::{Buf, BytesMut};
use log::{error, trace}; use log::{error, trace};
use pin_project::pin_project; use pin_project::pin_project;
use crate::cloneable::CloneableService; use crate::body::{Body, BodySize, MessageBody, ResponseBody};
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error};
use crate::error::{ParseError, PayloadError}; use crate::error::{ParseError, PayloadError};
use crate::httpmessage::HttpMessage;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{ use crate::service::HttpFlow;
body::{Body, BodySize, MessageBody, ResponseBody}, use crate::OnConnectData;
Extensions,
};
use super::codec::Codec; use super::codec::Codec;
use super::payload::{Payload, PayloadSender, PayloadStatus}; use super::payload::{Payload, PayloadSender, PayloadStatus};
@ -78,7 +77,7 @@ where
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
Normal(#[pin] InnerDispatcher<T, S, B, X, U>), Normal(#[pin] InnerDispatcher<T, S, B, X, U>),
Upgrade(Pin<Box<U::Future>>), Upgrade(#[pin] U::Future),
} }
#[pin_project(project = InnerDispatcherProj)] #[pin_project(project = InnerDispatcherProj)]
@ -92,10 +91,8 @@ where
U: Service<(Request, Framed<T, Codec>), Response = ()>, U: Service<(Request, Framed<T, Codec>), Response = ()>,
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
service: CloneableService<S>, services: Rc<RefCell<HttpFlow<S, X, U>>>,
expect: CloneableService<X>, on_connect_data: OnConnectData,
upgrade: Option<CloneableService<U>>,
on_connect_data: Extensions,
flags: Flags, flags: Flags,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
error: Option<DispatchError>, error: Option<DispatchError>,
@ -180,10 +177,8 @@ where
pub(crate) fn new( pub(crate) fn new(
stream: T, stream: T,
config: ServiceConfig, config: ServiceConfig,
service: CloneableService<S>, services: Rc<RefCell<HttpFlow<S, X, U>>>,
expect: CloneableService<X>, on_connect_data: OnConnectData,
upgrade: Option<CloneableService<U>>,
on_connect_data: Extensions,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
) -> Self { ) -> Self {
Dispatcher::with_timeout( Dispatcher::with_timeout(
@ -192,9 +187,7 @@ where
config, config,
BytesMut::with_capacity(HW_BUFFER_SIZE), BytesMut::with_capacity(HW_BUFFER_SIZE),
None, None,
service, services,
expect,
upgrade,
on_connect_data, on_connect_data,
peer_addr, peer_addr,
) )
@ -207,10 +200,8 @@ where
config: ServiceConfig, config: ServiceConfig,
read_buf: BytesMut, read_buf: BytesMut,
timeout: Option<Sleep>, timeout: Option<Sleep>,
service: CloneableService<S>, services: Rc<RefCell<HttpFlow<S, X, U>>>,
expect: CloneableService<X>, on_connect_data: OnConnectData,
upgrade: Option<CloneableService<U>>,
on_connect_data: Extensions,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
) -> Self { ) -> Self {
let keepalive = config.keep_alive_enabled(); let keepalive = config.keep_alive_enabled();
@ -239,9 +230,7 @@ where
io: Some(io), io: Some(io),
codec, codec,
read_buf, read_buf,
service, services,
expect,
upgrade,
on_connect_data, on_connect_data,
flags, flags,
peer_addr, peer_addr,
@ -395,7 +384,8 @@ where
Poll::Ready(Ok(req)) => { Poll::Ready(Ok(req)) => {
self.as_mut().send_continue(); self.as_mut().send_continue();
this = self.as_mut().project(); this = self.as_mut().project();
this.state.set(State::ServiceCall(this.service.call(req))); let fut = this.services.borrow_mut().service.call(req);
this.state.set(State::ServiceCall(fut));
continue; continue;
} }
Poll::Ready(Err(e)) => { Poll::Ready(Err(e)) => {
@ -483,12 +473,14 @@ where
// Handle `EXPECT: 100-Continue` header // Handle `EXPECT: 100-Continue` header
if req.head().expect() { if req.head().expect() {
// set dispatcher state so the future is pinned. // set dispatcher state so the future is pinned.
let task = self.as_mut().project().expect.call(req); let mut this = self.as_mut().project();
self.as_mut().project().state.set(State::ExpectCall(task)); let task = this.services.borrow_mut().expect.call(req);
this.state.set(State::ExpectCall(task));
} else { } else {
// the same as above. // the same as above.
let task = self.as_mut().project().service.call(req); let mut this = self.as_mut().project();
self.as_mut().project().state.set(State::ServiceCall(task)); let task = this.services.borrow_mut().service.call(req);
this.state.set(State::ServiceCall(task));
}; };
// eagerly poll the future for once(or twice if expect is resolved immediately). // eagerly poll the future for once(or twice if expect is resolved immediately).
@ -499,8 +491,9 @@ where
// expect is resolved. continue loop and poll the service call branch. // expect is resolved. continue loop and poll the service call branch.
Poll::Ready(Ok(req)) => { Poll::Ready(Ok(req)) => {
self.as_mut().send_continue(); self.as_mut().send_continue();
let task = self.as_mut().project().service.call(req); let mut this = self.as_mut().project();
self.as_mut().project().state.set(State::ServiceCall(task)); let task = this.services.borrow_mut().service.call(req);
this.state.set(State::ServiceCall(task));
continue; continue;
} }
// future is pending. return Ok(()) to notify that a new state is // future is pending. return Ok(()) to notify that a new state is
@ -568,9 +561,11 @@ where
req.head_mut().peer_addr = *this.peer_addr; req.head_mut().peer_addr = *this.peer_addr;
// merge on_connect_ext data into request extensions // merge on_connect_ext data into request extensions
req.extensions_mut().drain_from(this.on_connect_data); this.on_connect_data.merge_into(&mut req);
if pl == MessageType::Stream && this.upgrade.is_some() { if pl == MessageType::Stream
&& this.services.borrow().upgrade.is_some()
{
this.messages.push_back(DispatcherMessage::Upgrade(req)); this.messages.push_back(DispatcherMessage::Upgrade(req));
break; break;
} }
@ -834,12 +829,17 @@ where
); );
parts.write_buf = mem::take(inner_p.write_buf); parts.write_buf = mem::take(inner_p.write_buf);
let framed = Framed::from_parts(parts); let framed = Framed::from_parts(parts);
let upgrade = let upgrade = inner_p
inner_p.upgrade.take().unwrap().call((req, framed)); .services
.borrow_mut()
.upgrade
.take()
.unwrap()
.call((req, framed));
self.as_mut() self.as_mut()
.project() .project()
.inner .inner
.set(DispatcherState::Upgrade(Box::pin(upgrade))); .set(DispatcherState::Upgrade(upgrade));
return self.poll(cx); return self.poll(cx);
} }
@ -890,7 +890,7 @@ where
} }
} }
} }
DispatcherStateProj::Upgrade(fut) => fut.as_mut().poll(cx).map_err(|e| { DispatcherStateProj::Upgrade(fut) => fut.poll(cx).map_err(|e| {
error!("Upgrade handler error: {}", e); error!("Upgrade handler error: {}", e);
DispatchError::Upgrade DispatchError::Upgrade
}), }),
@ -1028,13 +1028,13 @@ mod tests {
lazy(|cx| { lazy(|cx| {
let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n"); let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n");
let services = HttpFlow::new(ok_service(), ExpectHandler, None);
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new(
buf, buf,
ServiceConfig::default(), ServiceConfig::default(),
CloneableService::new(ok_service()), services,
CloneableService::new(ExpectHandler), OnConnectData::default(),
None,
Extensions::new(),
None, None,
); );
@ -1068,13 +1068,13 @@ mod tests {
let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None);
let services = HttpFlow::new(echo_path_service(), ExpectHandler, None);
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new(
buf, buf,
cfg, cfg,
CloneableService::new(echo_path_service()), services,
CloneableService::new(ExpectHandler), OnConnectData::default(),
None,
Extensions::new(),
None, None,
); );
@ -1122,13 +1122,13 @@ mod tests {
let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None); let cfg = ServiceConfig::new(KeepAlive::Disabled, 1, 1, false, None);
let services = HttpFlow::new(echo_path_service(), ExpectHandler, None);
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new(
buf, buf,
cfg, cfg,
CloneableService::new(echo_path_service()), services,
CloneableService::new(ExpectHandler), OnConnectData::default(),
None,
Extensions::new(),
None, None,
); );
@ -1171,13 +1171,14 @@ mod tests {
lazy(|cx| { lazy(|cx| {
let mut buf = TestSeqBuffer::empty(); let mut buf = TestSeqBuffer::empty();
let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None);
let services = HttpFlow::new(echo_payload_service(), ExpectHandler, None);
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new(
buf.clone(), buf.clone(),
cfg, cfg,
CloneableService::new(echo_payload_service()), services,
CloneableService::new(ExpectHandler), OnConnectData::default(),
None,
Extensions::new(),
None, None,
); );
@ -1242,13 +1243,14 @@ mod tests {
lazy(|cx| { lazy(|cx| {
let mut buf = TestSeqBuffer::empty(); let mut buf = TestSeqBuffer::empty();
let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None);
let services = HttpFlow::new(echo_path_service(), ExpectHandler, None);
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new(
buf.clone(), buf.clone(),
cfg, cfg,
CloneableService::new(echo_path_service()), services,
CloneableService::new(ExpectHandler), OnConnectData::default(),
None,
Extensions::new(),
None, None,
); );
@ -1301,13 +1303,15 @@ mod tests {
lazy(|cx| { lazy(|cx| {
let mut buf = TestSeqBuffer::empty(); let mut buf = TestSeqBuffer::empty();
let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None); let cfg = ServiceConfig::new(KeepAlive::Disabled, 0, 0, false, None);
let services =
HttpFlow::new(ok_service(), ExpectHandler, Some(UpgradeHandler));
let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new( let h1 = Dispatcher::<_, _, _, _, UpgradeHandler>::new(
buf.clone(), buf.clone(),
cfg, cfg,
CloneableService::new(ok_service()), services,
CloneableService::new(ExpectHandler), OnConnectData::default(),
Some(CloneableService::new(UpgradeHandler)),
Extensions::new(),
None, None,
); );

View file

@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
@ -12,12 +13,12 @@ use futures_core::ready;
use futures_util::future::ready; use futures_util::future::ready;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::cloneable::CloneableService;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error};
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{ConnectCallback, Extensions}; use crate::service::HttpFlow;
use crate::{ConnectCallback, OnConnectData};
use super::codec::Codec; use super::codec::Codec;
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
@ -299,7 +300,7 @@ where
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: Option<ServiceConfig>, cfg: Option<ServiceConfig>,
_phantom: PhantomData<(T, B)>, _phantom: PhantomData<B>,
} }
impl<T, S, B, X, U> Future for H1ServiceResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for H1ServiceResponse<T, S, B, X, U>
@ -366,9 +367,7 @@ where
X: Service<Request>, X: Service<Request>,
U: Service<(Request, Framed<T, Codec>)>, U: Service<(Request, Framed<T, Codec>)>,
{ {
srv: CloneableService<S>, services: Rc<RefCell<HttpFlow<S, X, U>>>,
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: ServiceConfig, cfg: ServiceConfig,
_phantom: PhantomData<B>, _phantom: PhantomData<B>,
@ -387,15 +386,13 @@ where
{ {
fn new( fn new(
cfg: ServiceConfig, cfg: ServiceConfig,
srv: S, service: S,
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
) -> H1ServiceHandler<T, S, B, X, U> { ) -> H1ServiceHandler<T, S, B, X, U> {
H1ServiceHandler { H1ServiceHandler {
srv: CloneableService::new(srv), services: HttpFlow::new(service, expect, upgrade),
expect: CloneableService::new(expect),
upgrade: upgrade.map(CloneableService::new),
cfg, cfg,
on_connect_ext, on_connect_ext,
_phantom: PhantomData, _phantom: PhantomData,
@ -421,7 +418,8 @@ where
type Future = Dispatcher<T, S, B, X, U>; type Future = Dispatcher<T, S, B, X, U>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let ready = self let mut services = self.services.borrow_mut();
let ready = services
.expect .expect
.poll_ready(cx) .poll_ready(cx)
.map_err(|e| { .map_err(|e| {
@ -431,8 +429,8 @@ where
})? })?
.is_ready(); .is_ready();
let ready = self let ready = services
.srv .service
.poll_ready(cx) .poll_ready(cx)
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
@ -442,7 +440,7 @@ where
.is_ready() .is_ready()
&& ready; && ready;
let ready = if let Some(ref mut upg) = self.upgrade { let ready = if let Some(ref mut upg) = services.upgrade {
upg.poll_ready(cx) upg.poll_ready(cx)
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
@ -463,19 +461,14 @@ where
} }
fn call(&mut self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future { fn call(&mut self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future {
let mut connect_extensions = Extensions::new(); let on_connect_data =
if let Some(ref handler) = self.on_connect_ext { OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
// run on_connect_ext callback, populating connect extensions
handler(&io, &mut connect_extensions);
}
Dispatcher::new( Dispatcher::new(
io, io,
self.cfg.clone(), self.cfg.clone(),
self.srv.clone(), self.services.clone(),
self.expect.clone(), on_connect_data,
self.upgrade.clone(),
connect_extensions,
addr, addr,
) )
} }

View file

@ -1,7 +1,9 @@
use std::cell::RefCell;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::net; use std::net;
use std::pin::Pin; use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::{cmp, convert::TryFrom}; use std::{cmp, convert::TryFrom};
@ -16,29 +18,28 @@ use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCOD
use log::{error, trace}; use log::{error, trace};
use crate::body::{BodySize, MessageBody, ResponseBody}; use crate::body::{BodySize, MessageBody, ResponseBody};
use crate::cloneable::CloneableService;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error};
use crate::httpmessage::HttpMessage;
use crate::message::ResponseHead; use crate::message::ResponseHead;
use crate::payload::Payload; use crate::payload::Payload;
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::Extensions; use crate::service::HttpFlow;
use crate::OnConnectData;
const CHUNK_SIZE: usize = 16_384; const CHUNK_SIZE: usize = 16_384;
/// Dispatcher for HTTP/2 protocol. /// Dispatcher for HTTP/2 protocol.
#[pin_project::pin_project] #[pin_project::pin_project]
pub struct Dispatcher<T, S, B> pub struct Dispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>, S: Service<Request>,
B: MessageBody, B: MessageBody,
{ {
service: CloneableService<S>, services: Rc<RefCell<HttpFlow<S, X, U>>>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
on_connect_data: Extensions, on_connect_data: OnConnectData,
config: ServiceConfig, config: ServiceConfig,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
ka_expire: Instant, ka_expire: Instant,
@ -46,7 +47,7 @@ where
_phantom: PhantomData<B>, _phantom: PhantomData<B>,
} }
impl<T, S, B> Dispatcher<T, S, B> impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>, S: Service<Request>,
@ -55,9 +56,9 @@ where
B: MessageBody, B: MessageBody,
{ {
pub(crate) fn new( pub(crate) fn new(
service: CloneableService<S>, services: Rc<RefCell<HttpFlow<S, X, U>>>,
connection: Connection<T, Bytes>, connection: Connection<T, Bytes>,
on_connect_data: Extensions, on_connect_data: OnConnectData,
config: ServiceConfig, config: ServiceConfig,
timeout: Option<Sleep>, timeout: Option<Sleep>,
peer_addr: Option<net::SocketAddr>, peer_addr: Option<net::SocketAddr>,
@ -79,7 +80,7 @@ where
}; };
Dispatcher { Dispatcher {
service, services,
config, config,
peer_addr, peer_addr,
connection, connection,
@ -91,7 +92,7 @@ where
} }
} }
impl<T, S, B> Future for Dispatcher<T, S, B> impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
S: Service<Request>, S: Service<Request>,
@ -133,11 +134,11 @@ where
head.peer_addr = this.peer_addr; head.peer_addr = this.peer_addr;
// merge on_connect_ext data into request extensions // merge on_connect_ext data into request extensions
req.extensions_mut().drain_from(&mut this.on_connect_data); this.on_connect_data.merge_into(&mut req);
let svc = ServiceResponse::<S::Future, S::Response, S::Error, B> { let svc = ServiceResponse::<S::Future, S::Response, S::Error, B> {
state: ServiceResponseState::ServiceCall( state: ServiceResponseState::ServiceCall(
this.service.call(req), this.services.borrow_mut().service.call(req),
Some(res), Some(res),
), ),
config: this.config.clone(), config: this.config.clone(),

View file

@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::future::Future; use std::future::Future;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
@ -17,12 +18,12 @@ use h2::server::{self, Handshake};
use log::error; use log::error;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::cloneable::CloneableService;
use crate::config::ServiceConfig; use crate::config::ServiceConfig;
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error};
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{ConnectCallback, Extensions}; use crate::service::HttpFlow;
use crate::{ConnectCallback, OnConnectData};
use super::dispatcher::Dispatcher; use super::dispatcher::Dispatcher;
@ -248,7 +249,7 @@ pub struct H2ServiceHandler<T, S, B>
where where
S: Service<Request>, S: Service<Request>,
{ {
srv: CloneableService<S>, services: Rc<RefCell<HttpFlow<S, (), ()>>>,
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_phantom: PhantomData<B>, _phantom: PhantomData<B>,
@ -265,12 +266,12 @@ where
fn new( fn new(
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
srv: S, service: S,
) -> H2ServiceHandler<T, S, B> { ) -> H2ServiceHandler<T, S, B> {
H2ServiceHandler { H2ServiceHandler {
services: HttpFlow::new(service, (), None),
cfg, cfg,
on_connect_ext, on_connect_ext,
srv: CloneableService::new(srv),
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@ -290,7 +291,11 @@ where
type Future = H2ServiceHandlerResponse<T, S, B>; type Future = H2ServiceHandlerResponse<T, S, B>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.srv.poll_ready(cx).map_err(|e| { self.services
.borrow_mut()
.service
.poll_ready(cx)
.map_err(|e| {
let e = e.into(); let e = e.into();
error!("Service readiness error: {:?}", e); error!("Service readiness error: {:?}", e);
DispatchError::Service(e) DispatchError::Service(e)
@ -298,18 +303,15 @@ where
} }
fn call(&mut self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future { fn call(&mut self, (io, addr): (T, Option<net::SocketAddr>)) -> Self::Future {
let mut connect_extensions = Extensions::new(); let on_connect_data =
if let Some(ref handler) = self.on_connect_ext { OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
// run on_connect_ext callback, populating connect extensions
handler(&io, &mut connect_extensions);
}
H2ServiceHandlerResponse { H2ServiceHandlerResponse {
state: State::Handshake( state: State::Handshake(
Some(self.srv.clone()), Some(self.services.clone()),
Some(self.cfg.clone()), Some(self.cfg.clone()),
addr, addr,
Some(connect_extensions), on_connect_data,
server::handshake(io), server::handshake(io),
), ),
} }
@ -321,12 +323,12 @@ where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
S::Future: 'static, S::Future: 'static,
{ {
Incoming(Dispatcher<T, S, B>), Incoming(Dispatcher<T, S, B, (), ()>),
Handshake( Handshake(
Option<CloneableService<S>>, Option<Rc<RefCell<HttpFlow<S, (), ()>>>>,
Option<ServiceConfig>, Option<ServiceConfig>,
Option<net::SocketAddr>, Option<net::SocketAddr>,
Option<Extensions>, OnConnectData,
Handshake<T, Bytes>, Handshake<T, Bytes>,
), ),
} }
@ -365,10 +367,11 @@ where
ref mut handshake, ref mut handshake,
) => match ready!(Pin::new(handshake).poll(cx)) { ) => match ready!(Pin::new(handshake).poll(cx)) {
Ok(conn) => { Ok(conn) => {
let on_connect_data = std::mem::take(on_connect_data);
self.state = State::Incoming(Dispatcher::new( self.state = State::Incoming(Dispatcher::new(
srv.take().unwrap(), srv.take().unwrap(),
conn, conn,
on_connect_data.take().unwrap(), on_connect_data,
config.take().unwrap(), config.take().unwrap(),
None, None,
*peer_addr, *peer_addr,

View file

@ -19,7 +19,6 @@ mod macros;
pub mod body; pub mod body;
mod builder; mod builder;
pub mod client; pub mod client;
mod cloneable;
mod config; mod config;
#[cfg(feature = "compress")] #[cfg(feature = "compress")]
pub mod encoding; pub mod encoding;
@ -81,3 +80,36 @@ pub enum Protocol {
} }
type ConnectCallback<IO> = dyn Fn(&IO, &mut Extensions); type ConnectCallback<IO> = dyn Fn(&IO, &mut Extensions);
/// Container for data that extract with ConnectCallback.
pub(crate) struct OnConnectData(Option<Extensions>);
impl Default for OnConnectData {
fn default() -> Self {
Self(None)
}
}
impl OnConnectData {
// construct self from io.
pub(crate) fn from_io<T>(
io: &T,
on_connect_ext: Option<&ConnectCallback<T>>,
) -> Self {
let ext = on_connect_ext.map(|handler| {
let mut extensions = Extensions::new();
handler(io, &mut extensions);
extensions
});
Self(ext)
}
// merge self to given request's head extension.
#[inline]
pub(crate) fn merge_into(&mut self, req: &mut Request) {
if let Some(ref mut ext) = self.0 {
req.head.extensions.get_mut().drain_from(ext);
}
}
}

View file

@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@ -14,12 +15,11 @@ use pin_project::pin_project;
use crate::body::MessageBody; use crate::body::MessageBody;
use crate::builder::HttpServiceBuilder; use crate::builder::HttpServiceBuilder;
use crate::cloneable::CloneableService;
use crate::config::{KeepAlive, ServiceConfig}; use crate::config::{KeepAlive, ServiceConfig};
use crate::error::{DispatchError, Error}; use crate::error::{DispatchError, Error};
use crate::request::Request; use crate::request::Request;
use crate::response::Response; use crate::response::Response;
use crate::{h1, h2::Dispatcher, ConnectCallback, Extensions, Protocol}; use crate::{h1, h2::Dispatcher, ConnectCallback, OnConnectData, Protocol};
/// A `ServiceFactory` for HTTP/1.1 or HTTP/2 protocol. /// A `ServiceFactory` for HTTP/1.1 or HTTP/2 protocol.
pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler> { pub struct HttpService<T, S, B, X = h1::ExpectHandler, U = h1::UpgradeHandler> {
@ -371,7 +371,7 @@ where
upgrade: Option<U::Service>, upgrade: Option<U::Service>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
cfg: ServiceConfig, cfg: ServiceConfig,
_phantom: PhantomData<(T, B)>, _phantom: PhantomData<B>,
} }
impl<T, S, B, X, U> Future for HttpServiceResponse<T, S, B, X, U> impl<T, S, B, X, U> Future for HttpServiceResponse<T, S, B, X, U>
@ -441,14 +441,29 @@ where
X: Service<Request>, X: Service<Request>,
U: Service<(Request, Framed<T, h1::Codec>)>, U: Service<(Request, Framed<T, h1::Codec>)>,
{ {
srv: CloneableService<S>, services: Rc<RefCell<HttpFlow<S, X, U>>>,
expect: CloneableService<X>,
upgrade: Option<CloneableService<U>>,
cfg: ServiceConfig, cfg: ServiceConfig,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
_phantom: PhantomData<B>, _phantom: PhantomData<B>,
} }
// a collection of service for http.
pub(super) struct HttpFlow<S, X, U> {
pub(super) service: S,
pub(super) expect: X,
pub(super) upgrade: Option<U>,
}
impl<S, X, U> HttpFlow<S, X, U> {
pub(super) fn new(service: S, expect: X, upgrade: Option<U>) -> Rc<RefCell<Self>> {
Rc::new(RefCell::new(Self {
service,
expect,
upgrade,
}))
}
}
impl<T, S, B, X, U> HttpServiceHandler<T, S, B, X, U> impl<T, S, B, X, U> HttpServiceHandler<T, S, B, X, U>
where where
S: Service<Request>, S: Service<Request>,
@ -463,7 +478,7 @@ where
{ {
fn new( fn new(
cfg: ServiceConfig, cfg: ServiceConfig,
srv: S, service: S,
expect: X, expect: X,
upgrade: Option<U>, upgrade: Option<U>,
on_connect_ext: Option<Rc<ConnectCallback<T>>>, on_connect_ext: Option<Rc<ConnectCallback<T>>>,
@ -471,9 +486,7 @@ where
HttpServiceHandler { HttpServiceHandler {
cfg, cfg,
on_connect_ext, on_connect_ext,
srv: CloneableService::new(srv), services: HttpFlow::new(service, expect, upgrade),
expect: CloneableService::new(expect),
upgrade: upgrade.map(CloneableService::new),
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@ -498,7 +511,8 @@ where
type Future = HttpServiceHandlerResponse<T, S, B, X, U>; type Future = HttpServiceHandlerResponse<T, S, B, X, U>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let ready = self let mut services = self.services.borrow_mut();
let ready = services
.expect .expect
.poll_ready(cx) .poll_ready(cx)
.map_err(|e| { .map_err(|e| {
@ -508,8 +522,8 @@ where
})? })?
.is_ready(); .is_ready();
let ready = self let ready = services
.srv .service
.poll_ready(cx) .poll_ready(cx)
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
@ -519,7 +533,7 @@ where
.is_ready() .is_ready()
&& ready; && ready;
let ready = if let Some(ref mut upg) = self.upgrade { let ready = if let Some(ref mut upg) = services.upgrade {
upg.poll_ready(cx) upg.poll_ready(cx)
.map_err(|e| { .map_err(|e| {
let e = e.into(); let e = e.into();
@ -543,19 +557,16 @@ where
&mut self, &mut self,
(io, proto, peer_addr): (T, Protocol, Option<net::SocketAddr>), (io, proto, peer_addr): (T, Protocol, Option<net::SocketAddr>),
) -> Self::Future { ) -> Self::Future {
let mut connect_extensions = Extensions::new(); let on_connect_data =
OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
if let Some(ref handler) = self.on_connect_ext {
handler(&io, &mut connect_extensions);
}
match proto { match proto {
Protocol::Http2 => HttpServiceHandlerResponse { Protocol::Http2 => HttpServiceHandlerResponse {
state: State::H2Handshake(Some(( state: State::H2Handshake(Some((
server::handshake(io), server::handshake(io),
self.cfg.clone(), self.cfg.clone(),
self.srv.clone(), self.services.clone(),
connect_extensions, on_connect_data,
peer_addr, peer_addr,
))), ))),
}, },
@ -564,10 +575,8 @@ where
state: State::H1(h1::Dispatcher::new( state: State::H1(h1::Dispatcher::new(
io, io,
self.cfg.clone(), self.cfg.clone(),
self.srv.clone(), self.services.clone(),
self.expect.clone(), on_connect_data,
self.upgrade.clone(),
connect_extensions,
peer_addr, peer_addr,
)), )),
}, },
@ -589,13 +598,13 @@ where
U::Error: fmt::Display, U::Error: fmt::Display,
{ {
H1(#[pin] h1::Dispatcher<T, S, B, X, U>), H1(#[pin] h1::Dispatcher<T, S, B, X, U>),
H2(#[pin] Dispatcher<T, S, B>), H2(#[pin] Dispatcher<T, S, B, X, U>),
H2Handshake( H2Handshake(
Option<( Option<(
Handshake<T, Bytes>, Handshake<T, Bytes>,
ServiceConfig, ServiceConfig,
CloneableService<S>, Rc<RefCell<HttpFlow<S, X, U>>>,
Extensions, OnConnectData,
Option<net::SocketAddr>, Option<net::SocketAddr>,
)>, )>,
), ),