From 86583049fa59fbc8a2ae77dfa49335166fdd6219 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 25 Oct 2017 16:25:26 -0700 Subject: [PATCH] Fix disconnection handling --- CHANGES.md | 2 + Cargo.toml | 2 +- src/body.rs | 36 +++++----- src/context.rs | 20 +++++- src/resource.rs | 2 +- src/server.rs | 31 +++++++-- src/task.rs | 182 ++++++++++++++++++++++++++++++++---------------- 7 files changed, 188 insertions(+), 87 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 54c66adfb..76866d2fc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,8 @@ * Re-use `BinaryBody` for `Frame::Payload` +* Fix disconnection handling. + ## 0.1.0 (2017-10-23) diff --git a/Cargo.toml b/Cargo.toml index c04ff1152..bc67ab77b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ tokio-proto = "0.1" # h2 = { git = 'https://github.com/carllerche/h2', optional = true } [dependencies.actix] -#version = "0.3" +#version = ">=0.3.1" #path = "../actix" git = "https://github.com/actix/actix.git" default-features = false diff --git a/src/body.rs b/src/body.rs index 155710958..9e8b3ecec 100644 --- a/src/body.rs +++ b/src/body.rs @@ -49,7 +49,7 @@ impl Body { } /// Create body from slice (copy) - pub fn from_slice<'a>(s: &'a [u8]) -> Body { + pub fn from_slice(s: &[u8]) -> Body { Body::Binary(BinaryBody::Bytes(Bytes::from(s))) } } @@ -61,19 +61,23 @@ impl From for Body where T: Into{ } impl BinaryBody { + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn len(&self) -> usize { - match self { - &BinaryBody::Bytes(ref bytes) => bytes.len(), - &BinaryBody::Slice(slice) => slice.len(), - &BinaryBody::SharedBytes(ref bytes) => bytes.len(), - &BinaryBody::ArcSharedBytes(ref bytes) => bytes.len(), - &BinaryBody::SharedString(ref s) => s.len(), - &BinaryBody::ArcSharedString(ref s) => s.len(), + match *self { + BinaryBody::Bytes(ref bytes) => bytes.len(), + BinaryBody::Slice(slice) => slice.len(), + BinaryBody::SharedBytes(ref bytes) => bytes.len(), + BinaryBody::ArcSharedBytes(ref bytes) => bytes.len(), + BinaryBody::SharedString(ref s) => s.len(), + BinaryBody::ArcSharedString(ref s) => s.len(), } } /// Create binary body from slice - pub fn from_slice<'a>(s: &'a [u8]) -> BinaryBody { + pub fn from_slice(s: &[u8]) -> BinaryBody { BinaryBody::Bytes(Bytes::from(s)) } } @@ -164,13 +168,13 @@ impl<'a> From<&'a Arc> for BinaryBody { impl AsRef<[u8]> for BinaryBody { fn as_ref(&self) -> &[u8] { - match self { - &BinaryBody::Bytes(ref bytes) => bytes.as_ref(), - &BinaryBody::Slice(slice) => slice, - &BinaryBody::SharedBytes(ref bytes) => bytes.as_ref(), - &BinaryBody::ArcSharedBytes(ref bytes) => bytes.as_ref(), - &BinaryBody::SharedString(ref s) => s.as_bytes(), - &BinaryBody::ArcSharedString(ref s) => s.as_bytes(), + match *self { + BinaryBody::Bytes(ref bytes) => bytes.as_ref(), + BinaryBody::Slice(slice) => slice, + BinaryBody::SharedBytes(ref bytes) => bytes.as_ref(), + BinaryBody::ArcSharedBytes(ref bytes) => bytes.as_ref(), + BinaryBody::SharedString(ref s) => s.as_bytes(), + BinaryBody::ArcSharedString(ref s) => s.as_bytes(), } } } diff --git a/src/context.rs b/src/context.rs index 208852060..3606db593 100644 --- a/src/context.rs +++ b/src/context.rs @@ -10,6 +10,7 @@ use actix::fut::ActorFuture; use actix::dev::{AsyncContextApi, ActorAddressCell, ActorItemsCell, ActorWaitCell, SpawnHandle, Envelope, ToEnvelope, RemoteEnvelope}; +use task::IoContext; use body::BinaryBody; use route::{Route, Frame}; use httpresponse::HttpResponse; @@ -26,10 +27,20 @@ pub struct HttpContext where A: Actor> + Route, stream: VecDeque, wait: ActorWaitCell, app_state: Rc<::State>, + disconnected: bool, } +impl IoContext for HttpContext where A: Actor + Route { -impl ActorContext for HttpContext where A: Actor + Route + fn disconnected(&mut self) { + self.disconnected = true; + if self.state == ActorState::Running { + self.state = ActorState::Stopping; + } + } +} + +impl ActorContext for HttpContext where A: Actor + Route { /// Stop actor execution fn stop(&mut self) { @@ -95,6 +106,7 @@ impl HttpContext where A: Actor + Route { wait: ActorWaitCell::default(), stream: VecDeque::new(), app_state: state, + disconnected: false, } } @@ -124,6 +136,11 @@ impl HttpContext where A: Actor + Route { pub fn write_eof(&mut self) { self.stream.push_back(Frame::Payload(None)) } + + /// Check if connection still open + pub fn connected(&self) -> bool { + !self.disconnected + } } impl HttpContext where A: Actor + Route { @@ -157,7 +174,6 @@ impl Stream for HttpContext where A: Actor + Route if self.act.is_none() { return Ok(Async::NotReady) } - let act: &mut A = unsafe { std::mem::transmute(self.act.as_mut().unwrap() as &mut A) }; diff --git a/src/resource.rs b/src/resource.rs index a69e27ca8..a030d5ec9 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -158,7 +158,7 @@ impl Reply where A: Actor + Route }, ReplyItem::Actor(act) => { ctx.set_actor(act); - Task::with_stream(ctx) + Task::with_context(ctx) } } } diff --git a/src/server.rs b/src/server.rs index 8331b48f5..e1e6a2281 100644 --- a/src/server.rs +++ b/src/server.rs @@ -171,11 +171,11 @@ pub struct HttpChannel { keepalive_timer: Option, } -/*impl Drop for HttpChannel { +impl Drop for HttpChannel { fn drop(&mut self) { println!("Drop http channel"); } -}*/ +} impl Actor for HttpChannel where T: AsyncRead + AsyncWrite + 'static, @@ -205,6 +205,8 @@ impl Future for HttpChannel } loop { + let mut not_ready = true; + // check in-flight messages let mut idx = 0; while idx < self.items.len() { @@ -218,6 +220,7 @@ impl Future for HttpChannel match self.items[idx].task.poll_io(&mut self.stream, req) { Ok(Async::Ready(ready)) => { + not_ready = false; let mut item = self.items.pop_front().unwrap(); // overide keep-alive state @@ -247,8 +250,10 @@ impl Future for HttpChannel } else if !self.items[idx].finished && !self.items[idx].error { match self.items[idx].task.poll() { Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => - self.items[idx].finished = true, + Ok(Async::Ready(_)) => { + not_ready = false; + self.items[idx].finished = true; + }, Err(_) => self.items[idx].error = true, } @@ -267,8 +272,10 @@ impl Future for HttpChannel if !self.inactive[idx].finished && !self.inactive[idx].error { match self.inactive[idx].task.poll() { Ok(Async::NotReady) => (), - Ok(Async::Ready(_)) => - self.inactive[idx].finished = true, + Ok(Async::Ready(_)) => { + not_ready = false; + self.inactive[idx].finished = true + } Err(_) => self.inactive[idx].error = true, } @@ -280,6 +287,8 @@ impl Future for HttpChannel if !self.error && self.items.len() < MAX_PIPELINED_MESSAGES { match self.reader.parse(&mut self.stream) { Ok(Async::Ready((mut req, payload))) => { + not_ready = false; + // stop keepalive timer self.keepalive_timer.take(); @@ -300,6 +309,12 @@ impl Future for HttpChannel finished: false}); } Err(err) => { + // notify all tasks + not_ready = false; + for entry in &mut self.items { + entry.task.disconnected() + } + // kill keepalive self.keepalive = false; self.keepalive_timer.take(); @@ -344,6 +359,10 @@ impl Future for HttpChannel if self.items.is_empty() && self.inactive.is_empty() && self.error { return Ok(Async::Ready(())) } + + if not_ready { + return Ok(Async::NotReady) + } } } } diff --git a/src/task.rs b/src/task.rs index 7e29cbf5a..f5f155e6d 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,4 +1,4 @@ -use std::{cmp, io}; +use std::{mem, cmp, io}; use std::rc::Rc; use std::fmt::Write; use std::collections::VecDeque; @@ -47,16 +47,27 @@ impl TaskIOState { } } +enum TaskStream { + None, + Stream(Box), + Context(Box>), +} + +pub(crate) trait IoContext: Stream + 'static { + fn disconnected(&mut self); +} + pub struct Task { state: TaskRunningState, iostate: TaskIOState, frames: VecDeque, - stream: Option>, + stream: TaskStream, encoder: Encoder, buffer: BytesMut, upgrade: bool, keepalive: bool, prepared: Option, + disconnected: bool, middlewares: Option>>>, } @@ -71,12 +82,13 @@ impl Task { state: TaskRunningState::Running, iostate: TaskIOState::Done, frames: frames, - stream: None, + stream: TaskStream::None, encoder: Encoder::length(0), buffer: BytesMut::new(), upgrade: false, keepalive: false, prepared: None, + disconnected: false, middlewares: None, } } @@ -88,12 +100,30 @@ impl Task { state: TaskRunningState::Running, iostate: TaskIOState::ReadingMessage, frames: VecDeque::new(), - stream: Some(Box::new(stream)), + stream: TaskStream::Stream(Box::new(stream)), encoder: Encoder::length(0), buffer: BytesMut::new(), upgrade: false, keepalive: false, prepared: None, + disconnected: false, + middlewares: None, + } + } + + pub(crate) fn with_context(ctx: C) -> Self + { + Task { + state: TaskRunningState::Running, + iostate: TaskIOState::ReadingMessage, + frames: VecDeque::new(), + stream: TaskStream::Context(Box::new(ctx)), + encoder: Encoder::length(0), + buffer: BytesMut::new(), + upgrade: false, + keepalive: false, + prepared: None, + disconnected: false, middlewares: None, } } @@ -106,6 +136,15 @@ impl Task { self.middlewares = Some(middlewares); } + pub(crate) fn disconnected(&mut self) { + let len = self.buffer.len(); + self.buffer.split_to(len); + self.disconnected = true; + if let TaskStream::Context(ref mut ctx) = self.stream { + ctx.disconnected(); + } + } + fn prepare(&mut self, req: &mut HttpRequest, msg: HttpResponse) { trace!("Prepare message status={:?}", msg.status); @@ -252,20 +291,26 @@ impl Task { trace!("IO Frame: {:?}", frame); match frame { Frame::Message(response) => { - self.prepare(req, response); + if !self.disconnected { + self.prepare(req, response); + } } Frame::Payload(Some(chunk)) => { - if self.prepared.is_some() { - // TODO: add warning, write after EOF - self.encoder.encode(&mut self.buffer, chunk.as_ref()); - } else { - // might be response for EXCEPT - self.buffer.extend_from_slice(chunk.as_ref()) + if !self.disconnected { + if self.prepared.is_some() { + // TODO: add warning, write after EOF + self.encoder.encode(&mut self.buffer, chunk.as_ref()); + } else { + // might be response for EXCEPT + self.buffer.extend_from_slice(chunk.as_ref()) + } } }, Frame::Payload(None) => { - // TODO: add error "not eof"" - if !self.encoder.encode(&mut self.buffer, [].as_ref()) { + if !self.disconnected && + !self.encoder.encode(&mut self.buffer, [].as_ref()) + { + // TODO: add error "not eof"" debug!("last payload item, but it is not EOF "); return Err(()) } @@ -276,15 +321,17 @@ impl Task { } // write bytes to TcpStream - while !self.buffer.is_empty() { - match io.write(self.buffer.as_ref()) { - Ok(n) => { - self.buffer.split_to(n); - }, - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - break + if !self.disconnected { + while !self.buffer.is_empty() { + match io.write(self.buffer.as_ref()) { + Ok(n) => { + self.buffer.split_to(n); + }, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + break + } + Err(_) => return Err(()), } - Err(_) => return Err(()), } } @@ -295,10 +342,13 @@ impl Task { } else if self.state == TaskRunningState::Paused { self.state = TaskRunningState::Running; } + } else { + // at this point we wont get any more Frames + self.iostate = TaskIOState::Done; } // response is completed - if self.buffer.is_empty() && self.iostate.is_done() { + if (self.buffer.is_empty() || self.disconnected) && self.iostate.is_done() { // run middlewares if let Some(ref mut resp) = self.prepared { if let Some(middlewares) = self.middlewares.take() { @@ -313,6 +363,46 @@ impl Task { Ok(Async::NotReady) } } + + fn poll_stream(&mut self, stream: &mut S) -> Poll<(), ()> + where S: Stream + { + loop { + match stream.poll() { + Ok(Async::Ready(Some(frame))) => { + match frame { + Frame::Message(ref msg) => { + if self.iostate != TaskIOState::ReadingMessage { + error!("Non expected frame {:?}", frame); + return Err(()) + } + self.upgrade = msg.upgrade(); + if self.upgrade || msg.body().has_body() { + self.iostate = TaskIOState::ReadingPayload; + } else { + self.iostate = TaskIOState::Done; + } + }, + Frame::Payload(ref chunk) => { + if chunk.is_none() { + self.iostate = TaskIOState::Done; + } else if self.iostate != TaskIOState::ReadingPayload { + error!("Non expected frame {:?}", self.iostate); + return Err(()) + } + }, + } + self.frames.push_back(frame) + }, + Ok(Async::Ready(None)) => + return Ok(Async::Ready(())), + Ok(Async::NotReady) => + return Ok(Async::NotReady), + Err(_) => + return Err(()) + } + } + } } impl Future for Task { @@ -320,45 +410,15 @@ impl Future for Task { type Error = (); fn poll(&mut self) -> Poll { - if let Some(ref mut stream) = self.stream { - loop { - match stream.poll() { - Ok(Async::Ready(Some(frame))) => { - match frame { - Frame::Message(ref msg) => { - if self.iostate != TaskIOState::ReadingMessage { - error!("Non expected frame {:?}", frame); - return Err(()) - } - self.upgrade = msg.upgrade(); - if self.upgrade || msg.body().has_body() { - self.iostate = TaskIOState::ReadingPayload; - } else { - self.iostate = TaskIOState::Done; - } - }, - Frame::Payload(ref chunk) => { - if chunk.is_none() { - self.iostate = TaskIOState::Done; - } else if self.iostate != TaskIOState::ReadingPayload { - error!("Non expected frame {:?}", self.iostate); - return Err(()) - } - }, - } - self.frames.push_back(frame) - }, - Ok(Async::Ready(None)) => - return Ok(Async::Ready(())), - Ok(Async::NotReady) => - return Ok(Async::NotReady), - Err(_) => - return Err(()) - } - } - } else { - Ok(Async::Ready(())) - } + let mut s = mem::replace(&mut self.stream, TaskStream::None); + + let result = match s { + TaskStream::None => Ok(Async::Ready(())), + TaskStream::Stream(ref mut stream) => self.poll_stream(stream), + TaskStream::Context(ref mut context) => self.poll_stream(context), + }; + self.stream = s; + result } }