diff --git a/src/process.rs b/src/process.rs index 29bf910..46c2023 100644 --- a/src/process.rs +++ b/src/process.rs @@ -5,7 +5,11 @@ use std::{ future::Future, pin::Pin, process::{ExitStatus, Stdio}, - task::{Context, Poll}, + sync::{ + atomic::{AtomicU8, Ordering}, + Arc, Mutex, + }, + task::{Context, Poll, Wake, Waker}, time::{Duration, Instant}, }; use tokio::{ @@ -72,14 +76,23 @@ struct DropHandle { inner: JoinHandle<()>, } -pub(crate) struct ProcessRead { - inner: I, +struct ProcessReadState { + flags: AtomicU8, + parent: Mutex>, +} + +struct ProcessReadWaker { + state: Arc, + flag: u8, +} + +pub(crate) struct ProcessRead { + inner: ChildStdout, err_recv: RecvFut<'static, std::io::Error>, - err_closed: bool, - #[allow(dead_code)] handle: DropHandle, - eof: bool, - sleep: Pin>, + closed: bool, + state: Arc, + span: Span, } #[derive(Debug, thiserror::Error)] @@ -191,21 +204,21 @@ impl Process { } } - pub(crate) fn bytes_read(self, input: Bytes) -> ProcessRead { + pub(crate) fn bytes_read(self, input: Bytes) -> ProcessRead { self.spawn_fn(move |mut stdin| { let mut input = input; async move { stdin.write_all_buf(&mut input).await } }) } - pub(crate) fn read(self) -> ProcessRead { + pub(crate) fn read(self) -> ProcessRead { self.spawn_fn(|_| async { Ok(()) }) } #[allow(unknown_lints)] #[allow(clippy::let_with_type_underscore)] #[tracing::instrument(level = "trace", skip_all)] - fn spawn_fn(self, f: F) -> ProcessRead + fn spawn_fn(self, f: F) -> ProcessRead where F: FnOnce(ChildStdin) -> Fut + 'static, Fut: Future>, @@ -223,7 +236,11 @@ impl Process { let (tx, rx) = crate::sync::channel::(1); let rx = rx.into_recv_async(); - let span = tracing::info_span!(parent: None, "Background process task", %command); + let background_span = + tracing::info_span!(parent: None, "Background process task", %command); + background_span.follows_from(Span::current()); + + let span = tracing::info_span!(parent: None, "Foreground process task", %command); span.follows_from(Span::current()); let handle = crate::sync::spawn( @@ -250,81 +267,133 @@ impl Process { let _ = tx.send(error); let _ = child.kill().await; } - .instrument(span), + .instrument(background_span), ); - let sleep = tokio::time::sleep(timeout); - ProcessRead { inner: stdout, err_recv: rx, - err_closed: false, handle: DropHandle { inner: handle }, - eof: false, - sleep: Box::pin(sleep), + closed: false, + state: ProcessReadState::new_woken(), + span, } } } -impl AsyncRead for ProcessRead -where - I: AsyncRead + Unpin, -{ +impl ProcessReadState { + fn new_woken() -> Arc { + Arc::new(Self { + flags: AtomicU8::new(0xff), + parent: Mutex::new(None), + }) + } + + fn clone_parent(&self) -> Option { + let guard = self.parent.lock().unwrap(); + guard.as_ref().map(|w| w.clone()) + } + + fn into_parts(self) -> (AtomicU8, Option) { + let ProcessReadState { flags, parent } = self; + + let parent = parent.lock().unwrap().take(); + + (flags, parent) + } +} + +impl ProcessRead { + fn get_waker(&self, flag: u8) -> Option { + let mask = 0xff ^ flag; + let previous = self.state.flags.fetch_and(mask, Ordering::AcqRel); + let active = previous & flag; + + if active == flag { + Some( + Arc::new(ProcessReadWaker { + state: self.state.clone(), + flag, + }) + .into(), + ) + } else { + None + } + } + + fn set_parent_waker(&self, parent: &Waker) { + let mut guard = self.state.parent.lock().unwrap(); + if let Some(waker) = guard.as_mut() { + if !waker.will_wake(parent) { + *waker = parent.clone(); + } + } else { + *guard = Some(parent.clone()); + } + } +} + +const RECV_WAKER: u8 = 0b_0010; +const HANDLE_WAKER: u8 = 0b_0100; + +impl AsyncRead for ProcessRead { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - if !self.err_closed { - if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(cx) { - self.err_closed = true; + self.set_parent_waker(cx.waker()); - if let Ok(err) = res { - return Poll::Ready(Err(err)); - } + let span = self.span.clone(); + let guard = span.enter(); - if self.eof { - return Poll::Ready(Ok(())); - } - } - - if let Poll::Ready(()) = self.sleep.as_mut().poll(cx) { - self.err_closed = true; - - return Poll::Ready(Err(std::io::ErrorKind::TimedOut.into())); - } - } - - if !self.eof { + let value = loop { + // always poll for bytes when poll_read is called let before_size = buf.filled().len(); - return match Pin::new(&mut self.inner).poll_read(cx, buf) { - Poll::Ready(Ok(())) => { - if buf.filled().len() == before_size { - self.eof = true; + if let Poll::Ready(res) = Pin::new(&mut self.inner).poll_read(cx, buf) { + if let Err(e) = res { + self.closed = true; - if !self.err_closed { - // reached end of stream & haven't received process signal - return Poll::Pending; - } + break Poll::Ready(Err(e)); + } else if buf.filled().len() == before_size { + self.closed = true; + + break Poll::Ready(Ok(())); + } else { + break Poll::Ready(Ok(())); + } + } else if let Some(waker) = self.get_waker(RECV_WAKER) { + // only poll recv if we've been explicitly woken + let mut recv_cx = Context::from_waker(&waker); + + if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(&mut recv_cx) { + if let Ok(err) = res { + self.closed = true; + break Poll::Ready(Err(err)); } - - Poll::Ready(Ok(())) } - Poll::Ready(Err(e)) => { - self.eof = true; + } else if let Some(waker) = self.get_waker(HANDLE_WAKER) { + // only poll handle if we've been explicitly woken + let mut handle_cx = Context::from_waker(&waker); - Poll::Ready(Err(e)) + if let Poll::Ready(res) = Pin::new(&mut self.handle.inner).poll(&mut handle_cx) { + if let Err(e) = res { + self.closed = true; + break Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))); + } } - Poll::Pending => Poll::Pending, - }; - } + } else if self.closed { + break Poll::Ready(Ok(())); + } else { + break Poll::Pending; + } + }; - if self.err_closed && self.eof { - return Poll::Ready(Ok(())); - } + drop(guard); - Poll::Pending + value } } @@ -334,6 +403,40 @@ impl Drop for DropHandle { } } +impl Wake for ProcessReadWaker { + fn wake(self: Arc) { + match Arc::try_unwrap(self) { + Ok(ProcessReadWaker { state, flag }) => match Arc::try_unwrap(state) { + Ok(state) => { + let (flags, parent) = state.into_parts(); + + flags.fetch_and(flag, Ordering::AcqRel); + + if let Some(parent) = parent { + parent.wake(); + } + } + Err(state) => { + state.flags.fetch_or(flag, Ordering::AcqRel); + + if let Some(waker) = state.clone_parent() { + waker.wake(); + } + } + }, + Err(this) => this.wake_by_ref(), + } + } + + fn wake_by_ref(self: &Arc) { + self.state.flags.fetch_or(self.flag, Ordering::AcqRel); + + if let Some(parent) = self.state.clone_parent() { + parent.wake(); + } + } +} + impl std::fmt::Display for StatusError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "Command failed with bad status: {}", self.0) diff --git a/src/repo/migrate.rs b/src/repo/migrate.rs index 4545b27..5f32fba 100644 --- a/src/repo/migrate.rs +++ b/src/repo/migrate.rs @@ -1,4 +1,7 @@ -use std::sync::{Arc, OnceLock}; +use std::{ + sync::{Arc, OnceLock}, + time::Duration, +}; use streem::IntoStreamer; use tokio::{sync::Semaphore, task::JoinSet}; @@ -6,7 +9,7 @@ use tokio::{sync::Semaphore, task::JoinSet}; use crate::{ config::Configuration, details::Details, - error::Error, + error::{Error, UploadError}, repo::{ArcRepo, DeleteToken, Hash}, repo_04::{ AliasRepo as _, HashRepo as _, IdentifierRepo as _, SettingsRepo as _, @@ -41,7 +44,7 @@ pub(crate) async fn migrate_repo(old_repo: ArcRepo, new_repo: ArcRepo) -> Result let mut index = 0; while let Some(res) = hash_stream.next().await { if let Ok(hash) = res { - let _ = migrate_hash(old_repo.clone(), new_repo.clone(), hash).await; + migrate_hash(old_repo.clone(), new_repo.clone(), hash).await; } else { tracing::warn!("Failed to read hash, skipping"); } @@ -61,6 +64,12 @@ pub(crate) async fn migrate_repo(old_repo: ArcRepo, new_repo: ArcRepo) -> Result .await?; } + if let Some(generator_state) = old_repo.get(crate::NOT_FOUND_KEY).await? { + new_repo + .set(crate::NOT_FOUND_KEY, generator_state.to_vec().into()) + .await?; + } + tracing::info!("Migration complete"); Ok(()) @@ -181,7 +190,7 @@ async fn migrate_hash_04( ) { let mut hash_failures = 0; - while let Err(e) = do_migrate_hash_04( + while let Err(e) = timed_migrate_hash_04( &tmp_dir, &old_repo, &new_repo, @@ -275,6 +284,22 @@ async fn do_migrate_hash(old_repo: &ArcRepo, new_repo: &ArcRepo, hash: Hash) -> Ok(()) } +async fn timed_migrate_hash_04( + tmp_dir: &TmpDir, + old_repo: &OldSledRepo, + new_repo: &ArcRepo, + store: &S, + config: &Configuration, + old_hash: sled::IVec, +) -> Result<(), Error> { + tokio::time::timeout( + Duration::from_secs(config.media.external_validation_timeout * 6), + do_migrate_hash_04(tmp_dir, old_repo, new_repo, store, config, old_hash), + ) + .await + .map_err(|_| UploadError::ProcessTimeout)? +} + #[tracing::instrument(skip_all)] async fn do_migrate_hash_04( tmp_dir: &TmpDir,