diff --git a/actix-framed/src/service.rs b/actix-framed/src/service.rs index bc730074c..5fb74fa1f 100644 --- a/actix-framed/src/service.rs +++ b/actix-framed/src/service.rs @@ -1,12 +1,14 @@ use std::marker::PhantomData; use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use actix_http::body::{BodySize, MessageBody, ResponseBody}; use actix_http::error::{Error, ResponseError}; +use actix_http::h1::{Codec, Message}; use actix_http::ws::{verify_handshake, HandshakeError}; -use actix_http::{h1, Request}; +use actix_http::{Request, Response}; use actix_service::{NewService, Service}; use futures::future::{ok, Either, FutureResult}; -use futures::{Async, Future, IntoFuture, Poll}; +use futures::{Async, Future, IntoFuture, Poll, Sink}; /// Service that verifies incoming request if it is valid websocket /// upgrade request. In case of error returns `HandshakeError` @@ -21,9 +23,9 @@ impl Default for VerifyWebSockets { } impl NewService for VerifyWebSockets { - type Request = (Request, Framed); - type Response = (Request, Framed); - type Error = (HandshakeError, Framed); + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); type InitError = (); type Service = VerifyWebSockets; type Future = FutureResult; @@ -34,16 +36,16 @@ impl NewService for VerifyWebSockets { } impl Service for VerifyWebSockets { - type Request = (Request, Framed); - type Response = (Request, Framed); - type Error = (HandshakeError, Framed); + type Request = (Request, Framed); + type Response = (Request, Framed); + type Error = (HandshakeError, Framed); type Future = FutureResult; fn poll_ready(&mut self) -> Poll<(), Self::Error> { Ok(Async::Ready(())) } - fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { + fn call(&mut self, (req, framed): (Request, Framed)) -> Self::Future { match verify_handshake(req.head()) { Err(e) => Err((e, framed)).into_future(), Ok(_) => Ok((req, framed)).into_future(), @@ -70,7 +72,7 @@ where R: 'static, E: ResponseError + 'static, { - type Request = Result)>; + type Request = Result)>; type Response = R; type Error = Error; type InitError = (); @@ -88,7 +90,7 @@ where R: 'static, E: ResponseError + 'static, { - type Request = Result)>; + type Request = Result)>; type Response = R; type Error = Error; type Future = Either, Box>>; @@ -97,16 +99,123 @@ where Ok(Async::Ready(())) } - fn call(&mut self, req: Result)>) -> Self::Future { + fn call(&mut self, req: Result)>) -> Self::Future { match req { Ok(r) => Either::A(ok(r)), Err((e, framed)) => { let res = e.render_response(); let e = Error::from(e); Either::B(Box::new( - h1::SendResponse::new(framed, res).then(move |_| Err(e)), + SendResponse::new(framed, res).then(move |_| Err(e)), )) } } } } + +/// Send http/1 response +pub struct SendResponse { + res: Option, BodySize)>>, + body: Option>, + framed: Option>, +} + +impl SendResponse +where + B: MessageBody, +{ + pub fn new(framed: Framed, response: Response) -> Self { + let (res, body) = response.into_parts(); + + SendResponse { + res: Some((res, body.size()).into()), + body: Some(body), + framed: Some(framed), + } + } +} + +impl Future for SendResponse +where + T: AsyncRead + AsyncWrite, + B: MessageBody, +{ + type Item = Framed; + type Error = (Error, Framed); + + fn poll(&mut self) -> Poll { + loop { + let mut body_ready = self.body.is_some(); + + // send body + if self.res.is_none() && self.body.is_some() { + while body_ready + && self.body.is_some() + && !self.framed.as_ref().unwrap().is_write_buf_full() + { + match self + .body + .as_mut() + .unwrap() + .poll_next() + .map_err(|e| (e, self.framed.take().unwrap()))? + { + Async::Ready(item) => { + // body is done + if item.is_none() { + let _ = self.body.take(); + } + self.framed + .as_mut() + .unwrap() + .force_send(Message::Chunk(item)) + .map_err(|e| (e.into(), self.framed.take().unwrap()))?; + } + Async::NotReady => body_ready = false, + } + } + } + + // flush write buffer + if !self.framed.as_ref().unwrap().is_write_buf_empty() { + match self + .framed + .as_mut() + .unwrap() + .poll_complete() + .map_err(|e| (e.into(), self.framed.take().unwrap()))? + { + Async::Ready(_) => { + if body_ready { + continue; + } else { + return Ok(Async::NotReady); + } + } + Async::NotReady => return Ok(Async::NotReady), + } + } + + // send response + if let Some(res) = self.res.take() { + self.framed + .as_mut() + .unwrap() + .force_send(res) + .map_err(|e| (e.into(), self.framed.take().unwrap()))?; + continue; + } + + if self.body.is_some() { + if body_ready { + continue; + } else { + return Ok(Async::NotReady); + } + } else { + break; + } + } + Ok(Async::Ready(self.framed.take().unwrap())) + } +} diff --git a/actix-http/src/h1/utils.rs b/actix-http/src/h1/utils.rs index c7b7ccec3..fdc4cf0bc 100644 --- a/actix-http/src/h1/utils.rs +++ b/actix-http/src/h1/utils.rs @@ -34,35 +34,23 @@ where B: MessageBody, { type Item = Framed; - type Error = (Error, Framed); + type Error = Error; fn poll(&mut self) -> Poll { loop { let mut body_ready = self.body.is_some(); + let framed = self.framed.as_mut().unwrap(); // send body if self.res.is_none() && self.body.is_some() { - while body_ready - && self.body.is_some() - && !self.framed.as_ref().unwrap().is_write_buf_full() - { - match self - .body - .as_mut() - .unwrap() - .poll_next() - .map_err(|e| (e, self.framed.take().unwrap()))? - { + while body_ready && self.body.is_some() && !framed.is_write_buf_full() { + match self.body.as_mut().unwrap().poll_next()? { Async::Ready(item) => { // body is done if item.is_none() { let _ = self.body.take(); } - self.framed - .as_mut() - .unwrap() - .force_send(Message::Chunk(item)) - .map_err(|e| (e.into(), self.framed.take().unwrap()))?; + framed.force_send(Message::Chunk(item))?; } Async::NotReady => body_ready = false, } @@ -70,14 +58,8 @@ where } // flush write buffer - if !self.framed.as_ref().unwrap().is_write_buf_empty() { - match self - .framed - .as_mut() - .unwrap() - .poll_complete() - .map_err(|e| (e.into(), self.framed.take().unwrap()))? - { + if !framed.is_write_buf_empty() { + match framed.poll_complete()? { Async::Ready(_) => { if body_ready { continue; @@ -91,11 +73,7 @@ where // send response if let Some(res) = self.res.take() { - self.framed - .as_mut() - .unwrap() - .force_send(res) - .map_err(|e| (e.into(), self.framed.take().unwrap()))?; + framed.force_send(res)?; continue; }