diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 797cde99b..1eaccfb2e 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -1,7 +1,11 @@ # Changes ## Unreleased - 2021-xx-xx +### Added +* Add timeout for canceling HTTP/2 server side connection handshake. Default to 5 seconds. [#2483] +* HTTP/2 handshake timeout can be configured with `ServiceConfig::client_timeout`. [#2483] +[#2483]: https://github.com/actix/actix-web/pull/2483 ## 3.0.0-beta.14 - 2021-11-30 ### Changed diff --git a/actix-http/src/h2/dispatcher.rs b/actix-http/src/h2/dispatcher.rs index 8efd3e831..607997eb7 100644 --- a/actix-http/src/h2/dispatcher.rs +++ b/actix-http/src/h2/dispatcher.rs @@ -10,7 +10,7 @@ use std::{ }; use actix_codec::{AsyncRead, AsyncWrite}; -use actix_rt::time::Sleep; +use actix_rt::time::{sleep, Sleep}; use actix_service::Service; use actix_utils::future::poll_fn; use bytes::{Bytes, BytesMut}; @@ -55,9 +55,16 @@ where on_connect_data: OnConnectData, config: ServiceConfig, peer_addr: Option, + timer: Option>>, ) -> Self { - let ping_pong = config.keep_alive_timer().map(|timer| H2PingPong { - timer: Box::pin(timer), + let ping_pong = config.keep_alive().map(|dur| H2PingPong { + timer: timer + .map(|mut timer| { + // reset timer if it's received from new function. + timer.as_mut().reset(config.now() + dur); + timer + }) + .unwrap_or_else(|| Box::pin(sleep(dur))), on_flight: false, ping_pong: connection.ping_pong().unwrap(), }); diff --git a/actix-http/src/h2/mod.rs b/actix-http/src/h2/mod.rs index 7eff44ac1..25d53403e 100644 --- a/actix-http/src/h2/mod.rs +++ b/actix-http/src/h2/mod.rs @@ -1,20 +1,30 @@ //! HTTP/2 protocol. use std::{ + future::Future, pin::Pin, task::{Context, Poll}, }; +use actix_codec::{AsyncRead, AsyncWrite}; +use actix_rt::time::Sleep; use bytes::Bytes; use futures_core::{ready, Stream}; -use h2::RecvStream; +use h2::{ + server::{handshake, Connection, Handshake}, + RecvStream, +}; mod dispatcher; mod service; pub use self::dispatcher::Dispatcher; pub use self::service::H2Service; -use crate::error::PayloadError; + +use crate::{ + config::ServiceConfig, + error::{DispatchError, PayloadError}, +}; /// HTTP/2 peer stream. pub struct Payload { @@ -50,3 +60,44 @@ impl Stream for Payload { } } } + +pub(crate) fn handshake_with_timeout( + io: T, + config: &ServiceConfig, +) -> HandshakeWithTimeout +where + T: AsyncRead + AsyncWrite + Unpin, +{ + HandshakeWithTimeout { + handshake: handshake(io), + timer: config.client_timer().map(Box::pin), + } +} + +pub(crate) struct HandshakeWithTimeout { + handshake: Handshake, + timer: Option>>, +} + +impl Future for HandshakeWithTimeout +where + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(Connection, Option>>), DispatchError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + match Pin::new(&mut this.handshake).poll(cx)? { + // return the timer on success handshake. It can be re-used for h2 ping-pong. + Poll::Ready(conn) => Poll::Ready(Ok((conn, this.timer.take()))), + Poll::Pending => match this.timer.as_mut() { + Some(timer) => { + ready!(timer.as_mut().poll(cx)); + Poll::Ready(Err(DispatchError::SlowRequestTimeout)) + } + None => Poll::Pending, + }, + } + } +} diff --git a/actix-http/src/h2/service.rs b/actix-http/src/h2/service.rs index 798740234..0ad17ec0a 100644 --- a/actix-http/src/h2/service.rs +++ b/actix-http/src/h2/service.rs @@ -15,9 +15,7 @@ use actix_service::{ ServiceFactoryExt as _, }; use actix_utils::future::ready; -use bytes::Bytes; use futures_core::{future::LocalBoxFuture, ready}; -use h2::server::{handshake as h2_handshake, Handshake as H2Handshake}; use log::error; use crate::{ @@ -28,7 +26,7 @@ use crate::{ ConnectCallback, OnConnectData, Request, Response, }; -use super::dispatcher::Dispatcher; +use super::{dispatcher::Dispatcher, handshake_with_timeout, HandshakeWithTimeout}; /// `ServiceFactory` implementation for HTTP/2 transport pub struct H2Service { @@ -297,7 +295,7 @@ where Some(self.cfg.clone()), addr, on_connect_data, - h2_handshake(io), + handshake_with_timeout(io, &self.cfg), ), } } @@ -314,7 +312,7 @@ where Option, Option, OnConnectData, - H2Handshake, + HandshakeWithTimeout, ), } @@ -352,7 +350,7 @@ where ref mut on_connect_data, ref mut handshake, ) => match ready!(Pin::new(handshake).poll(cx)) { - Ok(conn) => { + Ok((conn, timer)) => { let on_connect_data = std::mem::take(on_connect_data); self.state = State::Incoming(Dispatcher::new( srv.take().unwrap(), @@ -360,12 +358,13 @@ where on_connect_data, config.take().unwrap(), *peer_addr, + timer, )); self.poll(cx) } Err(err) => { trace!("H2 handshake error: {}", err); - Poll::Ready(Err(err.into())) + Poll::Ready(Err(err)) } }, } diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs index a47dda738..fb0cccb38 100644 --- a/actix-http/src/service.rs +++ b/actix-http/src/service.rs @@ -9,13 +9,11 @@ use std::{ task::{Context, Poll}, }; -use ::h2::server::{handshake as h2_handshake, Handshake as H2Handshake}; use actix_codec::{AsyncRead, AsyncWrite, Framed}; use actix_rt::net::TcpStream; use actix_service::{ fn_service, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _, }; -use bytes::Bytes; use futures_core::{future::LocalBoxFuture, ready}; use pin_project::pin_project; @@ -522,7 +520,7 @@ where match proto { Protocol::Http2 => HttpServiceHandlerResponse { state: State::H2Handshake(Some(( - h2_handshake(io), + h2::handshake_with_timeout(io, &self.cfg), self.cfg.clone(), self.flow.clone(), on_connect_data, @@ -567,7 +565,7 @@ where H2(#[pin] h2::Dispatcher), H2Handshake( Option<( - H2Handshake, + h2::HandshakeWithTimeout, ServiceConfig, Rc>, OnConnectData, @@ -625,7 +623,7 @@ where StateProj::H2(disp) => disp.poll(cx), StateProj::H2Handshake(data) => { match ready!(Pin::new(&mut data.as_mut().unwrap().0).poll(cx)) { - Ok(conn) => { + Ok((conn, timer)) => { let (_, cfg, srv, on_connect_data, peer_addr) = data.take().unwrap(); self.as_mut().project().state.set(State::H2( @@ -635,13 +633,14 @@ where on_connect_data, cfg, peer_addr, + timer, ), )); self.poll(cx) } Err(err) => { trace!("H2 handshake error: {}", err); - Poll::Ready(Err(err.into())) + Poll::Ready(Err(err)) } } } diff --git a/actix-http/tests/test_h2_ping_pong.rs b/actix-http/tests/test_h2_ping_pong.rs deleted file mode 100644 index 30ce9aa51..000000000 --- a/actix-http/tests/test_h2_ping_pong.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::io; - -use actix_http::{error::Error, HttpService, Response}; -use actix_server::Server; - -#[actix_rt::test] -async fn h2_ping_pong() -> io::Result<()> { - let (tx, rx) = std::sync::mpsc::sync_channel(1); - - let lst = std::net::TcpListener::bind("127.0.0.1:0")?; - - let addr = lst.local_addr().unwrap(); - - let join = std::thread::spawn(move || { - actix_rt::System::new().block_on(async move { - let srv = Server::build() - .disable_signals() - .workers(1) - .listen("h2_ping_pong", lst, || { - HttpService::build() - .keep_alive(3) - .h2(|_| async { Ok::<_, Error>(Response::ok()) }) - .tcp() - })? - .run(); - - tx.send(srv.handle()).unwrap(); - - srv.await - }) - }); - - let handle = rx.recv().unwrap(); - - let (sync_tx, rx) = std::sync::mpsc::sync_channel(1); - - // use a separate thread for h2 client so it can be blocked. - std::thread::spawn(move || { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap() - .block_on(async move { - let stream = tokio::net::TcpStream::connect(addr).await.unwrap(); - - let (mut tx, conn) = h2::client::handshake(stream).await.unwrap(); - - tokio::spawn(async move { conn.await.unwrap() }); - - let (res, _) = tx.send_request(::http::Request::new(()), true).unwrap(); - let res = res.await.unwrap(); - - assert_eq!(res.status().as_u16(), 200); - - sync_tx.send(()).unwrap(); - - // intentionally block the client thread so it can not answer ping pong. - std::thread::sleep(std::time::Duration::from_secs(1000)); - }) - }); - - rx.recv().unwrap(); - - let now = std::time::Instant::now(); - - // stop server gracefully. this step would take up to 30 seconds. - handle.stop(true).await; - - // join server thread. only when connection are all gone this step would finish. - join.join().unwrap()?; - - // check the time used for join server thread so it's known that the server shutdown - // is from keep alive and not server graceful shutdown timeout. - assert!(now.elapsed() < std::time::Duration::from_secs(30)); - - Ok(()) -} diff --git a/actix-http/tests/test_h2_timer.rs b/actix-http/tests/test_h2_timer.rs new file mode 100644 index 000000000..2b9c26e4a --- /dev/null +++ b/actix-http/tests/test_h2_timer.rs @@ -0,0 +1,153 @@ +use std::io; + +use actix_http::{error::Error, HttpService, Response}; +use actix_server::Server; +use tokio::io::AsyncWriteExt; + +#[actix_rt::test] +async fn h2_ping_pong() -> io::Result<()> { + let (tx, rx) = std::sync::mpsc::sync_channel(1); + + let lst = std::net::TcpListener::bind("127.0.0.1:0")?; + + let addr = lst.local_addr().unwrap(); + + let join = std::thread::spawn(move || { + actix_rt::System::new().block_on(async move { + let srv = Server::build() + .disable_signals() + .workers(1) + .listen("h2_ping_pong", lst, || { + HttpService::build() + .keep_alive(3) + .h2(|_| async { Ok::<_, Error>(Response::ok()) }) + .tcp() + })? + .run(); + + tx.send(srv.handle()).unwrap(); + + srv.await + }) + }); + + let handle = rx.recv().unwrap(); + + let (sync_tx, rx) = std::sync::mpsc::sync_channel(1); + + // use a separate thread for h2 client so it can be blocked. + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + let stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + + let (mut tx, conn) = h2::client::handshake(stream).await.unwrap(); + + tokio::spawn(async move { conn.await.unwrap() }); + + let (res, _) = tx.send_request(::http::Request::new(()), true).unwrap(); + let res = res.await.unwrap(); + + assert_eq!(res.status().as_u16(), 200); + + sync_tx.send(()).unwrap(); + + // intentionally block the client thread so it can not answer ping pong. + std::thread::sleep(std::time::Duration::from_secs(1000)); + }) + }); + + rx.recv().unwrap(); + + let now = std::time::Instant::now(); + + // stop server gracefully. this step would take up to 30 seconds. + handle.stop(true).await; + + // join server thread. only when connection are all gone this step would finish. + join.join().unwrap()?; + + // check the time used for join server thread so it's known that the server shutdown + // is from keep alive and not server graceful shutdown timeout. + assert!(now.elapsed() < std::time::Duration::from_secs(30)); + + Ok(()) +} + +#[actix_rt::test] +async fn h2_handshake_timeout() -> io::Result<()> { + let (tx, rx) = std::sync::mpsc::sync_channel(1); + + let lst = std::net::TcpListener::bind("127.0.0.1:0")?; + + let addr = lst.local_addr().unwrap(); + + let join = std::thread::spawn(move || { + actix_rt::System::new().block_on(async move { + let srv = Server::build() + .disable_signals() + .workers(1) + .listen("h2_ping_pong", lst, || { + HttpService::build() + .keep_alive(30) + // set first request timeout to 5 seconds. + // this is the timeout used for http2 handshake. + .client_timeout(5000) + .h2(|_| async { Ok::<_, Error>(Response::ok()) }) + .tcp() + })? + .run(); + + tx.send(srv.handle()).unwrap(); + + srv.await + }) + }); + + let handle = rx.recv().unwrap(); + + let (sync_tx, rx) = std::sync::mpsc::sync_channel(1); + + // use a separate thread for tcp client so it can be blocked. + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { + let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap(); + + // do not send the last new line intentionally. + // This should hang the server handshake + let malicious_buf = b"PRI * HTTP/2.0\r\n\r\nSM\r\n"; + stream.write_all(malicious_buf).await.unwrap(); + stream.flush().await.unwrap(); + + sync_tx.send(()).unwrap(); + + // intentionally block the client thread so it sit idle and not do handshake. + std::thread::sleep(std::time::Duration::from_secs(1000)); + + drop(stream) + }) + }); + + rx.recv().unwrap(); + + let now = std::time::Instant::now(); + + // stop server gracefully. this step would take up to 30 seconds. + handle.stop(true).await; + + // join server thread. only when connection are all gone this step would finish. + join.join().unwrap()?; + + // check the time used for join server thread so it's known that the server shutdown + // is from handshake timeout and not server graceful shutdown timeout. + assert!(now.elapsed() < std::time::Duration::from_secs(30)); + + Ok(()) +}