diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 809d4d677..fefe05c4c 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -49,7 +49,7 @@ fail = ["failure"] [dependencies] actix-service = "0.3.4" -actix-codec = "0.1.1" +actix-codec = "0.1.2" actix-connect = "0.1.0" actix-utils = "0.3.4" actix-server-config = "0.1.0" diff --git a/actix-http/src/client/connection.rs b/actix-http/src/client/connection.rs index e8c1201aa..267c85d31 100644 --- a/actix-http/src/client/connection.rs +++ b/actix-http/src/client/connection.rs @@ -1,11 +1,13 @@ -use std::{fmt, time}; +use std::{fmt, io, time}; -use actix_codec::{AsyncRead, AsyncWrite}; -use bytes::Bytes; -use futures::Future; +use actix_codec::{AsyncRead, AsyncWrite, Framed}; +use bytes::{Buf, Bytes}; +use futures::future::{err, Either, Future, FutureResult}; +use futures::Poll; use h2::client::SendRequest; use crate::body::MessageBody; +use crate::h1::ClientCodec; use crate::message::{RequestHead, ResponseHead}; use crate::payload::Payload; @@ -19,6 +21,7 @@ pub(crate) enum ConnectionType { } pub trait Connection { + type Io: AsyncRead + AsyncWrite; type Future: Future; /// Send request and body @@ -27,6 +30,14 @@ pub trait Connection { head: RequestHead, body: B, ) -> Self::Future; + + type TunnelFuture: Future< + Item = (ResponseHead, Framed), + Error = SendRequestError, + >; + + /// Send request, returns Response and Framed + fn open_tunnel(self, head: RequestHead) -> Self::TunnelFuture; } pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static { @@ -80,6 +91,7 @@ impl Connection for IoConnection where T: AsyncRead + AsyncWrite + 'static, { + type Io = T; type Future = Box>; fn send_request( @@ -104,6 +116,35 @@ where )), } } + + type TunnelFuture = Either< + Box< + Future< + Item = (ResponseHead, Framed), + Error = SendRequestError, + >, + >, + FutureResult<(ResponseHead, Framed), SendRequestError>, + >; + + /// Send request, returns Response and Framed + fn open_tunnel(mut self, head: RequestHead) -> Self::TunnelFuture { + match self.io.take().unwrap() { + ConnectionType::H1(io) => { + Either::A(Box::new(h1proto::open_tunnel(io, head))) + } + ConnectionType::H2(io) => { + if let Some(mut pool) = self.pool.take() { + pool.release(IoConnection::new( + ConnectionType::H2(io), + self.created, + None, + )); + } + Either::B(err(SendRequestError::TunnelNotSupported)) + } + } + } } #[allow(dead_code)] @@ -117,6 +158,7 @@ where A: AsyncRead + AsyncWrite + 'static, B: AsyncRead + AsyncWrite + 'static, { + type Io = EitherIo; type Future = Box>; fn send_request( @@ -129,4 +171,99 @@ where EitherConnection::B(con) => con.send_request(head, body), } } + + type TunnelFuture = Box< + Future< + Item = (ResponseHead, Framed), + Error = SendRequestError, + >, + >; + + /// Send request, returns Response and Framed + fn open_tunnel(self, head: RequestHead) -> Self::TunnelFuture { + match self { + EitherConnection::A(con) => Box::new( + con.open_tunnel(head) + .map(|(head, framed)| (head, framed.map_io(|io| EitherIo::A(io)))), + ), + EitherConnection::B(con) => Box::new( + con.open_tunnel(head) + .map(|(head, framed)| (head, framed.map_io(|io| EitherIo::B(io)))), + ), + } + } +} + +pub enum EitherIo { + A(A), + B(B), +} + +impl io::Read for EitherIo +where + A: io::Read, + B: io::Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + EitherIo::A(ref mut val) => val.read(buf), + EitherIo::B(ref mut val) => val.read(buf), + } + } +} + +impl AsyncRead for EitherIo +where + A: AsyncRead, + B: AsyncRead, +{ + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match self { + EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf), + EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf), + } + } +} + +impl io::Write for EitherIo +where + A: io::Write, + B: io::Write, +{ + fn write(&mut self, buf: &[u8]) -> io::Result { + match self { + EitherIo::A(ref mut val) => val.write(buf), + EitherIo::B(ref mut val) => val.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match self { + EitherIo::A(ref mut val) => val.flush(), + EitherIo::B(ref mut val) => val.flush(), + } + } +} + +impl AsyncWrite for EitherIo +where + A: AsyncWrite, + B: AsyncWrite, +{ + fn shutdown(&mut self) -> Poll<(), io::Error> { + match self { + EitherIo::A(ref mut val) => val.shutdown(), + EitherIo::B(ref mut val) => val.shutdown(), + } + } + + fn write_buf(&mut self, buf: &mut U) -> Poll + where + Self: Sized, + { + match self { + EitherIo::A(ref mut val) => val.write_buf(buf), + EitherIo::B(ref mut val) => val.write_buf(buf), + } + } } diff --git a/actix-http/src/client/error.rs b/actix-http/src/client/error.rs index 69ec49585..e67db5462 100644 --- a/actix-http/src/client/error.rs +++ b/actix-http/src/client/error.rs @@ -105,6 +105,9 @@ pub enum SendRequestError { /// Http2 error #[display(fmt = "{}", _0)] H2(h2::Error), + /// Tunnels are not supported for http2 connection + #[display(fmt = "Tunnels are not supported for http2 connection")] + TunnelNotSupported, /// Error sending request body Body(Error), } diff --git a/actix-http/src/client/h1proto.rs b/actix-http/src/client/h1proto.rs index b7b8d4a0a..5fec9c4f1 100644 --- a/actix-http/src/client/h1proto.rs +++ b/actix-http/src/client/h1proto.rs @@ -70,6 +70,32 @@ where }) } +pub(crate) fn open_tunnel( + io: T, + head: RequestHead, +) -> impl Future), Error = SendRequestError> +where + T: AsyncRead + AsyncWrite + 'static, +{ + // create Framed and send reqest + Framed::new(io, h1::ClientCodec::default()) + .send((head, BodySize::None).into()) + .from_err() + // read response + .and_then(|framed| { + framed + .into_future() + .map_err(|(e, _)| SendRequestError::from(e)) + .and_then(|(head, framed)| { + if let Some(head) = head { + Ok((head, framed)) + } else { + Err(SendRequestError::from(ConnectError::Disconnected)) + } + }) + }) +} + #[doc(hidden)] /// HTTP client connection pub struct H1Connection { diff --git a/actix-http/src/client/pool.rs b/actix-http/src/client/pool.rs index a94b1e52a..aff11181b 100644 --- a/actix-http/src/client/pool.rs +++ b/actix-http/src/client/pool.rs @@ -411,66 +411,6 @@ where } } -// struct ConnectorPoolSupport -// where -// Io: AsyncRead + AsyncWrite + 'static, -// { -// connector: T, -// inner: Rc>>, -// } - -// impl Future for ConnectorPoolSupport -// where -// Io: AsyncRead + AsyncWrite + 'static, -// T: Service, -// T::Future: 'static, -// { -// type Item = (); -// type Error = (); - -// fn poll(&mut self) -> Poll { -// let mut inner = self.inner.as_ref().borrow_mut(); -// inner.task.register(); - -// // check waiters -// loop { -// let (key, token) = { -// if let Some((key, token)) = inner.waiters_queue.get_index(0) { -// (key.clone(), *token) -// } else { -// break; -// } -// }; -// match inner.acquire(&key) { -// Acquire::NotAvailable => break, -// Acquire::Acquired(io, created) => { -// let (_, tx) = inner.waiters.remove(token); -// if let Err(conn) = tx.send(Ok(IoConnection::new( -// io, -// created, -// Some(Acquired(key.clone(), Some(self.inner.clone()))), -// ))) { -// let (io, created) = conn.unwrap().into_inner(); -// inner.release_conn(&key, io, created); -// } -// } -// Acquire::Available => { -// let (connect, tx) = inner.waiters.remove(token); -// OpenWaitingConnection::spawn( -// key.clone(), -// tx, -// self.inner.clone(), -// self.connector.call(connect), -// ); -// } -// } -// let _ = inner.waiters_queue.swap_remove_index(0); -// } - -// Ok(Async::NotReady) -// } -// } - struct CloseConnection { io: T, timeout: Delay, diff --git a/actix-http/src/h1/decoder.rs b/actix-http/src/h1/decoder.rs index 9b97713fb..dfd9fe25c 100644 --- a/actix-http/src/h1/decoder.rs +++ b/actix-http/src/h1/decoder.rs @@ -611,7 +611,6 @@ mod tests { use super::*; use crate::error::ParseError; use crate::httpmessage::HttpMessage; - use crate::message::Head; impl PayloadType { fn unwrap(self) -> PayloadDecoder { diff --git a/actix-http/src/ws/client/connect.rs b/actix-http/src/ws/client/connect.rs deleted file mode 100644 index 2760967e0..000000000 --- a/actix-http/src/ws/client/connect.rs +++ /dev/null @@ -1,109 +0,0 @@ -//! Http client request -use std::str; - -#[cfg(feature = "cookies")] -use cookie::Cookie; -use http::header::{HeaderName, HeaderValue}; -use http::{Error as HttpError, HttpTryFrom, Uri}; - -use super::ClientError; -use crate::header::IntoHeaderValue; -use crate::message::RequestHead; - -/// `WebSocket` connection -pub struct Connect { - pub(super) head: RequestHead, - pub(super) err: Option, - pub(super) http_err: Option, - pub(super) origin: Option, - pub(super) protocols: Option, - pub(super) max_size: usize, - pub(super) server_mode: bool, -} - -impl Connect { - /// Create new websocket connection - pub fn new>(uri: S) -> Connect { - let mut cl = Connect { - head: RequestHead::default(), - err: None, - http_err: None, - origin: None, - protocols: None, - max_size: 65_536, - server_mode: false, - }; - - match Uri::try_from(uri.as_ref()) { - Ok(uri) => cl.head.uri = uri, - Err(e) => cl.http_err = Some(e.into()), - } - - cl - } - - /// Set supported websocket protocols - pub fn protocols(mut self, protos: U) -> Self - where - U: IntoIterator + 'static, - V: AsRef, - { - let mut protos = protos - .into_iter() - .fold(String::new(), |acc, s| acc + s.as_ref() + ","); - protos.pop(); - self.protocols = Some(protos); - self - } - - // #[cfg(feature = "cookies")] - // /// Set cookie for handshake request - // pub fn cookie(mut self, cookie: Cookie) -> Self { - // self.request.cookie(cookie); - // self - // } - - /// Set request Origin - pub fn origin(mut self, origin: V) -> Self - where - HeaderValue: HttpTryFrom, - { - match HeaderValue::try_from(origin) { - Ok(value) => self.origin = Some(value), - Err(e) => self.http_err = Some(e.into()), - } - self - } - - /// Set max frame size - /// - /// By default max size is set to 64kb - pub fn max_frame_size(mut self, size: usize) -> Self { - self.max_size = size; - self - } - - /// Disable payload masking. By default ws client masks frame payload. - pub fn server_mode(mut self) -> Self { - self.server_mode = true; - self - } - - /// Set request header - pub fn header(mut self, key: K, value: V) -> Self - where - HeaderName: HttpTryFrom, - V: IntoHeaderValue, - { - match HeaderName::try_from(key) { - Ok(key) => match value.try_into() { - Ok(value) => { - self.head.headers.append(key, value); - } - Err(e) => self.http_err = Some(e.into()), - }, - Err(e) => self.http_err = Some(e.into()), - } - self - } -} diff --git a/actix-http/src/ws/client/mod.rs b/actix-http/src/ws/client/mod.rs deleted file mode 100644 index a5c221967..000000000 --- a/actix-http/src/ws/client/mod.rs +++ /dev/null @@ -1,48 +0,0 @@ -mod connect; -mod error; -mod service; - -pub use self::connect::Connect; -pub use self::error::ClientError; -pub use self::service::Client; - -#[derive(PartialEq, Hash, Debug, Clone, Copy)] -pub(crate) enum Protocol { - Http, - Https, - Ws, - Wss, -} - -impl Protocol { - fn from(s: &str) -> Option { - match s { - "http" => Some(Protocol::Http), - "https" => Some(Protocol::Https), - "ws" => Some(Protocol::Ws), - "wss" => Some(Protocol::Wss), - _ => None, - } - } - - // fn is_http(self) -> bool { - // match self { - // Protocol::Https | Protocol::Http => true, - // _ => false, - // } - // } - - // fn is_secure(self) -> bool { - // match self { - // Protocol::Https | Protocol::Wss => true, - // _ => false, - // } - // } - - fn port(self) -> u16 { - match self { - Protocol::Http | Protocol::Ws => 80, - Protocol::Https | Protocol::Wss => 443, - } - } -} diff --git a/actix-http/src/ws/client/service.rs b/actix-http/src/ws/client/service.rs deleted file mode 100644 index bc86e516a..000000000 --- a/actix-http/src/ws/client/service.rs +++ /dev/null @@ -1,272 +0,0 @@ -//! websockets client -use std::marker::PhantomData; - -use actix_codec::{AsyncRead, AsyncWrite, Framed}; -use actix_connect::{default_connector, Connect as TcpConnect, ConnectError}; -use actix_service::{apply_fn, Service}; -use base64; -use futures::future::{err, Either, FutureResult}; -use futures::{try_ready, Async, Future, Poll, Sink, Stream}; -use http::header::{self, HeaderValue}; -use http::{HttpTryFrom, StatusCode}; -use log::trace; -use rand; -use sha1::Sha1; - -use crate::body::BodySize; -use crate::h1; -use crate::message::{ConnectionType, ResponseHead}; -use crate::ws::Codec; - -use super::{ClientError, Connect, Protocol}; - -/// WebSocket's client -pub struct Client { - connector: T, -} - -impl Client<()> { - /// Create client with default connector. - pub fn default() -> Client< - impl Service< - Request = TcpConnect, - Response = impl AsyncRead + AsyncWrite, - Error = ConnectError, - > + Clone, - > { - Client::new(apply_fn(default_connector(), |msg: TcpConnect<_>, srv| { - srv.call(msg).map(|stream| stream.into_parts().0) - })) - } -} - -impl Client -where - T: Service, Error = ConnectError>, - T::Response: AsyncRead + AsyncWrite, -{ - /// Create new websocket's client factory - pub fn new(connector: T) -> Self { - Client { connector } - } -} - -impl Clone for Client -where - T: Service, Error = ConnectError> + Clone, - T::Response: AsyncRead + AsyncWrite, -{ - fn clone(&self) -> Self { - Client { - connector: self.connector.clone(), - } - } -} - -impl Service for Client -where - T: Service, Error = ConnectError>, - T::Response: AsyncRead + AsyncWrite + 'static, - T::Future: 'static, -{ - type Request = Connect; - type Response = Framed; - type Error = ClientError; - type Future = Either< - FutureResult, - ClientResponseFut, - >; - - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.connector.poll_ready().map_err(ClientError::from) - } - - fn call(&mut self, mut req: Connect) -> Self::Future { - if let Some(e) = req.err.take() { - Either::A(err(e)) - } else if let Some(e) = req.http_err.take() { - Either::A(err(e.into())) - } else { - // origin - if let Some(origin) = req.origin.take() { - req.head.headers.insert(header::ORIGIN, origin); - } - - req.head.set_connection_type(ConnectionType::Upgrade); - req.head - .headers - .insert(header::UPGRADE, HeaderValue::from_static("websocket")); - req.head.headers.insert( - header::SEC_WEBSOCKET_VERSION, - HeaderValue::from_static("13"), - ); - - if let Some(protocols) = req.protocols.take() { - req.head.headers.insert( - header::SEC_WEBSOCKET_PROTOCOL, - HeaderValue::try_from(protocols.as_str()).unwrap(), - ); - } - if let Some(e) = req.http_err { - return Either::A(err(e.into())); - }; - - let mut request = req.head; - if request.uri.host().is_none() { - return Either::A(err(ClientError::InvalidUrl)); - } - - // supported protocols - let proto = if let Some(scheme) = request.uri.scheme_part() { - match Protocol::from(scheme.as_str()) { - Some(proto) => proto, - None => return Either::A(err(ClientError::InvalidUrl)), - } - } else { - return Either::A(err(ClientError::InvalidUrl)); - }; - - // Generate a random key for the `Sec-WebSocket-Key` header. - // a base64-encoded (see Section 4 of [RFC4648]) value that, - // when decoded, is 16 bytes in length (RFC 6455) - let sec_key: [u8; 16] = rand::random(); - let key = base64::encode(&sec_key); - - request.headers.insert( - header::SEC_WEBSOCKET_KEY, - HeaderValue::try_from(key.as_str()).unwrap(), - ); - - // prep connection - let connect = TcpConnect::new(request.uri.host().unwrap().to_string()) - .set_port(request.uri.port_u16().unwrap_or_else(|| proto.port())); - - let fut = Box::new( - self.connector - .call(connect) - .map_err(ClientError::from) - .and_then(move |io| { - // h1 protocol - let framed = Framed::new(io, h1::ClientCodec::default()); - framed - .send((request, BodySize::None).into()) - .map_err(ClientError::from) - .and_then(|framed| { - framed - .into_future() - .map_err(|(e, _)| ClientError::from(e)) - }) - }), - ); - - // start handshake - Either::B(ClientResponseFut { - key, - fut, - max_size: req.max_size, - server_mode: req.server_mode, - _t: PhantomData, - }) - } - } -} - -/// Future that implementes client websocket handshake process. -/// -/// It resolves to a `Framed` instance. -pub struct ClientResponseFut -where - T: AsyncRead + AsyncWrite, -{ - fut: Box< - Future< - Item = (Option, Framed), - Error = ClientError, - >, - >, - key: String, - max_size: usize, - server_mode: bool, - _t: PhantomData, -} - -impl Future for ClientResponseFut -where - T: AsyncRead + AsyncWrite, -{ - type Item = Framed; - type Error = ClientError; - - fn poll(&mut self) -> Poll { - let (item, framed) = try_ready!(self.fut.poll()); - - let res = match item { - Some(res) => res, - None => return Err(ClientError::Disconnected), - }; - - // verify response - if res.status != StatusCode::SWITCHING_PROTOCOLS { - return Err(ClientError::InvalidResponseStatus(res.status)); - } - // Check for "UPGRADE" to websocket header - let has_hdr = if let Some(hdr) = res.headers.get(header::UPGRADE) { - if let Ok(s) = hdr.to_str() { - s.to_lowercase().contains("websocket") - } else { - false - } - } else { - false - }; - if !has_hdr { - trace!("Invalid upgrade header"); - return Err(ClientError::InvalidUpgradeHeader); - } - // Check for "CONNECTION" header - if let Some(conn) = res.headers.get(header::CONNECTION) { - if let Ok(s) = conn.to_str() { - if !s.to_lowercase().contains("upgrade") { - trace!("Invalid connection header: {}", s); - return Err(ClientError::InvalidConnectionHeader(conn.clone())); - } - } else { - trace!("Invalid connection header: {:?}", conn); - return Err(ClientError::InvalidConnectionHeader(conn.clone())); - } - } else { - trace!("Missing connection header"); - return Err(ClientError::MissingConnectionHeader); - } - - if let Some(key) = res.headers.get(header::SEC_WEBSOCKET_ACCEPT) { - // field is constructed by concatenating /key/ - // with the string "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (RFC 6455) - const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - let mut sha1 = Sha1::new(); - sha1.update(self.key.as_ref()); - sha1.update(WS_GUID); - let encoded = base64::encode(&sha1.digest().bytes()); - if key.as_bytes() != encoded.as_bytes() { - trace!( - "Invalid challenge response: expected: {} received: {:?}", - encoded, - key - ); - return Err(ClientError::InvalidChallengeResponse(encoded, key.clone())); - } - } else { - trace!("Missing SEC-WEBSOCKET-ACCEPT header"); - return Err(ClientError::MissingWebSocketAcceptHeader); - }; - - // websockets codec - let codec = if self.server_mode { - Codec::new().max_size(self.max_size) - } else { - Codec::new().max_size(self.max_size).client_mode() - }; - - Ok(Async::Ready(framed.into_framed(codec))) - } -} diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 88fabde94..065c34d93 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -13,7 +13,6 @@ use crate::httpmessage::HttpMessage; use crate::request::Request; use crate::response::{Response, ResponseBuilder}; -mod client; mod codec; mod frame; mod mask; @@ -21,7 +20,6 @@ mod proto; mod service; mod transport; -pub use self::client::{Client, ClientError, Connect}; pub use self::codec::{Codec, Frame, Message}; pub use self::frame::Parser; pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode}; diff --git a/awc/Cargo.toml b/awc/Cargo.toml index e08169c96..c915475f8 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -39,13 +39,16 @@ flate2-zlib = ["actix-http/flate2-zlib"] flate2-rust = ["actix-http/flate2-rust"] [dependencies] +actix-codec = "0.1.1" actix-service = "0.3.4" actix-http = { path = "../actix-http/" } base64 = "0.10.1" bytes = "0.4" +derive_more = "0.14" futures = "0.1" log =" 0.4" percent-encoding = "1.0" +rand = "0.6" serde = "1.0" serde_json = "1.0" serde_urlencoded = "0.5.3" @@ -58,8 +61,11 @@ actix-rt = "0.2.1" actix-web = { path = "..", features=["ssl"] } actix-http = { path = "../actix-http/", features=["ssl"] } actix-http-test = { path = "../test-server/", features=["ssl"] } +actix-utils = "0.3.4" +actix-server = { version = "0.4.0", features=["ssl"] } brotli2 = { version="^0.3.2" } flate2 = { version="^1.0.2" } env_logger = "0.6" mime = "0.3" rand = "0.6" +tokio-tcp = "0.1" \ No newline at end of file diff --git a/awc/src/connect.rs b/awc/src/connect.rs index a07662791..77cd1fbff 100644 --- a/awc/src/connect.rs +++ b/awc/src/connect.rs @@ -1,8 +1,12 @@ +use std::io; + +use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::body::Body; use actix_http::client::{ConnectError, Connection, SendRequestError}; -use actix_http::{http, RequestHead}; +use actix_http::h1::ClientCodec; +use actix_http::{http, RequestHead, ResponseHead}; use actix_service::Service; -use futures::Future; +use futures::{Future, Poll}; use crate::response::ClientResponse; @@ -14,13 +18,26 @@ pub(crate) trait Connect { head: RequestHead, body: Body, ) -> Box>; + + /// Send request, returns Response and Framed + fn open_tunnel( + &mut self, + head: RequestHead, + ) -> Box< + Future< + Item = (ResponseHead, Framed), + Error = SendRequestError, + >, + >; } impl Connect for ConnectorWrapper where T: Service, T::Response: Connection, + ::Io: 'static, ::Future: 'static, + ::TunnelFuture: 'static, T::Future: 'static, { fn send_request( @@ -38,4 +55,77 @@ where .map(|(head, payload)| ClientResponse::new(head, payload)), ) } + + fn open_tunnel( + &mut self, + head: RequestHead, + ) -> Box< + Future< + Item = (ResponseHead, Framed), + Error = SendRequestError, + >, + > { + Box::new( + self.0 + // connect to the host + .call(head.uri.clone()) + .from_err() + // send request + .and_then(move |connection| connection.open_tunnel(head)) + .map(|(head, framed)| { + let framed = framed.map_io(|io| BoxedSocket(Box::new(Socket(io)))); + (head, framed) + }), + ) + } +} + +trait AsyncSocket { + fn as_read(&self) -> &AsyncRead; + fn as_read_mut(&mut self) -> &mut AsyncRead; + fn as_write(&mut self) -> &mut AsyncWrite; +} + +struct Socket(T); + +impl AsyncSocket for Socket { + fn as_read(&self) -> &AsyncRead { + &self.0 + } + fn as_read_mut(&mut self) -> &mut AsyncRead { + &mut self.0 + } + fn as_write(&mut self) -> &mut AsyncWrite { + &mut self.0 + } +} + +pub struct BoxedSocket(Box); + +impl io::Read for BoxedSocket { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.as_read_mut().read(buf) + } +} + +impl AsyncRead for BoxedSocket { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.0.as_read().prepare_uninitialized_buffer(buf) + } +} + +impl io::Write for BoxedSocket { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.as_write().write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.as_write().flush() + } +} + +impl AsyncWrite for BoxedSocket { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.0.as_write().shutdown() + } } diff --git a/actix-http/src/ws/client/error.rs b/awc/src/error.rs similarity index 55% rename from actix-http/src/ws/client/error.rs rename to awc/src/error.rs index ae1e39967..d3f1c1a17 100644 --- a/actix-http/src/ws/client/error.rs +++ b/awc/src/error.rs @@ -1,19 +1,14 @@ -//! Http client request -use std::io; +//! Http client errors +pub use actix_http::client::{ConnectError, InvalidUrl, SendRequestError}; +pub use actix_http::error::PayloadError; +pub use actix_http::ws::ProtocolError as WsProtocolError; -use actix_connect::ConnectError; +use actix_http::http::{header::HeaderValue, Error as HttpError, StatusCode}; use derive_more::{Display, From}; -use http::{header::HeaderValue, Error as HttpError, StatusCode}; - -use crate::error::ParseError; -use crate::ws::ProtocolError; /// Websocket client error #[derive(Debug, Display, From)] -pub enum ClientError { - /// Invalid url - #[display(fmt = "Invalid url")] - InvalidUrl, +pub enum WsClientError { /// Invalid response status #[display(fmt = "Invalid response status")] InvalidResponseStatus(StatusCode), @@ -32,22 +27,22 @@ pub enum ClientError { /// Invalid challenge response #[display(fmt = "Invalid challenge response")] InvalidChallengeResponse(String, HeaderValue), - /// Http parsing error - #[display(fmt = "Http parsing error")] - Http(HttpError), - /// Response parsing error - #[display(fmt = "Response parsing error: {}", _0)] - ParseError(ParseError), /// Protocol error #[display(fmt = "{}", _0)] - Protocol(ProtocolError), - /// Connect error - #[display(fmt = "Connector error: {:?}", _0)] - Connect(ConnectError), - /// IO Error + Protocol(WsProtocolError), + /// Send request error #[display(fmt = "{}", _0)] - Io(io::Error), - /// "Disconnected" - #[display(fmt = "Disconnected")] - Disconnected, + SendRequest(SendRequestError), +} + +impl From for WsClientError { + fn from(err: InvalidUrl) -> Self { + WsClientError::SendRequest(err.into()) + } +} + +impl From for WsClientError { + fn from(err: HttpError) -> Self { + WsClientError::SendRequest(err.into()) + } } diff --git a/awc/src/lib.rs b/awc/src/lib.rs index 3bad8caa3..9f5ca1f28 100644 --- a/awc/src/lib.rs +++ b/awc/src/lib.rs @@ -23,8 +23,6 @@ use std::cell::RefCell; use std::rc::Rc; -pub use actix_http::client::{ConnectError, InvalidUrl, SendRequestError}; -pub use actix_http::error::PayloadError; pub use actix_http::http; use actix_http::client::Connector; @@ -33,13 +31,16 @@ use actix_http::RequestHead; mod builder; mod connect; +pub mod error; mod request; mod response; pub mod test; +mod ws; pub use self::builder::ClientBuilder; pub use self::request::ClientRequest; pub use self::response::ClientResponse; +pub use self::ws::WebsocketsRequest; use self::connect::{Connect, ConnectorWrapper}; @@ -165,4 +166,11 @@ impl Client { { ClientRequest::new(Method::OPTIONS, url, self.connector.clone()) } + + pub fn ws(&self, url: U) -> WebsocketsRequest + where + Uri: HttpTryFrom, + { + WebsocketsRequest::new(url, self.connector.clone()) + } } diff --git a/awc/src/request.rs b/awc/src/request.rs index 7beb737e1..c0962ebf1 100644 --- a/awc/src/request.rs +++ b/awc/src/request.rs @@ -12,7 +12,6 @@ use serde::Serialize; use serde_json; use actix_http::body::{Body, BodyStream}; -use actix_http::client::{InvalidUrl, SendRequestError}; use actix_http::encoding::Decoder; use actix_http::http::header::{self, ContentEncoding, Header, IntoHeaderValue}; use actix_http::http::{ @@ -21,8 +20,9 @@ use actix_http::http::{ }; use actix_http::{Error, Payload, RequestHead}; +use crate::connect::Connect; +use crate::error::{InvalidUrl, PayloadError, SendRequestError}; use crate::response::ClientResponse; -use crate::{Connect, PayloadError}; #[cfg(any(feature = "brotli", feature = "flate2-zlib", feature = "flate2-rust"))] const HTTPS_ENCODING: &str = "br, gzip, deflate"; diff --git a/awc/src/ws.rs b/awc/src/ws.rs new file mode 100644 index 000000000..f959e62c5 --- /dev/null +++ b/awc/src/ws.rs @@ -0,0 +1,398 @@ +//! Websockets client +use std::cell::RefCell; +use std::io::Write; +use std::rc::Rc; +use std::{fmt, str}; + +use actix_codec::Framed; +use actix_http::{ws, Payload, RequestHead}; +use bytes::{BufMut, BytesMut}; +#[cfg(feature = "cookies")] +use cookie::{Cookie, CookieJar}; +use futures::future::{err, Either, Future}; + +use crate::connect::{BoxedSocket, Connect}; +use crate::error::{InvalidUrl, WsClientError}; +use crate::http::header::{ + self, HeaderName, HeaderValue, IntoHeaderValue, AUTHORIZATION, +}; +use crate::http::{ + ConnectionType, Error as HttpError, HttpTryFrom, Method, StatusCode, Uri, Version, +}; +use crate::response::ClientResponse; + +/// `WebSocket` connection +pub struct WebsocketsRequest { + head: RequestHead, + err: Option, + origin: Option, + protocols: Option, + max_size: usize, + server_mode: bool, + default_headers: bool, + #[cfg(feature = "cookies")] + cookies: Option, + connector: Rc>, +} + +impl WebsocketsRequest { + /// Create new websocket connection + pub(crate) fn new(uri: U, connector: Rc>) -> Self + where + Uri: HttpTryFrom, + { + let mut err = None; + let mut head = RequestHead::default(); + head.method = Method::GET; + head.version = Version::HTTP_11; + + match Uri::try_from(uri) { + Ok(uri) => head.uri = uri, + Err(e) => err = Some(e.into()), + } + + WebsocketsRequest { + head, + err, + connector, + origin: None, + protocols: None, + max_size: 65_536, + server_mode: false, + #[cfg(feature = "cookies")] + cookies: None, + default_headers: true, + } + } + + /// Set supported websocket protocols + pub fn protocols(mut self, protos: U) -> Self + where + U: IntoIterator + 'static, + V: AsRef, + { + let mut protos = protos + .into_iter() + .fold(String::new(), |acc, s| acc + s.as_ref() + ","); + protos.pop(); + self.protocols = Some(protos); + self + } + + #[cfg(feature = "cookies")] + /// Set a cookie + pub fn cookie<'c>(mut self, cookie: Cookie<'c>) -> Self { + if self.cookies.is_none() { + let mut jar = CookieJar::new(); + jar.add(cookie.into_owned()); + self.cookies = Some(jar) + } else { + self.cookies.as_mut().unwrap().add(cookie.into_owned()); + } + self + } + + /// Set request Origin + pub fn origin(mut self, origin: V) -> Self + where + HeaderValue: HttpTryFrom, + { + match HeaderValue::try_from(origin) { + Ok(value) => self.origin = Some(value), + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Set max frame size + /// + /// By default max size is set to 64kb + pub fn max_frame_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Disable payload masking. By default ws client masks frame payload. + pub fn server_mode(mut self) -> Self { + self.server_mode = true; + self + } + + /// Do not add default request headers. + /// By default `Date` and `User-Agent` headers are set. + pub fn no_default_headers(mut self) -> Self { + self.default_headers = false; + self + } + + /// Append a header. + /// + /// Header gets appended to existing header. + /// To override header use `set_header()` method. + pub fn header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + self.head.headers.append(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Insert a header, replaces existing header. + pub fn set_header(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => match value.try_into() { + Ok(value) => { + self.head.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + }, + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Insert a header only if it is not yet set. + pub fn set_header_if_none(mut self, key: K, value: V) -> Self + where + HeaderName: HttpTryFrom, + V: IntoHeaderValue, + { + match HeaderName::try_from(key) { + Ok(key) => { + if !self.head.headers.contains_key(&key) { + match value.try_into() { + Ok(value) => { + self.head.headers.insert(key, value); + } + Err(e) => self.err = Some(e.into()), + } + } + } + Err(e) => self.err = Some(e.into()), + } + self + } + + /// Set HTTP basic authorization header + pub fn basic_auth(self, username: U, password: Option

) -> Self + where + U: fmt::Display, + P: fmt::Display, + { + let auth = match password { + Some(password) => format!("{}:{}", username, password), + None => format!("{}", username), + }; + self.header(AUTHORIZATION, format!("Basic {}", base64::encode(&auth))) + } + + /// Set HTTP bearer authentication header + pub fn bearer_auth(self, token: T) -> Self + where + T: fmt::Display, + { + self.header(AUTHORIZATION, format!("Bearer {}", token)) + } + + /// Complete request construction and connect. + pub fn connect( + mut self, + ) -> impl Future< + Item = (ClientResponse, Framed), + Error = WsClientError, + > { + if let Some(e) = self.err.take() { + return Either::A(err(e.into())); + } + + // validate uri + let uri = &self.head.uri; + if uri.host().is_none() { + return Either::A(err(InvalidUrl::MissingHost.into())); + } else if uri.scheme_part().is_none() { + return Either::A(err(InvalidUrl::MissingScheme.into())); + } else if let Some(scheme) = uri.scheme_part() { + match scheme.as_str() { + "http" | "ws" | "https" | "wss" => (), + _ => return Either::A(err(InvalidUrl::UnknownScheme.into())), + } + } else { + return Either::A(err(InvalidUrl::UnknownScheme.into())); + } + + // set default headers + let mut slf = if self.default_headers { + // set request host header + if let Some(host) = self.head.uri.host() { + if !self.head.headers.contains_key(header::HOST) { + let mut wrt = BytesMut::with_capacity(host.len() + 5).writer(); + + let _ = match self.head.uri.port_u16() { + None | Some(80) | Some(443) => write!(wrt, "{}", host), + Some(port) => write!(wrt, "{}:{}", host, port), + }; + + match wrt.get_mut().take().freeze().try_into() { + Ok(value) => { + self.head.headers.insert(header::HOST, value); + } + Err(e) => return Either::A(err(HttpError::from(e).into())), + } + } + } + + // user agent + self.set_header_if_none( + header::USER_AGENT, + concat!("awc/", env!("CARGO_PKG_VERSION")), + ) + } else { + self + }; + + #[allow(unused_mut)] + let mut head = slf.head; + + #[cfg(feature = "cookies")] + { + use percent_encoding::{percent_encode, USERINFO_ENCODE_SET}; + use std::fmt::Write; + + // set cookies + if let Some(ref mut jar) = slf.cookies { + let mut cookie = String::new(); + for c in jar.delta() { + let name = percent_encode(c.name().as_bytes(), USERINFO_ENCODE_SET); + let value = + percent_encode(c.value().as_bytes(), USERINFO_ENCODE_SET); + let _ = write!(&mut cookie, "; {}={}", name, value); + } + head.headers.insert( + header::COOKIE, + HeaderValue::from_str(&cookie.as_str()[2..]).unwrap(), + ); + } + } + + // origin + if let Some(origin) = slf.origin.take() { + head.headers.insert(header::ORIGIN, origin); + } + + head.set_connection_type(ConnectionType::Upgrade); + head.headers + .insert(header::UPGRADE, HeaderValue::from_static("websocket")); + head.headers.insert( + header::SEC_WEBSOCKET_VERSION, + HeaderValue::from_static("13"), + ); + + if let Some(protocols) = slf.protocols.take() { + head.headers.insert( + header::SEC_WEBSOCKET_PROTOCOL, + HeaderValue::try_from(protocols.as_str()).unwrap(), + ); + } + + // Generate a random key for the `Sec-WebSocket-Key` header. + // a base64-encoded (see Section 4 of [RFC4648]) value that, + // when decoded, is 16 bytes in length (RFC 6455) + let sec_key: [u8; 16] = rand::random(); + let key = base64::encode(&sec_key); + + head.headers.insert( + header::SEC_WEBSOCKET_KEY, + HeaderValue::try_from(key.as_str()).unwrap(), + ); + + let max_size = slf.max_size; + let server_mode = slf.server_mode; + + let fut = slf + .connector + .borrow_mut() + .open_tunnel(head) + .from_err() + .and_then(move |(head, framed)| { + // verify response + if head.status != StatusCode::SWITCHING_PROTOCOLS { + return Err(WsClientError::InvalidResponseStatus(head.status)); + } + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some(hdr) = head.headers.get(header::UPGRADE) { + if let Ok(s) = hdr.to_str() { + s.to_ascii_lowercase().contains("websocket") + } else { + false + } + } else { + false + }; + if !has_hdr { + log::trace!("Invalid upgrade header"); + return Err(WsClientError::InvalidUpgradeHeader); + } + // Check for "CONNECTION" header + if let Some(conn) = head.headers.get(header::CONNECTION) { + if let Ok(s) = conn.to_str() { + if !s.to_ascii_lowercase().contains("upgrade") { + log::trace!("Invalid connection header: {}", s); + return Err(WsClientError::InvalidConnectionHeader( + conn.clone(), + )); + } + } else { + log::trace!("Invalid connection header: {:?}", conn); + return Err(WsClientError::InvalidConnectionHeader(conn.clone())); + } + } else { + log::trace!("Missing connection header"); + return Err(WsClientError::MissingConnectionHeader); + } + + if let Some(hdr_key) = head.headers.get(header::SEC_WEBSOCKET_ACCEPT) { + let encoded = ws::hash_key(key.as_ref()); + if hdr_key.as_bytes() != encoded.as_bytes() { + log::trace!( + "Invalid challenge response: expected: {} received: {:?}", + encoded, + key + ); + return Err(WsClientError::InvalidChallengeResponse( + encoded, + hdr_key.clone(), + )); + } + } else { + log::trace!("Missing SEC-WEBSOCKET-ACCEPT header"); + return Err(WsClientError::MissingWebSocketAcceptHeader); + }; + + // response and ws framed + Ok(( + ClientResponse::new(head, Payload::None), + framed.map_codec(|_| { + if server_mode { + ws::Codec::new().max_size(max_size) + } else { + ws::Codec::new().max_size(max_size).client_mode() + } + }), + )) + }); + Either::B(fut) + } +} diff --git a/actix-http/tests/test_ws.rs b/awc/tests/test_ws.rs similarity index 100% rename from actix-http/tests/test_ws.rs rename to awc/tests/test_ws.rs diff --git a/src/lib.rs b/src/lib.rs index 5b0ce7841..d3d66c616 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -393,8 +393,8 @@ pub mod client { //! })); //! } //! ``` - pub use awc::{ - test, Client, ClientBuilder, ClientRequest, ClientResponse, ConnectError, - InvalidUrl, PayloadError, SendRequestError, + pub use awc::error::{ + ConnectError, InvalidUrl, PayloadError, SendRequestError, WsClientError, }; + pub use awc::{test, Client, ClientBuilder, ClientRequest, ClientResponse}; } diff --git a/test-server/src/lib.rs b/test-server/src/lib.rs index 7cd94d4d2..07a0e0b4c 100644 --- a/test-server/src/lib.rs +++ b/test-server/src/lib.rs @@ -7,7 +7,6 @@ use actix_http::client::Connector; use actix_http::ws; use actix_rt::{Runtime, System}; use actix_server::{Server, StreamServiceFactory}; -use actix_service::Service; use awc::{Client, ClientRequest}; use futures::future::{lazy, Future}; use http::Method; @@ -205,16 +204,19 @@ impl TestServerRuntime { pub fn ws_at( &mut self, path: &str, - ) -> Result, ws::ClientError> { + ) -> Result, awc::error::WsClientError> + { let url = self.url(path); + let connect = self.client.ws(url).connect(); self.rt - .block_on(lazy(|| ws::Client::default().call(ws::Connect::new(url)))) + .block_on(lazy(move || connect.map(|(_, framed)| framed))) } /// Connect to a websocket server pub fn ws( &mut self, - ) -> Result, ws::ClientError> { + ) -> Result, awc::error::WsClientError> + { self.ws_at("/") } }