1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-04 06:18:51 +00:00

Merge pull request #267 from joshleeb/trait-middleware-mut-self

Update Middleware Trait to Use `&mut self`
This commit is contained in:
Nikolay Kim 2018-06-02 08:54:30 -07:00 committed by GitHub
commit 2f476021d8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 111 additions and 90 deletions

View file

@ -8,6 +8,8 @@
* Min rustc version is 1.26 * Min rustc version is 1.26
* Use `&mut self` instead of `&self` for Middleware trait
### Removed ### Removed
* Remove `Route::with2()` and `Route::with3()` use tuple of extractors instead. * Remove `Route::with2()` and `Route::with3()` use tuple of extractors instead.

View file

@ -1,4 +1,4 @@
use std::cell::UnsafeCell; use std::cell::{RefCell, UnsafeCell};
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
@ -22,7 +22,7 @@ pub struct HttpApplication<S = ()> {
prefix_len: usize, prefix_len: usize,
router: Router, router: Router,
inner: Rc<UnsafeCell<Inner<S>>>, inner: Rc<UnsafeCell<Inner<S>>>,
middlewares: Rc<Vec<Box<Middleware<S>>>>, middlewares: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
} }
pub(crate) struct Inner<S> { pub(crate) struct Inner<S> {
@ -612,7 +612,7 @@ where
HttpApplication { HttpApplication {
state: Rc::new(parts.state), state: Rc::new(parts.state),
router: router.clone(), router: router.clone(),
middlewares: Rc::new(parts.middlewares), middlewares: Rc::new(RefCell::new(parts.middlewares)),
prefix, prefix,
prefix_len, prefix_len,
inner, inner,

View file

@ -356,7 +356,7 @@ impl Cors {
} }
impl<S> Middleware<S> for Cors { impl<S> Middleware<S> for Cors {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&mut self, req: &mut HttpRequest<S>) -> Result<Started> {
if self.inner.preflight && Method::OPTIONS == *req.method() { if self.inner.preflight && Method::OPTIONS == *req.method() {
self.validate_origin(req)?; self.validate_origin(req)?;
self.validate_allowed_method(req)?; self.validate_allowed_method(req)?;
@ -431,7 +431,7 @@ impl<S> Middleware<S> for Cors {
} }
fn response( fn response(
&self, req: &mut HttpRequest<S>, mut resp: HttpResponse, &mut self, req: &mut HttpRequest<S>, mut resp: HttpResponse,
) -> Result<Response> { ) -> Result<Response> {
match self.inner.origins { match self.inner.origins {
AllOrSome::All => { AllOrSome::All => {
@ -941,7 +941,7 @@ mod tests {
#[test] #[test]
fn validate_origin_allows_all_origins() { fn validate_origin_allows_all_origins() {
let cors = Cors::default(); let mut cors = Cors::default();
let mut req = let mut req =
TestRequest::with_header("Origin", "https://www.example.com").finish(); TestRequest::with_header("Origin", "https://www.example.com").finish();
@ -1010,7 +1010,7 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "MissingOrigin")] #[should_panic(expected = "MissingOrigin")]
fn test_validate_missing_origin() { fn test_validate_missing_origin() {
let cors = Cors::build() let mut cors = Cors::build()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish(); .finish();
@ -1021,7 +1021,7 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "OriginNotAllowed")] #[should_panic(expected = "OriginNotAllowed")]
fn test_validate_not_allowed_origin() { fn test_validate_not_allowed_origin() {
let cors = Cors::build() let mut cors = Cors::build()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish(); .finish();
@ -1033,7 +1033,7 @@ mod tests {
#[test] #[test]
fn test_validate_origin() { fn test_validate_origin() {
let cors = Cors::build() let mut cors = Cors::build()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish(); .finish();
@ -1046,7 +1046,7 @@ mod tests {
#[test] #[test]
fn test_no_origin_response() { fn test_no_origin_response() {
let cors = Cors::build().finish(); let mut cors = Cors::build().finish();
let mut req = TestRequest::default().method(Method::GET).finish(); let mut req = TestRequest::default().method(Method::GET).finish();
let resp: HttpResponse = HttpResponse::Ok().into(); let resp: HttpResponse = HttpResponse::Ok().into();
@ -1072,7 +1072,7 @@ mod tests {
#[test] #[test]
fn test_response() { fn test_response() {
let cors = Cors::build() let mut cors = Cors::build()
.send_wildcard() .send_wildcard()
.disable_preflight() .disable_preflight()
.max_age(3600) .max_age(3600)
@ -1107,7 +1107,7 @@ mod tests {
resp.headers().get(header::VARY).unwrap().as_bytes() resp.headers().get(header::VARY).unwrap().as_bytes()
); );
let cors = Cors::build() let mut cors = Cors::build()
.disable_vary_header() .disable_vary_header()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.finish(); .finish();

View file

@ -209,7 +209,7 @@ impl CsrfFilter {
} }
impl<S> Middleware<S> for CsrfFilter { impl<S> Middleware<S> for CsrfFilter {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&mut self, req: &mut HttpRequest<S>) -> Result<Started> {
self.validate(req)?; self.validate(req)?;
Ok(Started::Done) Ok(Started::Done)
} }
@ -223,7 +223,7 @@ mod tests {
#[test] #[test]
fn test_safe() { fn test_safe() {
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); let mut csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header("Origin", "https://www.w3.org") let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::HEAD) .method(Method::HEAD)
@ -234,7 +234,7 @@ mod tests {
#[test] #[test]
fn test_csrf() { fn test_csrf() {
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); let mut csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header("Origin", "https://www.w3.org") let mut req = TestRequest::with_header("Origin", "https://www.w3.org")
.method(Method::POST) .method(Method::POST)
@ -245,7 +245,7 @@ mod tests {
#[test] #[test]
fn test_referer() { fn test_referer() {
let csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); let mut csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let mut req = TestRequest::with_header( let mut req = TestRequest::with_header(
"Referer", "Referer",
@ -258,9 +258,9 @@ mod tests {
#[test] #[test]
fn test_upgrade() { fn test_upgrade() {
let strict_csrf = CsrfFilter::new().allowed_origin("https://www.example.com"); let mut strict_csrf = CsrfFilter::new().allowed_origin("https://www.example.com");
let lax_csrf = CsrfFilter::new() let mut lax_csrf = CsrfFilter::new()
.allowed_origin("https://www.example.com") .allowed_origin("https://www.example.com")
.allow_upgrade(); .allow_upgrade();

View file

@ -75,7 +75,7 @@ impl DefaultHeaders {
impl<S> Middleware<S> for DefaultHeaders { impl<S> Middleware<S> for DefaultHeaders {
fn response( fn response(
&self, _: &mut HttpRequest<S>, mut resp: HttpResponse, &mut self, _: &mut HttpRequest<S>, mut resp: HttpResponse,
) -> Result<Response> { ) -> Result<Response> {
for (key, value) in self.headers.iter() { for (key, value) in self.headers.iter() {
if !resp.headers().contains_key(key) { if !resp.headers().contains_key(key) {
@ -100,7 +100,7 @@ mod tests {
#[test] #[test]
fn test_default_headers() { fn test_default_headers() {
let mw = DefaultHeaders::new().header(CONTENT_TYPE, "0001"); let mut mw = DefaultHeaders::new().header(CONTENT_TYPE, "0001");
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();

View file

@ -71,7 +71,7 @@ impl<S> ErrorHandlers<S> {
impl<S: 'static> Middleware<S> for ErrorHandlers<S> { impl<S: 'static> Middleware<S> for ErrorHandlers<S> {
fn response( fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse, &mut self, req: &mut HttpRequest<S>, resp: HttpResponse,
) -> Result<Response> { ) -> Result<Response> {
if let Some(handler) = self.handlers.get(&resp.status()) { if let Some(handler) = self.handlers.get(&resp.status()) {
handler(req, resp) handler(req, resp)
@ -95,7 +95,7 @@ mod tests {
#[test] #[test]
fn test_handler() { fn test_handler() {
let mw = let mut mw =
ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);
let mut req = HttpRequest::default(); let mut req = HttpRequest::default();

View file

@ -183,7 +183,7 @@ unsafe impl Send for IdentityBox {}
unsafe impl Sync for IdentityBox {} unsafe impl Sync for IdentityBox {}
impl<S: 'static, T: IdentityPolicy<S>> Middleware<S> for IdentityService<T> { impl<S: 'static, T: IdentityPolicy<S>> Middleware<S> for IdentityService<T> {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&mut self, req: &mut HttpRequest<S>) -> Result<Started> {
let mut req = req.clone(); let mut req = req.clone();
let fut = self let fut = self
@ -200,7 +200,7 @@ impl<S: 'static, T: IdentityPolicy<S>> Middleware<S> for IdentityService<T> {
} }
fn response( fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse, &mut self, req: &mut HttpRequest<S>, resp: HttpResponse,
) -> Result<Response> { ) -> Result<Response> {
if let Some(mut id) = req.extensions_mut().remove::<IdentityBox>() { if let Some(mut id) = req.extensions_mut().remove::<IdentityBox>() {
id.0.write(resp) id.0.write(resp)

View file

@ -124,14 +124,14 @@ impl Logger {
} }
impl<S> Middleware<S> for Logger { impl<S> Middleware<S> for Logger {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&mut self, req: &mut HttpRequest<S>) -> Result<Started> {
if !self.exclude.contains(req.path()) { if !self.exclude.contains(req.path()) {
req.extensions_mut().insert(StartTime(time::now())); req.extensions_mut().insert(StartTime(time::now()));
} }
Ok(Started::Done) Ok(Started::Done)
} }
fn finish(&self, req: &mut HttpRequest<S>, resp: &HttpResponse) -> Finished { fn finish(&mut self, req: &mut HttpRequest<S>, resp: &HttpResponse) -> Finished {
self.log(req, resp); self.log(req, resp);
Finished::Done Finished::Done
} }
@ -322,7 +322,7 @@ mod tests {
#[test] #[test]
fn test_logger() { fn test_logger() {
let logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test"); let mut logger = Logger::new("%% %{User-Agent}i %{X-Test}o %{HOME}e %D test");
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert( headers.insert(

View file

@ -51,20 +51,20 @@ pub enum Finished {
pub trait Middleware<S>: 'static { pub trait Middleware<S>: 'static {
/// Method is called when request is ready. It may return /// Method is called when request is ready. It may return
/// future, which should resolve before next middleware get called. /// future, which should resolve before next middleware get called.
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&mut self, req: &mut HttpRequest<S>) -> Result<Started> {
Ok(Started::Done) Ok(Started::Done)
} }
/// Method is called when handler returns response, /// Method is called when handler returns response,
/// but before sending http message to peer. /// but before sending http message to peer.
fn response( fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse, &mut self, req: &mut HttpRequest<S>, resp: HttpResponse,
) -> Result<Response> { ) -> Result<Response> {
Ok(Response::Done(resp)) Ok(Response::Done(resp))
} }
/// Method is called after body stream get sent to peer. /// Method is called after body stream get sent to peer.
fn finish(&self, req: &mut HttpRequest<S>, resp: &HttpResponse) -> Finished { fn finish(&mut self, req: &mut HttpRequest<S>, resp: &HttpResponse) -> Finished {
Finished::Done Finished::Done
} }
} }

View file

@ -251,7 +251,7 @@ impl<S, T: SessionBackend<S>> SessionStorage<T, S> {
} }
impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> { impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> {
fn start(&self, req: &mut HttpRequest<S>) -> Result<Started> { fn start(&mut self, req: &mut HttpRequest<S>) -> Result<Started> {
let mut req = req.clone(); let mut req = req.clone();
let fut = self.0.from_request(&mut req).then(move |res| match res { let fut = self.0.from_request(&mut req).then(move |res| match res {
@ -266,7 +266,7 @@ impl<S: 'static, T: SessionBackend<S>> Middleware<S> for SessionStorage<T, S> {
} }
fn response( fn response(
&self, req: &mut HttpRequest<S>, resp: HttpResponse, &mut self, req: &mut HttpRequest<S>, resp: HttpResponse,
) -> Result<Response> { ) -> Result<Response> {
if let Some(s_box) = req.extensions_mut().remove::<Arc<SessionImplCell>>() { if let Some(s_box) = req.extensions_mut().remove::<Arc<SessionImplCell>>() {
s_box.0.borrow_mut().write(resp) s_box.0.borrow_mut().write(resp)

View file

@ -1,4 +1,4 @@
use std::cell::UnsafeCell; use std::cell::{RefCell, UnsafeCell};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::rc::Rc; use std::rc::Rc;
use std::{io, mem}; use std::{io, mem};
@ -71,7 +71,7 @@ impl<S: 'static, H: PipelineHandler<S>> PipelineState<S, H> {
struct PipelineInfo<S> { struct PipelineInfo<S> {
req: UnsafeCell<HttpRequest<S>>, req: UnsafeCell<HttpRequest<S>>,
count: u16, count: u16,
mws: Rc<Vec<Box<Middleware<S>>>>, mws: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
context: Option<Box<ActorHttpContext>>, context: Option<Box<ActorHttpContext>>,
error: Option<Error>, error: Option<Error>,
disconnected: Option<bool>, disconnected: Option<bool>,
@ -83,7 +83,7 @@ impl<S> PipelineInfo<S> {
PipelineInfo { PipelineInfo {
req: UnsafeCell::new(req), req: UnsafeCell::new(req),
count: 0, count: 0,
mws: Rc::new(Vec::new()), mws: Rc::new(RefCell::new(Vec::new())),
error: None, error: None,
context: None, context: None,
disconnected: None, disconnected: None,
@ -120,7 +120,7 @@ impl<S> PipelineInfo<S> {
impl<S: 'static, H: PipelineHandler<S>> Pipeline<S, H> { impl<S: 'static, H: PipelineHandler<S>> Pipeline<S, H> {
pub fn new( pub fn new(
req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, req: HttpRequest<S>, mws: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
handler: Rc<UnsafeCell<H>>, htype: HandlerType, handler: Rc<UnsafeCell<H>>, htype: HandlerType,
) -> Pipeline<S, H> { ) -> Pipeline<S, H> {
let mut info = PipelineInfo { let mut info = PipelineInfo {
@ -243,13 +243,14 @@ impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
) -> PipelineState<S, H> { ) -> PipelineState<S, H> {
// execute middlewares, we need this stage because middlewares could be // execute middlewares, we need this stage because middlewares could be
// non-async and we can move to next state immediately // non-async and we can move to next state immediately
let len = info.mws.len() as u16; let len = info.mws.borrow().len() as u16;
loop { loop {
if info.count == len { if info.count == len {
let reply = unsafe { &mut *hnd.get() }.handle(info.req().clone(), htype); let reply = unsafe { &mut *hnd.get() }.handle(info.req().clone(), htype);
return WaitingResponse::init(info, reply); return WaitingResponse::init(info, reply);
} else { } else {
match info.mws[info.count as usize].start(info.req_mut()) { let state = info.mws.borrow_mut()[info.count as usize].start(info.req_mut());
match state {
Ok(Started::Done) => info.count += 1, Ok(Started::Done) => info.count += 1,
Ok(Started::Response(resp)) => { Ok(Started::Response(resp)) => {
return RunMiddlewares::init(info, resp) return RunMiddlewares::init(info, resp)
@ -269,7 +270,7 @@ impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
} }
fn poll(&mut self, info: &mut PipelineInfo<S>) -> Option<PipelineState<S, H>> { fn poll(&mut self, info: &mut PipelineInfo<S>) -> Option<PipelineState<S, H>> {
let len = info.mws.len() as u16; let len = info.mws.borrow().len() as u16;
'outer: loop { 'outer: loop {
match self.fut.as_mut().unwrap().poll() { match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) => return None, Ok(Async::NotReady) => return None,
@ -284,7 +285,9 @@ impl<S: 'static, H: PipelineHandler<S>> StartMiddlewares<S, H> {
.handle(info.req().clone(), self.htype); .handle(info.req().clone(), self.htype);
return Some(WaitingResponse::init(info, reply)); return Some(WaitingResponse::init(info, reply));
} else { } else {
match info.mws[info.count as usize].start(info.req_mut()) { let state = info.mws.borrow_mut()[info.count as usize]
.start(info.req_mut());
match state {
Ok(Started::Done) => info.count += 1, Ok(Started::Done) => info.count += 1,
Ok(Started::Response(resp)) => { Ok(Started::Response(resp)) => {
return Some(RunMiddlewares::init(info, resp)); return Some(RunMiddlewares::init(info, resp));
@ -353,10 +356,11 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
return ProcessResponse::init(resp); return ProcessResponse::init(resp);
} }
let mut curr = 0; let mut curr = 0;
let len = info.mws.len(); let len = info.mws.borrow().len();
loop { loop {
resp = match info.mws[curr].response(info.req_mut(), resp) { let state = info.mws.borrow_mut()[curr].response(info.req_mut(), resp);
resp = match state {
Err(err) => { Err(err) => {
info.count = (curr + 1) as u16; info.count = (curr + 1) as u16;
return ProcessResponse::init(err.into()); return ProcessResponse::init(err.into());
@ -382,7 +386,7 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
} }
fn poll(&mut self, info: &mut PipelineInfo<S>) -> Option<PipelineState<S, H>> { fn poll(&mut self, info: &mut PipelineInfo<S>) -> Option<PipelineState<S, H>> {
let len = info.mws.len(); let len = info.mws.borrow().len();
loop { loop {
// poll latest fut // poll latest fut
@ -399,7 +403,8 @@ impl<S: 'static, H> RunMiddlewares<S, H> {
if self.curr == len { if self.curr == len {
return Some(ProcessResponse::init(resp)); return Some(ProcessResponse::init(resp));
} else { } else {
match info.mws[self.curr].response(info.req_mut(), resp) { let state = info.mws.borrow_mut()[self.curr].response(info.req_mut(), resp);
match state {
Err(err) => return Some(ProcessResponse::init(err.into())), Err(err) => return Some(ProcessResponse::init(err.into())),
Ok(Response::Done(r)) => { Ok(Response::Done(r)) => {
self.curr += 1; self.curr += 1;
@ -723,7 +728,9 @@ impl<S: 'static, H> FinishingMiddlewares<S, H> {
} }
info.count -= 1; info.count -= 1;
match info.mws[info.count as usize].finish(info.req_mut(), &self.resp) { let state = info.mws.borrow_mut()[info.count as usize]
.finish(info.req_mut(), &self.resp);
match state {
Finished::Done => { Finished::Done => {
if info.count == 0 { if info.count == 0 {
return Some(Completed::init(info)); return Some(Completed::init(info));

View file

@ -1,5 +1,6 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use std::rc::Rc; use std::rc::Rc;
use std::cell::RefCell;
use futures::Future; use futures::Future;
use http::{Method, StatusCode}; use http::{Method, StatusCode};
@ -37,7 +38,7 @@ pub struct ResourceHandler<S = ()> {
name: String, name: String,
state: PhantomData<S>, state: PhantomData<S>,
routes: SmallVec<[Route<S>; 3]>, routes: SmallVec<[Route<S>; 3]>,
middlewares: Rc<Vec<Box<Middleware<S>>>>, middlewares: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
} }
impl<S> Default for ResourceHandler<S> { impl<S> Default for ResourceHandler<S> {
@ -46,7 +47,7 @@ impl<S> Default for ResourceHandler<S> {
name: String::new(), name: String::new(),
state: PhantomData, state: PhantomData,
routes: SmallVec::new(), routes: SmallVec::new(),
middlewares: Rc::new(Vec::new()), middlewares: Rc::new(RefCell::new(Vec::new())),
} }
} }
} }
@ -57,7 +58,7 @@ impl<S> ResourceHandler<S> {
name: String::new(), name: String::new(),
state: PhantomData, state: PhantomData,
routes: SmallVec::new(), routes: SmallVec::new(),
middlewares: Rc::new(Vec::new()), middlewares: Rc::new(RefCell::new(Vec::new())),
} }
} }
@ -276,6 +277,7 @@ impl<S: 'static> ResourceHandler<S> {
pub fn middleware<M: Middleware<S>>(&mut self, mw: M) { pub fn middleware<M: Middleware<S>>(&mut self, mw: M) {
Rc::get_mut(&mut self.middlewares) Rc::get_mut(&mut self.middlewares)
.unwrap() .unwrap()
.borrow_mut()
.push(Box::new(mw)); .push(Box::new(mw));
} }
@ -284,7 +286,7 @@ impl<S: 'static> ResourceHandler<S> {
) -> AsyncResult<HttpResponse> { ) -> AsyncResult<HttpResponse> {
for route in &mut self.routes { for route in &mut self.routes {
if route.check(&mut req) { if route.check(&mut req) {
return if self.middlewares.is_empty() { return if self.middlewares.borrow().is_empty() {
route.handle(req) route.handle(req)
} else { } else {
route.compose(req, Rc::clone(&self.middlewares)) route.compose(req, Rc::clone(&self.middlewares))

View file

@ -1,4 +1,4 @@
use std::cell::UnsafeCell; use std::cell::{RefCell, UnsafeCell};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::rc::Rc; use std::rc::Rc;
@ -55,7 +55,7 @@ impl<S: 'static> Route<S> {
#[inline] #[inline]
pub(crate) fn compose( pub(crate) fn compose(
&mut self, req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, &mut self, req: HttpRequest<S>, mws: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
) -> AsyncResult<HttpResponse> { ) -> AsyncResult<HttpResponse> {
AsyncResult::async(Box::new(Compose::new(req, mws, self.handler.clone()))) AsyncResult::async(Box::new(Compose::new(req, mws, self.handler.clone())))
} }
@ -263,7 +263,7 @@ struct Compose<S: 'static> {
struct ComposeInfo<S: 'static> { struct ComposeInfo<S: 'static> {
count: usize, count: usize,
req: HttpRequest<S>, req: HttpRequest<S>,
mws: Rc<Vec<Box<Middleware<S>>>>, mws: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
handler: InnerHandler<S>, handler: InnerHandler<S>,
} }
@ -289,7 +289,7 @@ impl<S: 'static> ComposeState<S> {
impl<S: 'static> Compose<S> { impl<S: 'static> Compose<S> {
fn new( fn new(
req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, handler: InnerHandler<S>, req: HttpRequest<S>, mws: Rc<RefCell<Vec<Box<Middleware<S>>>>>, handler: InnerHandler<S>,
) -> Self { ) -> Self {
let mut info = ComposeInfo { let mut info = ComposeInfo {
count: 0, count: 0,
@ -332,13 +332,14 @@ type Fut = Box<Future<Item = Option<HttpResponse>, Error = Error>>;
impl<S: 'static> StartMiddlewares<S> { impl<S: 'static> StartMiddlewares<S> {
fn init(info: &mut ComposeInfo<S>) -> ComposeState<S> { fn init(info: &mut ComposeInfo<S>) -> ComposeState<S> {
let len = info.mws.len(); let len = info.mws.borrow().len();
loop { loop {
if info.count == len { if info.count == len {
let reply = info.handler.handle(info.req.clone()); let reply = info.handler.handle(info.req.clone());
return WaitingResponse::init(info, reply); return WaitingResponse::init(info, reply);
} else { } else {
match info.mws[info.count].start(&mut info.req) { let state = info.mws.borrow_mut()[info.count].start(&mut info.req);
match state {
Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Done) => info.count += 1,
Ok(MiddlewareStarted::Response(resp)) => { Ok(MiddlewareStarted::Response(resp)) => {
return RunMiddlewares::init(info, resp) return RunMiddlewares::init(info, resp)
@ -356,7 +357,7 @@ impl<S: 'static> StartMiddlewares<S> {
} }
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> { fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
let len = info.mws.len(); let len = info.mws.borrow().len();
'outer: loop { 'outer: loop {
match self.fut.as_mut().unwrap().poll() { match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) => return None, Ok(Async::NotReady) => return None,
@ -370,7 +371,8 @@ impl<S: 'static> StartMiddlewares<S> {
let reply = info.handler.handle(info.req.clone()); let reply = info.handler.handle(info.req.clone());
return Some(WaitingResponse::init(info, reply)); return Some(WaitingResponse::init(info, reply));
} else { } else {
match info.mws[info.count].start(&mut info.req) { let state = info.mws.borrow_mut()[info.count].start(&mut info.req);
match state {
Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Done) => info.count += 1,
Ok(MiddlewareStarted::Response(resp)) => { Ok(MiddlewareStarted::Response(resp)) => {
return Some(RunMiddlewares::init(info, resp)); return Some(RunMiddlewares::init(info, resp));
@ -435,10 +437,11 @@ struct RunMiddlewares<S> {
impl<S: 'static> RunMiddlewares<S> { impl<S: 'static> RunMiddlewares<S> {
fn init(info: &mut ComposeInfo<S>, mut resp: HttpResponse) -> ComposeState<S> { fn init(info: &mut ComposeInfo<S>, mut resp: HttpResponse) -> ComposeState<S> {
let mut curr = 0; let mut curr = 0;
let len = info.mws.len(); let len = info.mws.borrow().len();
loop { loop {
resp = match info.mws[curr].response(&mut info.req, resp) { let state = info.mws.borrow_mut()[curr].response(&mut info.req, resp);
resp = match state {
Err(err) => { Err(err) => {
info.count = curr + 1; info.count = curr + 1;
return FinishingMiddlewares::init(info, err.into()); return FinishingMiddlewares::init(info, err.into());
@ -463,7 +466,7 @@ impl<S: 'static> RunMiddlewares<S> {
} }
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> { fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
let len = info.mws.len(); let len = info.mws.borrow().len();
loop { loop {
// poll latest fut // poll latest fut
@ -480,7 +483,8 @@ impl<S: 'static> RunMiddlewares<S> {
if self.curr == len { if self.curr == len {
return Some(FinishingMiddlewares::init(info, resp)); return Some(FinishingMiddlewares::init(info, resp));
} else { } else {
match info.mws[self.curr].response(&mut info.req, resp) { let state = info.mws.borrow_mut()[self.curr].response(&mut info.req, resp);
match state {
Err(err) => { Err(err) => {
return Some(FinishingMiddlewares::init(info, err.into())) return Some(FinishingMiddlewares::init(info, err.into()))
} }
@ -548,9 +552,10 @@ impl<S: 'static> FinishingMiddlewares<S> {
} }
info.count -= 1; info.count -= 1;
match info.mws[info.count as usize]
.finish(&mut info.req, self.resp.as_ref().unwrap()) let state = info.mws.borrow_mut()[info.count as usize]
{ .finish(&mut info.req, self.resp.as_ref().unwrap());
match state {
MiddlewareFinished::Done => { MiddlewareFinished::Done => {
if info.count == 0 { if info.count == 0 {
return Some(Response::init(self.resp.take().unwrap())); return Some(Response::init(self.resp.take().unwrap()));

View file

@ -1,4 +1,4 @@
use std::cell::UnsafeCell; use std::cell::{RefCell, UnsafeCell};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use std::mem;
use std::rc::Rc; use std::rc::Rc;
@ -56,7 +56,7 @@ type NestedInfo<S> = (Resource, Route<S>, Vec<Box<Predicate<S>>>);
pub struct Scope<S: 'static> { pub struct Scope<S: 'static> {
filters: Vec<Box<Predicate<S>>>, filters: Vec<Box<Predicate<S>>>,
nested: Vec<NestedInfo<S>>, nested: Vec<NestedInfo<S>>,
middlewares: Rc<Vec<Box<Middleware<S>>>>, middlewares: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
default: Rc<UnsafeCell<ResourceHandler<S>>>, default: Rc<UnsafeCell<ResourceHandler<S>>>,
resources: ScopeResources<S>, resources: ScopeResources<S>,
} }
@ -70,7 +70,7 @@ impl<S: 'static> Scope<S> {
filters: Vec::new(), filters: Vec::new(),
nested: Vec::new(), nested: Vec::new(),
resources: Rc::new(Vec::new()), resources: Rc::new(Vec::new()),
middlewares: Rc::new(Vec::new()), middlewares: Rc::new(RefCell::new(Vec::new())),
default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())), default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())),
} }
} }
@ -134,7 +134,7 @@ impl<S: 'static> Scope<S> {
filters: Vec::new(), filters: Vec::new(),
nested: Vec::new(), nested: Vec::new(),
resources: Rc::new(Vec::new()), resources: Rc::new(Vec::new()),
middlewares: Rc::new(Vec::new()), middlewares: Rc::new(RefCell::new(Vec::new())),
default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())), default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())),
}; };
let mut scope = f(scope); let mut scope = f(scope);
@ -177,7 +177,7 @@ impl<S: 'static> Scope<S> {
filters: Vec::new(), filters: Vec::new(),
nested: Vec::new(), nested: Vec::new(),
resources: Rc::new(Vec::new()), resources: Rc::new(Vec::new()),
middlewares: Rc::new(Vec::new()), middlewares: Rc::new(RefCell::new(Vec::new())),
default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())), default: Rc::new(UnsafeCell::new(ResourceHandler::default_not_found())),
}; };
let mut scope = f(scope); let mut scope = f(scope);
@ -314,6 +314,7 @@ impl<S: 'static> Scope<S> {
pub fn middleware<M: Middleware<S>>(mut self, mw: M) -> Scope<S> { pub fn middleware<M: Middleware<S>>(mut self, mw: M) -> Scope<S> {
Rc::get_mut(&mut self.middlewares) Rc::get_mut(&mut self.middlewares)
.expect("Can not use after configuration") .expect("Can not use after configuration")
.borrow_mut()
.push(Box::new(mw)); .push(Box::new(mw));
self self
} }
@ -329,7 +330,7 @@ impl<S: 'static> RouteHandler<S> for Scope<S> {
let default = unsafe { &mut *self.default.as_ref().get() }; let default = unsafe { &mut *self.default.as_ref().get() };
req.match_info_mut().remove("tail"); req.match_info_mut().remove("tail");
if self.middlewares.is_empty() { if self.middlewares.borrow().is_empty() {
let resource = unsafe { &mut *resource.get() }; let resource = unsafe { &mut *resource.get() };
return resource.handle(req, Some(default)); return resource.handle(req, Some(default));
} else { } else {
@ -371,7 +372,7 @@ impl<S: 'static> RouteHandler<S> for Scope<S> {
// default handler // default handler
let default = unsafe { &mut *self.default.as_ref().get() }; let default = unsafe { &mut *self.default.as_ref().get() };
if self.middlewares.is_empty() { if self.middlewares.borrow().is_empty() {
default.handle(req, None) default.handle(req, None)
} else { } else {
AsyncResult::async(Box::new(Compose::new( AsyncResult::async(Box::new(Compose::new(
@ -421,7 +422,7 @@ struct Compose<S: 'static> {
struct ComposeInfo<S: 'static> { struct ComposeInfo<S: 'static> {
count: usize, count: usize,
req: HttpRequest<S>, req: HttpRequest<S>,
mws: Rc<Vec<Box<Middleware<S>>>>, mws: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
default: Option<Rc<UnsafeCell<ResourceHandler<S>>>>, default: Option<Rc<UnsafeCell<ResourceHandler<S>>>>,
resource: Rc<UnsafeCell<ResourceHandler<S>>>, resource: Rc<UnsafeCell<ResourceHandler<S>>>,
} }
@ -448,7 +449,7 @@ impl<S: 'static> ComposeState<S> {
impl<S: 'static> Compose<S> { impl<S: 'static> Compose<S> {
fn new( fn new(
req: HttpRequest<S>, mws: Rc<Vec<Box<Middleware<S>>>>, req: HttpRequest<S>, mws: Rc<RefCell<Vec<Box<Middleware<S>>>>>,
resource: Rc<UnsafeCell<ResourceHandler<S>>>, resource: Rc<UnsafeCell<ResourceHandler<S>>>,
default: Option<Rc<UnsafeCell<ResourceHandler<S>>>>, default: Option<Rc<UnsafeCell<ResourceHandler<S>>>>,
) -> Self { ) -> Self {
@ -494,7 +495,7 @@ type Fut = Box<Future<Item = Option<HttpResponse>, Error = Error>>;
impl<S: 'static> StartMiddlewares<S> { impl<S: 'static> StartMiddlewares<S> {
fn init(info: &mut ComposeInfo<S>) -> ComposeState<S> { fn init(info: &mut ComposeInfo<S>) -> ComposeState<S> {
let len = info.mws.len(); let len = info.mws.borrow().len();
loop { loop {
if info.count == len { if info.count == len {
let resource = unsafe { &mut *info.resource.get() }; let resource = unsafe { &mut *info.resource.get() };
@ -506,7 +507,8 @@ impl<S: 'static> StartMiddlewares<S> {
}; };
return WaitingResponse::init(info, reply); return WaitingResponse::init(info, reply);
} else { } else {
match info.mws[info.count].start(&mut info.req) { let state = info.mws.borrow_mut()[info.count].start(&mut info.req);
match state {
Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Done) => info.count += 1,
Ok(MiddlewareStarted::Response(resp)) => { Ok(MiddlewareStarted::Response(resp)) => {
return RunMiddlewares::init(info, resp) return RunMiddlewares::init(info, resp)
@ -524,7 +526,7 @@ impl<S: 'static> StartMiddlewares<S> {
} }
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> { fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
let len = info.mws.len(); let len = info.mws.borrow().len();
'outer: loop { 'outer: loop {
match self.fut.as_mut().unwrap().poll() { match self.fut.as_mut().unwrap().poll() {
Ok(Async::NotReady) => return None, Ok(Async::NotReady) => return None,
@ -544,7 +546,8 @@ impl<S: 'static> StartMiddlewares<S> {
}; };
return Some(WaitingResponse::init(info, reply)); return Some(WaitingResponse::init(info, reply));
} else { } else {
match info.mws[info.count].start(&mut info.req) { let state = info.mws.borrow_mut()[info.count].start(&mut info.req);
match state {
Ok(MiddlewareStarted::Done) => info.count += 1, Ok(MiddlewareStarted::Done) => info.count += 1,
Ok(MiddlewareStarted::Response(resp)) => { Ok(MiddlewareStarted::Response(resp)) => {
return Some(RunMiddlewares::init(info, resp)); return Some(RunMiddlewares::init(info, resp));
@ -604,10 +607,11 @@ struct RunMiddlewares<S> {
impl<S: 'static> RunMiddlewares<S> { impl<S: 'static> RunMiddlewares<S> {
fn init(info: &mut ComposeInfo<S>, mut resp: HttpResponse) -> ComposeState<S> { fn init(info: &mut ComposeInfo<S>, mut resp: HttpResponse) -> ComposeState<S> {
let mut curr = 0; let mut curr = 0;
let len = info.mws.len(); let len = info.mws.borrow().len();
loop { loop {
resp = match info.mws[curr].response(&mut info.req, resp) { let state = info.mws.borrow_mut()[curr].response(&mut info.req, resp);
resp = match state {
Err(err) => { Err(err) => {
info.count = curr + 1; info.count = curr + 1;
return FinishingMiddlewares::init(info, err.into()); return FinishingMiddlewares::init(info, err.into());
@ -632,7 +636,7 @@ impl<S: 'static> RunMiddlewares<S> {
} }
fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> { fn poll(&mut self, info: &mut ComposeInfo<S>) -> Option<ComposeState<S>> {
let len = info.mws.len(); let len = info.mws.borrow().len();
loop { loop {
// poll latest fut // poll latest fut
@ -649,7 +653,8 @@ impl<S: 'static> RunMiddlewares<S> {
if self.curr == len { if self.curr == len {
return Some(FinishingMiddlewares::init(info, resp)); return Some(FinishingMiddlewares::init(info, resp));
} else { } else {
match info.mws[self.curr].response(&mut info.req, resp) { let state = info.mws.borrow_mut()[self.curr].response(&mut info.req, resp);
match state {
Err(err) => { Err(err) => {
return Some(FinishingMiddlewares::init(info, err.into())) return Some(FinishingMiddlewares::init(info, err.into()))
} }
@ -717,9 +722,9 @@ impl<S: 'static> FinishingMiddlewares<S> {
} }
info.count -= 1; info.count -= 1;
match info.mws[info.count as usize] let state = info.mws.borrow_mut()[info.count as usize]
.finish(&mut info.req, self.resp.as_ref().unwrap()) .finish(&mut info.req, self.resp.as_ref().unwrap());
{ match state {
MiddlewareFinished::Done => { MiddlewareFinished::Done => {
if info.count == 0 { if info.count == 0 {
return Some(Response::init(self.resp.take().unwrap())); return Some(Response::init(self.resp.take().unwrap()));

View file

@ -19,21 +19,21 @@ struct MiddlewareTest {
} }
impl<S> middleware::Middleware<S> for MiddlewareTest { impl<S> middleware::Middleware<S> for MiddlewareTest {
fn start(&self, _: &mut HttpRequest<S>) -> Result<middleware::Started> { fn start(&mut self, _: &mut HttpRequest<S>) -> Result<middleware::Started> {
self.start self.start
.store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed); .store(self.start.load(Ordering::Relaxed) + 1, Ordering::Relaxed);
Ok(middleware::Started::Done) Ok(middleware::Started::Done)
} }
fn response( fn response(
&self, _: &mut HttpRequest<S>, resp: HttpResponse, &mut self, _: &mut HttpRequest<S>, resp: HttpResponse,
) -> Result<middleware::Response> { ) -> Result<middleware::Response> {
self.response self.response
.store(self.response.load(Ordering::Relaxed) + 1, Ordering::Relaxed); .store(self.response.load(Ordering::Relaxed) + 1, Ordering::Relaxed);
Ok(middleware::Response::Done(resp)) Ok(middleware::Response::Done(resp))
} }
fn finish(&self, _: &mut HttpRequest<S>, _: &HttpResponse) -> middleware::Finished { fn finish(&mut self, _: &mut HttpRequest<S>, _: &HttpResponse) -> middleware::Finished {
self.finish self.finish
.store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed); .store(self.finish.load(Ordering::Relaxed) + 1, Ordering::Relaxed);
middleware::Finished::Done middleware::Finished::Done
@ -431,7 +431,7 @@ struct MiddlewareAsyncTest {
} }
impl<S> middleware::Middleware<S> for MiddlewareAsyncTest { impl<S> middleware::Middleware<S> for MiddlewareAsyncTest {
fn start(&self, _: &mut HttpRequest<S>) -> Result<middleware::Started> { fn start(&mut self, _: &mut HttpRequest<S>) -> Result<middleware::Started> {
let to = Delay::new(Instant::now() + Duration::from_millis(10)); let to = Delay::new(Instant::now() + Duration::from_millis(10));
let start = Arc::clone(&self.start); let start = Arc::clone(&self.start);
@ -444,7 +444,7 @@ impl<S> middleware::Middleware<S> for MiddlewareAsyncTest {
} }
fn response( fn response(
&self, _: &mut HttpRequest<S>, resp: HttpResponse, &mut self, _: &mut HttpRequest<S>, resp: HttpResponse,
) -> Result<middleware::Response> { ) -> Result<middleware::Response> {
let to = Delay::new(Instant::now() + Duration::from_millis(10)); let to = Delay::new(Instant::now() + Duration::from_millis(10));
@ -457,7 +457,7 @@ impl<S> middleware::Middleware<S> for MiddlewareAsyncTest {
))) )))
} }
fn finish(&self, _: &mut HttpRequest<S>, _: &HttpResponse) -> middleware::Finished { fn finish(&mut self, _: &mut HttpRequest<S>, _: &HttpResponse) -> middleware::Finished {
let to = Delay::new(Instant::now() + Duration::from_millis(10)); let to = Delay::new(Instant::now() + Duration::from_millis(10));
let finish = Arc::clone(&self.finish); let finish = Arc::clone(&self.finish);