diff --git a/src/config.rs b/src/config.rs index 543e78acd..36b949c33 100644 --- a/src/config.rs +++ b/src/config.rs @@ -21,7 +21,7 @@ pub struct ServiceConfig(Rc); struct Inner { keep_alive: Option, client_timeout: u64, - client_shutdown: u64, + client_disconnect: u64, ka_enabled: bool, date: UnsafeCell<(bool, Date)>, } @@ -35,7 +35,7 @@ impl Clone for ServiceConfig { impl ServiceConfig { /// Create instance of `ServiceConfig` pub(crate) fn new( - keep_alive: KeepAlive, client_timeout: u64, client_shutdown: u64, + keep_alive: KeepAlive, client_timeout: u64, client_disconnect: u64, ) -> ServiceConfig { let (keep_alive, ka_enabled) = match keep_alive { KeepAlive::Timeout(val) => (val as u64, true), @@ -52,7 +52,7 @@ impl ServiceConfig { keep_alive, ka_enabled, client_timeout, - client_shutdown, + client_disconnect, date: UnsafeCell::new((false, Date::new())), })) } @@ -100,9 +100,9 @@ impl ServiceConfig { } } - /// Client shutdown timer - pub fn client_shutdown_timer(&self) -> Option { - let delay = self.0.client_shutdown; + /// Client disconnect timer + pub fn client_disconnect_timer(&self) -> Option { + let delay = self.0.client_disconnect; if delay != 0 { Some(self.now() + Duration::from_millis(delay)) } else { @@ -184,7 +184,7 @@ impl ServiceConfig { pub struct ServiceConfigBuilder { keep_alive: KeepAlive, client_timeout: u64, - client_shutdown: u64, + client_disconnect: u64, host: String, addr: net::SocketAddr, secure: bool, @@ -196,7 +196,7 @@ impl ServiceConfigBuilder { ServiceConfigBuilder { keep_alive: KeepAlive::Timeout(5), client_timeout: 5000, - client_shutdown: 5000, + client_disconnect: 0, secure: false, host: "localhost".to_owned(), addr: "127.0.0.1:8080".parse().unwrap(), @@ -204,10 +204,14 @@ impl ServiceConfigBuilder { } /// Enable secure flag for current server. + /// This flags also enables `client disconnect timeout`. /// /// By default this flag is set to false. pub fn secure(mut self) -> Self { self.secure = true; + if self.client_disconnect == 0 { + self.client_disconnect = 3000; + } self } @@ -233,16 +237,16 @@ impl ServiceConfigBuilder { self } - /// Set server connection shutdown timeout in milliseconds. + /// Set server connection disconnect timeout in milliseconds. /// - /// Defines a timeout for shutdown connection. If a shutdown procedure does not complete - /// within this time, the request is dropped. This timeout affects only secure connections. + /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete + /// within this time, the request get dropped. This timeout affects secure connections. /// /// To disable timeout set value to 0. /// - /// By default client timeout is set to 5000 milliseconds. - pub fn client_shutdown(mut self, val: u64) -> Self { - self.client_shutdown = val; + /// By default disconnect timeout is set to 3000 milliseconds. + pub fn client_disconnect(mut self, val: u64) -> Self { + self.client_disconnect = val; self } @@ -277,9 +281,7 @@ impl ServiceConfigBuilder { /// Finish service configuration and create `ServiceConfig` object. pub fn finish(self) -> ServiceConfig { - let client_shutdown = if self.secure { self.client_shutdown } else { 0 }; - - ServiceConfig::new(self.keep_alive, self.client_timeout, client_shutdown) + ServiceConfig::new(self.keep_alive, self.client_timeout, self.client_disconnect) } } diff --git a/src/error.rs b/src/error.rs index 21aabac49..fb5df2328 100644 --- a/src/error.rs +++ b/src/error.rs @@ -397,9 +397,9 @@ pub enum DispatchError { // #[fail(display = "The first request did not complete within the specified timeout")] SlowRequestTimeout, - /// Shutdown timeout + /// Disconnect timeout. Makes sense for ssl streams. // #[fail(display = "Connection shutdown timeout")] - ShutdownTimeout, + DisconnectTimeout, /// Payload is not consumed // #[fail(display = "Task is completed but request's payload is not consumed")] diff --git a/src/h1/codec.rs b/src/h1/codec.rs index f1b526d52..ac54194ab 100644 --- a/src/h1/codec.rs +++ b/src/h1/codec.rs @@ -10,7 +10,7 @@ use body::Body; use error::ParseError; use helpers; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; -use http::Version; +use http::{Method, Version}; use httpresponse::HttpResponse; use request::RequestPool; use server::output::{ResponseInfo, ResponseLength}; @@ -27,6 +27,8 @@ pub enum OutMessage { pub struct Codec { decoder: H1Decoder, encoder: H1Writer, + head: bool, + version: Version, } impl Codec { @@ -40,6 +42,8 @@ impl Codec { Codec { decoder: H1Decoder::with_pool(pool), encoder: H1Writer::new(), + head: false, + version: Version::HTTP_11, } } } @@ -49,7 +53,17 @@ impl Decoder for Codec { type Error = ParseError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { - self.decoder.decode(src) + let res = self.decoder.decode(src); + + match res { + Ok(Some(InMessage::Message(ref req))) + | Ok(Some(InMessage::MessageWithPayload(ref req))) => { + self.head = req.inner.method == Method::HEAD; + self.version = req.inner.version; + } + _ => (), + } + res } } @@ -62,7 +76,7 @@ impl Encoder for Codec { ) -> Result<(), Self::Error> { match item { OutMessage::Response(res) => { - self.encoder.encode(res, dst)?; + self.encoder.encode(res, dst, self.head, self.version)?; } OutMessage::Payload(bytes) => { dst.extend_from_slice(&bytes); @@ -87,6 +101,7 @@ struct H1Writer { flags: Flags, written: u64, headers_size: u32, + info: ResponseInfo, } impl H1Writer { @@ -95,6 +110,7 @@ impl H1Writer { flags: Flags::empty(), written: 0, headers_size: 0, + info: ResponseInfo::default(), } } @@ -116,10 +132,11 @@ impl H1Writer { } fn encode( - &mut self, mut msg: HttpResponse, buffer: &mut BytesMut, + &mut self, mut msg: HttpResponse, buffer: &mut BytesMut, head: bool, + version: Version, ) -> io::Result<()> { // prepare task - let info = ResponseInfo::new(false); // req.inner.method == Method::HEAD); + self.info.update(&mut msg, head, version); //if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) { //self.flags = Flags::STARTED | Flags::KEEPALIVE; @@ -166,7 +183,7 @@ impl H1Writer { buffer.extend_from_slice(reason); // content length - match info.length { + match self.info.length { ResponseLength::Chunked => { buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") } @@ -183,11 +200,6 @@ impl H1Writer { } ResponseLength::None => buffer.extend_from_slice(b"\r\n"), } - if let Some(ce) = info.content_encoding { - buffer.extend_from_slice(b"content-encoding: "); - buffer.extend_from_slice(ce.as_ref()); - buffer.extend_from_slice(b"\r\n"); - } // write headers let mut pos = 0; @@ -197,7 +209,7 @@ impl H1Writer { for (key, value) in msg.headers() { match *key { TRANSFER_ENCODING => continue, - CONTENT_LENGTH => match info.length { + CONTENT_LENGTH => match self.info.length { ResponseLength::None => (), _ => continue, }, diff --git a/src/h1/dispatcher.rs b/src/h1/dispatcher.rs index eda8ebf00..f777648ec 100644 --- a/src/h1/dispatcher.rs +++ b/src/h1/dispatcher.rs @@ -1,7 +1,7 @@ // #![allow(unused_imports, unused_variables, dead_code)] use std::collections::VecDeque; use std::fmt::{Debug, Display}; -// use std::time::{Duration, Instant}; +use std::time::Instant; use actix_net::service::Service; @@ -9,7 +9,7 @@ use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; use tokio_codec::Framed; // use tokio_current_thread::spawn; use tokio_io::{AsyncRead, AsyncWrite}; -// use tokio_timer::Delay; +use tokio_timer::Delay; use error::{ParseError, PayloadError}; use payload::{Payload, PayloadStatus, PayloadWriter}; @@ -47,12 +47,14 @@ where flags: Flags, framed: Framed, error: Option>, + config: ServiceConfig, state: State, payload: Option, messages: VecDeque, - config: ServiceConfig, + ka_expire: Instant, + ka_timer: Option, } enum State { @@ -81,9 +83,28 @@ where { /// Create http/1 dispatcher. pub fn new(stream: T, config: ServiceConfig, service: S) -> Self { - let flags = Flags::FLUSHED; + Dispatcher::with_timeout(stream, config, None, service) + } + + /// Create http/1 dispatcher with slow request timeout. + pub fn with_timeout( + stream: T, config: ServiceConfig, timeout: Option, service: S, + ) -> Self { + let flags = if config.keep_alive_enabled() { + Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED + } else { + Flags::FLUSHED + }; let framed = Framed::new(stream, Codec::new()); + let (ka_expire, ka_timer) = if let Some(delay) = timeout { + (delay.deadline(), Some(delay)) + } else if let Some(delay) = config.keep_alive_timer() { + (delay.deadline(), Some(delay)) + } else { + (config.now(), None) + }; + Dispatcher { payload: None, state: State::None, @@ -93,6 +114,8 @@ where flags, framed, config, + ka_expire, + ka_timer, } } @@ -358,8 +381,64 @@ where } } + if self.ka_timer.is_some() && updated { + if let Some(expire) = self.config.keep_alive_expire() { + self.ka_expire = expire; + } + } Ok(updated) } + + /// keep-alive timer + fn poll_keepalive(&mut self) -> Result<(), DispatchError> { + if let Some(ref mut timer) = self.ka_timer { + match timer.poll() { + Ok(Async::Ready(_)) => { + if timer.deadline() >= self.ka_expire { + // check for any outstanding request handling + if self.state.is_empty() && self.messages.is_empty() { + // if we get timer during shutdown, just drop connection + if self.flags.contains(Flags::SHUTDOWN) { + return Err(DispatchError::DisconnectTimeout); + } else if !self.flags.contains(Flags::STARTED) { + // timeout on first request (slow request) return 408 + trace!("Slow request timeout"); + self.flags + .insert(Flags::STARTED | Flags::READ_DISCONNECTED); + self.state = + State::SendResponse(Some(OutMessage::Response( + HttpResponse::RequestTimeout().finish(), + ))); + } else { + trace!("Keep-alive timeout, close connection"); + self.flags.insert(Flags::SHUTDOWN); + + // start shutdown timer + if let Some(deadline) = + self.config.client_disconnect_timer() + { + timer.reset(deadline) + } else { + return Ok(()); + } + } + } else if let Some(deadline) = self.config.keep_alive_expire() { + timer.reset(deadline) + } + } else { + timer.reset(self.ka_expire) + } + } + Ok(Async::NotReady) => (), + Err(e) => { + error!("Timer error {:?}", e); + return Err(DispatchError::Unknown); + } + } + } + + Ok(()) + } } impl Future for Dispatcher @@ -373,6 +452,8 @@ where #[inline] fn poll(&mut self) -> Poll<(), Self::Error> { + self.poll_keepalive()?; + // shutdown if self.flags.contains(Flags::SHUTDOWN) { if self.flags.contains(Flags::WRITE_DISCONNECTED) { diff --git a/src/lib.rs b/src/lib.rs index b9f90c7a2..6215bc4fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,7 +79,7 @@ //! #![cfg_attr(actix_nightly, feature(tool_lints))] // #![warn(missing_docs)] -// #![allow(unused_imports, unused_variables, dead_code)] +#![allow(dead_code)] extern crate actix; extern crate actix_net; diff --git a/src/server/output.rs b/src/server/output.rs index cfc85e4bc..5fc6fc839 100644 --- a/src/server/output.rs +++ b/src/server/output.rs @@ -4,26 +4,15 @@ use std::io::Write; use std::str::FromStr; use std::{cmp, fmt, io, mem}; -#[cfg(feature = "brotli")] -use brotli2::write::BrotliEncoder; -use bytes::BytesMut; -#[cfg(feature = "flate2")] -use flate2::write::{GzEncoder, ZlibEncoder}; -#[cfg(feature = "flate2")] -use flate2::Compression; +use bytes::{Bytes, BytesMut}; use http::header::{HeaderValue, ACCEPT_ENCODING, CONTENT_LENGTH}; use http::{StatusCode, Version}; use body::{Binary, Body}; use header::ContentEncoding; +use http::Method; use httpresponse::HttpResponse; -use request::InnerRequest; - -// #[derive(Debug)] -// pub(crate) struct RequestInfo { -// pub version: Version, -// pub accept_encoding: Option, -// } +use request::Request; #[derive(Debug)] pub(crate) enum ResponseLength { @@ -38,285 +27,91 @@ pub(crate) enum ResponseLength { pub(crate) struct ResponseInfo { head: bool, pub length: ResponseLength, - pub content_encoding: Option<&'static str>, + pub te: TransferEncoding, +} + +impl Default for ResponseInfo { + fn default() -> Self { + ResponseInfo { + head: false, + length: ResponseLength::None, + te: TransferEncoding::empty(), + } + } } impl ResponseInfo { - pub fn new(head: bool) -> Self { - ResponseInfo { - head, - length: ResponseLength::None, - content_encoding: None, - } - } -} + pub fn update(&mut self, resp: &mut HttpResponse, head: bool, version: Version) { + self.head = head; -#[derive(Debug)] -pub(crate) enum Output { - Empty(BytesMut), - Buffer(BytesMut), - Encoder(ContentEncoder), - TE(TransferEncoding), - Done, -} - -impl Output { - pub fn take(&mut self) -> BytesMut { - match mem::replace(self, Output::Done) { - Output::Empty(bytes) => bytes, - Output::Buffer(bytes) => bytes, - Output::Encoder(mut enc) => enc.take_buf(), - Output::TE(mut te) => te.take(), - Output::Done => panic!(), - } - } - - pub fn take_option(&mut self) -> Option { - match mem::replace(self, Output::Done) { - Output::Empty(bytes) => Some(bytes), - Output::Buffer(bytes) => Some(bytes), - Output::Encoder(mut enc) => Some(enc.take_buf()), - Output::TE(mut te) => Some(te.take()), - Output::Done => None, - } - } - - pub fn as_ref(&mut self) -> &BytesMut { - match self { - Output::Empty(ref mut bytes) => bytes, - Output::Buffer(ref mut bytes) => bytes, - Output::Encoder(ref mut enc) => enc.buf_ref(), - Output::TE(ref mut te) => te.buf_ref(), - Output::Done => panic!(), - } - } - pub fn as_mut(&mut self) -> &mut BytesMut { - match self { - Output::Empty(ref mut bytes) => bytes, - Output::Buffer(ref mut bytes) => bytes, - Output::Encoder(ref mut enc) => enc.buf_mut(), - Output::TE(ref mut te) => te.buf_mut(), - Output::Done => panic!(), - } - } - pub fn split_to(&mut self, cap: usize) -> BytesMut { - match self { - Output::Empty(ref mut bytes) => bytes.split_to(cap), - Output::Buffer(ref mut bytes) => bytes.split_to(cap), - Output::Encoder(ref mut enc) => enc.buf_mut().split_to(cap), - Output::TE(ref mut te) => te.buf_mut().split_to(cap), - Output::Done => BytesMut::new(), - } - } - - pub fn len(&self) -> usize { - match self { - Output::Empty(ref bytes) => bytes.len(), - Output::Buffer(ref bytes) => bytes.len(), - Output::Encoder(ref enc) => enc.len(), - Output::TE(ref te) => te.len(), - Output::Done => 0, - } - } - - pub fn is_empty(&self) -> bool { - match self { - Output::Empty(ref bytes) => bytes.is_empty(), - Output::Buffer(ref bytes) => bytes.is_empty(), - Output::Encoder(ref enc) => enc.is_empty(), - Output::TE(ref te) => te.is_empty(), - Output::Done => true, - } - } - - pub fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { - match self { - Output::Buffer(ref mut bytes) => { - bytes.extend_from_slice(data); - Ok(()) - } - Output::Encoder(ref mut enc) => enc.write(data), - Output::TE(ref mut te) => te.encode(data).map(|_| ()), - Output::Empty(_) | Output::Done => Ok(()), - } - } - - pub fn write_eof(&mut self) -> Result { - match self { - Output::Buffer(_) => Ok(true), - Output::Encoder(ref mut enc) => enc.write_eof(), - Output::TE(ref mut te) => Ok(te.encode_eof()), - Output::Empty(_) | Output::Done => Ok(true), - } - } - - pub(crate) fn for_server( - &mut self, info: &mut ResponseInfo, req: &InnerRequest, resp: &mut HttpResponse, - response_encoding: ContentEncoding, - ) { - let buf = self.take(); - let version = resp.version().unwrap_or_else(|| req.version); + let version = resp.version().unwrap_or_else(|| version); let mut len = 0; let has_body = match resp.body() { Body::Empty => false, Body::Binary(ref bin) => { len = bin.len(); - !(response_encoding == ContentEncoding::Auto && len < 96) + true } _ => true, }; - // Enable content encoding only if response does not contain Content-Encoding - // header - #[cfg(any(feature = "brotli", feature = "flate2"))] - let mut encoding = if has_body { - let encoding = match response_encoding { - ContentEncoding::Auto => { - // negotiate content-encoding - if let Some(val) = req.headers.get(ACCEPT_ENCODING) { - if let Ok(enc) = val.to_str() { - AcceptEncoding::parse(enc) - } else { - ContentEncoding::Identity - } - } else { - ContentEncoding::Identity - } - } - encoding => encoding, - }; - if encoding.is_compression() { - info.content_encoding = Some(encoding.as_str()); - } - encoding - } else { - ContentEncoding::Identity + let has_body = match resp.body() { + Body::Empty => false, + _ => true, }; - #[cfg(not(any(feature = "brotli", feature = "flate2")))] - let mut encoding = ContentEncoding::Identity; let transfer = match resp.body() { Body::Empty => { - if !info.head { - info.length = match resp.status() { + if !self.head { + self.length = match resp.status() { StatusCode::NO_CONTENT | StatusCode::CONTINUE | StatusCode::SWITCHING_PROTOCOLS | StatusCode::PROCESSING => ResponseLength::None, _ => ResponseLength::Zero, }; + } else { + self.length = ResponseLength::Zero; } - *self = Output::Empty(buf); - return; + TransferEncoding::empty() } Body::Binary(_) => { - #[cfg(any(feature = "brotli", feature = "flate2"))] - { - if !(encoding == ContentEncoding::Identity - || encoding == ContentEncoding::Auto) - { - let mut tmp = BytesMut::new(); - let mut transfer = TransferEncoding::eof(tmp); - let mut enc = match encoding { - #[cfg(feature = "flate2")] - ContentEncoding::Deflate => ContentEncoder::Deflate( - ZlibEncoder::new(transfer, Compression::fast()), - ), - #[cfg(feature = "flate2")] - ContentEncoding::Gzip => ContentEncoder::Gzip( - GzEncoder::new(transfer, Compression::fast()), - ), - #[cfg(feature = "brotli")] - ContentEncoding::Br => { - ContentEncoder::Br(BrotliEncoder::new(transfer, 3)) - } - ContentEncoding::Identity | ContentEncoding::Auto => { - unreachable!() - } - }; - - let bin = resp.replace_body(Body::Empty).binary(); - - // TODO return error! - let _ = enc.write(bin.as_ref()); - let _ = enc.write_eof(); - let body = enc.buf_mut().take(); - len = body.len(); - resp.replace_body(Binary::from(body)); - } - } - - info.length = ResponseLength::Length(len); - if info.head { - *self = Output::Empty(buf); - } else { - *self = Output::Buffer(buf); - } - return; + self.length = ResponseLength::Length(len); + TransferEncoding::length(len as u64) } Body::Streaming(_) => { if resp.upgrade() { - if version == Version::HTTP_2 { - error!("Connection upgrade is forbidden for HTTP/2"); - } - if encoding != ContentEncoding::Identity { - encoding = ContentEncoding::Identity; - info.content_encoding.take(); - } - TransferEncoding::eof(buf) + self.length = ResponseLength::None; + TransferEncoding::eof() } else { - if !(encoding == ContentEncoding::Identity - || encoding == ContentEncoding::Auto) - { - resp.headers_mut().remove(CONTENT_LENGTH); - } - Output::streaming_encoding(info, buf, version, resp) + self.streaming_encoding(version, resp) } } }; // check for head response - if info.head { + if self.head { resp.set_body(Body::Empty); - *self = Output::Empty(transfer.buf.unwrap()); - return; + } else { + self.te = transfer; } - - let enc = match encoding { - #[cfg(feature = "flate2")] - ContentEncoding::Deflate => { - ContentEncoder::Deflate(ZlibEncoder::new(transfer, Compression::fast())) - } - #[cfg(feature = "flate2")] - ContentEncoding::Gzip => { - ContentEncoder::Gzip(GzEncoder::new(transfer, Compression::fast())) - } - #[cfg(feature = "brotli")] - ContentEncoding::Br => ContentEncoder::Br(BrotliEncoder::new(transfer, 3)), - ContentEncoding::Identity | ContentEncoding::Auto => { - *self = Output::TE(transfer); - return; - } - }; - *self = Output::Encoder(enc); } fn streaming_encoding( - info: &mut ResponseInfo, buf: BytesMut, version: Version, - resp: &mut HttpResponse, + &mut self, version: Version, resp: &mut HttpResponse, ) -> TransferEncoding { match resp.chunked() { Some(true) => { // Enable transfer encoding if version == Version::HTTP_2 { - info.length = ResponseLength::None; - TransferEncoding::eof(buf) + self.length = ResponseLength::None; + TransferEncoding::eof() } else { - info.length = ResponseLength::Chunked; - TransferEncoding::chunked(buf) + self.length = ResponseLength::Chunked; + TransferEncoding::chunked() } } - Some(false) => TransferEncoding::eof(buf), + Some(false) => TransferEncoding::eof(), None => { // if Content-Length is specified, then use it as length hint let (len, chunked) = @@ -339,21 +134,21 @@ impl Output { if !chunked { if let Some(len) = len { - info.length = ResponseLength::Length64(len); - TransferEncoding::length(len, buf) + self.length = ResponseLength::Length64(len); + TransferEncoding::length(len) } else { - TransferEncoding::eof(buf) + TransferEncoding::eof() } } else { // Enable transfer encoding match version { Version::HTTP_11 => { - info.length = ResponseLength::Chunked; - TransferEncoding::chunked(buf) + self.length = ResponseLength::Chunked; + TransferEncoding::chunked() } _ => { - info.length = ResponseLength::None; - TransferEncoding::eof(buf) + self.length = ResponseLength::None; + TransferEncoding::eof() } } } @@ -362,178 +157,9 @@ impl Output { } } -pub(crate) enum ContentEncoder { - #[cfg(feature = "flate2")] - Deflate(ZlibEncoder), - #[cfg(feature = "flate2")] - Gzip(GzEncoder), - #[cfg(feature = "brotli")] - Br(BrotliEncoder), - Identity(TransferEncoding), -} - -impl fmt::Debug for ContentEncoder { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(_) => writeln!(f, "ContentEncoder(Brotli)"), - #[cfg(feature = "flate2")] - ContentEncoder::Deflate(_) => writeln!(f, "ContentEncoder(Deflate)"), - #[cfg(feature = "flate2")] - ContentEncoder::Gzip(_) => writeln!(f, "ContentEncoder(Gzip)"), - ContentEncoder::Identity(_) => writeln!(f, "ContentEncoder(Identity)"), - } - } -} - -impl ContentEncoder { - #[inline] - pub fn len(&self) -> usize { - match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(ref encoder) => encoder.get_ref().len(), - #[cfg(feature = "flate2")] - ContentEncoder::Deflate(ref encoder) => encoder.get_ref().len(), - #[cfg(feature = "flate2")] - ContentEncoder::Gzip(ref encoder) => encoder.get_ref().len(), - ContentEncoder::Identity(ref encoder) => encoder.len(), - } - } - - #[inline] - pub fn is_empty(&self) -> bool { - match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(ref encoder) => encoder.get_ref().is_empty(), - #[cfg(feature = "flate2")] - ContentEncoder::Deflate(ref encoder) => encoder.get_ref().is_empty(), - #[cfg(feature = "flate2")] - ContentEncoder::Gzip(ref encoder) => encoder.get_ref().is_empty(), - ContentEncoder::Identity(ref encoder) => encoder.is_empty(), - } - } - - #[inline] - pub(crate) fn take_buf(&mut self) -> BytesMut { - match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(ref mut encoder) => encoder.get_mut().take(), - #[cfg(feature = "flate2")] - ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().take(), - #[cfg(feature = "flate2")] - ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().take(), - ContentEncoder::Identity(ref mut encoder) => encoder.take(), - } - } - - #[inline] - pub(crate) fn buf_mut(&mut self) -> &mut BytesMut { - match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(ref mut encoder) => encoder.get_mut().buf_mut(), - #[cfg(feature = "flate2")] - ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().buf_mut(), - #[cfg(feature = "flate2")] - ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().buf_mut(), - ContentEncoder::Identity(ref mut encoder) => encoder.buf_mut(), - } - } - - #[inline] - pub(crate) fn buf_ref(&mut self) -> &BytesMut { - match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(ref mut encoder) => encoder.get_mut().buf_ref(), - #[cfg(feature = "flate2")] - ContentEncoder::Deflate(ref mut encoder) => encoder.get_mut().buf_ref(), - #[cfg(feature = "flate2")] - ContentEncoder::Gzip(ref mut encoder) => encoder.get_mut().buf_ref(), - ContentEncoder::Identity(ref mut encoder) => encoder.buf_ref(), - } - } - - #[cfg_attr(feature = "cargo-clippy", allow(clippy::inline_always))] - #[inline(always)] - pub fn write_eof(&mut self) -> Result { - let encoder = - mem::replace(self, ContentEncoder::Identity(TransferEncoding::empty())); - - match encoder { - #[cfg(feature = "brotli")] - ContentEncoder::Br(encoder) => match encoder.finish() { - Ok(mut writer) => { - writer.encode_eof(); - *self = ContentEncoder::Identity(writer); - Ok(true) - } - Err(err) => Err(err), - }, - #[cfg(feature = "flate2")] - ContentEncoder::Gzip(encoder) => match encoder.finish() { - Ok(mut writer) => { - writer.encode_eof(); - *self = ContentEncoder::Identity(writer); - Ok(true) - } - Err(err) => Err(err), - }, - #[cfg(feature = "flate2")] - ContentEncoder::Deflate(encoder) => match encoder.finish() { - Ok(mut writer) => { - writer.encode_eof(); - *self = ContentEncoder::Identity(writer); - Ok(true) - } - Err(err) => Err(err), - }, - ContentEncoder::Identity(mut writer) => { - let res = writer.encode_eof(); - *self = ContentEncoder::Identity(writer); - Ok(res) - } - } - } - - #[cfg_attr(feature = "cargo-clippy", allow(clippy::inline_always))] - #[inline(always)] - pub fn write(&mut self, data: &[u8]) -> Result<(), io::Error> { - match *self { - #[cfg(feature = "brotli")] - ContentEncoder::Br(ref mut encoder) => match encoder.write_all(data) { - Ok(_) => Ok(()), - Err(err) => { - trace!("Error decoding br encoding: {}", err); - Err(err) - } - }, - #[cfg(feature = "flate2")] - ContentEncoder::Gzip(ref mut encoder) => match encoder.write_all(data) { - Ok(_) => Ok(()), - Err(err) => { - trace!("Error decoding gzip encoding: {}", err); - Err(err) - } - }, - #[cfg(feature = "flate2")] - ContentEncoder::Deflate(ref mut encoder) => match encoder.write_all(data) { - Ok(_) => Ok(()), - Err(err) => { - trace!("Error decoding deflate encoding: {}", err); - Err(err) - } - }, - ContentEncoder::Identity(ref mut encoder) => { - encoder.encode(data)?; - Ok(()) - } - } - } -} - /// Encoders to handle different Transfer-Encodings. #[derive(Debug)] pub(crate) struct TransferEncoding { - buf: Option, kind: TransferEncodingKind, } @@ -552,65 +178,41 @@ enum TransferEncodingKind { } impl TransferEncoding { - fn take(&mut self) -> BytesMut { - self.buf.take().unwrap() - } - - fn buf_ref(&mut self) -> &BytesMut { - self.buf.as_ref().unwrap() - } - - fn len(&self) -> usize { - self.buf.as_ref().unwrap().len() - } - - fn is_empty(&self) -> bool { - self.buf.as_ref().unwrap().is_empty() - } - - fn buf_mut(&mut self) -> &mut BytesMut { - self.buf.as_mut().unwrap() - } - #[inline] pub fn empty() -> TransferEncoding { TransferEncoding { - buf: None, kind: TransferEncodingKind::Eof, } } #[inline] - pub fn eof(buf: BytesMut) -> TransferEncoding { + pub fn eof() -> TransferEncoding { TransferEncoding { - buf: Some(buf), kind: TransferEncodingKind::Eof, } } #[inline] - pub fn chunked(buf: BytesMut) -> TransferEncoding { + pub fn chunked() -> TransferEncoding { TransferEncoding { - buf: Some(buf), kind: TransferEncodingKind::Chunked(false), } } #[inline] - pub fn length(len: u64, buf: BytesMut) -> TransferEncoding { + pub fn length(len: u64) -> TransferEncoding { TransferEncoding { - buf: Some(buf), kind: TransferEncodingKind::Length(len), } } /// Encode message. Return `EOF` state of encoder #[inline] - pub fn encode(&mut self, msg: &[u8]) -> io::Result { + pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result { match self.kind { TransferEncodingKind::Eof => { let eof = msg.is_empty(); - self.buf.as_mut().unwrap().extend_from_slice(msg); + buf.extend_from_slice(msg); Ok(eof) } TransferEncodingKind::Chunked(ref mut eof) => { @@ -620,17 +222,14 @@ impl TransferEncoding { if msg.is_empty() { *eof = true; - self.buf.as_mut().unwrap().extend_from_slice(b"0\r\n\r\n"); + buf.extend_from_slice(b"0\r\n\r\n"); } else { - let mut buf = BytesMut::new(); - writeln!(&mut buf, "{:X}\r", msg.len()) + writeln!(buf.as_mut(), "{:X}\r", msg.len()) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - let b = self.buf.as_mut().unwrap(); - b.reserve(buf.len() + msg.len() + 2); - b.extend_from_slice(buf.as_ref()); - b.extend_from_slice(msg); - b.extend_from_slice(b"\r\n"); + buf.reserve(msg.len() + 2); + buf.extend_from_slice(msg); + buf.extend_from_slice(b"\r\n"); } Ok(*eof) } @@ -641,10 +240,7 @@ impl TransferEncoding { } let len = cmp::min(*remaining, msg.len() as u64); - self.buf - .as_mut() - .unwrap() - .extend_from_slice(&msg[..len as usize]); + buf.extend_from_slice(&msg[..len as usize]); *remaining -= len as u64; Ok(*remaining == 0) @@ -657,14 +253,14 @@ impl TransferEncoding { /// Encode eof. Return `EOF` state of encoder #[inline] - pub fn encode_eof(&mut self) -> bool { + pub fn encode_eof(&mut self, buf: &mut BytesMut) -> bool { match self.kind { TransferEncodingKind::Eof => true, TransferEncodingKind::Length(rem) => rem == 0, TransferEncodingKind::Chunked(ref mut eof) => { if !*eof { *eof = true; - self.buf.as_mut().unwrap().extend_from_slice(b"0\r\n\r\n"); + buf.extend_from_slice(b"0\r\n\r\n"); } true } @@ -675,9 +271,9 @@ impl TransferEncoding { impl io::Write for TransferEncoding { #[inline] fn write(&mut self, buf: &[u8]) -> io::Result { - if self.buf.is_some() { - self.encode(buf)?; - } + // if self.buf.is_some() { + // self.encode(buf)?; + // } Ok(buf.len()) }