Handle tasks that panic

This commit is contained in:
Rafael Caricio 2023-03-12 18:33:00 +01:00
parent 10e01390b8
commit 716eeae4b1
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
5 changed files with 123 additions and 11 deletions

46
src/catch_unwind.rs Normal file
View file

@ -0,0 +1,46 @@
use crate::worker::TaskExecError;
use futures::future::BoxFuture;
use futures::FutureExt;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
pub(crate) struct CatchUnwindFuture<F: Future + Send + 'static> {
inner: BoxFuture<'static, F::Output>,
}
impl<F: Future + Send + 'static> CatchUnwindFuture<F> {
pub fn create(f: F) -> CatchUnwindFuture<F> {
Self { inner: f.boxed() }
}
}
impl<F: Future + Send + 'static> Future for CatchUnwindFuture<F> {
type Output = Result<F::Output, TaskExecError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let inner = &mut self.inner;
match catch_unwind(move || inner.poll_unpin(cx)) {
Ok(Poll::Pending) => Poll::Pending,
Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
Err(cause) => Poll::Ready(Err(cause)),
}
}
}
fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> Result<R, TaskExecError> {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(res) => Ok(res),
Err(cause) => match cause.downcast_ref::<&'static str>() {
None => match cause.downcast_ref::<String>() {
None => Err(TaskExecError::Panicked(
"Sorry, unknown panic message".to_string(),
)),
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
},
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
},
}
}

View file

@ -30,6 +30,7 @@ pub use task::{CurrentTask, Task, TaskId, TaskState};
pub use worker::Worker;
pub use worker_pool::{QueueConfig, WorkerPool};
mod catch_unwind;
pub mod errors;
mod queries;
mod queue;

View file

@ -106,7 +106,7 @@ pub mod test_store {
#[derive(Default, Clone)]
pub struct MemoryTaskStore {
tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
pub tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
}
#[async_trait::async_trait]

View file

@ -1,3 +1,4 @@
use crate::catch_unwind::CatchUnwindFuture;
use crate::errors::{AsyncQueueError, BackieError};
use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
@ -10,7 +11,6 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
pub type ExecuteTaskFn<AppData> = Arc<
dyn Fn(
@ -24,13 +24,16 @@ pub type ExecuteTaskFn<AppData> = Arc<
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
#[derive(Debug, Error)]
#[derive(Debug, thiserror::Error)]
pub enum TaskExecError {
#[error("Task deserialization failed: {0}")]
TaskDeserializationFailed(#[from] serde_json::Error),
#[error("Task execution failed: {0}")]
ExecutionFailed(#[from] anyhow::Error),
#[error("Task deserialization failed: {0}")]
TaskDeserializationFailed(#[from] serde_json::Error),
#[error("Task panicked with: {0}")]
Panicked(String),
}
pub(crate) fn runnable<BT>(
@ -144,9 +147,18 @@ where
.get(&task.task_name)
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
// TODO: catch panics
let result: Result<(), TaskExecError> =
runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
// catch panics
let result: Result<(), TaskExecError> = CatchUnwindFuture::create({
let task_payload = task.payload.clone();
let app_data = (self.app_data_fn)();
let runnable_task_caller = runnable_task_caller.clone();
async move { runnable_task_caller(task_info, task_payload, app_data).await }
})
.await
.and_then(|result| {
result?;
Ok(())
});
match &result {
Ok(_) => self.finalize_task(task, result).await?,
@ -159,7 +171,9 @@ where
task.id,
backoff_seconds
);
let error_message = format!("{}", error);
self.store
.schedule_task_retry(task.id, backoff_seconds, &error_message)
.await?;

View file

@ -379,7 +379,7 @@ mod tests {
#[derive(Clone)]
struct NotifyFinishedContext {
/// Used to notify the task ran
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
notify_finished: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
}
/// A task that notifies the test that it ran
@ -398,7 +398,7 @@ mod tests {
context: Self::AppData,
) -> Result<(), anyhow::Error> {
// Notify the test that the task ran
match context.tx.lock().await.take() {
match context.notify_finished.lock().await.take() {
None => println!("Cannot notify, already done that!"),
Some(tx) => {
tx.send(()).unwrap();
@ -412,7 +412,7 @@ mod tests {
let (tx, rx) = tokio::sync::oneshot::channel();
let my_app_context = NotifyFinishedContext {
tx: Arc::new(Mutex::new(Some(tx))),
notify_finished: Arc::new(Mutex::new(Some(tx))),
};
let (join_handle, queue) =
@ -529,6 +529,57 @@ mod tests {
);
}
#[tokio::test]
async fn task_can_panic_and_not_affect_worker() {
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct BrokenTask;
#[async_trait]
impl BackgroundTask for BrokenTask {
const TASK_NAME: &'static str = "panic_me";
type AppData = ();
async fn run(
&self,
_task: CurrentTask,
_context: Self::AppData,
) -> Result<(), anyhow::Error> {
panic!("Oh no!");
}
}
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
let task_store = memory_store().await;
let (worker_pool_finished, queue) = WorkerPool::new(task_store.clone(), |_| ())
.register_task_type::<BrokenTask>()
.configure_queue("default".into())
.start(async move {
should_stop.await.unwrap();
})
.await
.unwrap();
// Enqueue a task that will panic
queue.enqueue(BrokenTask).await.unwrap();
notify_stop_worker_pool.send(()).unwrap();
worker_pool_finished.await.unwrap();
let raw_task = task_store
.tasks
.lock()
.await
.first_entry()
.unwrap()
.remove();
assert_eq!(
serde_json::to_string(&raw_task.error_info.unwrap()).unwrap(),
"{\"error\":\"Task panicked with: Oh no!\"}"
);
}
/// This test will make sure that the worker pool will only stop after all workers are done.
/// We create a KeepAliveTask that will keep running until we notify it to stop.
/// We stop the worker pool and make sure that the KeepAliveTask is still running.