diff --git a/src/catch_unwind.rs b/src/catch_unwind.rs new file mode 100644 index 0000000..262b006 --- /dev/null +++ b/src/catch_unwind.rs @@ -0,0 +1,46 @@ +use crate::worker::TaskExecError; +use futures::future::BoxFuture; +use futures::FutureExt; +use std::future::Future; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +pub(crate) struct CatchUnwindFuture { + inner: BoxFuture<'static, F::Output>, +} + +impl CatchUnwindFuture { + pub fn create(f: F) -> CatchUnwindFuture { + Self { inner: f.boxed() } + } +} + +impl Future for CatchUnwindFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let inner = &mut self.inner; + + match catch_unwind(move || inner.poll_unpin(cx)) { + Ok(Poll::Pending) => Poll::Pending, + Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)), + Err(cause) => Poll::Ready(Err(cause)), + } + } +} + +fn catch_unwind R, R>(f: F) -> Result { + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) { + Ok(res) => Ok(res), + Err(cause) => match cause.downcast_ref::<&'static str>() { + None => match cause.downcast_ref::() { + None => Err(TaskExecError::Panicked( + "Sorry, unknown panic message".to_string(), + )), + Some(message) => Err(TaskExecError::Panicked(message.to_string())), + }, + Some(message) => Err(TaskExecError::Panicked(message.to_string())), + }, + } +} diff --git a/src/lib.rs b/src/lib.rs index 418938e..4be5de4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ pub use task::{CurrentTask, Task, TaskId, TaskState}; pub use worker::Worker; pub use worker_pool::{QueueConfig, WorkerPool}; +mod catch_unwind; pub mod errors; mod queries; mod queue; diff --git a/src/store.rs b/src/store.rs index 390d375..24637e3 100644 --- a/src/store.rs +++ b/src/store.rs @@ -106,7 +106,7 @@ pub mod test_store { #[derive(Default, Clone)] pub struct MemoryTaskStore { - tasks: Arc>>, + pub tasks: Arc>>, } #[async_trait::async_trait] diff --git a/src/worker.rs b/src/worker.rs index d1c37d0..a59bf44 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,3 +1,4 @@ +use crate::catch_unwind::CatchUnwindFuture; use crate::errors::{AsyncQueueError, BackieError}; use crate::runnable::BackgroundTask; use crate::store::TaskStore; @@ -10,7 +11,6 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; -use thiserror::Error; pub type ExecuteTaskFn = Arc< dyn Fn( @@ -24,13 +24,16 @@ pub type ExecuteTaskFn = Arc< pub type StateFn = Arc AppData + Send + Sync>; -#[derive(Debug, Error)] +#[derive(Debug, thiserror::Error)] pub enum TaskExecError { + #[error("Task deserialization failed: {0}")] + TaskDeserializationFailed(#[from] serde_json::Error), + #[error("Task execution failed: {0}")] ExecutionFailed(#[from] anyhow::Error), - #[error("Task deserialization failed: {0}")] - TaskDeserializationFailed(#[from] serde_json::Error), + #[error("Task panicked with: {0}")] + Panicked(String), } pub(crate) fn runnable( @@ -144,9 +147,18 @@ where .get(&task.task_name) .ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?; - // TODO: catch panics - let result: Result<(), TaskExecError> = - runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await; + // catch panics + let result: Result<(), TaskExecError> = CatchUnwindFuture::create({ + let task_payload = task.payload.clone(); + let app_data = (self.app_data_fn)(); + let runnable_task_caller = runnable_task_caller.clone(); + async move { runnable_task_caller(task_info, task_payload, app_data).await } + }) + .await + .and_then(|result| { + result?; + Ok(()) + }); match &result { Ok(_) => self.finalize_task(task, result).await?, @@ -159,7 +171,9 @@ where task.id, backoff_seconds ); + let error_message = format!("{}", error); + self.store .schedule_task_retry(task.id, backoff_seconds, &error_message) .await?; diff --git a/src/worker_pool.rs b/src/worker_pool.rs index d5379b6..3d8d0bc 100644 --- a/src/worker_pool.rs +++ b/src/worker_pool.rs @@ -379,7 +379,7 @@ mod tests { #[derive(Clone)] struct NotifyFinishedContext { /// Used to notify the task ran - tx: Arc>>>, + notify_finished: Arc>>>, } /// A task that notifies the test that it ran @@ -398,7 +398,7 @@ mod tests { context: Self::AppData, ) -> Result<(), anyhow::Error> { // Notify the test that the task ran - match context.tx.lock().await.take() { + match context.notify_finished.lock().await.take() { None => println!("Cannot notify, already done that!"), Some(tx) => { tx.send(()).unwrap(); @@ -412,7 +412,7 @@ mod tests { let (tx, rx) = tokio::sync::oneshot::channel(); let my_app_context = NotifyFinishedContext { - tx: Arc::new(Mutex::new(Some(tx))), + notify_finished: Arc::new(Mutex::new(Some(tx))), }; let (join_handle, queue) = @@ -529,6 +529,57 @@ mod tests { ); } + #[tokio::test] + async fn task_can_panic_and_not_affect_worker() { + #[derive(Clone, serde::Serialize, serde::Deserialize)] + struct BrokenTask; + + #[async_trait] + impl BackgroundTask for BrokenTask { + const TASK_NAME: &'static str = "panic_me"; + type AppData = (); + + async fn run( + &self, + _task: CurrentTask, + _context: Self::AppData, + ) -> Result<(), anyhow::Error> { + panic!("Oh no!"); + } + } + + let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel(); + + let task_store = memory_store().await; + + let (worker_pool_finished, queue) = WorkerPool::new(task_store.clone(), |_| ()) + .register_task_type::() + .configure_queue("default".into()) + .start(async move { + should_stop.await.unwrap(); + }) + .await + .unwrap(); + + // Enqueue a task that will panic + queue.enqueue(BrokenTask).await.unwrap(); + + notify_stop_worker_pool.send(()).unwrap(); + worker_pool_finished.await.unwrap(); + + let raw_task = task_store + .tasks + .lock() + .await + .first_entry() + .unwrap() + .remove(); + assert_eq!( + serde_json::to_string(&raw_task.error_info.unwrap()).unwrap(), + "{\"error\":\"Task panicked with: Oh no!\"}" + ); + } + /// This test will make sure that the worker pool will only stop after all workers are done. /// We create a KeepAliveTask that will keep running until we notify it to stop. /// We stop the worker pool and make sure that the KeepAliveTask is still running.