From aa1144e54f5971ebdb813a264bde32d8a32ce6e3 Mon Sep 17 00:00:00 2001 From: Rafael Caricio Date: Mon, 13 Mar 2023 17:46:59 +0100 Subject: [PATCH] Allow definition of custom error type --- Cargo.toml | 1 - README.md | 7 ++- examples/simple_worker/src/main.rs | 6 ++- src/runnable.rs | 9 +++- src/worker.rs | 30 +++++++---- src/worker_pool.rs | 84 ++++++++++++++++++++---------- 6 files changed, 91 insertions(+), 46 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 85d9de4..1d8feb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,6 @@ chrono = "0.4" log = "0.4" serde = { version = "1", features = ["derive"] } serde_json = "1" -anyhow = "1" thiserror = "1" uuid = { version = "1.1", features = ["v4", "serde"] } async-trait = "0.1" diff --git a/README.md b/README.md index 518bfc1..c7ec6f8 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ If you are not already using, you will also want to include the following depend ```toml [dependencies] async-trait = "0.1" -anyhow = "1" serde = { version = "1.0", features = ["derive"] } diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] } diesel-async = { version = "0.2", features = ["postgres", "bb8"] } @@ -75,6 +74,9 @@ the whole application. This attribute is critical for reconstructing the task ba The [`BackgroundTask::AppData`] can be used to argument the task with your application specific contextual information. This is useful for example to pass a database connection pool to the task or other application configuration. +The [`BackgroundTask::Error`] is the error type that will be returned by the [`BackgroundTask::run`] method. You can +use this to define your own error type for your tasks. + The [`BackgroundTask::run`] method is where you define the behaviour of your background task execution. This method will be called by the task queue workers. @@ -92,8 +94,9 @@ pub struct MyTask { impl BackgroundTask for MyTask { const TASK_NAME: &'static str = "my_task_unique_name"; type AppData = (); + type Error = (); - async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> { + async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> { // Do something Ok(()) } diff --git a/examples/simple_worker/src/main.rs b/examples/simple_worker/src/main.rs index 9f39b86..08f571f 100644 --- a/examples/simple_worker/src/main.rs +++ b/examples/simple_worker/src/main.rs @@ -34,8 +34,9 @@ impl MyTask { impl BackgroundTask for MyTask { const TASK_NAME: &'static str = "my_task"; type AppData = MyApplicationContext; + type Error = anyhow::Error; - async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> { + async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> { // let new_task = MyTask::new(self.number + 1); // queue // .insert_task(&new_task) @@ -70,8 +71,9 @@ impl MyFailingTask { impl BackgroundTask for MyFailingTask { const TASK_NAME: &'static str = "my_failing_task"; type AppData = MyApplicationContext; + type Error = anyhow::Error; - async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> { + async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> { // let new_task = MyFailingTask::new(self.number + 1); // queue // .insert_task(&new_task) diff --git a/src/runnable.rs b/src/runnable.rs index 4079ee3..e272a2e 100644 --- a/src/runnable.rs +++ b/src/runnable.rs @@ -1,6 +1,7 @@ use crate::task::{CurrentTask, TaskHash}; use async_trait::async_trait; use serde::{de::DeserializeOwned, ser::Serialize}; +use std::fmt::Debug; /// The [`BackgroundTask`] trait is used to define the behaviour of a task. You must implement this /// trait for all tasks you want to execute. @@ -29,8 +30,9 @@ use serde::{de::DeserializeOwned, ser::Serialize}; /// impl BackgroundTask for MyTask { /// const TASK_NAME: &'static str = "my_task_unique_name"; /// type AppData = (); +/// type Error = Box; /// -/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> { +/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> { /// // Do something /// Ok(()) /// } @@ -57,8 +59,11 @@ pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static { /// The application data provided to this task at runtime. type AppData: Clone + Send + 'static; + /// An application custom error type. + type Error: Debug + Send + 'static; + /// Execute the task. This method should define its logic - async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>; + async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error>; /// If set to true, no new tasks with the same metadata will be inserted /// By default it is set to false. diff --git a/src/worker.rs b/src/worker.rs index ec827ec..c545b2e 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -30,7 +30,7 @@ pub enum TaskExecError { TaskDeserializationFailed(#[from] serde_json::Error), #[error("Task execution failed: {0}")] - ExecutionFailed(#[from] anyhow::Error), + ExecutionFailed(String), #[error("Task panicked with: {0}")] Panicked(String), @@ -46,8 +46,10 @@ where { Box::pin(async move { let background_task: BT = serde_json::from_value(payload)?; - background_task.run(task_info, app_context).await?; - Ok(()) + match background_task.run(task_info, app_context).await { + Ok(_) => Ok(()), + Err(err) => Err(TaskExecError::ExecutionFailed(format!("{:?}", err))), + } }) } @@ -250,8 +252,9 @@ mod async_worker_tests { impl BackgroundTask for WorkerAsyncTask { const TASK_NAME: &'static str = "WorkerAsyncTask"; type AppData = (); + type Error = (); - async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { + async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), ()> { Ok(()) } } @@ -265,8 +268,9 @@ mod async_worker_tests { impl BackgroundTask for WorkerAsyncTaskSchedule { const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule"; type AppData = (); + type Error = (); - async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> { Ok(()) } @@ -284,11 +288,12 @@ mod async_worker_tests { impl BackgroundTask for AsyncFailedTask { const TASK_NAME: &'static str = "AsyncFailedTask"; type AppData = (); + type Error = TaskError; - async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), TaskError> { let message = format!("number {} is wrong :(", self.number); - Err(TaskError::Custom(message).into()) + Err(TaskError::Custom(message)) } fn max_retries(&self) -> i32 { @@ -303,9 +308,10 @@ mod async_worker_tests { impl BackgroundTask for AsyncRetryTask { const TASK_NAME: &'static str = "AsyncRetryTask"; type AppData = (); + type Error = TaskError; - async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { - Err(TaskError::SomethingWrong.into()) + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> { + Err(TaskError::SomethingWrong) } } @@ -316,8 +322,9 @@ mod async_worker_tests { impl BackgroundTask for AsyncTaskType1 { const TASK_NAME: &'static str = "AsyncTaskType1"; type AppData = (); + type Error = (); - async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> { Ok(()) } } @@ -329,8 +336,9 @@ mod async_worker_tests { impl BackgroundTask for AsyncTaskType2 { const TASK_NAME: &'static str = "AsyncTaskType2"; type AppData = (); + type Error = (); - async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> { Ok(()) } } diff --git a/src/worker_pool.rs b/src/worker_pool.rs index 7c15bac..d9c4e53 100644 --- a/src/worker_pool.rs +++ b/src/worker_pool.rs @@ -240,7 +240,7 @@ mod tests { use tokio::sync::Mutex; #[derive(Clone, Debug)] - struct ApplicationContext { + pub struct ApplicationContext { app_name: String, } @@ -261,17 +261,50 @@ mod tests { person: String, } + /// This tests that one can customize the task parameters for the application. #[async_trait] - impl BackgroundTask for GreetingTask { - const TASK_NAME: &'static str = "my_task"; + trait MyAppTask { + const TASK_NAME: &'static str; + const QUEUE: &'static str = "default"; + + async fn run( + &self, + task_info: CurrentTask, + app_context: ApplicationContext, + ) -> Result<(), ()>; + } + + #[async_trait] + impl BackgroundTask for T + where + T: MyAppTask + serde::de::DeserializeOwned + serde::ser::Serialize + Sync + Send + 'static, + { + const TASK_NAME: &'static str = T::TASK_NAME; + + const QUEUE: &'static str = T::QUEUE; type AppData = ApplicationContext; + type Error = (); + async fn run( &self, task_info: CurrentTask, app_context: Self::AppData, - ) -> Result<(), anyhow::Error> { + ) -> Result<(), Self::Error> { + self.run(task_info, app_context).await + } + } + + #[async_trait] + impl MyAppTask for GreetingTask { + const TASK_NAME: &'static str = "my_task"; + + async fn run( + &self, + task_info: CurrentTask, + app_context: ApplicationContext, + ) -> Result<(), ()> { println!( "[{}] Hello {}! I'm {}.", task_info.id(), @@ -292,12 +325,9 @@ mod tests { const QUEUE: &'static str = "other_queue"; type AppData = ApplicationContext; + type Error = (); - async fn run( - &self, - task: CurrentTask, - context: Self::AppData, - ) -> Result<(), anyhow::Error> { + async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> { println!( "[{}] Other task with {}!", task.id(), @@ -332,7 +362,7 @@ mod tests { let (join_handle, queue) = WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) .register_task_type::() - .configure_queue(GreetingTask::QUEUE.into()) + .configure_queue(::QUEUE.into()) .start(futures::future::ready(())) .await .unwrap(); @@ -391,11 +421,9 @@ mod tests { type AppData = NotifyFinishedContext; - async fn run( - &self, - task: CurrentTask, - context: Self::AppData, - ) -> Result<(), anyhow::Error> { + type Error = (); + + async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> { // Notify the test that the task ran match context.notify_finished.lock().await.take() { None => println!("Cannot notify, already done that!"), @@ -455,11 +483,13 @@ mod tests { type AppData = NotifyUnknownRanContext; + type Error = (); + async fn run( &self, task: CurrentTask, context: Self::AppData, - ) -> Result<(), anyhow::Error> { + ) -> Result<(), Self::Error> { // Notify the test that the task ran match context.should_stop.lock().await.take() { None => println!("Cannot notify, already done that!"), @@ -481,11 +511,9 @@ mod tests { type AppData = NotifyUnknownRanContext; - async fn run( - &self, - task: CurrentTask, - context: Self::AppData, - ) -> Result<(), anyhow::Error> { + type Error = (); + + async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> { println!("[{}] Unknown task ran!", task.id()); context.unknown_task_ran.store(true, Ordering::Relaxed); Ok(()) @@ -537,12 +565,9 @@ mod tests { impl BackgroundTask for BrokenTask { const TASK_NAME: &'static str = "panic_me"; type AppData = (); + type Error = (); - async fn run( - &self, - _task: CurrentTask, - _context: Self::AppData, - ) -> Result<(), anyhow::Error> { + async fn run(&self, _task: CurrentTask, _context: Self::AppData) -> Result<(), ()> { panic!("Oh no!"); } } @@ -609,11 +634,13 @@ mod tests { type AppData = PlayerContext; + type Error = (); + async fn run( &self, _task: CurrentTask, context: Self::AppData, - ) -> Result<(), anyhow::Error> { + ) -> Result<(), Self::Error> { loop { let msg = context.ping_rx.lock().await.recv().await.unwrap(); match msg { @@ -696,7 +723,8 @@ mod tests { WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone()) .register_task_type::() .configure_queue( - QueueConfig::new(GreetingTask::QUEUE).retention_mode(RetentionMode::RemoveDone), + QueueConfig::new(::QUEUE) + .retention_mode(RetentionMode::RemoveDone), ) .start(futures::future::ready(())) .await