diff --git a/Cargo.toml b/Cargo.toml index f5055122f..37586ca2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,8 +28,6 @@ default = [] # tls tls = ["native-tls", "tokio-tls"] -# http2 = ["h2"] - [dependencies] log = "0.3" time = "0.1" diff --git a/cov.sh b/cov.sh deleted file mode 100644 index 8e9fd237b..000000000 --- a/cov.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -for file in target/debug/actix_web-*[^\.d]; do mkdir -p "target/cov/$(basename $file)"; /usr/local/bin/kcov --exclude-pattern=/.cargo,/usr/lib --verify "target/cov/$(basename $file)" "$file"; done && -for file in target/debug/test_*[^\.d]; do mkdir -p "target/cov/$(basename $file)"; /usr/local/bin/kcov --exclude-pattern=/.cargo,/usr/lib --verify "target/cov/$(basename $file)" "$file"; done diff --git a/src/context.rs b/src/context.rs index 04f39e3fd..b6fd4425c 100644 --- a/src/context.rs +++ b/src/context.rs @@ -47,6 +47,7 @@ impl ActorContext for HttpContext where A: Actor + Route { /// Stop actor execution fn stop(&mut self) { + self.stream.push_back(Frame::Payload(None)); self.items.stop(); self.address.close(); if self.state == ActorState::Running { @@ -141,7 +142,6 @@ impl HttpContext where A: Actor + Route { /// Indicate end of streamimng payload. Also this method calls `Self::close`. pub fn write_eof(&mut self) { - self.stream.push_back(Frame::Payload(None)); self.stop(); } diff --git a/src/h1.rs b/src/h1.rs index ba7b23061..78ad3ad32 100644 --- a/src/h1.rs +++ b/src/h1.rs @@ -1,22 +1,282 @@ use std::{self, io, ptr}; +use std::rc::Rc; +use std::cell::UnsafeCell; +use std::time::Duration; +use std::collections::VecDeque; +use actix::Arbiter; use httparse; use http::{Method, Version, HttpTryFrom, HeaderMap}; use http::header::{self, HeaderName, HeaderValue}; use bytes::{Bytes, BytesMut, BufMut}; -use futures::{Async, Poll}; -use tokio_io::AsyncRead; +use futures::{Future, Poll, Async}; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_core::reactor::Timeout; use percent_encoding; +use task::Task; +use server::HttpHandler; use error::ParseError; +use httpcodes::HTTPNotFound; use httprequest::HttpRequest; use payload::{Payload, PayloadError, PayloadSender}; +use h1writer::H1Writer; -const MAX_HEADERS: usize = 100; +const KEEPALIVE_PERIOD: u64 = 15; // seconds const INIT_BUFFER_SIZE: usize = 8192; const MAX_BUFFER_SIZE: usize = 131_072; +const MAX_HEADERS: usize = 100; +const MAX_PIPELINED_MESSAGES: usize = 16; const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0"; +pub(crate) enum Http1Result { + Done, + Upgrade, +} + +pub(crate) struct Http1 { + router: Rc>, + #[allow(dead_code)] + addr: A, + stream: H1Writer, + reader: Reader, + read_buf: BytesMut, + error: bool, + tasks: VecDeque, + keepalive: bool, + keepalive_timer: Option, + h2: bool, +} + +struct Entry { + task: Task, + req: UnsafeCell, + eof: bool, + error: bool, + finished: bool, +} + +impl Http1 + where T: AsyncRead + AsyncWrite + 'static, + A: 'static, + H: HttpHandler + 'static +{ + pub fn new(stream: T, addr: A, router: Rc>) -> Self { + Http1{ router: router, + addr: addr, + stream: H1Writer::new(stream), + reader: Reader::new(), + read_buf: BytesMut::new(), + error: false, + tasks: VecDeque::new(), + keepalive: true, + keepalive_timer: None, + h2: false } + } + + pub fn into_inner(mut self) -> (T, A, Rc>, Bytes) { + (self.stream.into_inner(), self.addr, self.router, self.read_buf.freeze()) + } + + pub fn poll(&mut self) -> Poll { + // keep-alive timer + if let Some(ref mut timeout) = self.keepalive_timer { + match timeout.poll() { + Ok(Async::Ready(_)) => + return Ok(Async::Ready(Http1Result::Done)), + Ok(Async::NotReady) => (), + Err(_) => unreachable!(), + } + } + + loop { + let mut not_ready = true; + + // check in-flight messages + let mut io = false; + let mut idx = 0; + while idx < self.tasks.len() { + let item = &mut self.tasks[idx]; + + if !io && !item.eof { + if item.error { + return Err(()) + } + + // this is anoying + let req = unsafe {item.req.get().as_mut().unwrap()}; + match item.task.poll_io(&mut self.stream, req) + { + Ok(Async::Ready(ready)) => { + not_ready = false; + + // overide keep-alive state + if self.keepalive { + self.keepalive = self.stream.keepalive(); + } + self.stream = H1Writer::new(self.stream.into_inner()); + + item.eof = true; + if ready { + item.finished = true; + } + }, + Ok(Async::NotReady) => { + // no more IO for this iteration + io = true; + }, + Err(_) => { + // it is not possible to recover from error + // during task handling, so just drop connection + return Err(()) + } + } + } else if !item.finished { + match item.task.poll() { + Ok(Async::NotReady) => (), + Ok(Async::Ready(_)) => { + not_ready = false; + item.finished = true; + }, + Err(_) => + item.error = true, + } + } + idx += 1; + } + + // cleanup finished tasks + while !self.tasks.is_empty() { + if self.tasks[0].eof && self.tasks[0].finished { + self.tasks.pop_front(); + } else { + break + } + } + + // no keep-alive + if !self.keepalive && self.tasks.is_empty() { + if self.h2 { + return Ok(Async::Ready(Http1Result::Upgrade)) + } else { + return Ok(Async::Ready(Http1Result::Done)) + } + } + + // read incoming data + if !self.error && !self.h2 && self.tasks.len() < MAX_PIPELINED_MESSAGES { + match self.reader.parse(self.stream.get_mut(), &mut self.read_buf) { + Ok(Async::Ready(Item::Http1(mut req, payload))) => { + not_ready = false; + + // stop keepalive timer + self.keepalive_timer.take(); + + // start request processing + let mut task = None; + for h in self.router.iter() { + if req.path().starts_with(h.prefix()) { + task = Some(h.handle(&mut req, payload)); + break + } + } + + self.tasks.push_back( + Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)), + req: UnsafeCell::new(req), + eof: false, + error: false, + finished: false}); + } + Ok(Async::Ready(Item::Http2)) => { + self.h2 = true; + } + Err(ReaderError::Disconnect) => { + not_ready = false; + self.error = true; + self.stream.disconnected(); + for entry in &mut self.tasks { + entry.task.disconnected() + } + }, + Err(err) => { + // notify all tasks + not_ready = false; + self.stream.disconnected(); + for entry in &mut self.tasks { + entry.task.disconnected() + } + + // kill keepalive + self.keepalive = false; + self.keepalive_timer.take(); + + // on parse error, stop reading stream but + // tasks need to be completed + self.error = true; + + if self.tasks.is_empty() { + if let ReaderError::Error(err) = err { + self.tasks.push_back( + Entry {task: Task::reply(err), + req: UnsafeCell::new(HttpRequest::for_error()), + eof: false, + error: false, + finished: false}); + } + } + } + Ok(Async::NotReady) => { + // start keep-alive timer, this is also slow request timeout + if self.tasks.is_empty() { + if self.keepalive { + if self.keepalive_timer.is_none() { + trace!("Start keep-alive timer"); + let mut timeout = Timeout::new( + Duration::new(KEEPALIVE_PERIOD, 0), + Arbiter::handle()).unwrap(); + // register timeout + let _ = timeout.poll(); + self.keepalive_timer = Some(timeout); + } + } else { + // keep-alive disable, drop connection + return Ok(Async::Ready(Http1Result::Done)) + } + } + return Ok(Async::NotReady) + } + } + } + + // check for parse error + if self.tasks.is_empty() { + if self.error || self.keepalive_timer.is_none() { + return Ok(Async::Ready(Http1Result::Done)) + } + else if self.h2 { + return Ok(Async::Ready(Http1Result::Upgrade)) + } + } + + if not_ready { + return Ok(Async::NotReady) + } + } + } +} + +#[derive(Debug)] +enum Item { + Http1(HttpRequest, Payload), + Http2, +} + +struct Reader { + h1: bool, + payload: Option, +} + enum Decoding { Paused, Ready, @@ -28,19 +288,9 @@ struct PayloadInfo { decoder: Decoder, } -pub(crate) struct Reader { - read_buf: BytesMut, - payload: Option, -} - #[derive(Debug)] -pub(crate) enum ReaderItem { - Http1(HttpRequest, Payload), - Http2, -} - -#[derive(Debug)] -pub(crate) enum ReaderError { +enum ReaderError { + Disconnect, Payload, Error(ParseError), } @@ -55,19 +305,19 @@ enum Message { impl Reader { pub fn new() -> Reader { Reader { - read_buf: BytesMut::new(), + h1: false, payload: None, } } - fn decode(&mut self) -> std::result::Result + fn decode(&mut self, buf: &mut BytesMut) -> std::result::Result { if let Some(ref mut payload) = self.payload { if payload.tx.maybe_paused() { return Ok(Decoding::Paused) } loop { - match payload.decoder.decode(&mut self.read_buf) { + match payload.decoder.decode(buf) { Ok(Async::Ready(Some(bytes))) => { payload.tx.feed_data(bytes) }, @@ -87,18 +337,18 @@ impl Reader { } } - pub fn parse(&mut self, io: &mut T) -> Poll<(HttpRequest, Payload), ReaderError> + pub fn parse(&mut self, io: &mut T, buf: &mut BytesMut) -> Poll where T: AsyncRead { loop { - match self.decode()? { + match self.decode(buf)? { Decoding::Paused => return Ok(Async::NotReady), Decoding::Ready => { self.payload = None; break }, Decoding::NotReady => { - match self.read_from_io(io) { + match self.read_from_io(io, buf) { Ok(Async::Ready(0)) => { if let Some(ref mut payload) = self.payload { payload.tx.set_error(PayloadError::Incomplete); @@ -123,7 +373,7 @@ impl Reader { } loop { - match Reader::parse_message(&mut self.read_buf).map_err(ReaderError::Error)? { + match Reader::parse_message(buf).map_err(ReaderError::Error)? { Message::Http1(msg, decoder) => { let payload = if let Some(decoder) = decoder { let (tx, rx) = Payload::new(false); @@ -134,7 +384,7 @@ impl Reader { self.payload = Some(payload); loop { - match self.decode()? { + match self.decode(buf)? { Decoding::Paused => break, Decoding::Ready => { @@ -142,7 +392,7 @@ impl Reader { break }, Decoding::NotReady => { - match self.read_from_io(io) { + match self.read_from_io(io, buf) { Ok(Async::Ready(0)) => { trace!("parse eof"); if let Some(ref mut payload) = self.payload { @@ -171,21 +421,26 @@ impl Reader { let (_, rx) = Payload::new(true); rx }; - return Ok(Async::Ready((msg, payload))); + self.h1 = true; + return Ok(Async::Ready(Item::Http1(msg, payload))); }, Message::Http2 => { + if self.h1 { + return Err(ReaderError::Error(ParseError::Version)) + } + return Ok(Async::Ready(Item::Http2)); }, Message::NotReady => { - if self.read_buf.capacity() >= MAX_BUFFER_SIZE { + if buf.capacity() >= MAX_BUFFER_SIZE { debug!("MAX_BUFFER_SIZE reached, closing"); return Err(ReaderError::Error(ParseError::TooLarge)); } }, } - match self.read_from_io(io) { + match self.read_from_io(io, buf) { Ok(Async::Ready(0)) => { - trace!("Eof during parse"); - return Err(ReaderError::Error(ParseError::Incomplete)); + debug!("Ignored premature client disconnection"); + return Err(ReaderError::Disconnect); }, Ok(Async::Ready(_)) => (), Ok(Async::NotReady) => @@ -196,17 +451,19 @@ impl Reader { } } - fn read_from_io(&mut self, io: &mut T) -> Poll { - if self.read_buf.remaining_mut() < INIT_BUFFER_SIZE { - self.read_buf.reserve(INIT_BUFFER_SIZE); + fn read_from_io(&mut self, io: &mut T, buf: &mut BytesMut) + -> Poll + { + if buf.remaining_mut() < INIT_BUFFER_SIZE { + buf.reserve(INIT_BUFFER_SIZE); unsafe { // Zero out unused memory - let buf = self.read_buf.bytes_mut(); - let len = buf.len(); - ptr::write_bytes(buf.as_mut_ptr(), 0, len); + let b = buf.bytes_mut(); + let len = b.len(); + ptr::write_bytes(b.as_mut_ptr(), 0, len); } } unsafe { - let n = match io.read(self.read_buf.bytes_mut()) { + let n = match io.read(buf.bytes_mut()) { Ok(n) => n, Err(e) => { if e.kind() == io::ErrorKind::WouldBlock { @@ -215,18 +472,17 @@ impl Reader { return Err(e) } }; - self.read_buf.advance_mut(n); + buf.advance_mut(n); Ok(Async::Ready(n)) } } fn parse_message(buf: &mut BytesMut) -> Result { - println!("BUF: {:?}", buf); - if buf.is_empty() || buf.len() < 14 { + if buf.is_empty() { return Ok(Message::NotReady); } - if &buf[..14] == &HTTP2_PREFACE[..] { + if buf.len() >= 14 && &buf[..14] == &HTTP2_PREFACE[..] { return Ok(Message::Http2) } @@ -368,7 +624,7 @@ fn record_header_indices(bytes: &[u8], /// If a message body does not include a Transfer-Encoding, it *should* /// include a Content-Length header. #[derive(Debug, Clone, PartialEq)] -pub struct Decoder { +struct Decoder { kind: Kind, } @@ -424,7 +680,7 @@ enum ChunkedState { } impl Decoder { - pub fn is_eof(&self) -> bool { + /*pub fn is_eof(&self) -> bool { trace!("is_eof? {:?}", self); match self.kind { Kind::Length(0) | @@ -432,7 +688,7 @@ impl Decoder { Kind::Eof(true) => true, _ => false, } - } + }*/ } impl Decoder { @@ -633,7 +889,7 @@ mod tests { use futures::{Async}; use tokio_io::AsyncRead; use http::{Version, Method}; - use super::{Reader, ReaderError}; + use super::*; struct Buffer { buf: Bytes, @@ -682,8 +938,8 @@ mod tests { macro_rules! parse_ready { ($e:expr) => ( - match Reader::new().parse($e) { - Ok(Async::Ready((req, payload))) => (req, payload), + match Reader::new().parse($e, &mut BytesMut::new()) { + Ok(Async::Ready(Item::Http1(req, payload))) => (req, payload), Ok(_) => panic!("Eof during parsing http request"), Err(err) => panic!("Error during parsing http request: {:?}", err), } @@ -693,7 +949,7 @@ mod tests { macro_rules! reader_parse_ready { ($e:expr) => ( match $e { - Ok(Async::Ready((req, payload))) => (req, payload), + Ok(Async::Ready(Item::Http1(req, payload))) => (req, payload), Ok(_) => panic!("Eof during parsing http request"), Err(err) => panic!("Error during parsing http request: {:?}", err), } @@ -701,22 +957,28 @@ mod tests { } macro_rules! expect_parse_err { - ($e:expr) => (match Reader::new().parse($e) { - Err(err) => match err { - ReaderError::Error(_) => (), - _ => panic!("Parse error expected"), - }, - _ => panic!("Error expected"), - }) + ($e:expr) => ({ + let mut buf = BytesMut::new(); + match Reader::new().parse($e, &mut buf) { + Err(err) => match err { + ReaderError::Error(_) => (), + _ => panic!("Parse error expected"), + }, + val => { + panic!("Error expected") + } + }} + ) } #[test] fn test_parse() { let mut buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - match reader.parse(&mut buf) { - Ok(Async::Ready((req, payload))) => { + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http1(req, payload))) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -729,16 +991,17 @@ mod tests { #[test] fn test_parse_partial() { let mut buf = Buffer::new("PUT /test HTTP/1"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - match reader.parse(&mut buf) { + match reader.parse(&mut buf, &mut readbuf) { Ok(Async::NotReady) => (), _ => panic!("Error"), } buf.feed_data(".1\r\n\r\n"); - match reader.parse(&mut buf) { - Ok(Async::Ready((req, payload))) => { + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http1(req, payload))) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::PUT); assert_eq!(req.path(), "/test"); @@ -751,10 +1014,11 @@ mod tests { #[test] fn test_parse_post() { let mut buf = Buffer::new("POST /test2 HTTP/1.0\r\n\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - match reader.parse(&mut buf) { - Ok(Async::Ready((req, payload))) => { + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http1(req, payload))) => { assert_eq!(req.version(), Version::HTTP_10); assert_eq!(*req.method(), Method::POST); assert_eq!(req.path(), "/test2"); @@ -767,10 +1031,11 @@ mod tests { #[test] fn test_parse_body() { let mut buf = Buffer::new("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - match reader.parse(&mut buf) { - Ok(Async::Ready((req, mut payload))) => { + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http1(req, mut payload))) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -784,10 +1049,11 @@ mod tests { fn test_parse_body_crlf() { let mut buf = Buffer::new( "\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - match reader.parse(&mut buf) { - Ok(Async::Ready((req, mut payload))) => { + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http1(req, mut payload))) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -800,13 +1066,14 @@ mod tests { #[test] fn test_parse_partial_eof() { let mut buf = Buffer::new("GET /test HTTP/1.1\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - not_ready!{ reader.parse(&mut buf) } + not_ready!{ reader.parse(&mut buf, &mut readbuf) } buf.feed_data("\r\n"); - match reader.parse(&mut buf) { - Ok(Async::Ready((req, payload))) => { + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http1(req, payload))) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -819,19 +1086,20 @@ mod tests { #[test] fn test_headers_split_field() { let mut buf = Buffer::new("GET /test HTTP/1.1\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - not_ready!{ reader.parse(&mut buf) } + not_ready!{ reader.parse(&mut buf, &mut readbuf) } buf.feed_data("t"); - not_ready!{ reader.parse(&mut buf) } + not_ready!{ reader.parse(&mut buf, &mut readbuf) } buf.feed_data("es"); - not_ready!{ reader.parse(&mut buf) } + not_ready!{ reader.parse(&mut buf, &mut readbuf) } buf.feed_data("t: value\r\n\r\n"); - match reader.parse(&mut buf) { - Ok(Async::Ready((req, payload))) => { + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http1(req, payload))) => { assert_eq!(req.version(), Version::HTTP_11); assert_eq!(*req.method(), Method::GET); assert_eq!(req.path(), "/test"); @@ -848,10 +1116,11 @@ mod tests { "GET /test HTTP/1.1\r\n\ Set-Cookie: c1=cookie1\r\n\ Set-Cookie: c2=cookie2\r\n\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - match reader.parse(&mut buf) { - Ok(Async::Ready((req, _))) => { + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http1(req, _))) => { let val: Vec<_> = req.headers().get_all("Set-Cookie") .iter().map(|v| v.to_str().unwrap().to_owned()).collect(); assert_eq!(val[0], "c1=cookie1"); @@ -1081,14 +1350,15 @@ mod tests { let mut buf = Buffer::new( "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf)); + let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf)); assert!(req.chunked().unwrap()); assert!(!payload.eof()); buf.feed_data("4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); assert!(!payload.eof()); assert_eq!(payload.readall().unwrap().as_ref(), b"dataline"); assert!(payload.eof()); @@ -1099,10 +1369,11 @@ mod tests { let mut buf = Buffer::new( "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf)); + let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf)); assert!(req.chunked().unwrap()); assert!(!payload.eof()); @@ -1111,7 +1382,7 @@ mod tests { POST /test2 HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); - let (req2, payload2) = reader_parse_ready!(reader.parse(&mut buf)); + let (req2, payload2) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf)); assert_eq!(*req2.method(), Method::POST); assert!(req2.chunked().unwrap()); assert!(!payload2.eof()); @@ -1125,37 +1396,38 @@ mod tests { let mut buf = Buffer::new( "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf)); + let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf)); assert!(req.chunked().unwrap()); assert!(!payload.eof()); buf.feed_data("4\r\ndata\r"); - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); buf.feed_data("\n4"); - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); buf.feed_data("\r"); - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); buf.feed_data("\n"); - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); buf.feed_data("li"); - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); buf.feed_data("ne\r\n0\r\n"); - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); //buf.feed_data("test: test\r\n"); - //not_ready!(reader.parse(&mut buf)); + //not_ready!(reader.parse(&mut buf, &mut readbuf)); assert_eq!(payload.readall().unwrap().as_ref(), b"dataline"); assert!(!payload.eof()); buf.feed_data("\r\n"); - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); assert!(payload.eof()); } @@ -1164,14 +1436,15 @@ mod tests { let mut buf = Buffer::new( "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); + let mut readbuf = BytesMut::new(); let mut reader = Reader::new(); - let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf)); + let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf)); assert!(req.chunked().unwrap()); assert!(!payload.eof()); buf.feed_data("4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n") - not_ready!(reader.parse(&mut buf)); + not_ready!(reader.parse(&mut buf, &mut readbuf)); assert!(!payload.eof()); assert_eq!(payload.readall().unwrap().as_ref(), b"dataline"); assert!(payload.eof()); @@ -1193,4 +1466,16 @@ mod tests { Err(err) => panic!("{:?}", err), } }*/ + + #[test] + fn test_http2_prefix() { + let mut buf = Buffer::new("PRI * HTTP/2.0\r\n\r\n"); + let mut readbuf = BytesMut::new(); + + let mut reader = Reader::new(); + match reader.parse(&mut buf, &mut readbuf) { + Ok(Async::Ready(Item::Http2)) => (), + Ok(_) | Err(_) => panic!("Error during parsing http request"), + } + } } diff --git a/src/h1writer.rs b/src/h1writer.rs new file mode 100644 index 000000000..98f2aa4fa --- /dev/null +++ b/src/h1writer.rs @@ -0,0 +1,351 @@ +use std::{cmp, io}; +use std::fmt::Write; +use bytes::BytesMut; +use futures::{Async, Poll}; +use tokio_io::AsyncWrite; +use http::{Version, StatusCode}; +use http::header::{HeaderValue, + CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE}; + +use date; +use body::Body; +use httprequest::HttpRequest; +use httpresponse::HttpResponse; + +const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific +const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k + + +pub(crate) enum WriterState { + Done, + Pause, +} + +/// Send stream +pub(crate) trait Writer { + fn start(&mut self, req: &mut HttpRequest, resp: &mut HttpResponse) + -> Result; + + fn write(&mut self, payload: &[u8]) -> Result; + + fn write_eof(&mut self) -> Result; + + fn poll_complete(&mut self) -> Poll<(), io::Error>; +} + + +pub(crate) struct H1Writer { + stream: Option, + buffer: BytesMut, + started: bool, + encoder: Encoder, + upgrade: bool, + keepalive: bool, + disconnected: bool, +} + +impl H1Writer { + + pub fn new(stream: T) -> H1Writer { + H1Writer { + stream: Some(stream), + buffer: BytesMut::new(), + started: false, + encoder: Encoder::length(0), + upgrade: false, + keepalive: false, + disconnected: false, + } + } + + pub fn get_mut(&mut self) -> &mut T { + self.stream.as_mut().unwrap() + } + + pub fn into_inner(&mut self) -> T { + self.stream.take().unwrap() + } + + pub fn disconnected(&mut self) { + let len = self.buffer.len(); + self.buffer.split_to(len); + } + + pub fn keepalive(&self) -> bool { + self.keepalive && !self.upgrade + } + + fn write_to_stream(&mut self) -> Result { + if let Some(ref mut stream) = self.stream { + while !self.buffer.is_empty() { + match stream.write(self.buffer.as_ref()) { + Ok(n) => { + self.buffer.split_to(n); + }, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + return Ok(WriterState::Pause) + } else { + return Ok(WriterState::Done) + } + } + Err(err) => + return Err(err), + } + } + } + return Ok(WriterState::Done) + } +} + +impl Writer for H1Writer { + + fn start(&mut self, req: &mut HttpRequest, msg: &mut HttpResponse) + -> Result + { + trace!("Prepare message status={:?}", msg.status); + + // prepare task + let mut extra = 0; + let body = msg.replace_body(Body::Empty); + let version = msg.version().unwrap_or_else(|| req.version()); + self.started = true; + self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive()); + + match body { + Body::Empty => { + if msg.chunked() { + error!("Chunked transfer is enabled but body is set to Empty"); + } + msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0")); + msg.headers.remove(TRANSFER_ENCODING); + self.encoder = Encoder::length(0); + }, + Body::Length(n) => { + if msg.chunked() { + error!("Chunked transfer is enabled but body with specific length is specified"); + } + msg.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(format!("{}", n).as_str()).unwrap()); + msg.headers.remove(TRANSFER_ENCODING); + self.encoder = Encoder::length(n); + }, + Body::Binary(ref bytes) => { + extra = bytes.len(); + msg.headers.insert( + CONTENT_LENGTH, + HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap()); + msg.headers.remove(TRANSFER_ENCODING); + self.encoder = Encoder::length(0); + } + Body::Streaming => { + if msg.chunked() { + if version < Version::HTTP_11 { + error!("Chunked transfer encoding is forbidden for {:?}", version); + } + msg.headers.remove(CONTENT_LENGTH); + msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); + self.encoder = Encoder::chunked(); + } else { + self.encoder = Encoder::eof(); + } + } + Body::Upgrade => { + msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); + self.encoder = Encoder::eof(); + } + } + + // Connection upgrade + if msg.upgrade() { + msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); + } + // keep-alive + else if self.keepalive { + if version < Version::HTTP_11 { + msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive")); + } + } else if version >= Version::HTTP_11 { + msg.headers.insert(CONNECTION, HeaderValue::from_static("close")); + } + + // render message + let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra; + self.buffer.reserve(init_cap); + + if version == Version::HTTP_11 && msg.status == StatusCode::OK { + self.buffer.extend(b"HTTP/1.1 200 OK\r\n"); + } else { + let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status); + } + for (key, value) in &msg.headers { + let t: &[u8] = key.as_ref(); + self.buffer.extend(t); + self.buffer.extend(b": "); + self.buffer.extend(value.as_ref()); + self.buffer.extend(b"\r\n"); + } + + // using http::h1::date is quite a lot faster than generating + // a unique Date header each time like req/s goes up about 10% + if !msg.headers.contains_key(DATE) { + self.buffer.reserve(date::DATE_VALUE_LENGTH + 8); + self.buffer.extend(b"Date: "); + date::extend(&mut self.buffer); + self.buffer.extend(b"\r\n"); + } + + // default content-type + if !msg.headers.contains_key(CONTENT_TYPE) { + self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref()); + } + + self.buffer.extend(b"\r\n"); + + if let Body::Binary(ref bytes) = body { + self.buffer.extend_from_slice(bytes.as_ref()); + return Ok(WriterState::Done) + } + msg.replace_body(body); + + Ok(WriterState::Done) + } + + fn write(&mut self, payload: &[u8]) -> Result { + if !self.disconnected { + if self.started { + // TODO: add warning, write after EOF + self.encoder.encode(&mut self.buffer, payload); + } else { + // might be response for EXCEPT + self.buffer.extend_from_slice(payload) + } + } + + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + return Ok(WriterState::Pause) + } else { + return Ok(WriterState::Done) + } + } + + fn write_eof(&mut self) -> Result { + if !self.encoder.encode_eof(&mut self.buffer) { + //debug!("last payload item, but it is not EOF "); + Err(io::Error::new(io::ErrorKind::Other, + "Last payload item, but eof is not reached")) + } else { + if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + return Ok(WriterState::Pause) + } else { + return Ok(WriterState::Done) + } + } + } + + fn poll_complete(&mut self) -> Poll<(), io::Error> { + match self.write_to_stream() { + Ok(WriterState::Done) => Ok(Async::Ready(())), + Ok(WriterState::Pause) => Ok(Async::NotReady), + Err(err) => Err(err) + } + } +} + +/// Encoders to handle different Transfer-Encodings. +#[derive(Debug, Clone)] +struct Encoder { + kind: Kind, +} + +#[derive(Debug, PartialEq, Clone)] +enum Kind { + /// An Encoder for when Transfer-Encoding includes `chunked`. + Chunked(bool), + /// An Encoder for when Content-Length is set. + /// + /// Enforces that the body is not longer than the Content-Length header. + Length(u64), + /// An Encoder for when Content-Length is not known. + /// + /// Appliction decides when to stop writing. + Eof, +} + +impl Encoder { + + pub fn eof() -> Encoder { + Encoder { + kind: Kind::Eof, + } + } + + pub fn chunked() -> Encoder { + Encoder { + kind: Kind::Chunked(false), + } + } + + pub fn length(len: u64) -> Encoder { + Encoder { + kind: Kind::Length(len), + } + } + + /// Encode message. Return `EOF` state of encoder + pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool { + match self.kind { + Kind::Eof => { + dst.extend(msg); + msg.is_empty() + }, + Kind::Chunked(ref mut eof) => { + if *eof { + return true; + } + + if msg.is_empty() { + *eof = true; + dst.extend(b"0\r\n\r\n"); + } else { + write!(dst, "{:X}\r\n", msg.len()).unwrap(); + dst.extend(msg); + dst.extend(b"\r\n"); + } + *eof + }, + Kind::Length(ref mut remaining) => { + if msg.is_empty() { + return *remaining == 0 + } + let max = cmp::min(*remaining, msg.len() as u64); + trace!("sized write = {}", max); + dst.extend(msg[..max as usize].as_ref()); + + *remaining -= max as u64; + trace!("encoded {} bytes, remaining = {}", max, remaining); + *remaining == 0 + }, + } + } + + /// Encode eof. Return `EOF` state of encoder + pub fn encode_eof(&mut self, dst: &mut BytesMut) -> bool { + match self.kind { + Kind::Eof => true, + Kind::Chunked(ref mut eof) => { + if *eof { + return true; + } + + *eof = true; + dst.extend(b"0\r\n\r\n"); + true + }, + Kind::Length(ref mut remaining) => { + return *remaining == 0 + }, + } + } +} diff --git a/src/h2.rs b/src/h2.rs index 21ab1e191..ecf429c0a 100644 --- a/src/h2.rs +++ b/src/h2.rs @@ -1,9 +1,147 @@ -use std::{io, cmp}; +use std::{io, cmp, mem}; +use std::rc::Rc; use std::io::{Read, Write}; +use std::cell::UnsafeCell; +use std::collections::VecDeque; + +use http::request::Parts; +use http2::{RecvStream}; +use http2::server::{Server, Handshake, Respond}; use bytes::{Buf, Bytes}; -use futures::Poll; +use futures::{Async, Poll, Future, Stream}; use tokio_io::{AsyncRead, AsyncWrite}; +use task::Task; +use server::HttpHandler; +use httpcodes::HTTPNotFound; +use httprequest::HttpRequest; +use payload::{Payload, PayloadError, PayloadSender}; + + +pub(crate) struct Http2 + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static +{ + router: Rc>, + #[allow(dead_code)] + addr: A, + state: State>, + error: bool, + tasks: VecDeque, +} + +enum State { + Handshake(Handshake), + Server(Server), + Empty, +} + +impl Http2 + where T: AsyncRead + AsyncWrite + 'static, + A: 'static, + H: HttpHandler + 'static +{ + pub fn new(stream: T, addr: A, router: Rc>, buf: Bytes) -> Self { + Http2{ router: router, + addr: addr, + error: false, + tasks: VecDeque::new(), + state: State::Handshake( + Server::handshake(IoWrapper{unread: Some(buf), inner: stream})) } + } + + pub fn poll(&mut self) -> Poll<(), ()> { + // handshake + self.state = if let State::Handshake(ref mut handshake) = self.state { + match handshake.poll() { + Ok(Async::Ready(srv)) => { + State::Server(srv) + }, + Ok(Async::NotReady) => + return Ok(Async::NotReady), + Err(err) => { + trace!("Error handling connection: {}", err); + return Err(()) + } + } + } else { + mem::replace(&mut self.state, State::Empty) + }; + + // get request + let poll = if let State::Server(ref mut server) = self.state { + server.poll() + } else { + unreachable!("Http2::poll() state was not advanced completely!") + }; + + match poll { + Ok(Async::NotReady) => { + // Ok(Async::NotReady); + () + } + Err(err) => { + trace!("Connection error: {}", err); + self.error = true; + }, + Ok(Async::Ready(None)) => { + + }, + Ok(Async::Ready(Some((req, resp)))) => { + let (parts, body) = req.into_parts(); + let entry = Entry::new(parts, body, resp, &self.router); + } + } + + Ok(Async::Ready(())) + } +} + +struct Entry { + task: Task, + req: UnsafeCell, + payload: PayloadSender, + recv: RecvStream, + respond: Respond, + eof: bool, + error: bool, + finished: bool, +} + +impl Entry { + fn new(parts: Parts, + recv: RecvStream, + resp: Respond, + router: &Rc>) -> Entry + where H: HttpHandler + 'static + { + let path = parts.uri.path().to_owned(); + let query = parts.uri.query().unwrap_or("").to_owned(); + + println!("PARTS: {:?}", parts); + let mut req = HttpRequest::new( + parts.method, path, parts.version, parts.headers, query); + let (psender, payload) = Payload::new(false); + + // start request processing + let mut task = None; + for h in router.iter() { + if req.path().starts_with(h.prefix()) { + task = Some(h.handle(&mut req, payload)); + break + } + } + println!("REQ: {:?}", req); + + Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)), + req: UnsafeCell::new(req), + payload: psender, + recv: recv, + respond: resp, + eof: false, + error: false, + finished: false} + } +} struct IoWrapper { unread: Option, @@ -14,9 +152,9 @@ impl Read for IoWrapper { fn read(&mut self, buf: &mut [u8]) -> io::Result { if let Some(mut bytes) = self.unread.take() { let size = cmp::min(buf.len(), bytes.len()); - buf.copy_from_slice(&bytes[..size]); - bytes.split_to(size); - if !bytes.is_empty() { + buf[..size].copy_from_slice(&bytes[..size]); + if bytes.len() > size { + bytes.split_to(size); self.unread = Some(bytes); } Ok(size) diff --git a/src/lib.rs b/src/lib.rs index 838307950..9dc538124 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ extern crate mime_guess; extern crate url; extern crate percent_encoding; extern crate actix; +extern crate h2 as http2; #[cfg(feature="tls")] extern crate native_tls; @@ -45,6 +46,7 @@ mod wsframe; mod wsproto; mod h1; mod h2; +mod h1writer; pub mod ws; pub mod dev; diff --git a/src/server.rs b/src/server.rs index 3d852e50d..55f85b6d3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,13 +1,9 @@ -use std::{io, net}; +use std::{io, net, mem}; use std::rc::Rc; -use std::cell::UnsafeCell; -use std::time::Duration; use std::marker::PhantomData; -use std::collections::VecDeque; use actix::dev::*; use futures::{Future, Poll, Async, Stream}; -use tokio_core::reactor::Timeout; use tokio_core::net::{TcpListener, TcpStream}; use tokio_io::{AsyncRead, AsyncWrite}; @@ -17,9 +13,9 @@ use native_tls::TlsAcceptor; use tokio_tls::{TlsStream, TlsAcceptorExt}; use h1; +use h2; use task::Task; use payload::Payload; -use httpcodes::HTTPNotFound; use httprequest::HttpRequest; /// Low level http request handler @@ -153,11 +149,10 @@ impl HttpServer, net::SocketAddr, H> { println!("SSL"); TlsAcceptorExt::accept_async(acc.as_ref(), stream) .map(move |t| { - println!("connected {:?} {:?}", t, addr); IoStream(t, addr) }) .map_err(|err| { - println!("ERR: {:?}", err); + trace!("Error during handling tls connection: {}", err); io::Error::new(io::ErrorKind::Other, err) }) })); @@ -195,42 +190,25 @@ impl Handler, io::Error> for HttpServer -> Response> { Arbiter::handle().spawn( - HttpChannel{router: Rc::clone(&self.h), - addr: msg.1, - stream: msg.0, - reader: h1::Reader::new(), - error: false, - items: VecDeque::new(), - inactive: VecDeque::new(), - keepalive: true, - keepalive_timer: None, + HttpChannel{ + proto: Protocol::H1(h1::Http1::new(msg.0, msg.1, Rc::clone(&self.h))) }); Self::empty() } } -struct Entry { - task: Task, - req: UnsafeCell, - eof: bool, - error: bool, - finished: bool, +enum Protocol + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static +{ + H1(h1::Http1), + H2(h2::Http2), + None, } -const KEEPALIVE_PERIOD: u64 = 15; // seconds -const MAX_PIPELINED_MESSAGES: usize = 16; - -pub struct HttpChannel { - router: Rc>, - #[allow(dead_code)] - addr: A, - stream: T, - reader: h1::Reader, - error: bool, - items: VecDeque, - inactive: VecDeque, - keepalive: bool, - keepalive_timer: Option, +pub struct HttpChannel + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static +{ + proto: Protocol, } /*impl Drop for HttpChannel { @@ -240,193 +218,45 @@ pub struct HttpChannel { }*/ impl Actor for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, - A: 'static, - H: HttpHandler + 'static + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static { type Context = Context; } impl Future for HttpChannel - where T: AsyncRead + AsyncWrite + 'static, - A: 'static, - H: HttpHandler + 'static + where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static { type Item = (); type Error = (); fn poll(&mut self) -> Poll { - // keep-alive timer - if let Some(ref mut timeout) = self.keepalive_timer { - match timeout.poll() { - Ok(Async::Ready(_)) => - return Ok(Async::Ready(())), - Ok(Async::NotReady) => (), - Err(_) => unreachable!(), + match self.proto { + Protocol::H1(ref mut h1) => { + match h1.poll() { + Ok(Async::Ready(h1::Http1Result::Done)) => + return Ok(Async::Ready(())), + Ok(Async::Ready(h1::Http1Result::Upgrade)) => (), + Ok(Async::NotReady) => + return Ok(Async::NotReady), + Err(_) => + return Err(()), + } } + Protocol::H2(ref mut h2) => + return h2.poll(), + Protocol::None => + unreachable!() } - loop { - let mut not_ready = true; - - // check in-flight messages - let mut idx = 0; - while idx < self.items.len() { - if idx == 0 { - if self.items[idx].error { - return Err(()) - } - - // this is anoying - let req = unsafe {self.items[idx].req.get().as_mut().unwrap()}; - match self.items[idx].task.poll_io(&mut self.stream, req) - { - Ok(Async::Ready(ready)) => { - not_ready = false; - let mut item = self.items.pop_front().unwrap(); - - // overide keep-alive state - if self.keepalive { - self.keepalive = item.task.keepalive(); - } - if !ready { - item.eof = true; - self.inactive.push_back(item); - } - - // no keep-alive - if ready && !self.keepalive && - self.items.is_empty() && self.inactive.is_empty() - { - return Ok(Async::Ready(())) - } - continue - }, - Ok(Async::NotReady) => (), - Err(_) => { - // it is not possible to recover from error - // during task handling, so just drop connection - return Err(()) - } - } - } else if !self.items[idx].finished && !self.items[idx].error { - match self.items[idx].task.poll() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => { - not_ready = false; - self.items[idx].finished = true; - }, - Err(_) => - self.items[idx].error = true, - } - } - idx += 1; - } - - // check inactive tasks - let mut idx = 0; - while idx < self.inactive.len() { - if idx == 0 && self.inactive[idx].error && self.inactive[idx].finished { - let _ = self.inactive.pop_front(); - continue - } - - if !self.inactive[idx].finished && !self.inactive[idx].error { - match self.inactive[idx].task.poll() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => { - not_ready = false; - self.inactive[idx].finished = true - } - Err(_) => - self.inactive[idx].error = true, - } - } - idx += 1; - } - - // read incoming data - if !self.error && self.items.len() < MAX_PIPELINED_MESSAGES { - match self.reader.parse(&mut self.stream) { - Ok(Async::Ready((mut req, payload))) => { - not_ready = false; - - // stop keepalive timer - self.keepalive_timer.take(); - - // start request processing - let mut task = None; - for h in self.router.iter() { - if req.path().starts_with(h.prefix()) { - task = Some(h.handle(&mut req, payload)); - break - } - } - - self.items.push_back( - Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)), - req: UnsafeCell::new(req), - eof: false, - error: false, - finished: false}); - } - Err(err) => { - // notify all tasks - not_ready = false; - for entry in &mut self.items { - entry.task.disconnected() - } - - // kill keepalive - self.keepalive = false; - self.keepalive_timer.take(); - - // on parse error, stop reading stream but - // tasks need to be completed - self.error = true; - - if self.items.is_empty() { - if let h1::ReaderError::Error(err) = err { - self.items.push_back( - Entry {task: Task::reply(err), - req: UnsafeCell::new(HttpRequest::for_error()), - eof: false, - error: false, - finished: false}); - } - } - } - Ok(Async::NotReady) => { - // start keep-alive timer, this is also slow request timeout - if self.items.is_empty() && self.inactive.is_empty() { - if self.keepalive { - if self.keepalive_timer.is_none() { - trace!("Start keep-alive timer"); - let mut timeout = Timeout::new( - Duration::new(KEEPALIVE_PERIOD, 0), - Arbiter::handle()).unwrap(); - // register timeout - let _ = timeout.poll(); - self.keepalive_timer = Some(timeout); - } - } else { - // keep-alive disable, drop connection - return Ok(Async::Ready(())) - } - } - return Ok(Async::NotReady) - } - } - } - - // check for parse error - if self.items.is_empty() && self.inactive.is_empty() && self.error { - return Ok(Async::Ready(())) - } - - if not_ready { - return Ok(Async::NotReady) + // upgrade to h2 + let proto = mem::replace(&mut self.proto, Protocol::None); + match proto { + Protocol::H1(h1) => { + let (stream, addr, router, buf) = h1.into_inner(); + self.proto = Protocol::H2(h2::Http2::new(stream, addr, router, buf)); + return self.poll() } + _ => unreachable!() } } } diff --git a/src/task.rs b/src/task.rs index ec3f6bd59..073cde62b 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,27 +1,18 @@ -use std::{mem, cmp, io}; +use std::{mem, io}; use std::rc::Rc; -use std::fmt::Write; use std::cell::RefCell; use std::collections::VecDeque; -use http::{StatusCode, Version}; -use http::header::{HeaderValue, - CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE}; -use bytes::BytesMut; use futures::{Async, Future, Poll, Stream}; use futures::task::{Task as FutureTask, current as current_task}; -use tokio_io::AsyncWrite; -use date; -use body::Body; +use h1writer::{Writer, WriterState}; use route::Frame; use application::Middleware; use httprequest::HttpRequest; use httpresponse::HttpResponse; type FrameStream = Stream; -const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific -const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k #[derive(PartialEq, Debug)] enum TaskRunningState { @@ -34,6 +25,16 @@ impl TaskRunningState { fn is_done(&self) -> bool { *self == TaskRunningState::Done } + fn pause(&mut self) { + if *self != TaskRunningState::Done { + *self = TaskRunningState::Paused + } + } + fn resume(&mut self) { + if *self != TaskRunningState::Done { + *self = TaskRunningState::Running + } + } } #[derive(PartialEq, Debug)] @@ -100,17 +101,12 @@ impl Future for DrainFut { } } - pub struct Task { state: TaskRunningState, iostate: TaskIOState, frames: VecDeque, stream: TaskStream, - encoder: Encoder, - buffer: BytesMut, drain: Vec>>, - upgrade: bool, - keepalive: bool, prepared: Option, disconnected: bool, middlewares: Option>>>, @@ -129,10 +125,6 @@ impl Task { frames: frames, drain: Vec::new(), stream: TaskStream::None, - encoder: Encoder::length(0), - buffer: BytesMut::new(), - upgrade: false, - keepalive: false, prepared: None, disconnected: false, middlewares: None, @@ -147,11 +139,7 @@ impl Task { iostate: TaskIOState::ReadingMessage, frames: VecDeque::new(), stream: TaskStream::Stream(Box::new(stream)), - encoder: Encoder::length(0), - buffer: BytesMut::new(), drain: Vec::new(), - upgrade: false, - keepalive: false, prepared: None, disconnected: false, middlewares: None, @@ -165,158 +153,26 @@ impl Task { iostate: TaskIOState::ReadingMessage, frames: VecDeque::new(), stream: TaskStream::Context(Box::new(ctx)), - encoder: Encoder::length(0), - buffer: BytesMut::new(), drain: Vec::new(), - upgrade: false, - keepalive: false, prepared: None, disconnected: false, middlewares: None, } } - pub(crate) fn keepalive(&self) -> bool { - self.keepalive && !self.upgrade - } - pub(crate) fn set_middlewares(&mut self, middlewares: Rc>>) { self.middlewares = Some(middlewares); } pub(crate) fn disconnected(&mut self) { - let len = self.buffer.len(); - self.buffer.split_to(len); self.disconnected = true; if let TaskStream::Context(ref mut ctx) = self.stream { ctx.disconnected(); } } - fn prepare(&mut self, req: &mut HttpRequest, msg: HttpResponse) - { - trace!("Prepare message status={:?}", msg.status); - - // run middlewares - let mut msg = if let Some(middlewares) = self.middlewares.take() { - let mut msg = msg; - for middleware in middlewares.iter() { - msg = middleware.response(req, msg); - } - self.middlewares = Some(middlewares); - msg - } else { - msg - }; - - // prepare task - let mut extra = 0; - let body = msg.replace_body(Body::Empty); - let version = msg.version().unwrap_or_else(|| req.version()); - self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive()); - - match body { - Body::Empty => { - if msg.chunked() { - error!("Chunked transfer is enabled but body is set to Empty"); - } - msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0")); - msg.headers.remove(TRANSFER_ENCODING); - self.encoder = Encoder::length(0); - }, - Body::Length(n) => { - if msg.chunked() { - error!("Chunked transfer is enabled but body with specific length is specified"); - } - msg.headers.insert( - CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", n).as_str()).unwrap()); - msg.headers.remove(TRANSFER_ENCODING); - self.encoder = Encoder::length(n); - }, - Body::Binary(ref bytes) => { - extra = bytes.len(); - msg.headers.insert( - CONTENT_LENGTH, - HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap()); - msg.headers.remove(TRANSFER_ENCODING); - self.encoder = Encoder::length(0); - } - Body::Streaming => { - if msg.chunked() { - if version < Version::HTTP_11 { - error!("Chunked transfer encoding is forbidden for {:?}", version); - } - msg.headers.remove(CONTENT_LENGTH); - msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked")); - self.encoder = Encoder::chunked(); - } else { - self.encoder = Encoder::eof(); - } - } - Body::Upgrade => { - msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); - self.encoder = Encoder::eof(); - } - } - - // Connection upgrade - if msg.upgrade() { - msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade")); - } - // keep-alive - else if self.keepalive { - if version < Version::HTTP_11 { - msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive")); - } - } else if version >= Version::HTTP_11 { - msg.headers.insert(CONNECTION, HeaderValue::from_static("close")); - } - - // render message - let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra; - self.buffer.reserve(init_cap); - - if version == Version::HTTP_11 && msg.status == StatusCode::OK { - self.buffer.extend(b"HTTP/1.1 200 OK\r\n"); - } else { - let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status); - } - for (key, value) in &msg.headers { - let t: &[u8] = key.as_ref(); - self.buffer.extend(t); - self.buffer.extend(b": "); - self.buffer.extend(value.as_ref()); - self.buffer.extend(b"\r\n"); - } - - // using http::h1::date is quite a lot faster than generating - // a unique Date header each time like req/s goes up about 10% - if !msg.headers.contains_key(DATE) { - self.buffer.reserve(date::DATE_VALUE_LENGTH + 8); - self.buffer.extend(b"Date: "); - date::extend(&mut self.buffer); - self.buffer.extend(b"\r\n"); - } - - // default content-type - if !msg.headers.contains_key(CONTENT_TYPE) { - self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref()); - } - - self.buffer.extend(b"\r\n"); - - if let Body::Binary(ref bytes) = body { - self.buffer.extend_from_slice(bytes.as_ref()); - self.prepared = Some(msg); - return - } - msg.replace_body(body); - self.prepared = Some(msg); - } - pub(crate) fn poll_io(&mut self, io: &mut T, req: &mut HttpRequest) -> Poll - where T: AsyncWrite + where T: Writer { trace!("POLL-IO frames:{:?}", self.frames.len()); // response is completed @@ -328,87 +184,76 @@ impl Task { match self.poll() { Ok(Async::Ready(_)) => { self.state = TaskRunningState::Done; - } + }, Ok(Async::NotReady) => (), Err(_) => return Err(()) } } // use exiting frames - while let Some(frame) = self.frames.pop_front() { - trace!("IO Frame: {:?}", frame); - match frame { - Frame::Message(response) => { - if !self.disconnected { - self.prepare(req, response); + if self.state != TaskRunningState::Paused { + while let Some(frame) = self.frames.pop_front() { + trace!("IO Frame: {:?}", frame); + let res = match frame { + Frame::Message(mut response) => { + trace!("Prepare message status={:?}", response.status); + + // run middlewares + let mut response = + if let Some(middlewares) = self.middlewares.take() { + let mut response = response; + for middleware in middlewares.iter() { + response = middleware.response(req, response); + } + self.middlewares = Some(middlewares); + response + } else { + response + }; + + let result = io.start(req, &mut response); + self.prepared = Some(response); + result } - } - Frame::Payload(Some(chunk)) => { - if !self.disconnected { - if self.prepared.is_some() { - // TODO: add warning, write after EOF - self.encoder.encode(&mut self.buffer, chunk.as_ref()); - } else { - // might be response for EXCEPT - self.buffer.extend_from_slice(chunk.as_ref()) - } + Frame::Payload(Some(chunk)) => { + io.write(chunk.as_ref()) + }, + Frame::Payload(None) => { + self.iostate = TaskIOState::Done; + io.write_eof() + }, + Frame::Drain(fut) => { + self.drain.push(fut); + break } - }, - Frame::Payload(None) => { - if !self.disconnected && - !self.encoder.encode(&mut self.buffer, [].as_ref()) - { - // TODO: add error "not eof"" - debug!("last payload item, but it is not EOF "); - return Err(()) + }; + + match res { + Ok(WriterState::Pause) => { + self.state.pause(); + break } - break - }, - Frame::Drain(fut) => { - self.drain.push(fut); - break + Ok(WriterState::Done) => self.state.resume(), + Err(_) => return Err(()) } } } } - // write bytes to TcpStream - if !self.disconnected { - while !self.buffer.is_empty() { - match io.write(self.buffer.as_ref()) { - Ok(n) => { - self.buffer.split_to(n); - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - break - } - Err(_) => return Err(()), - } + // flush io + match io.poll_complete() { + Ok(Async::Ready(())) => self.state.resume(), + Ok(Async::NotReady) => { + return Ok(Async::NotReady) } - } - - // should pause task - if self.state != TaskRunningState::Done { - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { - self.state = TaskRunningState::Paused; - } else if self.state == TaskRunningState::Paused { - self.state = TaskRunningState::Running; + Err(err) => { + trace!("Error sending data: {}", err); + return Err(()) } - } else { - // at this point we wont get any more Frames - self.iostate = TaskIOState::Done; } // drain - if self.buffer.is_empty() && !self.drain.is_empty() { - match io.flush() { - Ok(_) => (), - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - return Ok(Async::NotReady) - } - Err(_) => return Err(()), - } - + if !self.drain.is_empty() { for fut in &mut self.drain { fut.borrow_mut().set() } @@ -416,7 +261,7 @@ impl Task { } // response is completed - if (self.buffer.is_empty() || self.disconnected) && self.iostate.is_done() { + if self.iostate.is_done() { // run middlewares if let Some(ref mut resp) = self.prepared { if let Some(middlewares) = self.middlewares.take() { @@ -443,8 +288,8 @@ impl Task { error!("Non expected frame {:?}", frame); return Err(()) } - self.upgrade = msg.upgrade(); - if self.upgrade || msg.body().has_body() { + let upgrade = msg.upgrade(); + if upgrade || msg.body().has_body() { self.iostate = TaskIOState::ReadingPayload; } else { self.iostate = TaskIOState::Done; @@ -489,89 +334,3 @@ impl Future for Task { result } } - -/// Encoders to handle different Transfer-Encodings. -#[derive(Debug, Clone)] -struct Encoder { - kind: Kind, -} - -#[derive(Debug, PartialEq, Clone)] -enum Kind { - /// An Encoder for when Transfer-Encoding includes `chunked`. - Chunked(bool), - /// An Encoder for when Content-Length is set. - /// - /// Enforces that the body is not longer than the Content-Length header. - Length(u64), - /// An Encoder for when Content-Length is not known. - /// - /// Appliction decides when to stop writing. - Eof, -} - -impl Encoder { - - pub fn eof() -> Encoder { - Encoder { - kind: Kind::Eof, - } - } - - pub fn chunked() -> Encoder { - Encoder { - kind: Kind::Chunked(false), - } - } - - pub fn length(len: u64) -> Encoder { - Encoder { - kind: Kind::Length(len), - } - } - - /*pub fn is_eof(&self) -> bool { - match self.kind { - Kind::Eof | Kind::Length(0) => true, - Kind::Chunked(eof) => eof, - _ => false, - } - }*/ - - /// Encode message. Return `EOF` state of encoder - pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool { - match self.kind { - Kind::Eof => { - dst.extend(msg); - msg.is_empty() - }, - Kind::Chunked(ref mut eof) => { - if *eof { - return true; - } - - if msg.is_empty() { - *eof = true; - dst.extend(b"0\r\n\r\n"); - } else { - write!(dst, "{:X}\r\n", msg.len()).unwrap(); - dst.extend(msg); - dst.extend(b"\r\n"); - } - *eof - }, - Kind::Length(ref mut remaining) => { - if msg.is_empty() { - return *remaining == 0 - } - let max = cmp::min(*remaining, msg.len() as u64); - trace!("sized write = {}", max); - dst.extend(msg[..max as usize].as_ref()); - - *remaining -= max as u64; - trace!("encoded {} bytes, remaining = {}", max, remaining); - *remaining == 0 - }, - } - } -}