Simplify process future by removing explicit channel, don't lock mutex as often

This commit is contained in:
asonix 2023-12-22 12:03:05 -06:00
parent 46ceac3432
commit db43392a3b

View file

@ -1,5 +1,4 @@
use actix_web::web::Bytes; use actix_web::web::Bytes;
use flume::r#async::RecvFut;
use std::{ use std::{
ffi::OsStr, ffi::OsStr,
future::Future, future::Future,
@ -73,7 +72,7 @@ impl std::fmt::Debug for Process {
} }
struct DropHandle { struct DropHandle {
inner: JoinHandle<()>, inner: JoinHandle<std::io::Result<()>>,
} }
struct ProcessReadState { struct ProcessReadState {
@ -88,7 +87,6 @@ struct ProcessReadWaker {
pub(crate) struct ProcessRead { pub(crate) struct ProcessRead {
inner: ChildStdout, inner: ChildStdout,
err_recv: RecvFut<'static, std::io::Error>,
handle: DropHandle, handle: DropHandle,
closed: bool, closed: bool,
state: Arc<ProcessReadState>, state: Arc<ProcessReadState>,
@ -233,9 +231,6 @@ impl Process {
let stdin = child.stdin.take().expect("stdin exists"); let stdin = child.stdin.take().expect("stdin exists");
let stdout = child.stdout.take().expect("stdout exists"); let stdout = child.stdout.take().expect("stdout exists");
let (tx, rx) = crate::sync::channel::<std::io::Error>(1);
let rx = rx.into_recv_async();
let background_span = let background_span =
tracing::info_span!(parent: None, "Background process task", %command); tracing::info_span!(parent: None, "Background process task", %command);
background_span.follows_from(Span::current()); background_span.follows_from(Span::current());
@ -255,7 +250,7 @@ impl Process {
let error = match child_fut.with_timeout(timeout).await { let error = match child_fut.with_timeout(timeout).await {
Ok(Ok(status)) if status.success() => { Ok(Ok(status)) if status.success() => {
guard.disarm(); guard.disarm();
return; return Ok(());
} }
Ok(Ok(status)) => { Ok(Ok(status)) => {
std::io::Error::new(std::io::ErrorKind::Other, StatusError(status)) std::io::Error::new(std::io::ErrorKind::Other, StatusError(status))
@ -264,15 +259,15 @@ impl Process {
Err(_) => std::io::ErrorKind::TimedOut.into(), Err(_) => std::io::ErrorKind::TimedOut.into(),
}; };
let _ = tx.send(error); child.kill().await?;
let _ = child.kill().await;
Err(error)
} }
.instrument(background_span), .instrument(background_span),
); );
ProcessRead { ProcessRead {
inner: stdout, inner: stdout,
err_recv: rx,
handle: DropHandle { inner: handle }, handle: DropHandle { inner: handle },
closed: false, closed: false,
state: ProcessReadState::new_woken(), state: ProcessReadState::new_woken(),
@ -291,7 +286,7 @@ impl ProcessReadState {
fn clone_parent(&self) -> Option<Waker> { fn clone_parent(&self) -> Option<Waker> {
let guard = self.parent.lock().unwrap(); let guard = self.parent.lock().unwrap();
guard.as_ref().map(|w| w.clone()) guard.as_ref().cloned()
} }
fn into_parts(self) -> (AtomicU8, Option<Waker>) { fn into_parts(self) -> (AtomicU8, Option<Waker>) {
@ -322,19 +317,26 @@ impl ProcessRead {
} }
} }
fn set_parent_waker(&self, parent: &Waker) { fn set_parent_waker(&self, parent: &Waker) -> bool {
let mut guard = self.state.parent.lock().unwrap(); let mut guard = self.state.parent.lock().unwrap();
if let Some(waker) = guard.as_mut() { if let Some(waker) = guard.as_mut() {
if !waker.will_wake(parent) { if !waker.will_wake(parent) {
*waker = parent.clone(); *waker = parent.clone();
true
} else {
false
} }
} else { } else {
*guard = Some(parent.clone()); *guard = Some(parent.clone());
true
} }
} }
fn mark_all_woken(&self) {
self.state.flags.store(0xff, Ordering::Release);
}
} }
const RECV_WAKER: u8 = 0b_0010;
const HANDLE_WAKER: u8 = 0b_0100; const HANDLE_WAKER: u8 = 0b_0100;
impl AsyncRead for ProcessRead { impl AsyncRead for ProcessRead {
@ -343,8 +345,6 @@ impl AsyncRead for ProcessRead {
cx: &mut Context<'_>, cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> { ) -> Poll<std::io::Result<()>> {
self.set_parent_waker(cx.waker());
let span = self.span.clone(); let span = self.span.clone();
let guard = span.enter(); let guard = span.enter();
@ -364,29 +364,30 @@ impl AsyncRead for ProcessRead {
} else { } else {
break Poll::Ready(Ok(())); break Poll::Ready(Ok(()));
} }
} else if let Some(waker) = self.get_waker(RECV_WAKER) {
// only poll recv if we've been explicitly woken
let mut recv_cx = Context::from_waker(&waker);
if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(&mut recv_cx) {
if let Ok(err) = res {
self.closed = true;
break Poll::Ready(Err(err));
}
}
} else if let Some(waker) = self.get_waker(HANDLE_WAKER) { } else if let Some(waker) = self.get_waker(HANDLE_WAKER) {
// only poll handle if we've been explicitly woken // only poll handle if we've been explicitly woken
let mut handle_cx = Context::from_waker(&waker); let mut handle_cx = Context::from_waker(&waker);
if let Poll::Ready(res) = Pin::new(&mut self.handle.inner).poll(&mut handle_cx) { if let Poll::Ready(res) = Pin::new(&mut self.handle.inner).poll(&mut handle_cx) {
if let Err(e) = res { let error = match res {
self.closed = true; Ok(Ok(())) => continue,
break Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))); Ok(Err(e)) => e,
} Err(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
};
self.closed = true;
break Poll::Ready(Err(error));
} }
} else if self.closed { } else if self.closed {
// Stop if we're closed
break Poll::Ready(Ok(())); break Poll::Ready(Ok(()));
} else if self.set_parent_waker(cx.waker()) {
// if we updated the stored waker, mark all as woken an try polling again
// This doesn't actually "wake" the waker, it just allows the handle to be polled
// again next iteration
self.mark_all_woken();
} else { } else {
// if the waker hasn't changed and nothing polled ready, return pending
break Poll::Pending; break Poll::Pending;
} }
}; };