Handle tasks that panic
This commit is contained in:
parent
10e01390b8
commit
716eeae4b1
5 changed files with 123 additions and 11 deletions
46
src/catch_unwind.rs
Normal file
46
src/catch_unwind.rs
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
use crate::worker::TaskExecError;
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
use futures::FutureExt;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::Context;
|
||||||
|
use std::task::Poll;
|
||||||
|
|
||||||
|
pub(crate) struct CatchUnwindFuture<F: Future + Send + 'static> {
|
||||||
|
inner: BoxFuture<'static, F::Output>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Future + Send + 'static> CatchUnwindFuture<F> {
|
||||||
|
pub fn create(f: F) -> CatchUnwindFuture<F> {
|
||||||
|
Self { inner: f.boxed() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F: Future + Send + 'static> Future for CatchUnwindFuture<F> {
|
||||||
|
type Output = Result<F::Output, TaskExecError>;
|
||||||
|
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
|
let inner = &mut self.inner;
|
||||||
|
|
||||||
|
match catch_unwind(move || inner.poll_unpin(cx)) {
|
||||||
|
Ok(Poll::Pending) => Poll::Pending,
|
||||||
|
Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
|
||||||
|
Err(cause) => Poll::Ready(Err(cause)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> Result<R, TaskExecError> {
|
||||||
|
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
|
||||||
|
Ok(res) => Ok(res),
|
||||||
|
Err(cause) => match cause.downcast_ref::<&'static str>() {
|
||||||
|
None => match cause.downcast_ref::<String>() {
|
||||||
|
None => Err(TaskExecError::Panicked(
|
||||||
|
"Sorry, unknown panic message".to_string(),
|
||||||
|
)),
|
||||||
|
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
|
||||||
|
},
|
||||||
|
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
|
@ -30,6 +30,7 @@ pub use task::{CurrentTask, Task, TaskId, TaskState};
|
||||||
pub use worker::Worker;
|
pub use worker::Worker;
|
||||||
pub use worker_pool::{QueueConfig, WorkerPool};
|
pub use worker_pool::{QueueConfig, WorkerPool};
|
||||||
|
|
||||||
|
mod catch_unwind;
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
mod queries;
|
mod queries;
|
||||||
mod queue;
|
mod queue;
|
||||||
|
|
|
@ -106,7 +106,7 @@ pub mod test_store {
|
||||||
|
|
||||||
#[derive(Default, Clone)]
|
#[derive(Default, Clone)]
|
||||||
pub struct MemoryTaskStore {
|
pub struct MemoryTaskStore {
|
||||||
tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
|
pub tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::catch_unwind::CatchUnwindFuture;
|
||||||
use crate::errors::{AsyncQueueError, BackieError};
|
use crate::errors::{AsyncQueueError, BackieError};
|
||||||
use crate::runnable::BackgroundTask;
|
use crate::runnable::BackgroundTask;
|
||||||
use crate::store::TaskStore;
|
use crate::store::TaskStore;
|
||||||
|
@ -10,7 +11,6 @@ use std::future::Future;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
pub type ExecuteTaskFn<AppData> = Arc<
|
pub type ExecuteTaskFn<AppData> = Arc<
|
||||||
dyn Fn(
|
dyn Fn(
|
||||||
|
@ -24,13 +24,16 @@ pub type ExecuteTaskFn<AppData> = Arc<
|
||||||
|
|
||||||
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
|
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum TaskExecError {
|
pub enum TaskExecError {
|
||||||
|
#[error("Task deserialization failed: {0}")]
|
||||||
|
TaskDeserializationFailed(#[from] serde_json::Error),
|
||||||
|
|
||||||
#[error("Task execution failed: {0}")]
|
#[error("Task execution failed: {0}")]
|
||||||
ExecutionFailed(#[from] anyhow::Error),
|
ExecutionFailed(#[from] anyhow::Error),
|
||||||
|
|
||||||
#[error("Task deserialization failed: {0}")]
|
#[error("Task panicked with: {0}")]
|
||||||
TaskDeserializationFailed(#[from] serde_json::Error),
|
Panicked(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn runnable<BT>(
|
pub(crate) fn runnable<BT>(
|
||||||
|
@ -144,9 +147,18 @@ where
|
||||||
.get(&task.task_name)
|
.get(&task.task_name)
|
||||||
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
|
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
|
||||||
|
|
||||||
// TODO: catch panics
|
// catch panics
|
||||||
let result: Result<(), TaskExecError> =
|
let result: Result<(), TaskExecError> = CatchUnwindFuture::create({
|
||||||
runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
|
let task_payload = task.payload.clone();
|
||||||
|
let app_data = (self.app_data_fn)();
|
||||||
|
let runnable_task_caller = runnable_task_caller.clone();
|
||||||
|
async move { runnable_task_caller(task_info, task_payload, app_data).await }
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.and_then(|result| {
|
||||||
|
result?;
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
match &result {
|
match &result {
|
||||||
Ok(_) => self.finalize_task(task, result).await?,
|
Ok(_) => self.finalize_task(task, result).await?,
|
||||||
|
@ -159,7 +171,9 @@ where
|
||||||
task.id,
|
task.id,
|
||||||
backoff_seconds
|
backoff_seconds
|
||||||
);
|
);
|
||||||
|
|
||||||
let error_message = format!("{}", error);
|
let error_message = format!("{}", error);
|
||||||
|
|
||||||
self.store
|
self.store
|
||||||
.schedule_task_retry(task.id, backoff_seconds, &error_message)
|
.schedule_task_retry(task.id, backoff_seconds, &error_message)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
|
@ -379,7 +379,7 @@ mod tests {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct NotifyFinishedContext {
|
struct NotifyFinishedContext {
|
||||||
/// Used to notify the task ran
|
/// Used to notify the task ran
|
||||||
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
notify_finished: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A task that notifies the test that it ran
|
/// A task that notifies the test that it ran
|
||||||
|
@ -398,7 +398,7 @@ mod tests {
|
||||||
context: Self::AppData,
|
context: Self::AppData,
|
||||||
) -> Result<(), anyhow::Error> {
|
) -> Result<(), anyhow::Error> {
|
||||||
// Notify the test that the task ran
|
// Notify the test that the task ran
|
||||||
match context.tx.lock().await.take() {
|
match context.notify_finished.lock().await.take() {
|
||||||
None => println!("Cannot notify, already done that!"),
|
None => println!("Cannot notify, already done that!"),
|
||||||
Some(tx) => {
|
Some(tx) => {
|
||||||
tx.send(()).unwrap();
|
tx.send(()).unwrap();
|
||||||
|
@ -412,7 +412,7 @@ mod tests {
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
|
|
||||||
let my_app_context = NotifyFinishedContext {
|
let my_app_context = NotifyFinishedContext {
|
||||||
tx: Arc::new(Mutex::new(Some(tx))),
|
notify_finished: Arc::new(Mutex::new(Some(tx))),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (join_handle, queue) =
|
let (join_handle, queue) =
|
||||||
|
@ -529,6 +529,57 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn task_can_panic_and_not_affect_worker() {
|
||||||
|
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
struct BrokenTask;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BackgroundTask for BrokenTask {
|
||||||
|
const TASK_NAME: &'static str = "panic_me";
|
||||||
|
type AppData = ();
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
_task: CurrentTask,
|
||||||
|
_context: Self::AppData,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
panic!("Oh no!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
|
||||||
|
|
||||||
|
let task_store = memory_store().await;
|
||||||
|
|
||||||
|
let (worker_pool_finished, queue) = WorkerPool::new(task_store.clone(), |_| ())
|
||||||
|
.register_task_type::<BrokenTask>()
|
||||||
|
.configure_queue("default".into())
|
||||||
|
.start(async move {
|
||||||
|
should_stop.await.unwrap();
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Enqueue a task that will panic
|
||||||
|
queue.enqueue(BrokenTask).await.unwrap();
|
||||||
|
|
||||||
|
notify_stop_worker_pool.send(()).unwrap();
|
||||||
|
worker_pool_finished.await.unwrap();
|
||||||
|
|
||||||
|
let raw_task = task_store
|
||||||
|
.tasks
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.first_entry()
|
||||||
|
.unwrap()
|
||||||
|
.remove();
|
||||||
|
assert_eq!(
|
||||||
|
serde_json::to_string(&raw_task.error_info.unwrap()).unwrap(),
|
||||||
|
"{\"error\":\"Task panicked with: Oh no!\"}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
/// This test will make sure that the worker pool will only stop after all workers are done.
|
/// This test will make sure that the worker pool will only stop after all workers are done.
|
||||||
/// We create a KeepAliveTask that will keep running until we notify it to stop.
|
/// We create a KeepAliveTask that will keep running until we notify it to stop.
|
||||||
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
|
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
|
||||||
|
|
Loading…
Reference in a new issue