1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-04 22:38:44 +00:00

refactor encoder/decoder impl

This commit is contained in:
Nikolay Kim 2018-11-18 17:52:56 -08:00
parent 8fea1367c7
commit adad203314
17 changed files with 255 additions and 234 deletions

View file

@ -91,9 +91,11 @@ native-tls = { version="0.2", optional = true }
# openssl # openssl
openssl = { version="0.10", optional = true } openssl = { version="0.10", optional = true }
#rustls # rustls
rustls = { version = "^0.14", optional = true } rustls = { version = "^0.14", optional = true }
backtrace="*"
[dev-dependencies] [dev-dependencies]
actix-web = "0.7" actix-web = "0.7"
env_logger = "0.5" env_logger = "0.5"

View file

@ -9,7 +9,7 @@ use error::{Error, PayloadError};
/// Type represent streaming payload /// Type represent streaming payload
pub type PayloadStream = Box<dyn Stream<Item = Bytes, Error = PayloadError>>; pub type PayloadStream = Box<dyn Stream<Item = Bytes, Error = PayloadError>>;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Copy, Clone)]
/// Different type of body /// Different type of body
pub enum BodyLength { pub enum BodyLength {
None, None,
@ -76,10 +76,11 @@ impl MessageBody for Body {
Body::None => Ok(Async::Ready(None)), Body::None => Ok(Async::Ready(None)),
Body::Empty => Ok(Async::Ready(None)), Body::Empty => Ok(Async::Ready(None)),
Body::Bytes(ref mut bin) => { Body::Bytes(ref mut bin) => {
if bin.len() == 0 { let len = bin.len();
if len == 0 {
Ok(Async::Ready(None)) Ok(Async::Ready(None))
} else { } else {
Ok(Async::Ready(Some(bin.slice_to(bin.len())))) Ok(Async::Ready(Some(bin.split_to(len))))
} }
} }
Body::Message(ref mut body) => body.poll_next(), Body::Message(ref mut body) => body.poll_next(),

View file

@ -33,7 +33,7 @@ where
.from_err() .from_err()
// create Framed and send reqest // create Framed and send reqest
.map(|io| Framed::new(io, h1::ClientCodec::default())) .map(|io| Framed::new(io, h1::ClientCodec::default()))
.and_then(|framed| framed.send((head, len).into()).from_err()) .and_then(move |framed| framed.send((head, len).into()).from_err())
// send request body // send request body
.and_then(move |framed| match body.length() { .and_then(move |framed| match body.length() {
BodyLength::None | BodyLength::Empty | BodyLength::Sized(0) => { BodyLength::None | BodyLength::Empty | BodyLength::Sized(0) => {

View file

@ -16,7 +16,7 @@ use http::{
uri, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method, uri, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method,
Uri, Version, Uri, Version,
}; };
use message::RequestHead; use message::{Head, RequestHead};
use super::response::ClientResponse; use super::response::ClientResponse;
use super::{pipeline, Connect, Connection, ConnectorError, SendRequestError}; use super::{pipeline, Connect, Connection, ConnectorError, SendRequestError};
@ -365,8 +365,21 @@ impl ClientRequestBuilder {
where where
V: IntoHeaderValue, V: IntoHeaderValue,
{ {
{
if let Some(parts) = parts(&mut self.head, &self.err) {
parts.set_upgrade();
}
}
self.set_header(header::UPGRADE, value) self.set_header(header::UPGRADE, value)
.set_header(header::CONNECTION, "upgrade") }
/// Close connection
#[inline]
pub fn close(&mut self) -> &mut Self {
if let Some(parts) = parts(&mut self.head, &self.err) {
parts.force_close();
}
self
} }
/// Set request's content type /// Set request's content type

View file

@ -8,7 +8,7 @@ use http::{HeaderMap, StatusCode, Version};
use body::PayloadStream; use body::PayloadStream;
use error::PayloadError; use error::PayloadError;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use message::{MessageFlags, ResponseHead}; use message::{Head, ResponseHead};
use super::pipeline::Payload; use super::pipeline::Payload;
@ -81,7 +81,7 @@ impl ClientResponse {
/// Checks if a connection should be kept alive. /// Checks if a connection should be kept alive.
#[inline] #[inline]
pub fn keep_alive(&self) -> bool { pub fn keep_alive(&self) -> bool {
self.head().flags.contains(MessageFlags::KEEPALIVE) self.head().keep_alive()
} }
} }

View file

@ -16,7 +16,7 @@ use http::header::{
HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE,
}; };
use http::{Method, Version}; use http::{Method, Version};
use message::{MessagePool, RequestHead}; use message::{Head, MessagePool, RequestHead};
bitflags! { bitflags! {
struct Flags: u8 { struct Flags: u8 {
@ -135,7 +135,7 @@ fn prn_version(ver: Version) -> &'static str {
} }
impl ClientCodecInner { impl ClientCodecInner {
fn encode_response( fn encode_request(
&mut self, &mut self,
msg: RequestHead, msg: RequestHead,
length: BodyLength, length: BodyLength,
@ -146,7 +146,7 @@ impl ClientCodecInner {
// status line // status line
write!( write!(
Writer(buffer), Writer(buffer),
"{} {} {}", "{} {} {}\r\n",
msg.method, msg.method,
msg.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"), msg.uri.path_and_query().map(|u| u.as_str()).unwrap_or("/"),
prn_version(msg.version) prn_version(msg.version)
@ -156,38 +156,26 @@ impl ClientCodecInner {
buffer.reserve(msg.headers.len() * AVERAGE_HEADER_SIZE); buffer.reserve(msg.headers.len() * AVERAGE_HEADER_SIZE);
// content length // content length
let mut len_is_set = true;
match length { match length {
BodyLength::Sized(len) => helpers::write_content_length(len, buffer), BodyLength::Sized(len) => helpers::write_content_length(len, buffer),
BodyLength::Sized64(len) => { BodyLength::Sized64(len) => {
buffer.extend_from_slice(b"\r\ncontent-length: "); buffer.extend_from_slice(b"content-length: ");
write!(buffer.writer(), "{}", len)?; write!(buffer.writer(), "{}", len)?;
buffer.extend_from_slice(b"\r\n"); buffer.extend_from_slice(b"\r\n");
} }
BodyLength::Chunked => { BodyLength::Chunked => {
buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") buffer.extend_from_slice(b"transfer-encoding: chunked\r\n")
}
BodyLength::Empty => {
len_is_set = false;
buffer.extend_from_slice(b"\r\n")
}
BodyLength::None | BodyLength::Stream => {
buffer.extend_from_slice(b"\r\n")
} }
BodyLength::Empty => buffer.extend_from_slice(b"content-length: 0\r\n"),
BodyLength::None | BodyLength::Stream => (),
} }
let mut has_date = false; let mut has_date = false;
for (key, value) in &msg.headers { for (key, value) in &msg.headers {
match *key { match *key {
TRANSFER_ENCODING => continue, TRANSFER_ENCODING | CONNECTION | CONTENT_LENGTH => continue,
CONTENT_LENGTH => match length {
BodyLength::None => (),
BodyLength::Empty => len_is_set = true,
_ => continue,
},
DATE => has_date = true, DATE => has_date = true,
UPGRADE => self.flags.insert(Flags::UPGRADE),
_ => (), _ => (),
} }
@ -197,12 +185,19 @@ impl ClientCodecInner {
buffer.put_slice(b"\r\n"); buffer.put_slice(b"\r\n");
} }
// set content length // Connection header
if !len_is_set { if msg.upgrade() {
buffer.extend_from_slice(b"content-length: 0\r\n") self.flags.set(Flags::UPGRADE, msg.upgrade());
buffer.extend_from_slice(b"connection: upgrade\r\n");
} else if msg.keep_alive() {
if self.version < Version::HTTP_11 {
buffer.extend_from_slice(b"connection: keep-alive\r\n");
}
} else if self.version >= Version::HTTP_11 {
buffer.extend_from_slice(b"connection: close\r\n");
} }
// set date header // Date header
if !has_date { if !has_date {
self.config.set_date(buffer); self.config.set_date(buffer);
} else { } else {
@ -276,7 +271,7 @@ impl Encoder for ClientCodec {
) -> Result<(), Self::Error> { ) -> Result<(), Self::Error> {
match item { match item {
Message::Item((msg, btype)) => { Message::Item((msg, btype)) => {
self.inner.encode_response(msg, btype, dst)?; self.inner.encode_request(msg, btype, dst)?;
} }
Message::Chunk(Some(bytes)) => { Message::Chunk(Some(bytes)) => {
self.inner.te.encode(bytes.as_ref(), dst)?; self.inner.te.encode(bytes.as_ref(), dst)?;

View file

@ -13,8 +13,8 @@ use config::ServiceConfig;
use error::ParseError; use error::ParseError;
use helpers; use helpers;
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING}; use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{Method, Version}; use http::{Method, StatusCode, Version};
use message::ResponseHead; use message::{Head, ResponseHead};
use request::Request; use request::Request;
use response::Response; use response::Response;
@ -99,69 +99,71 @@ impl Codec {
} }
/// prepare transfer encoding /// prepare transfer encoding
pub fn prepare_te(&mut self, head: &mut ResponseHead, length: &mut BodyLength) { fn prepare_te(&mut self, head: &mut ResponseHead, length: BodyLength) {
self.te self.te
.update(head, self.flags.contains(Flags::HEAD), self.version, length); .update(head, self.flags.contains(Flags::HEAD), self.version, length);
} }
fn encode_response( fn encode_response(
&mut self, &mut self,
mut msg: Response<()>, msg: &mut ResponseHead,
length: BodyLength,
buffer: &mut BytesMut, buffer: &mut BytesMut,
) -> io::Result<()> { ) -> io::Result<()> {
let ka = self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg msg.version = self.version;
.keep_alive()
.unwrap_or_else(|| self.flags.contains(Flags::KEEPALIVE));
// Connection upgrade // Connection upgrade
if msg.upgrade() { if msg.upgrade() {
self.flags.insert(Flags::UPGRADE); self.flags.insert(Flags::UPGRADE);
self.flags.remove(Flags::KEEPALIVE); self.flags.remove(Flags::KEEPALIVE);
msg.headers_mut() msg.headers
.insert(CONNECTION, HeaderValue::from_static("upgrade")); .insert(CONNECTION, HeaderValue::from_static("upgrade"));
} }
// keep-alive // keep-alive
else if ka { else if self.flags.contains(Flags::KEEPALIVE_ENABLED) && msg.keep_alive() {
self.flags.insert(Flags::KEEPALIVE); self.flags.insert(Flags::KEEPALIVE);
if self.version < Version::HTTP_11 { if self.version < Version::HTTP_11 {
msg.headers_mut() msg.headers
.insert(CONNECTION, HeaderValue::from_static("keep-alive")); .insert(CONNECTION, HeaderValue::from_static("keep-alive"));
} }
} else if self.version >= Version::HTTP_11 { } else if self.version >= Version::HTTP_11 {
self.flags.remove(Flags::KEEPALIVE); self.flags.remove(Flags::KEEPALIVE);
msg.headers_mut() msg.headers
.insert(CONNECTION, HeaderValue::from_static("close")); .insert(CONNECTION, HeaderValue::from_static("close"));
} }
// render message // render message
{ {
let reason = msg.reason().as_bytes(); let reason = msg.reason().as_bytes();
buffer buffer.reserve(256 + msg.headers.len() * AVERAGE_HEADER_SIZE + reason.len());
.reserve(256 + msg.headers().len() * AVERAGE_HEADER_SIZE + reason.len());
// status line // status line
helpers::write_status_line(self.version, msg.status().as_u16(), buffer); helpers::write_status_line(self.version, msg.status.as_u16(), buffer);
buffer.extend_from_slice(reason); buffer.extend_from_slice(reason);
// content length // content length
let mut len_is_set = true; match msg.status {
match self.te.length { StatusCode::NO_CONTENT
BodyLength::Chunked => { | StatusCode::CONTINUE
buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n") | StatusCode::SWITCHING_PROTOCOLS
} | StatusCode::PROCESSING => buffer.extend_from_slice(b"\r\n"),
BodyLength::Empty => { _ => match length {
len_is_set = false; BodyLength::Chunked => {
buffer.extend_from_slice(b"\r\n") buffer.extend_from_slice(b"\r\ntransfer-encoding: chunked\r\n")
} }
BodyLength::Sized(len) => helpers::write_content_length(len, buffer), BodyLength::Empty => {
BodyLength::Sized64(len) => { buffer.extend_from_slice(b"\r\ncontent-length: 0\r\n");
buffer.extend_from_slice(b"\r\ncontent-length: "); }
write!(buffer.writer(), "{}", len)?; BodyLength::Sized(len) => helpers::write_content_length(len, buffer),
buffer.extend_from_slice(b"\r\n"); BodyLength::Sized64(len) => {
} buffer.extend_from_slice(b"\r\ncontent-length: ");
BodyLength::None | BodyLength::Stream => { write!(buffer.writer(), "{}", len)?;
buffer.extend_from_slice(b"\r\n") buffer.extend_from_slice(b"\r\n");
} }
BodyLength::None | BodyLength::Stream => {
buffer.extend_from_slice(b"\r\n")
}
},
} }
// write headers // write headers
@ -169,16 +171,9 @@ impl Codec {
let mut has_date = false; let mut has_date = false;
let mut remaining = buffer.remaining_mut(); let mut remaining = buffer.remaining_mut();
let mut buf = unsafe { &mut *(buffer.bytes_mut() as *mut [u8]) }; let mut buf = unsafe { &mut *(buffer.bytes_mut() as *mut [u8]) };
for (key, value) in msg.headers() { for (key, value) in &msg.headers {
match *key { match *key {
TRANSFER_ENCODING => continue, TRANSFER_ENCODING | CONTENT_LENGTH => continue,
CONTENT_LENGTH => match self.te.length {
BodyLength::None => (),
BodyLength::Empty => {
len_is_set = true;
}
_ => continue,
},
DATE => { DATE => {
has_date = true; has_date = true;
} }
@ -213,9 +208,6 @@ impl Codec {
unsafe { unsafe {
buffer.advance_mut(pos); buffer.advance_mut(pos);
} }
if !len_is_set {
buffer.extend_from_slice(b"content-length: 0\r\n")
}
// optimized date header, set_date writes \r\n // optimized date header, set_date writes \r\n
if !has_date { if !has_date {
@ -268,7 +260,7 @@ impl Decoder for Codec {
} }
impl Encoder for Codec { impl Encoder for Codec {
type Item = Message<Response<()>>; type Item = Message<(Response<()>, BodyLength)>;
type Error = io::Error; type Error = io::Error;
fn encode( fn encode(
@ -277,8 +269,9 @@ impl Encoder for Codec {
dst: &mut BytesMut, dst: &mut BytesMut,
) -> Result<(), Self::Error> { ) -> Result<(), Self::Error> {
match item { match item {
Message::Item(res) => { Message::Item((mut res, length)) => {
self.encode_response(res, dst)?; self.prepare_te(res.head_mut(), length);
self.encode_response(res.head_mut(), length, dst)?;
} }
Message::Chunk(Some(bytes)) => { Message::Chunk(Some(bytes)) => {
self.te.encode(bytes.as_ref(), dst)?; self.te.encode(bytes.as_ref(), dst)?;

View file

@ -10,7 +10,7 @@ use client::ClientResponse;
use error::ParseError; use error::ParseError;
use http::header::{HeaderName, HeaderValue}; use http::header::{HeaderName, HeaderValue};
use http::{header, HeaderMap, HttpTryFrom, Method, StatusCode, Uri, Version}; use http::{header, HeaderMap, HttpTryFrom, Method, StatusCode, Uri, Version};
use message::MessageFlags; use message::Head;
use request::Request; use request::Request;
const MAX_BUFFER_SIZE: usize = 131_072; const MAX_BUFFER_SIZE: usize = 131_072;
@ -50,6 +50,8 @@ pub(crate) enum PayloadLength {
pub(crate) trait MessageTypeDecoder: Sized { pub(crate) trait MessageTypeDecoder: Sized {
fn keep_alive(&mut self); fn keep_alive(&mut self);
fn force_close(&mut self);
fn headers_mut(&mut self) -> &mut HeaderMap; fn headers_mut(&mut self) -> &mut HeaderMap;
fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError>; fn decode(src: &mut BytesMut) -> Result<Option<(Self, PayloadType)>, ParseError>;
@ -137,6 +139,8 @@ pub(crate) trait MessageTypeDecoder: Sized {
if ka { if ka {
self.keep_alive(); self.keep_alive();
} else {
self.force_close();
} }
// https://tools.ietf.org/html/rfc7230#section-3.3.3 // https://tools.ietf.org/html/rfc7230#section-3.3.3
@ -160,7 +164,11 @@ pub(crate) trait MessageTypeDecoder: Sized {
impl MessageTypeDecoder for Request { impl MessageTypeDecoder for Request {
fn keep_alive(&mut self) { fn keep_alive(&mut self) {
self.inner_mut().flags.set(MessageFlags::KEEPALIVE); self.inner_mut().head.set_keep_alive()
}
fn force_close(&mut self) {
self.inner_mut().head.force_close()
} }
fn headers_mut(&mut self) -> &mut HeaderMap { fn headers_mut(&mut self) -> &mut HeaderMap {
@ -234,7 +242,11 @@ impl MessageTypeDecoder for Request {
impl MessageTypeDecoder for ClientResponse { impl MessageTypeDecoder for ClientResponse {
fn keep_alive(&mut self) { fn keep_alive(&mut self) {
self.head.flags.insert(MessageFlags::KEEPALIVE); self.head.set_keep_alive();
}
fn force_close(&mut self) {
self.head.force_close();
} }
fn headers_mut(&mut self) -> &mut HeaderMap { fn headers_mut(&mut self) -> &mut HeaderMap {

View file

@ -30,7 +30,6 @@ bitflags! {
const KEEPALIVE_ENABLED = 0b0000_0010; const KEEPALIVE_ENABLED = 0b0000_0010;
const KEEPALIVE = 0b0000_0100; const KEEPALIVE = 0b0000_0100;
const POLLED = 0b0000_1000; const POLLED = 0b0000_1000;
const FLUSHED = 0b0001_0000;
const SHUTDOWN = 0b0010_0000; const SHUTDOWN = 0b0010_0000;
const DISCONNECTED = 0b0100_0000; const DISCONNECTED = 0b0100_0000;
} }
@ -105,9 +104,9 @@ where
) -> Self { ) -> Self {
let keepalive = config.keep_alive_enabled(); let keepalive = config.keep_alive_enabled();
let flags = if keepalive { let flags = if keepalive {
Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED | Flags::FLUSHED Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED
} else { } else {
Flags::FLUSHED Flags::empty()
}; };
let framed = Framed::new(stream, Codec::new(config.clone())); let framed = Framed::new(stream, Codec::new(config.clone()));
@ -167,7 +166,7 @@ where
/// Flush stream /// Flush stream
fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> { fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> {
if !self.flags.contains(Flags::FLUSHED) { if !self.framed.is_write_buf_empty() {
match self.framed.poll_complete() { match self.framed.poll_complete() {
Ok(Async::NotReady) => Ok(Async::NotReady), Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => { Err(err) => {
@ -179,7 +178,6 @@ where
if self.payload.is_some() && self.state.is_empty() { if self.payload.is_some() && self.state.is_empty() {
return Err(DispatchError::PayloadIsNotConsumed); return Err(DispatchError::PayloadIsNotConsumed);
} }
self.flags.insert(Flags::FLUSHED);
Ok(Async::Ready(())) Ok(Async::Ready(()))
} }
} }
@ -194,7 +192,7 @@ where
body: B1, body: B1,
) -> Result<State<S, B1>, DispatchError<S::Error>> { ) -> Result<State<S, B1>, DispatchError<S::Error>> {
self.framed self.framed
.force_send(Message::Item(message)) .force_send(Message::Item((message, body.length())))
.map_err(|err| { .map_err(|err| {
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete(None)); payload.set_error(PayloadError::Incomplete(None));
@ -204,7 +202,6 @@ where
self.flags self.flags
.set(Flags::KEEPALIVE, self.framed.get_codec().keepalive()); .set(Flags::KEEPALIVE, self.framed.get_codec().keepalive());
self.flags.remove(Flags::FLUSHED);
match body.length() { match body.length() {
BodyLength::None | BodyLength::Empty => Ok(State::None), BodyLength::None | BodyLength::Empty => Ok(State::None),
_ => Ok(State::SendPayload(body)), _ => Ok(State::SendPayload(body)),
@ -228,10 +225,7 @@ where
State::ServiceCall(mut fut) => { State::ServiceCall(mut fut) => {
match fut.poll().map_err(DispatchError::Service)? { match fut.poll().map_err(DispatchError::Service)? {
Async::Ready(mut res) => { Async::Ready(mut res) => {
let (mut res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
self.framed
.get_codec_mut()
.prepare_te(res.head_mut(), &mut body.length());
Some(self.send_response(res, body)?) Some(self.send_response(res, body)?)
} }
Async::NotReady => { Async::NotReady => {
@ -248,13 +242,11 @@ where
.map_err(|_| DispatchError::Unknown)? .map_err(|_| DispatchError::Unknown)?
{ {
Async::Ready(Some(item)) => { Async::Ready(Some(item)) => {
self.flags.remove(Flags::FLUSHED);
self.framed self.framed
.force_send(Message::Chunk(Some(item)))?; .force_send(Message::Chunk(Some(item)))?;
continue; continue;
} }
Async::Ready(None) => { Async::Ready(None) => {
self.flags.remove(Flags::FLUSHED);
self.framed.force_send(Message::Chunk(None))?; self.framed.force_send(Message::Chunk(None))?;
} }
Async::NotReady => { Async::NotReady => {
@ -296,10 +288,7 @@ where
let mut task = self.service.call(req); let mut task = self.service.call(req);
match task.poll().map_err(DispatchError::Service)? { match task.poll().map_err(DispatchError::Service)? {
Async::Ready(res) => { Async::Ready(res) => {
let (mut res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
self.framed
.get_codec_mut()
.prepare_te(res.head_mut(), &mut body.length());
self.send_response(res, body) self.send_response(res, body)
} }
Async::NotReady => Ok(State::ServiceCall(task)), Async::NotReady => Ok(State::ServiceCall(task)),
@ -408,7 +397,7 @@ where
/// keep-alive timer /// keep-alive timer
fn poll_keepalive(&mut self) -> Result<(), DispatchError<S::Error>> { fn poll_keepalive(&mut self) -> Result<(), DispatchError<S::Error>> {
if self.ka_timer.is_some() { if self.ka_timer.is_none() {
return Ok(()); return Ok(());
} }
match self.ka_timer.as_mut().unwrap().poll().map_err(|e| { match self.ka_timer.as_mut().unwrap().poll().map_err(|e| {
@ -421,7 +410,7 @@ where
return Err(DispatchError::DisconnectTimeout); return Err(DispatchError::DisconnectTimeout);
} else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire { } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire {
// check for any outstanding response processing // check for any outstanding response processing
if self.state.is_empty() && self.flags.contains(Flags::FLUSHED) { if self.state.is_empty() && self.framed.is_write_buf_empty() {
if self.flags.contains(Flags::STARTED) { if self.flags.contains(Flags::STARTED) {
trace!("Keep-alive timeout, close connection"); trace!("Keep-alive timeout, close connection");
self.flags.insert(Flags::SHUTDOWN); self.flags.insert(Flags::SHUTDOWN);
@ -490,12 +479,14 @@ where
inner.poll_response()?; inner.poll_response()?;
inner.poll_flush()?; inner.poll_flush()?;
if inner.flags.contains(Flags::DISCONNECTED) {
return Ok(Async::Ready(H1ServiceResult::Disconnected));
}
// keep-alive and stream errors // keep-alive and stream errors
if inner.state.is_empty() && inner.flags.contains(Flags::FLUSHED) { if inner.state.is_empty() && inner.framed.is_write_buf_empty() {
if let Some(err) = inner.error.take() { if let Some(err) = inner.error.take() {
return Err(err); return Err(err);
} else if inner.flags.contains(Flags::DISCONNECTED) {
return Ok(Async::Ready(H1ServiceResult::Disconnected));
} }
// unhandled request (upgrade or connect) // unhandled request (upgrade or connect)
else if inner.unhandled.is_some() { else if inner.unhandled.is_some() {

View file

@ -48,22 +48,13 @@ impl ResponseEncoder {
resp: &mut ResponseHead, resp: &mut ResponseHead,
head: bool, head: bool,
version: Version, version: Version,
length: &mut BodyLength, length: BodyLength,
) { ) {
self.head = head; self.head = head;
let transfer = match length { let transfer = match length {
BodyLength::Empty => { BodyLength::Empty => TransferEncoding::empty(),
match resp.status { BodyLength::Sized(len) => TransferEncoding::length(len as u64),
StatusCode::NO_CONTENT BodyLength::Sized64(len) => TransferEncoding::length(len),
| StatusCode::CONTINUE
| StatusCode::SWITCHING_PROTOCOLS
| StatusCode::PROCESSING => *length = BodyLength::None,
_ => (),
}
TransferEncoding::empty()
}
BodyLength::Sized(len) => TransferEncoding::length(*len as u64),
BodyLength::Sized64(len) => TransferEncoding::length(*len),
BodyLength::Chunked => TransferEncoding::chunked(), BodyLength::Chunked => TransferEncoding::chunked(),
BodyLength::Stream => TransferEncoding::eof(), BodyLength::Stream => TransferEncoding::eof(),
BodyLength::None => TransferEncoding::length(0), BodyLength::None => TransferEncoding::length(0),

View file

@ -109,6 +109,8 @@ extern crate serde_derive;
#[cfg(feature = "ssl")] #[cfg(feature = "ssl")]
extern crate openssl; extern crate openssl;
extern crate backtrace;
pub mod body; pub mod body;
pub mod client; pub mod client;
mod config; mod config;
@ -173,5 +175,4 @@ pub mod http {
pub use header::*; pub use header::*;
} }
pub use header::ContentEncoding; pub use header::ContentEncoding;
pub use response::ConnectionType;
} }

View file

@ -12,12 +12,41 @@ use uri::Url;
pub trait Head: Default + 'static { pub trait Head: Default + 'static {
fn clear(&mut self); fn clear(&mut self);
fn flags(&self) -> MessageFlags;
fn flags_mut(&mut self) -> &mut MessageFlags;
fn pool() -> &'static MessagePool<Self>; fn pool() -> &'static MessagePool<Self>;
/// Set upgrade
fn set_upgrade(&mut self) {
*self.flags_mut() = MessageFlags::UPGRADE;
}
/// Check if request is upgrade request
fn upgrade(&self) -> bool {
self.flags().contains(MessageFlags::UPGRADE)
}
/// Set keep-alive
fn set_keep_alive(&mut self) {
*self.flags_mut() = MessageFlags::KEEP_ALIVE;
}
/// Check if request is keep-alive
fn keep_alive(&self) -> bool;
/// Set force-close connection
fn force_close(&mut self) {
*self.flags_mut() = MessageFlags::FORCE_CLOSE;
}
} }
bitflags! { bitflags! {
pub(crate) struct MessageFlags: u8 { pub struct MessageFlags: u8 {
const KEEPALIVE = 0b0000_0001; const KEEP_ALIVE = 0b0000_0001;
const FORCE_CLOSE = 0b0000_0010;
const UPGRADE = 0b0000_0100;
} }
} }
@ -47,6 +76,25 @@ impl Head for RequestHead {
self.flags = MessageFlags::empty(); self.flags = MessageFlags::empty();
} }
fn flags(&self) -> MessageFlags {
self.flags
}
fn flags_mut(&mut self) -> &mut MessageFlags {
&mut self.flags
}
/// Check if request is keep-alive
fn keep_alive(&self) -> bool {
if self.flags().contains(MessageFlags::FORCE_CLOSE) {
false
} else if self.flags().contains(MessageFlags::KEEP_ALIVE) {
true
} else {
self.version <= Version::HTTP_11
}
}
fn pool() -> &'static MessagePool<Self> { fn pool() -> &'static MessagePool<Self> {
REQUEST_POOL.with(|p| *p) REQUEST_POOL.with(|p| *p)
} }
@ -79,11 +127,44 @@ impl Head for ResponseHead {
self.flags = MessageFlags::empty(); self.flags = MessageFlags::empty();
} }
fn flags(&self) -> MessageFlags {
self.flags
}
fn flags_mut(&mut self) -> &mut MessageFlags {
&mut self.flags
}
/// Check if response is keep-alive
fn keep_alive(&self) -> bool {
if self.flags().contains(MessageFlags::FORCE_CLOSE) {
false
} else if self.flags().contains(MessageFlags::KEEP_ALIVE) {
true
} else {
self.version <= Version::HTTP_11
}
}
fn pool() -> &'static MessagePool<Self> { fn pool() -> &'static MessagePool<Self> {
RESPONSE_POOL.with(|p| *p) RESPONSE_POOL.with(|p| *p)
} }
} }
impl ResponseHead {
/// Get custom reason for the response
#[inline]
pub fn reason(&self) -> &str {
if let Some(reason) = self.reason {
reason
} else {
self.status
.canonical_reason()
.unwrap_or("<unknown status code>")
}
}
}
pub struct Message<T: Head> { pub struct Message<T: Head> {
pub head: T, pub head: T,
pub url: Url, pub url: Url,

View file

@ -8,7 +8,7 @@ use extensions::Extensions;
use httpmessage::HttpMessage; use httpmessage::HttpMessage;
use payload::Payload; use payload::Payload;
use message::{Message, MessageFlags, MessagePool, RequestHead}; use message::{Head, Message, MessagePool, RequestHead};
/// Request /// Request
pub struct Request { pub struct Request {
@ -116,7 +116,7 @@ impl Request {
/// Checks if a connection should be kept alive. /// Checks if a connection should be kept alive.
#[inline] #[inline]
pub fn keep_alive(&self) -> bool { pub fn keep_alive(&self) -> bool {
self.inner().flags.get().contains(MessageFlags::KEEPALIVE) self.inner().head.keep_alive()
} }
/// Request extensions /// Request extensions

View file

@ -20,17 +20,6 @@ use message::{Head, MessageFlags, ResponseHead};
/// max write buffer size 64k /// max write buffer size 64k
pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536; pub(crate) const MAX_WRITE_BUFFER_SIZE: usize = 65_536;
/// Represents various types of connection
#[derive(Copy, Clone, PartialEq, Debug)]
pub enum ConnectionType {
/// Close connection after response
Close,
/// Keep connection alive after response
KeepAlive,
/// Connection is upgraded to different type
Upgrade,
}
/// An HTTP Response /// An HTTP Response
pub struct Response<B: MessageBody = Body>(Box<InnerResponse>, B); pub struct Response<B: MessageBody = Body>(Box<InnerResponse>, B);
@ -124,27 +113,6 @@ impl<B: MessageBody> Response<B> {
&mut self.get_mut().head.status &mut self.get_mut().head.status
} }
/// Get custom reason for the response
#[inline]
pub fn reason(&self) -> &str {
if let Some(reason) = self.get_ref().head.reason {
reason
} else {
self.get_ref()
.head
.status
.canonical_reason()
.unwrap_or("<unknown status code>")
}
}
/// Set the custom reason for the response
#[inline]
pub fn set_reason(&mut self, reason: &'static str) -> &mut Self {
self.get_mut().head.reason = Some(reason);
self
}
/// Get the headers from the response /// Get the headers from the response
#[inline] #[inline]
pub fn headers(&self) -> &HeaderMap { pub fn headers(&self) -> &HeaderMap {
@ -207,28 +175,15 @@ impl<B: MessageBody> Response<B> {
count count
} }
/// Set connection type
pub fn set_connection_type(&mut self, conn: ConnectionType) -> &mut Self {
self.get_mut().connection_type = Some(conn);
self
}
/// Connection upgrade status /// Connection upgrade status
#[inline] #[inline]
pub fn upgrade(&self) -> bool { pub fn upgrade(&self) -> bool {
self.get_ref().connection_type == Some(ConnectionType::Upgrade) self.get_ref().head.upgrade()
} }
/// Keep-alive status for this connection /// Keep-alive status for this connection
pub fn keep_alive(&self) -> Option<bool> { pub fn keep_alive(&self) -> bool {
if let Some(ct) = self.get_ref().connection_type { self.get_ref().head.keep_alive()
match ct {
ConnectionType::KeepAlive => Some(true),
ConnectionType::Close | ConnectionType::Upgrade => Some(false),
}
} else {
None
}
} }
/// Get body os this response /// Get body os this response
@ -275,19 +230,20 @@ impl<B: MessageBody> Response<B> {
} }
} }
impl fmt::Debug for Response { impl<B: MessageBody> fmt::Debug for Response<B> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let res = writeln!( let res = writeln!(
f, f,
"\nResponse {:?} {}{}", "\nResponse {:?} {}{}",
self.get_ref().head.version, self.get_ref().head.version,
self.get_ref().head.status, self.get_ref().head.status,
self.get_ref().head.reason.unwrap_or("") self.get_ref().head.reason.unwrap_or(""),
); );
let _ = writeln!(f, " headers:"); let _ = writeln!(f, " headers:");
for (key, val) in self.get_ref().head.headers.iter() { for (key, val) in self.get_ref().head.headers.iter() {
let _ = writeln!(f, " {:?}: {:?}", key, val); let _ = writeln!(f, " {:?}: {:?}", key, val);
} }
let _ = writeln!(f, " body: {:?}", self.body().length());
res res
} }
} }
@ -400,27 +356,31 @@ impl ResponseBuilder {
self self
} }
/// Set connection type /// Set connection type to KeepAlive
#[inline] #[inline]
#[doc(hidden)] pub fn keep_alive(&mut self) -> &mut Self {
pub fn connection_type(&mut self, conn: ConnectionType) -> &mut Self {
if let Some(parts) = parts(&mut self.response, &self.err) { if let Some(parts) = parts(&mut self.response, &self.err) {
parts.connection_type = Some(conn); parts.head.set_keep_alive();
} }
self self
} }
/// Set connection type to Upgrade /// Set connection type to Upgrade
#[inline] #[inline]
#[doc(hidden)]
pub fn upgrade(&mut self) -> &mut Self { pub fn upgrade(&mut self) -> &mut Self {
self.connection_type(ConnectionType::Upgrade) if let Some(parts) = parts(&mut self.response, &self.err) {
parts.head.set_upgrade();
}
self
} }
/// Force close connection, even if it is marked as keep-alive /// Force close connection, even if it is marked as keep-alive
#[inline] #[inline]
pub fn force_close(&mut self) -> &mut Self { pub fn force_close(&mut self) -> &mut Self {
self.connection_type(ConnectionType::Close) if let Some(parts) = parts(&mut self.response, &self.err) {
parts.head.force_close();
}
self
} }
/// Set response content type /// Set response content type
@ -719,8 +679,6 @@ impl From<BytesMut> for Response {
struct InnerResponse { struct InnerResponse {
head: ResponseHead, head: ResponseHead,
connection_type: Option<ConnectionType>,
write_capacity: usize,
response_size: u64, response_size: u64,
error: Option<Error>, error: Option<Error>,
pool: &'static ResponsePool, pool: &'static ResponsePool,
@ -728,7 +686,6 @@ struct InnerResponse {
pub(crate) struct ResponseParts { pub(crate) struct ResponseParts {
head: ResponseHead, head: ResponseHead,
connection_type: Option<ConnectionType>,
error: Option<Error>, error: Option<Error>,
} }
@ -744,9 +701,7 @@ impl InnerResponse {
flags: MessageFlags::empty(), flags: MessageFlags::empty(),
}, },
pool, pool,
connection_type: None,
response_size: 0, response_size: 0,
write_capacity: MAX_WRITE_BUFFER_SIZE,
error: None, error: None,
} }
} }
@ -755,7 +710,6 @@ impl InnerResponse {
fn into_parts(self) -> ResponseParts { fn into_parts(self) -> ResponseParts {
ResponseParts { ResponseParts {
head: self.head, head: self.head,
connection_type: self.connection_type,
error: self.error, error: self.error,
} }
} }
@ -763,9 +717,7 @@ impl InnerResponse {
fn from_parts(parts: ResponseParts) -> InnerResponse { fn from_parts(parts: ResponseParts) -> InnerResponse {
InnerResponse { InnerResponse {
head: parts.head, head: parts.head,
connection_type: parts.connection_type,
response_size: 0, response_size: 0,
write_capacity: MAX_WRITE_BUFFER_SIZE,
error: parts.error, error: parts.error,
pool: ResponsePool::pool(), pool: ResponsePool::pool(),
} }
@ -838,10 +790,8 @@ impl ResponsePool {
let mut p = inner.pool.0.borrow_mut(); let mut p = inner.pool.0.borrow_mut();
if p.len() < 128 { if p.len() < 128 {
inner.head.clear(); inner.head.clear();
inner.connection_type = None;
inner.response_size = 0; inner.response_size = 0;
inner.error = None; inner.error = None;
inner.write_capacity = MAX_WRITE_BUFFER_SIZE;
p.push_front(inner); p.push_front(inner);
} }
} }
@ -937,7 +887,7 @@ mod tests {
#[test] #[test]
fn test_force_close() { fn test_force_close() {
let resp = Response::build(StatusCode::OK).force_close().finish(); let resp = Response::build(StatusCode::OK).force_close().finish();
assert!(!resp.keep_alive().unwrap()) assert!(!resp.keep_alive())
} }
#[test] #[test]

View file

@ -3,10 +3,10 @@ use std::marker::PhantomData;
use actix_net::codec::Framed; use actix_net::codec::Framed;
use actix_net::service::{NewService, Service}; use actix_net::service::{NewService, Service};
use futures::future::{ok, Either, FutureResult}; use futures::future::{ok, Either, FutureResult};
use futures::{Async, AsyncSink, Future, Poll, Sink}; use futures::{Async, Future, Poll, Sink};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use body::MessageBody; use body::{BodyLength, MessageBody};
use error::{Error, ResponseError}; use error::{Error, ResponseError};
use h1::{Codec, Message}; use h1::{Codec, Message};
use response::Response; use response::Response;
@ -15,7 +15,7 @@ pub struct SendError<T, R, E>(PhantomData<(T, R, E)>);
impl<T, R, E> Default for SendError<T, R, E> impl<T, R, E> Default for SendError<T, R, E>
where where
T: AsyncWrite, T: AsyncRead + AsyncWrite,
E: ResponseError, E: ResponseError,
{ {
fn default() -> Self { fn default() -> Self {
@ -25,7 +25,7 @@ where
impl<T, R, E> NewService for SendError<T, R, E> impl<T, R, E> NewService for SendError<T, R, E>
where where
T: AsyncWrite, T: AsyncRead + AsyncWrite,
E: ResponseError, E: ResponseError,
{ {
type Request = Result<R, (E, Framed<T, Codec>)>; type Request = Result<R, (E, Framed<T, Codec>)>;
@ -42,7 +42,7 @@ where
impl<T, R, E> Service for SendError<T, R, E> impl<T, R, E> Service for SendError<T, R, E>
where where
T: AsyncWrite, T: AsyncRead + AsyncWrite,
E: ResponseError, E: ResponseError,
{ {
type Request = Result<R, (E, Framed<T, Codec>)>; type Request = Result<R, (E, Framed<T, Codec>)>;
@ -62,7 +62,7 @@ where
let (res, _body) = res.replace_body(()); let (res, _body) = res.replace_body(());
Either::B(SendErrorFut { Either::B(SendErrorFut {
framed: Some(framed), framed: Some(framed),
res: Some(res.into()), res: Some((res, BodyLength::Empty).into()),
err: Some(e), err: Some(e),
_t: PhantomData, _t: PhantomData,
}) })
@ -72,7 +72,7 @@ where
} }
pub struct SendErrorFut<T, R, E> { pub struct SendErrorFut<T, R, E> {
res: Option<Message<Response<()>>>, res: Option<Message<(Response<()>, BodyLength)>>,
framed: Option<Framed<T, Codec>>, framed: Option<Framed<T, Codec>>,
err: Option<E>, err: Option<E>,
_t: PhantomData<R>, _t: PhantomData<R>,
@ -81,22 +81,15 @@ pub struct SendErrorFut<T, R, E> {
impl<T, R, E> Future for SendErrorFut<T, R, E> impl<T, R, E> Future for SendErrorFut<T, R, E>
where where
E: ResponseError, E: ResponseError,
T: AsyncWrite, T: AsyncRead + AsyncWrite,
{ {
type Item = R; type Item = R;
type Error = (E, Framed<T, Codec>); type Error = (E, Framed<T, Codec>);
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(res) = self.res.take() { if let Some(res) = self.res.take() {
match self.framed.as_mut().unwrap().start_send(res) { if let Err(_) = self.framed.as_mut().unwrap().force_send(res) {
Ok(AsyncSink::Ready) => (), return Err((self.err.take().unwrap(), self.framed.take().unwrap()));
Ok(AsyncSink::NotReady(res)) => {
self.res = Some(res);
return Ok(Async::NotReady);
}
Err(_) => {
return Err((self.err.take().unwrap(), self.framed.take().unwrap()))
}
} }
} }
match self.framed.as_mut().unwrap().poll_complete() { match self.framed.as_mut().unwrap().poll_complete() {
@ -123,20 +116,15 @@ where
B: MessageBody, B: MessageBody,
{ {
pub fn send( pub fn send(
mut framed: Framed<T, Codec>, framed: Framed<T, Codec>,
res: Response<B>, res: Response<B>,
) -> impl Future<Item = Framed<T, Codec>, Error = Error> { ) -> impl Future<Item = Framed<T, Codec>, Error = Error> {
// extract body from response // extract body from response
let (mut res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
// init codec
framed
.get_codec_mut()
.prepare_te(&mut res.head_mut(), &mut body.length());
// write response // write response
SendResponseFut { SendResponseFut {
res: Some(Message::Item(res)), res: Some(Message::Item((res, body.length()))),
body: Some(body), body: Some(body),
framed: Some(framed), framed: Some(framed),
} }
@ -174,13 +162,10 @@ where
Ok(Async::Ready(())) Ok(Async::Ready(()))
} }
fn call(&mut self, (res, mut framed): Self::Request) -> Self::Future { fn call(&mut self, (res, framed): Self::Request) -> Self::Future {
let (mut res, body) = res.replace_body(()); let (res, body) = res.replace_body(());
framed
.get_codec_mut()
.prepare_te(res.head_mut(), &mut body.length());
SendResponseFut { SendResponseFut {
res: Some(Message::Item(res)), res: Some(Message::Item((res, body.length()))),
body: Some(body), body: Some(body),
framed: Some(framed), framed: Some(framed),
} }
@ -188,7 +173,7 @@ where
} }
pub struct SendResponseFut<T, B> { pub struct SendResponseFut<T, B> {
res: Option<Message<Response<()>>>, res: Option<Message<(Response<()>, BodyLength)>>,
body: Option<B>, body: Option<B>,
framed: Option<Framed<T, Codec>>, framed: Option<Framed<T, Codec>>,
} }

View file

@ -8,7 +8,7 @@ use std::io;
use error::ResponseError; use error::ResponseError;
use http::{header, Method, StatusCode}; use http::{header, Method, StatusCode};
use request::Request; use request::Request;
use response::{ConnectionType, Response, ResponseBuilder}; use response::{Response, ResponseBuilder};
mod client; mod client;
mod codec; mod codec;
@ -183,7 +183,7 @@ pub fn handshake_response(req: &Request) -> ResponseBuilder {
}; };
Response::build(StatusCode::SWITCHING_PROTOCOLS) Response::build(StatusCode::SWITCHING_PROTOCOLS)
.connection_type(ConnectionType::Upgrade) .upgrade()
.header(header::UPGRADE, "websocket") .header(header::UPGRADE, "websocket")
.header(header::TRANSFER_ENCODING, "chunked") .header(header::TRANSFER_ENCODING, "chunked")
.header(header::SEC_WEBSOCKET_ACCEPT, key.as_str()) .header(header::SEC_WEBSOCKET_ACCEPT, key.as_str())

View file

@ -12,10 +12,13 @@ use actix_net::server::Server;
use actix_net::service::NewServiceExt; use actix_net::service::NewServiceExt;
use actix_web::{client, test, HttpMessage}; use actix_web::{client, test, HttpMessage};
use bytes::Bytes; use bytes::Bytes;
use futures::future::{self, ok}; use futures::future::{self, lazy, ok};
use futures::stream::once; use futures::stream::once;
use actix_http::{body, h1, http, Body, Error, KeepAlive, Request, Response}; use actix_http::{
body, client as client2, h1, http, Body, Error, HttpMessage as HttpMessage2,
KeepAlive, Request, Response,
};
#[test] #[test]
fn test_h1_v2() { fn test_h1_v2() {
@ -181,14 +184,19 @@ fn test_headers() {
.unwrap() .unwrap()
.run() .run()
}); });
thread::sleep(time::Duration::from_millis(400)); thread::sleep(time::Duration::from_millis(200));
let mut sys = System::new("test"); let mut sys = System::new("test");
let req = client::ClientRequest::get(format!("http://{}/", addr)) let mut connector = sys
.block_on(lazy(|| {
Ok::<_, ()>(client2::Connector::default().service())
})).unwrap();
let req = client2::ClientRequest::get(format!("http://{}/", addr))
.finish() .finish()
.unwrap(); .unwrap();
let response = sys.block_on(req.send()).unwrap(); let response = sys.block_on(req.send(&mut connector)).unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
@ -249,9 +257,7 @@ fn test_head_empty() {
thread::spawn(move || { thread::spawn(move || {
Server::new() Server::new()
.bind("test", addr, move || { .bind("test", addr, move || {
h1::H1Service::new(|_| { h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))).map(|_| ())
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish())
}).map(|_| ())
}).unwrap() }).unwrap()
.run() .run()
}); });