diff --git a/Cargo.toml b/Cargo.toml index 8d1211a50..96104a700 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,9 +60,10 @@ url = "1.6" cookie = { version="0.10", features=["percent-encode", "secure"] } # io -mio = "0.6" +mio = "^0.6.13" net2 = "0.2" bytes = "0.4" +byteorder = "1" futures = "0.1" tokio-io = "0.1" tokio-core = "0.1" @@ -108,4 +109,5 @@ members = [ "examples/websocket", "examples/websocket-chat", "examples/web-cors/backend", + "tools/wsload/", ] diff --git a/src/client/writer.rs b/src/client/writer.rs index f77d4c988..4370b37b6 100644 --- a/src/client/writer.rs +++ b/src/client/writer.rs @@ -4,16 +4,17 @@ use std::fmt::Write; use bytes::BufMut; use futures::{Async, Poll}; use tokio_io::AsyncWrite; -// use http::header::{HeaderValue, CONNECTION, DATE}; use body::Binary; -use server::{WriterState, MAX_WRITE_BUFFER_SIZE}; +use server::WriterState; use server::shared::SharedBytes; use client::ClientRequest; -const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific +const LOW_WATERMARK: usize = 1024; +const HIGH_WATERMARK: usize = 8 * LOW_WATERMARK; +const AVERAGE_HEADER_SIZE: usize = 30; bitflags! { struct Flags: u8 { @@ -29,6 +30,8 @@ pub(crate) struct HttpClientWriter { written: u64, headers_size: u32, buffer: SharedBytes, + low: usize, + high: usize, } impl HttpClientWriter { @@ -39,6 +42,8 @@ impl HttpClientWriter { written: 0, headers_size: 0, buffer: buf, + low: LOW_WATERMARK, + high: HIGH_WATERMARK, } } @@ -50,6 +55,12 @@ impl HttpClientWriter { self.flags.contains(Flags::KEEPALIVE) && !self.flags.contains(Flags::UPGRADE) } + /// Set write buffer capacity + pub fn set_buffer_capacity(&mut self, low_watermark: usize, high_watermark: usize) { + self.low = low_watermark; + self.high = high_watermark; + } + fn write_to_stream(&mut self, stream: &mut T) -> io::Result { while !self.buffer.is_empty() { match stream.write(self.buffer.as_ref()) { @@ -61,7 +72,7 @@ impl HttpClientWriter { let _ = self.buffer.split_to(n); }, Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + if self.buffer.len() > self.high { return Ok(WriterState::Pause) } else { return Ok(WriterState::Done) @@ -117,7 +128,7 @@ impl HttpClientWriter { self.buffer.extend_from_slice(payload.as_ref()) } - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + if self.buffer.len() > self.high { Ok(WriterState::Pause) } else { Ok(WriterState::Done) @@ -125,7 +136,7 @@ impl HttpClientWriter { } pub fn write_eof(&mut self) -> io::Result { - if self.buffer.len() > MAX_WRITE_BUFFER_SIZE { + if self.buffer.len() > self.high { Ok(WriterState::Pause) } else { Ok(WriterState::Done) diff --git a/src/lib.rs b/src/lib.rs index c42bef718..faca69ce6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,7 @@ extern crate log; extern crate time; extern crate base64; extern crate bytes; +extern crate byteorder; extern crate sha1; extern crate regex; #[macro_use] diff --git a/src/ws/client.rs b/src/ws/client.rs index 7e4c4f9b1..bc857bd81 100644 --- a/src/ws/client.rs +++ b/src/ws/client.rs @@ -92,9 +92,9 @@ impl From for WsClientError { } } -/// WebSocket client +/// `WebSocket` client /// -/// Example of WebSocket client usage is available in +/// Example of `WebSocket` client usage is available in /// [websocket example]( /// https://github.com/actix/actix-web/blob/master/examples/websocket/src/client.rs#L24) pub struct WsClient { @@ -317,7 +317,7 @@ impl Future for WsHandshake { return Err(WsClientError::InvalidChallengeResponse) } - let inner = Rc::new(UnsafeCell::new(Inner{inner: inner})); + let inner = Rc::new(UnsafeCell::new(inner)); Ok(Async::Ready( (WsClientReader{inner: Rc::clone(&inner)}, WsClientWriter{inner: inner}))) @@ -332,12 +332,8 @@ impl Future for WsHandshake { } -struct Inner { - inner: WsInner, -} - pub struct WsClientReader { - inner: Rc> + inner: Rc> } impl fmt::Debug for WsClientReader { @@ -348,7 +344,7 @@ impl fmt::Debug for WsClientReader { impl WsClientReader { #[inline] - fn as_mut(&mut self) -> &mut Inner { + fn as_mut(&mut self) -> &mut WsInner { unsafe{ &mut *self.inner.get() } } } @@ -361,10 +357,10 @@ impl Stream for WsClientReader { let inner = self.as_mut(); let mut done = false; - match utils::read_from_io(&mut inner.inner.conn, &mut inner.inner.parser_buf) { + match utils::read_from_io(&mut inner.conn, &mut inner.parser_buf) { Ok(Async::Ready(0)) => { done = true; - inner.inner.closed = true; + inner.closed = true; }, Ok(Async::Ready(_)) | Ok(Async::NotReady) => (), Err(err) => @@ -372,10 +368,10 @@ impl Stream for WsClientReader { } // write - let _ = inner.inner.writer.poll_completed(&mut inner.inner.conn, false); + let _ = inner.writer.poll_completed(&mut inner.conn, false); // read - match Frame::parse(&mut inner.inner.parser_buf) { + match Frame::parse(&mut inner.parser_buf) { Ok(Some(frame)) => { // trace!("WsFrame {}", frame); let (_finished, opcode, payload) = frame.unpack(); @@ -385,8 +381,8 @@ impl Stream for WsClientReader { OpCode::Bad => Ok(Async::Ready(Some(Message::Error))), OpCode::Close => { - inner.inner.closed = true; - inner.inner.error_sent = true; + inner.closed = true; + inner.error_sent = true; Ok(Async::Ready(Some(Message::Closed))) }, OpCode::Ping => @@ -413,9 +409,9 @@ impl Stream for WsClientReader { Ok(None) => { if done { Ok(Async::Ready(None)) - } else if inner.inner.closed { - if !inner.inner.error_sent { - inner.inner.error_sent = true; + } else if inner.closed { + if !inner.error_sent { + inner.error_sent = true; Ok(Async::Ready(Some(Message::Closed))) } else { Ok(Async::Ready(None)) @@ -425,8 +421,8 @@ impl Stream for WsClientReader { } }, Err(err) => { - inner.inner.closed = true; - inner.inner.error_sent = true; + inner.closed = true; + inner.error_sent = true; Err(err.into()) } } @@ -434,12 +430,12 @@ impl Stream for WsClientReader { } pub struct WsClientWriter { - inner: Rc> + inner: Rc> } impl WsClientWriter { #[inline] - fn as_mut(&mut self) -> &mut Inner { + fn as_mut(&mut self) -> &mut WsInner { unsafe{ &mut *self.inner.get() } } } @@ -449,8 +445,8 @@ impl WsClientWriter { /// Write payload #[inline] fn write>(&mut self, data: B) { - if !self.as_mut().inner.closed { - let _ = self.as_mut().inner.writer.write(&data.into()); + if !self.as_mut().closed { + let _ = self.as_mut().writer.write(&data.into()); } else { warn!("Trying to write to disconnected response"); } diff --git a/src/ws/context.rs b/src/ws/context.rs index d977a69b5..55b4e67ea 100644 --- a/src/ws/context.rs +++ b/src/ws/context.rs @@ -211,11 +211,8 @@ impl ActorHttpContext for WebsocketContext where A: Actor) }; - if self.inner.alive() { - match self.inner.poll(ctx) { - Ok(Async::NotReady) | Ok(Async::Ready(())) => (), - Err(_) => return Err(ErrorInternalServerError("error").into()), - } + if self.inner.alive() && self.inner.poll(ctx).is_err() { + return Err(ErrorInternalServerError("error").into()) } // frames diff --git a/src/ws/frame.rs b/src/ws/frame.rs index 015127f0a..d149f37a1 100644 --- a/src/ws/frame.rs +++ b/src/ws/frame.rs @@ -2,6 +2,7 @@ use std::{fmt, mem}; use std::io::{Write, Error, ErrorKind}; use std::iter::FromIterator; use bytes::BytesMut; +use byteorder::{ByteOrder, BigEndian}; use body::Binary; use ws::proto::{OpCode, CloseCode}; @@ -39,7 +40,6 @@ impl Frame { header_length += 8; } } - if self.mask.is_some() { header_length += 4; } @@ -84,136 +84,107 @@ impl Frame { /// Parse the input stream into a frame. pub fn parse(buf: &mut BytesMut) -> Result, Error> { let mut idx = 2; + let mut size = buf.len(); - let (frame, length) = { - let mut size = buf.len(); + if size < 2 { + return Ok(None) + } + size -= 2; + let first = buf[0]; + let second = buf[1]; + let finished = first & 0x80 != 0; + let rsv1 = first & 0x40 != 0; + let rsv2 = first & 0x20 != 0; + let rsv3 = first & 0x10 != 0; + let opcode = OpCode::from(first & 0x0F); + let masked = second & 0x80 != 0; + let len = second & 0x7F; + + let length = if len == 126 { if size < 2 { return Ok(None) } - let mut head = [0u8; 2]; + let len = u64::from(BigEndian::read_u16(&buf[idx..])); size -= 2; - head.copy_from_slice(&buf[..2]); - - trace!("Parsed headers {:?}", head); - - let first = head[0]; - let second = head[1]; - trace!("First: {:b}", first); - trace!("Second: {:b}", second); - - let finished = first & 0x80 != 0; - - let rsv1 = first & 0x40 != 0; - let rsv2 = first & 0x20 != 0; - let rsv3 = first & 0x10 != 0; - - let opcode = OpCode::from(first & 0x0F); - trace!("Opcode: {:?}", opcode); - - let masked = second & 0x80 != 0; - trace!("Masked: {:?}", masked); - - let mut header_length = 2; - let mut length = u64::from(second & 0x7F); - - if length == 126 { - if size < 2 { - return Ok(None) - } - let mut length_bytes = [0u8; 2]; - length_bytes.copy_from_slice(&buf[idx..idx+2]); - size -= 2; - idx += 2; - - length = u64::from(unsafe{ - let mut wide: u16 = mem::transmute(length_bytes); - wide = u16::from_be(wide); - wide}); - header_length += 2; - } else if length == 127 { - if size < 8 { - return Ok(None) - } - let mut length_bytes = [0u8; 8]; - length_bytes.copy_from_slice(&buf[idx..idx+8]); - size -= 8; - idx += 8; - - unsafe { length = mem::transmute(length_bytes); } - length = u64::from_be(length); - header_length += 8; - } - trace!("Payload length: {}", length); - - let mask = if masked { - let mut mask_bytes = [0u8; 4]; - if size < 4 { - return Ok(None) - } else { - header_length += 4; - size -= 4; - mask_bytes.copy_from_slice(&buf[idx..idx+4]); - idx += 4; - Some(mask_bytes) - } - } else { - None - }; - - let length = length as usize; - if size < length { + idx += 2; + len + } else if len == 127 { + if size < 8 { return Ok(None) } + let len = BigEndian::read_u64(&buf[idx..]); + size -= 8; + idx += 8; + len + } else { + u64::from(len) + }; - let mut data = Vec::with_capacity(length); - if length > 0 { - data.extend_from_slice(&buf[idx..idx+length]); + let mask = if masked { + let mut mask_bytes = [0u8; 4]; + if size < 4 { + return Ok(None) + } else { + size -= 4; + mask_bytes.copy_from_slice(&buf[idx..idx+4]); + idx += 4; + Some(mask_bytes) } + } else { + None + }; - // Disallow bad opcode - if let OpCode::Bad = opcode { + let length = length as usize; + if size < length { + return Ok(None) + } + + // get body + buf.split_to(idx); + let mut data = if length > 0 { + buf.split_to(length) + } else { + BytesMut::new() + }; + + // Disallow bad opcode + if let OpCode::Bad = opcode { + return Err( + Error::new( + ErrorKind::Other, + format!("Encountered invalid opcode: {}", first & 0x0F))) + } + + // control frames must have length <= 125 + match opcode { + OpCode::Ping | OpCode::Pong if length > 125 => { return Err( Error::new( ErrorKind::Other, - format!("Encountered invalid opcode: {}", first & 0x0F))) + format!("Rejected WebSocket handshake.Received control frame with length: {}.", length))) } - - // control frames must have length <= 125 - match opcode { - OpCode::Ping | OpCode::Pong if length > 125 => { - return Err( - Error::new( - ErrorKind::Other, - format!("Rejected WebSocket handshake.Received control frame with length: {}.", length))) - } - OpCode::Close if length > 125 => { - debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Some(Frame::close(CloseCode::Protocol, "Received close frame with payload length exceeding 125."))) - } - _ => () + OpCode::Close if length > 125 => { + debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); + return Ok(Some(Frame::close(CloseCode::Protocol, "Received close frame with payload length exceeding 125."))) } + _ => () + } - // unmask - if let Some(ref mask) = mask { - apply_mask(&mut data, mask); - } + // unmask + if let Some(ref mask) = mask { + apply_mask(&mut data, mask); + } - let frame = Frame { - finished: finished, - rsv1: rsv1, - rsv2: rsv2, - rsv3: rsv3, - opcode: opcode, - mask: mask, - payload: data.into(), - }; - - (frame, header_length + length) - }; - - buf.split_to(length); - Ok(Some(frame)) + Ok(Some(Frame { + finished: finished, + rsv1: rsv1, + rsv2: rsv2, + rsv3: rsv3, + opcode: opcode, + mask: mask, + payload: data.into(), + })) } /// Write a frame out to a buffer diff --git a/tools/wsload/Cargo.toml b/tools/wsload/Cargo.toml new file mode 100644 index 000000000..606615a0b --- /dev/null +++ b/tools/wsload/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "wsclient" +version = "0.1.0" +authors = ["Nikolay Kim "] +workspace = "../.." + +[[bin]] +name = "wsclient" +path = "src/wsclient.rs" + +[dependencies] +env_logger = "*" +futures = "0.1" +clap = "2" +url = "1.6" +rand = "0.4" +time = "*" +num_cpus = "1" +tokio-core = "0.1" +actix = { git = "https://github.com/actix/actix.git" } +actix-web = { path="../../" } diff --git a/tools/wsload/src/wsclient.rs b/tools/wsload/src/wsclient.rs new file mode 100644 index 000000000..e6438c634 --- /dev/null +++ b/tools/wsload/src/wsclient.rs @@ -0,0 +1,238 @@ +//! Simple websocket client. + +#![allow(unused_variables)] +extern crate actix; +extern crate actix_web; +extern crate env_logger; +extern crate futures; +extern crate tokio_core; +extern crate url; +extern crate clap; +extern crate rand; +extern crate time; +extern crate num_cpus; + +use std::time::Duration; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use futures::Future; +use rand::{thread_rng, Rng}; + +use actix::prelude::*; +use actix_web::ws::{Message, WsClientError, WsClient, WsClientWriter}; + + +fn main() { + ::std::env::set_var("RUST_LOG", "actix_web=info"); + let _ = env_logger::init(); + + let matches = clap::App::new("ws tool") + .version("0.1") + .about("Applies load to websocket server") + .args_from_usage( + " 'WebSocket url' + [bin]... -b, 'use binary frames' + -s, --size=[NUMBER] 'size of PUBLISH packet payload to send in KB' + -w, --warm-up=[SECONDS] 'seconds before counter values are considered for reporting' + -r, --sample-rate=[SECONDS] 'seconds between average reports' + -c, --concurrency=[NUMBER] 'number of websockt connections to open and use concurrently for sending' + -t, --threads=[NUMBER] 'number of threads to use'", + ) + .get_matches(); + + let bin: bool = matches.value_of("bin").is_some(); + let ws_url = matches.value_of("url").unwrap().to_owned(); + let _ = url::Url::parse(&ws_url).map_err(|e| { + println!("Invalid url: {}", ws_url); + std::process::exit(0); + }); + + let threads = parse_u64_default(matches.value_of("threads"), num_cpus::get() as u64); + let concurrency = parse_u64_default(matches.value_of("concurrency"), 1); + let payload_size: usize = match matches.value_of("size") { + Some(s) => parse_u64_default(Some(s), 0) as usize * 1024, + None => 1024, + }; + let warmup_seconds = parse_u64_default(matches.value_of("warm-up"), 2) as u64; + let sample_rate = parse_u64_default(matches.value_of("sample-rate"), 1) as usize; + + let perf_counters = Arc::new(PerfCounters::new()); + let payload = Arc::new(thread_rng() + .gen_ascii_chars() + .take(payload_size) + .collect::()); + + let sys = actix::System::new("ws-client"); + + let mut report = true; + for t in 0..threads { + let pl = payload.clone(); + let ws = ws_url.clone(); + let perf = perf_counters.clone(); + let addr = Arbiter::new(format!("test {}", t)); + + addr.send(actix::msgs::Execute::new(move || -> Result<(), ()> { + let mut reps = report; + for _ in 0..concurrency { + let pl2 = pl.clone(); + let perf2 = perf.clone(); + + Arbiter::handle().spawn( + WsClient::new(&ws).connect().unwrap() + .map_err(|e| { + println!("Error: {}", e); + Arbiter::system().send(actix::msgs::SystemExit(0)); + () + }) + .map(move |(reader, writer)| { + let addr: SyncAddress<_> = ChatClient::create(move |ctx| { + ChatClient::add_stream(reader, ctx); + ChatClient{conn: writer, + payload: pl2, + report: reps, + bin: bin, + ts: time::precise_time_ns(), + perf_counters: perf2, + sample_rate_secs: sample_rate, + } + }); + }) + ); + reps = false; + } + Ok(()) + })); + report = false; + } + + let _ = sys.run(); +} + +fn parse_u64_default(input: Option<&str>, default: u64) -> u64 { + input.map(|v| v.parse().expect(&format!("not a valid number: {}", v))) + .unwrap_or(default) +} + +struct ChatClient{ + conn: WsClientWriter, + payload: Arc, + ts: u64, + bin: bool, + report: bool, + perf_counters: Arc, + sample_rate_secs: usize, +} + +impl Actor for ChatClient { + type Context = Context; + + fn started(&mut self, ctx: &mut Context) { + self.send_text(); + if self.report { + self.sample_rate(ctx); + } + } + + fn stopping(&mut self, _: &mut Context) -> bool { + Arbiter::system().send(actix::msgs::SystemExit(0)); + true + } +} + +impl ChatClient { + fn sample_rate(&self, ctx: &mut Context) { + ctx.run_later(Duration::new(self.sample_rate_secs as u64, 0), |act, ctx| { + let req_count = act.perf_counters.pull_request_count(); + if req_count != 0 { + let latency = act.perf_counters.pull_latency_ns(); + let latency_max = act.perf_counters.pull_latency_max_ns(); + println!( + "rate: {}, throughput: {:?} kb, latency: {}, latency max: {}", + req_count / act.sample_rate_secs, + (((req_count * act.payload.len()) as f64) / 1024.0) / + act.sample_rate_secs as f64, + time::Duration::nanoseconds((latency / req_count as u64) as i64), + time::Duration::nanoseconds(latency_max as i64) + ); + } + + act.sample_rate(ctx); + }); + } + + fn send_text(&mut self) { + self.ts = time::precise_time_ns(); + if self.bin { + self.conn.binary(&self.payload); + } else { + self.conn.text(&self.payload); + } + } +} + +/// Handle server websocket messages +impl StreamHandler for ChatClient { + + fn finished(&mut self, ctx: &mut Context) { + ctx.stop() + } + + fn handle(&mut self, msg: Message, ctx: &mut Context) { + match msg { + Message::Text(txt) => { + if txt == self.payload.as_ref().as_str() { + self.perf_counters.register_request(); + self.perf_counters.register_latency(time::precise_time_ns() - self.ts); + self.send_text(); + } else { + println!("not eaqual"); + } + }, + _ => () + } + } +} + + +pub struct PerfCounters { + req: AtomicUsize, + lat: AtomicUsize, + lat_max: AtomicUsize +} + +impl PerfCounters { + pub fn new() -> PerfCounters { + PerfCounters { + req: AtomicUsize::new(0), + lat: AtomicUsize::new(0), + lat_max: AtomicUsize::new(0), + } + } + + pub fn pull_request_count(&self) -> usize { + self.req.swap(0, Ordering::SeqCst) + } + + pub fn pull_latency_ns(&self) -> u64 { + self.lat.swap(0, Ordering::SeqCst) as u64 + } + + pub fn pull_latency_max_ns(&self) -> u64 { + self.lat_max.swap(0, Ordering::SeqCst) as u64 + } + + pub fn register_request(&self) { + self.req.fetch_add(1, Ordering::SeqCst); + } + + pub fn register_latency(&self, nanos: u64) { + let nanos = nanos as usize; + self.lat.fetch_add(nanos, Ordering::SeqCst); + loop { + let current = self.lat_max.load(Ordering::SeqCst); + if current >= nanos || self.lat_max.compare_and_swap(current, nanos, Ordering::SeqCst) == current { + break; + } + } + } +}