diff --git a/Cargo.toml b/Cargo.toml
index f5055122f..37586ca2c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -28,8 +28,6 @@ default = []
# tls
tls = ["native-tls", "tokio-tls"]
-# http2 = ["h2"]
-
[dependencies]
log = "0.3"
time = "0.1"
diff --git a/cov.sh b/cov.sh
deleted file mode 100644
index 8e9fd237b..000000000
--- a/cov.sh
+++ /dev/null
@@ -1,4 +0,0 @@
-#!/bin/bash
-
-for file in target/debug/actix_web-*[^\.d]; do mkdir -p "target/cov/$(basename $file)"; /usr/local/bin/kcov --exclude-pattern=/.cargo,/usr/lib --verify "target/cov/$(basename $file)" "$file"; done &&
-for file in target/debug/test_*[^\.d]; do mkdir -p "target/cov/$(basename $file)"; /usr/local/bin/kcov --exclude-pattern=/.cargo,/usr/lib --verify "target/cov/$(basename $file)" "$file"; done
diff --git a/src/context.rs b/src/context.rs
index 04f39e3fd..b6fd4425c 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -47,6 +47,7 @@ impl ActorContext for HttpContext where A: Actor + Route
{
/// Stop actor execution
fn stop(&mut self) {
+ self.stream.push_back(Frame::Payload(None));
self.items.stop();
self.address.close();
if self.state == ActorState::Running {
@@ -141,7 +142,6 @@ impl HttpContext where A: Actor + Route {
/// Indicate end of streamimng payload. Also this method calls `Self::close`.
pub fn write_eof(&mut self) {
- self.stream.push_back(Frame::Payload(None));
self.stop();
}
diff --git a/src/h1.rs b/src/h1.rs
index ba7b23061..78ad3ad32 100644
--- a/src/h1.rs
+++ b/src/h1.rs
@@ -1,22 +1,282 @@
use std::{self, io, ptr};
+use std::rc::Rc;
+use std::cell::UnsafeCell;
+use std::time::Duration;
+use std::collections::VecDeque;
+use actix::Arbiter;
use httparse;
use http::{Method, Version, HttpTryFrom, HeaderMap};
use http::header::{self, HeaderName, HeaderValue};
use bytes::{Bytes, BytesMut, BufMut};
-use futures::{Async, Poll};
-use tokio_io::AsyncRead;
+use futures::{Future, Poll, Async};
+use tokio_io::{AsyncRead, AsyncWrite};
+use tokio_core::reactor::Timeout;
use percent_encoding;
+use task::Task;
+use server::HttpHandler;
use error::ParseError;
+use httpcodes::HTTPNotFound;
use httprequest::HttpRequest;
use payload::{Payload, PayloadError, PayloadSender};
+use h1writer::H1Writer;
-const MAX_HEADERS: usize = 100;
+const KEEPALIVE_PERIOD: u64 = 15; // seconds
const INIT_BUFFER_SIZE: usize = 8192;
const MAX_BUFFER_SIZE: usize = 131_072;
+const MAX_HEADERS: usize = 100;
+const MAX_PIPELINED_MESSAGES: usize = 16;
const HTTP2_PREFACE: [u8; 14] = *b"PRI * HTTP/2.0";
+pub(crate) enum Http1Result {
+ Done,
+ Upgrade,
+}
+
+pub(crate) struct Http1 {
+ router: Rc>,
+ #[allow(dead_code)]
+ addr: A,
+ stream: H1Writer,
+ reader: Reader,
+ read_buf: BytesMut,
+ error: bool,
+ tasks: VecDeque,
+ keepalive: bool,
+ keepalive_timer: Option,
+ h2: bool,
+}
+
+struct Entry {
+ task: Task,
+ req: UnsafeCell,
+ eof: bool,
+ error: bool,
+ finished: bool,
+}
+
+impl Http1
+ where T: AsyncRead + AsyncWrite + 'static,
+ A: 'static,
+ H: HttpHandler + 'static
+{
+ pub fn new(stream: T, addr: A, router: Rc>) -> Self {
+ Http1{ router: router,
+ addr: addr,
+ stream: H1Writer::new(stream),
+ reader: Reader::new(),
+ read_buf: BytesMut::new(),
+ error: false,
+ tasks: VecDeque::new(),
+ keepalive: true,
+ keepalive_timer: None,
+ h2: false }
+ }
+
+ pub fn into_inner(mut self) -> (T, A, Rc>, Bytes) {
+ (self.stream.into_inner(), self.addr, self.router, self.read_buf.freeze())
+ }
+
+ pub fn poll(&mut self) -> Poll {
+ // keep-alive timer
+ if let Some(ref mut timeout) = self.keepalive_timer {
+ match timeout.poll() {
+ Ok(Async::Ready(_)) =>
+ return Ok(Async::Ready(Http1Result::Done)),
+ Ok(Async::NotReady) => (),
+ Err(_) => unreachable!(),
+ }
+ }
+
+ loop {
+ let mut not_ready = true;
+
+ // check in-flight messages
+ let mut io = false;
+ let mut idx = 0;
+ while idx < self.tasks.len() {
+ let item = &mut self.tasks[idx];
+
+ if !io && !item.eof {
+ if item.error {
+ return Err(())
+ }
+
+ // this is anoying
+ let req = unsafe {item.req.get().as_mut().unwrap()};
+ match item.task.poll_io(&mut self.stream, req)
+ {
+ Ok(Async::Ready(ready)) => {
+ not_ready = false;
+
+ // overide keep-alive state
+ if self.keepalive {
+ self.keepalive = self.stream.keepalive();
+ }
+ self.stream = H1Writer::new(self.stream.into_inner());
+
+ item.eof = true;
+ if ready {
+ item.finished = true;
+ }
+ },
+ Ok(Async::NotReady) => {
+ // no more IO for this iteration
+ io = true;
+ },
+ Err(_) => {
+ // it is not possible to recover from error
+ // during task handling, so just drop connection
+ return Err(())
+ }
+ }
+ } else if !item.finished {
+ match item.task.poll() {
+ Ok(Async::NotReady) => (),
+ Ok(Async::Ready(_)) => {
+ not_ready = false;
+ item.finished = true;
+ },
+ Err(_) =>
+ item.error = true,
+ }
+ }
+ idx += 1;
+ }
+
+ // cleanup finished tasks
+ while !self.tasks.is_empty() {
+ if self.tasks[0].eof && self.tasks[0].finished {
+ self.tasks.pop_front();
+ } else {
+ break
+ }
+ }
+
+ // no keep-alive
+ if !self.keepalive && self.tasks.is_empty() {
+ if self.h2 {
+ return Ok(Async::Ready(Http1Result::Upgrade))
+ } else {
+ return Ok(Async::Ready(Http1Result::Done))
+ }
+ }
+
+ // read incoming data
+ if !self.error && !self.h2 && self.tasks.len() < MAX_PIPELINED_MESSAGES {
+ match self.reader.parse(self.stream.get_mut(), &mut self.read_buf) {
+ Ok(Async::Ready(Item::Http1(mut req, payload))) => {
+ not_ready = false;
+
+ // stop keepalive timer
+ self.keepalive_timer.take();
+
+ // start request processing
+ let mut task = None;
+ for h in self.router.iter() {
+ if req.path().starts_with(h.prefix()) {
+ task = Some(h.handle(&mut req, payload));
+ break
+ }
+ }
+
+ self.tasks.push_back(
+ Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
+ req: UnsafeCell::new(req),
+ eof: false,
+ error: false,
+ finished: false});
+ }
+ Ok(Async::Ready(Item::Http2)) => {
+ self.h2 = true;
+ }
+ Err(ReaderError::Disconnect) => {
+ not_ready = false;
+ self.error = true;
+ self.stream.disconnected();
+ for entry in &mut self.tasks {
+ entry.task.disconnected()
+ }
+ },
+ Err(err) => {
+ // notify all tasks
+ not_ready = false;
+ self.stream.disconnected();
+ for entry in &mut self.tasks {
+ entry.task.disconnected()
+ }
+
+ // kill keepalive
+ self.keepalive = false;
+ self.keepalive_timer.take();
+
+ // on parse error, stop reading stream but
+ // tasks need to be completed
+ self.error = true;
+
+ if self.tasks.is_empty() {
+ if let ReaderError::Error(err) = err {
+ self.tasks.push_back(
+ Entry {task: Task::reply(err),
+ req: UnsafeCell::new(HttpRequest::for_error()),
+ eof: false,
+ error: false,
+ finished: false});
+ }
+ }
+ }
+ Ok(Async::NotReady) => {
+ // start keep-alive timer, this is also slow request timeout
+ if self.tasks.is_empty() {
+ if self.keepalive {
+ if self.keepalive_timer.is_none() {
+ trace!("Start keep-alive timer");
+ let mut timeout = Timeout::new(
+ Duration::new(KEEPALIVE_PERIOD, 0),
+ Arbiter::handle()).unwrap();
+ // register timeout
+ let _ = timeout.poll();
+ self.keepalive_timer = Some(timeout);
+ }
+ } else {
+ // keep-alive disable, drop connection
+ return Ok(Async::Ready(Http1Result::Done))
+ }
+ }
+ return Ok(Async::NotReady)
+ }
+ }
+ }
+
+ // check for parse error
+ if self.tasks.is_empty() {
+ if self.error || self.keepalive_timer.is_none() {
+ return Ok(Async::Ready(Http1Result::Done))
+ }
+ else if self.h2 {
+ return Ok(Async::Ready(Http1Result::Upgrade))
+ }
+ }
+
+ if not_ready {
+ return Ok(Async::NotReady)
+ }
+ }
+ }
+}
+
+#[derive(Debug)]
+enum Item {
+ Http1(HttpRequest, Payload),
+ Http2,
+}
+
+struct Reader {
+ h1: bool,
+ payload: Option,
+}
+
enum Decoding {
Paused,
Ready,
@@ -28,19 +288,9 @@ struct PayloadInfo {
decoder: Decoder,
}
-pub(crate) struct Reader {
- read_buf: BytesMut,
- payload: Option,
-}
-
#[derive(Debug)]
-pub(crate) enum ReaderItem {
- Http1(HttpRequest, Payload),
- Http2,
-}
-
-#[derive(Debug)]
-pub(crate) enum ReaderError {
+enum ReaderError {
+ Disconnect,
Payload,
Error(ParseError),
}
@@ -55,19 +305,19 @@ enum Message {
impl Reader {
pub fn new() -> Reader {
Reader {
- read_buf: BytesMut::new(),
+ h1: false,
payload: None,
}
}
- fn decode(&mut self) -> std::result::Result
+ fn decode(&mut self, buf: &mut BytesMut) -> std::result::Result
{
if let Some(ref mut payload) = self.payload {
if payload.tx.maybe_paused() {
return Ok(Decoding::Paused)
}
loop {
- match payload.decoder.decode(&mut self.read_buf) {
+ match payload.decoder.decode(buf) {
Ok(Async::Ready(Some(bytes))) => {
payload.tx.feed_data(bytes)
},
@@ -87,18 +337,18 @@ impl Reader {
}
}
- pub fn parse(&mut self, io: &mut T) -> Poll<(HttpRequest, Payload), ReaderError>
+ pub fn parse(&mut self, io: &mut T, buf: &mut BytesMut) -> Poll-
where T: AsyncRead
{
loop {
- match self.decode()? {
+ match self.decode(buf)? {
Decoding::Paused => return Ok(Async::NotReady),
Decoding::Ready => {
self.payload = None;
break
},
Decoding::NotReady => {
- match self.read_from_io(io) {
+ match self.read_from_io(io, buf) {
Ok(Async::Ready(0)) => {
if let Some(ref mut payload) = self.payload {
payload.tx.set_error(PayloadError::Incomplete);
@@ -123,7 +373,7 @@ impl Reader {
}
loop {
- match Reader::parse_message(&mut self.read_buf).map_err(ReaderError::Error)? {
+ match Reader::parse_message(buf).map_err(ReaderError::Error)? {
Message::Http1(msg, decoder) => {
let payload = if let Some(decoder) = decoder {
let (tx, rx) = Payload::new(false);
@@ -134,7 +384,7 @@ impl Reader {
self.payload = Some(payload);
loop {
- match self.decode()? {
+ match self.decode(buf)? {
Decoding::Paused =>
break,
Decoding::Ready => {
@@ -142,7 +392,7 @@ impl Reader {
break
},
Decoding::NotReady => {
- match self.read_from_io(io) {
+ match self.read_from_io(io, buf) {
Ok(Async::Ready(0)) => {
trace!("parse eof");
if let Some(ref mut payload) = self.payload {
@@ -171,21 +421,26 @@ impl Reader {
let (_, rx) = Payload::new(true);
rx
};
- return Ok(Async::Ready((msg, payload)));
+ self.h1 = true;
+ return Ok(Async::Ready(Item::Http1(msg, payload)));
},
Message::Http2 => {
+ if self.h1 {
+ return Err(ReaderError::Error(ParseError::Version))
+ }
+ return Ok(Async::Ready(Item::Http2));
},
Message::NotReady => {
- if self.read_buf.capacity() >= MAX_BUFFER_SIZE {
+ if buf.capacity() >= MAX_BUFFER_SIZE {
debug!("MAX_BUFFER_SIZE reached, closing");
return Err(ReaderError::Error(ParseError::TooLarge));
}
},
}
- match self.read_from_io(io) {
+ match self.read_from_io(io, buf) {
Ok(Async::Ready(0)) => {
- trace!("Eof during parse");
- return Err(ReaderError::Error(ParseError::Incomplete));
+ debug!("Ignored premature client disconnection");
+ return Err(ReaderError::Disconnect);
},
Ok(Async::Ready(_)) => (),
Ok(Async::NotReady) =>
@@ -196,17 +451,19 @@ impl Reader {
}
}
- fn read_from_io(&mut self, io: &mut T) -> Poll {
- if self.read_buf.remaining_mut() < INIT_BUFFER_SIZE {
- self.read_buf.reserve(INIT_BUFFER_SIZE);
+ fn read_from_io(&mut self, io: &mut T, buf: &mut BytesMut)
+ -> Poll
+ {
+ if buf.remaining_mut() < INIT_BUFFER_SIZE {
+ buf.reserve(INIT_BUFFER_SIZE);
unsafe { // Zero out unused memory
- let buf = self.read_buf.bytes_mut();
- let len = buf.len();
- ptr::write_bytes(buf.as_mut_ptr(), 0, len);
+ let b = buf.bytes_mut();
+ let len = b.len();
+ ptr::write_bytes(b.as_mut_ptr(), 0, len);
}
}
unsafe {
- let n = match io.read(self.read_buf.bytes_mut()) {
+ let n = match io.read(buf.bytes_mut()) {
Ok(n) => n,
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
@@ -215,18 +472,17 @@ impl Reader {
return Err(e)
}
};
- self.read_buf.advance_mut(n);
+ buf.advance_mut(n);
Ok(Async::Ready(n))
}
}
fn parse_message(buf: &mut BytesMut) -> Result
{
- println!("BUF: {:?}", buf);
- if buf.is_empty() || buf.len() < 14 {
+ if buf.is_empty() {
return Ok(Message::NotReady);
}
- if &buf[..14] == &HTTP2_PREFACE[..] {
+ if buf.len() >= 14 && &buf[..14] == &HTTP2_PREFACE[..] {
return Ok(Message::Http2)
}
@@ -368,7 +624,7 @@ fn record_header_indices(bytes: &[u8],
/// If a message body does not include a Transfer-Encoding, it *should*
/// include a Content-Length header.
#[derive(Debug, Clone, PartialEq)]
-pub struct Decoder {
+struct Decoder {
kind: Kind,
}
@@ -424,7 +680,7 @@ enum ChunkedState {
}
impl Decoder {
- pub fn is_eof(&self) -> bool {
+ /*pub fn is_eof(&self) -> bool {
trace!("is_eof? {:?}", self);
match self.kind {
Kind::Length(0) |
@@ -432,7 +688,7 @@ impl Decoder {
Kind::Eof(true) => true,
_ => false,
}
- }
+ }*/
}
impl Decoder {
@@ -633,7 +889,7 @@ mod tests {
use futures::{Async};
use tokio_io::AsyncRead;
use http::{Version, Method};
- use super::{Reader, ReaderError};
+ use super::*;
struct Buffer {
buf: Bytes,
@@ -682,8 +938,8 @@ mod tests {
macro_rules! parse_ready {
($e:expr) => (
- match Reader::new().parse($e) {
- Ok(Async::Ready((req, payload))) => (req, payload),
+ match Reader::new().parse($e, &mut BytesMut::new()) {
+ Ok(Async::Ready(Item::Http1(req, payload))) => (req, payload),
Ok(_) => panic!("Eof during parsing http request"),
Err(err) => panic!("Error during parsing http request: {:?}", err),
}
@@ -693,7 +949,7 @@ mod tests {
macro_rules! reader_parse_ready {
($e:expr) => (
match $e {
- Ok(Async::Ready((req, payload))) => (req, payload),
+ Ok(Async::Ready(Item::Http1(req, payload))) => (req, payload),
Ok(_) => panic!("Eof during parsing http request"),
Err(err) => panic!("Error during parsing http request: {:?}", err),
}
@@ -701,22 +957,28 @@ mod tests {
}
macro_rules! expect_parse_err {
- ($e:expr) => (match Reader::new().parse($e) {
- Err(err) => match err {
- ReaderError::Error(_) => (),
- _ => panic!("Parse error expected"),
- },
- _ => panic!("Error expected"),
- })
+ ($e:expr) => ({
+ let mut buf = BytesMut::new();
+ match Reader::new().parse($e, &mut buf) {
+ Err(err) => match err {
+ ReaderError::Error(_) => (),
+ _ => panic!("Parse error expected"),
+ },
+ val => {
+ panic!("Error expected")
+ }
+ }}
+ )
}
#[test]
fn test_parse() {
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- match reader.parse(&mut buf) {
- Ok(Async::Ready((req, payload))) => {
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@@ -729,16 +991,17 @@ mod tests {
#[test]
fn test_parse_partial() {
let mut buf = Buffer::new("PUT /test HTTP/1");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- match reader.parse(&mut buf) {
+ match reader.parse(&mut buf, &mut readbuf) {
Ok(Async::NotReady) => (),
_ => panic!("Error"),
}
buf.feed_data(".1\r\n\r\n");
- match reader.parse(&mut buf) {
- Ok(Async::Ready((req, payload))) => {
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::PUT);
assert_eq!(req.path(), "/test");
@@ -751,10 +1014,11 @@ mod tests {
#[test]
fn test_parse_post() {
let mut buf = Buffer::new("POST /test2 HTTP/1.0\r\n\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- match reader.parse(&mut buf) {
- Ok(Async::Ready((req, payload))) => {
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_10);
assert_eq!(*req.method(), Method::POST);
assert_eq!(req.path(), "/test2");
@@ -767,10 +1031,11 @@ mod tests {
#[test]
fn test_parse_body() {
let mut buf = Buffer::new("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- match reader.parse(&mut buf) {
- Ok(Async::Ready((req, mut payload))) => {
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http1(req, mut payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@@ -784,10 +1049,11 @@ mod tests {
fn test_parse_body_crlf() {
let mut buf = Buffer::new(
"\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- match reader.parse(&mut buf) {
- Ok(Async::Ready((req, mut payload))) => {
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http1(req, mut payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@@ -800,13 +1066,14 @@ mod tests {
#[test]
fn test_parse_partial_eof() {
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- not_ready!{ reader.parse(&mut buf) }
+ not_ready!{ reader.parse(&mut buf, &mut readbuf) }
buf.feed_data("\r\n");
- match reader.parse(&mut buf) {
- Ok(Async::Ready((req, payload))) => {
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@@ -819,19 +1086,20 @@ mod tests {
#[test]
fn test_headers_split_field() {
let mut buf = Buffer::new("GET /test HTTP/1.1\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- not_ready!{ reader.parse(&mut buf) }
+ not_ready!{ reader.parse(&mut buf, &mut readbuf) }
buf.feed_data("t");
- not_ready!{ reader.parse(&mut buf) }
+ not_ready!{ reader.parse(&mut buf, &mut readbuf) }
buf.feed_data("es");
- not_ready!{ reader.parse(&mut buf) }
+ not_ready!{ reader.parse(&mut buf, &mut readbuf) }
buf.feed_data("t: value\r\n\r\n");
- match reader.parse(&mut buf) {
- Ok(Async::Ready((req, payload))) => {
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http1(req, payload))) => {
assert_eq!(req.version(), Version::HTTP_11);
assert_eq!(*req.method(), Method::GET);
assert_eq!(req.path(), "/test");
@@ -848,10 +1116,11 @@ mod tests {
"GET /test HTTP/1.1\r\n\
Set-Cookie: c1=cookie1\r\n\
Set-Cookie: c2=cookie2\r\n\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- match reader.parse(&mut buf) {
- Ok(Async::Ready((req, _))) => {
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http1(req, _))) => {
let val: Vec<_> = req.headers().get_all("Set-Cookie")
.iter().map(|v| v.to_str().unwrap().to_owned()).collect();
assert_eq!(val[0], "c1=cookie1");
@@ -1081,14 +1350,15 @@ mod tests {
let mut buf = Buffer::new(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
+ let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(req.chunked().unwrap());
assert!(!payload.eof());
buf.feed_data("4\r\ndata\r\n4\r\nline\r\n0\r\n\r\n");
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(!payload.eof());
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
assert!(payload.eof());
@@ -1099,10 +1369,11 @@ mod tests {
let mut buf = Buffer::new(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
+ let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(req.chunked().unwrap());
assert!(!payload.eof());
@@ -1111,7 +1382,7 @@ mod tests {
POST /test2 HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
- let (req2, payload2) = reader_parse_ready!(reader.parse(&mut buf));
+ let (req2, payload2) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert_eq!(*req2.method(), Method::POST);
assert!(req2.chunked().unwrap());
assert!(!payload2.eof());
@@ -1125,37 +1396,38 @@ mod tests {
let mut buf = Buffer::new(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
+ let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(req.chunked().unwrap());
assert!(!payload.eof());
buf.feed_data("4\r\ndata\r");
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("\n4");
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("\r");
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("\n");
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("li");
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
buf.feed_data("ne\r\n0\r\n");
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
//buf.feed_data("test: test\r\n");
- //not_ready!(reader.parse(&mut buf));
+ //not_ready!(reader.parse(&mut buf, &mut readbuf));
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
assert!(!payload.eof());
buf.feed_data("\r\n");
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(payload.eof());
}
@@ -1164,14 +1436,15 @@ mod tests {
let mut buf = Buffer::new(
"GET /test HTTP/1.1\r\n\
transfer-encoding: chunked\r\n\r\n");
+ let mut readbuf = BytesMut::new();
let mut reader = Reader::new();
- let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf));
+ let (req, mut payload) = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(req.chunked().unwrap());
assert!(!payload.eof());
buf.feed_data("4;test\r\ndata\r\n4\r\nline\r\n0\r\n\r\n"); // test: test\r\n\r\n")
- not_ready!(reader.parse(&mut buf));
+ not_ready!(reader.parse(&mut buf, &mut readbuf));
assert!(!payload.eof());
assert_eq!(payload.readall().unwrap().as_ref(), b"dataline");
assert!(payload.eof());
@@ -1193,4 +1466,16 @@ mod tests {
Err(err) => panic!("{:?}", err),
}
}*/
+
+ #[test]
+ fn test_http2_prefix() {
+ let mut buf = Buffer::new("PRI * HTTP/2.0\r\n\r\n");
+ let mut readbuf = BytesMut::new();
+
+ let mut reader = Reader::new();
+ match reader.parse(&mut buf, &mut readbuf) {
+ Ok(Async::Ready(Item::Http2)) => (),
+ Ok(_) | Err(_) => panic!("Error during parsing http request"),
+ }
+ }
}
diff --git a/src/h1writer.rs b/src/h1writer.rs
new file mode 100644
index 000000000..98f2aa4fa
--- /dev/null
+++ b/src/h1writer.rs
@@ -0,0 +1,351 @@
+use std::{cmp, io};
+use std::fmt::Write;
+use bytes::BytesMut;
+use futures::{Async, Poll};
+use tokio_io::AsyncWrite;
+use http::{Version, StatusCode};
+use http::header::{HeaderValue,
+ CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE};
+
+use date;
+use body::Body;
+use httprequest::HttpRequest;
+use httpresponse::HttpResponse;
+
+const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
+const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k
+
+
+pub(crate) enum WriterState {
+ Done,
+ Pause,
+}
+
+/// Send stream
+pub(crate) trait Writer {
+ fn start(&mut self, req: &mut HttpRequest, resp: &mut HttpResponse)
+ -> Result;
+
+ fn write(&mut self, payload: &[u8]) -> Result;
+
+ fn write_eof(&mut self) -> Result;
+
+ fn poll_complete(&mut self) -> Poll<(), io::Error>;
+}
+
+
+pub(crate) struct H1Writer {
+ stream: Option,
+ buffer: BytesMut,
+ started: bool,
+ encoder: Encoder,
+ upgrade: bool,
+ keepalive: bool,
+ disconnected: bool,
+}
+
+impl H1Writer {
+
+ pub fn new(stream: T) -> H1Writer {
+ H1Writer {
+ stream: Some(stream),
+ buffer: BytesMut::new(),
+ started: false,
+ encoder: Encoder::length(0),
+ upgrade: false,
+ keepalive: false,
+ disconnected: false,
+ }
+ }
+
+ pub fn get_mut(&mut self) -> &mut T {
+ self.stream.as_mut().unwrap()
+ }
+
+ pub fn into_inner(&mut self) -> T {
+ self.stream.take().unwrap()
+ }
+
+ pub fn disconnected(&mut self) {
+ let len = self.buffer.len();
+ self.buffer.split_to(len);
+ }
+
+ pub fn keepalive(&self) -> bool {
+ self.keepalive && !self.upgrade
+ }
+
+ fn write_to_stream(&mut self) -> Result {
+ if let Some(ref mut stream) = self.stream {
+ while !self.buffer.is_empty() {
+ match stream.write(self.buffer.as_ref()) {
+ Ok(n) => {
+ self.buffer.split_to(n);
+ },
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
+ return Ok(WriterState::Pause)
+ } else {
+ return Ok(WriterState::Done)
+ }
+ }
+ Err(err) =>
+ return Err(err),
+ }
+ }
+ }
+ return Ok(WriterState::Done)
+ }
+}
+
+impl Writer for H1Writer {
+
+ fn start(&mut self, req: &mut HttpRequest, msg: &mut HttpResponse)
+ -> Result
+ {
+ trace!("Prepare message status={:?}", msg.status);
+
+ // prepare task
+ let mut extra = 0;
+ let body = msg.replace_body(Body::Empty);
+ let version = msg.version().unwrap_or_else(|| req.version());
+ self.started = true;
+ self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive());
+
+ match body {
+ Body::Empty => {
+ if msg.chunked() {
+ error!("Chunked transfer is enabled but body is set to Empty");
+ }
+ msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
+ msg.headers.remove(TRANSFER_ENCODING);
+ self.encoder = Encoder::length(0);
+ },
+ Body::Length(n) => {
+ if msg.chunked() {
+ error!("Chunked transfer is enabled but body with specific length is specified");
+ }
+ msg.headers.insert(
+ CONTENT_LENGTH,
+ HeaderValue::from_str(format!("{}", n).as_str()).unwrap());
+ msg.headers.remove(TRANSFER_ENCODING);
+ self.encoder = Encoder::length(n);
+ },
+ Body::Binary(ref bytes) => {
+ extra = bytes.len();
+ msg.headers.insert(
+ CONTENT_LENGTH,
+ HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap());
+ msg.headers.remove(TRANSFER_ENCODING);
+ self.encoder = Encoder::length(0);
+ }
+ Body::Streaming => {
+ if msg.chunked() {
+ if version < Version::HTTP_11 {
+ error!("Chunked transfer encoding is forbidden for {:?}", version);
+ }
+ msg.headers.remove(CONTENT_LENGTH);
+ msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
+ self.encoder = Encoder::chunked();
+ } else {
+ self.encoder = Encoder::eof();
+ }
+ }
+ Body::Upgrade => {
+ msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
+ self.encoder = Encoder::eof();
+ }
+ }
+
+ // Connection upgrade
+ if msg.upgrade() {
+ msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
+ }
+ // keep-alive
+ else if self.keepalive {
+ if version < Version::HTTP_11 {
+ msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
+ }
+ } else if version >= Version::HTTP_11 {
+ msg.headers.insert(CONNECTION, HeaderValue::from_static("close"));
+ }
+
+ // render message
+ let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra;
+ self.buffer.reserve(init_cap);
+
+ if version == Version::HTTP_11 && msg.status == StatusCode::OK {
+ self.buffer.extend(b"HTTP/1.1 200 OK\r\n");
+ } else {
+ let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status);
+ }
+ for (key, value) in &msg.headers {
+ let t: &[u8] = key.as_ref();
+ self.buffer.extend(t);
+ self.buffer.extend(b": ");
+ self.buffer.extend(value.as_ref());
+ self.buffer.extend(b"\r\n");
+ }
+
+ // using http::h1::date is quite a lot faster than generating
+ // a unique Date header each time like req/s goes up about 10%
+ if !msg.headers.contains_key(DATE) {
+ self.buffer.reserve(date::DATE_VALUE_LENGTH + 8);
+ self.buffer.extend(b"Date: ");
+ date::extend(&mut self.buffer);
+ self.buffer.extend(b"\r\n");
+ }
+
+ // default content-type
+ if !msg.headers.contains_key(CONTENT_TYPE) {
+ self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref());
+ }
+
+ self.buffer.extend(b"\r\n");
+
+ if let Body::Binary(ref bytes) = body {
+ self.buffer.extend_from_slice(bytes.as_ref());
+ return Ok(WriterState::Done)
+ }
+ msg.replace_body(body);
+
+ Ok(WriterState::Done)
+ }
+
+ fn write(&mut self, payload: &[u8]) -> Result {
+ if !self.disconnected {
+ if self.started {
+ // TODO: add warning, write after EOF
+ self.encoder.encode(&mut self.buffer, payload);
+ } else {
+ // might be response for EXCEPT
+ self.buffer.extend_from_slice(payload)
+ }
+ }
+
+ if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
+ return Ok(WriterState::Pause)
+ } else {
+ return Ok(WriterState::Done)
+ }
+ }
+
+ fn write_eof(&mut self) -> Result {
+ if !self.encoder.encode_eof(&mut self.buffer) {
+ //debug!("last payload item, but it is not EOF ");
+ Err(io::Error::new(io::ErrorKind::Other,
+ "Last payload item, but eof is not reached"))
+ } else {
+ if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
+ return Ok(WriterState::Pause)
+ } else {
+ return Ok(WriterState::Done)
+ }
+ }
+ }
+
+ fn poll_complete(&mut self) -> Poll<(), io::Error> {
+ match self.write_to_stream() {
+ Ok(WriterState::Done) => Ok(Async::Ready(())),
+ Ok(WriterState::Pause) => Ok(Async::NotReady),
+ Err(err) => Err(err)
+ }
+ }
+}
+
+/// Encoders to handle different Transfer-Encodings.
+#[derive(Debug, Clone)]
+struct Encoder {
+ kind: Kind,
+}
+
+#[derive(Debug, PartialEq, Clone)]
+enum Kind {
+ /// An Encoder for when Transfer-Encoding includes `chunked`.
+ Chunked(bool),
+ /// An Encoder for when Content-Length is set.
+ ///
+ /// Enforces that the body is not longer than the Content-Length header.
+ Length(u64),
+ /// An Encoder for when Content-Length is not known.
+ ///
+ /// Appliction decides when to stop writing.
+ Eof,
+}
+
+impl Encoder {
+
+ pub fn eof() -> Encoder {
+ Encoder {
+ kind: Kind::Eof,
+ }
+ }
+
+ pub fn chunked() -> Encoder {
+ Encoder {
+ kind: Kind::Chunked(false),
+ }
+ }
+
+ pub fn length(len: u64) -> Encoder {
+ Encoder {
+ kind: Kind::Length(len),
+ }
+ }
+
+ /// Encode message. Return `EOF` state of encoder
+ pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool {
+ match self.kind {
+ Kind::Eof => {
+ dst.extend(msg);
+ msg.is_empty()
+ },
+ Kind::Chunked(ref mut eof) => {
+ if *eof {
+ return true;
+ }
+
+ if msg.is_empty() {
+ *eof = true;
+ dst.extend(b"0\r\n\r\n");
+ } else {
+ write!(dst, "{:X}\r\n", msg.len()).unwrap();
+ dst.extend(msg);
+ dst.extend(b"\r\n");
+ }
+ *eof
+ },
+ Kind::Length(ref mut remaining) => {
+ if msg.is_empty() {
+ return *remaining == 0
+ }
+ let max = cmp::min(*remaining, msg.len() as u64);
+ trace!("sized write = {}", max);
+ dst.extend(msg[..max as usize].as_ref());
+
+ *remaining -= max as u64;
+ trace!("encoded {} bytes, remaining = {}", max, remaining);
+ *remaining == 0
+ },
+ }
+ }
+
+ /// Encode eof. Return `EOF` state of encoder
+ pub fn encode_eof(&mut self, dst: &mut BytesMut) -> bool {
+ match self.kind {
+ Kind::Eof => true,
+ Kind::Chunked(ref mut eof) => {
+ if *eof {
+ return true;
+ }
+
+ *eof = true;
+ dst.extend(b"0\r\n\r\n");
+ true
+ },
+ Kind::Length(ref mut remaining) => {
+ return *remaining == 0
+ },
+ }
+ }
+}
diff --git a/src/h2.rs b/src/h2.rs
index 21ab1e191..ecf429c0a 100644
--- a/src/h2.rs
+++ b/src/h2.rs
@@ -1,9 +1,147 @@
-use std::{io, cmp};
+use std::{io, cmp, mem};
+use std::rc::Rc;
use std::io::{Read, Write};
+use std::cell::UnsafeCell;
+use std::collections::VecDeque;
+
+use http::request::Parts;
+use http2::{RecvStream};
+use http2::server::{Server, Handshake, Respond};
use bytes::{Buf, Bytes};
-use futures::Poll;
+use futures::{Async, Poll, Future, Stream};
use tokio_io::{AsyncRead, AsyncWrite};
+use task::Task;
+use server::HttpHandler;
+use httpcodes::HTTPNotFound;
+use httprequest::HttpRequest;
+use payload::{Payload, PayloadError, PayloadSender};
+
+
+pub(crate) struct Http2
+ where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
+{
+ router: Rc>,
+ #[allow(dead_code)]
+ addr: A,
+ state: State>,
+ error: bool,
+ tasks: VecDeque,
+}
+
+enum State {
+ Handshake(Handshake),
+ Server(Server),
+ Empty,
+}
+
+impl Http2
+ where T: AsyncRead + AsyncWrite + 'static,
+ A: 'static,
+ H: HttpHandler + 'static
+{
+ pub fn new(stream: T, addr: A, router: Rc>, buf: Bytes) -> Self {
+ Http2{ router: router,
+ addr: addr,
+ error: false,
+ tasks: VecDeque::new(),
+ state: State::Handshake(
+ Server::handshake(IoWrapper{unread: Some(buf), inner: stream})) }
+ }
+
+ pub fn poll(&mut self) -> Poll<(), ()> {
+ // handshake
+ self.state = if let State::Handshake(ref mut handshake) = self.state {
+ match handshake.poll() {
+ Ok(Async::Ready(srv)) => {
+ State::Server(srv)
+ },
+ Ok(Async::NotReady) =>
+ return Ok(Async::NotReady),
+ Err(err) => {
+ trace!("Error handling connection: {}", err);
+ return Err(())
+ }
+ }
+ } else {
+ mem::replace(&mut self.state, State::Empty)
+ };
+
+ // get request
+ let poll = if let State::Server(ref mut server) = self.state {
+ server.poll()
+ } else {
+ unreachable!("Http2::poll() state was not advanced completely!")
+ };
+
+ match poll {
+ Ok(Async::NotReady) => {
+ // Ok(Async::NotReady);
+ ()
+ }
+ Err(err) => {
+ trace!("Connection error: {}", err);
+ self.error = true;
+ },
+ Ok(Async::Ready(None)) => {
+
+ },
+ Ok(Async::Ready(Some((req, resp)))) => {
+ let (parts, body) = req.into_parts();
+ let entry = Entry::new(parts, body, resp, &self.router);
+ }
+ }
+
+ Ok(Async::Ready(()))
+ }
+}
+
+struct Entry {
+ task: Task,
+ req: UnsafeCell,
+ payload: PayloadSender,
+ recv: RecvStream,
+ respond: Respond,
+ eof: bool,
+ error: bool,
+ finished: bool,
+}
+
+impl Entry {
+ fn new(parts: Parts,
+ recv: RecvStream,
+ resp: Respond,
+ router: &Rc>) -> Entry
+ where H: HttpHandler + 'static
+ {
+ let path = parts.uri.path().to_owned();
+ let query = parts.uri.query().unwrap_or("").to_owned();
+
+ println!("PARTS: {:?}", parts);
+ let mut req = HttpRequest::new(
+ parts.method, path, parts.version, parts.headers, query);
+ let (psender, payload) = Payload::new(false);
+
+ // start request processing
+ let mut task = None;
+ for h in router.iter() {
+ if req.path().starts_with(h.prefix()) {
+ task = Some(h.handle(&mut req, payload));
+ break
+ }
+ }
+ println!("REQ: {:?}", req);
+
+ Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
+ req: UnsafeCell::new(req),
+ payload: psender,
+ recv: recv,
+ respond: resp,
+ eof: false,
+ error: false,
+ finished: false}
+ }
+}
struct IoWrapper {
unread: Option,
@@ -14,9 +152,9 @@ impl Read for IoWrapper {
fn read(&mut self, buf: &mut [u8]) -> io::Result {
if let Some(mut bytes) = self.unread.take() {
let size = cmp::min(buf.len(), bytes.len());
- buf.copy_from_slice(&bytes[..size]);
- bytes.split_to(size);
- if !bytes.is_empty() {
+ buf[..size].copy_from_slice(&bytes[..size]);
+ if bytes.len() > size {
+ bytes.split_to(size);
self.unread = Some(bytes);
}
Ok(size)
diff --git a/src/lib.rs b/src/lib.rs
index 838307950..9dc538124 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -20,6 +20,7 @@ extern crate mime_guess;
extern crate url;
extern crate percent_encoding;
extern crate actix;
+extern crate h2 as http2;
#[cfg(feature="tls")]
extern crate native_tls;
@@ -45,6 +46,7 @@ mod wsframe;
mod wsproto;
mod h1;
mod h2;
+mod h1writer;
pub mod ws;
pub mod dev;
diff --git a/src/server.rs b/src/server.rs
index 3d852e50d..55f85b6d3 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,13 +1,9 @@
-use std::{io, net};
+use std::{io, net, mem};
use std::rc::Rc;
-use std::cell::UnsafeCell;
-use std::time::Duration;
use std::marker::PhantomData;
-use std::collections::VecDeque;
use actix::dev::*;
use futures::{Future, Poll, Async, Stream};
-use tokio_core::reactor::Timeout;
use tokio_core::net::{TcpListener, TcpStream};
use tokio_io::{AsyncRead, AsyncWrite};
@@ -17,9 +13,9 @@ use native_tls::TlsAcceptor;
use tokio_tls::{TlsStream, TlsAcceptorExt};
use h1;
+use h2;
use task::Task;
use payload::Payload;
-use httpcodes::HTTPNotFound;
use httprequest::HttpRequest;
/// Low level http request handler
@@ -153,11 +149,10 @@ impl HttpServer, net::SocketAddr, H> {
println!("SSL");
TlsAcceptorExt::accept_async(acc.as_ref(), stream)
.map(move |t| {
- println!("connected {:?} {:?}", t, addr);
IoStream(t, addr)
})
.map_err(|err| {
- println!("ERR: {:?}", err);
+ trace!("Error during handling tls connection: {}", err);
io::Error::new(io::ErrorKind::Other, err)
})
}));
@@ -195,42 +190,25 @@ impl Handler, io::Error> for HttpServer
-> Response>
{
Arbiter::handle().spawn(
- HttpChannel{router: Rc::clone(&self.h),
- addr: msg.1,
- stream: msg.0,
- reader: h1::Reader::new(),
- error: false,
- items: VecDeque::new(),
- inactive: VecDeque::new(),
- keepalive: true,
- keepalive_timer: None,
+ HttpChannel{
+ proto: Protocol::H1(h1::Http1::new(msg.0, msg.1, Rc::clone(&self.h)))
});
Self::empty()
}
}
-struct Entry {
- task: Task,
- req: UnsafeCell,
- eof: bool,
- error: bool,
- finished: bool,
+enum Protocol
+ where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
+{
+ H1(h1::Http1),
+ H2(h2::Http2),
+ None,
}
-const KEEPALIVE_PERIOD: u64 = 15; // seconds
-const MAX_PIPELINED_MESSAGES: usize = 16;
-
-pub struct HttpChannel {
- router: Rc>,
- #[allow(dead_code)]
- addr: A,
- stream: T,
- reader: h1::Reader,
- error: bool,
- items: VecDeque,
- inactive: VecDeque,
- keepalive: bool,
- keepalive_timer: Option,
+pub struct HttpChannel
+ where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: 'static
+{
+ proto: Protocol,
}
/*impl Drop for HttpChannel {
@@ -240,193 +218,45 @@ pub struct HttpChannel {
}*/
impl Actor for HttpChannel
- where T: AsyncRead + AsyncWrite + 'static,
- A: 'static,
- H: HttpHandler + 'static
+ where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static
{
type Context = Context;
}
impl Future for HttpChannel
- where T: AsyncRead + AsyncWrite + 'static,
- A: 'static,
- H: HttpHandler + 'static
+ where T: AsyncRead + AsyncWrite + 'static, A: 'static, H: HttpHandler + 'static
{
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll {
- // keep-alive timer
- if let Some(ref mut timeout) = self.keepalive_timer {
- match timeout.poll() {
- Ok(Async::Ready(_)) =>
- return Ok(Async::Ready(())),
- Ok(Async::NotReady) => (),
- Err(_) => unreachable!(),
+ match self.proto {
+ Protocol::H1(ref mut h1) => {
+ match h1.poll() {
+ Ok(Async::Ready(h1::Http1Result::Done)) =>
+ return Ok(Async::Ready(())),
+ Ok(Async::Ready(h1::Http1Result::Upgrade)) => (),
+ Ok(Async::NotReady) =>
+ return Ok(Async::NotReady),
+ Err(_) =>
+ return Err(()),
+ }
}
+ Protocol::H2(ref mut h2) =>
+ return h2.poll(),
+ Protocol::None =>
+ unreachable!()
}
- loop {
- let mut not_ready = true;
-
- // check in-flight messages
- let mut idx = 0;
- while idx < self.items.len() {
- if idx == 0 {
- if self.items[idx].error {
- return Err(())
- }
-
- // this is anoying
- let req = unsafe {self.items[idx].req.get().as_mut().unwrap()};
- match self.items[idx].task.poll_io(&mut self.stream, req)
- {
- Ok(Async::Ready(ready)) => {
- not_ready = false;
- let mut item = self.items.pop_front().unwrap();
-
- // overide keep-alive state
- if self.keepalive {
- self.keepalive = item.task.keepalive();
- }
- if !ready {
- item.eof = true;
- self.inactive.push_back(item);
- }
-
- // no keep-alive
- if ready && !self.keepalive &&
- self.items.is_empty() && self.inactive.is_empty()
- {
- return Ok(Async::Ready(()))
- }
- continue
- },
- Ok(Async::NotReady) => (),
- Err(_) => {
- // it is not possible to recover from error
- // during task handling, so just drop connection
- return Err(())
- }
- }
- } else if !self.items[idx].finished && !self.items[idx].error {
- match self.items[idx].task.poll() {
- Ok(Async::NotReady) => (),
- Ok(Async::Ready(_)) => {
- not_ready = false;
- self.items[idx].finished = true;
- },
- Err(_) =>
- self.items[idx].error = true,
- }
- }
- idx += 1;
- }
-
- // check inactive tasks
- let mut idx = 0;
- while idx < self.inactive.len() {
- if idx == 0 && self.inactive[idx].error && self.inactive[idx].finished {
- let _ = self.inactive.pop_front();
- continue
- }
-
- if !self.inactive[idx].finished && !self.inactive[idx].error {
- match self.inactive[idx].task.poll() {
- Ok(Async::NotReady) => (),
- Ok(Async::Ready(_)) => {
- not_ready = false;
- self.inactive[idx].finished = true
- }
- Err(_) =>
- self.inactive[idx].error = true,
- }
- }
- idx += 1;
- }
-
- // read incoming data
- if !self.error && self.items.len() < MAX_PIPELINED_MESSAGES {
- match self.reader.parse(&mut self.stream) {
- Ok(Async::Ready((mut req, payload))) => {
- not_ready = false;
-
- // stop keepalive timer
- self.keepalive_timer.take();
-
- // start request processing
- let mut task = None;
- for h in self.router.iter() {
- if req.path().starts_with(h.prefix()) {
- task = Some(h.handle(&mut req, payload));
- break
- }
- }
-
- self.items.push_back(
- Entry {task: task.unwrap_or_else(|| Task::reply(HTTPNotFound)),
- req: UnsafeCell::new(req),
- eof: false,
- error: false,
- finished: false});
- }
- Err(err) => {
- // notify all tasks
- not_ready = false;
- for entry in &mut self.items {
- entry.task.disconnected()
- }
-
- // kill keepalive
- self.keepalive = false;
- self.keepalive_timer.take();
-
- // on parse error, stop reading stream but
- // tasks need to be completed
- self.error = true;
-
- if self.items.is_empty() {
- if let h1::ReaderError::Error(err) = err {
- self.items.push_back(
- Entry {task: Task::reply(err),
- req: UnsafeCell::new(HttpRequest::for_error()),
- eof: false,
- error: false,
- finished: false});
- }
- }
- }
- Ok(Async::NotReady) => {
- // start keep-alive timer, this is also slow request timeout
- if self.items.is_empty() && self.inactive.is_empty() {
- if self.keepalive {
- if self.keepalive_timer.is_none() {
- trace!("Start keep-alive timer");
- let mut timeout = Timeout::new(
- Duration::new(KEEPALIVE_PERIOD, 0),
- Arbiter::handle()).unwrap();
- // register timeout
- let _ = timeout.poll();
- self.keepalive_timer = Some(timeout);
- }
- } else {
- // keep-alive disable, drop connection
- return Ok(Async::Ready(()))
- }
- }
- return Ok(Async::NotReady)
- }
- }
- }
-
- // check for parse error
- if self.items.is_empty() && self.inactive.is_empty() && self.error {
- return Ok(Async::Ready(()))
- }
-
- if not_ready {
- return Ok(Async::NotReady)
+ // upgrade to h2
+ let proto = mem::replace(&mut self.proto, Protocol::None);
+ match proto {
+ Protocol::H1(h1) => {
+ let (stream, addr, router, buf) = h1.into_inner();
+ self.proto = Protocol::H2(h2::Http2::new(stream, addr, router, buf));
+ return self.poll()
}
+ _ => unreachable!()
}
}
}
diff --git a/src/task.rs b/src/task.rs
index ec3f6bd59..073cde62b 100644
--- a/src/task.rs
+++ b/src/task.rs
@@ -1,27 +1,18 @@
-use std::{mem, cmp, io};
+use std::{mem, io};
use std::rc::Rc;
-use std::fmt::Write;
use std::cell::RefCell;
use std::collections::VecDeque;
-use http::{StatusCode, Version};
-use http::header::{HeaderValue,
- CONNECTION, CONTENT_TYPE, CONTENT_LENGTH, TRANSFER_ENCODING, DATE};
-use bytes::BytesMut;
use futures::{Async, Future, Poll, Stream};
use futures::task::{Task as FutureTask, current as current_task};
-use tokio_io::AsyncWrite;
-use date;
-use body::Body;
+use h1writer::{Writer, WriterState};
use route::Frame;
use application::Middleware;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
type FrameStream = Stream
- ;
-const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific
-const MAX_WRITE_BUFFER_SIZE: usize = 65_536; // max buffer size 64k
#[derive(PartialEq, Debug)]
enum TaskRunningState {
@@ -34,6 +25,16 @@ impl TaskRunningState {
fn is_done(&self) -> bool {
*self == TaskRunningState::Done
}
+ fn pause(&mut self) {
+ if *self != TaskRunningState::Done {
+ *self = TaskRunningState::Paused
+ }
+ }
+ fn resume(&mut self) {
+ if *self != TaskRunningState::Done {
+ *self = TaskRunningState::Running
+ }
+ }
}
#[derive(PartialEq, Debug)]
@@ -100,17 +101,12 @@ impl Future for DrainFut {
}
}
-
pub struct Task {
state: TaskRunningState,
iostate: TaskIOState,
frames: VecDeque,
stream: TaskStream,
- encoder: Encoder,
- buffer: BytesMut,
drain: Vec>>,
- upgrade: bool,
- keepalive: bool,
prepared: Option,
disconnected: bool,
middlewares: Option>>>,
@@ -129,10 +125,6 @@ impl Task {
frames: frames,
drain: Vec::new(),
stream: TaskStream::None,
- encoder: Encoder::length(0),
- buffer: BytesMut::new(),
- upgrade: false,
- keepalive: false,
prepared: None,
disconnected: false,
middlewares: None,
@@ -147,11 +139,7 @@ impl Task {
iostate: TaskIOState::ReadingMessage,
frames: VecDeque::new(),
stream: TaskStream::Stream(Box::new(stream)),
- encoder: Encoder::length(0),
- buffer: BytesMut::new(),
drain: Vec::new(),
- upgrade: false,
- keepalive: false,
prepared: None,
disconnected: false,
middlewares: None,
@@ -165,158 +153,26 @@ impl Task {
iostate: TaskIOState::ReadingMessage,
frames: VecDeque::new(),
stream: TaskStream::Context(Box::new(ctx)),
- encoder: Encoder::length(0),
- buffer: BytesMut::new(),
drain: Vec::new(),
- upgrade: false,
- keepalive: false,
prepared: None,
disconnected: false,
middlewares: None,
}
}
- pub(crate) fn keepalive(&self) -> bool {
- self.keepalive && !self.upgrade
- }
-
pub(crate) fn set_middlewares(&mut self, middlewares: Rc>>) {
self.middlewares = Some(middlewares);
}
pub(crate) fn disconnected(&mut self) {
- let len = self.buffer.len();
- self.buffer.split_to(len);
self.disconnected = true;
if let TaskStream::Context(ref mut ctx) = self.stream {
ctx.disconnected();
}
}
- fn prepare(&mut self, req: &mut HttpRequest, msg: HttpResponse)
- {
- trace!("Prepare message status={:?}", msg.status);
-
- // run middlewares
- let mut msg = if let Some(middlewares) = self.middlewares.take() {
- let mut msg = msg;
- for middleware in middlewares.iter() {
- msg = middleware.response(req, msg);
- }
- self.middlewares = Some(middlewares);
- msg
- } else {
- msg
- };
-
- // prepare task
- let mut extra = 0;
- let body = msg.replace_body(Body::Empty);
- let version = msg.version().unwrap_or_else(|| req.version());
- self.keepalive = msg.keep_alive().unwrap_or_else(|| req.keep_alive());
-
- match body {
- Body::Empty => {
- if msg.chunked() {
- error!("Chunked transfer is enabled but body is set to Empty");
- }
- msg.headers.insert(CONTENT_LENGTH, HeaderValue::from_static("0"));
- msg.headers.remove(TRANSFER_ENCODING);
- self.encoder = Encoder::length(0);
- },
- Body::Length(n) => {
- if msg.chunked() {
- error!("Chunked transfer is enabled but body with specific length is specified");
- }
- msg.headers.insert(
- CONTENT_LENGTH,
- HeaderValue::from_str(format!("{}", n).as_str()).unwrap());
- msg.headers.remove(TRANSFER_ENCODING);
- self.encoder = Encoder::length(n);
- },
- Body::Binary(ref bytes) => {
- extra = bytes.len();
- msg.headers.insert(
- CONTENT_LENGTH,
- HeaderValue::from_str(format!("{}", bytes.len()).as_str()).unwrap());
- msg.headers.remove(TRANSFER_ENCODING);
- self.encoder = Encoder::length(0);
- }
- Body::Streaming => {
- if msg.chunked() {
- if version < Version::HTTP_11 {
- error!("Chunked transfer encoding is forbidden for {:?}", version);
- }
- msg.headers.remove(CONTENT_LENGTH);
- msg.headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
- self.encoder = Encoder::chunked();
- } else {
- self.encoder = Encoder::eof();
- }
- }
- Body::Upgrade => {
- msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
- self.encoder = Encoder::eof();
- }
- }
-
- // Connection upgrade
- if msg.upgrade() {
- msg.headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
- }
- // keep-alive
- else if self.keepalive {
- if version < Version::HTTP_11 {
- msg.headers.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
- }
- } else if version >= Version::HTTP_11 {
- msg.headers.insert(CONNECTION, HeaderValue::from_static("close"));
- }
-
- // render message
- let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra;
- self.buffer.reserve(init_cap);
-
- if version == Version::HTTP_11 && msg.status == StatusCode::OK {
- self.buffer.extend(b"HTTP/1.1 200 OK\r\n");
- } else {
- let _ = write!(self.buffer, "{:?} {}\r\n", version, msg.status);
- }
- for (key, value) in &msg.headers {
- let t: &[u8] = key.as_ref();
- self.buffer.extend(t);
- self.buffer.extend(b": ");
- self.buffer.extend(value.as_ref());
- self.buffer.extend(b"\r\n");
- }
-
- // using http::h1::date is quite a lot faster than generating
- // a unique Date header each time like req/s goes up about 10%
- if !msg.headers.contains_key(DATE) {
- self.buffer.reserve(date::DATE_VALUE_LENGTH + 8);
- self.buffer.extend(b"Date: ");
- date::extend(&mut self.buffer);
- self.buffer.extend(b"\r\n");
- }
-
- // default content-type
- if !msg.headers.contains_key(CONTENT_TYPE) {
- self.buffer.extend(b"ContentType: application/octet-stream\r\n".as_ref());
- }
-
- self.buffer.extend(b"\r\n");
-
- if let Body::Binary(ref bytes) = body {
- self.buffer.extend_from_slice(bytes.as_ref());
- self.prepared = Some(msg);
- return
- }
- msg.replace_body(body);
- self.prepared = Some(msg);
- }
-
pub(crate) fn poll_io(&mut self, io: &mut T, req: &mut HttpRequest) -> Poll
- where T: AsyncWrite
+ where T: Writer
{
trace!("POLL-IO frames:{:?}", self.frames.len());
// response is completed
@@ -328,87 +184,76 @@ impl Task {
match self.poll() {
Ok(Async::Ready(_)) => {
self.state = TaskRunningState::Done;
- }
+ },
Ok(Async::NotReady) => (),
Err(_) => return Err(())
}
}
// use exiting frames
- while let Some(frame) = self.frames.pop_front() {
- trace!("IO Frame: {:?}", frame);
- match frame {
- Frame::Message(response) => {
- if !self.disconnected {
- self.prepare(req, response);
+ if self.state != TaskRunningState::Paused {
+ while let Some(frame) = self.frames.pop_front() {
+ trace!("IO Frame: {:?}", frame);
+ let res = match frame {
+ Frame::Message(mut response) => {
+ trace!("Prepare message status={:?}", response.status);
+
+ // run middlewares
+ let mut response =
+ if let Some(middlewares) = self.middlewares.take() {
+ let mut response = response;
+ for middleware in middlewares.iter() {
+ response = middleware.response(req, response);
+ }
+ self.middlewares = Some(middlewares);
+ response
+ } else {
+ response
+ };
+
+ let result = io.start(req, &mut response);
+ self.prepared = Some(response);
+ result
}
- }
- Frame::Payload(Some(chunk)) => {
- if !self.disconnected {
- if self.prepared.is_some() {
- // TODO: add warning, write after EOF
- self.encoder.encode(&mut self.buffer, chunk.as_ref());
- } else {
- // might be response for EXCEPT
- self.buffer.extend_from_slice(chunk.as_ref())
- }
+ Frame::Payload(Some(chunk)) => {
+ io.write(chunk.as_ref())
+ },
+ Frame::Payload(None) => {
+ self.iostate = TaskIOState::Done;
+ io.write_eof()
+ },
+ Frame::Drain(fut) => {
+ self.drain.push(fut);
+ break
}
- },
- Frame::Payload(None) => {
- if !self.disconnected &&
- !self.encoder.encode(&mut self.buffer, [].as_ref())
- {
- // TODO: add error "not eof""
- debug!("last payload item, but it is not EOF ");
- return Err(())
+ };
+
+ match res {
+ Ok(WriterState::Pause) => {
+ self.state.pause();
+ break
}
- break
- },
- Frame::Drain(fut) => {
- self.drain.push(fut);
- break
+ Ok(WriterState::Done) => self.state.resume(),
+ Err(_) => return Err(())
}
}
}
}
- // write bytes to TcpStream
- if !self.disconnected {
- while !self.buffer.is_empty() {
- match io.write(self.buffer.as_ref()) {
- Ok(n) => {
- self.buffer.split_to(n);
- },
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- break
- }
- Err(_) => return Err(()),
- }
+ // flush io
+ match io.poll_complete() {
+ Ok(Async::Ready(())) => self.state.resume(),
+ Ok(Async::NotReady) => {
+ return Ok(Async::NotReady)
}
- }
-
- // should pause task
- if self.state != TaskRunningState::Done {
- if self.buffer.len() > MAX_WRITE_BUFFER_SIZE {
- self.state = TaskRunningState::Paused;
- } else if self.state == TaskRunningState::Paused {
- self.state = TaskRunningState::Running;
+ Err(err) => {
+ trace!("Error sending data: {}", err);
+ return Err(())
}
- } else {
- // at this point we wont get any more Frames
- self.iostate = TaskIOState::Done;
}
// drain
- if self.buffer.is_empty() && !self.drain.is_empty() {
- match io.flush() {
- Ok(_) => (),
- Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
- return Ok(Async::NotReady)
- }
- Err(_) => return Err(()),
- }
-
+ if !self.drain.is_empty() {
for fut in &mut self.drain {
fut.borrow_mut().set()
}
@@ -416,7 +261,7 @@ impl Task {
}
// response is completed
- if (self.buffer.is_empty() || self.disconnected) && self.iostate.is_done() {
+ if self.iostate.is_done() {
// run middlewares
if let Some(ref mut resp) = self.prepared {
if let Some(middlewares) = self.middlewares.take() {
@@ -443,8 +288,8 @@ impl Task {
error!("Non expected frame {:?}", frame);
return Err(())
}
- self.upgrade = msg.upgrade();
- if self.upgrade || msg.body().has_body() {
+ let upgrade = msg.upgrade();
+ if upgrade || msg.body().has_body() {
self.iostate = TaskIOState::ReadingPayload;
} else {
self.iostate = TaskIOState::Done;
@@ -489,89 +334,3 @@ impl Future for Task {
result
}
}
-
-/// Encoders to handle different Transfer-Encodings.
-#[derive(Debug, Clone)]
-struct Encoder {
- kind: Kind,
-}
-
-#[derive(Debug, PartialEq, Clone)]
-enum Kind {
- /// An Encoder for when Transfer-Encoding includes `chunked`.
- Chunked(bool),
- /// An Encoder for when Content-Length is set.
- ///
- /// Enforces that the body is not longer than the Content-Length header.
- Length(u64),
- /// An Encoder for when Content-Length is not known.
- ///
- /// Appliction decides when to stop writing.
- Eof,
-}
-
-impl Encoder {
-
- pub fn eof() -> Encoder {
- Encoder {
- kind: Kind::Eof,
- }
- }
-
- pub fn chunked() -> Encoder {
- Encoder {
- kind: Kind::Chunked(false),
- }
- }
-
- pub fn length(len: u64) -> Encoder {
- Encoder {
- kind: Kind::Length(len),
- }
- }
-
- /*pub fn is_eof(&self) -> bool {
- match self.kind {
- Kind::Eof | Kind::Length(0) => true,
- Kind::Chunked(eof) => eof,
- _ => false,
- }
- }*/
-
- /// Encode message. Return `EOF` state of encoder
- pub fn encode(&mut self, dst: &mut BytesMut, msg: &[u8]) -> bool {
- match self.kind {
- Kind::Eof => {
- dst.extend(msg);
- msg.is_empty()
- },
- Kind::Chunked(ref mut eof) => {
- if *eof {
- return true;
- }
-
- if msg.is_empty() {
- *eof = true;
- dst.extend(b"0\r\n\r\n");
- } else {
- write!(dst, "{:X}\r\n", msg.len()).unwrap();
- dst.extend(msg);
- dst.extend(b"\r\n");
- }
- *eof
- },
- Kind::Length(ref mut remaining) => {
- if msg.is_empty() {
- return *remaining == 0
- }
- let max = cmp::min(*remaining, msg.len() as u64);
- trace!("sized write = {}", max);
- dst.extend(msg[..max as usize].as_ref());
-
- *remaining -= max as u64;
- trace!("encoded {} bytes, remaining = {}", max, remaining);
- *remaining == 0
- },
- }
- }
-}