[dep-upgrade-202402] fix shutdown issue introduced when upgrading hyper

This commit is contained in:
Alex Auvolat 2024-02-08 23:43:59 +01:00
parent bcbd15da84
commit 5c63193d1d
No known key found for this signature in database
GPG key ID: 0E496D15096376BE
6 changed files with 71 additions and 60 deletions

View file

@ -3,9 +3,9 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::Future;
use http::header::{ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ALLOW}; use http::header::{ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, ALLOW};
use hyper::{body::Incoming as IncomingBody, Request, Response, StatusCode}; use hyper::{body::Incoming as IncomingBody, Request, Response, StatusCode};
use tokio::sync::watch;
use opentelemetry::trace::SpanRef; use opentelemetry::trace::SpanRef;
@ -65,11 +65,11 @@ impl AdminApiServer {
pub async fn run( pub async fn run(
self, self,
bind_addr: UnixOrTCPSocketAddress, bind_addr: UnixOrTCPSocketAddress,
shutdown_signal: impl Future<Output = ()>, must_exit: watch::Receiver<bool>,
) -> Result<(), GarageError> { ) -> Result<(), GarageError> {
let region = self.garage.config.s3_api.s3_region.clone(); let region = self.garage.config.s3_api.s3_region.clone();
ApiServer::new(region, self) ApiServer::new(region, self)
.run_server(bind_addr, Some(0o220), shutdown_signal) .run_server(bind_addr, Some(0o220), must_exit)
.await .await
} }

View file

@ -18,6 +18,7 @@ use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
use tokio::sync::watch;
use opentelemetry::{ use opentelemetry::{
global, global,
@ -104,20 +105,17 @@ impl<A: ApiHandler> ApiServer<A> {
self: Arc<Self>, self: Arc<Self>,
bind_addr: UnixOrTCPSocketAddress, bind_addr: UnixOrTCPSocketAddress,
unix_bind_addr_mode: Option<u32>, unix_bind_addr_mode: Option<u32>,
shutdown_signal: impl Future<Output = ()>, must_exit: watch::Receiver<bool>,
) -> Result<(), GarageError> { ) -> Result<(), GarageError> {
info!( let server_name = format!("{} API", A::API_NAME_DISPLAY);
"{} API server listening on {}", info!("{} server listening on {}", server_name, bind_addr);
A::API_NAME_DISPLAY,
bind_addr
);
match bind_addr { match bind_addr {
UnixOrTCPSocketAddress::TCPSocket(addr) => { UnixOrTCPSocketAddress::TCPSocket(addr) => {
let listener = TcpListener::bind(addr).await?; let listener = TcpListener::bind(addr).await?;
let handler = move |request, socketaddr| self.clone().handler(request, socketaddr); let handler = move |request, socketaddr| self.clone().handler(request, socketaddr);
server_loop(listener, handler, shutdown_signal).await server_loop(server_name, listener, handler, must_exit).await
} }
UnixOrTCPSocketAddress::UnixSocket(ref path) => { UnixOrTCPSocketAddress::UnixSocket(ref path) => {
if path.exists() { if path.exists() {
@ -133,7 +131,7 @@ impl<A: ApiHandler> ApiServer<A> {
)?; )?;
let handler = move |request, socketaddr| self.clone().handler(request, socketaddr); let handler = move |request, socketaddr| self.clone().handler(request, socketaddr);
server_loop(listener, handler, shutdown_signal).await server_loop(server_name, listener, handler, must_exit).await
} }
} }
} }
@ -278,9 +276,10 @@ impl Accept for UnixListenerOn {
} }
pub async fn server_loop<A, H, F, E>( pub async fn server_loop<A, H, F, E>(
server_name: String,
listener: A, listener: A,
handler: H, handler: H,
shutdown_signal: impl Future<Output = ()>, mut must_exit: watch::Receiver<bool>,
) -> Result<(), GarageError> ) -> Result<(), GarageError>
where where
A: Accept, A: Accept,
@ -288,42 +287,57 @@ where
F: Future<Output = Result<Response<BoxBody<E>>, http::Error>> + Send + 'static, F: Future<Output = Result<Response<BoxBody<E>>, http::Error>> + Send + 'static,
E: Send + Sync + std::error::Error + 'static, E: Send + Sync + std::error::Error + 'static,
{ {
tokio::pin!(shutdown_signal);
let (conn_in, mut conn_out) = tokio::sync::mpsc::unbounded_channel(); let (conn_in, mut conn_out) = tokio::sync::mpsc::unbounded_channel();
let connection_collector = tokio::spawn(async move { let connection_collector = tokio::spawn({
let mut collection = FuturesUnordered::new(); let server_name = server_name.clone();
async move {
let mut connections = FuturesUnordered::new();
loop { loop {
let collect_next = async { let collect_next = async {
if collection.is_empty() { if connections.is_empty() {
futures::future::pending().await futures::future::pending().await
} else { } else {
collection.next().await connections.next().await
} }
}; };
tokio::select! { tokio::select! {
result = collect_next => { result = collect_next => {
trace!("HTTP connection finished: {:?}", result); trace!("{} server: HTTP connection finished: {:?}", server_name, result);
} }
new_fut = conn_out.recv() => { new_fut = conn_out.recv() => {
match new_fut { match new_fut {
Some(f) => collection.push(f), Some(f) => connections.push(f),
None => break, None => break,
} }
} }
} }
} }
debug!("Collecting last open HTTP connections."); if !connections.is_empty() {
while let Some(conn_res) = collection.next().await { info!(
trace!("HTTP connection finished: {:?}", conn_res); "{} server: {} connections still open",
server_name,
connections.len()
);
while let Some(conn_res) = connections.next().await {
trace!(
"{} server: HTTP connection finished: {:?}",
server_name,
conn_res
);
info!(
"{} server: {} connections still open",
server_name,
connections.len()
);
}
}
} }
debug!("No more HTTP connections to collect");
}); });
loop { while !*must_exit.borrow() {
let (stream, client_addr) = tokio::select! { let (stream, client_addr) = tokio::select! {
acc = listener.accept() => acc?, acc = listener.accept() => acc?,
_ = &mut shutdown_signal => break, _ = must_exit.changed() => continue,
}; };
let io = TokioIo::new(stream); let io = TokioIo::new(stream);
@ -343,6 +357,8 @@ where
conn_in.send(fut)?; conn_in.send(fut)?;
} }
info!("{} server exiting", server_name);
drop(conn_in);
connection_collector.await?; connection_collector.await?;
Ok(()) Ok(())

View file

@ -2,8 +2,8 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::Future;
use hyper::{body::Incoming as IncomingBody, Method, Request, Response}; use hyper::{body::Incoming as IncomingBody, Method, Request, Response};
use tokio::sync::watch;
use opentelemetry::{trace::SpanRef, KeyValue}; use opentelemetry::{trace::SpanRef, KeyValue};
@ -42,10 +42,10 @@ impl K2VApiServer {
garage: Arc<Garage>, garage: Arc<Garage>,
bind_addr: UnixOrTCPSocketAddress, bind_addr: UnixOrTCPSocketAddress,
s3_region: String, s3_region: String,
shutdown_signal: impl Future<Output = ()>, must_exit: watch::Receiver<bool>,
) -> Result<(), GarageError> { ) -> Result<(), GarageError> {
ApiServer::new(s3_region, K2VApiServer { garage }) ApiServer::new(s3_region, K2VApiServer { garage })
.run_server(bind_addr, None, shutdown_signal) .run_server(bind_addr, None, must_exit)
.await .await
} }
} }

View file

@ -2,9 +2,9 @@ use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use futures::future::Future;
use hyper::header; use hyper::header;
use hyper::{body::Incoming as IncomingBody, Request, Response}; use hyper::{body::Incoming as IncomingBody, Request, Response};
use tokio::sync::watch;
use opentelemetry::{trace::SpanRef, KeyValue}; use opentelemetry::{trace::SpanRef, KeyValue};
@ -51,10 +51,10 @@ impl S3ApiServer {
garage: Arc<Garage>, garage: Arc<Garage>,
addr: UnixOrTCPSocketAddress, addr: UnixOrTCPSocketAddress,
s3_region: String, s3_region: String,
shutdown_signal: impl Future<Output = ()>, must_exit: watch::Receiver<bool>,
) -> Result<(), GarageError> { ) -> Result<(), GarageError> {
ApiServer::new(s3_region, S3ApiServer { garage }) ApiServer::new(s3_region, S3ApiServer { garage })
.run_server(addr, None, shutdown_signal) .run_server(addr, None, must_exit)
.await .await
} }

View file

@ -88,7 +88,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er
garage.clone(), garage.clone(),
s3_bind_addr.clone(), s3_bind_addr.clone(),
config.s3_api.s3_region.clone(), config.s3_api.s3_region.clone(),
wait_from(watch_cancel.clone()), watch_cancel.clone(),
)), )),
)); ));
} }
@ -103,7 +103,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er
garage.clone(), garage.clone(),
config.k2v_api.as_ref().unwrap().api_bind_addr.clone(), config.k2v_api.as_ref().unwrap().api_bind_addr.clone(),
config.s3_api.s3_region.clone(), config.s3_api.s3_region.clone(),
wait_from(watch_cancel.clone()), watch_cancel.clone(),
)), )),
)); ));
} }
@ -116,10 +116,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er
let web_server = WebServer::new(garage.clone(), web_config.root_domain.clone()); let web_server = WebServer::new(garage.clone(), web_config.root_domain.clone());
servers.push(( servers.push((
"Web", "Web",
tokio::spawn(web_server.run( tokio::spawn(web_server.run(web_config.bind_addr.clone(), watch_cancel.clone())),
web_config.bind_addr.clone(),
wait_from(watch_cancel.clone()),
)),
)); ));
} }
@ -127,9 +124,7 @@ pub async fn run_server(config_file: PathBuf, secrets: Secrets) -> Result<(), Er
info!("Launching Admin API server..."); info!("Launching Admin API server...");
servers.push(( servers.push((
"Admin", "Admin",
tokio::spawn( tokio::spawn(admin_server.run(admin_bind_addr.clone(), watch_cancel.clone())),
admin_server.run(admin_bind_addr.clone(), wait_from(watch_cancel.clone())),
),
)); ));
} }

View file

@ -2,7 +2,8 @@ use std::fs::{self, Permissions};
use std::os::unix::prelude::PermissionsExt; use std::os::unix::prelude::PermissionsExt;
use std::{convert::Infallible, sync::Arc}; use std::{convert::Infallible, sync::Arc};
use futures::future::Future; use tokio::net::{TcpListener, UnixListener};
use tokio::sync::watch;
use hyper::{ use hyper::{
body::Incoming as IncomingBody, body::Incoming as IncomingBody,
@ -10,8 +11,6 @@ use hyper::{
Method, Request, Response, StatusCode, Method, Request, Response, StatusCode,
}; };
use tokio::net::{TcpListener, UnixListener};
use opentelemetry::{ use opentelemetry::{
global, global,
metrics::{Counter, ValueRecorder}, metrics::{Counter, ValueRecorder},
@ -84,8 +83,9 @@ impl WebServer {
pub async fn run( pub async fn run(
self: Arc<Self>, self: Arc<Self>,
bind_addr: UnixOrTCPSocketAddress, bind_addr: UnixOrTCPSocketAddress,
shutdown_signal: impl Future<Output = ()>, must_exit: watch::Receiver<bool>,
) -> Result<(), GarageError> { ) -> Result<(), GarageError> {
let server_name = "Web".into();
info!("Web server listening on {}", bind_addr); info!("Web server listening on {}", bind_addr);
match bind_addr { match bind_addr {
@ -94,7 +94,7 @@ impl WebServer {
let handler = let handler =
move |stream, socketaddr| self.clone().handle_request(stream, socketaddr); move |stream, socketaddr| self.clone().handle_request(stream, socketaddr);
server_loop(listener, handler, shutdown_signal).await server_loop(server_name, listener, handler, must_exit).await
} }
UnixOrTCPSocketAddress::UnixSocket(ref path) => { UnixOrTCPSocketAddress::UnixSocket(ref path) => {
if path.exists() { if path.exists() {
@ -108,7 +108,7 @@ impl WebServer {
let handler = let handler =
move |stream, socketaddr| self.clone().handle_request(stream, socketaddr); move |stream, socketaddr| self.clone().handle_request(stream, socketaddr);
server_loop(listener, handler, shutdown_signal).await server_loop(server_name, listener, handler, must_exit).await
} }
} }
} }