mirror of
https://github.com/actix/actix-web.git
synced 2025-01-04 14:28:50 +00:00
add upgrade service support to h1 dispatcher
This commit is contained in:
parent
43d325a139
commit
561f83d044
6 changed files with 271 additions and 101 deletions
|
@ -1,9 +1,18 @@
|
||||||
# Changes
|
# Changes
|
||||||
|
|
||||||
|
## [0.1.0-alpha.5] - 2019-04-xx
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
* Allow to use custom service for upgrade requests
|
||||||
|
|
||||||
|
|
||||||
## [0.1.0-alpha.4] - 2019-04-08
|
## [0.1.0-alpha.4] - 2019-04-08
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
* Allow to use custom `Expect` handler
|
||||||
|
|
||||||
* Add minimal `std::error::Error` impl for `Error`
|
* Add minimal `std::error::Error` impl for `Error`
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
|
@ -115,6 +115,27 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Provide service for custom `Connection: UPGRADE` support.
|
||||||
|
///
|
||||||
|
/// If service is provided then normal requests handling get halted
|
||||||
|
/// and this service get called with original request and framed object.
|
||||||
|
pub fn upgrade<F, U1>(self, upgrade: F) -> HttpServiceBuilder<T, S, X, U1>
|
||||||
|
where
|
||||||
|
F: IntoNewService<U1>,
|
||||||
|
U1: NewService<Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||||
|
U1::Error: fmt::Display,
|
||||||
|
U1::InitError: fmt::Debug,
|
||||||
|
{
|
||||||
|
HttpServiceBuilder {
|
||||||
|
keep_alive: self.keep_alive,
|
||||||
|
client_timeout: self.client_timeout,
|
||||||
|
client_disconnect: self.client_disconnect,
|
||||||
|
expect: self.expect,
|
||||||
|
upgrade: Some(upgrade.into_new_service()),
|
||||||
|
_t: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Finish service configuration and create *http service* for HTTP/1 protocol.
|
/// Finish service configuration and create *http service* for HTTP/1 protocol.
|
||||||
pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
|
pub fn h1<F, P, B>(self, service: F) -> H1Service<T, P, S, B, X, U>
|
||||||
where
|
where
|
||||||
|
|
|
@ -350,6 +350,9 @@ pub enum DispatchError {
|
||||||
/// Service error
|
/// Service error
|
||||||
Service(Error),
|
Service(Error),
|
||||||
|
|
||||||
|
/// Upgrade service error
|
||||||
|
Upgrade,
|
||||||
|
|
||||||
/// An `io::Error` that occurred while trying to read or write to a network
|
/// An `io::Error` that occurred while trying to read or write to a network
|
||||||
/// stream.
|
/// stream.
|
||||||
#[display(fmt = "IO error: {}", _0)]
|
#[display(fmt = "IO error: {}", _0)]
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::collections::VecDeque;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use std::{fmt, io};
|
use std::{fmt, io};
|
||||||
|
|
||||||
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
|
use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
|
||||||
use actix_service::Service;
|
use actix_service::Service;
|
||||||
use actix_utils::cloneable::CloneableService;
|
use actix_utils::cloneable::CloneableService;
|
||||||
use bitflags::bitflags;
|
use bitflags::bitflags;
|
||||||
|
@ -34,7 +34,7 @@ bitflags! {
|
||||||
const SHUTDOWN = 0b0000_1000;
|
const SHUTDOWN = 0b0000_1000;
|
||||||
const READ_DISCONNECT = 0b0001_0000;
|
const READ_DISCONNECT = 0b0001_0000;
|
||||||
const WRITE_DISCONNECT = 0b0010_0000;
|
const WRITE_DISCONNECT = 0b0010_0000;
|
||||||
const DROPPING = 0b0100_0000;
|
const UPGRADE = 0b0100_0000;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,7 +49,22 @@ where
|
||||||
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display,
|
||||||
{
|
{
|
||||||
inner: Option<InnerDispatcher<T, S, B, X, U>>,
|
inner: DispatcherState<T, S, B, X, U>,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum DispatcherState<T, S, B, X, U>
|
||||||
|
where
|
||||||
|
S: Service<Request = Request>,
|
||||||
|
S::Error: Into<Error>,
|
||||||
|
B: MessageBody,
|
||||||
|
X: Service<Request = Request, Response = Request>,
|
||||||
|
X::Error: Into<Error>,
|
||||||
|
U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
|
||||||
|
U::Error: fmt::Display,
|
||||||
|
{
|
||||||
|
Normal(InnerDispatcher<T, S, B, X, U>),
|
||||||
|
Upgrade(U::Future),
|
||||||
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct InnerDispatcher<T, S, B, X, U>
|
struct InnerDispatcher<T, S, B, X, U>
|
||||||
|
@ -83,6 +98,7 @@ where
|
||||||
|
|
||||||
enum DispatcherMessage {
|
enum DispatcherMessage {
|
||||||
Item(Request),
|
Item(Request),
|
||||||
|
Upgrade(Request),
|
||||||
Error(Response<()>),
|
Error(Response<()>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,18 +137,24 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, B, X> fmt::Debug for State<S, B, X>
|
enum PollResponse {
|
||||||
where
|
Upgrade(Request),
|
||||||
S: Service<Request = Request>,
|
DoNothing,
|
||||||
X: Service<Request = Request, Response = Request>,
|
DrainWriteBuf,
|
||||||
B: MessageBody,
|
}
|
||||||
{
|
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
impl PartialEq for PollResponse {
|
||||||
|
fn eq(&self, other: &PollResponse) -> bool {
|
||||||
match self {
|
match self {
|
||||||
State::None => write!(f, "State::None"),
|
PollResponse::DrainWriteBuf => match other {
|
||||||
State::ExpectCall(_) => write!(f, "State::ExceptCall"),
|
PollResponse::DrainWriteBuf => true,
|
||||||
State::ServiceCall(_) => write!(f, "State::ServiceCall"),
|
_ => false,
|
||||||
State::SendPayload(_) => write!(f, "State::SendPayload"),
|
},
|
||||||
|
PollResponse::DoNothing => match other {
|
||||||
|
PollResponse::DoNothing => true,
|
||||||
|
_ => false,
|
||||||
|
},
|
||||||
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -197,7 +219,7 @@ where
|
||||||
};
|
};
|
||||||
|
|
||||||
Dispatcher {
|
Dispatcher {
|
||||||
inner: Some(InnerDispatcher {
|
inner: DispatcherState::Normal(InnerDispatcher {
|
||||||
io,
|
io,
|
||||||
codec,
|
codec,
|
||||||
read_buf,
|
read_buf,
|
||||||
|
@ -230,7 +252,10 @@ where
|
||||||
U::Error: fmt::Display,
|
U::Error: fmt::Display,
|
||||||
{
|
{
|
||||||
fn can_read(&self) -> bool {
|
fn can_read(&self) -> bool {
|
||||||
if self.flags.contains(Flags::READ_DISCONNECT) {
|
if self
|
||||||
|
.flags
|
||||||
|
.intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
|
||||||
|
{
|
||||||
false
|
false
|
||||||
} else if let Some(ref info) = self.payload {
|
} else if let Some(ref info) = self.payload {
|
||||||
info.need_read() == PayloadStatus::Read
|
info.need_read() == PayloadStatus::Read
|
||||||
|
@ -315,7 +340,7 @@ where
|
||||||
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
|
.extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_response(&mut self) -> Result<bool, DispatchError> {
|
fn poll_response(&mut self) -> Result<PollResponse, DispatchError> {
|
||||||
loop {
|
loop {
|
||||||
let state = match self.state {
|
let state = match self.state {
|
||||||
State::None => match self.messages.pop_front() {
|
State::None => match self.messages.pop_front() {
|
||||||
|
@ -325,6 +350,9 @@ where
|
||||||
Some(DispatcherMessage::Error(res)) => {
|
Some(DispatcherMessage::Error(res)) => {
|
||||||
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
|
Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
|
||||||
}
|
}
|
||||||
|
Some(DispatcherMessage::Upgrade(req)) => {
|
||||||
|
return Ok(PollResponse::Upgrade(req));
|
||||||
|
}
|
||||||
None => None,
|
None => None,
|
||||||
},
|
},
|
||||||
State::ExpectCall(ref mut fut) => match fut.poll() {
|
State::ExpectCall(ref mut fut) => match fut.poll() {
|
||||||
|
@ -374,10 +402,10 @@ where
|
||||||
)?;
|
)?;
|
||||||
self.state = State::None;
|
self.state = State::None;
|
||||||
}
|
}
|
||||||
Async::NotReady => return Ok(false),
|
Async::NotReady => return Ok(PollResponse::DoNothing),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return Ok(true);
|
return Ok(PollResponse::DrainWriteBuf);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -405,7 +433,7 @@ where
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(false)
|
Ok(PollResponse::DoNothing)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
|
fn handle_request(&mut self, req: Request) -> Result<State<S, B, X>, DispatchError> {
|
||||||
|
@ -461,16 +489,19 @@ where
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
Message::Item(mut req) => {
|
Message::Item(mut req) => {
|
||||||
match self.codec.message_type() {
|
let pl = self.codec.message_type();
|
||||||
MessageType::Payload | MessageType::Stream => {
|
|
||||||
|
if pl == MessageType::Stream && self.upgrade.is_some() {
|
||||||
|
self.messages.push_back(DispatcherMessage::Upgrade(req));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if pl == MessageType::Payload || pl == MessageType::Stream {
|
||||||
let (ps, pl) = Payload::create(false);
|
let (ps, pl) = Payload::create(false);
|
||||||
let (req1, _) =
|
let (req1, _) =
|
||||||
req.replace_payload(crate::Payload::H1(pl));
|
req.replace_payload(crate::Payload::H1(pl));
|
||||||
req = req1;
|
req = req1;
|
||||||
self.payload = Some(ps);
|
self.payload = Some(ps);
|
||||||
}
|
}
|
||||||
_ => (),
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle request early
|
// handle request early
|
||||||
if self.state.is_empty() {
|
if self.state.is_empty() {
|
||||||
|
@ -633,7 +664,8 @@ where
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||||
let inner = self.inner.as_mut().unwrap();
|
match self.inner {
|
||||||
|
DispatcherState::Normal(ref mut inner) => {
|
||||||
inner.poll_keepalive()?;
|
inner.poll_keepalive()?;
|
||||||
|
|
||||||
if inner.flags.contains(Flags::SHUTDOWN) {
|
if inner.flags.contains(Flags::SHUTDOWN) {
|
||||||
|
@ -654,7 +686,9 @@ where
|
||||||
} else {
|
} else {
|
||||||
// read socket into a buf
|
// read socket into a buf
|
||||||
if !inner.flags.contains(Flags::READ_DISCONNECT) {
|
if !inner.flags.contains(Flags::READ_DISCONNECT) {
|
||||||
if let Some(true) = read_available(&mut inner.io, &mut inner.read_buf)? {
|
if let Some(true) =
|
||||||
|
read_available(&mut inner.io, &mut inner.read_buf)?
|
||||||
|
{
|
||||||
inner.flags.insert(Flags::READ_DISCONNECT)
|
inner.flags.insert(Flags::READ_DISCONNECT)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -664,12 +698,34 @@ where
|
||||||
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
|
if inner.write_buf.remaining_mut() < LW_BUFFER_SIZE {
|
||||||
inner.write_buf.reserve(HW_BUFFER_SIZE);
|
inner.write_buf.reserve(HW_BUFFER_SIZE);
|
||||||
}
|
}
|
||||||
let need_write = inner.poll_response()?;
|
let result = inner.poll_response()?;
|
||||||
|
let drain = result == PollResponse::DrainWriteBuf;
|
||||||
|
|
||||||
|
// switch to upgrade handler
|
||||||
|
if let PollResponse::Upgrade(req) = result {
|
||||||
|
if let DispatcherState::Normal(inner) =
|
||||||
|
std::mem::replace(&mut self.inner, DispatcherState::None)
|
||||||
|
{
|
||||||
|
let mut parts = FramedParts::with_read_buf(
|
||||||
|
inner.io,
|
||||||
|
inner.codec,
|
||||||
|
inner.read_buf,
|
||||||
|
);
|
||||||
|
parts.write_buf = inner.write_buf;
|
||||||
|
let framed = Framed::from_parts(parts);
|
||||||
|
self.inner = DispatcherState::Upgrade(
|
||||||
|
inner.upgrade.unwrap().call((req, framed)),
|
||||||
|
);
|
||||||
|
return self.poll();
|
||||||
|
} else {
|
||||||
|
panic!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// we didnt get WouldBlock from write operation,
|
// we didnt get WouldBlock from write operation,
|
||||||
// so data get written to kernel completely (OSX)
|
// so data get written to kernel completely (OSX)
|
||||||
// and we have to write again otherwise response can get stuck
|
// and we have to write again otherwise response can get stuck
|
||||||
if inner.poll_flush()? || !need_write {
|
if inner.poll_flush()? || !drain {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -709,6 +765,13 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
DispatcherState::Upgrade(ref mut fut) => fut.poll().map_err(|e| {
|
||||||
|
error!("Upgrade handler error: {}", e);
|
||||||
|
DispatchError::Upgrade
|
||||||
|
}),
|
||||||
|
DispatcherState::None => panic!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error>
|
fn read_available<T>(io: &mut T, buf: &mut BytesMut) -> Result<Option<bool>, io::Error>
|
||||||
|
|
|
@ -31,9 +31,7 @@ const STR: &str = "Hello World Hello World Hello World Hello World Hello World \
|
||||||
fn test_h1_v2() {
|
fn test_h1_v2() {
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
let mut srv = TestServer::new(move || {
|
let mut srv = TestServer::new(move || {
|
||||||
HttpService::build()
|
HttpService::build().finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
|
||||||
.finish(|_| future::ok::<_, ()>(Response::Ok().body(STR)))
|
|
||||||
.map(|_| ())
|
|
||||||
});
|
});
|
||||||
let response = srv.block_on(srv.get("/").send()).unwrap();
|
let response = srv.block_on(srv.get("/").send()).unwrap();
|
||||||
assert!(response.status().is_success());
|
assert!(response.status().is_success());
|
||||||
|
|
76
actix-http/tests/test_ws.rs
Normal file
76
actix-http/tests/test_ws.rs
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
use actix_codec::{AsyncRead, AsyncWrite, Framed};
|
||||||
|
use actix_http::{body, h1, ws, Error, HttpService, Request, Response};
|
||||||
|
use actix_http_test::TestServer;
|
||||||
|
use actix_utils::framed::FramedTransport;
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
|
use futures::future::{self, ok};
|
||||||
|
use futures::{Future, Sink, Stream};
|
||||||
|
|
||||||
|
fn ws_service<T: AsyncRead + AsyncWrite>(
|
||||||
|
(req, framed): (Request, Framed<T, h1::Codec>),
|
||||||
|
) -> impl Future<Item = (), Error = Error> {
|
||||||
|
let res = ws::handshake(&req).unwrap().message_body(());
|
||||||
|
|
||||||
|
framed
|
||||||
|
.send((res, body::BodySize::None).into())
|
||||||
|
.map_err(|_| panic!())
|
||||||
|
.and_then(|framed| {
|
||||||
|
FramedTransport::new(framed.into_framed(ws::Codec::new()), service)
|
||||||
|
.map_err(|_| panic!())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn service(msg: ws::Frame) -> impl Future<Item = ws::Message, Error = Error> {
|
||||||
|
let msg = match msg {
|
||||||
|
ws::Frame::Ping(msg) => ws::Message::Pong(msg),
|
||||||
|
ws::Frame::Text(text) => {
|
||||||
|
ws::Message::Text(String::from_utf8_lossy(&text.unwrap()).to_string())
|
||||||
|
}
|
||||||
|
ws::Frame::Binary(bin) => ws::Message::Binary(bin.unwrap().freeze()),
|
||||||
|
ws::Frame::Close(reason) => ws::Message::Close(reason),
|
||||||
|
_ => panic!(),
|
||||||
|
};
|
||||||
|
ok(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simple() {
|
||||||
|
let mut srv = TestServer::new(|| {
|
||||||
|
HttpService::build()
|
||||||
|
.upgrade(ws_service)
|
||||||
|
.finish(|_| future::ok::<_, ()>(Response::NotFound()))
|
||||||
|
});
|
||||||
|
|
||||||
|
// client service
|
||||||
|
let framed = srv.ws().unwrap();
|
||||||
|
let framed = srv
|
||||||
|
.block_on(framed.send(ws::Message::Text("text".to_string())))
|
||||||
|
.unwrap();
|
||||||
|
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
|
||||||
|
assert_eq!(item, Some(ws::Frame::Text(Some(BytesMut::from("text")))));
|
||||||
|
|
||||||
|
let framed = srv
|
||||||
|
.block_on(framed.send(ws::Message::Binary("text".into())))
|
||||||
|
.unwrap();
|
||||||
|
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
item,
|
||||||
|
Some(ws::Frame::Binary(Some(Bytes::from_static(b"text").into())))
|
||||||
|
);
|
||||||
|
|
||||||
|
let framed = srv
|
||||||
|
.block_on(framed.send(ws::Message::Ping("text".into())))
|
||||||
|
.unwrap();
|
||||||
|
let (item, framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
|
||||||
|
assert_eq!(item, Some(ws::Frame::Pong("text".to_string().into())));
|
||||||
|
|
||||||
|
let framed = srv
|
||||||
|
.block_on(framed.send(ws::Message::Close(Some(ws::CloseCode::Normal.into()))))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (item, _framed) = srv.block_on(framed.into_future()).map_err(|_| ()).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
item,
|
||||||
|
Some(ws::Frame::Close(Some(ws::CloseCode::Normal.into())))
|
||||||
|
);
|
||||||
|
}
|
Loading…
Reference in a new issue