From de71ad7de463f3194545ff03c261406f424f8c01 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 15 Nov 2017 20:06:28 -1000 Subject: [PATCH] refactor error handling --- CHANGES.md | 2 + Cargo.toml | 8 +- examples/basic.rs | 5 +- src/context.rs | 5 +- src/encoding.rs | 3 +- src/error.rs | 293 +++++++++++++++++++++++++++++++++----------- src/h1.rs | 10 +- src/h2.rs | 3 +- src/httprequest.rs | 12 +- src/httpresponse.rs | 33 +---- src/lib.rs | 15 ++- src/multipart.rs | 63 +--------- src/payload.rs | 57 +-------- src/resource.rs | 5 +- src/route.rs | 15 +-- src/task.rs | 13 +- src/ws.rs | 68 +++------- src/wsproto.rs | 32 ++--- 18 files changed, 317 insertions(+), 325 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 0d36f63f3..939e0b8c8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,6 +5,8 @@ * HTTP/2 Support +* Refactor error handling + * Asynchronous middlewares * Content compression/decompression (br, gzip, deflate) diff --git a/Cargo.toml b/Cargo.toml index e25b983d2..6cff09638 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,8 @@ alpn = ["openssl", "openssl/v102", "openssl/v110", "tokio-openssl"] [dependencies] log = "0.3" +failure = { git = "https://github.com/withoutboats/failure" } +failure_derive = { git = "https://github.com/withoutboats/failure_derive" } time = "0.1" http = "0.1" httparse = "0.1" @@ -44,11 +46,15 @@ cookie = { version="0.10", features=["percent-encode", "secure"] } regex = "0.2" sha1 = "0.2" url = "1.5" -libc = "^0.2" +libc = "0.2" +serde = "1.0" +serde_json = "1.0" flate2 = "0.2" brotli2 = "^0.3.2" percent-encoding = "1.0" +# redis-async = { git="https://github.com/benashford/redis-async-rs" } + # tokio bytes = "0.4" futures = "0.1" diff --git a/examples/basic.rs b/examples/basic.rs index 7e05b07c4..8ae36d9c6 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -19,7 +19,8 @@ fn index(req: &mut HttpRequest, mut _payload: Payload, state: &()) -> HttpRespon } /// somple handle -fn index_async(req: &mut HttpRequest, _payload: Payload, state: &()) -> Once +fn index_async(req: &mut HttpRequest, _payload: Payload, state: &()) + -> Once { println!("{:?}", req); @@ -49,7 +50,7 @@ fn main() { HttpServer::new( Application::default("/") // enable logger - .middleware(middlewares::Logger::default()) + //.middleware(middlewares::Logger::default()) // register simple handle r, handle all methods .handler("/index.html", index) // with path parameters diff --git a/src/context.rs b/src/context.rs index c2493dd64..89a145b81 100644 --- a/src/context.rs +++ b/src/context.rs @@ -14,6 +14,7 @@ use actix::dev::{AsyncContextApi, ActorAddressCell, ActorItemsCell, ActorWaitCel use task::{IoContext, DrainFut}; use body::Binary; +use error::Error; use route::{Route, Frame}; use httpresponse::HttpResponse; @@ -184,9 +185,9 @@ impl HttpContext where A: Actor + Route { impl Stream for HttpContext where A: Actor + Route { type Item = Frame; - type Error = std::io::Error; + type Error = Error; - fn poll(&mut self) -> Poll, std::io::Error> { + fn poll(&mut self) -> Poll, Error> { if self.act.is_none() { return Ok(Async::NotReady) } diff --git a/src/encoding.rs b/src/encoding.rs index 4781f3aa7..27138ff73 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -14,9 +14,10 @@ use brotli2::write::{BrotliDecoder, BrotliEncoder}; use bytes::{Bytes, BytesMut, BufMut, Writer}; use body::Body; +use error::PayloadError; use httprequest::HttpRequest; use httpresponse::HttpResponse; -use payload::{PayloadSender, PayloadWriter, PayloadError}; +use payload::{PayloadSender, PayloadWriter}; /// Represents supported types of content encodings #[derive(Copy, Clone, PartialEq, Debug)] diff --git a/src/error.rs b/src/error.rs index fb64456d2..5cf9ef947 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,78 +1,113 @@ //! Error and Result module. -use std::error::Error as StdError; -use std::fmt; -use std::io::Error as IoError; +use std::{fmt, result}; use std::str::Utf8Error; use std::string::FromUtf8Error; +use std::io::Error as IoError; use cookie; use httparse; -use http::{StatusCode, Error as HttpError}; +use failure::Fail; +use http2::Error as Http2Error; +use http::{header, StatusCode, Error as HttpError}; +use http_range::HttpRangeParseError; + +// re-exports +pub use cookie::{ParseError as CookieParseError}; -use HttpRangeParseError; -use multipart::MultipartError; use body::Body; -use httpresponse::{HttpResponse}; +use httpresponse::HttpResponse; +use httpcodes::{HTTPBadRequest, HTTPMethodNotAllowed}; +/// A specialized [`Result`](https://doc.rust-lang.org/std/result/enum.Result.html) +/// for actix web operations. +/// +/// This typedef is generally used to avoid writing out `actix_web::error::Error` directly and +/// is otherwise a direct mapping to `Result`. +pub type Result = result::Result; + +/// Actix web error. +#[derive(Debug)] +pub struct Error { + cause: Box, +} + +/// Error that can be converted to HttpResponse +pub trait ErrorResponse: Fail { + + /// Create response for error + /// + /// Internal server error is generated by default. + fn error_response(&self) -> HttpResponse { + HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR, Body::Empty) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(&self.cause, f) + } +} + +/// `HttpResponse` for `Error`. +impl From for HttpResponse { + fn from(err: Error) -> Self { + err.cause.error_response() + } +} + +impl From for Error { + fn from(err: T) -> Error { + Error { cause: Box::new(err) } + } +} + +// /// Default error is `InternalServerError` +// impl ErrorResponse for T { +// fn error_response(&self) -> HttpResponse { +// HttpResponse::new(StatusCode::INTERNAL_SERVER_ERROR, Body::Empty) +// } +// } /// A set of errors that can occur during parsing HTTP streams. -#[derive(Debug)] +#[derive(Fail, Debug)] pub enum ParseError { /// An invalid `Method`, such as `GE,T`. + #[fail(display="Invalid Method specified")] Method, /// An invalid `Uri`, such as `exam ple.domain`. + #[fail(display="Uri error")] Uri, /// An invalid `HttpVersion`, such as `HTP/1.1` + #[fail(display="Invalid HTTP version specified")] Version, /// An invalid `Header`. + #[fail(display="Invalid Header provided")] Header, /// A message head is too large to be reasonable. + #[fail(display="Message head is too large")] TooLarge, /// A message reached EOF, but is not complete. + #[fail(display="Message is incomplete")] Incomplete, /// An invalid `Status`, such as `1337 ELITE`. + #[fail(display="Invalid Status provided")] Status, /// A timeout occurred waiting for an IO event. #[allow(dead_code)] + #[fail(display="Timeout")] Timeout, /// An `io::Error` that occurred while trying to read or write to a network stream. + #[fail(display="IO error: {}", _0)] Io(IoError), /// Parsing a field as string failed + #[fail(display="UTF8 error: {}", _0)] Utf8(Utf8Error), } -impl fmt::Display for ParseError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - ParseError::Io(ref e) => fmt::Display::fmt(e, f), - ParseError::Utf8(ref e) => fmt::Display::fmt(e, f), - ref e => f.write_str(e.description()), - } - } -} - -impl StdError for ParseError { - fn description(&self) -> &str { - match *self { - ParseError::Method => "Invalid Method specified", - ParseError::Version => "Invalid HTTP version specified", - ParseError::Header => "Invalid Header provided", - ParseError::TooLarge => "Message head is too large", - ParseError::Status => "Invalid Status provided", - ParseError::Incomplete => "Message is incomplete", - ParseError::Timeout => "Timeout", - ParseError::Uri => "Uri error", - ParseError::Io(ref e) => e.description(), - ParseError::Utf8(ref e) => e.description(), - } - } - - fn cause(&self) -> Option<&StdError> { - match *self { - ParseError::Io(ref error) => Some(error), - ParseError::Utf8(ref error) => Some(error), - _ => None, - } +/// Return `BadRequest` for `ParseError` +impl ErrorResponse for ParseError { + fn error_response(&self) -> HttpResponse { + HttpResponse::new(StatusCode::BAD_REQUEST, Body::Empty) } } @@ -108,47 +143,158 @@ impl From for ParseError { } } -/// Return `BadRequest` for `ParseError` -impl From for HttpResponse { - fn from(err: ParseError) -> Self { - HttpResponse::from_error(StatusCode::BAD_REQUEST, err) +#[derive(Fail, Debug)] +/// A set of errors that can occur during payload parsing. +pub enum PayloadError { + /// A payload reached EOF, but is not complete. + #[fail(display="A payload reached EOF, but is not complete.")] + Incomplete, + /// Content encoding stream corruption + #[fail(display="Can not decode content-encoding.")] + EncodingCorrupted, + /// Parse error + #[fail(display="{}", _0)] + ParseError(#[cause] IoError), + /// Http2 error + #[fail(display="{}", _0)] + Http2(#[cause] Http2Error), +} + +impl From for PayloadError { + fn from(err: IoError) -> PayloadError { + PayloadError::ParseError(err) } } /// Return `InternalServerError` for `HttpError`, /// Response generation can return `HttpError`, so it is internal error -impl From for HttpResponse { - fn from(err: HttpError) -> Self { - HttpResponse::from_error(StatusCode::INTERNAL_SERVER_ERROR, err) - } -} +impl ErrorResponse for HttpError {} /// Return `InternalServerError` for `io::Error` -impl From for HttpResponse { - fn from(err: IoError) -> Self { - HttpResponse::from_error(StatusCode::INTERNAL_SERVER_ERROR, err) +impl ErrorResponse for IoError {} + +/// Return `BadRequest` for `cookie::ParseError` +impl ErrorResponse for cookie::ParseError { + fn error_response(&self) -> HttpResponse { + HttpResponse::new(StatusCode::BAD_REQUEST, Body::Empty) } } -/// Return `BadRequest` for `cookie::ParseError` -impl From for HttpResponse { - fn from(err: cookie::ParseError) -> Self { - HttpResponse::from_error(StatusCode::BAD_REQUEST, err) +/// Http range header parsing error +#[derive(Fail, Debug)] +pub enum HttpRangeError { + /// Returned if range is invalid. + #[fail(display="Range header is invalid")] + InvalidRange, + /// Returned if first-byte-pos of all of the byte-range-spec + /// values is greater than the content size. + /// See https://github.com/golang/go/commit/aa9b3d7 + #[fail(display="First-byte-pos of all of the byte-range-spec values is greater than the content size")] + NoOverlap, +} + +/// Return `BadRequest` for `HttpRangeError` +impl ErrorResponse for HttpRangeError { + fn error_response(&self) -> HttpResponse { + HttpResponse::new( + StatusCode::BAD_REQUEST, Body::from("Invalid Range header provided")) + } +} + +impl From for HttpRangeError { + fn from(err: HttpRangeParseError) -> HttpRangeError { + match err { + HttpRangeParseError::InvalidRange => HttpRangeError::InvalidRange, + HttpRangeParseError::NoOverlap => HttpRangeError::NoOverlap, + } + } +} + +/// A set of errors that can occur during parsing multipart streams. +#[derive(Fail, Debug)] +pub enum MultipartError { + /// Content-Type header is not found + #[fail(display="No Content-type header found")] + NoContentType, + /// Can not parse Content-Type header + #[fail(display="Can not parse Content-Type header")] + ParseContentType, + /// Multipart boundary is not found + #[fail(display="Multipart boundary is not found")] + Boundary, + /// Error during field parsing + #[fail(display="{}", _0)] + Parse(#[cause] ParseError), + /// Payload error + #[fail(display="{}", _0)] + Payload(#[cause] PayloadError), +} + +impl From for MultipartError { + fn from(err: ParseError) -> MultipartError { + MultipartError::Parse(err) + } +} + +impl From for MultipartError { + fn from(err: PayloadError) -> MultipartError { + MultipartError::Payload(err) } } /// Return `BadRequest` for `MultipartError` -impl From for HttpResponse { - fn from(err: MultipartError) -> Self { - HttpResponse::from_error(StatusCode::BAD_REQUEST, err) +impl ErrorResponse for MultipartError { + + fn error_response(&self) -> HttpResponse { + HttpResponse::new(StatusCode::BAD_REQUEST, Body::Empty) } } -/// Return `BadRequest` for `HttpRangeParseError` -impl From for HttpResponse { - fn from(_: HttpRangeParseError) -> Self { - HttpResponse::new( - StatusCode::BAD_REQUEST, Body::from("Invalid Range header provided")) +/// Websocket handshake errors +#[derive(Fail, PartialEq, Debug)] +pub enum WsHandshakeError { + /// Only get method is allowed + #[fail(display="Method not allowed")] + GetMethodRequired, + /// Ugrade header if not set to websocket + #[fail(display="Websocket upgrade is expected")] + NoWebsocketUpgrade, + /// Connection header is not set to upgrade + #[fail(display="Connection upgrade is expected")] + NoConnectionUpgrade, + /// Websocket version header is not set + #[fail(display="Websocket version header is required")] + NoVersionHeader, + /// Unsupported websockt version + #[fail(display="Unsupported version")] + UnsupportedVersion, + /// Websocket key is not set or wrong + #[fail(display="Unknown websocket key")] + BadWebsocketKey, +} + +impl ErrorResponse for WsHandshakeError { + + fn error_response(&self) -> HttpResponse { + match *self { + WsHandshakeError::GetMethodRequired => { + HTTPMethodNotAllowed + .builder() + .header(header::ALLOW, "GET") + .finish() + .unwrap() + } + WsHandshakeError::NoWebsocketUpgrade => + HTTPBadRequest.with_reason("No WebSocket UPGRADE header found"), + WsHandshakeError::NoConnectionUpgrade => + HTTPBadRequest.with_reason("No CONNECTION upgrade"), + WsHandshakeError::NoVersionHeader => + HTTPBadRequest.with_reason("Websocket version header is required"), + WsHandshakeError::UnsupportedVersion => + HTTPBadRequest.with_reason("Unsupported version"), + WsHandshakeError::BadWebsocketKey => + HTTPBadRequest.with_reason("Handshake error") + } } } @@ -159,24 +305,24 @@ mod tests { use httparse; use http::{StatusCode, Error as HttpError}; use cookie::ParseError as CookieParseError; - use super::{ParseError, HttpResponse, HttpRangeParseError, MultipartError}; + use super::*; #[test] fn test_into_response() { - let resp: HttpResponse = ParseError::Incomplete.into(); + let resp: HttpResponse = ParseError::Incomplete.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = HttpRangeParseError::InvalidRange.into(); + let resp: HttpResponse = HttpRangeError::InvalidRange.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = CookieParseError::EmptyName.into(); + let resp: HttpResponse = CookieParseError::EmptyName.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let resp: HttpResponse = MultipartError::Boundary.into(); + let resp: HttpResponse = MultipartError::Boundary.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); let err: HttpError = StatusCode::from_u16(10000).err().unwrap().into(); - let resp: HttpResponse = err.into(); + let resp: HttpResponse = err.error_response(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); } @@ -185,14 +331,14 @@ mod tests { let orig = io::Error::new(io::ErrorKind::Other, "other"); let desc = orig.description().to_owned(); let e = ParseError::Io(orig); - assert_eq!(e.cause().unwrap().description(), desc); + assert_eq!(format!("{}", e.cause().unwrap()), desc); } macro_rules! from { ($from:expr => $error:pat) => { match ParseError::from($from) { e @ $error => { - assert!(e.description().len() >= 5); + assert!(format!("{}", e).len() >= 5); } , e => panic!("{:?}", e) } @@ -203,9 +349,8 @@ mod tests { ($from:expr => $error:pat) => { match ParseError::from($from) { e @ $error => { - let desc = e.cause().unwrap().description(); + let desc = format!("{}", e.cause().unwrap()); assert_eq!(desc, $from.description().to_owned()); - assert_eq!(desc, e.description()); }, _ => panic!("{:?}", $from) } diff --git a/src/h1.rs b/src/h1.rs index 3da5dc0ca..ea6fda6ab 100644 --- a/src/h1.rs +++ b/src/h1.rs @@ -17,12 +17,12 @@ use percent_encoding; use task::Task; use channel::HttpHandler; -use error::ParseError; +use error::{ParseError, PayloadError, ErrorResponse}; use h1writer::H1Writer; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; use encoding::PayloadType; -use payload::{Payload, PayloadError, PayloadWriter, DEFAULT_BUFFER_SIZE}; +use payload::{Payload, PayloadWriter, DEFAULT_BUFFER_SIZE}; const KEEPALIVE_PERIOD: u64 = 15; // seconds const INIT_BUFFER_SIZE: usize = 8192; @@ -167,7 +167,7 @@ impl Http1 } // read incoming data - if !self.error && !self.h2 && self.tasks.len() < MAX_PIPELINED_MESSAGES { + while !self.error && !self.h2 && self.tasks.len() < MAX_PIPELINED_MESSAGES { match self.reader.parse(self.stream.get_mut(), &mut self.read_buf) { Ok(Async::Ready(Item::Http1(mut req, payload))) => { not_ready = false; @@ -224,7 +224,7 @@ impl Http1 if self.tasks.is_empty() { if let ReaderError::Error(err) = err { self.tasks.push_back( - Entry {task: Task::reply(err), + Entry {task: Task::reply(err.error_response()), req: UnsafeCell::new(HttpRequest::for_error()), eof: false, error: false, @@ -250,7 +250,7 @@ impl Http1 return Ok(Async::Ready(Http1Result::Done)) } } - return Ok(Async::NotReady) + break } } } diff --git a/src/h2.rs b/src/h2.rs index 618cdca3a..99e153420 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -20,8 +20,9 @@ use h2writer::H2Writer; use channel::HttpHandler; use httpcodes::HTTPNotFound; use httprequest::HttpRequest; +use error::PayloadError; use encoding::PayloadType; -use payload::{Payload, PayloadError, PayloadWriter}; +use payload::{Payload, PayloadWriter}; const KEEPALIVE_PERIOD: u64 = 15; // seconds diff --git a/src/httprequest.rs b/src/httprequest.rs index e1310fff0..4f3caaa1e 100644 --- a/src/httprequest.rs +++ b/src/httprequest.rs @@ -7,12 +7,11 @@ use futures::{Async, Future, Stream, Poll}; use url::form_urlencoded; use http::{header, Method, Version, HeaderMap, Extensions}; -use {Cookie, CookieParseError}; -use {HttpRange, HttpRangeParseError}; -use error::ParseError; +use {Cookie, HttpRange}; use recognizer::Params; -use payload::{Payload, PayloadError}; -use multipart::{Multipart, MultipartError}; +use payload::Payload; +use multipart::Multipart; +use error::{ParseError, PayloadError, MultipartError, CookieParseError, HttpRangeError}; /// An HTTP Request @@ -222,9 +221,10 @@ impl HttpRequest { /// Parses Range HTTP header string as per RFC 2616. /// `size` is full size of response (file). - pub fn range(&self, size: u64) -> Result, HttpRangeParseError> { + pub fn range(&self, size: u64) -> Result, HttpRangeError> { if let Some(range) = self.headers().get(header::RANGE) { HttpRange::parse(unsafe{str::from_utf8_unchecked(range.as_bytes())}, size) + .map_err(|e| e.into()) } else { Ok(Vec::new()) } diff --git a/src/httpresponse.rs b/src/httpresponse.rs index 7d5bcee67..6e627b2af 100644 --- a/src/httpresponse.rs +++ b/src/httpresponse.rs @@ -1,6 +1,5 @@ -//! Pieces pertaining to the HTTP message protocol. +//! Pieces pertaining to the HTTP response. use std::{io, mem, str, fmt}; -use std::error::Error as Error; use std::convert::Into; use cookie::CookieJar; @@ -33,7 +32,6 @@ pub struct HttpResponse { chunked: bool, encoding: ContentEncoding, connection_type: Option, - error: Option>, response_size: u64, } @@ -59,35 +57,10 @@ impl HttpResponse { chunked: false, encoding: ContentEncoding::Auto, connection_type: None, - error: None, response_size: 0, } } - /// Constructs a response from error - #[inline] - pub fn from_error(status: StatusCode, error: E) -> HttpResponse { - HttpResponse { - version: None, - headers: Default::default(), - status: status, - reason: None, - body: Body::from_slice(error.description().as_ref()), - chunked: false, - encoding: ContentEncoding::Auto, - connection_type: None, - error: Some(Box::new(error)), - response_size: 0, - } - } - - /// The `error` which is responsible for this response - #[inline] - #[cfg_attr(feature="cargo-clippy", allow(borrowed_box))] - pub fn error(&self) -> Option<&Box> { - self.error.as_ref() - } - /// Get the HTTP version of this response. #[inline] pub fn version(&self) -> Option { @@ -241,9 +214,6 @@ impl fmt::Debug for HttpResponse { let _ = write!(f, " {:?}: {:?}\n", key, vals[0]); } } - if let Some(ref err) = self.error { - let _ = write!(f, " error: {}\n", err); - } res } } @@ -445,7 +415,6 @@ impl HttpResponseBuilder { chunked: parts.chunked, encoding: parts.encoding, connection_type: parts.connection_type, - error: None, response_size: 0, }) } diff --git a/src/lib.rs b/src/lib.rs index 7bb19417f..72f3da5bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,9 @@ extern crate futures; extern crate tokio_io; extern crate tokio_core; +extern crate failure; +#[macro_use] extern crate failure_derive; + extern crate cookie; extern crate http; extern crate httparse; @@ -19,12 +22,16 @@ extern crate mime; extern crate mime_guess; extern crate url; extern crate libc; +extern crate serde; +extern crate serde_json; extern crate flate2; extern crate brotli2; extern crate percent_encoding; extern crate actix; extern crate h2 as http2; +extern crate redis_async; + #[cfg(feature="tls")] extern crate native_tls; #[cfg(feature="tls")] @@ -38,7 +45,6 @@ extern crate tokio_openssl; mod application; mod body; mod context; -mod error; mod date; mod encoding; mod httprequest; @@ -60,16 +66,16 @@ mod h2writer; pub mod ws; pub mod dev; +pub mod error; pub mod httpcodes; pub mod multipart; pub mod middlewares; pub use encoding::ContentEncoding; -pub use error::ParseError; pub use body::{Body, Binary}; pub use application::{Application, ApplicationBuilder}; pub use httprequest::{HttpRequest, UrlEncoded}; pub use httpresponse::{HttpResponse, HttpResponseBuilder}; -pub use payload::{Payload, PayloadItem, PayloadError}; +pub use payload::{Payload, PayloadItem}; pub use route::{Frame, Route, RouteFactory, RouteHandler, RouteResult}; pub use resource::{Reply, Resource, HandlerResult}; pub use recognizer::{Params, RouteRecognizer}; @@ -81,8 +87,7 @@ pub use staticfiles::StaticFiles; // re-exports pub use http::{Method, StatusCode, Version}; pub use cookie::{Cookie, CookieBuilder}; -pub use cookie::{ParseError as CookieParseError}; -pub use http_range::{HttpRange, HttpRangeParseError}; +pub use http_range::HttpRange; #[cfg(feature="tls")] pub use native_tls::Pkcs12; diff --git a/src/multipart.rs b/src/multipart.rs index be272777b..64b434275 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -2,7 +2,6 @@ use std::{cmp, fmt}; use std::rc::Rc; use std::cell::RefCell; -use std::error::Error; use std::marker::PhantomData; use mime; @@ -13,69 +12,11 @@ use http::header::{self, HeaderMap, HeaderName, HeaderValue}; use futures::{Async, Stream, Poll}; use futures::task::{Task, current as current_task}; -use error::ParseError; -use payload::{Payload, PayloadError}; +use error::{ParseError, PayloadError, MultipartError}; +use payload::Payload; const MAX_HEADERS: usize = 32; -/// A set of errors that can occur during parsing multipart streams. -#[derive(Debug)] -pub enum MultipartError { - /// Content-Type header is not found - NoContentType, - /// Can not parse Content-Type header - ParseContentType, - /// Multipart boundary is not found - Boundary, - /// Error during field parsing - Parse(ParseError), - /// Payload error - Payload(PayloadError), -} - -impl fmt::Display for MultipartError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - MultipartError::Parse(ref e) => fmt::Display::fmt(e, f), - MultipartError::Payload(ref e) => fmt::Display::fmt(e, f), - ref e => f.write_str(e.description()), - } - } -} - -impl Error for MultipartError { - fn description(&self) -> &str { - match *self { - MultipartError::NoContentType => "No Content-type header found", - MultipartError::ParseContentType => "Can not parse Content-Type header", - MultipartError::Boundary => "Multipart boundary is not found", - MultipartError::Parse(ref e) => e.description(), - MultipartError::Payload(ref e) => e.description(), - } - } - - fn cause(&self) -> Option<&Error> { - match *self { - MultipartError::Parse(ref error) => Some(error), - MultipartError::Payload(ref error) => Some(error), - _ => None, - } - } -} - - -impl From for MultipartError { - fn from(err: ParseError) -> MultipartError { - MultipartError::Parse(err) - } -} - -impl From for MultipartError { - fn from(err: PayloadError) -> MultipartError { - MultipartError::Payload(err) - } -} - /// The server-side implementation of `multipart/form-data` requests. /// /// This will parse the incoming stream into `MultipartItem` instances via its diff --git a/src/payload.rs b/src/payload.rs index 2a00458b7..321693fa1 100644 --- a/src/payload.rs +++ b/src/payload.rs @@ -2,14 +2,12 @@ use std::{fmt, cmp}; use std::rc::{Rc, Weak}; use std::cell::RefCell; use std::collections::VecDeque; -use std::error::Error; -use std::io::{Error as IoError}; use bytes::{Bytes, BytesMut}; -use http2::Error as Http2Error; use futures::{Async, Poll, Stream}; use futures::task::{Task, current as current_task}; use actix::ResponseType; +use error::PayloadError; pub(crate) const DEFAULT_BUFFER_SIZE: usize = 65_536; // max buffer size 64k @@ -27,52 +25,6 @@ impl fmt::Debug for PayloadItem { } } -#[derive(Debug)] -/// A set of error that can occur during payload parsing. -pub enum PayloadError { - /// A payload reached EOF, but is not complete. - Incomplete, - /// Content encoding stream corruption - EncodingCorrupted, - /// Parse error - ParseError(IoError), - /// Http2 error - Http2(Http2Error), -} - -impl fmt::Display for PayloadError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - PayloadError::ParseError(ref e) => fmt::Display::fmt(e, f), - ref e => f.write_str(e.description()), - } - } -} - -impl Error for PayloadError { - fn description(&self) -> &str { - match *self { - PayloadError::Incomplete => "A payload reached EOF, but is not complete.", - PayloadError::EncodingCorrupted => "Can not decode content-encoding.", - PayloadError::ParseError(ref e) => e.description(), - PayloadError::Http2(ref e) => e.description(), - } - } - - fn cause(&self) -> Option<&Error> { - match *self { - PayloadError::ParseError(ref error) => Some(error), - _ => None, - } - } -} - -impl From for PayloadError { - fn from(err: IoError) -> PayloadError { - PayloadError::ParseError(err) - } -} - /// Stream of byte chunks /// /// Payload stores chunks in vector. First chunk can be received with `.readany()` method. @@ -392,18 +344,17 @@ impl Inner { mod tests { use super::*; use std::io; + use failure::Fail; use futures::future::{lazy, result}; use tokio_core::reactor::Core; #[test] fn test_error() { - let err: PayloadError = IoError::new(io::ErrorKind::Other, "ParseError").into(); - assert_eq!(err.description(), "ParseError"); - assert_eq!(err.cause().unwrap().description(), "ParseError"); + let err: PayloadError = io::Error::new(io::ErrorKind::Other, "ParseError").into(); assert_eq!(format!("{}", err), "ParseError"); + assert_eq!(format!("{}", err.cause().unwrap()), "ParseError"); let err = PayloadError::Incomplete; - assert_eq!(err.description(), "A payload reached EOF, but is not complete."); assert_eq!(format!("{}", err), "A payload reached EOF, but is not complete."); } diff --git a/src/resource.rs b/src/resource.rs index d3a26175b..b8ddcc2e1 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -8,6 +8,7 @@ use http::Method; use futures::Stream; use task::Task; +use error::Error; use route::{Route, RouteHandler, RouteResult, Frame, FnHandler, StreamHandler}; use payload::Payload; use context::HttpContext; @@ -16,7 +17,7 @@ use httpresponse::HttpResponse; use httpcodes::{HTTPNotFound, HTTPMethodNotAllowed}; /// Result of a resource handler function -pub type HandlerResult = Result; +pub type HandlerResult = Result; /// Http resource /// @@ -77,7 +78,7 @@ impl Resource where S: 'static { /// Register async handler for specified method. pub fn async(&mut self, method: Method, handler: F) where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, - R: Stream + 'static, + R: Stream + 'static, { self.routes.insert(method, Box::new(StreamHandler::new(handler))); } diff --git a/src/route.rs b/src/route.rs index d5a2b94b8..8f1d1cb4a 100644 --- a/src/route.rs +++ b/src/route.rs @@ -1,4 +1,3 @@ -use std::io; use std::rc::Rc; use std::cell::RefCell; use std::marker::PhantomData; @@ -9,6 +8,7 @@ use futures::Stream; use task::{Task, DrainFut}; use body::Binary; +use error::Error; use context::HttpContext; use resource::Reply; use payload::Payload; @@ -42,7 +42,7 @@ pub trait RouteHandler: 'static { } /// Request handling result. -pub type RouteResult = Result, HttpResponse>; +pub type RouteResult = Result, Error>; /// Actors with ability to handle http requests. #[allow(unused_variables)] @@ -151,7 +151,7 @@ impl RouteHandler for FnHandler pub(crate) struct StreamHandler where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, - R: Stream + 'static, + R: Stream + 'static, S: 'static, { f: Box, @@ -160,7 +160,7 @@ struct StreamHandler impl StreamHandler where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, - R: Stream + 'static, + R: Stream + 'static, S: 'static, { pub fn new(f: F) -> Self { @@ -170,14 +170,11 @@ impl StreamHandler impl RouteHandler for StreamHandler where F: Fn(&mut HttpRequest, Payload, &S) -> R + 'static, - R: Stream + 'static, + R: Stream + 'static, S: 'static, { fn handle(&self, req: &mut HttpRequest, payload: Payload, state: Rc) -> Task { - Task::with_stream( - (self.f)(req, payload, &state).map_err( - |_| io::Error::new(io::ErrorKind::Other, "")) - ) + Task::with_stream((self.f)(req, payload, &state)) } } diff --git a/src/task.rs b/src/task.rs index 810e92ebd..80ef6621d 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,4 +1,4 @@ -use std::{mem, io}; +use std::mem; use std::rc::Rc; use std::cell::RefCell; use std::collections::VecDeque; @@ -7,12 +7,13 @@ use futures::{Async, Future, Poll, Stream}; use futures::task::{Task as FutureTask, current as current_task}; use h1writer::{Writer, WriterState}; +use error::Error; use route::Frame; use middlewares::{Middleware, MiddlewaresExecutor}; use httprequest::HttpRequest; use httpresponse::HttpResponse; -type FrameStream = Stream; +type FrameStream = Stream; #[derive(PartialEq, Debug)] enum TaskRunningState { @@ -53,10 +54,10 @@ impl TaskIOState { enum TaskStream { None, Stream(Box), - Context(Box>), + Context(Box>), } -pub(crate) trait IoContext: Stream + 'static { +pub(crate) trait IoContext: Stream + 'static { fn disconnected(&mut self); } @@ -141,7 +142,7 @@ impl Task { } pub(crate) fn with_stream(stream: S) -> Self - where S: Stream + 'static + where S: Stream + 'static { Task { state: TaskRunningState::Running, iostate: TaskIOState::ReadingMessage, @@ -290,7 +291,7 @@ impl Task { } fn poll_stream(&mut self, stream: &mut S) -> Poll<(), ()> - where S: Stream { + where S: Stream { loop { match stream.poll() { Ok(Async::Ready(Some(frame))) => { diff --git a/src/ws.rs b/src/ws.rs index 78dc85f99..b7e45dd2b 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -71,7 +71,7 @@ use body::Body; use context::HttpContext; use route::Route; use payload::Payload; -use httpcodes::{HTTPBadRequest, HTTPMethodNotAllowed}; +use error::WsHandshakeError; use httprequest::HttpRequest; use httpresponse::{ConnectionType, HttpResponse}; @@ -114,14 +114,10 @@ impl ResponseType for Message { // /// `protocols` is a sequence of known protocols. On successful handshake, // /// the returned response headers contain the first protocol in this list // /// which the server also knows. -pub fn handshake(req: &HttpRequest) -> Result { +pub fn handshake(req: &HttpRequest) -> Result { // WebSocket accepts only GET if *req.method() != Method::GET { - return Err( - HTTPMethodNotAllowed - .builder() - .header(header::ALLOW, "GET") - .finish()?) + return Err(WsHandshakeError::GetMethodRequired) } // Check for "UPGRADE" to websocket header @@ -135,17 +131,17 @@ pub fn handshake(req: &HttpRequest) -> Result { false }; if !has_hdr { - return Err(HTTPBadRequest.with_reason("No WebSocket UPGRADE header found")) + return Err(WsHandshakeError::NoWebsocketUpgrade) } // Upgrade connection if !req.upgrade() { - return Err(HTTPBadRequest.with_reason("No CONNECTION upgrade")) + return Err(WsHandshakeError::NoConnectionUpgrade) } // check supported version if !req.headers().contains_key(SEC_WEBSOCKET_VERSION) { - return Err(HTTPBadRequest.with_reason("No websocket version header is required")) + return Err(WsHandshakeError::NoVersionHeader) } let supported_ver = { if let Some(hdr) = req.headers().get(SEC_WEBSOCKET_VERSION) { @@ -155,12 +151,12 @@ pub fn handshake(req: &HttpRequest) -> Result { } }; if !supported_ver { - return Err(HTTPBadRequest.with_reason("Unsupported version")) + return Err(WsHandshakeError::UnsupportedVersion) } // check client handshake for validity if !req.headers().contains_key(SEC_WEBSOCKET_KEY) { - return Err(HTTPBadRequest.with_reason("Handshake error")); + return Err(WsHandshakeError::BadWebsocketKey) } let key = { let key = req.headers().get(SEC_WEBSOCKET_KEY).unwrap(); @@ -172,7 +168,7 @@ pub fn handshake(req: &HttpRequest) -> Result { .header(header::UPGRADE, "websocket") .header(header::TRANSFER_ENCODING, "chunked") .header(SEC_WEBSOCKET_ACCEPT, key.as_str()) - .body(Body::Upgrade)? + .body(Body::Upgrade).unwrap() ) } @@ -338,44 +334,32 @@ impl WsWriter { #[cfg(test)] mod tests { - use http::{Method, HeaderMap, StatusCode, Version, header}; - use super::{HttpRequest, SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_KEY, handshake}; + use super::*; + use http::{Method, HeaderMap, Version, header}; #[test] fn test_handshake() { let req = HttpRequest::new(Method::POST, "/".to_owned(), Version::HTTP_11, HeaderMap::new(), String::new()); - match handshake(&req) { - Err(err) => assert_eq!(err.status(), StatusCode::METHOD_NOT_ALLOWED), - _ => panic!("should not happen"), - } + assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); let req = HttpRequest::new(Method::GET, "/".to_owned(), Version::HTTP_11, HeaderMap::new(), String::new()); - match handshake(&req) { - Err(err) => assert_eq!(err.status(), StatusCode::BAD_REQUEST), - _ => panic!("should not happen"), - } + assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, header::HeaderValue::from_static("test")); let req = HttpRequest::new(Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new()); - match handshake(&req) { - Err(err) => assert_eq!(err.status(), StatusCode::BAD_REQUEST), - _ => panic!("should not happen"), - } + assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, header::HeaderValue::from_static("websocket")); let req = HttpRequest::new(Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new()); - match handshake(&req) { - Err(err) => assert_eq!(err.status(), StatusCode::BAD_REQUEST), - _ => panic!("should not happen"), - } + assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, @@ -384,10 +368,7 @@ mod tests { header::HeaderValue::from_static("upgrade")); let req = HttpRequest::new(Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new()); - match handshake(&req) { - Err(err) => assert_eq!(err.status(), StatusCode::BAD_REQUEST), - _ => panic!("should not happen"), - } + assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, @@ -398,10 +379,7 @@ mod tests { header::HeaderValue::from_static("5")); let req = HttpRequest::new(Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new()); - match handshake(&req) { - Err(err) => assert_eq!(err.status(), StatusCode::BAD_REQUEST), - _ => panic!("should not happen"), - } + assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, @@ -412,10 +390,7 @@ mod tests { header::HeaderValue::from_static("13")); let req = HttpRequest::new(Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new()); - match handshake(&req) { - Err(err) => assert_eq!(err.status(), StatusCode::BAD_REQUEST), - _ => panic!("should not happen"), - } + assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); let mut headers = HeaderMap::new(); headers.insert(header::UPGRADE, @@ -428,11 +403,6 @@ mod tests { header::HeaderValue::from_static("13")); let req = HttpRequest::new(Method::GET, "/".to_owned(), Version::HTTP_11, headers, String::new()); - match handshake(&req) { - Ok(resp) => { - assert_eq!(resp.status(), StatusCode::SWITCHING_PROTOCOLS) - }, - _ => panic!("should not happen"), - } + assert_eq!(WsHandshakeError::GetMethodRequired, handshake(&req).err().unwrap()); } } diff --git a/src/wsproto.rs b/src/wsproto.rs index ffbbbd8d2..a1b72f69f 100644 --- a/src/wsproto.rs +++ b/src/wsproto.rs @@ -329,21 +329,21 @@ mod test { #[test] fn closecode_into_u16() { - assert_eq!(1000u16, CloseCode::Normal.into()); - assert_eq!(1001u16, CloseCode::Away.into()); - assert_eq!(1002u16, CloseCode::Protocol.into()); - assert_eq!(1003u16, CloseCode::Unsupported.into()); - assert_eq!(1005u16, CloseCode::Status.into()); - assert_eq!(1006u16, CloseCode::Abnormal.into()); - assert_eq!(1007u16, CloseCode::Invalid.into()); - assert_eq!(1008u16, CloseCode::Policy.into()); - assert_eq!(1009u16, CloseCode::Size.into()); - assert_eq!(1010u16, CloseCode::Extension.into()); - assert_eq!(1011u16, CloseCode::Error.into()); - assert_eq!(1012u16, CloseCode::Restart.into()); - assert_eq!(1013u16, CloseCode::Again.into()); - assert_eq!(1015u16, CloseCode::Tls.into()); - assert_eq!(0u16, CloseCode::Empty.into()); - assert_eq!(2000u16, CloseCode::Other(2000).into()); + assert_eq!(1000u16, Into::::into(CloseCode::Normal)); + assert_eq!(1001u16, Into::::into(CloseCode::Away)); + assert_eq!(1002u16, Into::::into(CloseCode::Protocol)); + assert_eq!(1003u16, Into::::into(CloseCode::Unsupported)); + assert_eq!(1005u16, Into::::into(CloseCode::Status)); + assert_eq!(1006u16, Into::::into(CloseCode::Abnormal)); + assert_eq!(1007u16, Into::::into(CloseCode::Invalid)); + assert_eq!(1008u16, Into::::into(CloseCode::Policy)); + assert_eq!(1009u16, Into::::into(CloseCode::Size)); + assert_eq!(1010u16, Into::::into(CloseCode::Extension)); + assert_eq!(1011u16, Into::::into(CloseCode::Error)); + assert_eq!(1012u16, Into::::into(CloseCode::Restart)); + assert_eq!(1013u16, Into::::into(CloseCode::Again)); + assert_eq!(1015u16, Into::::into(CloseCode::Tls)); + assert_eq!(0u16, Into::::into(CloseCode::Empty)); + assert_eq!(2000u16, Into::::into(CloseCode::Other(2000))); } }