Handle tasks that panic
This commit is contained in:
parent
10e01390b8
commit
716eeae4b1
5 changed files with 123 additions and 11 deletions
46
src/catch_unwind.rs
Normal file
46
src/catch_unwind.rs
Normal file
|
@ -0,0 +1,46 @@
|
|||
use crate::worker::TaskExecError;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::FutureExt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
pub(crate) struct CatchUnwindFuture<F: Future + Send + 'static> {
|
||||
inner: BoxFuture<'static, F::Output>,
|
||||
}
|
||||
|
||||
impl<F: Future + Send + 'static> CatchUnwindFuture<F> {
|
||||
pub fn create(f: F) -> CatchUnwindFuture<F> {
|
||||
Self { inner: f.boxed() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future + Send + 'static> Future for CatchUnwindFuture<F> {
|
||||
type Output = Result<F::Output, TaskExecError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let inner = &mut self.inner;
|
||||
|
||||
match catch_unwind(move || inner.poll_unpin(cx)) {
|
||||
Ok(Poll::Pending) => Poll::Pending,
|
||||
Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
|
||||
Err(cause) => Poll::Ready(Err(cause)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> Result<R, TaskExecError> {
|
||||
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
|
||||
Ok(res) => Ok(res),
|
||||
Err(cause) => match cause.downcast_ref::<&'static str>() {
|
||||
None => match cause.downcast_ref::<String>() {
|
||||
None => Err(TaskExecError::Panicked(
|
||||
"Sorry, unknown panic message".to_string(),
|
||||
)),
|
||||
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
|
||||
},
|
||||
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
|
||||
},
|
||||
}
|
||||
}
|
|
@ -30,6 +30,7 @@ pub use task::{CurrentTask, Task, TaskId, TaskState};
|
|||
pub use worker::Worker;
|
||||
pub use worker_pool::{QueueConfig, WorkerPool};
|
||||
|
||||
mod catch_unwind;
|
||||
pub mod errors;
|
||||
mod queries;
|
||||
mod queue;
|
||||
|
|
|
@ -106,7 +106,7 @@ pub mod test_store {
|
|||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct MemoryTaskStore {
|
||||
tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
|
||||
pub tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::catch_unwind::CatchUnwindFuture;
|
||||
use crate::errors::{AsyncQueueError, BackieError};
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::store::TaskStore;
|
||||
|
@ -10,7 +11,6 @@ use std::future::Future;
|
|||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
|
||||
pub type ExecuteTaskFn<AppData> = Arc<
|
||||
dyn Fn(
|
||||
|
@ -24,13 +24,16 @@ pub type ExecuteTaskFn<AppData> = Arc<
|
|||
|
||||
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TaskExecError {
|
||||
#[error("Task deserialization failed: {0}")]
|
||||
TaskDeserializationFailed(#[from] serde_json::Error),
|
||||
|
||||
#[error("Task execution failed: {0}")]
|
||||
ExecutionFailed(#[from] anyhow::Error),
|
||||
|
||||
#[error("Task deserialization failed: {0}")]
|
||||
TaskDeserializationFailed(#[from] serde_json::Error),
|
||||
#[error("Task panicked with: {0}")]
|
||||
Panicked(String),
|
||||
}
|
||||
|
||||
pub(crate) fn runnable<BT>(
|
||||
|
@ -144,9 +147,18 @@ where
|
|||
.get(&task.task_name)
|
||||
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
|
||||
|
||||
// TODO: catch panics
|
||||
let result: Result<(), TaskExecError> =
|
||||
runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
|
||||
// catch panics
|
||||
let result: Result<(), TaskExecError> = CatchUnwindFuture::create({
|
||||
let task_payload = task.payload.clone();
|
||||
let app_data = (self.app_data_fn)();
|
||||
let runnable_task_caller = runnable_task_caller.clone();
|
||||
async move { runnable_task_caller(task_info, task_payload, app_data).await }
|
||||
})
|
||||
.await
|
||||
.and_then(|result| {
|
||||
result?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
match &result {
|
||||
Ok(_) => self.finalize_task(task, result).await?,
|
||||
|
@ -159,7 +171,9 @@ where
|
|||
task.id,
|
||||
backoff_seconds
|
||||
);
|
||||
|
||||
let error_message = format!("{}", error);
|
||||
|
||||
self.store
|
||||
.schedule_task_retry(task.id, backoff_seconds, &error_message)
|
||||
.await?;
|
||||
|
|
|
@ -379,7 +379,7 @@ mod tests {
|
|||
#[derive(Clone)]
|
||||
struct NotifyFinishedContext {
|
||||
/// Used to notify the task ran
|
||||
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||
notify_finished: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||
}
|
||||
|
||||
/// A task that notifies the test that it ran
|
||||
|
@ -398,7 +398,7 @@ mod tests {
|
|||
context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// Notify the test that the task ran
|
||||
match context.tx.lock().await.take() {
|
||||
match context.notify_finished.lock().await.take() {
|
||||
None => println!("Cannot notify, already done that!"),
|
||||
Some(tx) => {
|
||||
tx.send(()).unwrap();
|
||||
|
@ -412,7 +412,7 @@ mod tests {
|
|||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
let my_app_context = NotifyFinishedContext {
|
||||
tx: Arc::new(Mutex::new(Some(tx))),
|
||||
notify_finished: Arc::new(Mutex::new(Some(tx))),
|
||||
};
|
||||
|
||||
let (join_handle, queue) =
|
||||
|
@ -529,6 +529,57 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn task_can_panic_and_not_affect_worker() {
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct BrokenTask;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for BrokenTask {
|
||||
const TASK_NAME: &'static str = "panic_me";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
_task: CurrentTask,
|
||||
_context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
panic!("Oh no!");
|
||||
}
|
||||
}
|
||||
|
||||
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
|
||||
|
||||
let task_store = memory_store().await;
|
||||
|
||||
let (worker_pool_finished, queue) = WorkerPool::new(task_store.clone(), |_| ())
|
||||
.register_task_type::<BrokenTask>()
|
||||
.configure_queue("default".into())
|
||||
.start(async move {
|
||||
should_stop.await.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Enqueue a task that will panic
|
||||
queue.enqueue(BrokenTask).await.unwrap();
|
||||
|
||||
notify_stop_worker_pool.send(()).unwrap();
|
||||
worker_pool_finished.await.unwrap();
|
||||
|
||||
let raw_task = task_store
|
||||
.tasks
|
||||
.lock()
|
||||
.await
|
||||
.first_entry()
|
||||
.unwrap()
|
||||
.remove();
|
||||
assert_eq!(
|
||||
serde_json::to_string(&raw_task.error_info.unwrap()).unwrap(),
|
||||
"{\"error\":\"Task panicked with: Oh no!\"}"
|
||||
);
|
||||
}
|
||||
|
||||
/// This test will make sure that the worker pool will only stop after all workers are done.
|
||||
/// We create a KeepAliveTask that will keep running until we notify it to stop.
|
||||
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
|
||||
|
|
Loading…
Reference in a new issue