diff --git a/src/server/error.rs b/src/server/error.rs index eb3e88478..70f100998 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -44,6 +44,10 @@ pub enum HttpDispatchError { #[fail(display = "HTTP2 error: {}", _0)] Http2(http2::Error), + /// Payload is not consumed + #[fail(display = "Task is completed but request's payload is not consumed")] + PayloadIsNotConsumed, + /// Malformed request #[fail(display = "Malformed request")] MalformedRequest, diff --git a/src/server/h1.rs b/src/server/h1.rs index 205be9494..cd9134275 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -30,7 +30,7 @@ bitflags! { const READ_DISCONNECTED = 0b0001_0000; const WRITE_DISCONNECTED = 0b0010_0000; const POLLED = 0b0100_0000; - + const FLUSHED = 0b1000_0000; } } @@ -99,9 +99,9 @@ where }; let flags = if is_eof { - Flags::READ_DISCONNECTED + Flags::READ_DISCONNECTED | Flags::FLUSHED } else if settings.keep_alive_enabled() { - Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED + Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED } else { Flags::empty() }; @@ -130,7 +130,7 @@ where } let mut disp = Http1Dispatcher { - flags: Flags::STARTED | Flags::READ_DISCONNECTED, + flags: Flags::STARTED | Flags::READ_DISCONNECTED | Flags::FLUSHED, stream: H1Writer::new(stream, settings.clone()), decoder: H1Decoder::new(), payload: None, @@ -177,7 +177,8 @@ where } if !checked || self.tasks.is_empty() { - self.flags.insert(Flags::WRITE_DISCONNECTED); + self.flags + .insert(Flags::WRITE_DISCONNECTED | Flags::FLUSHED); self.stream.disconnected(); // notify all tasks @@ -205,54 +206,70 @@ where // shutdown if self.flags.contains(Flags::SHUTDOWN) { - if self - .flags - .intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) - { + if self.flags.intersects(Flags::WRITE_DISCONNECTED) { return Ok(Async::Ready(())); } - match self.stream.poll_completed(true) { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(_)) => return Ok(Async::Ready(())), - Err(err) => { - debug!("Error sending data: {}", err); - return Err(err.into()); - } - } + return self.poll_flush(true); } - self.poll_io()?; - + // process incoming requests if !self.flags.contains(Flags::WRITE_DISCONNECTED) { - match self.poll_handler()? { - Async::Ready(true) => self.poll(), - Async::Ready(false) => { - self.flags.insert(Flags::SHUTDOWN); - self.poll() + self.poll_handler()?; + + // flush stream + self.poll_flush(false)?; + + // deal with keep-alive and stream eof (client-side write shutdown) + if self.tasks.is_empty() && self.flags.intersects(Flags::FLUSHED) { + // handle stream eof + if self + .flags + .intersects(Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED) + { + return Ok(Async::Ready(())); } - Async::NotReady => { - // deal with keep-alive and steam eof (client-side write shutdown) - if self.tasks.is_empty() { - // handle stream eof - if self.flags.intersects( - Flags::READ_DISCONNECTED | Flags::WRITE_DISCONNECTED, - ) { - return Ok(Async::Ready(())); - } - // no keep-alive - if self.flags.contains(Flags::STARTED) - && (!self.flags.contains(Flags::KEEPALIVE_ENABLED) - || !self.flags.contains(Flags::KEEPALIVE)) - { - self.flags.insert(Flags::SHUTDOWN); - return self.poll(); - } + // no keep-alive + if self.flags.contains(Flags::STARTED) + && (!self.flags.contains(Flags::KEEPALIVE_ENABLED) + || !self.flags.contains(Flags::KEEPALIVE)) + { + self.flags.insert(Flags::SHUTDOWN); + return self.poll(); + } + } + Ok(Async::NotReady) + } else if let Some(err) = self.error.take() { + Err(err) + } else { + Ok(Async::Ready(())) + } + } + + /// Flush stream + fn poll_flush(&mut self, shutdown: bool) -> Poll<(), HttpDispatchError> { + if shutdown || self.flags.contains(Flags::STARTED) { + match self.stream.poll_completed(shutdown) { + Ok(Async::NotReady) => { + // mark stream + if !self.stream.flushed() { + self.flags.remove(Flags::FLUSHED); } Ok(Async::NotReady) } + Err(err) => { + debug!("Error sending data: {}", err); + self.client_disconnected(false); + return Err(err.into()); + } + Ok(Async::Ready(_)) => { + // if payload is not consumed we can not use connection + if self.payload.is_some() && self.tasks.is_empty() { + return Err(HttpDispatchError::PayloadIsNotConsumed); + } + self.flags.insert(Flags::FLUSHED); + Ok(Async::Ready(())) + } } - } else if let Some(err) = self.error.take() { - Err(err) } else { Ok(Async::Ready(())) } @@ -317,20 +334,23 @@ where } #[inline] - /// read data from stream - pub(self) fn poll_io(&mut self) -> Result<(), HttpDispatchError> { + /// read data from the stream + pub(self) fn poll_io(&mut self) -> Result { if !self.flags.contains(Flags::POLLED) { - self.parse()?; + let updated = self.parse()?; self.flags.insert(Flags::POLLED); - return Ok(()); + return Ok(updated); } // read io from socket + let mut updated = false; if self.can_read() && self.tasks.len() < MAX_PIPELINED_MESSAGES { match self.stream.get_mut().read_available(&mut self.buf) { Ok(Async::Ready((read_some, disconnected))) => { if read_some { - self.parse()?; + if self.parse()? { + updated = true; + } } if disconnected { self.client_disconnected(true); @@ -343,13 +363,14 @@ where } } } - Ok(()) + Ok(updated) } - pub(self) fn poll_handler(&mut self) -> Poll { - let retry = self.can_read(); + pub(self) fn poll_handler(&mut self) -> Result<(), HttpDispatchError> { + self.poll_io()?; + let mut retry = self.can_read(); - // process first pipelined response, only one task can do io operation in http/1 + // process first pipelined response, only first task can do io operation in http/1 while !self.tasks.is_empty() { match self.tasks[0].poll_io(&mut self.stream) { Ok(Async::Ready(ready)) => { @@ -375,9 +396,12 @@ where } // if read-backpressure is enabled and we consumed some data. - // we may read more data + // we may read more dataand retry if !retry && self.can_read() { - return Ok(Async::Ready(true)); + if self.poll_io()? { + retry = self.can_read(); + continue; + } } break; } @@ -431,25 +455,7 @@ where } } - // flush stream - if self.flags.contains(Flags::STARTED) { - match self.stream.poll_completed(false) { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(err) => { - debug!("Error sending data: {}", err); - self.client_disconnected(false); - return Err(err.into()); - } - Ok(Async::Ready(_)) => { - // if payload is not consumed we can not use connection - if self.payload.is_some() && self.tasks.is_empty() { - return Ok(Async::Ready(false)); - } - } - } - } - - Ok(Async::NotReady) + Ok(()) } fn push_response_entry(&mut self, status: StatusCode) { @@ -457,7 +463,7 @@ where .push_back(Entry::Error(ServerError::err(Version::HTTP_11, status))); } - pub(self) fn parse(&mut self) -> Result<(), HttpDispatchError> { + pub(self) fn parse(&mut self) -> Result { let mut updated = false; 'outer: loop { @@ -524,7 +530,7 @@ where payload.feed_data(chunk); } else { error!("Internal server error: unexpected payload chunk"); - self.flags.insert(Flags::READ_DISCONNECTED); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); self.error = Some(HttpDispatchError::InternalError); break; @@ -536,7 +542,7 @@ where payload.feed_eof(); } else { error!("Internal server error: unexpected eof"); - self.flags.insert(Flags::READ_DISCONNECTED); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); self.push_response_entry(StatusCode::INTERNAL_SERVER_ERROR); self.error = Some(HttpDispatchError::InternalError); break; @@ -559,7 +565,7 @@ where // Malformed requests should be responded with 400 self.push_response_entry(StatusCode::BAD_REQUEST); - self.flags.insert(Flags::READ_DISCONNECTED); + self.flags.insert(Flags::READ_DISCONNECTED | Flags::STARTED); self.error = Some(HttpDispatchError::MalformedRequest); break; } @@ -571,7 +577,7 @@ where self.ka_expire = expire; } } - Ok(()) + Ok(updated) } } diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index 3036aa089..5c32de3aa 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -62,6 +62,10 @@ impl H1Writer { self.flags = Flags::KEEPALIVE; } + pub fn flushed(&mut self) -> bool { + self.buffer.is_empty() + } + pub fn disconnected(&mut self) { self.flags.insert(Flags::DISCONNECTED); } diff --git a/tests/test_server.rs b/tests/test_server.rs index 269a1cd7d..03a89642e 100644 --- a/tests/test_server.rs +++ b/tests/test_server.rs @@ -1094,3 +1094,35 @@ fn test_slow_request() { sys.stop(); } + +#[test] +fn test_malformed_request() { + use actix::System; + use std::net; + use std::sync::mpsc; + let (tx, rx) = mpsc::channel(); + + let addr = test::TestServer::unused_addr(); + thread::spawn(move || { + System::run(move || { + let srv = server::new(|| { + vec![App::new().resource("/", |r| { + r.method(http::Method::GET).f(|_| HttpResponse::Ok()) + })] + }); + + let _ = srv.bind(addr).unwrap().start(); + let _ = tx.send(System::current()); + }); + }); + let sys = rx.recv().unwrap(); + thread::sleep(time::Duration::from_millis(200)); + + let mut stream = net::TcpStream::connect(addr).unwrap(); + let _ = stream.write_all(b"GET /test/tests/test HTTP1.1\r\n"); + let mut data = String::new(); + let _ = stream.read_to_string(&mut data); + assert!(data.starts_with("HTTP/1.1 400 Bad Request")); + + sys.stop(); +} diff --git a/tests/test_ws.rs b/tests/test_ws.rs index 3baa48eb7..522832e00 100644 --- a/tests/test_ws.rs +++ b/tests/test_ws.rs @@ -7,7 +7,7 @@ extern crate rand; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use std::thread; +use std::{thread, time}; use bytes::Bytes; use futures::Stream; @@ -380,17 +380,17 @@ fn test_ws_stopped() { let num = Arc::new(AtomicUsize::new(0)); let num2 = num.clone(); - let _ = thread::spawn(move || { + let mut srv = test::TestServer::new(move |app| { let num3 = num2.clone(); - let mut srv = test::TestServer::new(move |app| { - let num4 = num3.clone(); - app.handler(move |req| ws::start(req, WsStopped(num4.clone()))) - }); + app.handler(move |req| ws::start(req, WsStopped(num3.clone()))) + }); + { let (reader, mut writer) = srv.ws().unwrap(); writer.text("text"); let (item, _) = srv.execute(reader.into_future()).unwrap(); assert_eq!(item, Some(ws::Message::Text("text".to_owned()))); - }).join(); + } + thread::sleep(time::Duration::from_millis(1000)); assert_eq!(num.load(Ordering::Relaxed), 1); }