From afe9459ce1296ba33b76d0f9ef4d44e9a6f22024 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 21 Oct 2017 22:59:09 -0700 Subject: [PATCH] pass request by ref; added middleware support --- Cargo.toml | 3 + examples/basic.rs | 8 +- examples/{websocket => }/static/index.html | 0 .../client.py => websocket-client.py} | 0 .../{websocket/src/main.rs => websocket.rs} | 11 +- examples/websocket/Cargo.toml | 12 - src/application.rs | 86 ++++-- src/dev.rs | 3 +- src/httpcodes.rs | 2 +- src/httprequest.rs | 35 ++- src/lib.rs | 4 +- src/logger.rs | 274 ++++++++++++++++++ src/resource.rs | 6 +- src/route.rs | 23 +- src/server.rs | 22 +- src/staticfiles.rs | 19 +- src/task.rs | 72 +++-- 17 files changed, 470 insertions(+), 110 deletions(-) rename examples/{websocket => }/static/index.html (100%) rename examples/{websocket/client.py => websocket-client.py} (100%) rename examples/{websocket/src/main.rs => websocket.rs} (80%) delete mode 100644 examples/websocket/Cargo.toml create mode 100644 src/logger.rs diff --git a/Cargo.toml b/Cargo.toml index a527e263b..009f38a50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,9 @@ git = "https://github.com/fafhrd91/actix.git" default-features = false features = [] +[dev-dependencies] +env_logger = "*" + [profile.release] lto = true opt-level = 3 diff --git a/examples/basic.rs b/examples/basic.rs index 46279498b..e119ce3b3 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -6,13 +6,13 @@ extern crate env_logger; use actix_web::*; /// somple handle -fn index(req: HttpRequest, payload: Payload, state: &()) -> HttpResponse { +fn index(req: &mut HttpRequest, payload: Payload, state: &()) -> HttpResponse { println!("{:?}", req); httpcodes::HTTPOk.into() } /// handle with path parameters like `/name/{name}/` -fn with_param(req: HttpRequest, payload: Payload, state: &()) -> HttpResponse { +fn with_param(req: &mut HttpRequest, payload: Payload, state: &()) -> HttpResponse { println!("{:?}", req); HttpResponse::builder(StatusCode::OK) @@ -22,10 +22,14 @@ fn with_param(req: HttpRequest, payload: Payload, state: &()) -> HttpResponse { } fn main() { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + let _ = env_logger::init(); let sys = actix::System::new("ws-example"); HttpServer::new( Application::default("/") + // enable logger + .middleware(Logger::new(None)) // register simple handler, handle all methods .handler("/index.html", index) // with path parameters diff --git a/examples/websocket/static/index.html b/examples/static/index.html similarity index 100% rename from examples/websocket/static/index.html rename to examples/static/index.html diff --git a/examples/websocket/client.py b/examples/websocket-client.py similarity index 100% rename from examples/websocket/client.py rename to examples/websocket-client.py diff --git a/examples/websocket/src/main.rs b/examples/websocket.rs similarity index 80% rename from examples/websocket/src/main.rs rename to examples/websocket.rs index 2ab22989d..2c63687ce 100644 --- a/examples/websocket/src/main.rs +++ b/examples/websocket.rs @@ -1,6 +1,7 @@ #![allow(unused_variables)] extern crate actix; extern crate actix_web; +extern crate env_logger; use actix::*; use actix_web::*; @@ -15,7 +16,8 @@ impl Actor for MyWebSocket { impl Route for MyWebSocket { type State = (); - fn request(req: HttpRequest, payload: Payload, ctx: &mut HttpContext) -> Reply + fn request(req: &mut HttpRequest, + payload: Payload, ctx: &mut HttpContext) -> Reply { match ws::handshake(&req) { Ok(resp) => { @@ -59,12 +61,17 @@ impl Handler for MyWebSocket { } fn main() { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + let _ = env_logger::init(); let sys = actix::System::new("ws-example"); HttpServer::new( Application::default("/") + // enable logger + .middleware(Logger::new(None)) + // websocket route .resource("/ws/", |r| r.get::()) - .route_handler("/", StaticFiles::new("static/", true))) + .route_handler("/", StaticFiles::new("examples/static/", true))) .serve::<_, ()>("127.0.0.1:8080").unwrap(); println!("Started http server: 127.0.0.1:8080"); diff --git a/examples/websocket/Cargo.toml b/examples/websocket/Cargo.toml deleted file mode 100644 index 33d9ed39e..000000000 --- a/examples/websocket/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "websocket-example" -version = "0.1.0" -authors = ["Nikolay Kim "] - -[[bin]] -name = "websocket" -path = "src/main.rs" - -[dependencies] -actix = { git = "https://github.com/fafhrd91/actix.git" } -actix-web = { path = "../../" } diff --git a/src/application.rs b/src/application.rs index 1f4295085..95ff8d47a 100644 --- a/src/application.rs +++ b/src/application.rs @@ -12,6 +12,24 @@ use httpresponse::HttpResponse; use server::HttpHandler; +#[allow(unused_variables)] +pub trait Middleware { + + /// Method is called when request is ready. + fn start(&self, req: &mut HttpRequest) -> Result<(), HttpResponse> { + Ok(()) + } + + /// Method is called when handler returns response, + /// but before sending body stream to peer. + fn response(&self, req: &mut HttpRequest, resp: HttpResponse) -> HttpResponse { + resp + } + + /// Http interation is finished + fn finish(&self, req: &mut HttpRequest, resp: &HttpResponse) {} +} + /// Application pub struct Application { state: Rc, @@ -19,6 +37,26 @@ pub struct Application { default: Resource, handlers: HashMap>>, router: RouteRecognizer>, + middlewares: Rc>>, +} + +impl Application { + + fn run(&self, req: &mut HttpRequest, payload: Payload) -> Task { + if let Some((params, h)) = self.router.recognize(req.path()) { + if let Some(params) = params { + req.set_match_info(params); + } + h.handle(req, payload, Rc::clone(&self.state)) + } else { + for (prefix, handler) in &self.handlers { + if req.path().starts_with(prefix) { + return handler.handle(req, payload, Rc::clone(&self.state)) + } + } + self.default.handle(req, payload, Rc::clone(&self.state)) + } + } } impl HttpHandler for Application { @@ -27,21 +65,19 @@ impl HttpHandler for Application { &self.prefix } - fn handle(&self, req: HttpRequest, payload: Payload) -> Task { - if let Some((params, h)) = self.router.recognize(req.path()) { - if let Some(params) = params { - h.handle( - req.with_match_info(params), payload, Rc::clone(&self.state)) - } else { - h.handle(req, payload, Rc::clone(&self.state)) + fn handle(&self, req: &mut HttpRequest, payload: Payload) -> Task { + // run middlewares + if !self.middlewares.is_empty() { + for middleware in self.middlewares.iter() { + if let Err(resp) = middleware.start(req) { + return Task::reply(resp) + }; } + let mut task = self.run(req, payload); + task.set_middlewares(Rc::clone(&self.middlewares)); + task } else { - for (prefix, handler) in &self.handlers { - if req.path().starts_with(prefix) { - return handler.handle(req, payload, Rc::clone(&self.state)) - } - } - self.default.handle(req, payload, Rc::clone(&self.state)) + self.run(req, payload) } } } @@ -56,7 +92,9 @@ impl Application<()> { prefix: prefix.to_string(), default: Resource::default(), handlers: HashMap::new(), - resources: HashMap::new()}) + resources: HashMap::new(), + middlewares: Vec::new(), + }) } } } @@ -73,7 +111,9 @@ impl Application where S: 'static { prefix: prefix.to_string(), default: Resource::default(), handlers: HashMap::new(), - resources: HashMap::new()}) + resources: HashMap::new(), + middlewares: Vec::new(), + }) } } } @@ -84,6 +124,7 @@ struct ApplicationBuilderParts { default: Resource, handlers: HashMap>>, resources: HashMap>, + middlewares: Vec>, } /// Application builder @@ -192,7 +233,7 @@ impl ApplicationBuilder where S: 'static { /// } /// ``` pub fn handler(&mut self, path: P, handler: F) -> &mut Self - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Into + 'static, P: ToString, { @@ -217,6 +258,15 @@ impl ApplicationBuilder where S: 'static { self } + /// Construct application + pub fn middleware(&mut self, mw: T) -> &mut Self + where T: Middleware + 'static + { + self.parts.as_mut().expect("Use after finish") + .middlewares.push(Box::new(mw)); + self + } + /// Construct application pub fn finish(&mut self) -> Application { let parts = self.parts.take().expect("Use after finish"); @@ -243,7 +293,9 @@ impl ApplicationBuilder where S: 'static { prefix: prefix.clone(), default: parts.default, handlers: handlers, - router: RouteRecognizer::new(prefix, routes) } + router: RouteRecognizer::new(prefix, routes), + middlewares: Rc::new(parts.middlewares), + } } } diff --git a/src/dev.rs b/src/dev.rs index aa209fb09..f87dae74c 100644 --- a/src/dev.rs +++ b/src/dev.rs @@ -10,13 +10,14 @@ pub use ws; pub use httpcodes; pub use error::ParseError; -pub use application::{Application, ApplicationBuilder}; +pub use application::{Application, ApplicationBuilder, Middleware}; pub use httprequest::HttpRequest; pub use httpresponse::{Body, HttpResponse, HttpResponseBuilder}; pub use payload::{Payload, PayloadItem, PayloadError}; pub use resource::{Reply, Resource}; pub use route::{Route, RouteFactory, RouteHandler}; pub use recognizer::Params; +pub use logger::Logger; pub use server::HttpServer; pub use context::HttpContext; pub use staticfiles::StaticFiles; diff --git a/src/httpcodes.rs b/src/httpcodes.rs index f575ad528..e54876900 100644 --- a/src/httpcodes.rs +++ b/src/httpcodes.rs @@ -61,7 +61,7 @@ impl StaticResponse { } impl RouteHandler for StaticResponse { - fn handle(&self, _: HttpRequest, _: Payload, _: Rc) -> Task { + fn handle(&self, _: &mut HttpRequest, _: Payload, _: Rc) -> Task { Task::reply(HttpResponse::new(self.0, Body::Empty)) } } diff --git a/src/httprequest.rs b/src/httprequest.rs index 45bdd77bf..6e936d653 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use bytes::BytesMut; use futures::{Async, Future, Stream, Poll}; use url::form_urlencoded; -use http::{header, Method, Version, Uri, HeaderMap}; +use http::{header, Method, Version, Uri, HeaderMap, Extensions}; use {Cookie, CookieParseError}; use {HttpRange, HttpRangeParseError}; @@ -22,6 +22,7 @@ pub struct HttpRequest { headers: HeaderMap, params: Params, cookies: Vec>, + extensions: Extensions, } impl HttpRequest { @@ -35,9 +36,28 @@ impl HttpRequest { headers: headers, params: Params::empty(), cookies: Vec::new(), + extensions: Extensions::new(), } } + pub(crate) fn for_error() -> HttpRequest { + HttpRequest { + method: Method::GET, + uri: Uri::default(), + version: Version::HTTP_11, + headers: HeaderMap::new(), + params: Params::empty(), + cookies: Vec::new(), + extensions: Extensions::new(), + } + } + + /// Protocol extensions. + #[inline] + pub fn extensions(&mut self) -> &mut Extensions { + &mut self.extensions + } + /// Read the Request Uri. #[inline] pub fn uri(&self) -> &Uri { &self.uri } @@ -111,16 +131,9 @@ impl HttpRequest { #[inline] pub fn match_info(&self) -> &Params { &self.params } - /// Create new request with Params object. - pub fn with_match_info(self, params: Params) -> Self { - HttpRequest { - method: self.method, - uri: self.uri, - version: self.version, - headers: self.headers, - params: params, - cookies: self.cookies, - } + /// Set request Params. + pub fn set_match_info(&mut self, params: Params) { + self.params = params; } /// Checks if a connection should be kept alive. diff --git a/src/lib.rs b/src/lib.rs index 34e45b21e..f1c3d1e87 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,7 @@ mod date; mod decode; mod httprequest; mod httpresponse; +mod logger; mod payload; mod resource; mod recognizer; @@ -48,13 +49,14 @@ pub mod dev; pub mod httpcodes; pub mod multipart; pub use error::ParseError; -pub use application::{Application, ApplicationBuilder}; +pub use application::{Application, ApplicationBuilder, Middleware}; pub use httprequest::{HttpRequest, UrlEncoded}; pub use httpresponse::{Body, HttpResponse, HttpResponseBuilder}; pub use payload::{Payload, PayloadItem, PayloadError}; pub use route::{Route, RouteFactory, RouteHandler}; pub use resource::{Reply, Resource}; pub use recognizer::{Params, RouteRecognizer}; +pub use logger::Logger; pub use server::HttpServer; pub use context::HttpContext; pub use staticfiles::StaticFiles; diff --git a/src/logger.rs b/src/logger.rs new file mode 100644 index 000000000..1936bc437 --- /dev/null +++ b/src/logger.rs @@ -0,0 +1,274 @@ +//! Request logging middleware +use std::fmt; +use std::str::Chars; +use std::iter::Peekable; +use std::fmt::{Display, Formatter}; + +use time; + +use application::Middleware; +use httprequest::HttpRequest; +use httpresponse::HttpResponse; + +/// `Middleware` for logging request and response info to the terminal. +pub struct Logger { + format: Format, +} + +impl Logger { + /// Create `Logger` middlewares with the specified `format`. + /// If a `None` is passed in, uses the default format: + /// + /// ```ignore + /// {method} {uri} -> {status} ({response-time} ms) + /// ``` + /// + /// ```rust,ignore + /// let app = Application::default("/") + /// .middleware(Logger::new(None)) + /// .finish() + /// ``` + pub fn new(format: Option) -> Logger { + let format = format.unwrap_or_default(); + Logger { format: format.clone() } + } +} + +struct StartTime(time::Tm); + +impl Logger { + fn initialise(&self, req: &mut HttpRequest) { + req.extensions().insert(StartTime(time::now())); + } + + fn log(&self, req: &mut HttpRequest, resp: &HttpResponse) { + let entry_time = req.extensions().get::().unwrap().0; + + let response_time = time::now() - entry_time; + let response_time_ms = (response_time.num_seconds() * 1000) as f64 + (response_time.num_nanoseconds().unwrap_or(0) as f64) / 1000000.0; + + { + let render = |fmt: &mut Formatter, text: &FormatText| { + match *text { + FormatText::Str(ref string) => fmt.write_str(string), + FormatText::Method => req.method().fmt(fmt), + FormatText::URI => req.uri().fmt(fmt), + FormatText::Status => resp.status().fmt(fmt), + FormatText::ResponseTime => + fmt.write_fmt(format_args!("{} ms", response_time_ms)), + FormatText::RemoteAddr => Ok(()), //req.remote_addr.fmt(fmt), + FormatText::RequestTime => { + entry_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ%z") + .unwrap() + .fmt(fmt) + } + } + }; + + info!("{}", self.format.display_with(&render)); + //println!("{}", self.format.display_with(&render)); + } + } +} + +impl Middleware for Logger { + fn start(&self, req: &mut HttpRequest) -> Result<(), HttpResponse> { + self.initialise(req); + Ok(()) + } + + fn finish(&self, req: &mut HttpRequest, resp: &HttpResponse) { + self.log(req, resp); + } +} + + +use self::FormatText::{Method, URI, Status, ResponseTime, RemoteAddr, RequestTime}; + +/// A formatting style for the `Logger`, consisting of multiple +/// `FormatText`s concatenated into one line. +#[derive(Clone)] +#[doc(hidden)] +pub struct Format(Vec); + +impl Default for Format { + /// Return the default formatting style for the `Logger`: + /// + /// ```ignore + /// {method} {uri} -> {status} ({response-time}) + /// // This will be written as: {method} {uri} -> {status} ({response-time}) + /// ``` + fn default() -> Format { + Format::new("{method} {uri} {status} ({response-time})").unwrap() + } +} + +impl Format { + /// Create a `Format` from a format string, which can contain the fields + /// `{method}`, `{uri}`, `{status}`, `{response-time}`, `{ip-addr}` and + /// `{request-time}`. + /// + /// Returns `None` if the format string syntax is incorrect. + pub fn new(s: &str) -> Option { + + let parser = FormatParser::new(s.chars().peekable()); + + let mut results = Vec::new(); + + for unit in parser { + match unit { + Some(unit) => results.push(unit), + None => return None + } + } + + Some(Format(results)) + } +} + +pub(crate) trait ContextDisplay<'a> { + type Item; + type Display: fmt::Display; + fn display_with(&'a self, + render: &'a Fn(&mut Formatter, &Self::Item) -> Result<(), fmt::Error>) + -> Self::Display; +} + +impl<'a> ContextDisplay<'a> for Format { + type Item = FormatText; + type Display = FormatDisplay<'a>; + fn display_with(&'a self, + render: &'a Fn(&mut Formatter, &FormatText) -> Result<(), fmt::Error>) + -> FormatDisplay<'a> { + FormatDisplay { + format: self, + render: render, + } + } +} + +struct FormatParser<'a> { + // The characters of the format string. + chars: Peekable>, + + // A reusable buffer for parsing style attributes. + object_buffer: String, + + finished: bool +} + +impl<'a> FormatParser<'a> { + fn new(chars: Peekable) -> FormatParser { + FormatParser { + chars: chars, + + // No attributes are longer than 14 characters, so we can avoid reallocating. + object_buffer: String::with_capacity(14), + + finished: false + } + } +} + +// Some(None) means there was a parse error and this FormatParser should be abandoned. +impl<'a> Iterator for FormatParser<'a> { + type Item = Option; + + fn next(&mut self) -> Option> { + // If the parser has been cancelled or errored for some reason. + if self.finished { return None } + + // Try to parse a new FormatText. + match self.chars.next() { + // Parse a recognized object. + // + // The allowed forms are: + // - {method} + // - {uri} + // - {status} + // - {response-time} + // - {ip-addr} + // - {request-time} + Some('{') => { + self.object_buffer.clear(); + + let mut chr = self.chars.next(); + while chr != None { + match chr.unwrap() { + // Finished parsing, parse buffer. + '}' => break, + c => self.object_buffer.push(c.clone()) + } + + chr = self.chars.next(); + } + + let text = match self.object_buffer.as_ref() { + "method" => Method, + "uri" => URI, + "status" => Status, + "response-time" => ResponseTime, + "request-time" => RequestTime, + "ip-addr" => RemoteAddr, + _ => { + // Error, so mark as finished. + self.finished = true; + return Some(None); + } + }; + + Some(Some(text)) + } + + // Parse a regular string part of the format string. + Some(c) => { + let mut buffer = String::new(); + buffer.push(c); + + loop { + match self.chars.peek() { + // Done parsing. + Some(&'{') | None => return Some(Some(FormatText::Str(buffer))), + + Some(_) => { + buffer.push(self.chars.next().unwrap()) + } + } + } + }, + + // Reached end of the format string. + None => None + } + } +} + +/// A string of text to be logged. This is either one of the data +/// fields supported by the `Logger`, or a custom `String`. +#[derive(Clone)] +#[doc(hidden)] +pub enum FormatText { + Str(String), + Method, + URI, + Status, + ResponseTime, + RemoteAddr, + RequestTime +} + + +pub(crate) struct FormatDisplay<'a> { + format: &'a Format, + render: &'a Fn(&mut Formatter, &FormatText) -> Result<(), fmt::Error>, +} + +impl<'a> fmt::Display for FormatDisplay<'a> { + fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { + let Format(ref format) = *self.format; + for unit in format { + (self.render)(fmt, unit)?; + } + Ok(()) + } +} diff --git a/src/resource.rs b/src/resource.rs index d1a8aba3e..0ad10220a 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -58,7 +58,7 @@ impl Resource where S: 'static { /// Register handler for specified method. pub fn handler(&mut self, method: Method, handler: F) - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Into + 'static, { self.routes.insert(method, Box::new(FnHandler::new(handler))); @@ -66,7 +66,7 @@ impl Resource where S: 'static { /// Register async handler for specified method. pub fn async(&mut self, method: Method, handler: F) - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Stream + 'static, { self.routes.insert(method, Box::new(StreamHandler::new(handler))); @@ -119,7 +119,7 @@ impl Resource where S: 'static { impl RouteHandler for Resource { - fn handle(&self, req: HttpRequest, payload: Payload, state: Rc) -> Task { + fn handle(&self, req: &mut HttpRequest, payload: Payload, state: Rc) -> Task { if let Some(handler) = self.routes.get(req.method()) { handler.handle(req, payload, state) } else { diff --git a/src/route.rs b/src/route.rs index 487931600..48946fe85 100644 --- a/src/route.rs +++ b/src/route.rs @@ -27,7 +27,7 @@ pub enum Frame { #[allow(unused_variables)] pub trait RouteHandler: 'static { /// Handle request - fn handle(&self, req: HttpRequest, payload: Payload, state: Rc) -> Task; + fn handle(&self, req: &mut HttpRequest, payload: Payload, state: Rc) -> Task; /// Set route prefix fn set_prefix(&mut self, prefix: String) {} @@ -73,7 +73,8 @@ pub trait Route: Actor { /// request/response or websocket connection. /// In that case `HttpContext::start` and `HttpContext::write` has to be used /// for writing response. - fn request(req: HttpRequest, payload: Payload, ctx: &mut Self::Context) -> Reply; + fn request(req: &mut HttpRequest, + payload: Payload, ctx: &mut Self::Context) -> Reply; /// This method creates `RouteFactory` for this actor. fn factory() -> RouteFactory { @@ -88,7 +89,7 @@ impl RouteHandler for RouteFactory where A: Actor> + Route, S: 'static { - fn handle(&self, req: HttpRequest, payload: Payload, state: Rc) -> Task + fn handle(&self, req: &mut HttpRequest, payload: Payload, state: Rc) -> Task { let mut ctx = HttpContext::new(state); @@ -105,7 +106,7 @@ impl RouteHandler for RouteFactory /// Fn() route handler pub(crate) struct FnHandler - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Into, S: 'static, { @@ -114,7 +115,7 @@ struct FnHandler } impl FnHandler - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Into + 'static, S: 'static, { @@ -124,11 +125,11 @@ impl FnHandler } impl RouteHandler for FnHandler - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Into + 'static, S: 'static, { - fn handle(&self, req: HttpRequest, payload: Payload, state: Rc) -> Task + fn handle(&self, req: &mut HttpRequest, payload: Payload, state: Rc) -> Task { Task::reply((self.f)(req, payload, &state).into()) } @@ -137,7 +138,7 @@ impl RouteHandler for FnHandler /// Async route handler pub(crate) struct StreamHandler - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Stream + 'static, S: 'static, { @@ -146,7 +147,7 @@ struct StreamHandler } impl StreamHandler - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Stream + 'static, S: 'static, { @@ -156,11 +157,11 @@ impl StreamHandler } impl RouteHandler for StreamHandler - where F: Fn(HttpRequest, Payload, &S) -> R + 'static, + where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, R: Stream + 'static, S: 'static, { - fn handle(&self, req: HttpRequest, payload: Payload, state: Rc) -> Task + fn handle(&self, req: &mut HttpRequest, payload: Payload, state: Rc) -> Task { Task::with_stream( (self.f)(req, payload, &state).map_err( diff --git a/src/server.rs b/src/server.rs index 45f77525c..aa4ec4b9c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,5 +1,6 @@ -use std::{io, mem, net}; +use std::{io, net}; use std::rc::Rc; +use std::cell::UnsafeCell; use std::time::Duration; use std::marker::PhantomData; use std::collections::VecDeque; @@ -10,7 +11,7 @@ use tokio_core::reactor::Timeout; use tokio_core::net::{TcpListener, TcpStream}; use tokio_io::{AsyncRead, AsyncWrite}; -use task::{Task, RequestInfo}; +use task::Task; use reader::{Reader, ReaderError}; use payload::Payload; use httpcodes::HTTPNotFound; @@ -21,7 +22,7 @@ pub trait HttpHandler: 'static { /// Http handler prefix fn prefix(&self) -> &str; /// Handle request - fn handle(&self, req: HttpRequest, payload: Payload) -> Task; + fn handle(&self, req: &mut HttpRequest, payload: Payload) -> Task; } /// An HTTP Server @@ -148,7 +149,7 @@ impl Handler, io::Error> for HttpServer struct Entry { task: Task, - req: RequestInfo, + req: UnsafeCell, eof: bool, error: bool, finished: bool, @@ -213,9 +214,7 @@ impl Future for HttpChannel } // this is anoying - let req: &RequestInfo = unsafe { - mem::transmute(&self.items[idx].req) - }; + let req = unsafe {self.items[idx].req.get().as_mut().unwrap()}; match self.items[idx].task.poll_io(&mut self.stream, req) { Ok(Async::Ready(ready)) => { @@ -280,23 +279,22 @@ impl Future for HttpChannel // read incoming data if !self.error && self.items.len() < MAX_PIPELINED_MESSAGES { match self.reader.parse(&mut self.stream) { - Ok(Async::Ready((req, payload))) => { + Ok(Async::Ready((mut req, payload))) => { // stop keepalive timer self.keepalive_timer.take(); // start request processing - let info = RequestInfo::new(&req); let mut task = None; for h in self.router.iter() { if req.path().starts_with(h.prefix()) { - task = Some(h.handle(req, payload)); + task = Some(h.handle(&mut req, payload)); break } } self.items.push_back( Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)), - req: info, + req: UnsafeCell::new(req), eof: false, error: false, finished: false}); @@ -313,7 +311,7 @@ impl Future for HttpChannel if let ReaderError::Error(err) = err { self.items.push_back( Entry {task: Task::reply(err), - req: RequestInfo::for_error(), + req: UnsafeCell::new(HttpRequest::for_error()), eof: false, error: false, finished: false}); diff --git a/src/staticfiles.rs b/src/staticfiles.rs index c37c7d978..6ea1d2477 100644 --- a/src/staticfiles.rs +++ b/src/staticfiles.rs @@ -48,14 +48,19 @@ impl StaticFiles { pub fn new>(dir: D, index: bool) -> StaticFiles { let dir = dir.into(); - let (dir, access) = if let Ok(dir) = dir.canonicalize() { - if dir.is_dir() { - (dir, true) - } else { + let (dir, access) = match dir.canonicalize() { + Ok(dir) => { + if dir.is_dir() { + (dir, true) + } else { + warn!("Is not directory `{:?}`", dir); + (dir, false) + } + }, + Err(err) => { + warn!("Static files directory `{:?}` error: {}", dir, err); (dir, false) } - } else { - (dir, false) }; StaticFiles { @@ -134,7 +139,7 @@ impl RouteHandler for StaticFiles { } } - fn handle(&self, req: HttpRequest, payload: Payload, state: Rc) -> Task { + fn handle(&self, req: &mut HttpRequest, payload: Payload, state: Rc) -> Task { if !self.accessible { Task::reply(HTTPNotFound) } else { diff --git a/src/task.rs b/src/task.rs index 836c00481..782d710f0 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,4 +1,5 @@ use std::{cmp, io}; +use std::rc::Rc; use std::fmt::Write; use std::collections::VecDeque; @@ -11,6 +12,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use date; use route::Frame; +use application::Middleware; use httprequest::HttpRequest; use httpresponse::{Body, HttpResponse}; @@ -44,26 +46,6 @@ impl TaskIOState { } } -pub(crate) struct RequestInfo { - version: Version, - keep_alive: bool, -} - -impl RequestInfo { - pub fn new(req: &HttpRequest) -> Self { - RequestInfo { - version: req.version(), - keep_alive: req.keep_alive(), - } - } - pub fn for_error() -> Self { - RequestInfo { - version: Version::HTTP_11, - keep_alive: false, - } - } -} - pub struct Task { state: TaskRunningState, iostate: TaskIOState, @@ -73,7 +55,8 @@ pub struct Task { buffer: BytesMut, upgrade: bool, keepalive: bool, - prepared: bool, + prepared: Option, + middlewares: Option>>>, } impl Task { @@ -92,7 +75,8 @@ impl Task { buffer: BytesMut::new(), upgrade: false, keepalive: false, - prepared: false, + prepared: None, + middlewares: None, } } @@ -108,7 +92,8 @@ impl Task { buffer: BytesMut::new(), upgrade: false, keepalive: false, - prepared: false, + prepared: None, + middlewares: None, } } @@ -116,15 +101,31 @@ impl Task { self.keepalive && !self.upgrade } - fn prepare(&mut self, req: &RequestInfo, mut msg: HttpResponse) + pub(crate) fn set_middlewares(&mut self, middlewares: Rc>>) { + self.middlewares = Some(middlewares); + } + + fn prepare(&mut self, req: &mut HttpRequest, msg: HttpResponse) { trace!("Prepare message status={:?}", msg.status); + // run middlewares + let mut msg = if let Some(middlewares) = self.middlewares.take() { + let mut msg = msg; + for middleware in middlewares.iter() { + msg = middleware.response(req, msg); + } + self.middlewares = Some(middlewares); + msg + } else { + msg + }; + + // prepare task let mut extra = 0; let body = msg.replace_body(Body::Empty); - let version = msg.version().unwrap_or_else(|| req.version); - self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive); - self.prepared = true; + let version = msg.version().unwrap_or_else(|| req.version()); + self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive()); match body { Body::Empty => { @@ -219,12 +220,14 @@ impl Task { if let Body::Binary(ref bytes) = body { self.buffer.extend(bytes); + self.prepared = Some(msg); return } msg.replace_body(body); + self.prepared = Some(msg); } - pub(crate) fn poll_io(&mut self, io: &mut T, info: &RequestInfo) -> Poll + pub(crate) fn poll_io(&mut self, io: &mut T, req: &mut HttpRequest) -> Poll where T: AsyncRead + AsyncWrite { trace!("POLL-IO frames:{:?}", self.frames.len()); @@ -248,10 +251,10 @@ impl Task { trace!("IO Frame: {:?}", frame); match frame { Frame::Message(response) => { - self.prepare(info, response); + self.prepare(req, response); } Frame::Payload(Some(chunk)) => { - if self.prepared { + if self.prepared.is_some() { // TODO: add warning, write after EOF self.encoder.encode(&mut self.buffer, chunk.as_ref()); } else { @@ -295,6 +298,15 @@ impl Task { // response is completed if self.buffer.is_empty() && self.iostate.is_done() { + // run middlewares + if let Some(ref mut resp) = self.prepared { + if let Some(middlewares) = self.middlewares.take() { + for middleware in middlewares.iter() { + middleware.finish(req, resp); + } + } + } + Ok(Async::Ready(self.state.is_done())) } else { Ok(Async::NotReady)