From 05f5ba00845bdbf32ef4f2c72b19d0695e0ab60b Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 9 Mar 2018 16:21:14 -0800 Subject: [PATCH] refactor keep-alive; fixed write to socket for upgraded connection --- src/server/h1.rs | 244 ++++++++++++++++++++--------------------- src/server/h1writer.rs | 59 +++++----- src/server/mod.rs | 13 +++ src/server/settings.rs | 16 ++- src/server/srv.rs | 17 +-- src/server/worker.rs | 19 +++- 6 files changed, 187 insertions(+), 181 deletions(-) diff --git a/src/server/h1.rs b/src/server/h1.rs index a55ac2799..097804ba2 100644 --- a/src/server/h1.rs +++ b/src/server/h1.rs @@ -192,134 +192,112 @@ impl Http1 let retry = self.reader.need_read() == PayloadStatus::Read; - loop { - // check in-flight messages - let mut io = false; - let mut idx = 0; - while idx < self.tasks.len() { - let item = &mut self.tasks[idx]; + // check in-flight messages + let mut io = false; + let mut idx = 0; + while idx < self.tasks.len() { + let item = &mut self.tasks[idx]; - if !io && !item.flags.contains(EntryFlags::EOF) { - // io is corrupted, send buffer - if item.flags.contains(EntryFlags::ERROR) { + if !io && !item.flags.contains(EntryFlags::EOF) { + // io is corrupted, send buffer + if item.flags.contains(EntryFlags::ERROR) { + if let Ok(Async::NotReady) = self.stream.poll_completed(true) { + return Ok(Async::NotReady) + } + return Err(()) + } + + match item.pipe.poll_io(&mut self.stream) { + Ok(Async::Ready(ready)) => { + // override keep-alive state + if self.stream.keepalive() { + self.flags.insert(Flags::KEEPALIVE); + } else { + self.flags.remove(Flags::KEEPALIVE); + } + // prepare stream for next response + self.stream.reset(); + + if ready { + item.flags.insert(EntryFlags::EOF | EntryFlags::FINISHED); + } else { + item.flags.insert(EntryFlags::FINISHED); + } + }, + // no more IO for this iteration + Ok(Async::NotReady) => { + if self.reader.need_read() == PayloadStatus::Read && !retry { + return Ok(Async::Ready(true)); + } + io = true; + } + Err(err) => { + // it is not possible to recover from error + // during pipe handling, so just drop connection + error!("Unhandled error: {}", err); + item.flags.insert(EntryFlags::ERROR); + + // check stream state, we still can have valid data in buffer if let Ok(Async::NotReady) = self.stream.poll_completed(true) { return Ok(Async::NotReady) } return Err(()) } - - match item.pipe.poll_io(&mut self.stream) { - Ok(Async::Ready(ready)) => { - // override keep-alive state - if self.stream.keepalive() { - self.flags.insert(Flags::KEEPALIVE); - } else { - self.flags.remove(Flags::KEEPALIVE); - } - // prepare stream for next response - self.stream.reset(); - - if ready { - item.flags.insert(EntryFlags::EOF | EntryFlags::FINISHED); - } else { - item.flags.insert(EntryFlags::FINISHED); - } - }, - // no more IO for this iteration - Ok(Async::NotReady) => { - if self.reader.need_read() == PayloadStatus::Read && !retry { - return Ok(Async::Ready(true)); - } - io = true; - } - Err(err) => { - // it is not possible to recover from error - // during pipe handling, so just drop connection - error!("Unhandled error: {}", err); - item.flags.insert(EntryFlags::ERROR); - - // check stream state, we still can have valid data in buffer - if let Ok(Async::NotReady) = self.stream.poll_completed(true) { - return Ok(Async::NotReady) - } - return Err(()) - } - } - } else if !item.flags.contains(EntryFlags::FINISHED) { - match item.pipe.poll() { - Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => item.flags.insert(EntryFlags::FINISHED), - Err(err) => { - item.flags.insert(EntryFlags::ERROR); - error!("Unhandled error: {}", err); - } + } + } else if !item.flags.contains(EntryFlags::FINISHED) { + match item.pipe.poll() { + Ok(Async::NotReady) => (), + Ok(Async::Ready(_)) => item.flags.insert(EntryFlags::FINISHED), + Err(err) => { + item.flags.insert(EntryFlags::ERROR); + error!("Unhandled error: {}", err); } } - idx += 1; } + idx += 1; + } - // cleanup finished tasks - let mut popped = false; - while !self.tasks.is_empty() { - if self.tasks[0].flags.contains(EntryFlags::EOF | EntryFlags::FINISHED) { - popped = true; - self.tasks.pop_front(); - } else { - break - } - } - if need_read && popped { - return self.poll_io() + // cleanup finished tasks + let mut popped = false; + while !self.tasks.is_empty() { + if self.tasks[0].flags.contains(EntryFlags::EOF | EntryFlags::FINISHED) { + popped = true; + self.tasks.pop_front(); + } else { + break } + } + if need_read && popped { + return self.poll_io() + } - // no keep-alive - if !self.flags.contains(Flags::KEEPALIVE) && self.tasks.is_empty() { - // check stream state - if !self.poll_completed(true)? { - return Ok(Async::NotReady) - } + // check stream state + if !self.poll_completed(true)? { + return Ok(Async::NotReady) + } + + // deal with keep-alive + if self.tasks.is_empty() { + // no keep-alive situations + if self.flags.contains(Flags::ERROR) + || !self.flags.contains(Flags::KEEPALIVE) + || !self.settings.keep_alive_enabled() + { return Ok(Async::Ready(false)) } - // start keep-alive timer, this also is slow request timeout - if self.tasks.is_empty() { - // check stream state - if self.flags.contains(Flags::ERROR) { - return Ok(Async::Ready(false)) - } - - if self.settings.keep_alive_enabled() { - let keep_alive = self.settings.keep_alive(); - if keep_alive > 0 && self.flags.contains(Flags::KEEPALIVE) { - if self.keepalive_timer.is_none() { - trace!("Start keep-alive timer"); - let mut to = Timeout::new( - Duration::new(keep_alive, 0), Arbiter::handle()).unwrap(); - // register timeout - let _ = to.poll(); - self.keepalive_timer = Some(to); - } - } else { - // check stream state - if !self.poll_completed(true)? { - return Ok(Async::NotReady) - } - // keep-alive is disabled, drop connection - return Ok(Async::Ready(false)) - } - } else if !self.poll_completed(false)? || - self.flags.contains(Flags::KEEPALIVE) { - // check stream state or - // if keep-alive unset, rely on operating system - return Ok(Async::NotReady) - } else { - return Ok(Async::Ready(false)) - } - } else { - self.poll_completed(false)?; - return Ok(Async::NotReady) + // start keep-alive timer + let keep_alive = self.settings.keep_alive(); + if self.keepalive_timer.is_none() && keep_alive > 0 { + trace!("Start keep-alive timer"); + let mut timer = Timeout::new( + Duration::new(keep_alive, 0), Arbiter::handle()).unwrap(); + // register timer + let _ = timer.poll(); + self.keepalive_timer = Some(timer); } } + Ok(Async::NotReady) } } @@ -868,7 +846,7 @@ mod tests { use httpmessage::HttpMessage; use application::HttpApplication; use server::settings::WorkerSettings; - use server::IoStream; + use server::{IoStream, KeepAlive}; struct Buffer { buf: Bytes, @@ -939,7 +917,8 @@ mod tests { macro_rules! parse_ready { ($e:expr) => ({ - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); match Reader::new().parse($e, &mut BytesMut::new(), &settings) { Ok(Async::Ready(req)) => req, Ok(_) => panic!("Eof during parsing http request"), @@ -961,7 +940,8 @@ mod tests { macro_rules! expect_parse_err { ($e:expr) => ({ let mut buf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); match Reader::new().parse($e, &mut buf, &settings) { Err(err) => match err { @@ -979,7 +959,8 @@ mod tests { fn test_parse() { let mut buf = Buffer::new("GET /test HTTP/1.1\r\n\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { @@ -996,7 +977,8 @@ mod tests { fn test_parse_partial() { let mut buf = Buffer::new("PUT /test HTTP/1"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { @@ -1019,7 +1001,8 @@ mod tests { fn test_parse_post() { let mut buf = Buffer::new("POST /test2 HTTP/1.0\r\n\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { @@ -1036,7 +1019,8 @@ mod tests { fn test_parse_body() { let mut buf = Buffer::new("GET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { @@ -1055,7 +1039,8 @@ mod tests { let mut buf = Buffer::new( "\r\nGET /test HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { @@ -1073,7 +1058,8 @@ mod tests { fn test_parse_partial_eof() { let mut buf = Buffer::new("GET /test HTTP/1.1\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); not_ready!{ reader.parse(&mut buf, &mut readbuf, &settings) } @@ -1093,7 +1079,8 @@ mod tests { fn test_headers_split_field() { let mut buf = Buffer::new("GET /test HTTP/1.1\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); not_ready!{ reader.parse(&mut buf, &mut readbuf, &settings) } @@ -1123,7 +1110,8 @@ mod tests { Set-Cookie: c1=cookie1\r\n\ Set-Cookie: c2=cookie2\r\n\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); match reader.parse(&mut buf, &mut readbuf, &settings) { @@ -1358,7 +1346,8 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); let mut req = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); @@ -1379,7 +1368,8 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); @@ -1408,7 +1398,8 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); let mut req = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); @@ -1458,7 +1449,8 @@ mod tests { "GET /test HTTP/1.1\r\n\ transfer-encoding: chunked\r\n\r\n"); let mut readbuf = BytesMut::new(); - let settings = WorkerSettings::::new(Vec::new(), None); + let settings = WorkerSettings::::new( + Vec::new(), KeepAlive::Os); let mut reader = Reader::new(); let mut req = reader_parse_ready!(reader.parse(&mut buf, &mut readbuf, &settings)); diff --git a/src/server/h1writer.rs b/src/server/h1writer.rs index e77e60ca7..72e34a1e6 100644 --- a/src/server/h1writer.rs +++ b/src/server/h1writer.rs @@ -68,27 +68,22 @@ impl H1Writer { self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) } - fn write_to_stream(&mut self) -> io::Result { - while !self.buffer.is_empty() { - match self.stream.write(self.buffer.as_ref()) { + fn write_data(&mut self, data: &[u8]) -> io::Result<(usize, bool)> { + let mut written = 0; + while written < data.len() { + match self.stream.write(&data[written..]) { Ok(0) => { self.disconnected(); - return Ok(WriterState::Done); - }, - Ok(n) => { - let _ = self.buffer.split_to(n); + return Ok((0, true)); }, + Ok(n) => written += n, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if self.buffer.len() > self.buffer_capacity { - return Ok(WriterState::Pause) - } else { - return Ok(WriterState::Done) - } + return Ok((written, false)) } Err(err) => return Err(err), } } - Ok(WriterState::Done) + Ok((written, false)) } } @@ -216,18 +211,12 @@ impl Writer for H1Writer { // shortcut for upgraded connection if self.flags.contains(Flags::UPGRADE) { if self.buffer.is_empty() { - match self.stream.write(payload.as_ref()) { - Ok(0) => { - self.disconnected(); + match self.write_data(payload.as_ref())? { + (_, true) => return Ok(WriterState::Done), + (n, false) => if payload.len() < n { + self.buffer.extend_from_slice(&payload.as_ref()[n..]); return Ok(WriterState::Done); - }, - Ok(n) => if payload.len() < n { - self.buffer.extend_from_slice(&payload.as_ref()[n..]) - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - return Ok(WriterState::Done) } - Err(err) => return Err(err), } } else { self.buffer.extend(payload); @@ -264,16 +253,22 @@ impl Writer for H1Writer { #[inline] fn poll_completed(&mut self, shutdown: bool) -> Poll<(), io::Error> { - match self.write_to_stream() { - Ok(WriterState::Done) => { - if shutdown { - self.stream.shutdown() - } else { - Ok(Async::Ready(())) + if !self.buffer.is_empty() { + let buf: &[u8] = unsafe{mem::transmute(self.buffer.as_ref())}; + match self.write_data(buf)? { + (_, true) => (), + (n, false) => { + let _ = self.buffer.split_to(n); + if self.buffer.len() > self.buffer_capacity { + return Ok(Async::NotReady) + } } - }, - Ok(WriterState::Pause) => Ok(Async::NotReady), - Err(err) => Err(err) + } + } + if shutdown { + self.stream.shutdown() + } else { + Ok(Async::Ready(())) } } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 9f644a1e9..b1b4793c9 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -31,6 +31,19 @@ use httpresponse::HttpResponse; /// max buffer size 64k pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536; +#[derive(Debug, PartialEq, Clone, Copy)] +/// Server keep-alive setting +pub enum KeepAlive { + /// Keep alive in seconds + Timeout(usize), + /// Use `SO_KEEPALIVE` socket option, value in seconds + Tcp(usize), + /// Relay on OS to shutdown tcp connection + Os, + /// Disabled + Disabled, +} + /// Pause accepting incoming connections /// /// If socket contains some pending connection, they might be dropped. diff --git a/src/server/settings.rs b/src/server/settings.rs index 7c7299ec8..82b190b85 100644 --- a/src/server/settings.rs +++ b/src/server/settings.rs @@ -5,6 +5,7 @@ use std::cell::{Cell, RefCell, RefMut, UnsafeCell}; use futures_cpupool::{Builder, CpuPool}; use helpers; +use super::KeepAlive; use super::channel::Node; use super::shared::{SharedBytes, SharedBytesPool}; @@ -97,8 +98,8 @@ impl ServerSettings { pub(crate) struct WorkerSettings { h: RefCell>, - enabled: bool, keep_alive: u64, + ka_enabled: bool, bytes: Rc, messages: Rc, channels: Cell, @@ -106,11 +107,16 @@ pub(crate) struct WorkerSettings { } impl WorkerSettings { - pub(crate) fn new(h: Vec, keep_alive: Option) -> WorkerSettings { + pub(crate) fn new(h: Vec, keep_alive: KeepAlive) -> WorkerSettings { + let (keep_alive, ka_enabled) = match keep_alive { + KeepAlive::Timeout(val) => (val as u64, true), + KeepAlive::Os | KeepAlive::Tcp(_) => (0, true), + KeepAlive::Disabled => (0, false), + }; + WorkerSettings { + keep_alive, ka_enabled, h: RefCell::new(h), - enabled: if let Some(ka) = keep_alive { ka > 0 } else { false }, - keep_alive: keep_alive.unwrap_or(0), bytes: Rc::new(SharedBytesPool::new()), messages: Rc::new(helpers::SharedMessagePool::new()), channels: Cell::new(0), @@ -135,7 +141,7 @@ impl WorkerSettings { } pub fn keep_alive_enabled(&self) -> bool { - self.enabled + self.ka_enabled } pub fn get_shared_bytes(&self) -> SharedBytes { diff --git a/src/server/srv.rs b/src/server/srv.rs index cd6e2edfe..d0c180b5c 100644 --- a/src/server/srv.rs +++ b/src/server/srv.rs @@ -20,13 +20,12 @@ use native_tls::TlsAcceptor; use openssl::ssl::{AlpnError, SslAcceptorBuilder}; use helpers; -use super::{IntoHttpHandler, IoStream}; +use super::{IntoHttpHandler, IoStream, KeepAlive}; use super::{PauseServer, ResumeServer, StopServer}; use super::channel::{HttpChannel, WrapperStream}; use super::worker::{Conn, Worker, StreamHandlerType, StopWorker}; use super::settings::{ServerSettings, WorkerSettings}; - /// An HTTP Server pub struct HttpServer where H: IntoHttpHandler + 'static { @@ -34,7 +33,7 @@ pub struct HttpServer where H: IntoHttpHandler + 'static threads: usize, backlog: i32, host: Option, - keep_alive: Option, + keep_alive: KeepAlive, factory: Arc Vec + Send + Sync>, #[cfg_attr(feature="cargo-clippy", allow(type_complexity))] workers: Vec<(usize, Addr>)>, @@ -83,7 +82,7 @@ impl HttpServer where H: IntoHttpHandler + 'static threads: num_cpus::get(), backlog: 2048, host: None, - keep_alive: None, + keep_alive: KeepAlive::Os, factory: Arc::new(f), workers: Vec::new(), sockets: HashMap::new(), @@ -124,14 +123,8 @@ impl HttpServer where H: IntoHttpHandler + 'static /// Set server keep-alive setting. /// - /// By default keep alive is enabled. - /// - /// - `Some(75)` - enable - /// - /// - `Some(0)` - disable - /// - /// - `None` - use `SO_KEEPALIVE` socket option - pub fn keep_alive(mut self, val: Option) -> Self { + /// By default keep alive is set to a `Os`. + pub fn keep_alive(mut self, val: KeepAlive) -> Self { self.keep_alive = val; self } diff --git a/src/server/worker.rs b/src/server/worker.rs index 5257d8615..02fa7453c 100644 --- a/src/server/worker.rs +++ b/src/server/worker.rs @@ -23,7 +23,7 @@ use actix::*; use actix::msgs::StopArbiter; use helpers; -use server::HttpHandler; +use server::{HttpHandler, KeepAlive}; use server::channel::HttpChannel; use server::settings::WorkerSettings; @@ -48,21 +48,30 @@ impl Message for StopWorker { /// Http worker /// /// Worker accepts Socket objects via unbounded channel and start requests processing. -pub(crate) struct Worker where H: HttpHandler + 'static { +pub(crate) +struct Worker where H: HttpHandler + 'static { settings: Rc>, hnd: Handle, handler: StreamHandlerType, + tcp_ka: Option, } impl Worker { - pub(crate) fn new(h: Vec, handler: StreamHandlerType, keep_alive: Option) + pub(crate) fn new(h: Vec, handler: StreamHandlerType, keep_alive: KeepAlive) -> Worker { + let tcp_ka = if let KeepAlive::Tcp(val) = keep_alive { + Some(time::Duration::new(val as u64, 0)) + } else { + None + }; + Worker { settings: Rc::new(WorkerSettings::new(h, keep_alive)), hnd: Arbiter::handle().clone(), handler, + tcp_ka, } } @@ -106,9 +115,7 @@ impl Handler> for Worker fn handle(&mut self, msg: Conn, _: &mut Context) { - if !self.settings.keep_alive_enabled() && - msg.io.set_keepalive(Some(time::Duration::new(75, 0))).is_err() - { + if self.tcp_ka.is_some() && msg.io.set_keepalive(self.tcp_ka).is_err() { error!("Can not set socket keep-alive option"); } self.handler.handle(Rc::clone(&self.settings), &self.hnd, msg);