1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-04 14:28:50 +00:00

special handling for upgraded pipeline

This commit is contained in:
Nikolay Kim 2018-02-10 00:05:20 -08:00
parent 2d049e4a9f
commit 3109f9be62
5 changed files with 91 additions and 144 deletions

View file

@ -128,7 +128,7 @@ impl Handler<ws::Message> for Ws {
fn handle(&mut self, msg: ws::Message, ctx: &mut Self::Context) {
match msg {
ws::Message::Text(text) => ctx.text(&text),
ws::Message::Text(text) => ctx.text(text),
_ => (),
}
}

View file

@ -96,15 +96,17 @@ impl<T: AsyncWrite> Writer for H1Writer<T> {
fn start(&mut self, req: &mut HttpMessage, msg: &mut HttpResponse) -> io::Result<WriterState> {
// prepare task
self.flags.insert(Flags::STARTED);
self.encoder = PayloadEncoder::new(self.buffer.clone(), req, msg);
if msg.keep_alive().unwrap_or_else(|| req.keep_alive()) {
self.flags.insert(Flags::KEEPALIVE);
self.flags.insert(Flags::STARTED | Flags::KEEPALIVE);
} else {
self.flags.insert(Flags::STARTED);
}
// Connection upgrade
let version = msg.version().unwrap_or_else(|| req.version);
if msg.upgrade() {
self.flags.insert(Flags::UPGRADE);
msg.headers_mut().insert(CONNECTION, HeaderValue::from_static("upgrade"));
}
// keep-alive
@ -177,8 +179,29 @@ impl<T: AsyncWrite> Writer for H1Writer<T> {
self.written += payload.len() as u64;
if !self.flags.contains(Flags::DISCONNECTED) {
if self.flags.contains(Flags::STARTED) {
// shortcut for upgraded connection
if self.flags.contains(Flags::UPGRADE) {
if self.buffer.is_empty() {
match self.stream.write(payload.as_ref()) {
Ok(0) => {
self.disconnected();
return Ok(WriterState::Done);
},
Ok(n) => if payload.len() < n {
self.buffer.extend_from_slice(&payload.as_ref()[n..])
},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
return Ok(WriterState::Done)
}
Err(err) => return Err(err),
}
} else {
self.buffer.extend(payload);
}
} else {
// TODO: add warning, write after EOF
self.encoder.write(payload)?;
}
} else {
// might be response to EXCEPT
self.buffer.extend_from_slice(payload.as_ref())

View file

@ -28,8 +28,8 @@ use client::{ClientRequest, ClientRequestBuilder,
use client::{Connect, Connection, ClientConnector, ClientConnectorError};
use super::Message;
use super::frame::Frame;
use super::proto::{CloseCode, OpCode};
use super::frame::{Frame, FrameData};
pub type WsClientFuture =
Future<Item=(WsClientReader, WsClientWriter), Error=WsClientError>;
@ -444,17 +444,9 @@ impl WsClientWriter {
/// Write payload
#[inline]
fn write(&mut self, data: FrameData) {
fn write(&mut self, data: &Binary) {
if !self.as_mut().closed {
match data {
FrameData::Complete(data) => {
let _ = self.as_mut().writer.write(&data);
},
FrameData::Split(headers, payload) => {
let _ = self.as_mut().writer.write(&headers);
let _ = self.as_mut().writer.write(&payload);
}
}
let _ = self.as_mut().writer.write(data);
} else {
warn!("Trying to write to disconnected response");
}
@ -462,31 +454,31 @@ impl WsClientWriter {
/// Send text frame
#[inline]
pub fn text(&mut self, text: &str) {
self.write(Frame::message(Vec::from(text), OpCode::Text, true).generate(true));
pub fn text<T: Into<String>>(&mut self, text: T) {
self.write(&Frame::message(text.into(), OpCode::Text, true, true));
}
/// Send binary frame
#[inline]
pub fn binary<B: Into<Binary>>(&mut self, data: B) {
self.write(Frame::message(data, OpCode::Binary, true).generate(true));
self.write(&Frame::message(data, OpCode::Binary, true, true));
}
/// Send ping frame
#[inline]
pub fn ping(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Ping, true).generate(true));
self.write(&Frame::message(Vec::from(message), OpCode::Ping, true, true));
}
/// Send pong frame
#[inline]
pub fn pong(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Pong, true).generate(true));
self.write(&Frame::message(Vec::from(message), OpCode::Pong, true, true));
}
/// Send close frame
#[inline]
pub fn close(&mut self, code: CloseCode, reason: &str) {
self.write(Frame::close(code, reason).generate(true));
self.write(&Frame::close(code, reason, true));
}
}

View file

@ -14,7 +14,7 @@ use error::{Error, ErrorInternalServerError};
use httprequest::HttpRequest;
use context::{Frame as ContextFrame, ActorHttpContext, Drain};
use ws::frame::{Frame, FrameData};
use ws::frame::Frame;
use ws::proto::{OpCode, CloseCode};
@ -105,21 +105,13 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
/// Write payload
#[inline]
fn write(&mut self, data: FrameData) {
fn write(&mut self, data: Binary) {
if !self.disconnected {
if self.stream.is_none() {
self.stream = Some(SmallVec::new());
}
let stream = self.stream.as_mut().unwrap();
match data {
FrameData::Complete(data) =>
stream.push(ContextFrame::Chunk(Some(data))),
FrameData::Split(headers, payload) => {
stream.push(ContextFrame::Chunk(Some(headers)));
stream.push(ContextFrame::Chunk(Some(payload)));
}
}
stream.push(ContextFrame::Chunk(Some(data)));
} else {
warn!("Trying to write to disconnected response");
}
@ -140,31 +132,31 @@ impl<A, S> WebsocketContext<A, S> where A: Actor<Context=Self> {
/// Send text frame
#[inline]
pub fn text<T: Into<String>>(&mut self, text: T) {
self.write(Frame::message(text.into(), OpCode::Text, true).generate(false));
self.write(Frame::message(text.into(), OpCode::Text, true, false));
}
/// Send binary frame
#[inline]
pub fn binary<B: Into<Binary>>(&mut self, data: B) {
self.write(Frame::message(data, OpCode::Binary, true).generate(false));
self.write(Frame::message(data, OpCode::Binary, true, false));
}
/// Send ping frame
#[inline]
pub fn ping(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Ping, true).generate(false));
self.write(Frame::message(Vec::from(message), OpCode::Ping, true, false));
}
/// Send pong frame
#[inline]
pub fn pong(&mut self, message: &str) {
self.write(Frame::message(Vec::from(message), OpCode::Pong, true).generate(false));
self.write(Frame::message(Vec::from(message), OpCode::Pong, true, false));
}
/// Send close frame
#[inline]
pub fn close(&mut self, code: CloseCode, reason: &str) {
self.write(Frame::close(code, reason).generate(false));
self.write(Frame::close(code, reason, false));
}
/// Returns drain future

View file

@ -9,14 +9,6 @@ use body::Binary;
use ws::proto::{OpCode, CloseCode};
use ws::mask::apply_mask;
#[derive(Debug, PartialEq)]
pub(crate) enum FrameData {
Complete(Binary),
Split(Binary, Binary),
}
const MAX_LEN: usize = 122;
/// A struct representing a `WebSocket` frame.
#[derive(Debug)]
pub(crate) struct Frame {
@ -35,20 +27,9 @@ impl Frame {
(self.finished, self.opcode, self.payload)
}
/// Create a new data frame.
#[inline]
pub fn message<B: Into<Binary>>(data: B, code: OpCode, finished: bool) -> Frame {
Frame {
finished: finished,
opcode: code,
payload: data.into(),
.. Frame::default()
}
}
/// Create a new Close control frame.
#[inline]
pub fn close(code: CloseCode, reason: &str) -> Frame {
pub fn close(code: CloseCode, reason: &str, genmask: bool) -> Binary {
let raw: [u8; 2] = unsafe {
let u: u16 = code.into();
mem::transmute(u.to_be())
@ -63,10 +44,7 @@ impl Frame {
.cloned())
};
Frame {
payload: payload.into(),
.. Frame::default()
}
Frame::message(payload, OpCode::Close, true, genmask)
}
/// Parse the input stream into a frame.
@ -162,7 +140,7 @@ impl Frame {
}
OpCode::Close if length > 125 => {
debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame.");
return Ok(Some(Frame::close(CloseCode::Protocol, "Received close frame with payload length exceeding 125.")))
return Ok(Some(Frame::default()))
}
_ => ()
}
@ -183,96 +161,61 @@ impl Frame {
}
/// Generate binary representation
pub fn generate(self, genmask: bool) -> FrameData {
let mut one = 0u8;
let code: u8 = self.opcode.into();
if self.finished {
one |= 0x80;
}
if self.rsv1 {
one |= 0x40;
}
if self.rsv2 {
one |= 0x20;
}
if self.rsv3 {
one |= 0x10;
}
one |= code;
let (two, mask_size) = if genmask {
(0x80, 4)
pub fn message<B: Into<Binary>>(data: B, code: OpCode,
finished: bool, genmask: bool) -> Binary
{
let payload = data.into();
let one: u8 = if finished {
0x80 | Into::<u8>::into(code)
} else {
(0, 0)
code.into()
};
let payload_len = payload.len();
let (two, p_len) = if genmask {
(0x80, payload_len + 4)
} else {
(0, payload_len)
};
let payload_len = self.payload.len();
let mut buf = if payload_len < MAX_LEN {
if genmask {
let len = payload_len + 6;
let mask: [u8; 4] = rand::random();
let mut buf = BytesMut::with_capacity(len);
let mut buf = if payload_len < 126 {
let mut buf = BytesMut::with_capacity(p_len + 2);
buf.put_slice(&[one, two | payload_len as u8]);
buf
} else if payload_len <= 65_535 {
let mut buf = BytesMut::with_capacity(p_len + 4);
buf.put_slice(&[one, two | 126]);
{
let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one;
buf_mut[1] = two | payload_len as u8;
buf_mut[2..6].copy_from_slice(&mask);
buf_mut[6..payload_len+6].copy_from_slice(self.payload.as_ref());
apply_mask(&mut buf_mut[6..], &mask);
}
unsafe{buf.advance_mut(len)};
return FrameData::Complete(buf.into())
} else {
let len = payload_len + 2;
let mut buf = BytesMut::with_capacity(len);
{
let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one;
buf_mut[1] = two | payload_len as u8;
buf_mut[2..payload_len+2].copy_from_slice(self.payload.as_ref());
}
unsafe{buf.advance_mut(len)};
return FrameData::Complete(buf.into())
}
} else if payload_len < 126 {
let mut buf = BytesMut::with_capacity(mask_size + 2);
{
let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one;
buf_mut[1] = two | payload_len as u8;
BigEndian::write_u16(&mut buf_mut[..2], payload_len as u16);
}
unsafe{buf.advance_mut(2)};
buf
} else if payload_len <= 65_535 {
let mut buf = BytesMut::with_capacity(mask_size + 4);
{
let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one;
buf_mut[1] = two | 126;
BigEndian::write_u16(&mut buf_mut[2..4], payload_len as u16);
}
unsafe{buf.advance_mut(4)};
buf
} else {
let mut buf = BytesMut::with_capacity(mask_size + 10);
let mut buf = BytesMut::with_capacity(p_len + 8);
buf.put_slice(&[one, two | 127]);
{
let buf_mut = unsafe{buf.bytes_mut()};
buf_mut[0] = one;
buf_mut[1] = two | 127;
BigEndian::write_u64(&mut buf_mut[2..10], payload_len as u64);
BigEndian::write_u64(&mut buf_mut[..8], payload_len as u64);
}
unsafe{buf.advance_mut(10)};
unsafe{buf.advance_mut(8)};
buf
};
if genmask {
let mut payload = Vec::from(self.payload.as_ref());
let mask: [u8; 4] = rand::random();
apply_mask(&mut payload, &mask);
buf.extend_from_slice(&mask);
FrameData::Split(buf.into(), payload.into())
unsafe {
{
let buf_mut = buf.bytes_mut();
buf_mut[..4].copy_from_slice(&mask);
buf_mut[4..payload_len+4].copy_from_slice(payload.as_ref());
apply_mask(&mut buf_mut[4..], &mask);
}
buf.advance_mut(payload_len + 4);
}
buf.into()
} else {
FrameData::Split(buf.into(), self.payload)
buf.put_slice(payload.as_ref());
buf.into()
}
}
}
@ -392,31 +335,28 @@ mod tests {
#[test]
fn test_ping_frame() {
let frame = Frame::message(Vec::from("data"), OpCode::Ping, true);
let res = frame.generate(false);
let frame = Frame::message(Vec::from("data"), OpCode::Ping, true, false);
let mut v = vec![137u8, 4u8];
v.extend(b"data");
assert_eq!(res, FrameData::Complete(v.into()));
assert_eq!(frame, v.into());
}
#[test]
fn test_pong_frame() {
let frame = Frame::message(Vec::from("data"), OpCode::Pong, true);
let res = frame.generate(false);
let frame = Frame::message(Vec::from("data"), OpCode::Pong, true, false);
let mut v = vec![138u8, 4u8];
v.extend(b"data");
assert_eq!(res, FrameData::Complete(v.into()));
assert_eq!(frame, v.into());
}
#[test]
fn test_close_frame() {
let frame = Frame::close(CloseCode::Normal, "data");
let res = frame.generate(false);
let frame = Frame::close(CloseCode::Normal, "data", false);
let mut v = vec![136u8, 6u8, 3u8, 232u8];
v.extend(b"data");
assert_eq!(res, FrameData::Complete(v.into()));
assert_eq!(frame, v.into());
}
}