1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-21 14:38:07 +00:00

refactor client response payload handling

This commit is contained in:
Nikolay Kim 2018-11-14 09:38:16 -08:00
parent 550c5f55b6
commit 6297fe0d41
11 changed files with 166 additions and 83 deletions

View file

@ -46,7 +46,8 @@ rust-tls = ["rustls", "actix-net/rust-tls"]
[dependencies] [dependencies]
actix = "0.7.5" actix = "0.7.5"
#actix-net = "0.2.0" #actix-net = "0.2.0"
actix-net = { git="https://github.com/actix/actix-net.git" } #actix-net = { git="https://github.com/actix/actix-net.git" }
actix-net = { path="../actix-net" }
base64 = "0.9" base64 = "0.9"
bitflags = "1.0" bitflags = "1.0"

View file

@ -4,10 +4,13 @@ use std::{fmt, mem};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{Async, Poll, Stream}; use futures::{Async, Poll, Stream};
use error::Error; use error::{Error, PayloadError};
/// Type represent streaming body /// Type represent streaming body
pub type BodyStream = Box<Stream<Item = Bytes, Error = Error>>; pub type BodyStream = Box<dyn Stream<Item = Bytes, Error = Error>>;
/// Type represent streaming payload
pub type PayloadStream = Box<dyn Stream<Item = Bytes, Error = PayloadError>>;
/// Different type of bory /// Different type of bory
pub enum BodyType { pub enum BodyType {

View file

@ -11,8 +11,8 @@ use super::error::{ConnectorError, SendRequestError};
use super::request::RequestHead; use super::request::RequestHead;
use super::response::ClientResponse; use super::response::ClientResponse;
use super::{Connect, Connection}; use super::{Connect, Connection};
use body::{BodyStream, BodyType, MessageBody}; use body::{BodyType, MessageBody, PayloadStream};
use error::Error; use error::PayloadError;
use h1; use h1;
pub fn send_request<T, Io, B>( pub fn send_request<T, Io, B>(
@ -44,7 +44,7 @@ where
let mut res = item.into_item().unwrap(); let mut res = item.into_item().unwrap();
match framed.get_codec().message_type() { match framed.get_codec().message_type() {
h1::MessageType::None => release_connection(framed), h1::MessageType::None => release_connection(framed),
_ => res.payload = Some(Payload::stream(framed)), _ => *res.payload.borrow_mut() = Some(Payload::stream(framed)),
} }
ok(res) ok(res)
} else { } else {
@ -129,41 +129,56 @@ where
} }
} }
struct Payload<Io> { struct EmptyPayload;
framed: Option<Framed<Connection<Io>, h1::ClientCodec>>,
impl Stream for EmptyPayload {
type Item = Bytes;
type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
Ok(Async::Ready(None))
}
}
pub(crate) struct Payload<Io> {
framed: Option<Framed<Connection<Io>, h1::ClientPayloadCodec>>,
}
impl Payload<()> {
pub fn empty() -> PayloadStream {
Box::new(EmptyPayload)
}
} }
impl<Io: AsyncRead + AsyncWrite + 'static> Payload<Io> { impl<Io: AsyncRead + AsyncWrite + 'static> Payload<Io> {
fn stream(framed: Framed<Connection<Io>, h1::ClientCodec>) -> BodyStream { fn stream(framed: Framed<Connection<Io>, h1::ClientCodec>) -> PayloadStream {
Box::new(Payload { Box::new(Payload {
framed: Some(framed), framed: Some(framed.map_codec(|codec| codec.into_payload_codec())),
}) })
} }
} }
impl<Io: AsyncRead + AsyncWrite + 'static> Stream for Payload<Io> { impl<Io: AsyncRead + AsyncWrite + 'static> Stream for Payload<Io> {
type Item = Bytes; type Item = Bytes;
type Error = Error; type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self.framed.as_mut().unwrap().poll()? { match self.framed.as_mut().unwrap().poll()? {
Async::NotReady => Ok(Async::NotReady), Async::NotReady => Ok(Async::NotReady),
Async::Ready(Some(chunk)) => match chunk { Async::Ready(Some(chunk)) => if let Some(chunk) = chunk {
h1::Message::Chunk(Some(chunk)) => Ok(Async::Ready(Some(chunk))), Ok(Async::Ready(Some(chunk)))
h1::Message::Chunk(None) => { } else {
release_connection(self.framed.take().unwrap()); release_connection(self.framed.take().unwrap());
Ok(Async::Ready(None)) Ok(Async::Ready(None))
}
h1::Message::Item(_) => unreachable!(),
}, },
Async::Ready(None) => Ok(Async::Ready(None)), Async::Ready(None) => Ok(Async::Ready(None)),
} }
} }
} }
fn release_connection<Io>(framed: Framed<Connection<Io>, h1::ClientCodec>) fn release_connection<T, U>(framed: Framed<Connection<T>, U>)
where where
Io: AsyncRead + AsyncWrite + 'static, T: AsyncRead + AsyncWrite + 'static,
{ {
let parts = framed.into_parts(); let parts = framed.into_parts();
if parts.read_buf.is_empty() && parts.write_buf.is_empty() { if parts.read_buf.is_empty() && parts.write_buf.is_empty() {

View file

@ -6,34 +6,37 @@ use bytes::Bytes;
use futures::{Async, Poll, Stream}; use futures::{Async, Poll, Stream};
use http::{HeaderMap, Method, StatusCode, Version}; use http::{HeaderMap, Method, StatusCode, Version};
use body::BodyStream; use body::PayloadStream;
use error::Error; use error::PayloadError;
use extensions::Extensions; use extensions::Extensions;
use httpmessage::HttpMessage;
use request::{Message, MessageFlags, MessagePool}; use request::{Message, MessageFlags, MessagePool};
use uri::Url; use uri::Url;
use super::pipeline::Payload;
/// Client Response /// Client Response
pub struct ClientResponse { pub struct ClientResponse {
pub(crate) inner: Rc<Message>, pub(crate) inner: Rc<Message>,
pub(crate) payload: Option<BodyStream>, pub(crate) payload: RefCell<Option<PayloadStream>>,
} }
// impl HttpMessage for ClientResponse { impl HttpMessage for ClientResponse {
// type Stream = Payload; type Stream = PayloadStream;
// fn headers(&self) -> &HeaderMap { fn headers(&self) -> &HeaderMap {
// &self.inner.headers &self.inner.headers
// } }
// #[inline] #[inline]
// fn payload(&self) -> Payload { fn payload(&self) -> Self::Stream {
// if let Some(payload) = self.inner.payload.borrow_mut().take() { if let Some(payload) = self.payload.borrow_mut().take() {
// payload payload
// } else { } else {
// Payload::empty() Payload::empty()
// } }
// } }
// } }
impl ClientResponse { impl ClientResponse {
/// Create new Request instance /// Create new Request instance
@ -55,7 +58,7 @@ impl ClientResponse {
payload: RefCell::new(None), payload: RefCell::new(None),
extensions: RefCell::new(Extensions::new()), extensions: RefCell::new(Extensions::new()),
}), }),
payload: None, payload: RefCell::new(None),
} }
} }
@ -114,10 +117,10 @@ impl ClientResponse {
impl Stream for ClientResponse { impl Stream for ClientResponse {
type Item = Bytes; type Item = Bytes;
type Error = Error; type Error = PayloadError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
if let Some(ref mut payload) = self.payload { if let Some(ref mut payload) = &mut *self.payload.borrow_mut() {
payload.poll() payload.poll()
} else { } else {
Ok(Async::Ready(None)) Ok(Async::Ready(None))

View file

@ -339,7 +339,7 @@ impl From<httparse::Error> for ParseError {
pub enum PayloadError { pub enum PayloadError {
/// A payload reached EOF, but is not complete. /// A payload reached EOF, but is not complete.
#[fail(display = "A payload reached EOF, but is not complete.")] #[fail(display = "A payload reached EOF, but is not complete.")]
Incomplete, Incomplete(Option<io::Error>),
/// Content encoding stream corruption /// Content encoding stream corruption
#[fail(display = "Can not decode content-encoding.")] #[fail(display = "Can not decode content-encoding.")]
EncodingCorrupted, EncodingCorrupted,
@ -351,6 +351,12 @@ pub enum PayloadError {
UnknownLength, UnknownLength,
} }
impl From<io::Error> for PayloadError {
fn from(err: io::Error) -> Self {
PayloadError::Incomplete(Some(err))
}
}
/// `PayloadError` returns two possible results: /// `PayloadError` returns two possible results:
/// ///
/// - `Overflow` returns `PayloadTooLarge` /// - `Overflow` returns `PayloadTooLarge`

View file

@ -10,7 +10,7 @@ use super::{Message, MessageType};
use body::{Binary, Body, BodyType}; use body::{Binary, Body, BodyType};
use client::{ClientResponse, RequestHead}; use client::{ClientResponse, RequestHead};
use config::ServiceConfig; use config::ServiceConfig;
use error::ParseError; use error::{ParseError, PayloadError};
use helpers; use helpers;
use http::header::{ use http::header::{
HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE, HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING, UPGRADE,
@ -32,6 +32,15 @@ const AVERAGE_HEADER_SIZE: usize = 30;
/// HTTP/1 Codec /// HTTP/1 Codec
pub struct ClientCodec { pub struct ClientCodec {
inner: ClientCodecInner,
}
/// HTTP/1 Payload Codec
pub struct ClientPayloadCodec {
inner: ClientCodecInner,
}
struct ClientCodecInner {
config: ServiceConfig, config: ServiceConfig,
decoder: ResponseDecoder, decoder: ResponseDecoder,
payload: Option<PayloadDecoder>, payload: Option<PayloadDecoder>,
@ -65,6 +74,7 @@ impl ClientCodec {
Flags::empty() Flags::empty()
}; };
ClientCodec { ClientCodec {
inner: ClientCodecInner {
config, config,
decoder: ResponseDecoder::with_pool(pool), decoder: ResponseDecoder::with_pool(pool),
payload: None, payload: None,
@ -73,24 +83,25 @@ impl ClientCodec {
flags, flags,
headers_size: 0, headers_size: 0,
te: RequestEncoder::default(), te: RequestEncoder::default(),
},
} }
} }
/// Check if request is upgrade /// Check if request is upgrade
pub fn upgrade(&self) -> bool { pub fn upgrade(&self) -> bool {
self.flags.contains(Flags::UPGRADE) self.inner.flags.contains(Flags::UPGRADE)
} }
/// Check if last response is keep-alive /// Check if last response is keep-alive
pub fn keepalive(&self) -> bool { pub fn keepalive(&self) -> bool {
self.flags.contains(Flags::KEEPALIVE) self.inner.flags.contains(Flags::KEEPALIVE)
} }
/// Check last request's message type /// Check last request's message type
pub fn message_type(&self) -> MessageType { pub fn message_type(&self) -> MessageType {
if self.flags.contains(Flags::STREAM) { if self.inner.flags.contains(Flags::STREAM) {
MessageType::Stream MessageType::Stream
} else if self.payload.is_none() { } else if self.inner.payload.is_none() {
MessageType::None MessageType::None
} else { } else {
MessageType::Payload MessageType::Payload
@ -99,10 +110,27 @@ impl ClientCodec {
/// prepare transfer encoding /// prepare transfer encoding
pub fn prepare_te(&mut self, head: &mut RequestHead, btype: BodyType) { pub fn prepare_te(&mut self, head: &mut RequestHead, btype: BodyType) {
self.te self.inner.te.update(
.update(head, self.flags.contains(Flags::HEAD), self.version); head,
self.inner.flags.contains(Flags::HEAD),
self.inner.version,
);
} }
/// Convert message codec to a payload codec
pub fn into_payload_codec(self) -> ClientPayloadCodec {
ClientPayloadCodec { inner: self.inner }
}
}
impl ClientPayloadCodec {
/// Transform payload codec to a message codec
pub fn into_message_codec(self) -> ClientCodec {
ClientCodec { inner: self.inner }
}
}
impl ClientCodecInner {
fn encode_response( fn encode_response(
&mut self, &mut self,
msg: RequestHead, msg: RequestHead,
@ -154,25 +182,26 @@ impl Decoder for ClientCodec {
type Error = ParseError; type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.payload.is_some() { if self.inner.payload.is_some() {
Ok(match self.payload.as_mut().unwrap().decode(src)? { Ok(match self.inner.payload.as_mut().unwrap().decode(src)? {
Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))), Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
Some(PayloadItem::Eof) => Some(Message::Chunk(None)), Some(PayloadItem::Eof) => Some(Message::Chunk(None)),
None => None, None => None,
}) })
} else if let Some((req, payload)) = self.decoder.decode(src)? { } else if let Some((req, payload)) = self.inner.decoder.decode(src)? {
self.flags self.inner
.flags
.set(Flags::HEAD, req.inner.method == Method::HEAD); .set(Flags::HEAD, req.inner.method == Method::HEAD);
self.version = req.inner.version; self.inner.version = req.inner.version;
if self.flags.contains(Flags::KEEPALIVE_ENABLED) { if self.inner.flags.contains(Flags::KEEPALIVE_ENABLED) {
self.flags.set(Flags::KEEPALIVE, req.keep_alive()); self.inner.flags.set(Flags::KEEPALIVE, req.keep_alive());
} }
match payload { match payload {
PayloadType::None => self.payload = None, PayloadType::None => self.inner.payload = None,
PayloadType::Payload(pl) => self.payload = Some(pl), PayloadType::Payload(pl) => self.inner.payload = Some(pl),
PayloadType::Stream(pl) => { PayloadType::Stream(pl) => {
self.payload = Some(pl); self.inner.payload = Some(pl);
self.flags.insert(Flags::STREAM); self.inner.flags.insert(Flags::STREAM);
} }
}; };
Ok(Some(Message::Item(req))) Ok(Some(Message::Item(req)))
@ -182,6 +211,27 @@ impl Decoder for ClientCodec {
} }
} }
impl Decoder for ClientPayloadCodec {
type Item = Option<Bytes>;
type Error = PayloadError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
assert!(
self.inner.payload.is_some(),
"Payload decoder is not specified"
);
Ok(match self.inner.payload.as_mut().unwrap().decode(src)? {
Some(PayloadItem::Chunk(chunk)) => Some(Some(chunk)),
Some(PayloadItem::Eof) => {
self.inner.payload.take();
Some(None)
}
None => None,
})
}
}
impl Encoder for ClientCodec { impl Encoder for ClientCodec {
type Item = Message<(RequestHead, BodyType)>; type Item = Message<(RequestHead, BodyType)>;
type Error = io::Error; type Error = io::Error;
@ -193,13 +243,13 @@ impl Encoder for ClientCodec {
) -> Result<(), Self::Error> { ) -> Result<(), Self::Error> {
match item { match item {
Message::Item((msg, btype)) => { Message::Item((msg, btype)) => {
self.encode_response(msg, btype, dst)?; self.inner.encode_response(msg, btype, dst)?;
} }
Message::Chunk(Some(bytes)) => { Message::Chunk(Some(bytes)) => {
self.te.encode(bytes.as_ref(), dst)?; self.inner.te.encode(bytes.as_ref(), dst)?;
} }
Message::Chunk(None) => { Message::Chunk(None) => {
self.te.encode_eof(dst)?; self.inner.te.encode_eof(dst)?;
} }
} }
Ok(()) Ok(())

View file

@ -143,7 +143,7 @@ where
fn client_disconnected(&mut self) { fn client_disconnected(&mut self) {
self.flags.insert(Flags::DISCONNECTED); self.flags.insert(Flags::DISCONNECTED);
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete); payload.set_error(PayloadError::Incomplete(None));
} }
} }
@ -228,7 +228,7 @@ where
} }
Err(err) => { Err(err) => {
if let Some(mut payload) = self.payload.take() { if let Some(mut payload) = self.payload.take() {
payload.set_error(PayloadError::Incomplete); payload.set_error(PayloadError::Incomplete(None));
} }
return Err(DispatchError::Io(err)); return Err(DispatchError::Io(err));
} }
@ -236,7 +236,10 @@ where
} }
// Send payload // Send payload
State::SendPayload(ref mut stream, ref mut bin) => { State::SendPayload(ref mut stream, ref mut bin) => {
println!("SEND payload");
if let Some(item) = bin.take() { if let Some(item) = bin.take() {
let mut framed = self.framed.as_mut().unwrap();
if framed.is_
match self.framed.as_mut().unwrap().start_send(item) { match self.framed.as_mut().unwrap().start_send(item) {
Ok(AsyncSink::Ready) => { Ok(AsyncSink::Ready) => {
self.flags.remove(Flags::FLUSHED); self.flags.remove(Flags::FLUSHED);

View file

@ -9,7 +9,7 @@ mod dispatcher;
mod encoder; mod encoder;
mod service; mod service;
pub use self::client::ClientCodec; pub use self::client::{ClientCodec, ClientPayloadCodec};
pub use self::codec::Codec; pub use self::codec::Codec;
pub use self::decoder::{PayloadDecoder, RequestDecoder}; pub use self::decoder::{PayloadDecoder, RequestDecoder};
pub use self::dispatcher::Dispatcher; pub use self::dispatcher::Dispatcher;

View file

@ -527,7 +527,7 @@ mod tests {
#[test] #[test]
fn test_error() { fn test_error() {
let err = PayloadError::Incomplete; let err = PayloadError::Incomplete(None);
assert_eq!( assert_eq!(
format!("{}", err), format!("{}", err),
"A payload reached EOF, but is not complete." "A payload reached EOF, but is not complete."
@ -584,7 +584,7 @@ mod tests {
assert_eq!(Async::NotReady, payload.readany().ok().unwrap()); assert_eq!(Async::NotReady, payload.readany().ok().unwrap());
sender.set_error(PayloadError::Incomplete); sender.set_error(PayloadError::Incomplete(None));
payload.readany().err().unwrap(); payload.readany().err().unwrap();
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
result(res) result(res)
@ -644,7 +644,7 @@ mod tests {
); );
assert_eq!(payload.len, 4); assert_eq!(payload.len, 4);
sender.set_error(PayloadError::Incomplete); sender.set_error(PayloadError::Incomplete(None));
payload.read_exact(10).err().unwrap(); payload.read_exact(10).err().unwrap();
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());
@ -677,7 +677,7 @@ mod tests {
); );
assert_eq!(payload.len, 0); assert_eq!(payload.len, 0);
sender.set_error(PayloadError::Incomplete); sender.set_error(PayloadError::Incomplete(None));
payload.read_until(b"b").err().unwrap(); payload.read_until(b"b").err().unwrap();
let res: Result<(), ()> = Ok(()); let res: Result<(), ()> = Ok(());

View file

@ -242,7 +242,7 @@ impl MessagePool {
} }
return ClientResponse { return ClientResponse {
inner: msg, inner: msg,
payload: None, payload: RefCell::new(None),
}; };
} }
ClientResponse::with_pool(pool) ClientResponse::with_pool(pool)

View file

@ -9,8 +9,10 @@ use std::{thread, time};
use actix::System; use actix::System;
use actix_net::server::Server; use actix_net::server::Server;
use actix_net::service::NewServiceExt; use actix_net::service::NewServiceExt;
use bytes::Bytes;
use futures::future::{self, lazy, ok}; use futures::future::{self, lazy, ok};
use actix_http::HttpMessage;
use actix_http::{client, h1, test, Request, Response}; use actix_http::{client, h1, test, Request, Response};
const STR: &str = "Hello World Hello World Hello World Hello World Hello World \ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
@ -73,8 +75,8 @@ fn test_h1_v2() {
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
// let bytes = srv.execute(response.body()).unwrap(); let bytes = sys.block_on(response.body()).unwrap();
// assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
let request = client::ClientRequest::post(format!("http://{}/", addr)) let request = client::ClientRequest::post(format!("http://{}/", addr))
.finish() .finish()
@ -83,8 +85,8 @@ fn test_h1_v2() {
assert!(response.status().is_success()); assert!(response.status().is_success());
// read response // read response
// let bytes = srv.execute(response.body()).unwrap(); let bytes = sys.block_on(response.body()).unwrap();
// assert_eq!(bytes, Bytes::from_static(STR.as_ref())); assert_eq!(bytes, Bytes::from_static(STR.as_ref()));
} }
#[test] #[test]