From f2d20514fa41cb640906cca9a69c04d557248ad8 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 7 Oct 2017 21:48:00 -0700 Subject: [PATCH] websocket support --- Cargo.toml | 6 +- README.md | 7 +- src/httpcodes.rs | 13 +- src/httpmessage.rs | 81 ++++++-- src/lib.rs | 8 + src/main.rs | 62 +++++- src/reader.rs | 10 +- src/resource.rs | 9 +- src/route.rs | 12 +- src/router.rs | 2 +- src/task.rs | 42 ++-- src/ws.rs | 248 +++++++++++++++++++++++ src/wsframe.rs | 496 +++++++++++++++++++++++++++++++++++++++++++++ src/wsproto.rs | 300 +++++++++++++++++++++++++++ 14 files changed, 1232 insertions(+), 64 deletions(-) create mode 100644 src/ws.rs create mode 100644 src/wsframe.rs create mode 100644 src/wsproto.rs diff --git a/Cargo.toml b/Cargo.toml index a315dce5f..0860f5aaf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,8 +24,12 @@ path = "src/main.rs" [dependencies] time = "0.1" http = "0.1" -httparse = "*" +httparse = "0.1" hyper = "0.11" +unicase = "2.0" +slab = "0.4" +sha1 = "0.2" +rand = "0.3" route-recognizer = "0.1" # tokio diff --git a/README.md b/README.md index 3eaec08f0..82a74bd79 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ Actix Http is licensed under the [Apache-2.0 license](http://opensource.org/lice * HTTP 1.1 and 1.0 support * Streaming and pipelining support + * WebSockets support * Configurable request routing ## Usage @@ -47,8 +48,7 @@ impl Actor for MyRoute { impl Route for MyRoute { type State = (); - fn request(req: HttpRequest, - payload: Option, + fn request(req: HttpRequest, payload: Option, ctx: &mut HttpContext) -> HttpMessage { Self::http_reply(req, httpcodes::HTTPOk) @@ -60,7 +60,8 @@ fn main() { // create routing map with `MyRoute` route let mut routes = RoutingMap::default(); - routes.add_resource("/") + routes + .add_resource("/") .post::(); // start http server diff --git a/src/httpcodes.rs b/src/httpcodes.rs index 57cea818c..e3b332180 100644 --- a/src/httpcodes.rs +++ b/src/httpcodes.rs @@ -7,8 +7,6 @@ use task::Task; use route::{Payload, RouteHandler}; use httpmessage::{Body, HttpRequest, HttpResponse, IntoHttpResponse}; -pub struct StaticResponse(StatusCode); - pub const HTTPOk: StaticResponse = StaticResponse(StatusCode::OK); pub const HTTPCreated: StaticResponse = StaticResponse(StatusCode::CREATED); pub const HTTPNoContent: StaticResponse = StaticResponse(StatusCode::NO_CONTENT); @@ -17,6 +15,15 @@ pub const HTTPNotFound: StaticResponse = StaticResponse(StatusCode::NOT_FOUND); pub const HTTPMethodNotAllowed: StaticResponse = StaticResponse(StatusCode::METHOD_NOT_ALLOWED); +pub struct StaticResponse(StatusCode); + +impl StaticResponse { + pub fn with_reason(self, req: HttpRequest, reason: &'static str) -> HttpResponse { + HttpResponse::new(req, self.0, Body::Empty) + .set_reason(reason) + } +} + impl RouteHandler for StaticResponse { fn handle(&self, req: HttpRequest, _: Option, _: Rc) -> Task { @@ -25,7 +32,7 @@ impl RouteHandler for StaticResponse { } impl IntoHttpResponse for StaticResponse { - fn into_response(self, req: HttpRequest) -> HttpResponse { + fn response(self, req: HttpRequest) -> HttpResponse { HttpResponse::new(req, self.0, Body::Empty) } } diff --git a/src/httpmessage.rs b/src/httpmessage.rs index b7faa4a87..6826af8c3 100644 --- a/src/httpmessage.rs +++ b/src/httpmessage.rs @@ -12,6 +12,13 @@ use hyper::header::{Connection, ConnectionOption, use Params; use error::Error; +#[derive(Copy, Clone, PartialEq, Debug)] +pub enum ConnectionType { + Close, + KeepAlive, + Upgrade, +} + pub trait Message { fn version(&self) -> Version; @@ -61,14 +68,6 @@ pub trait Message { Ok(false) } } - - fn is_upgrade(&self) -> bool { - if let Some(&Connection(ref conn)) = self.headers().get() { - conn.contains(&ConnectionOption::from_str("upgrade").unwrap()) - } else { - false - } - } } @@ -159,6 +158,14 @@ impl HttpRequest { params: params } } + + pub fn is_upgrade(&self) -> bool { + if let Some(&Connection(ref conn)) = self.headers().get() { + conn.contains(&ConnectionOption::from_str("upgrade").unwrap()) + } else { + false + } + } } /// Represents various types of http message body. @@ -173,6 +180,8 @@ pub enum Body { /// Unspecified streaming response. Developer is responsible for setting /// right `Content-Length` or `Transfer-Encoding` headers. Streaming, + /// Upgrade connection. + Upgrade, } impl Body { @@ -188,7 +197,7 @@ impl Body { /// Implements by something that can be converted to `HttpMessage` pub trait IntoHttpResponse { /// Convert into response. - fn into_response(self, req: HttpRequest) -> HttpResponse; + fn response(self, req: HttpRequest) -> HttpResponse; } #[derive(Debug)] @@ -198,10 +207,11 @@ pub struct HttpResponse { pub version: Version, pub headers: Headers, pub status: StatusCode, + reason: Option<&'static str>, body: Body, chunked: bool, - keep_alive: Option, compression: Option, + connection_type: Option, } impl Message for HttpResponse { @@ -223,13 +233,20 @@ impl HttpResponse { version: version, headers: Default::default(), status: status, + reason: None, body: body, chunked: false, - keep_alive: None, compression: None, + connection_type: None, } } + /// Original prequest + #[inline] + pub fn request(&self) -> &HttpRequest { + &self.request + } + /// Get the HTTP version of this response. #[inline] pub fn version(&self) -> Version { @@ -256,37 +273,55 @@ impl HttpResponse { /// Set the `StatusCode` for this response. #[inline] - pub fn set_status(&mut self, status: StatusCode) -> &mut Self { + pub fn set_status(mut self, status: StatusCode) -> Self { self.status = status; self } /// Set a header and move the Response. #[inline] - pub fn set_header(&mut self, header: H) -> &mut Self { + pub fn set_header(mut self, header: H) -> Self { self.headers.set(header); self } - /// Set the headers and move the Response. + /// Set the headers. #[inline] - pub fn with_headers(&mut self, headers: Headers) -> &mut Self { + pub fn with_headers(mut self, headers: Headers) -> Self { self.headers = headers; self } + /// Set the custom reason for the response. + #[inline] + pub fn set_reason(mut self, reason: &'static str) -> Self { + self.reason = Some(reason); + self + } + + /// Set connection type + pub fn set_connection_type(mut self, conn: ConnectionType) -> Self { + self.connection_type = Some(conn); + self + } + + /// Connection upgrade status + pub fn upgrade(&self) -> bool { + self.connection_type == Some(ConnectionType::Upgrade) + } + /// Keep-alive status for this connection pub fn keep_alive(&self) -> bool { - if let Some(ka) = self.keep_alive { - ka + if let Some(ConnectionType::KeepAlive) = self.connection_type { + true } else { self.request.should_keep_alive() } } - + /// Force close connection, even if it is marked as keep-alive pub fn force_close(&mut self) { - self.keep_alive = Some(false); + self.connection_type = Some(ConnectionType::Close); } /// is chunked encoding enabled @@ -310,8 +345,14 @@ impl HttpResponse { &self.body } + /// Set a body + pub fn set_body>(mut self, body: B) -> Self { + self.body = body.into(); + self + } + /// Set a body and return previous body value - pub fn set_body>(&mut self, body: B) -> Body { + pub fn replace_body>(&mut self, body: B) -> Body { mem::replace(&mut self.body, body.into()) } } diff --git a/src/lib.rs b/src/lib.rs index 3d71fb6a1..b4111ae2d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,12 +4,16 @@ extern crate log; extern crate time; extern crate bytes; +extern crate rand; +extern crate sha1; #[macro_use] extern crate futures; extern crate tokio_core; extern crate tokio_io; extern crate tokio_proto; +#[macro_use] extern crate hyper; +extern crate unicase; extern crate http; extern crate httparse; extern crate route_recognizer; @@ -28,6 +32,10 @@ mod task; mod reader; mod server; +pub mod ws; +mod wsframe; +mod wsproto; + pub mod httpcodes; pub use application::HttpApplication; pub use route::{Route, RouteFactory, RouteHandler, Payload, PayloadItem}; diff --git a/src/main.rs b/src/main.rs index 18ed55ea2..5e230c900 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -#![allow(dead_code)] +#![allow(dead_code, unused_variables)] extern crate actix; extern crate actix_http; extern crate tokio_core; @@ -25,9 +25,9 @@ impl Route for MyRoute { { if let Some(pl) = payload { ctx.add_stream(pl); - Self::http_stream(MyRoute{req: Some(req)}) + HttpMessage::stream(MyRoute{req: Some(req)}) } else { - Self::http_reply(req, httpcodes::HTTPOk) + HttpMessage::reply_with(req, httpcodes::HTTPOk) } } } @@ -45,7 +45,7 @@ impl Handler for MyRoute { { println!("CHUNK: {:?}", msg); if let Some(req) = self.req.take() { - ctx.start(httpcodes::HTTPOk.into_response(req)); + ctx.start(httpcodes::HTTPOk.response(req)); ctx.write_eof(); } @@ -53,6 +53,57 @@ impl Handler for MyRoute { } } +struct MyWS {} + +impl Actor for MyWS { + type Context = HttpContext; +} + +impl Route for MyWS { + type State = (); + + fn request(req: HttpRequest, + payload: Option, + ctx: &mut HttpContext) -> HttpMessage + { + if let Some(payload) = payload { + match ws::do_handshake(req) { + Ok(resp) => { + ctx.start(resp); + ctx.add_stream(ws::WsStream::new(payload)); + HttpMessage::stream(MyWS{}) + }, + Err(err) => + HttpMessage::reply(err) + } + } else { + HttpMessage::reply_with(req, httpcodes::HTTPBadRequest) + } + } +} + +impl ResponseType for MyWS { + type Item = (); + type Error = (); +} + +impl StreamHandler for MyWS {} + +impl Handler for MyWS { + fn handle(&mut self, msg: ws::Message, ctx: &mut HttpContext) + -> Response + { + println!("WS: {:?}", msg); + match msg { + ws::Message::Ping(msg) => ws::WsWriter::pong(ctx, msg), + ws::Message::Text(text) => ws::WsWriter::text(ctx, text), + ws::Message::Binary(bin) => ws::WsWriter::binary(ctx, bin), + _ => (), + } + Self::empty() + } +} + fn main() { let _ = env_logger::init(); @@ -71,6 +122,9 @@ fn main() { routes.add_resource("/test") .post::(); + routes.add_resource("/ws/") + .get::(); + let http = HttpServer::new(routes); http.serve::<()>( &net::SocketAddr::from_str("127.0.0.1:9080").unwrap()).unwrap(); diff --git a/src/reader.rs b/src/reader.rs index d6ffca492..f79a9ca7b 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -120,6 +120,7 @@ impl Reader { match self.decode()? { Decoding::Paused => return Ok(Async::NotReady), Decoding::Ready => { + println!("decode ready"); self.payload = None; break }, @@ -149,6 +150,7 @@ impl Reader { Decoding::Paused => break, Decoding::Ready => { + println!("decoded 3"); self.payload = None; break }, @@ -280,12 +282,14 @@ pub fn parse(buf: &mut BytesMut) -> Result) }); let msg = HttpRequest::new(method, uri, version, headers); - - let _upgrade = msg.is_upgrade(); + let upgrade = msg.is_upgrade() || *msg.method() == Method::CONNECT; let chunked = msg.is_chunked()?; + if upgrade { + Ok(Some((msg, Some(Decoder::eof())))) + } // Content-Length - if let Some(&ContentLength(len)) = msg.headers().get() { + else if let Some(&ContentLength(len)) = msg.headers().get() { if chunked { return Err(Error::Header) } diff --git a/src/resource.rs b/src/resource.rs index 80b06bf43..d49fbdc00 100644 --- a/src/resource.rs +++ b/src/resource.rs @@ -101,8 +101,13 @@ impl HttpMessage where A: Actor> + Route } /// Send response - pub fn reply(req: HttpRequest, msg: I) -> Self { - HttpMessage(HttpMessageItem::Message(msg.into_response(req))) + pub fn reply(msg: HttpResponse) -> Self { + HttpMessage(HttpMessageItem::Message(msg)) + } + + /// Send response + pub fn reply_with(req: HttpRequest, msg: I) -> Self { + HttpMessage(HttpMessageItem::Message(msg.response(req))) } pub(crate) fn into(self, mut ctx: HttpContext) -> Task { diff --git a/src/route.rs b/src/route.rs index c8c3b4b97..cbbbaae28 100644 --- a/src/route.rs +++ b/src/route.rs @@ -8,7 +8,7 @@ use futures::unsync::mpsc::Receiver; use task::Task; use context::HttpContext; use resource::HttpMessage; -use httpmessage::{HttpRequest, HttpResponse, IntoHttpResponse}; +use httpmessage::{HttpRequest, HttpResponse}; /// Stream of `PayloadItem`'s pub type Payload = Receiver; @@ -60,16 +60,6 @@ pub trait Route: Actor> { fn factory() -> RouteFactory { RouteFactory(PhantomData) } - - /// Create async response - fn http_stream(act: Self) -> HttpMessage { - HttpMessage::stream(act) - } - - /// Send response - fn http_reply(req: HttpRequest, msg: I) -> HttpMessage { - HttpMessage::reply(req, msg) - } } diff --git a/src/router.rs b/src/router.rs index 02fe2ab26..91c34ecfe 100644 --- a/src/router.rs +++ b/src/router.rs @@ -91,7 +91,7 @@ impl Router { } } - Task::reply(IntoHttpResponse::into_response(HTTPNotFound, req)) + Task::reply(IntoHttpResponse::response(HTTPNotFound, req)) } } } diff --git a/src/task.rs b/src/task.rs index 034f91b9a..1744fd421 100644 --- a/src/task.rs +++ b/src/task.rs @@ -8,8 +8,9 @@ use bytes::BytesMut; use futures::{Async, Future, Poll, Stream}; use tokio_core::net::TcpStream; -use hyper::header::{Date, Connection, ContentType, - ContentLength, Encoding, TransferEncoding}; +use unicase::Ascii; +use hyper::header::{Date, Connection, ConnectionOption, + ContentType, ContentLength, Encoding, TransferEncoding}; use date; use route::Frame; @@ -53,6 +54,7 @@ pub struct Task { stream: Option>, encoder: Encoder, buffer: BytesMut, + upgraded: bool, } impl Task { @@ -69,6 +71,7 @@ impl Task { stream: None, encoder: Encoder::length(0), buffer: BytesMut::new(), + upgraded: false, } } @@ -82,6 +85,7 @@ impl Task { stream: Some(Box::new(stream)), encoder: Encoder::length(0), buffer: BytesMut::new(), + upgraded: false, } } @@ -90,7 +94,8 @@ impl Task { trace!("Prepare message status={:?}", msg.status); let mut extra = 0; - let body = msg.set_body(Body::Empty); + let body = msg.replace_body(Body::Empty); + match body { Body::Empty => { if msg.chunked() { @@ -126,21 +131,24 @@ impl Task { self.encoder = Encoder::eof(); } } - } - - // keep-alive - if !msg.headers.has::() { - if msg.keep_alive() { - if msg.version < Version::HTTP_11 { - msg.headers.set(Connection::keep_alive()); - } - } else if msg.version >= Version::HTTP_11 { - msg.headers.set(Connection::close()); + Body::Upgrade => { + msg.headers.set(Connection(vec![ + ConnectionOption::ConnectionHeader(Ascii::new("upgrade".to_owned()))])); + self.encoder = Encoder::eof(); } } + // keep-alive + if msg.keep_alive() { + if msg.version < Version::HTTP_11 { + msg.headers.set(Connection::keep_alive()); + } + } else if msg.version >= Version::HTTP_11 { + msg.headers.set(Connection::close()); + } + // render message - let init_cap = 30 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra; + let init_cap = 100 + msg.headers.len() * AVERAGE_HEADER_SIZE + extra; self.buffer.reserve(init_cap); if msg.version == Version::HTTP_11 && msg.status == StatusCode::OK { @@ -149,6 +157,7 @@ impl Task { } else { let _ = write!(self.buffer, "{:?} {}\r\n{}", msg.version, msg.status, msg.headers); } + // using http::h1::date is quite a lot faster than generating // a unique Date header each time like req/s goes up about 10% if !msg.headers.has::() { @@ -169,7 +178,7 @@ impl Task { self.buffer.extend(bytes); return } - msg.set_body(body); + msg.replace_body(body); } pub(crate) fn poll_io(&mut self, io: &mut TcpStream) -> Poll { @@ -261,7 +270,8 @@ impl Future for Task { error!("Non expected frame {:?}", frame); return Err(()) } - if msg.body().has_body() { + self.upgraded = msg.upgrade(); + if self.upgraded || msg.body().has_body() { self.iostate = TaskIOState::ReadingPayload; } else { self.iostate = TaskIOState::Done; diff --git a/src/ws.rs b/src/ws.rs new file mode 100644 index 000000000..07a3ce667 --- /dev/null +++ b/src/ws.rs @@ -0,0 +1,248 @@ +//! `WebSocket` context implementation + +#![allow(dead_code, unused_variables, unused_imports)] + +use std::io; +use std::vec::Vec; +use std::borrow::Cow; + +use http::{Method, StatusCode}; +use bytes::{Bytes, BytesMut}; +use futures::{Async, Future, Poll, Stream}; +use hyper::header; + +use actix::Actor; + +use context::HttpContext; +use route::{Route, Payload, PayloadItem}; +use httpcodes::{HTTPBadRequest, HTTPMethodNotAllowed}; +use httpmessage::{Body, ConnectionType, HttpRequest, HttpResponse, IntoHttpResponse}; + +use wsframe; +pub use wsproto::*; + +header! { (WebSocketAccept, "SEC-WEBSOCKET-ACCEPT") => [String] } +header! { (WebSocketKey, "SEC-WEBSOCKET-KEY") => [String] } +header! { (WebSocketVersion, "SEC-WEBSOCKET-VERSION") => [String] } +header! { (WebSocketProtocol, "SEC-WEBSOCKET-PROTOCOL") => [String] } + + +#[derive(Debug)] +pub enum Message { + Text(String), + Binary(Vec), + Ping(String), + Pong(String), + Close, + Closed, + Error +} + +#[derive(Debug)] +pub enum SendMessage { + Text(String), + Binary(Vec), + Close(CloseCode), + Ping, + Pong, +} + +/// Prepare `WebSocket` handshake. +/// +/// It return HTTP response code, response headers, websocket parser, +/// websocket writer. It does not perform any IO. +/// +/// `protocols` is a sequence of known protocols. On successful handshake, +/// the returned response headers contain the first protocol in this list +/// which the server also knows. +pub fn do_handshake(req: HttpRequest) -> Result { + // WebSocket accepts only GET + if *req.method() != Method::GET { + return Err(HTTPMethodNotAllowed.response(req)) + } + + // Check for "UPGRADE" to websocket header + let has_hdr = if let Some::<&header::Upgrade>(hdr) = req.headers().get() { + hdr.0.contains(&header::Protocol::new(header::ProtocolName::WebSocket, None)) + } else { + false + }; + if !has_hdr { + return Err(HTTPMethodNotAllowed.with_reason(req, "No WebSocket UPGRADE header found")) + } + + // Upgrade connection + if !req.is_upgrade() { + return Err(HTTPBadRequest.with_reason(req, "No CONNECTION upgrade")) + } + + // check supported version + if !req.headers().has::() { + return Err(HTTPBadRequest.with_reason(req, "No websocket version header is required")) + } + let supported_ver = { + let hdr = req.headers().get::().unwrap(); + match hdr.0.as_str() { + "13" | "8" | "7" => true, + _ => false, + } + }; + if !supported_ver { + return Err(HTTPBadRequest.with_reason(req, "Unsupported version")) + } + + // check client handshake for validity + let key = if let Some::<&WebSocketKey>(hdr) = req.headers().get() { + Some(hash_key(hdr.0.as_bytes())) + } else { + None + }; + let key = if let Some(key) = key { + key + } else { + return Err(HTTPBadRequest.with_reason(req, "Handshake error")); + }; + + Ok(HttpResponse::new(req, StatusCode::SWITCHING_PROTOCOLS, Body::Empty) + .set_connection_type(ConnectionType::Upgrade) + .set_header( + header::Upgrade(vec![header::Protocol::new(header::ProtocolName::WebSocket, None)])) + .set_header( + header::TransferEncoding(vec![header::Encoding::Chunked])) + .set_header( + WebSocketAccept(key)) + .set_body(Body::Upgrade) + ) +} + + +/// Struct represent stream of `ws::Message` items +pub struct WsStream { + rx: Payload, + buf: BytesMut, +} + +impl WsStream { + pub fn new(rx: Payload) -> WsStream { + WsStream { rx: rx, buf: BytesMut::new() } + } +} + +impl Stream for WsStream { + type Item = Message; + type Error = (); + + fn poll(&mut self) -> Poll, Self::Error> { + let mut done = false; + + loop { + match self.rx.poll() { + Ok(Async::Ready(Some(item))) => { + match item { + PayloadItem::Eof => + return Ok(Async::Ready(None)), + PayloadItem::Chunk(chunk) => { + self.buf.extend(chunk) + } + } + } + Ok(Async::Ready(None)) => done = true, + Ok(Async::NotReady) => {}, + Err(err) => return Err(err), + } + + match wsframe::Frame::parse(&mut self.buf) { + Ok(Some(frame)) => { + trace!("Frame {}", frame); + let (finished, opcode, payload) = frame.unpack(); + + match opcode { + OpCode::Continue => continue, + OpCode::Bad => + return Ok(Async::Ready(Some(Message::Error))), + OpCode::Close => + return Ok(Async::Ready(Some(Message::Closed))), + OpCode::Ping => + return Ok(Async::Ready(Some( + Message::Ping(String::from_utf8_lossy(&payload).into())))), + OpCode::Pong => + return Ok(Async::Ready(Some( + Message::Pong(String::from_utf8_lossy(&payload).into())))), + OpCode::Binary => + return Ok(Async::Ready(Some(Message::Binary(payload)))), + OpCode::Text => { + match String::from_utf8(payload) { + Ok(s) => + return Ok(Async::Ready(Some(Message::Text(s)))), + Err(err) => + return Ok(Async::Ready(Some(Message::Error))), + } + } + } + } + Ok(None) => if done { + return Ok(Async::Ready(None)) + } else { + return Ok(Async::NotReady) + }, + Err(err) => + return Err(()), + } + } + } +} + + +/// `WebSocket` writer +pub struct WsWriter; + +impl WsWriter { + + pub fn text(ctx: &mut HttpContext, text: String) + where A: Actor> + Route + { + let mut frame = wsframe::Frame::message(Vec::from(text.as_str()), OpCode::Text, true); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + + ctx.write( + Bytes::from(buf.as_slice()) + ); + } + + pub fn binary(ctx: &mut HttpContext, data: Vec) + where A: Actor> + Route + { + let mut frame = wsframe::Frame::message(data, OpCode::Binary, true); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + + ctx.write( + Bytes::from(buf.as_slice()) + ); + } + + pub fn ping(ctx: &mut HttpContext, message: String) + where A: Actor> + Route + { + let mut frame = wsframe::Frame::ping(Vec::from(message.as_str())); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + + ctx.write( + Bytes::from(buf.as_slice()) + ) + } + + pub fn pong(ctx: &mut HttpContext, message: String) + where A: Actor> + Route + { + let mut frame = wsframe::Frame::pong(Vec::from(message.as_str())); + let mut buf = Vec::new(); + frame.format(&mut buf).unwrap(); + + ctx.write( + Bytes::from(buf.as_slice()) + ) + } +} diff --git a/src/wsframe.rs b/src/wsframe.rs new file mode 100644 index 000000000..b78f1d8a0 --- /dev/null +++ b/src/wsframe.rs @@ -0,0 +1,496 @@ +#![allow(dead_code, unused_variables)] +use std::fmt; +use std::mem::transmute; +use std::io::{Write, Error, ErrorKind}; +use std::default::Default; +use std::iter::FromIterator; + +use rand; +use bytes::BytesMut; + +use wsproto::{OpCode, CloseCode}; + + +fn apply_mask(buf: &mut [u8], mask: &[u8; 4]) { + let iter = buf.iter_mut().zip(mask.iter().cycle()); + for (byte, &key) in iter { + *byte ^= key + } +} + +#[inline] +fn generate_mask() -> [u8; 4] { + unsafe { transmute(rand::random::()) } +} + +/// A struct representing a `WebSocket` frame. +#[derive(Debug, Clone)] +pub struct Frame { + finished: bool, + rsv1: bool, + rsv2: bool, + rsv3: bool, + opcode: OpCode, + mask: Option<[u8; 4]>, + payload: Vec, +} + +impl Frame { + + /// Desctructe frame + pub fn unpack(self) -> (bool, OpCode, Vec) { + (self.finished, self.opcode, self.payload) + } + + /// Get the length of the frame. + /// This is the length of the header + the length of the payload. + #[inline] + pub fn len(&self) -> usize { + let mut header_length = 2; + let payload_len = self.payload().len(); + if payload_len > 125 { + if payload_len <= u16::max_value() as usize { + header_length += 2; + } else { + header_length += 8; + } + } + + if self.is_masked() { + header_length += 4; + } + + header_length + payload_len + } + + /// Test whether the frame is a final frame. + #[inline] + pub fn is_final(&self) -> bool { + self.finished + } + + /// Test whether the first reserved bit is set. + #[inline] + pub fn has_rsv1(&self) -> bool { + self.rsv1 + } + + /// Test whether the second reserved bit is set. + #[inline] + pub fn has_rsv2(&self) -> bool { + self.rsv2 + } + + /// Test whether the third reserved bit is set. + #[inline] + pub fn has_rsv3(&self) -> bool { + self.rsv3 + } + + /// Get the OpCode of the frame. + #[inline] + pub fn opcode(&self) -> OpCode { + self.opcode + } + + /// Test whether this is a control frame. + #[inline] + pub fn is_control(&self) -> bool { + self.opcode.is_control() + } + + /// Get a reference to the frame's payload. + #[inline] + pub fn payload(&self) -> &Vec { + &self.payload + } + + // Test whether the frame is masked. + #[doc(hidden)] + #[inline] + pub fn is_masked(&self) -> bool { + self.mask.is_some() + } + + // Get an optional reference to the frame's mask. + #[doc(hidden)] + #[allow(dead_code)] + #[inline] + pub fn mask(&self) -> Option<&[u8; 4]> { + self.mask.as_ref() + } + + /// Make this frame a final frame. + #[allow(dead_code)] + #[inline] + pub fn set_final(&mut self, is_final: bool) -> &mut Frame { + self.finished = is_final; + self + } + + /// Set the first reserved bit. + #[inline] + pub fn set_rsv1(&mut self, has_rsv1: bool) -> &mut Frame { + self.rsv1 = has_rsv1; + self + } + + /// Set the second reserved bit. + #[inline] + pub fn set_rsv2(&mut self, has_rsv2: bool) -> &mut Frame { + self.rsv2 = has_rsv2; + self + } + + /// Set the third reserved bit. + #[inline] + pub fn set_rsv3(&mut self, has_rsv3: bool) -> &mut Frame { + self.rsv3 = has_rsv3; + self + } + + /// Set the OpCode. + #[allow(dead_code)] + #[inline] + pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Frame { + self.opcode = opcode; + self + } + + /// Edit the frame's payload. + #[allow(dead_code)] + #[inline] + pub fn payload_mut(&mut self) -> &mut Vec { + &mut self.payload + } + + // Generate a new mask for this frame. + // + // This method simply generates and stores the mask. It does not change the payload data. + // Instead, the payload data will be masked with the generated mask when the frame is sent + // to the other endpoint. + #[doc(hidden)] + #[inline] + pub fn set_mask(&mut self) -> &mut Frame { + self.mask = Some(generate_mask()); + self + } + + // This method unmasks the payload and should only be called on frames that are actually + // masked. In other words, those frames that have just been received from a client endpoint. + #[doc(hidden)] + #[inline] + pub fn remove_mask(&mut self) -> &mut Frame { + self.mask.and_then(|mask| { + Some(apply_mask(&mut self.payload, &mask)) + }); + self.mask = None; + self + } + + /// Consume the frame into its payload. + pub fn into_data(self) -> Vec { + self.payload + } + + /// Create a new data frame. + #[inline] + pub fn message(data: Vec, code: OpCode, finished: bool) -> Frame { + debug_assert!(match code { + OpCode::Text | OpCode::Binary | OpCode::Continue => true, + _ => false, + }, "Invalid opcode for data frame."); + + Frame { + finished: finished, + opcode: code, + payload: data, + .. Frame::default() + } + } + + /// Create a new Pong control frame. + #[inline] + pub fn pong(data: Vec) -> Frame { + Frame { + opcode: OpCode::Pong, + payload: data, + .. Frame::default() + } + } + + /// Create a new Ping control frame. + #[inline] + pub fn ping(data: Vec) -> Frame { + Frame { + opcode: OpCode::Ping, + payload: data, + .. Frame::default() + } + } + + /// Create a new Close control frame. + #[inline] + pub fn close(code: CloseCode, reason: &str) -> Frame { + let raw: [u8; 2] = unsafe { + let u: u16 = code.into(); + transmute(u.to_be()) + }; + + let payload = if let CloseCode::Empty = code { + Vec::new() + } else { + Vec::from_iter( + raw[..].iter() + .chain(reason.as_bytes().iter()) + .cloned()) + }; + + Frame { + payload: payload, + .. Frame::default() + } + } + + /// Parse the input stream into a frame. + pub fn parse(buf: &mut BytesMut) -> Result, Error> { + let mut idx = 2; + + let (frame, length) = { + let mut size = buf.len(); + + if size < 2 { + return Ok(None) + } + let mut head = [0u8; 2]; + size -= 2; + head.copy_from_slice(&buf[..2]); + + trace!("Parsed headers {:?}", head); + + let first = head[0]; + let second = head[1]; + trace!("First: {:b}", first); + trace!("Second: {:b}", second); + + let finished = first & 0x80 != 0; + + let rsv1 = first & 0x40 != 0; + let rsv2 = first & 0x20 != 0; + let rsv3 = first & 0x10 != 0; + + let opcode = OpCode::from(first & 0x0F); + trace!("Opcode: {:?}", opcode); + + let masked = second & 0x80 != 0; + trace!("Masked: {:?}", masked); + + let mut header_length = 2; + let mut length = u64::from(second & 0x7F); + + if length == 126 { + let mut length_bytes = [0u8; 2]; + if size < 2 { + return Ok(None) + } + length_bytes.copy_from_slice(&buf[idx..idx+2]); + size -= 2; + idx += 2; + + length = u64::from(unsafe{ + let mut wide: u16 = transmute(length_bytes); + wide = u16::from_be(wide); + wide}); + header_length += 2; + } else if length == 127 { + let mut length_bytes = [0u8; 8]; + if size < 8 { + return Ok(None) + } + length_bytes.copy_from_slice(&buf[idx..idx+8]); + size -= 8; + idx += 2; + + unsafe { length = transmute(length_bytes); } + length = u64::from_be(length); + header_length += 8; + } + trace!("Payload length: {}", length); + + let mask = if masked { + let mut mask_bytes = [0u8; 4]; + if size < 4 { + return Ok(None) + } else { + header_length += 4; + size -= 4; + mask_bytes.copy_from_slice(&buf[idx..idx+4]); + idx += 4; + Some(mask_bytes) + } + } else { + None + }; + + let length = length as usize; + if size < length { + return Ok(None) + } + + let mut data = Vec::with_capacity(length); + if length > 0 { + data.extend_from_slice(&buf[idx..idx+length]); + } + + // Disallow bad opcode + if let OpCode::Bad = opcode { + return Err( + Error::new( + ErrorKind::Other, + format!("Encountered invalid opcode: {}", first & 0x0F))) + } + + // control frames must have length <= 125 + match opcode { + OpCode::Ping | OpCode::Pong if length > 125 => { + return Err( + Error::new( + ErrorKind::Other, + format!("Rejected WebSocket handshake.Received control frame with length: {}.", length))) + } + OpCode::Close if length > 125 => { + debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); + return Ok(Some(Frame::close(CloseCode::Protocol, "Received close frame with payload length exceeding 125."))) + } + _ => () + } + + // unmask + if let Some(ref mask) = mask { + apply_mask(&mut data, mask); + } + + let frame = Frame { + finished: finished, + rsv1: rsv1, + rsv2: rsv2, + rsv3: rsv3, + opcode: opcode, + mask: mask, + payload: data, + }; + + (frame, header_length + length) + }; + + buf.split_to(length); + Ok(Some(frame)) + } + + /// Write a frame out to a buffer + pub fn format(&mut self, w: &mut W) -> Result<(), Error> + where W: Write + { + let mut one = 0u8; + let code: u8 = self.opcode.into(); + if self.is_final() { + one |= 0x80; + } + if self.has_rsv1() { + one |= 0x40; + } + if self.has_rsv2() { + one |= 0x20; + } + if self.has_rsv3() { + one |= 0x10; + } + one |= code; + + let mut two = 0u8; + + if self.is_masked() { + two |= 0x80; + } + + if self.payload.len() < 126 { + two |= self.payload.len() as u8; + let headers = [one, two]; + try!(w.write_all(&headers)); + } else if self.payload.len() <= 65_535 { + two |= 126; + let length_bytes: [u8; 2] = unsafe { + let short = self.payload.len() as u16; + transmute(short.to_be()) + }; + let headers = [one, two, length_bytes[0], length_bytes[1]]; + try!(w.write_all(&headers)); + } else { + two |= 127; + let length_bytes: [u8; 8] = unsafe { + let long = self.payload.len() as u64; + transmute(long.to_be()) + }; + let headers = [ + one, + two, + length_bytes[0], + length_bytes[1], + length_bytes[2], + length_bytes[3], + length_bytes[4], + length_bytes[5], + length_bytes[6], + length_bytes[7], + ]; + try!(w.write_all(&headers)); + } + + if self.is_masked() { + let mask = self.mask.take().unwrap(); + apply_mask(&mut self.payload, &mask); + try!(w.write_all(&mask)); + } + + try!(w.write_all(&self.payload)); + Ok(()) + } +} + +impl Default for Frame { + fn default() -> Frame { + Frame { + finished: true, + rsv1: false, + rsv2: false, + rsv3: false, + opcode: OpCode::Close, + mask: None, + payload: Vec::new(), + } + } +} + +impl fmt::Display for Frame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, + " + + final: {} + reserved: {} {} {} + opcode: {} + length: {} + payload length: {} + payload: 0x{} +", + self.finished, + self.rsv1, + self.rsv2, + self.rsv3, + self.opcode, + // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), + self.len(), + self.payload.len(), + self.payload.iter().map(|byte| format!("{:x}", byte)).collect::()) + } +} diff --git a/src/wsproto.rs b/src/wsproto.rs new file mode 100644 index 000000000..e8b9895c4 --- /dev/null +++ b/src/wsproto.rs @@ -0,0 +1,300 @@ +#![allow(dead_code, unused_variables)] +use std::fmt; +use std::convert::{Into, From}; +use std::mem::transmute; + +use rand; +use sha1; + + +use self::OpCode::*; +/// Operation codes as part of rfc6455. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum OpCode { + /// Indicates a continuation frame of a fragmented message. + Continue, + /// Indicates a text data frame. + Text, + /// Indicates a binary data frame. + Binary, + /// Indicates a close control frame. + Close, + /// Indicates a ping control frame. + Ping, + /// Indicates a pong control frame. + Pong, + /// Indicates an invalid opcode was received. + Bad, +} + +impl OpCode { + + /// Test whether the opcode indicates a control frame. + pub fn is_control(&self) -> bool { + match *self { + Text | Binary | Continue => false, + _ => true, + } + } + +} + +impl fmt::Display for OpCode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Continue => write!(f, "CONTINUE"), + Text => write!(f, "TEXT"), + Binary => write!(f, "BINARY"), + Close => write!(f, "CLOSE"), + Ping => write!(f, "PING"), + Pong => write!(f, "PONG"), + Bad => write!(f, "BAD"), + } + } +} + +impl Into for OpCode { + + fn into(self) -> u8 { + match self { + Continue => 0, + Text => 1, + Binary => 2, + Close => 8, + Ping => 9, + Pong => 10, + Bad => { + debug_assert!(false, "Attempted to convert invalid opcode to u8. This is a bug."); + 8 // if this somehow happens, a close frame will help us tear down quickly + } + } + } +} + +impl From for OpCode { + + fn from(byte: u8) -> OpCode { + match byte { + 0 => Continue, + 1 => Text, + 2 => Binary, + 8 => Close, + 9 => Ping, + 10 => Pong, + _ => Bad + } + } +} + +use self::CloseCode::*; +/// Status code used to indicate why an endpoint is closing the `WebSocket` connection. +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum CloseCode { + /// Indicates a normal closure, meaning that the purpose for + /// which the connection was established has been fulfilled. + Normal, + /// Indicates that an endpoint is "going away", such as a server + /// going down or a browser having navigated away from a page. + Away, + /// Indicates that an endpoint is terminating the connection due + /// to a protocol error. + Protocol, + /// Indicates that an endpoint is terminating the connection + /// because it has received a type of data it cannot accept (e.g., an + /// endpoint that understands only text data MAY send this if it + /// receives a binary message). + Unsupported, + /// Indicates that no status code was included in a closing frame. This + /// close code makes it possible to use a single method, `on_close` to + /// handle even cases where no close code was provided. + Status, + /// Indicates an abnormal closure. If the abnormal closure was due to an + /// error, this close code will not be used. Instead, the `on_error` method + /// of the handler will be called with the error. However, if the connection + /// is simply dropped, without an error, this close code will be sent to the + /// handler. + Abnormal, + /// Indicates that an endpoint is terminating the connection + /// because it has received data within a message that was not + /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] + /// data within a text message). + Invalid, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that violates its policy. This + /// is a generic status code that can be returned when there is no + /// other more suitable status code (e.g., Unsupported or Size) or if there + /// is a need to hide specific details about the policy. + Policy, + /// Indicates that an endpoint is terminating the connection + /// because it has received a message that is too big for it to + /// process. + Size, + /// Indicates that an endpoint (client) is terminating the + /// connection because it has expected the server to negotiate one or + /// more extension, but the server didn't return them in the response + /// message of the WebSocket handshake. The list of extensions that + /// are needed should be given as the reason for closing. + /// Note that this status code is not used by the server, because it + /// can fail the WebSocket handshake instead. + Extension, + /// Indicates that a server is terminating the connection because + /// it encountered an unexpected condition that prevented it from + /// fulfilling the request. + Error, + /// Indicates that the server is restarting. A client may choose to reconnect, + /// and if it does, it should use a randomized delay of 5-30 seconds between attempts. + Restart, + /// Indicates that the server is overloaded and the client should either connect + /// to a different IP (when multiple targets exist), or reconnect to the same IP + /// when a user has performed an action. + Again, + #[doc(hidden)] + Tls, + #[doc(hidden)] + Empty, + #[doc(hidden)] + Other(u16), +} + +impl Into for CloseCode { + + fn into(self) -> u16 { + match self { + Normal => 1000, + Away => 1001, + Protocol => 1002, + Unsupported => 1003, + Status => 1005, + Abnormal => 1006, + Invalid => 1007, + Policy => 1008, + Size => 1009, + Extension => 1010, + Error => 1011, + Restart => 1012, + Again => 1013, + Tls => 1015, + Empty => 0, + Other(code) => code, + } + } +} + +impl From for CloseCode { + + fn from(code: u16) -> CloseCode { + match code { + 1000 => Normal, + 1001 => Away, + 1002 => Protocol, + 1003 => Unsupported, + 1005 => Status, + 1006 => Abnormal, + 1007 => Invalid, + 1008 => Policy, + 1009 => Size, + 1010 => Extension, + 1011 => Error, + 1012 => Restart, + 1013 => Again, + 1015 => Tls, + 0 => Empty, + _ => Other(code), + } + } +} + + +static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; +static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + +fn generate_key() -> String { + let key: [u8; 16] = unsafe { + transmute(rand::random::<(u64, u64)>()) + }; + encode_base64(&key) +} + +// TODO: hash is always same size, we dont need String +pub(crate) fn hash_key(key: &[u8]) -> String { + let mut hasher = sha1::Sha1::new(); + + hasher.update(key); + hasher.update(WS_GUID.as_bytes()); + + encode_base64(&hasher.digest().bytes()) +} + + +// This code is based on rustc_serialize base64 STANDARD +fn encode_base64(data: &[u8]) -> String { + let len = data.len(); + let mod_len = len % 3; + + let mut encoded = vec![b'='; (len + 2) / 3 * 4]; + { + let mut in_iter = data[..len - mod_len].iter().map(|&c| u32::from(c)); + let mut out_iter = encoded.iter_mut(); + + let enc = |val| BASE64[val as usize]; + let mut write = |val| *out_iter.next().unwrap() = val; + + while let (Some(one), Some(two), Some(three)) = (in_iter.next(), in_iter.next(), in_iter.next()) { + let g24 = one << 16 | two << 8 | three; + write(enc((g24 >> 18) & 63)); + write(enc((g24 >> 12) & 63)); + write(enc((g24 >> 6 ) & 63)); + write(enc(g24 & 63)); + } + + match mod_len { + 1 => { + let pad = u32::from(data[len-1]) << 16; + write(enc((pad >> 18) & 63)); + write(enc((pad >> 12) & 63)); + } + 2 => { + let pad = u32::from(data[len-2]) << 16 | u32::from(data[len-1]) << 8; + write(enc((pad >> 18) & 63)); + write(enc((pad >> 12) & 63)); + write(enc((pad >> 6) & 63)); + } + _ => (), + } + } + + String::from_utf8(encoded).unwrap() +} + + +mod test { + #![allow(unused_imports, unused_variables, dead_code)] + use super::*; + + #[test] + fn opcode_from_u8() { + let byte = 2u8; + assert_eq!(OpCode::from(byte), OpCode::Binary); + } + + #[test] + fn opcode_into_u8() { + let text = OpCode::Text; + let byte: u8 = text.into(); + assert_eq!(byte, 1u8); + } + + #[test] + fn closecode_from_u16() { + let byte = 1008u16; + assert_eq!(CloseCode::from(byte), CloseCode::Policy); + } + + #[test] + fn closecode_into_u16() { + let text = CloseCode::Away; + let byte: u16 = text.into(); + assert_eq!(byte, 1001u16); + } +}