1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-21 22:48:07 +00:00

add client http codec; websockets client

This commit is contained in:
Nikolay Kim 2018-10-22 18:18:05 -07:00
parent 9b94eaa6a8
commit 09c94cb06b
20 changed files with 1802 additions and 148 deletions

View file

@ -54,7 +54,10 @@ time = "0.1"
encoding = "0.2"
lazy_static = "1.0"
serde_urlencoded = "0.5.3"
cookie = { version="0.11", features=["percent-encode"] }
percent-encoding = "1.0"
url = { version="1.7", features=["query_encoding"] }
# io
net2 = "0.2"

6
src/client/mod.rs Normal file
View file

@ -0,0 +1,6 @@
//! Http client api
mod request;
mod response;
pub use self::request::{ClientRequest, ClientRequestBuilder};
pub use self::response::ClientResponse;

564
src/client/request.rs Normal file
View file

@ -0,0 +1,564 @@
use std::fmt;
use std::fmt::Write as FmtWrite;
use std::io::Write;
use bytes::{BufMut, BytesMut};
use cookie::{Cookie, CookieJar};
use percent_encoding::{percent_encode, USERINFO_ENCODE_SET};
use urlcrate::Url;
use header::{self, Header, IntoHeaderValue};
use http::{
uri, Error as HttpError, HeaderMap, HeaderName, HeaderValue, HttpTryFrom, Method,
Uri, Version,
};
/// An HTTP Client Request
///
/// ```rust
/// # extern crate actix_web;
/// # extern crate futures;
/// # extern crate tokio;
/// # use futures::Future;
/// # use std::process;
/// use actix_web::{actix, client};
///
/// fn main() {
/// actix::run(
/// || client::ClientRequest::get("http://www.rust-lang.org") // <- Create request builder
/// .header("User-Agent", "Actix-web")
/// .finish().unwrap()
/// .send() // <- Send http request
/// .map_err(|_| ())
/// .and_then(|response| { // <- server http response
/// println!("Response: {:?}", response);
/// # actix::System::current().stop();
/// Ok(())
/// }),
/// );
/// }
/// ```
pub struct ClientRequest {
uri: Uri,
method: Method,
version: Version,
headers: HeaderMap,
chunked: bool,
upgrade: bool,
}
impl Default for ClientRequest {
fn default() -> ClientRequest {
ClientRequest {
uri: Uri::default(),
method: Method::default(),
version: Version::HTTP_11,
headers: HeaderMap::with_capacity(16),
chunked: false,
upgrade: false,
}
}
}
impl ClientRequest {
/// Create request builder for `GET` request
pub fn get<U: AsRef<str>>(uri: U) -> ClientRequestBuilder {
let mut builder = ClientRequest::build();
builder.method(Method::GET).uri(uri);
builder
}
/// Create request builder for `HEAD` request
pub fn head<U: AsRef<str>>(uri: U) -> ClientRequestBuilder {
let mut builder = ClientRequest::build();
builder.method(Method::HEAD).uri(uri);
builder
}
/// Create request builder for `POST` request
pub fn post<U: AsRef<str>>(uri: U) -> ClientRequestBuilder {
let mut builder = ClientRequest::build();
builder.method(Method::POST).uri(uri);
builder
}
/// Create request builder for `PUT` request
pub fn put<U: AsRef<str>>(uri: U) -> ClientRequestBuilder {
let mut builder = ClientRequest::build();
builder.method(Method::PUT).uri(uri);
builder
}
/// Create request builder for `DELETE` request
pub fn delete<U: AsRef<str>>(uri: U) -> ClientRequestBuilder {
let mut builder = ClientRequest::build();
builder.method(Method::DELETE).uri(uri);
builder
}
}
impl ClientRequest {
/// Create client request builder
pub fn build() -> ClientRequestBuilder {
ClientRequestBuilder {
request: Some(ClientRequest::default()),
err: None,
cookies: None,
default_headers: true,
}
}
/// Get the request URI
#[inline]
pub fn uri(&self) -> &Uri {
&self.uri
}
/// Set client request URI
#[inline]
pub fn set_uri(&mut self, uri: Uri) {
self.uri = uri
}
/// Get the request method
#[inline]
pub fn method(&self) -> &Method {
&self.method
}
/// Set HTTP `Method` for the request
#[inline]
pub fn set_method(&mut self, method: Method) {
self.method = method
}
/// Get HTTP version for the request
#[inline]
pub fn version(&self) -> Version {
self.version
}
/// Set http `Version` for the request
#[inline]
pub fn set_version(&mut self, version: Version) {
self.version = version
}
/// Get the headers from the request
#[inline]
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
/// Get a mutable reference to the headers
#[inline]
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
}
/// is chunked encoding enabled
#[inline]
pub fn chunked(&self) -> bool {
self.chunked
}
/// is upgrade request
#[inline]
pub fn upgrade(&self) -> bool {
self.upgrade
}
}
impl fmt::Debug for ClientRequest {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(
f,
"\nClientRequest {:?} {}:{}",
self.version, self.method, self.uri
)?;
writeln!(f, " headers:")?;
for (key, val) in self.headers.iter() {
writeln!(f, " {:?}: {:?}", key, val)?;
}
Ok(())
}
}
/// An HTTP Client request builder
///
/// This type can be used to construct an instance of `ClientRequest` through a
/// builder-like pattern.
pub struct ClientRequestBuilder {
request: Option<ClientRequest>,
err: Option<HttpError>,
cookies: Option<CookieJar>,
default_headers: bool,
}
impl ClientRequestBuilder {
/// Set HTTP URI of request.
#[inline]
pub fn uri<U: AsRef<str>>(&mut self, uri: U) -> &mut Self {
match Url::parse(uri.as_ref()) {
Ok(url) => self._uri(url.as_str()),
Err(_) => self._uri(uri.as_ref()),
}
}
fn _uri(&mut self, url: &str) -> &mut Self {
match Uri::try_from(url) {
Ok(uri) => {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.uri = uri;
}
}
Err(e) => self.err = Some(e.into()),
}
self
}
/// Set HTTP method of this request.
#[inline]
pub fn method(&mut self, method: Method) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.method = method;
}
self
}
/// Set HTTP method of this request.
#[inline]
pub fn get_method(&mut self) -> &Method {
let parts = self.request.as_ref().expect("cannot reuse request builder");
&parts.method
}
/// Set HTTP version of this request.
///
/// By default requests's HTTP version depends on network stream
#[inline]
pub fn version(&mut self, version: Version) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.version = version;
}
self
}
/// Set a header.
///
/// ```rust
/// # extern crate mime;
/// # extern crate actix_web;
/// # use actix_web::client::*;
/// #
/// use actix_web::{client, http};
///
/// fn main() {
/// let req = client::ClientRequest::build()
/// .set(http::header::Date::now())
/// .set(http::header::ContentType(mime::TEXT_HTML))
/// .finish()
/// .unwrap();
/// }
/// ```
#[doc(hidden)]
pub fn set<H: Header>(&mut self, hdr: H) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
match hdr.try_into() {
Ok(value) => {
parts.headers.insert(H::name(), value);
}
Err(e) => self.err = Some(e.into()),
}
}
self
}
/// Append a header.
///
/// Header gets appended to existing header.
/// To override header use `set_header()` method.
///
/// ```rust
/// # extern crate http;
/// # extern crate actix_web;
/// # use actix_web::client::*;
/// #
/// use http::header;
///
/// fn main() {
/// let req = ClientRequest::build()
/// .header("X-TEST", "value")
/// .header(header::CONTENT_TYPE, "application/json")
/// .finish()
/// .unwrap();
/// }
/// ```
pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => match value.try_into() {
Ok(value) => {
parts.headers.append(key, value);
}
Err(e) => self.err = Some(e.into()),
},
Err(e) => self.err = Some(e.into()),
};
}
self
}
/// Set a header.
pub fn set_header<K, V>(&mut self, key: K, value: V) -> &mut Self
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => match value.try_into() {
Ok(value) => {
parts.headers.insert(key, value);
}
Err(e) => self.err = Some(e.into()),
},
Err(e) => self.err = Some(e.into()),
};
}
self
}
/// Set a header only if it is not yet set.
pub fn set_header_if_none<K, V>(&mut self, key: K, value: V) -> &mut Self
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderName::try_from(key) {
Ok(key) => if !parts.headers.contains_key(&key) {
match value.try_into() {
Ok(value) => {
parts.headers.insert(key, value);
}
Err(e) => self.err = Some(e.into()),
}
},
Err(e) => self.err = Some(e.into()),
};
}
self
}
/// Enable connection upgrade
#[inline]
pub fn upgrade(&mut self) -> &mut Self {
if let Some(parts) = parts(&mut self.request, &self.err) {
parts.upgrade = true;
}
self
}
/// Set request's content type
#[inline]
pub fn content_type<V>(&mut self, value: V) -> &mut Self
where
HeaderValue: HttpTryFrom<V>,
{
if let Some(parts) = parts(&mut self.request, &self.err) {
match HeaderValue::try_from(value) {
Ok(value) => {
parts.headers.insert(header::CONTENT_TYPE, value);
}
Err(e) => self.err = Some(e.into()),
};
}
self
}
/// Set content length
#[inline]
pub fn content_length(&mut self, len: u64) -> &mut Self {
let mut wrt = BytesMut::new().writer();
let _ = write!(wrt, "{}", len);
self.header(header::CONTENT_LENGTH, wrt.get_mut().take().freeze())
}
/// Set a cookie
///
/// ```rust
/// # extern crate actix_web;
/// use actix_web::{client, http};
///
/// fn main() {
/// let req = client::ClientRequest::build()
/// .cookie(
/// http::Cookie::build("name", "value")
/// .domain("www.rust-lang.org")
/// .path("/")
/// .secure(true)
/// .http_only(true)
/// .finish(),
/// )
/// .finish()
/// .unwrap();
/// }
/// ```
pub fn cookie<'c>(&mut self, cookie: Cookie<'c>) -> &mut Self {
if self.cookies.is_none() {
let mut jar = CookieJar::new();
jar.add(cookie.into_owned());
self.cookies = Some(jar)
} else {
self.cookies.as_mut().unwrap().add(cookie.into_owned());
}
self
}
/// Do not add default request headers.
/// By default `Accept-Encoding` and `User-Agent` headers are set.
pub fn no_default_headers(&mut self) -> &mut Self {
self.default_headers = false;
self
}
/// This method calls provided closure with builder reference if
/// value is `true`.
pub fn if_true<F>(&mut self, value: bool, f: F) -> &mut Self
where
F: FnOnce(&mut ClientRequestBuilder),
{
if value {
f(self);
}
self
}
/// This method calls provided closure with builder reference if
/// value is `Some`.
pub fn if_some<T, F>(&mut self, value: Option<T>, f: F) -> &mut Self
where
F: FnOnce(T, &mut ClientRequestBuilder),
{
if let Some(val) = value {
f(val, self);
}
self
}
/// Set a body and generate `ClientRequest`.
///
/// `ClientRequestBuilder` can not be used after this call.
pub fn finish(&mut self) -> Result<ClientRequest, HttpError> {
if let Some(e) = self.err.take() {
return Err(e);
}
if self.default_headers {
// enable br only for https
let https = if let Some(parts) = parts(&mut self.request, &self.err) {
parts
.uri
.scheme_part()
.map(|s| s == &uri::Scheme::HTTPS)
.unwrap_or(true)
} else {
true
};
if https {
self.set_header_if_none(header::ACCEPT_ENCODING, "br, gzip, deflate");
} else {
self.set_header_if_none(header::ACCEPT_ENCODING, "gzip, deflate");
}
// set request host header
if let Some(parts) = parts(&mut self.request, &self.err) {
if let Some(host) = parts.uri.host() {
if !parts.headers.contains_key(header::HOST) {
let mut wrt = BytesMut::with_capacity(host.len() + 5).writer();
let _ = match parts.uri.port() {
None | Some(80) | Some(443) => write!(wrt, "{}", host),
Some(port) => write!(wrt, "{}:{}", host, port),
};
match wrt.get_mut().take().freeze().try_into() {
Ok(value) => {
parts.headers.insert(header::HOST, value);
}
Err(e) => self.err = Some(e.into()),
}
}
}
}
// user agent
self.set_header_if_none(
header::USER_AGENT,
concat!("actix-http/", env!("CARGO_PKG_VERSION")),
);
}
let mut request = self.request.take().expect("cannot reuse request builder");
// set cookies
if let Some(ref mut jar) = self.cookies {
let mut cookie = String::new();
for c in jar.delta() {
let name = percent_encode(c.name().as_bytes(), USERINFO_ENCODE_SET);
let value = percent_encode(c.value().as_bytes(), USERINFO_ENCODE_SET);
let _ = write!(&mut cookie, "; {}={}", name, value);
}
request.headers.insert(
header::COOKIE,
HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(),
);
}
Ok(request)
}
/// This method construct new `ClientRequestBuilder`
pub fn take(&mut self) -> ClientRequestBuilder {
ClientRequestBuilder {
request: self.request.take(),
err: self.err.take(),
cookies: self.cookies.take(),
default_headers: self.default_headers,
}
}
}
#[inline]
fn parts<'a>(
parts: &'a mut Option<ClientRequest>, err: &Option<HttpError>,
) -> Option<&'a mut ClientRequest> {
if err.is_some() {
return None;
}
parts.as_mut()
}
impl fmt::Debug for ClientRequestBuilder {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(ref parts) = self.request {
writeln!(
f,
"\nClientRequestBuilder {:?} {}:{}",
parts.version, parts.method, parts.uri
)?;
writeln!(f, " headers:")?;
for (key, val) in parts.headers.iter() {
writeln!(f, " {:?}: {:?}", key, val)?;
}
Ok(())
} else {
write!(f, "ClientRequestBuilder(Consumed)")
}
}
}

128
src/client/response.rs Normal file
View file

@ -0,0 +1,128 @@
use std::cell::{Cell, Ref, RefCell, RefMut};
use std::fmt;
use std::rc::Rc;
use http::{HeaderMap, Method, StatusCode, Version};
use extensions::Extensions;
use httpmessage::HttpMessage;
use payload::Payload;
use request::{Message, MessageFlags, MessagePool};
use uri::Url;
/// Client Response
pub struct ClientResponse {
pub(crate) inner: Rc<Message>,
}
impl HttpMessage for ClientResponse {
type Stream = Payload;
fn headers(&self) -> &HeaderMap {
&self.inner.headers
}
#[inline]
fn payload(&self) -> Payload {
if let Some(payload) = self.inner.payload.borrow_mut().take() {
payload
} else {
Payload::empty()
}
}
}
impl ClientResponse {
/// Create new Request instance
pub fn new() -> ClientResponse {
ClientResponse::with_pool(MessagePool::pool())
}
/// Create new Request instance with pool
pub(crate) fn with_pool(pool: &'static MessagePool) -> ClientResponse {
ClientResponse {
inner: Rc::new(Message {
pool,
method: Method::GET,
status: StatusCode::OK,
url: Url::default(),
version: Version::HTTP_11,
headers: HeaderMap::with_capacity(16),
flags: Cell::new(MessageFlags::empty()),
payload: RefCell::new(None),
extensions: RefCell::new(Extensions::new()),
}),
}
}
#[inline]
pub(crate) fn inner(&self) -> &Message {
self.inner.as_ref()
}
#[inline]
pub(crate) fn inner_mut(&mut self) -> &mut Message {
Rc::get_mut(&mut self.inner).expect("Multiple copies exist")
}
/// Read the Request Version.
#[inline]
pub fn version(&self) -> Version {
self.inner().version
}
/// Get the status from the server.
#[inline]
pub fn status(&self) -> StatusCode {
self.inner().status
}
#[inline]
/// Returns Request's headers.
pub fn headers(&self) -> &HeaderMap {
&self.inner().headers
}
#[inline]
/// Returns mutable Request's headers.
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.inner_mut().headers
}
/// Checks if a connection should be kept alive.
#[inline]
pub fn keep_alive(&self) -> bool {
self.inner().flags.get().contains(MessageFlags::KEEPALIVE)
}
/// Request extensions
#[inline]
pub fn extensions(&self) -> Ref<Extensions> {
self.inner().extensions.borrow()
}
/// Mutable reference to a the request's extensions
#[inline]
pub fn extensions_mut(&self) -> RefMut<Extensions> {
self.inner().extensions.borrow_mut()
}
}
impl Drop for ClientResponse {
fn drop(&mut self) {
if Rc::strong_count(&self.inner) == 1 {
self.inner.pool.release(self.inner.clone());
}
}
}
impl fmt::Debug for ClientResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "\nClientResponse {:?} {}", self.version(), self.status(),)?;
writeln!(f, " headers:")?;
for (key, val) in self.headers().iter() {
writeln!(f, " {:?}: {:?}", key, val)?;
}
Ok(())
}
}

217
src/h1/client.rs Normal file
View file

@ -0,0 +1,217 @@
#![allow(unused_imports, unused_variables, dead_code)]
use std::io::{self, Write};
use bytes::{BufMut, Bytes, BytesMut};
use tokio_codec::{Decoder, Encoder};
use super::decoder::{PayloadDecoder, PayloadItem, PayloadType, ResponseDecoder};
use super::encoder::{RequestEncoder, ResponseLength};
use super::{Message, MessageType};
use body::{Binary, Body};
use client::{ClientRequest, ClientResponse};
use config::ServiceConfig;
use error::ParseError;
use helpers;
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{Method, Version};
use request::MessagePool;
bitflags! {
struct Flags: u8 {
const HEAD = 0b0000_0001;
const UPGRADE = 0b0000_0010;
const KEEPALIVE = 0b0000_0100;
const KEEPALIVE_ENABLED = 0b0000_1000;
const UNHANDLED = 0b0001_0000;
}
}
const AVERAGE_HEADER_SIZE: usize = 30;
/// HTTP/1 Codec
pub struct ClientCodec {
config: ServiceConfig,
decoder: ResponseDecoder,
payload: Option<PayloadDecoder>,
version: Version,
// encoder part
flags: Flags,
headers_size: u32,
te: RequestEncoder,
}
impl Default for ClientCodec {
fn default() -> Self {
ClientCodec::new(ServiceConfig::default())
}
}
impl ClientCodec {
/// Create HTTP/1 codec.
///
/// `keepalive_enabled` how response `connection` header get generated.
pub fn new(config: ServiceConfig) -> Self {
ClientCodec::with_pool(MessagePool::pool(), config)
}
/// Create HTTP/1 codec with request's pool
pub(crate) fn with_pool(pool: &'static MessagePool, config: ServiceConfig) -> Self {
let flags = if config.keep_alive_enabled() {
Flags::KEEPALIVE_ENABLED
} else {
Flags::empty()
};
ClientCodec {
config,
decoder: ResponseDecoder::with_pool(pool),
payload: None,
version: Version::HTTP_11,
flags,
headers_size: 0,
te: RequestEncoder::default(),
}
}
/// Check if request is upgrade
pub fn upgrade(&self) -> bool {
self.flags.contains(Flags::UPGRADE)
}
/// Check if last response is keep-alive
pub fn keepalive(&self) -> bool {
self.flags.contains(Flags::KEEPALIVE)
}
/// Check last request's message type
pub fn message_type(&self) -> MessageType {
if self.flags.contains(Flags::UNHANDLED) {
MessageType::Unhandled
} else if self.payload.is_none() {
MessageType::None
} else {
MessageType::Payload
}
}
/// prepare transfer encoding
pub fn prepare_te(&mut self, res: &mut ClientRequest) {
self.te
.update(res, self.flags.contains(Flags::HEAD), self.version);
}
fn encode_response(
&mut self, msg: ClientRequest, buffer: &mut BytesMut,
) -> io::Result<()> {
// Connection upgrade
if msg.upgrade() {
self.flags.insert(Flags::UPGRADE);
}
// render message
{
// status line
writeln!(
Writer(buffer),
"{} {} {:?}\r",
msg.method(),
msg.uri()
.path_and_query()
.map(|u| u.as_str())
.unwrap_or("/"),
msg.version()
).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
// write headers
buffer.reserve(msg.headers().len() * AVERAGE_HEADER_SIZE);
for (key, value) in msg.headers() {
let v = value.as_ref();
let k = key.as_str().as_bytes();
buffer.reserve(k.len() + v.len() + 4);
buffer.put_slice(k);
buffer.put_slice(b": ");
buffer.put_slice(v);
buffer.put_slice(b"\r\n");
}
// set date header
if !msg.headers().contains_key(DATE) {
self.config.set_date(buffer);
} else {
buffer.extend_from_slice(b"\r\n");
}
}
Ok(())
}
}
impl Decoder for ClientCodec {
type Item = Message<ClientResponse>;
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.payload.is_some() {
Ok(match self.payload.as_mut().unwrap().decode(src)? {
Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
Some(PayloadItem::Eof) => Some(Message::Chunk(None)),
None => None,
})
} else if self.flags.contains(Flags::UNHANDLED) {
Ok(None)
} else if let Some((req, payload)) = self.decoder.decode(src)? {
self.flags
.set(Flags::HEAD, req.inner.method == Method::HEAD);
self.version = req.inner.version;
if self.flags.contains(Flags::KEEPALIVE_ENABLED) {
self.flags.set(Flags::KEEPALIVE, req.keep_alive());
}
match payload {
PayloadType::None => self.payload = None,
PayloadType::Payload(pl) => self.payload = Some(pl),
PayloadType::Unhandled => {
self.payload = None;
self.flags.insert(Flags::UNHANDLED);
}
};
Ok(Some(Message::Item(req)))
} else {
Ok(None)
}
}
}
impl Encoder for ClientCodec {
type Item = Message<ClientRequest>;
type Error = io::Error;
fn encode(
&mut self, item: Self::Item, dst: &mut BytesMut,
) -> Result<(), Self::Error> {
match item {
Message::Item(res) => {
self.encode_response(res, dst)?;
}
Message::Chunk(Some(bytes)) => {
self.te.encode(bytes.as_ref(), dst)?;
}
Message::Chunk(None) => {
self.te.encode_eof(dst)?;
}
}
Ok(())
}
}
pub struct Writer<'a>(pub &'a mut BytesMut);
impl<'a> io::Write for Writer<'a> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}

View file

@ -4,15 +4,16 @@ use std::io::{self, Write};
use bytes::{BufMut, Bytes, BytesMut};
use tokio_codec::{Decoder, Encoder};
use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder, RequestPayloadType};
use super::decoder::{PayloadDecoder, PayloadItem, PayloadType, RequestDecoder};
use super::encoder::{ResponseEncoder, ResponseLength};
use super::{Message, MessageType};
use body::{Binary, Body};
use config::ServiceConfig;
use error::ParseError;
use helpers;
use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
use http::{Method, Version};
use request::{Request, RequestPool};
use request::{MessagePool, Request};
use response::Response;
bitflags! {
@ -27,32 +28,6 @@ bitflags! {
const AVERAGE_HEADER_SIZE: usize = 30;
#[derive(Debug)]
/// Http response
pub enum OutMessage {
/// Http response message
Response(Response),
/// Payload chunk
Chunk(Option<Binary>),
}
/// Incoming http/1 request
#[derive(Debug)]
pub enum InMessage {
/// Request
Message(Request, InMessageType),
/// Payload chunk
Chunk(Option<Bytes>),
}
/// Incoming request type
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InMessageType {
None,
Payload,
Unhandled,
}
/// HTTP/1 Codec
pub struct Codec {
config: ServiceConfig,
@ -71,11 +46,11 @@ impl Codec {
///
/// `keepalive_enabled` how response `connection` header get generated.
pub fn new(config: ServiceConfig) -> Self {
Codec::with_pool(RequestPool::pool(), config)
Codec::with_pool(MessagePool::pool(), config)
}
/// Create HTTP/1 codec with request's pool
pub(crate) fn with_pool(pool: &'static RequestPool, config: ServiceConfig) -> Self {
pub(crate) fn with_pool(pool: &'static MessagePool, config: ServiceConfig) -> Self {
let flags = if config.keep_alive_enabled() {
Flags::KEEPALIVE_ENABLED
} else {
@ -104,13 +79,13 @@ impl Codec {
}
/// Check last request's message type
pub fn message_type(&self) -> InMessageType {
pub fn message_type(&self) -> MessageType {
if self.flags.contains(Flags::UNHANDLED) {
InMessageType::Unhandled
MessageType::Unhandled
} else if self.payload.is_none() {
InMessageType::None
MessageType::None
} else {
InMessageType::Payload
MessageType::Payload
}
}
@ -256,14 +231,14 @@ impl Codec {
}
impl Decoder for Codec {
type Item = InMessage;
type Item = Message<Request>;
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.payload.is_some() {
Ok(match self.payload.as_mut().unwrap().decode(src)? {
Some(PayloadItem::Chunk(chunk)) => Some(InMessage::Chunk(Some(chunk))),
Some(PayloadItem::Eof) => Some(InMessage::Chunk(None)),
Some(PayloadItem::Chunk(chunk)) => Some(Message::Chunk(Some(chunk))),
Some(PayloadItem::Eof) => Some(Message::Chunk(None)),
None => None,
})
} else if self.flags.contains(Flags::UNHANDLED) {
@ -275,22 +250,15 @@ impl Decoder for Codec {
if self.flags.contains(Flags::KEEPALIVE_ENABLED) {
self.flags.set(Flags::KEEPALIVE, req.keep_alive());
}
let payload = match payload {
RequestPayloadType::None => {
self.payload = None;
InMessageType::None
}
RequestPayloadType::Payload(pl) => {
self.payload = Some(pl);
InMessageType::Payload
}
RequestPayloadType::Unhandled => {
match payload {
PayloadType::None => self.payload = None,
PayloadType::Payload(pl) => self.payload = Some(pl),
PayloadType::Unhandled => {
self.payload = None;
self.flags.insert(Flags::UNHANDLED);
InMessageType::Unhandled
}
};
Ok(Some(InMessage::Message(req, payload)))
}
Ok(Some(Message::Item(req)))
} else {
Ok(None)
}
@ -298,20 +266,20 @@ impl Decoder for Codec {
}
impl Encoder for Codec {
type Item = OutMessage;
type Item = Message<Response>;
type Error = io::Error;
fn encode(
&mut self, item: Self::Item, dst: &mut BytesMut,
) -> Result<(), Self::Error> {
match item {
OutMessage::Response(res) => {
Message::Item(res) => {
self.encode_response(res, dst)?;
}
OutMessage::Chunk(Some(bytes)) => {
Message::Chunk(Some(bytes)) => {
self.te.encode(bytes.as_ref(), dst)?;
}
OutMessage::Chunk(None) => {
Message::Chunk(None) => {
self.te.encode_eof(dst)?;
}
}

View file

@ -5,38 +5,43 @@ use futures::{Async, Poll};
use httparse;
use tokio_codec::Decoder;
use client::ClientResponse;
use error::ParseError;
use http::header::{HeaderName, HeaderValue};
use http::{header, HttpTryFrom, Method, Uri, Version};
use request::{MessageFlags, Request, RequestPool};
use http::{header, HttpTryFrom, Method, StatusCode, Uri, Version};
use request::{MessageFlags, MessagePool, Request};
use uri::Url;
const MAX_BUFFER_SIZE: usize = 131_072;
const MAX_HEADERS: usize = 96;
pub struct RequestDecoder(&'static RequestPool);
/// Client request decoder
pub struct RequestDecoder(&'static MessagePool);
/// Server response decoder
pub struct ResponseDecoder(&'static MessagePool);
/// Incoming request type
pub enum RequestPayloadType {
pub enum PayloadType {
None,
Payload(PayloadDecoder),
Unhandled,
}
impl RequestDecoder {
pub(crate) fn with_pool(pool: &'static RequestPool) -> RequestDecoder {
pub(crate) fn with_pool(pool: &'static MessagePool) -> RequestDecoder {
RequestDecoder(pool)
}
}
impl Default for RequestDecoder {
fn default() -> RequestDecoder {
RequestDecoder::with_pool(RequestPool::pool())
RequestDecoder::with_pool(MessagePool::pool())
}
}
impl Decoder for RequestDecoder {
type Item = (Request, RequestPayloadType);
type Item = (Request, PayloadType);
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
@ -77,7 +82,7 @@ impl Decoder for RequestDecoder {
let slice = src.split_to(len).freeze();
// convert headers
let mut msg = RequestPool::get(self.0);
let mut msg = MessagePool::get_request(self.0);
{
let inner = msg.inner_mut();
inner
@ -156,18 +161,165 @@ impl Decoder for RequestDecoder {
// https://tools.ietf.org/html/rfc7230#section-3.3.3
let decoder = if chunked {
// Chunked encoding
RequestPayloadType::Payload(PayloadDecoder::chunked())
PayloadType::Payload(PayloadDecoder::chunked())
} else if let Some(len) = content_length {
// Content-Length
RequestPayloadType::Payload(PayloadDecoder::length(len))
PayloadType::Payload(PayloadDecoder::length(len))
} else if has_upgrade || msg.inner.method == Method::CONNECT {
// upgrade(websocket) or connect
RequestPayloadType::Unhandled
PayloadType::Unhandled
} else if src.len() >= MAX_BUFFER_SIZE {
error!("MAX_BUFFER_SIZE unprocessed data reached, closing");
return Err(ParseError::TooLarge);
} else {
RequestPayloadType::None
PayloadType::None
};
Ok(Some((msg, decoder)))
}
}
impl ResponseDecoder {
pub(crate) fn with_pool(pool: &'static MessagePool) -> ResponseDecoder {
ResponseDecoder(pool)
}
}
impl Default for ResponseDecoder {
fn default() -> ResponseDecoder {
ResponseDecoder::with_pool(MessagePool::pool())
}
}
impl Decoder for ResponseDecoder {
type Item = (ClientResponse, PayloadType);
type Error = ParseError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
// Parse http message
let mut chunked = false;
let mut content_length = None;
let msg = {
// Unsafe: we read only this data only after httparse parses headers into.
// performance bump for pipeline benchmarks.
let mut headers: [HeaderIndex; MAX_HEADERS] =
unsafe { mem::uninitialized() };
let (len, version, status, headers_len) = {
let mut parsed: [httparse::Header; MAX_HEADERS] =
unsafe { mem::uninitialized() };
let mut res = httparse::Response::new(&mut parsed);
match res.parse(src)? {
httparse::Status::Complete(len) => {
let version = if res.version.unwrap() == 1 {
Version::HTTP_11
} else {
Version::HTTP_10
};
let status = StatusCode::from_u16(res.code.unwrap())
.map_err(|_| ParseError::Status)?;
HeaderIndex::record(src, res.headers, &mut headers);
(len, version, status, res.headers.len())
}
httparse::Status::Partial => return Ok(None),
}
};
let slice = src.split_to(len).freeze();
// convert headers
let mut msg = MessagePool::get_response(self.0);
{
let inner = msg.inner_mut();
inner
.flags
.get_mut()
.set(MessageFlags::KEEPALIVE, version != Version::HTTP_10);
for idx in headers[..headers_len].iter() {
if let Ok(name) =
HeaderName::from_bytes(&slice[idx.name.0..idx.name.1])
{
// Unsafe: httparse check header value for valid utf-8
let value = unsafe {
HeaderValue::from_shared_unchecked(
slice.slice(idx.value.0, idx.value.1),
)
};
match name {
header::CONTENT_LENGTH => {
if let Ok(s) = value.to_str() {
if let Ok(len) = s.parse::<u64>() {
content_length = Some(len);
} else {
debug!("illegal Content-Length: {:?}", len);
return Err(ParseError::Header);
}
} else {
debug!("illegal Content-Length: {:?}", len);
return Err(ParseError::Header);
}
}
// transfer-encoding
header::TRANSFER_ENCODING => {
if let Ok(s) = value.to_str() {
chunked = s.to_lowercase().contains("chunked");
} else {
return Err(ParseError::Header);
}
}
// connection keep-alive state
header::CONNECTION => {
let ka = if let Ok(conn) = value.to_str() {
if version == Version::HTTP_10
&& conn.contains("keep-alive")
{
true
} else {
version == Version::HTTP_11 && !(conn
.contains("close")
|| conn.contains("upgrade"))
}
} else {
false
};
inner.flags.get_mut().set(MessageFlags::KEEPALIVE, ka);
}
_ => (),
}
inner.headers.append(name, value);
} else {
return Err(ParseError::Header);
}
}
inner.status = status;
inner.version = version;
}
msg
};
// https://tools.ietf.org/html/rfc7230#section-3.3.3
let decoder = if chunked {
// Chunked encoding
PayloadType::Payload(PayloadDecoder::chunked())
} else if let Some(len) = content_length {
// Content-Length
PayloadType::Payload(PayloadDecoder::length(len))
} else if msg.inner.status == StatusCode::SWITCHING_PROTOCOLS
|| msg.inner.method == Method::CONNECT
{
// switching protocol or connect
PayloadType::Unhandled
} else if src.len() >= MAX_BUFFER_SIZE {
error!("MAX_BUFFER_SIZE unprocessed data reached, closing");
return Err(ParseError::TooLarge);
} else {
PayloadType::None
};
Ok(Some((msg, decoder)))
@ -488,36 +640,30 @@ mod tests {
use super::*;
use error::ParseError;
use h1::{InMessage, InMessageType};
use h1::Message;
use httpmessage::HttpMessage;
use request::Request;
impl RequestPayloadType {
impl PayloadType {
fn unwrap(self) -> PayloadDecoder {
match self {
RequestPayloadType::Payload(pl) => pl,
PayloadType::Payload(pl) => pl,
_ => panic!(),
}
}
fn is_unhandled(&self) -> bool {
match self {
RequestPayloadType::Unhandled => true,
PayloadType::Unhandled => true,
_ => false,
}
}
}
impl InMessage {
impl Message<Request> {
fn message(self) -> Request {
match self {
InMessage::Message(req, _) => req,
_ => panic!("error"),
}
}
fn is_payload(&self) -> bool {
match *self {
InMessage::Message(_, payload) => payload == InMessageType::Payload,
Message::Item(req) => req,
_ => panic!("error"),
}
}

View file

@ -18,8 +18,8 @@ use error::DispatchError;
use request::Request;
use response::Response;
use super::codec::{Codec, InMessage, InMessageType, OutMessage};
use super::H1ServiceResult;
use super::codec::Codec;
use super::{H1ServiceResult, Message, MessageType};
const MAX_PIPELINED_MESSAGES: usize = 16;
@ -48,14 +48,14 @@ where
state: State<S>,
payload: Option<PayloadSender>,
messages: VecDeque<Message>,
messages: VecDeque<DispatcherMessage>,
unhandled: Option<Request>,
ka_expire: Instant,
ka_timer: Option<Delay>,
}
enum Message {
enum DispatcherMessage {
Item(Request),
Error(Response),
}
@ -63,8 +63,8 @@ enum Message {
enum State<S: Service> {
None,
ServiceCall(S::Future),
SendResponse(Option<(OutMessage, Body)>),
SendPayload(Option<BodyStream>, Option<OutMessage>),
SendResponse(Option<(Message<Response>, Body)>),
SendPayload(Option<BodyStream>, Option<Message<Response>>),
}
impl<S: Service> State<S> {
@ -176,11 +176,12 @@ where
State::None => loop {
break if let Some(msg) = self.messages.pop_front() {
match msg {
Message::Item(req) => Some(self.handle_request(req)?),
Message::Error(res) => Some(State::SendResponse(Some((
OutMessage::Response(res),
Body::Empty,
)))),
DispatcherMessage::Item(req) => {
Some(self.handle_request(req)?)
}
DispatcherMessage::Error(res) => Some(State::SendResponse(
Some((Message::Item(res), Body::Empty)),
)),
}
} else {
None
@ -196,10 +197,7 @@ where
.get_codec_mut()
.prepare_te(&mut res);
let body = res.replace_body(Body::Empty);
Some(State::SendResponse(Some((
OutMessage::Response(res),
body,
))))
Some(State::SendResponse(Some((Message::Item(res), body))))
}
Async::NotReady => None,
}
@ -216,9 +214,9 @@ where
self.flags.remove(Flags::FLUSHED);
match body {
Body::Empty => Some(State::None),
Body::Binary(bin) => Some(State::SendPayload(
Body::Binary(mut bin) => Some(State::SendPayload(
None,
Some(OutMessage::Chunk(bin.into())),
Some(Message::Chunk(Some(bin.take()))),
)),
Body::Streaming(stream) => {
Some(State::SendPayload(Some(stream), None))
@ -257,7 +255,7 @@ where
.framed
.as_mut()
.unwrap()
.start_send(OutMessage::Chunk(Some(item.into())))
.start_send(Message::Chunk(Some(item.into())))
{
Ok(AsyncSink::Ready) => {
self.flags.remove(Flags::FLUSHED);
@ -271,7 +269,7 @@ where
},
Ok(Async::Ready(None)) => Some(State::SendPayload(
None,
Some(OutMessage::Chunk(None)),
Some(Message::Chunk(None)),
)),
Ok(Async::NotReady) => return Ok(()),
// Err(err) => return Err(DispatchError::Io(err)),
@ -312,7 +310,7 @@ where
.get_codec_mut()
.prepare_te(&mut res);
let body = res.replace_body(Body::Empty);
Ok(State::SendResponse(Some((OutMessage::Response(res), body))))
Ok(State::SendResponse(Some((Message::Item(res), body))))
}
Async::NotReady => Ok(State::ServiceCall(task)),
}
@ -333,14 +331,20 @@ where
self.flags.insert(Flags::STARTED);
match msg {
InMessage::Message(req, payload) => {
match payload {
InMessageType::Payload => {
Message::Item(req) => {
match self
.framed
.as_ref()
.unwrap()
.get_codec()
.message_type()
{
MessageType::Payload => {
let (ps, pl) = Payload::new(false);
*req.inner.payload.borrow_mut() = Some(pl);
self.payload = Some(ps);
}
InMessageType::Unhandled => {
MessageType::Unhandled => {
self.unhandled = Some(req);
return Ok(updated);
}
@ -351,10 +355,10 @@ where
if self.state.is_empty() {
self.state = self.handle_request(req)?;
} else {
self.messages.push_back(Message::Item(req));
self.messages.push_back(DispatcherMessage::Item(req));
}
}
InMessage::Chunk(Some(chunk)) => {
Message::Chunk(Some(chunk)) => {
if let Some(ref mut payload) = self.payload {
payload.feed_data(chunk);
} else {
@ -362,19 +366,19 @@ where
"Internal server error: unexpected payload chunk"
);
self.flags.insert(Flags::DISCONNECTED);
self.messages.push_back(Message::Error(
self.messages.push_back(DispatcherMessage::Error(
Response::InternalServerError().finish(),
));
self.error = Some(DispatchError::InternalError);
}
}
InMessage::Chunk(None) => {
Message::Chunk(None) => {
if let Some(mut payload) = self.payload.take() {
payload.feed_eof();
} else {
error!("Internal server error: unexpected eof");
self.flags.insert(Flags::DISCONNECTED);
self.messages.push_back(Message::Error(
self.messages.push_back(DispatcherMessage::Error(
Response::InternalServerError().finish(),
));
self.error = Some(DispatchError::InternalError);
@ -398,8 +402,9 @@ where
}
// Malformed requests should be responded with 400
self.messages
.push_back(Message::Error(Response::BadRequest().finish()));
self.messages.push_back(DispatcherMessage::Error(
Response::BadRequest().finish(),
));
self.flags.insert(Flags::DISCONNECTED);
self.error = Some(e.into());
break;
@ -443,9 +448,7 @@ where
trace!("Slow request timeout");
self.flags.insert(Flags::STARTED | Flags::DISCONNECTED);
self.state = State::SendResponse(Some((
OutMessage::Response(
Response::RequestTimeout().finish(),
),
Message::Item(Response::RequestTimeout().finish()),
Body::Empty,
)));
}

View file

@ -9,6 +9,7 @@ use http::header::{HeaderValue, ACCEPT_ENCODING, CONTENT_LENGTH};
use http::{StatusCode, Version};
use body::{Binary, Body};
use client::ClientRequest;
use header::ContentEncoding;
use http::Method;
use request::Request;
@ -165,6 +166,39 @@ impl ResponseEncoder {
}
}
#[derive(Debug)]
pub(crate) struct RequestEncoder {
head: bool,
pub length: ResponseLength,
pub te: TransferEncoding,
}
impl Default for RequestEncoder {
fn default() -> Self {
RequestEncoder {
head: false,
length: ResponseLength::None,
te: TransferEncoding::empty(),
}
}
}
impl RequestEncoder {
/// Encode message
pub fn encode(&mut self, msg: &[u8], buf: &mut BytesMut) -> io::Result<bool> {
self.te.encode(msg, buf)
}
/// Encode eof
pub fn encode_eof(&mut self, buf: &mut BytesMut) -> io::Result<()> {
self.te.encode_eof(buf)
}
pub fn update(&mut self, resp: &mut ClientRequest, head: bool, version: Version) {
self.head = head;
}
}
/// Encoders to handle different Transfer-Encodings.
#[derive(Debug)]
pub(crate) struct TransferEncoding {

View file

@ -1,13 +1,16 @@
//! HTTP/1 implementation
use actix_net::codec::Framed;
use bytes::Bytes;
mod client;
mod codec;
mod decoder;
mod dispatcher;
mod encoder;
mod service;
pub use self::codec::{Codec, InMessage, InMessageType, OutMessage};
pub use self::client::ClientCodec;
pub use self::codec::Codec;
pub use self::decoder::{PayloadDecoder, RequestDecoder};
pub use self::dispatcher::Dispatcher;
pub use self::service::{H1Service, H1ServiceHandler, OneRequest};
@ -20,3 +23,26 @@ pub enum H1ServiceResult<T> {
Shutdown(T),
Unhandled(Request, Framed<T, Codec>),
}
#[derive(Debug)]
/// Codec message
pub enum Message<T> {
/// Http message
Item(T),
/// Payload chunk
Chunk(Option<Bytes>),
}
impl<T> From<T> for Message<T> {
fn from(item: T) -> Self {
Message::Item(item)
}
}
/// Incoming request type
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessageType {
None,
Payload,
Unhandled,
}

View file

@ -13,9 +13,9 @@ use error::{DispatchError, ParseError};
use request::Request;
use response::Response;
use super::codec::{Codec, InMessage};
use super::codec::Codec;
use super::dispatcher::Dispatcher;
use super::H1ServiceResult;
use super::{H1ServiceResult, Message};
/// `NewService` implementation for HTTP1 transport
pub struct H1Service<T, S> {
@ -344,10 +344,10 @@ where
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.framed.as_mut().unwrap().poll()? {
Async::Ready(Some(req)) => match req {
InMessage::Message(req, _) => {
Message::Item(req) => {
Ok(Async::Ready((req, self.framed.take().unwrap())))
}
InMessage::Chunk(_) => unreachable!("Something is wrong"),
Message::Chunk(_) => unreachable!("Something is wrong"),
},
Async::Ready(None) => Err(ParseError::Incomplete),
Async::NotReady => Ok(Async::NotReady),

View file

@ -85,6 +85,7 @@ extern crate http as modhttp;
extern crate httparse;
extern crate mime;
extern crate net2;
extern crate percent_encoding;
extern crate rand;
extern crate serde;
extern crate serde_json;
@ -95,12 +96,14 @@ extern crate tokio_current_thread;
extern crate tokio_io;
extern crate tokio_tcp;
extern crate tokio_timer;
extern crate url as urlcrate;
#[cfg(test)]
#[macro_use]
extern crate serde_derive;
mod body;
pub mod client;
mod config;
mod extensions;
mod header;

View file

@ -3,8 +3,9 @@ use std::collections::VecDeque;
use std::fmt;
use std::rc::Rc;
use http::{header, HeaderMap, Method, Uri, Version};
use http::{header, HeaderMap, Method, StatusCode, Uri, Version};
use client::ClientResponse;
use extensions::Extensions;
use httpmessage::HttpMessage;
use payload::Payload;
@ -17,23 +18,24 @@ bitflags! {
}
}
/// Request's context
/// Request
pub struct Request {
pub(crate) inner: Rc<InnerRequest>,
pub(crate) inner: Rc<Message>,
}
pub(crate) struct InnerRequest {
pub(crate) struct Message {
pub(crate) version: Version,
pub(crate) status: StatusCode,
pub(crate) method: Method,
pub(crate) url: Url,
pub(crate) flags: Cell<MessageFlags>,
pub(crate) headers: HeaderMap,
pub(crate) extensions: RefCell<Extensions>,
pub(crate) payload: RefCell<Option<Payload>>,
pool: &'static RequestPool,
pub(crate) pool: &'static MessagePool,
}
impl InnerRequest {
impl Message {
#[inline]
/// Reset request instance
pub fn reset(&mut self) {
@ -64,15 +66,16 @@ impl HttpMessage for Request {
impl Request {
/// Create new Request instance
pub fn new() -> Request {
Request::with_pool(RequestPool::pool())
Request::with_pool(MessagePool::pool())
}
/// Create new Request instance with pool
pub(crate) fn with_pool(pool: &'static RequestPool) -> Request {
pub(crate) fn with_pool(pool: &'static MessagePool) -> Request {
Request {
inner: Rc::new(InnerRequest {
inner: Rc::new(Message {
pool,
method: Method::GET,
status: StatusCode::OK,
url: Url::default(),
version: Version::HTTP_11,
headers: HeaderMap::with_capacity(16),
@ -84,12 +87,12 @@ impl Request {
}
#[inline]
pub(crate) fn inner(&self) -> &InnerRequest {
pub(crate) fn inner(&self) -> &Message {
self.inner.as_ref()
}
#[inline]
pub(crate) fn inner_mut(&mut self) -> &mut InnerRequest {
pub(crate) fn inner_mut(&mut self) -> &mut Message {
Rc::get_mut(&mut self.inner).expect("Multiple copies exist")
}
@ -201,24 +204,24 @@ impl fmt::Debug for Request {
}
/// Request's objects pool
pub(crate) struct RequestPool(RefCell<VecDeque<Rc<InnerRequest>>>);
pub(crate) struct MessagePool(RefCell<VecDeque<Rc<Message>>>);
thread_local!(static POOL: &'static RequestPool = RequestPool::create());
thread_local!(static POOL: &'static MessagePool = MessagePool::create());
impl RequestPool {
fn create() -> &'static RequestPool {
let pool = RequestPool(RefCell::new(VecDeque::with_capacity(128)));
impl MessagePool {
fn create() -> &'static MessagePool {
let pool = MessagePool(RefCell::new(VecDeque::with_capacity(128)));
Box::leak(Box::new(pool))
}
/// Get default request's pool
pub fn pool() -> &'static RequestPool {
pub fn pool() -> &'static MessagePool {
POOL.with(|p| *p)
}
/// Get Request object
#[inline]
pub fn get(pool: &'static RequestPool) -> Request {
pub fn get_request(pool: &'static MessagePool) -> Request {
if let Some(mut msg) = pool.0.borrow_mut().pop_front() {
if let Some(r) = Rc::get_mut(&mut msg) {
r.reset();
@ -228,9 +231,21 @@ impl RequestPool {
Request::with_pool(pool)
}
/// Get Client Response object
#[inline]
pub fn get_response(pool: &'static MessagePool) -> ClientResponse {
if let Some(mut msg) = pool.0.borrow_mut().pop_front() {
if let Some(r) = Rc::get_mut(&mut msg) {
r.reset();
}
return ClientResponse { inner: msg };
}
ClientResponse::with_pool(pool)
}
#[inline]
/// Release request instance
pub(crate) fn release(&self, msg: Rc<InnerRequest>) {
pub(crate) fn release(&self, msg: Rc<Message>) {
let v = &mut self.0.borrow_mut();
if v.len() < 128 {
v.push_front(msg);

View file

@ -8,7 +8,7 @@ use futures::{Async, AsyncSink, Future, Poll, Sink};
use tokio_io::AsyncWrite;
use error::ResponseError;
use h1::{Codec, OutMessage};
use h1::{Codec, Message};
use response::Response;
pub struct SendError<T, R, E>(PhantomData<(T, R, E)>);
@ -59,7 +59,7 @@ where
Ok(r) => Either::A(ok(r)),
Err((e, framed)) => Either::B(SendErrorFut {
framed: Some(framed),
res: Some(OutMessage::Response(e.error_response())),
res: Some(Message::Item(e.error_response())),
err: Some(e),
_t: PhantomData,
}),
@ -68,7 +68,7 @@ where
}
pub struct SendErrorFut<T, R, E> {
res: Option<OutMessage>,
res: Option<Message<Response>>,
framed: Option<Framed<T, Codec>>,
err: Option<E>,
_t: PhantomData<R>,
@ -149,14 +149,14 @@ where
fn call(&mut self, (res, framed): Self::Request) -> Self::Future {
SendResponseFut {
res: Some(OutMessage::Response(res)),
res: Some(Message::Item(res)),
framed: Some(framed),
}
}
}
pub struct SendResponseFut<T> {
res: Option<OutMessage>,
res: Option<Message<Response>>,
framed: Option<Framed<T, Codec>>,
}

95
src/ws/client/connect.rs Normal file
View file

@ -0,0 +1,95 @@
//! Http client request
use std::str;
use cookie::Cookie;
use http::header::{HeaderName, HeaderValue};
use http::{Error as HttpError, HttpTryFrom};
use client::{ClientRequest, ClientRequestBuilder};
use header::IntoHeaderValue;
use super::ClientError;
/// `WebSocket` connection
pub struct Connect {
pub(super) request: ClientRequestBuilder,
pub(super) err: Option<ClientError>,
pub(super) http_err: Option<HttpError>,
pub(super) origin: Option<HeaderValue>,
pub(super) protocols: Option<String>,
pub(super) max_size: usize,
pub(super) server_mode: bool,
}
impl Connect {
/// Create new websocket connection
pub fn new<S: AsRef<str>>(uri: S) -> Connect {
let mut cl = Connect {
request: ClientRequest::build(),
err: None,
http_err: None,
origin: None,
protocols: None,
max_size: 65_536,
server_mode: false,
};
cl.request.uri(uri.as_ref());
cl
}
/// Set supported websocket protocols
pub fn protocols<U, V>(mut self, protos: U) -> Self
where
U: IntoIterator<Item = V> + 'static,
V: AsRef<str>,
{
let mut protos = protos
.into_iter()
.fold(String::new(), |acc, s| acc + s.as_ref() + ",");
protos.pop();
self.protocols = Some(protos);
self
}
/// Set cookie for handshake request
pub fn cookie(mut self, cookie: Cookie) -> Self {
self.request.cookie(cookie);
self
}
/// Set request Origin
pub fn origin<V>(mut self, origin: V) -> Self
where
HeaderValue: HttpTryFrom<V>,
{
match HeaderValue::try_from(origin) {
Ok(value) => self.origin = Some(value),
Err(e) => self.http_err = Some(e.into()),
}
self
}
/// Set max frame size
///
/// By default max size is set to 64kb
pub fn max_frame_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
/// Disable payload masking. By default ws client masks frame payload.
pub fn server_mode(mut self) -> Self {
self.server_mode = true;
self
}
/// Set request header
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: HttpTryFrom<K>,
V: IntoHeaderValue,
{
self.request.header(key, value);
self
}
}

87
src/ws/client/error.rs Normal file
View file

@ -0,0 +1,87 @@
//! Http client request
use std::io;
use actix_net::connector::ConnectorError;
use http::header::HeaderValue;
use http::StatusCode;
use error::ParseError;
use http::Error as HttpError;
use ws::ProtocolError;
/// Websocket client error
#[derive(Fail, Debug)]
pub enum ClientError {
/// Invalid url
#[fail(display = "Invalid url")]
InvalidUrl,
/// Invalid response status
#[fail(display = "Invalid response status")]
InvalidResponseStatus(StatusCode),
/// Invalid upgrade header
#[fail(display = "Invalid upgrade header")]
InvalidUpgradeHeader,
/// Invalid connection header
#[fail(display = "Invalid connection header")]
InvalidConnectionHeader(HeaderValue),
/// Missing CONNECTION header
#[fail(display = "Missing CONNECTION header")]
MissingConnectionHeader,
/// Missing SEC-WEBSOCKET-ACCEPT header
#[fail(display = "Missing SEC-WEBSOCKET-ACCEPT header")]
MissingWebSocketAcceptHeader,
/// Invalid challenge response
#[fail(display = "Invalid challenge response")]
InvalidChallengeResponse(String, HeaderValue),
/// Http parsing error
#[fail(display = "Http parsing error")]
Http(HttpError),
// /// Url parsing error
// #[fail(display = "Url parsing error")]
// Url(UrlParseError),
/// Response parsing error
#[fail(display = "Response parsing error")]
ParseError(ParseError),
/// Protocol error
#[fail(display = "{}", _0)]
Protocol(#[cause] ProtocolError),
/// Connect error
#[fail(display = "{:?}", _0)]
Connect(ConnectorError),
/// IO Error
#[fail(display = "{}", _0)]
Io(io::Error),
/// "Disconnected"
#[fail(display = "Disconnected")]
Disconnected,
}
impl From<HttpError> for ClientError {
fn from(err: HttpError) -> ClientError {
ClientError::Http(err)
}
}
impl From<ConnectorError> for ClientError {
fn from(err: ConnectorError) -> ClientError {
ClientError::Connect(err)
}
}
impl From<ProtocolError> for ClientError {
fn from(err: ProtocolError) -> ClientError {
ClientError::Protocol(err)
}
}
impl From<io::Error> for ClientError {
fn from(err: io::Error) -> ClientError {
ClientError::Io(err)
}
}
impl From<ParseError> for ClientError {
fn from(err: ParseError) -> ClientError {
ClientError::ParseError(err)
}
}

48
src/ws/client/mod.rs Normal file
View file

@ -0,0 +1,48 @@
mod connect;
mod error;
mod service;
pub use self::connect::Connect;
pub use self::error::ClientError;
pub use self::service::Client;
#[derive(PartialEq, Hash, Debug, Clone, Copy)]
pub(crate) enum Protocol {
Http,
Https,
Ws,
Wss,
}
impl Protocol {
fn from(s: &str) -> Option<Protocol> {
match s {
"http" => Some(Protocol::Http),
"https" => Some(Protocol::Https),
"ws" => Some(Protocol::Ws),
"wss" => Some(Protocol::Wss),
_ => None,
}
}
fn is_http(self) -> bool {
match self {
Protocol::Https | Protocol::Http => true,
_ => false,
}
}
fn is_secure(self) -> bool {
match self {
Protocol::Https | Protocol::Wss => true,
_ => false,
}
}
fn port(self) -> u16 {
match self {
Protocol::Http | Protocol::Ws => 80,
Protocol::Https | Protocol::Wss => 443,
}
}
}

270
src/ws/client/service.rs Normal file
View file

@ -0,0 +1,270 @@
//! websockets client
use std::marker::PhantomData;
use actix_net::codec::Framed;
use actix_net::connector::{ConnectorError, DefaultConnector};
use actix_net::service::Service;
use base64;
use futures::future::{err, Either, FutureResult};
use futures::{Async, Future, Poll, Sink, Stream};
use http::header::{self, HeaderValue};
use http::{HttpTryFrom, StatusCode};
use rand;
use sha1::Sha1;
use tokio_io::{AsyncRead, AsyncWrite};
use client::ClientResponse;
use h1;
use ws::Codec;
use super::{ClientError, Connect, Protocol};
/// WebSocket's client
pub struct Client<T>
where
T: Service<Error = ConnectorError>,
T::Response: AsyncRead + AsyncWrite,
{
connector: T,
}
impl<T> Client<T>
where
T: Service<Request = String, Error = ConnectorError>,
T::Response: AsyncRead + AsyncWrite,
{
/// Create new websocket's client factory
pub fn new(connector: T) -> Self {
Client { connector }
}
}
impl Default for Client<DefaultConnector<String>> {
fn default() -> Self {
Client::new(DefaultConnector::default())
}
}
impl<T> Clone for Client<T>
where
T: Service<Request = String, Error = ConnectorError> + Clone,
T::Response: AsyncRead + AsyncWrite,
{
fn clone(&self) -> Self {
Client {
connector: self.connector.clone(),
}
}
}
impl<T> Service for Client<T>
where
T: Service<Request = String, Error = ConnectorError>,
T::Response: AsyncRead + AsyncWrite + 'static,
T::Future: 'static,
{
type Request = Connect;
type Response = Framed<T::Response, Codec>;
type Error = ClientError;
type Future = Either<
FutureResult<Self::Response, Self::Error>,
ClientResponseFut<T::Response>,
>;
fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.connector.poll_ready().map_err(ClientError::from)
}
fn call(&mut self, mut req: Self::Request) -> Self::Future {
if let Some(e) = req.err.take() {
Either::A(err(e))
} else if let Some(e) = req.http_err.take() {
Either::A(err(e.into()))
} else {
// origin
if let Some(origin) = req.origin.take() {
req.request.set_header(header::ORIGIN, origin);
}
req.request.upgrade();
req.request.set_header(header::UPGRADE, "websocket");
req.request.set_header(header::CONNECTION, "upgrade");
req.request.set_header(header::SEC_WEBSOCKET_VERSION, "13");
if let Some(protocols) = req.protocols.take() {
req.request
.set_header(header::SEC_WEBSOCKET_PROTOCOL, protocols.as_str());
}
let mut request = match req.request.finish() {
Ok(req) => req,
Err(e) => return Either::A(err(e.into())),
};
if request.uri().host().is_none() {
return Either::A(err(ClientError::InvalidUrl));
}
// supported protocols
let proto = if let Some(scheme) = request.uri().scheme_part() {
match Protocol::from(scheme.as_str()) {
Some(proto) => proto,
None => return Either::A(err(ClientError::InvalidUrl)),
}
} else {
return Either::A(err(ClientError::InvalidUrl));
};
// Generate a random key for the `Sec-WebSocket-Key` header.
// a base64-encoded (see Section 4 of [RFC4648]) value that,
// when decoded, is 16 bytes in length (RFC 6455)
let sec_key: [u8; 16] = rand::random();
let key = base64::encode(&sec_key);
request.headers_mut().insert(
header::SEC_WEBSOCKET_KEY,
HeaderValue::try_from(key.as_str()).unwrap(),
);
// prep connection
let host = {
let uri = request.uri();
format!(
"{}:{}",
uri.host().unwrap(),
uri.port().unwrap_or_else(|| proto.port())
)
};
let fut = Box::new(
self.connector
.call(host)
.map_err(|e| ClientError::from(e))
.and_then(move |io| {
// h1 protocol
let framed = Framed::new(io, h1::ClientCodec::default());
framed
.send(request.into())
.map_err(|e| ClientError::from(e))
.and_then(|framed| {
framed
.into_future()
.map_err(|(e, _)| ClientError::from(e))
})
}),
);
// start handshake
Either::B(ClientResponseFut {
key,
fut,
max_size: req.max_size,
server_mode: req.server_mode,
_t: PhantomData,
})
}
}
}
/// Future that implementes client websocket handshake process.
///
/// It resolves to a `Framed<T, ws::Codec>` instance.
pub struct ClientResponseFut<T>
where
T: AsyncRead + AsyncWrite,
{
fut: Box<
Future<
Item = (
Option<h1::Message<ClientResponse>>,
Framed<T, h1::ClientCodec>,
),
Error = ClientError,
>,
>,
key: String,
max_size: usize,
server_mode: bool,
_t: PhantomData<T>,
}
impl<T> Future for ClientResponseFut<T>
where
T: AsyncRead + AsyncWrite,
{
type Item = Framed<T, Codec>;
type Error = ClientError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let (item, framed) = try_ready!(self.fut.poll());
let res = match item {
Some(h1::Message::Item(res)) => res,
Some(h1::Message::Chunk(_)) => unreachable!(),
None => return Err(ClientError::Disconnected),
};
// verify response
if res.status() != StatusCode::SWITCHING_PROTOCOLS {
return Err(ClientError::InvalidResponseStatus(res.status()));
}
// Check for "UPGRADE" to websocket header
let has_hdr = if let Some(hdr) = res.headers().get(header::UPGRADE) {
if let Ok(s) = hdr.to_str() {
s.to_lowercase().contains("websocket")
} else {
false
}
} else {
false
};
if !has_hdr {
trace!("Invalid upgrade header");
return Err(ClientError::InvalidUpgradeHeader);
}
// Check for "CONNECTION" header
if let Some(conn) = res.headers().get(header::CONNECTION) {
if let Ok(s) = conn.to_str() {
if !s.to_lowercase().contains("upgrade") {
trace!("Invalid connection header: {}", s);
return Err(ClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
trace!("Invalid connection header: {:?}", conn);
return Err(ClientError::InvalidConnectionHeader(conn.clone()));
}
} else {
trace!("Missing connection header");
return Err(ClientError::MissingConnectionHeader);
}
if let Some(key) = res.headers().get(header::SEC_WEBSOCKET_ACCEPT) {
// field is constructed by concatenating /key/
// with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455)
const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
let mut sha1 = Sha1::new();
sha1.update(self.key.as_ref());
sha1.update(WS_GUID);
let encoded = base64::encode(&sha1.digest().bytes());
if key.as_bytes() != encoded.as_bytes() {
trace!(
"Invalid challenge response: expected: {} received: {:?}",
encoded,
key
);
return Err(ClientError::InvalidChallengeResponse(encoded, key.clone()));
}
} else {
trace!("Missing SEC-WEBSOCKET-ACCEPT header");
return Err(ClientError::MissingWebSocketAcceptHeader);
};
// websockets codec
let codec = if self.server_mode {
Codec::new().max_size(self.max_size)
} else {
Codec::new().max_size(self.max_size).client_mode()
};
Ok(Async::Ready(framed.into_framed(codec)))
}
}

View file

@ -10,6 +10,7 @@ use http::{header, Method, StatusCode};
use request::Request;
use response::{ConnectionType, Response, ResponseBuilder};
mod client;
mod codec;
mod frame;
mod mask;
@ -17,6 +18,7 @@ mod proto;
mod service;
mod transport;
pub use self::client::{Client, ClientError, Connect};
pub use self::codec::{Codec, Frame, Message};
pub use self::frame::Parser;
pub use self::proto::{CloseCode, CloseReason, OpCode};

View file

@ -11,12 +11,12 @@ use actix::System;
use actix_net::codec::Framed;
use actix_net::framed::IntoFramed;
use actix_net::server::Server;
use actix_net::service::NewServiceExt;
use actix_net::service::{NewServiceExt, Service};
use actix_net::stream::TakeItem;
use actix_web::{test, ws as web_ws};
use bytes::Bytes;
use futures::future::{ok, Either};
use futures::{Future, Sink, Stream};
use bytes::{Bytes, BytesMut};
use futures::future::{lazy, ok, Either};
use futures::{Future, IntoFuture, Sink, Stream};
use actix_http::{h1, ws, ResponseError, ServiceConfig};
@ -51,14 +51,14 @@ fn test_simple() {
.and_then(TakeItem::new().map_err(|_| ()))
.and_then(|(req, framed): (_, Framed<_, _>)| {
// validate request
if let Some(h1::InMessage::Message(req, _)) = req {
if let Some(h1::Message::Item(req)) = req {
match ws::verify_handshake(&req) {
Err(e) => {
// validation failed
let resp = e.error_response();
Either::A(
framed
.send(h1::OutMessage::Response(resp))
.send(h1::Message::Item(resp))
.map_err(|_| ())
.map(|_| ()),
)
@ -66,7 +66,7 @@ fn test_simple() {
Ok(_) => Either::B(
// send response
framed
.send(h1::OutMessage::Response(
.send(h1::Message::Item(
ws::handshake_response(&req).finish(),
)).map_err(|_| ())
.and_then(|framed| {
@ -116,4 +116,43 @@ fn test_simple() {
)))
);
}
// client service
let mut client = sys
.block_on(lazy(|| Ok::<_, ()>(ws::Client::default()).into_future()))
.unwrap();
let framed = sys
.block_on(client.call(ws::Connect::new(format!("http://{}/", addr))))
.unwrap();
let framed = sys
.block_on(framed.send(ws::Message::Text("text".to_string())))
.unwrap();
let (item, framed) = sys.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
let framed = sys
.block_on(framed.send(ws::Message::Binary("text".into())))
.unwrap();
let (item, framed) = sys.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(
item,
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
);
let framed = sys
.block_on(framed.send(ws::Message::Ping("text".into())))
.unwrap();
let (item, framed) = sys.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
let framed = sys
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
.unwrap();
let (item, _framed) = sys.block_on(framed.into_future()).map_err(|_| ()).unwrap();
assert_eq!(
item,
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
)
}