From 3b860ebdc7aead34fede705b5a8cebe6b11da025 Mon Sep 17 00:00:00 2001 From: Rajasekharan Vengalil Date: Mon, 16 Dec 2019 23:34:25 -0800 Subject: [PATCH] Fix poll_ready call for WebSockets upgrade (#1219) * Fix poll_ready call for WebSockets upgrade * Poll upgrade service from H1ServiceHandler too --- actix-http/CHANGES.md | 6 +++ actix-http/src/h1/dispatcher.rs | 42 ++++------------ actix-http/src/h1/service.rs | 21 ++++++-- actix-http/src/service.rs | 21 ++++++-- actix-http/tests/test_ws.rs | 87 ++++++++++++++++++++++++++------- 5 files changed, 120 insertions(+), 57 deletions(-) diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 212ce6a15..ae41610c3 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,5 +1,11 @@ # Changes +## [1.0.xx] - 2019-12-xx + +### Fixed + +* Poll upgrade service's readiness from HTTP service handlers + ## [1.0.0] - 2019-12-13 ### Added diff --git a/actix-http/src/h1/dispatcher.rs b/actix-http/src/h1/dispatcher.rs index 6b37be683..8a2a4f030 100644 --- a/actix-http/src/h1/dispatcher.rs +++ b/actix-http/src/h1/dispatcher.rs @@ -66,7 +66,6 @@ where U::Error: fmt::Display, { Normal(InnerDispatcher), - UpgradeReadiness(InnerDispatcher, Request), Upgrade(U::Future), None, } @@ -764,8 +763,16 @@ where if let DispatcherState::Normal(inner) = std::mem::replace(&mut self.inner, DispatcherState::None) { - self.inner = - DispatcherState::UpgradeReadiness(inner, req); + let mut parts = FramedParts::with_read_buf( + inner.io, + inner.codec, + inner.read_buf, + ); + parts.write_buf = inner.write_buf; + let framed = Framed::from_parts(parts); + self.inner = DispatcherState::Upgrade( + inner.upgrade.unwrap().call((req, framed)), + ); return self.poll(cx); } else { panic!() @@ -815,35 +822,6 @@ where } } } - DispatcherState::UpgradeReadiness(ref mut inner, _) => { - let upgrade = inner.upgrade.as_mut().unwrap(); - match upgrade.poll_ready(cx) { - Poll::Ready(Ok(_)) => { - if let DispatcherState::UpgradeReadiness(inner, req) = - std::mem::replace(&mut self.inner, DispatcherState::None) - { - let mut parts = FramedParts::with_read_buf( - inner.io, - inner.codec, - inner.read_buf, - ); - parts.write_buf = inner.write_buf; - let framed = Framed::from_parts(parts); - self.inner = DispatcherState::Upgrade( - inner.upgrade.unwrap().call((req, framed)), - ); - self.poll(cx) - } else { - panic!() - } - } - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - error!("Upgrade handler readiness check error: {}", e); - Poll::Ready(Err(DispatchError::Upgrade)) - } - } - } DispatcherState::Upgrade(ref mut fut) => { unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| { error!("Upgrade handler error: {}", e); diff --git a/actix-http/src/h1/service.rs b/actix-http/src/h1/service.rs index fb5514da3..beb577f9a 100644 --- a/actix-http/src/h1/service.rs +++ b/actix-http/src/h1/service.rs @@ -72,7 +72,7 @@ where Request = (Request, Framed), Response = (), >, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, U::InitError: fmt::Debug, { /// Create simple tcp stream service @@ -115,7 +115,7 @@ mod openssl { Request = (Request, Framed, Codec>), Response = (), >, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, U::InitError: fmt::Debug, { /// Create openssl based service @@ -255,7 +255,7 @@ where X::Error: Into, X::InitError: fmt::Debug, U: ServiceFactory), Response = ()>, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, U::InitError: fmt::Debug, { type Config = (); @@ -412,7 +412,7 @@ where X: Service, X::Error: Into, U: Service), Response = ()>, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, { type Request = (T, Option); type Response = (); @@ -440,6 +440,19 @@ where })? .is_ready() && ready; + + let ready = if let Some(ref mut upg) = self.upgrade { + upg.poll_ready(cx) + .map_err(|e| { + let e = e.into(); + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service(e) + })? + .is_ready() + && ready + } else { + ready + }; if ready { Poll::Ready(Ok(())) diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index 2d934cc19..457c5ca1f 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -169,7 +169,7 @@ where Request = (Request, Framed), Response = (), >, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, U::InitError: fmt::Debug, ::Future: 'static, { @@ -214,7 +214,7 @@ mod openssl { Request = (Request, Framed, h1::Codec>), Response = (), >, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, U::InitError: fmt::Debug, ::Future: 'static, { @@ -335,7 +335,7 @@ where Request = (Request, Framed), Response = (), >, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, U::InitError: fmt::Debug, ::Future: 'static, { @@ -493,7 +493,7 @@ where X: Service, X::Error: Into, U: Service), Response = ()>, - U::Error: fmt::Display, + U::Error: fmt::Display + Into, { type Request = (T, Protocol, Option); type Response = (); @@ -522,6 +522,19 @@ where .is_ready() && ready; + let ready = if let Some(ref mut upg) = self.upgrade { + upg.poll_ready(cx) + .map_err(|e| { + let e = e.into(); + log::error!("Http service readiness error: {:?}", e); + DispatchError::Service(e) + })? + .is_ready() + && ready + } else { + ready + }; + if ready { Poll::Ready(Ok(())) } else { diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs index 2c1d6cdc1..7152fee48 100644 --- a/actix-http/tests/test_ws.rs +++ b/actix-http/tests/test_ws.rs @@ -1,24 +1,70 @@ +use std::cell::Cell; +use std::marker::PhantomData; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; + use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_http::{body, h1, ws, Error, HttpService, Request, Response}; use actix_http_test::test_server; +use actix_service::{fn_factory, Service}; use actix_utils::framed::Dispatcher; use bytes::Bytes; use futures::future; -use futures::{SinkExt, StreamExt}; +use futures::task::{Context, Poll}; +use futures::{Future, SinkExt, StreamExt}; -async fn ws_service( - (req, mut framed): (Request, Framed), -) -> Result<(), Error> { - let res = ws::handshake(req.head()).unwrap().message_body(()); +struct WsService(Arc, Cell)>>); - framed - .send((res, body::BodySize::None).into()) - .await - .unwrap(); +impl WsService { + fn new() -> Self { + WsService(Arc::new(Mutex::new((PhantomData, Cell::new(false))))) + } - Dispatcher::new(framed.into_framed(ws::Codec::new()), service) - .await - .map_err(|_| panic!()) + fn set_polled(&mut self) { + *self.0.lock().unwrap().1.get_mut() = true; + } + + fn was_polled(&self) -> bool { + self.0.lock().unwrap().1.get() + } +} + +impl Clone for WsService { + fn clone(&self) -> Self { + WsService(self.0.clone()) + } +} + +impl Service for WsService +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + type Request = (Request, Framed); + type Response = (); + type Error = Error; + type Future = Pin>>>; + + fn poll_ready(&mut self, _ctx: &mut Context<'_>) -> Poll> { + self.set_polled(); + Poll::Ready(Ok(())) + } + + fn call(&mut self, (req, mut framed): Self::Request) -> Self::Future { + let fut = async move { + let res = ws::handshake(req.head()).unwrap().message_body(()); + + framed + .send((res, body::BodySize::None).into()) + .await + .unwrap(); + + Dispatcher::new(framed.into_framed(ws::Codec::new()), service) + .await + .map_err(|_| panic!()) + }; + + Box::pin(fut) + } } async fn service(msg: ws::Frame) -> Result { @@ -37,11 +83,16 @@ async fn service(msg: ws::Frame) -> Result { #[actix_rt::test] async fn test_simple() { - let mut srv = test_server(|| { - HttpService::build() - .upgrade(actix_service::fn_service(ws_service)) - .finish(|_| future::ok::<_, ()>(Response::NotFound())) - .tcp() + let ws_service = WsService::new(); + let mut srv = test_server({ + let ws_service = ws_service.clone(); + move || { + let ws_service = ws_service.clone(); + HttpService::build() + .upgrade(fn_factory(move || future::ok::<_, ()>(ws_service.clone()))) + .finish(|_| future::ok::<_, ()>(Response::NotFound())) + .tcp() + } }); // client service @@ -138,4 +189,6 @@ async fn test_simple() { item.unwrap().unwrap(), ws::Frame::Close(Some(ws::CloseCode::Normal.into())) ); + + assert!(ws_service.was_polled()); }