mirror of
https://github.com/actix/actix-web.git
synced 2025-01-02 13:28:44 +00:00
do not handle upgrade and connect requests
This commit is contained in:
parent
b960b5827c
commit
d39c018c93
7 changed files with 141 additions and 57 deletions
|
@ -4,7 +4,7 @@ use std::io::{self, Write};
|
|||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use tokio_codec::{Decoder, Encoder};
|
||||
|
||||
use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder};
|
||||
use super::decoder::{PayloadDecoder, PayloadItem, RequestDecoder, RequestPayloadType};
|
||||
use super::encoder::{ResponseEncoder, ResponseLength};
|
||||
use body::{Binary, Body};
|
||||
use config::ServiceConfig;
|
||||
|
@ -17,10 +17,11 @@ use response::Response;
|
|||
|
||||
bitflags! {
|
||||
struct Flags: u8 {
|
||||
const HEAD = 0b0000_0001;
|
||||
const UPGRADE = 0b0000_0010;
|
||||
const KEEPALIVE = 0b0000_0100;
|
||||
const KEEPALIVE_ENABLED = 0b0001_0000;
|
||||
const HEAD = 0b0000_0001;
|
||||
const UPGRADE = 0b0000_0010;
|
||||
const KEEPALIVE = 0b0000_0100;
|
||||
const KEEPALIVE_ENABLED = 0b0000_1000;
|
||||
const UNHANDLED = 0b0001_0000;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -39,11 +40,19 @@ pub enum OutMessage {
|
|||
#[derive(Debug)]
|
||||
pub enum InMessage {
|
||||
/// Request
|
||||
Message { req: Request, payload: bool },
|
||||
Message(Request, InMessageType),
|
||||
/// Payload chunk
|
||||
Chunk(Option<Bytes>),
|
||||
}
|
||||
|
||||
/// Incoming request type
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum InMessageType {
|
||||
None,
|
||||
Payload,
|
||||
Unhandled,
|
||||
}
|
||||
|
||||
/// HTTP/1 Codec
|
||||
pub struct Codec {
|
||||
config: ServiceConfig,
|
||||
|
@ -246,6 +255,8 @@ impl Decoder for Codec {
|
|||
Some(PayloadItem::Eof) => Some(InMessage::Chunk(None)),
|
||||
None => None,
|
||||
})
|
||||
} else if self.flags.contains(Flags::UNHANDLED) {
|
||||
Ok(None)
|
||||
} else if let Some((req, payload)) = self.decoder.decode(src)? {
|
||||
self.flags
|
||||
.set(Flags::HEAD, req.inner.method == Method::HEAD);
|
||||
|
@ -253,11 +264,21 @@ impl Decoder for Codec {
|
|||
if self.flags.contains(Flags::KEEPALIVE_ENABLED) {
|
||||
self.flags.set(Flags::KEEPALIVE, req.keep_alive());
|
||||
}
|
||||
self.payload = payload;
|
||||
Ok(Some(InMessage::Message {
|
||||
req,
|
||||
payload: self.payload.is_some(),
|
||||
}))
|
||||
let payload = match payload {
|
||||
RequestPayloadType::None => {
|
||||
self.payload = None;
|
||||
InMessageType::None
|
||||
}
|
||||
RequestPayloadType::Payload(pl) => {
|
||||
self.payload = Some(pl);
|
||||
InMessageType::Payload
|
||||
}
|
||||
RequestPayloadType::Unhandled => {
|
||||
self.payload = None;
|
||||
InMessageType::Unhandled
|
||||
}
|
||||
};
|
||||
Ok(Some(InMessage::Message(req, payload)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
|
|
|
@ -16,6 +16,13 @@ const MAX_HEADERS: usize = 96;
|
|||
|
||||
pub struct RequestDecoder(&'static RequestPool);
|
||||
|
||||
/// Incoming request type
|
||||
pub enum RequestPayloadType {
|
||||
None,
|
||||
Payload(PayloadDecoder),
|
||||
Unhandled,
|
||||
}
|
||||
|
||||
impl RequestDecoder {
|
||||
pub(crate) fn with_pool(pool: &'static RequestPool) -> RequestDecoder {
|
||||
RequestDecoder(pool)
|
||||
|
@ -29,7 +36,7 @@ impl Default for RequestDecoder {
|
|||
}
|
||||
|
||||
impl Decoder for RequestDecoder {
|
||||
type Item = (Request, Option<PayloadDecoder>);
|
||||
type Item = (Request, RequestPayloadType);
|
||||
type Error = ParseError;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
|
@ -149,18 +156,18 @@ impl Decoder for RequestDecoder {
|
|||
// https://tools.ietf.org/html/rfc7230#section-3.3.3
|
||||
let decoder = if chunked {
|
||||
// Chunked encoding
|
||||
Some(PayloadDecoder::chunked())
|
||||
RequestPayloadType::Payload(PayloadDecoder::chunked())
|
||||
} else if let Some(len) = content_length {
|
||||
// Content-Length
|
||||
Some(PayloadDecoder::length(len))
|
||||
RequestPayloadType::Payload(PayloadDecoder::length(len))
|
||||
} else if has_upgrade || msg.inner.method == Method::CONNECT {
|
||||
// upgrade(websocket) or connect
|
||||
Some(PayloadDecoder::eof())
|
||||
RequestPayloadType::Unhandled
|
||||
} else if src.len() >= MAX_BUFFER_SIZE {
|
||||
error!("MAX_BUFFER_SIZE unprocessed data reached, closing");
|
||||
return Err(ParseError::TooLarge);
|
||||
} else {
|
||||
None
|
||||
RequestPayloadType::None
|
||||
};
|
||||
|
||||
Ok(Some((msg, decoder)))
|
||||
|
@ -481,20 +488,36 @@ mod tests {
|
|||
|
||||
use super::*;
|
||||
use error::ParseError;
|
||||
use h1::InMessage;
|
||||
use h1::{InMessage, InMessageType};
|
||||
use httpmessage::HttpMessage;
|
||||
use request::Request;
|
||||
|
||||
impl RequestPayloadType {
|
||||
fn unwrap(self) -> PayloadDecoder {
|
||||
match self {
|
||||
RequestPayloadType::Payload(pl) => pl,
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_unhandled(&self) -> bool {
|
||||
match self {
|
||||
RequestPayloadType::Unhandled => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InMessage {
|
||||
fn message(self) -> Request {
|
||||
match self {
|
||||
InMessage::Message { req, payload: _ } => req,
|
||||
InMessage::Message(req, _) => req,
|
||||
_ => panic!("error"),
|
||||
}
|
||||
}
|
||||
fn is_payload(&self) -> bool {
|
||||
match *self {
|
||||
InMessage::Message { req: _, payload } => payload,
|
||||
InMessage::Message(_, payload) => payload == InMessageType::Payload,
|
||||
_ => panic!("error"),
|
||||
}
|
||||
}
|
||||
|
@ -919,13 +942,9 @@ mod tests {
|
|||
);
|
||||
let mut reader = RequestDecoder::default();
|
||||
let (req, pl) = reader.decode(&mut buf).unwrap().unwrap();
|
||||
let mut pl = pl.unwrap();
|
||||
assert!(!req.keep_alive());
|
||||
assert!(req.upgrade());
|
||||
assert_eq!(
|
||||
pl.decode(&mut buf).unwrap().unwrap().chunk().as_ref(),
|
||||
b"some raw data"
|
||||
);
|
||||
assert!(pl.is_unhandled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -18,7 +18,8 @@ use error::DispatchError;
|
|||
use request::Request;
|
||||
use response::Response;
|
||||
|
||||
use super::codec::{Codec, InMessage, OutMessage};
|
||||
use super::codec::{Codec, InMessage, InMessageType, OutMessage};
|
||||
use super::H1ServiceResult;
|
||||
|
||||
const MAX_PIPELINED_MESSAGES: usize = 16;
|
||||
|
||||
|
@ -41,13 +42,14 @@ where
|
|||
{
|
||||
service: S,
|
||||
flags: Flags,
|
||||
framed: Framed<T, Codec>,
|
||||
framed: Option<Framed<T, Codec>>,
|
||||
error: Option<DispatchError<S::Error>>,
|
||||
config: ServiceConfig,
|
||||
|
||||
state: State<S>,
|
||||
payload: Option<PayloadSender>,
|
||||
messages: VecDeque<Message>,
|
||||
unhandled: Option<Request>,
|
||||
|
||||
ka_expire: Instant,
|
||||
ka_timer: Option<Delay>,
|
||||
|
@ -112,9 +114,10 @@ where
|
|||
state: State::None,
|
||||
error: None,
|
||||
messages: VecDeque::new(),
|
||||
framed: Some(framed),
|
||||
unhandled: None,
|
||||
service,
|
||||
flags,
|
||||
framed,
|
||||
config,
|
||||
ka_expire,
|
||||
ka_timer,
|
||||
|
@ -144,7 +147,7 @@ where
|
|||
/// Flush stream
|
||||
fn poll_flush(&mut self) -> Poll<(), DispatchError<S::Error>> {
|
||||
if !self.flags.contains(Flags::FLUSHED) {
|
||||
match self.framed.poll_complete() {
|
||||
match self.framed.as_mut().unwrap().poll_complete() {
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(err) => {
|
||||
debug!("Error sending data: {}", err);
|
||||
|
@ -187,7 +190,11 @@ where
|
|||
State::ServiceCall(ref mut fut) => {
|
||||
match fut.poll().map_err(DispatchError::Service)? {
|
||||
Async::Ready(mut res) => {
|
||||
self.framed.get_codec_mut().prepare_te(&mut res);
|
||||
self.framed
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.get_codec_mut()
|
||||
.prepare_te(&mut res);
|
||||
let body = res.replace_body(Body::Empty);
|
||||
Some(State::SendResponse(Some((
|
||||
OutMessage::Response(res),
|
||||
|
@ -200,11 +207,11 @@ where
|
|||
// send respons
|
||||
State::SendResponse(ref mut item) => {
|
||||
let (msg, body) = item.take().expect("SendResponse is empty");
|
||||
match self.framed.start_send(msg) {
|
||||
match self.framed.as_mut().unwrap().start_send(msg) {
|
||||
Ok(AsyncSink::Ready) => {
|
||||
self.flags.set(
|
||||
Flags::KEEPALIVE,
|
||||
self.framed.get_codec().keepalive(),
|
||||
self.framed.as_mut().unwrap().get_codec().keepalive(),
|
||||
);
|
||||
self.flags.remove(Flags::FLUSHED);
|
||||
match body {
|
||||
|
@ -233,7 +240,7 @@ where
|
|||
// Send payload
|
||||
State::SendPayload(ref mut stream, ref mut bin) => {
|
||||
if let Some(item) = bin.take() {
|
||||
match self.framed.start_send(item) {
|
||||
match self.framed.as_mut().unwrap().start_send(item) {
|
||||
Ok(AsyncSink::Ready) => {
|
||||
self.flags.remove(Flags::FLUSHED);
|
||||
}
|
||||
|
@ -248,6 +255,8 @@ where
|
|||
match stream.poll() {
|
||||
Ok(Async::Ready(Some(item))) => match self
|
||||
.framed
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.start_send(OutMessage::Chunk(Some(item.into())))
|
||||
{
|
||||
Ok(AsyncSink::Ready) => {
|
||||
|
@ -297,7 +306,11 @@ where
|
|||
let mut task = self.service.call(req);
|
||||
match task.poll().map_err(DispatchError::Service)? {
|
||||
Async::Ready(mut res) => {
|
||||
self.framed.get_codec_mut().prepare_te(&mut res);
|
||||
self.framed
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.get_codec_mut()
|
||||
.prepare_te(&mut res);
|
||||
let body = res.replace_body(Body::Empty);
|
||||
Ok(State::SendResponse(Some((OutMessage::Response(res), body))))
|
||||
}
|
||||
|
@ -314,17 +327,24 @@ where
|
|||
|
||||
let mut updated = false;
|
||||
'outer: loop {
|
||||
match self.framed.poll() {
|
||||
match self.framed.as_mut().unwrap().poll() {
|
||||
Ok(Async::Ready(Some(msg))) => {
|
||||
updated = true;
|
||||
self.flags.insert(Flags::STARTED);
|
||||
|
||||
match msg {
|
||||
InMessage::Message { req, payload } => {
|
||||
if payload {
|
||||
let (ps, pl) = Payload::new(false);
|
||||
*req.inner.payload.borrow_mut() = Some(pl);
|
||||
self.payload = Some(ps);
|
||||
InMessage::Message(req, payload) => {
|
||||
match payload {
|
||||
InMessageType::Payload => {
|
||||
let (ps, pl) = Payload::new(false);
|
||||
*req.inner.payload.borrow_mut() = Some(pl);
|
||||
self.payload = Some(ps);
|
||||
}
|
||||
InMessageType::Unhandled => {
|
||||
self.unhandled = Some(req);
|
||||
return Ok(updated);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
// handle request early
|
||||
|
@ -454,15 +474,16 @@ where
|
|||
S: Service<Request = Request, Response = Response>,
|
||||
S::Error: Debug,
|
||||
{
|
||||
type Item = ();
|
||||
type Item = H1ServiceResult<T>;
|
||||
type Error = DispatchError<S::Error>;
|
||||
|
||||
#[inline]
|
||||
fn poll(&mut self) -> Poll<(), Self::Error> {
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
if self.flags.contains(Flags::SHUTDOWN) {
|
||||
self.poll_keepalive()?;
|
||||
try_ready!(self.poll_flush());
|
||||
Ok(AsyncWrite::shutdown(self.framed.get_mut())?)
|
||||
let io = self.framed.take().unwrap().into_inner();
|
||||
Ok(Async::Ready(H1ServiceResult::Shutdown(io)))
|
||||
} else {
|
||||
self.poll_keepalive()?;
|
||||
self.poll_request()?;
|
||||
|
@ -474,15 +495,21 @@ where
|
|||
if let Some(err) = self.error.take() {
|
||||
Err(err)
|
||||
} else if self.flags.contains(Flags::DISCONNECTED) {
|
||||
Ok(Async::Ready(()))
|
||||
Ok(Async::Ready(H1ServiceResult::Disconnected))
|
||||
}
|
||||
// unhandled request (upgrade or connect)
|
||||
else if self.unhandled.is_some() {
|
||||
let req = self.unhandled.take().unwrap();
|
||||
let framed = self.framed.take().unwrap();
|
||||
Ok(Async::Ready(H1ServiceResult::Unhandled(req, framed)))
|
||||
}
|
||||
// disconnect if keep-alive is not enabled
|
||||
else if self.flags.contains(Flags::STARTED) && !self
|
||||
.flags
|
||||
.intersects(Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED)
|
||||
{
|
||||
self.flags.insert(Flags::SHUTDOWN);
|
||||
self.poll()
|
||||
let io = self.framed.take().unwrap().into_inner();
|
||||
Ok(Async::Ready(H1ServiceResult::Shutdown(io)))
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
|
|
|
@ -1,11 +1,22 @@
|
|||
//! HTTP/1 implementation
|
||||
use actix_net::codec::Framed;
|
||||
|
||||
mod codec;
|
||||
mod decoder;
|
||||
mod dispatcher;
|
||||
mod encoder;
|
||||
mod service;
|
||||
|
||||
pub use self::codec::{Codec, InMessage, OutMessage};
|
||||
pub use self::codec::{Codec, InMessage, InMessageType, OutMessage};
|
||||
pub use self::decoder::{PayloadDecoder, RequestDecoder};
|
||||
pub use self::dispatcher::Dispatcher;
|
||||
pub use self::service::{H1Service, H1ServiceHandler};
|
||||
|
||||
use request::Request;
|
||||
|
||||
/// H1 service response type
|
||||
pub enum H1ServiceResult<T> {
|
||||
Disconnected,
|
||||
Shutdown(T),
|
||||
Unhandled(Request, Framed<T, Codec>),
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ use request::Request;
|
|||
use response::Response;
|
||||
|
||||
use super::dispatcher::Dispatcher;
|
||||
use super::H1ServiceResult;
|
||||
|
||||
/// `NewService` implementation for HTTP1 transport
|
||||
pub struct H1Service<T, S> {
|
||||
|
@ -51,7 +52,7 @@ where
|
|||
S::Error: Debug,
|
||||
{
|
||||
type Request = T;
|
||||
type Response = ();
|
||||
type Response = H1ServiceResult<T>;
|
||||
type Error = DispatchError<S::Error>;
|
||||
type InitError = S::InitError;
|
||||
type Service = H1ServiceHandler<T, S::Service>;
|
||||
|
@ -243,7 +244,7 @@ where
|
|||
S::Error: Debug,
|
||||
{
|
||||
type Request = T;
|
||||
type Response = ();
|
||||
type Response = H1ServiceResult<T>;
|
||||
type Error = DispatchError<S::Error>;
|
||||
type Future = Dispatcher<T, S>;
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ use std::{io::Read, io::Write, net, thread, time};
|
|||
|
||||
use actix::System;
|
||||
use actix_net::server::Server;
|
||||
use actix_net::service::NewServiceExt;
|
||||
use actix_web::{client, test, HttpMessage};
|
||||
use bytes::Bytes;
|
||||
use futures::future::{self, ok};
|
||||
|
@ -29,6 +30,7 @@ fn test_h1_v2() {
|
|||
.server_hostname("localhost")
|
||||
.server_address(addr)
|
||||
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
|
||||
.map(|_| ())
|
||||
}).unwrap()
|
||||
.run();
|
||||
});
|
||||
|
@ -53,6 +55,7 @@ fn test_slow_request() {
|
|||
h1::H1Service::build()
|
||||
.client_timeout(100)
|
||||
.finish(|_| future::ok::<_, ()>(Response::Ok().finish()))
|
||||
.map(|_| ())
|
||||
}).unwrap()
|
||||
.run();
|
||||
});
|
||||
|
@ -72,6 +75,7 @@ fn test_malformed_request() {
|
|||
Server::new()
|
||||
.bind("test", addr, move || {
|
||||
h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().finish()))
|
||||
.map(|_| ())
|
||||
}).unwrap()
|
||||
.run();
|
||||
});
|
||||
|
@ -106,7 +110,7 @@ fn test_content_length() {
|
|||
StatusCode::NOT_FOUND,
|
||||
];
|
||||
future::ok::<_, ()>(Response::new(statuses[indx]))
|
||||
})
|
||||
}).map(|_| ())
|
||||
}).unwrap()
|
||||
.run();
|
||||
});
|
||||
|
@ -172,7 +176,7 @@ fn test_headers() {
|
|||
);
|
||||
}
|
||||
future::ok::<_, ()>(builder.body(data.clone()))
|
||||
})
|
||||
}).map(|_| ())
|
||||
})
|
||||
.unwrap()
|
||||
.run()
|
||||
|
@ -221,6 +225,7 @@ fn test_body() {
|
|||
Server::new()
|
||||
.bind("test", addr, move || {
|
||||
h1::H1Service::new(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
|
||||
.map(|_| ())
|
||||
}).unwrap()
|
||||
.run();
|
||||
});
|
||||
|
@ -246,7 +251,7 @@ fn test_head_empty() {
|
|||
.bind("test", addr, move || {
|
||||
h1::H1Service::new(|_| {
|
||||
ok::<_, ()>(Response::Ok().content_length(STR.len() as u64).finish())
|
||||
})
|
||||
}).map(|_| ())
|
||||
}).unwrap()
|
||||
.run()
|
||||
});
|
||||
|
@ -282,7 +287,7 @@ fn test_head_binary() {
|
|||
ok::<_, ()>(
|
||||
Response::Ok().content_length(STR.len() as u64).body(STR),
|
||||
)
|
||||
})
|
||||
}).map(|_| ())
|
||||
}).unwrap()
|
||||
.run()
|
||||
});
|
||||
|
@ -314,7 +319,7 @@ fn test_head_binary2() {
|
|||
thread::spawn(move || {
|
||||
Server::new()
|
||||
.bind("test", addr, move || {
|
||||
h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR)))
|
||||
h1::H1Service::new(|_| ok::<_, ()>(Response::Ok().body(STR))).map(|_| ())
|
||||
}).unwrap()
|
||||
.run()
|
||||
});
|
||||
|
@ -349,7 +354,7 @@ fn test_body_length() {
|
|||
.content_length(STR.len() as u64)
|
||||
.body(Body::Streaming(Box::new(body))),
|
||||
)
|
||||
})
|
||||
}).map(|_| ())
|
||||
}).unwrap()
|
||||
.run()
|
||||
});
|
||||
|
@ -380,7 +385,7 @@ fn test_body_chunked_explicit() {
|
|||
.chunked()
|
||||
.body(Body::Streaming(Box::new(body))),
|
||||
)
|
||||
})
|
||||
}).map(|_| ())
|
||||
}).unwrap()
|
||||
.run()
|
||||
});
|
||||
|
@ -409,7 +414,7 @@ fn test_body_chunked_implicit() {
|
|||
h1::H1Service::new(|_| {
|
||||
let body = once(Ok(Bytes::from_static(STR.as_ref())));
|
||||
ok::<_, ()>(Response::Ok().body(Body::Streaming(Box::new(body))))
|
||||
})
|
||||
}).map(|_| ())
|
||||
}).unwrap()
|
||||
.run()
|
||||
});
|
||||
|
|
|
@ -51,7 +51,7 @@ fn test_simple() {
|
|||
.and_then(TakeItem::new().map_err(|_| ()))
|
||||
.and_then(|(req, framed): (_, Framed<_, _>)| {
|
||||
// validate request
|
||||
if let Some(h1::InMessage::Message { req, payload: _ }) = req {
|
||||
if let Some(h1::InMessage::Message(req, _)) = req {
|
||||
match ws::handshake(&req) {
|
||||
Err(e) => {
|
||||
// validation failed
|
||||
|
|
Loading…
Reference in a new issue