diff --git a/.gitignore b/.gitignore index 9db20c8..147e43c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ **/target Cargo.lock -src/schema.rs docs/content/docs/CHANGELOG.md docs/content/docs/README.md diff --git a/Cargo.toml b/Cargo.toml index e3d707f..fe77e32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,16 +19,13 @@ cron = "0.12" chrono = "0.4" hex = "0.4" log = "0.4" -serde = "1.0" -serde_derive = "1.0" -serde_json = "1.0" +serde = { version = "1", features = ["derive"] } +serde_json = "1" sha2 = "0.10" -thiserror = "1.0" -typed-builder = "0.13" -typetag = "0.2" +anyhow = "1" +thiserror = "1" uuid = { version = "1.1", features = ["v4", "serde"] } async-trait = "0.1" -async-recursion = "1" futures = "0.3" diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] } diesel-derive-newtype = "2.0.0-rc.0" diff --git a/examples/simple_cron_async_worker/Cargo.toml b/examples/simple_cron_async_worker/Cargo.toml deleted file mode 100644 index 8e88150..0000000 --- a/examples/simple_cron_async_worker/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "simple_cron_async_worker" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -fang = { path = "../../../" , features = ["asynk"]} -env_logger = "0.9.0" -log = "0.4.0" -tokio = { version = "1", features = ["full"] } diff --git a/examples/simple_cron_async_worker/src/lib.rs b/examples/simple_cron_async_worker/src/lib.rs deleted file mode 100644 index 2bb972b..0000000 --- a/examples/simple_cron_async_worker/src/lib.rs +++ /dev/null @@ -1,33 +0,0 @@ -use fang::async_trait; -use fang::asynk::async_queue::AsyncQueueable; -use fang::serde::{Deserialize, Serialize}; -use fang::typetag; -use fang::AsyncRunnable; -use fang::FangError; -use fang::Scheduled; - -#[derive(Serialize, Deserialize)] -#[serde(crate = "fang::serde")] -pub struct MyCronTask {} - -#[async_trait] -#[typetag::serde] -impl AsyncRunnable for MyCronTask { - async fn run(&self, _queue: &mut dyn AsyncQueueable) -> Result<(), FangError> { - log::info!("CRON!!!!!!!!!!!!!!!",); - - Ok(()) - } - - fn cron(&self) -> Option { - // sec min hour day of month month day of week year - // be careful works only with UTC hour. - // https://www.timeanddate.com/worldclock/timezone/utc - let expression = "0/20 * * * Aug-Sep * 2022/1"; - Some(Scheduled::CronPattern(expression.to_string())) - } - - fn uniq(&self) -> bool { - true - } -} diff --git a/examples/simple_cron_async_worker/src/main.rs b/examples/simple_cron_async_worker/src/main.rs deleted file mode 100644 index 37c6197..0000000 --- a/examples/simple_cron_async_worker/src/main.rs +++ /dev/null @@ -1,41 +0,0 @@ -use fang::asynk::async_queue::AsyncQueue; -use fang::asynk::async_queue::AsyncQueueable; -use fang::asynk::async_worker_pool::AsyncWorkerPool; -use fang::AsyncRunnable; -use fang::NoTls; -use simple_cron_async_worker::MyCronTask; -use std::time::Duration; - -#[tokio::main] -async fn main() { - env_logger::init(); - - log::info!("Starting..."); - let max_pool_size: u32 = 3; - let mut queue = AsyncQueue::builder() - .uri("postgres://postgres:postgres@localhost/fang") - .max_pool_size(max_pool_size) - .build(); - - queue.connect(NoTls).await.unwrap(); - log::info!("Queue connected..."); - - let mut pool: AsyncWorkerPool> = AsyncWorkerPool::builder() - .number_of_workers(10_u32) - .queue(queue.clone()) - .build(); - - log::info!("Pool created ..."); - - pool.start().await; - log::info!("Workers started ..."); - - let task = MyCronTask {}; - - queue - .schedule_task(&task as &dyn AsyncRunnable) - .await - .unwrap(); - - tokio::time::sleep(Duration::from_secs(100)).await; -} diff --git a/examples/simple_worker/Cargo.toml b/examples/simple_worker/Cargo.toml index efd1d1a..5b3a9c1 100644 --- a/examples/simple_worker/Cargo.toml +++ b/examples/simple_worker/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] backie = { path = "../../" } +anyhow = "1" env_logger = "0.9.0" log = "0.4.0" tokio = { version = "1", features = ["full"] } @@ -12,4 +13,3 @@ diesel-async = { version = "0.2", features = ["postgres", "bb8"] } diesel = { version = "2.0", features = ["postgres"] } async-trait = "0.1" serde = { version = "1.0", features = ["derive"] } -typetag = "0.2" diff --git a/examples/simple_worker/src/lib.rs b/examples/simple_worker/src/lib.rs index ee847f0..4776108 100644 --- a/examples/simple_worker/src/lib.rs +++ b/examples/simple_worker/src/lib.rs @@ -1,7 +1,20 @@ -use std::time::Duration; use async_trait::async_trait; -use serde::{Serialize, Deserialize}; -use backie::{RunnableTask, Queueable}; +use backie::{BackgroundTask, CurrentTask}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Clone, Debug)] +pub struct MyApplicationContext { + app_name: String, +} + +impl MyApplicationContext { + pub fn new(app_name: &str) -> Self { + Self { + app_name: app_name.to_string(), + } + } +} #[derive(Serialize, Deserialize)] pub struct MyTask { @@ -26,37 +39,51 @@ impl MyFailingTask { } #[async_trait] -#[typetag::serde] -impl RunnableTask for MyTask { - async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box> { +impl BackgroundTask for MyTask { + const TASK_NAME: &'static str = "my_task"; + type AppData = MyApplicationContext; + + async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> { // let new_task = MyTask::new(self.number + 1); // queue - // .insert_task(&new_task as &dyn AsyncRunnable) + // .insert_task(&new_task) // .await // .unwrap(); - log::info!("the current number is {}", self.number); + log::info!( + "[{}] Hello from {}! the current number is {}", + task.id(), + ctx.app_name, + self.number + ); tokio::time::sleep(Duration::from_secs(3)).await; - log::info!("done.."); + log::info!("[{}] done..", task.id()); Ok(()) } } #[async_trait] -#[typetag::serde] -impl RunnableTask for MyFailingTask { - async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box> { +impl BackgroundTask for MyFailingTask { + const TASK_NAME: &'static str = "my_failing_task"; + type AppData = MyApplicationContext; + + async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> { // let new_task = MyFailingTask::new(self.number + 1); // queue - // .insert_task(&new_task as &dyn AsyncRunnable) + // .insert_task(&new_task) // .await // .unwrap(); - log::info!("the current number is {}", self.number); + // task.id(); + // task.keep_alive().await?; + // task.previous_error(); + // task.retry_count(); + + log::info!("[{}] the current number is {}", task.id(), self.number); tokio::time::sleep(Duration::from_secs(3)).await; - log::info!("done.."); + log::info!("[{}] done..", task.id()); // // let b = true; // diff --git a/examples/simple_worker/src/main.rs b/examples/simple_worker/src/main.rs index ed207aa..541e564 100644 --- a/examples/simple_worker/src/main.rs +++ b/examples/simple_worker/src/main.rs @@ -1,9 +1,9 @@ -use simple_worker::MyFailingTask; -use simple_worker::MyTask; -use std::time::Duration; +use backie::{PgTaskStore, RetentionMode, WorkerPool}; use diesel_async::pg::AsyncPgConnection; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; -use backie::{PgAsyncQueue, WorkerPool, Queueable}; +use simple_worker::MyApplicationContext; +use simple_worker::MyFailingTask; +use simple_worker::MyTask; #[tokio::main] async fn main() { @@ -22,49 +22,38 @@ async fn main() { .unwrap(); log::info!("Pool created ..."); - let mut queue = PgAsyncQueue::new(pool); + let task_store = PgTaskStore::new(pool); let (tx, mut rx) = tokio::sync::watch::channel(false); - let executor_task = tokio::spawn({ - let mut queue = queue.clone(); - async move { - let mut workers_pool: WorkerPool = WorkerPool::builder() - .number_of_workers(10_u32) - .queue(queue) - .build(); + // Some global application context I want to pass to my background tasks + let my_app_context = MyApplicationContext::new("Backie Example App"); - log::info!("Workers starting ..."); - workers_pool.start(async move { - rx.changed().await; - }).await; - log::info!("Workers stopped!"); - } - }); + // Register the task types I want to use and start the worker pool + let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone()) + .register_task_type::(1, RetentionMode::RemoveDone) + .register_task_type::(1, RetentionMode::RemoveDone) + .start(async move { + let _ = rx.changed().await; + }) + .await + .unwrap(); + + log::info!("Workers started ..."); let task1 = MyTask::new(0); let task2 = MyTask::new(20_000); let task3 = MyFailingTask::new(50_000); - queue - .create_task(&task1) - .await - .unwrap(); - - queue - .create_task(&task2) - .await - .unwrap(); - - queue - .create_task(&task3) - .await - .unwrap(); - + queue.enqueue(task1).await.unwrap(); + queue.enqueue(task2).await.unwrap(); + queue.enqueue(task3).await.unwrap(); log::info!("Tasks created ..."); - tokio::signal::ctrl_c().await; + + // Wait for Ctrl+C + let _ = tokio::signal::ctrl_c().await; log::info!("Stopping ..."); tx.send(true).unwrap(); - executor_task.await.unwrap(); - log::info!("Stopped!"); + join_handle.await.unwrap(); + log::info!("Workers Stopped!"); } diff --git a/migrations/2023-03-06-151907_create_backie_tasks/up.sql b/migrations/2023-03-06-151907_create_backie_tasks/up.sql index 1baaac5..919fa9d 100644 --- a/migrations/2023-03-06-151907_create_backie_tasks/up.sql +++ b/migrations/2023-03-06-151907_create_backie_tasks/up.sql @@ -2,20 +2,20 @@ CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; CREATE TABLE backie_tasks ( id uuid PRIMARY KEY DEFAULT uuid_generate_v4(), - payload jsonb NOT NULL, - error_message TEXT DEFAULT NULL, - task_type VARCHAR DEFAULT 'common' NOT NULL, + task_name VARCHAR NOT NULL, + queue_name VARCHAR DEFAULT 'common' NOT NULL, uniq_hash CHAR(64) DEFAULT NULL, - retries INTEGER DEFAULT 0 NOT NULL, + payload jsonb NOT NULL, + timeout_msecs INT8 NOT NULL, created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), + scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), running_at TIMESTAMP WITH TIME ZONE DEFAULT NULL, - done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL + done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL, + error_info jsonb DEFAULT NULL, + retries INTEGER DEFAULT 0 NOT NULL, + max_retries INTEGER DEFAULT 0 NOT NULL ); -CREATE INDEX backie_tasks_type_index ON backie_tasks(task_type); -CREATE INDEX backie_tasks_created_at_index ON backie_tasks(created_at); -CREATE INDEX backie_tasks_uniq_hash ON backie_tasks(uniq_hash); - --- create uniqueness index CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL; diff --git a/src/errors.rs b/src/errors.rs index db2fcf1..b3b2a04 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,25 +1,16 @@ -use serde_json::Error as SerdeError; -use std::fmt::Display; use thiserror::Error; /// Library errors #[derive(Debug, Error)] pub enum BackieError { + #[error("Queue processing error: {0}")] QueueProcessingError(#[from] AsyncQueueError), - SerializationError(#[from] SerdeError), - ShutdownError(#[from] tokio::sync::watch::error::SendError<()>), -} -impl Display for BackieError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - BackieError::QueueProcessingError(error) => { - write!(f, "Queue processing error: {}", error) - } - BackieError::SerializationError(error) => write!(f, "Serialization error: {}", error), - BackieError::ShutdownError(error) => write!(f, "Shutdown error: {}", error), - } - } + #[error("Worker Pool shutdown error: {0}")] + WorkerPoolShutdownError(#[from] tokio::sync::watch::error::SendError<()>), + + #[error("Worker shutdown error: {0}")] + WorkerShutdownError(#[from] tokio::sync::watch::error::RecvError), } /// List of error types that can occur while working with cron schedules. @@ -29,7 +20,7 @@ pub enum CronError { #[error(transparent)] LibraryError(#[from] cron::error::Error), /// [`Scheduled`] enum variant is not provided - #[error("You have to implement method `cron()` in your AsyncRunnable")] + #[error("You have to implement method `cron()` in your Runnable")] TaskNotSchedulableError, /// The next execution can not be determined using the current [`Scheduled::CronPattern`] #[error("No timestamps match with this cron pattern")] @@ -41,7 +32,7 @@ pub enum AsyncQueueError { #[error(transparent)] PgError(#[from] diesel::result::Error), - #[error(transparent)] + #[error("Task serialization error: {0}")] SerdeError(#[from] serde_json::Error), #[error(transparent)] @@ -49,6 +40,9 @@ pub enum AsyncQueueError { #[error("Task is not in progress, operation not allowed")] TaskNotRunning, + + #[error("Task with name {0} is not registered")] + TaskNotRegistered(String), } impl From for AsyncQueueError { diff --git a/src/lib.rs b/src/lib.rs index e4958b1..23a7c92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,9 +38,9 @@ impl Default for RetentionMode { } } -pub use queue::PgAsyncQueue; -pub use queue::Queueable; -pub use runnable::RunnableTask; +pub use queue::PgTaskStore; +pub use runnable::BackgroundTask; +pub use task::CurrentTask; pub use worker_pool::WorkerPool; pub mod errors; @@ -48,6 +48,7 @@ mod queries; pub mod queue; pub mod runnable; mod schema; +pub mod store; pub mod task; pub mod worker; pub mod worker_pool; diff --git a/src/queries.rs b/src/queries.rs index 2640ef1..0343bcf 100644 --- a/src/queries.rs +++ b/src/queries.rs @@ -1,8 +1,7 @@ use crate::errors::AsyncQueueError; -use crate::runnable::RunnableTask; use crate::schema::backie_tasks; use crate::task::Task; -use crate::task::{NewTask, TaskHash, TaskId, TaskType}; +use crate::task::{NewTask, TaskHash, TaskId}; use chrono::DateTime; use chrono::Duration; use chrono::Utc; @@ -43,14 +42,6 @@ impl Task { Ok(qty > 0) } - pub(crate) async fn remove_by_type( - connection: &mut AsyncPgConnection, - task_type: TaskType, - ) -> Result { - let query = backie_tasks::table.filter(backie_tasks::task_type.eq(task_type)); - Ok(diesel::delete(query).execute(connection).await? as u64) - } - pub(crate) async fn find_by_id( connection: &mut AsyncPgConnection, id: TaskId, @@ -67,10 +58,13 @@ impl Task { id: TaskId, error_message: &str, ) -> Result { + let error = serde_json::json!({ + "error": error_message, + }); let query = backie_tasks::table.filter(backie_tasks::id.eq(id)); Ok(diesel::update(query) .set(( - backie_tasks::error_message.eq(error_message), + backie_tasks::error_info.eq(Some(error)), backie_tasks::done_at.eq(Utc::now()), )) .get_result::(connection) @@ -81,18 +75,22 @@ impl Task { connection: &mut AsyncPgConnection, id: TaskId, backoff_seconds: u32, - error: &str, + error_message: &str, ) -> Result { use crate::schema::backie_tasks::dsl; let now = Utc::now(); let scheduled_at = now + Duration::seconds(backoff_seconds as i64); + let error = serde_json::json!({ + "error": error_message, + }); + let task = diesel::update(backie_tasks::table.filter(backie_tasks::id.eq(id))) .set(( - backie_tasks::error_message.eq(error), + backie_tasks::error_info.eq(Some(error)), backie_tasks::retries.eq(dsl::retries + 1), - backie_tasks::created_at.eq(scheduled_at), + backie_tasks::scheduled_at.eq(scheduled_at), backie_tasks::running_at.eq::>>(None), )) .get_result::(connection) @@ -103,14 +101,14 @@ impl Task { pub(crate) async fn fetch_next_pending( connection: &mut AsyncPgConnection, - task_type: TaskType, + queue_name: &str, ) -> Option { backie_tasks::table - .filter(backie_tasks::created_at.lt(Utc::now())) // skip tasks scheduled for the future + .filter(backie_tasks::scheduled_at.lt(Utc::now())) // skip tasks scheduled for the future .order(backie_tasks::created_at.asc()) // get the oldest task first .filter(backie_tasks::running_at.is_null()) // that is not marked as running already .filter(backie_tasks::done_at.is_null()) // and not marked as done - .filter(backie_tasks::task_type.eq(task_type)) + .filter(backie_tasks::queue_name.eq(queue_name)) .limit(1) .for_update() .skip_locked() @@ -143,38 +141,12 @@ impl Task { pub(crate) async fn insert( connection: &mut AsyncPgConnection, - runnable: &dyn RunnableTask, + new_task: NewTask, ) -> Result { - let payload = serde_json::to_value(runnable)?; - match runnable.uniq() { - None => { - let new_task = NewTask::builder() - .uniq_hash(None) - .task_type(runnable.task_type()) - .payload(payload) - .build(); - - Ok(diesel::insert_into(backie_tasks::table) - .values(new_task) - .get_result::(connection) - .await?) - } - Some(hash) => match Self::find_by_uniq_hash(connection, hash.clone()).await { - Some(task) => Ok(task), - None => { - let new_task = NewTask::builder() - .uniq_hash(Some(hash)) - .task_type(runnable.task_type()) - .payload(payload) - .build(); - - Ok(diesel::insert_into(backie_tasks::table) - .values(new_task) - .get_result::(connection) - .await?) - } - }, - } + Ok(diesel::insert_into(backie_tasks::table) + .values(new_task) + .get_result::(connection) + .await?) } pub(crate) async fn find_by_uniq_hash( diff --git a/src/queue.rs b/src/queue.rs index 8455716..281d783 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -1,86 +1,48 @@ use crate::errors::AsyncQueueError; -use crate::runnable::RunnableTask; -use crate::task::{Task, TaskHash, TaskId, TaskType}; -use async_trait::async_trait; +use crate::runnable::BackgroundTask; +use crate::task::{NewTask, Task, TaskHash, TaskId, TaskState}; use diesel::result::Error::QueryBuilderError; use diesel_async::scoped_futures::ScopedFutureExt; use diesel_async::AsyncConnection; use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool}; +use std::sync::Arc; +use std::time::Duration; -/// This trait defines operations for an asynchronous queue. -/// The trait can be implemented for different storage backends. -/// For now, the trait is only implemented for PostgreSQL. More backends are planned to be implemented in the future. -#[async_trait] -pub trait Queueable: Send { - /// Pull pending tasks from the queue to execute them. - /// - /// This method returns one task of the `task_type` type. If `task_type` is `None` it will try to - /// fetch a task of the type `common`. The returned task is marked as running and must be executed. - async fn pull_next_task( - &mut self, - kind: Option, - ) -> Result, AsyncQueueError>; +#[derive(Clone)] +pub struct Queue { + task_store: Arc, +} - /// Enqueue a task to the queue, The task will be executed as soon as possible by the worker of the same type - /// created by an AsyncWorkerPool. - async fn create_task(&mut self, task: &dyn RunnableTask) -> Result; +impl Queue { + pub(crate) fn new(task_store: Arc) -> Self { + Queue { task_store } + } - /// Retrieve a task by its `id`. - async fn find_task_by_id(&mut self, id: TaskId) -> Result; - - /// Update the state of a task to failed and set an error_message. - async fn set_task_failed( - &mut self, - id: TaskId, - error_message: &str, - ) -> Result; - - /// Update the state of a task to done. - async fn set_task_done(&mut self, id: TaskId) -> Result; - - /// Update the state of a task to inform that it's still in progress. - async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError>; - - /// Remove a task by its id. - async fn remove_task(&mut self, id: TaskId) -> Result; - - /// The method will remove all tasks from the queue - async fn remove_all_tasks(&mut self) -> Result; - - /// Remove all tasks that are scheduled in the future. - async fn remove_all_scheduled_tasks(&mut self) -> Result; - - /// Remove a task by its metadata (struct fields values) - async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result; - - /// Removes all tasks that have the specified `task_type`. - async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result; - - async fn schedule_task_retry( - &mut self, - id: TaskId, - backoff_seconds: u32, - error: &str, - ) -> Result; + pub async fn enqueue(&self, background_task: BT) -> Result<(), AsyncQueueError> + where + BT: BackgroundTask, + { + self.task_store + .create_task(NewTask::new(background_task, Duration::from_secs(10))?) + .await?; + Ok(()) + } } /// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage. #[derive(Debug, Clone)] -pub struct PgAsyncQueue { +pub struct PgTaskStore { pool: Pool, } -impl PgAsyncQueue { +impl PgTaskStore { pub fn new(pool: Pool) -> Self { - PgAsyncQueue { pool } + PgTaskStore { pool } } -} -#[async_trait] -impl Queueable for PgAsyncQueue { - async fn pull_next_task( - &mut self, - task_type: Option, + pub(crate) async fn pull_next_task( + &self, + queue_name: &str, ) -> Result, AsyncQueueError> { let mut connection = self .pool @@ -90,7 +52,7 @@ impl Queueable for PgAsyncQueue { connection .transaction::, AsyncQueueError, _>(|conn| { async move { - let Some(pending_task) = Task::fetch_next_pending(conn, task_type.unwrap_or_default()).await else { + let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else { return Ok(None); }; @@ -101,16 +63,16 @@ impl Queueable for PgAsyncQueue { .await } - async fn create_task(&mut self, runnable: &dyn RunnableTask) -> Result { + pub(crate) async fn create_task(&self, new_task: NewTask) -> Result { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; - Ok(Task::insert(&mut connection, runnable).await?) + Task::insert(&mut connection, new_task).await } - async fn find_task_by_id(&mut self, id: TaskId) -> Result { + pub(crate) async fn find_task_by_id(&self, id: TaskId) -> Result { let mut connection = self .pool .get() @@ -119,29 +81,27 @@ impl Queueable for PgAsyncQueue { Task::find_by_id(&mut connection, id).await } - async fn set_task_failed( - &mut self, + pub(crate) async fn set_task_state( + &self, id: TaskId, - error_message: &str, - ) -> Result { + state: TaskState, + ) -> Result<(), AsyncQueueError> { let mut connection = self .pool .get() .await .map_err(|e| QueryBuilderError(e.into()))?; - Task::fail_with_message(&mut connection, id, error_message).await + match state { + TaskState::Done => Task::set_done(&mut connection, id).await?, + TaskState::Failed(error_msg) => { + Task::fail_with_message(&mut connection, id, &error_msg).await? + } + _ => return Ok(()), + }; + Ok(()) } - async fn set_task_done(&mut self, id: TaskId) -> Result { - let mut connection = self - .pool - .get() - .await - .map_err(|e| QueryBuilderError(e.into()))?; - Task::set_done(&mut connection, id).await - } - - async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError> { + pub(crate) async fn keep_task_alive(&self, id: TaskId) -> Result<(), AsyncQueueError> { let mut connection = self .pool .get() @@ -159,7 +119,7 @@ impl Queueable for PgAsyncQueue { .await } - async fn remove_task(&mut self, id: TaskId) -> Result { + pub(crate) async fn remove_task(&self, id: TaskId) -> Result { let mut connection = self .pool .get() @@ -169,7 +129,7 @@ impl Queueable for PgAsyncQueue { Ok(result) } - async fn remove_all_tasks(&mut self) -> Result { + pub(crate) async fn remove_all_tasks(&self) -> Result { let mut connection = self .pool .get() @@ -178,37 +138,8 @@ impl Queueable for PgAsyncQueue { Task::remove_all(&mut connection).await } - async fn remove_all_scheduled_tasks(&mut self) -> Result { - let mut connection = self - .pool - .get() - .await - .map_err(|e| QueryBuilderError(e.into()))?; - let result = Task::remove_all_scheduled(&mut connection).await?; - Ok(result) - } - - async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result { - let mut connection = self - .pool - .get() - .await - .map_err(|e| QueryBuilderError(e.into()))?; - Task::remove_by_hash(&mut connection, task_hash).await - } - - async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result { - let mut connection = self - .pool - .get() - .await - .map_err(|e| QueryBuilderError(e.into()))?; - let result = Task::remove_by_type(&mut connection, task_type).await?; - Ok(result) - } - - async fn schedule_task_retry( - &mut self, + pub(crate) async fn schedule_task_retry( + &self, id: TaskId, backoff_seconds: u32, error: &str, @@ -226,11 +157,8 @@ impl Queueable for PgAsyncQueue { #[cfg(test)] mod async_queue_tests { use super::*; - use crate::task::TaskState; - use crate::Scheduled; + use crate::CurrentTask; use async_trait::async_trait; - use chrono::DateTime; - use chrono::Utc; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::AsyncPgConnection; use serde::{Deserialize, Serialize}; @@ -240,13 +168,12 @@ mod async_queue_tests { pub number: u16, } - #[typetag::serde] #[async_trait] - impl RunnableTask for AsyncTask { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { + impl BackgroundTask for AsyncTask { + const TASK_NAME: &'static str = "AsyncUniqTask"; + type AppData = (); + + async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } } @@ -256,13 +183,12 @@ mod async_queue_tests { pub number: u16, } - #[typetag::serde] #[async_trait] - impl RunnableTask for AsyncUniqTask { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { + impl BackgroundTask for AsyncUniqTask { + const TASK_NAME: &'static str = "AsyncUniqTask"; + type AppData = (); + + async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } @@ -277,286 +203,154 @@ mod async_queue_tests { pub datetime: String, } - #[typetag::serde] #[async_trait] - impl RunnableTask for AsyncTaskSchedule { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { + impl BackgroundTask for AsyncTaskSchedule { + const TASK_NAME: &'static str = "AsyncUniqTask"; + type AppData = (); + + async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } - fn cron(&self) -> Option { - let datetime = self.datetime.parse::>().ok()?; - Some(Scheduled::ScheduleOnce(datetime)) - } - } - - #[tokio::test] - async fn insert_task_creates_new_task() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(1), number); - assert_eq!(Some("AsyncTask"), type_task); - - test.remove_all_tasks().await.unwrap(); - } - - #[tokio::test] - async fn update_task_state_test() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - let id = task.id; - - assert_eq!(Some(1), number); - assert_eq!(Some("AsyncTask"), type_task); - - let finished_task = test.set_task_done(task.id).await.unwrap(); - - assert_eq!(id, finished_task.id); - assert_eq!(TaskState::Done, finished_task.state()); - - test.remove_all_tasks().await.unwrap(); - } - - #[tokio::test] - async fn failed_task_query_test() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - let id = task.id; - - assert_eq!(Some(1), number); - assert_eq!(Some("AsyncTask"), type_task); - - let failed_task = test.set_task_failed(task.id, "Some error").await.unwrap(); - - assert_eq!(id, failed_task.id); - assert_eq!(Some("Some error"), failed_task.error_message.as_deref()); - assert_eq!(TaskState::Failed, failed_task.state()); - - test.remove_all_tasks().await.unwrap(); - } - - #[tokio::test] - async fn remove_all_tasks_test() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(1), number); - assert_eq!(Some("AsyncTask"), type_task); - - let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(2), number); - assert_eq!(Some("AsyncTask"), type_task); - - let result = test.remove_all_tasks().await.unwrap(); - assert_eq!(2, result); + // fn cron(&self) -> Option { + // let datetime = self.datetime.parse::>().ok()?; + // Some(Scheduled::ScheduleOnce(datetime)) + // } } // #[tokio::test] - // async fn schedule_task_test() { + // async fn insert_task_creates_new_task() { // let pool = pool().await; - // let mut test = PgAsyncQueue::new(pool); + // let mut queue = PgTaskStore::new(pool); // - // let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0); - // - // let task = &AsyncTaskSchedule { - // number: 1, - // datetime: datetime.to_string(), - // }; - // - // let task = test.schedule_task(task).await.unwrap(); + // let task = queue.create_task(AsyncTask { number: 1 }).await.unwrap(); // // let metadata = task.payload.as_object().unwrap(); // let number = metadata["number"].as_u64(); // let type_task = metadata["type"].as_str(); // // assert_eq!(Some(1), number); - // assert_eq!(Some("AsyncTaskSchedule"), type_task); - // assert_eq!(task.scheduled_at, datetime); + // assert_eq!(Some("AsyncTask"), type_task); + // + // queue.remove_all_tasks().await.unwrap(); + // } + // + // #[tokio::test] + // async fn update_task_state_test() { + // let pool = pool().await; + // let mut test = PgTaskStore::new(pool); + // + // let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); + // + // let metadata = task.payload.as_object().unwrap(); + // let number = metadata["number"].as_u64(); + // let type_task = metadata["type"].as_str(); + // let id = task.id; + // + // assert_eq!(Some(1), number); + // assert_eq!(Some("AsyncTask"), type_task); + // + // let finished_task = test.set_task_state(task.id, TaskState::Done).await.unwrap(); + // + // assert_eq!(id, finished_task.id); + // assert_eq!(TaskState::Done, finished_task.state()); // // test.remove_all_tasks().await.unwrap(); // } // // #[tokio::test] - // async fn remove_all_scheduled_tasks_test() { + // async fn failed_task_query_test() { // let pool = pool().await; - // let mut test = PgAsyncQueue::new(pool); + // let mut test = PgTaskStore::new(pool); // - // let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0); + // let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); // - // let task1 = &AsyncTaskSchedule { - // number: 1, - // datetime: datetime.to_string(), - // }; + // let metadata = task.payload.as_object().unwrap(); + // let number = metadata["number"].as_u64(); + // let type_task = metadata["type"].as_str(); + // let id = task.id; // - // let task2 = &AsyncTaskSchedule { - // number: 2, - // datetime: datetime.to_string(), - // }; + // assert_eq!(Some(1), number); + // assert_eq!(Some("AsyncTask"), type_task); // - // test.schedule_task(task1).await.unwrap(); - // test.schedule_task(task2).await.unwrap(); + // let failed_task = test.set_task_state(task.id, TaskState::Failed("Some error".to_string())).await.unwrap(); // - // let number = test.remove_all_scheduled_tasks().await.unwrap(); - // - // assert_eq!(2, number); + // assert_eq!(id, failed_task.id); + // assert_eq!(Some("Some error"), failed_task.error_message.as_deref()); + // assert_eq!(TaskState::Failed, failed_task.state()); // // test.remove_all_tasks().await.unwrap(); // } - - #[tokio::test] - async fn pull_next_task_test() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(1), number); - assert_eq!(Some("AsyncTask"), type_task); - - let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(2), number); - assert_eq!(Some("AsyncTask"), type_task); - - let task = test.pull_next_task(None).await.unwrap().unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(1), number); - assert_eq!(Some("AsyncTask"), type_task); - - let task = test.pull_next_task(None).await.unwrap().unwrap(); - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(2), number); - assert_eq!(Some("AsyncTask"), type_task); - - test.remove_all_tasks().await.unwrap(); - } - - #[tokio::test] - async fn remove_tasks_type_test() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(1), number); - assert_eq!(Some("AsyncTask"), type_task); - - let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(2), number); - assert_eq!(Some("AsyncTask"), type_task); - - let result = test - .remove_tasks_type(TaskType::from("nonexistentType")) - .await - .unwrap(); - assert_eq!(0, result); - - let result = test.remove_tasks_type(TaskType::default()).await.unwrap(); - assert_eq!(2, result); - - test.remove_all_tasks().await.unwrap(); - } - - #[tokio::test] - async fn remove_tasks_by_metadata() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task = test - .create_task(&AsyncUniqTask { number: 1 }) - .await - .unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(1), number); - assert_eq!(Some("AsyncUniqTask"), type_task); - - let task = test - .create_task(&AsyncUniqTask { number: 2 }) - .await - .unwrap(); - - let metadata = task.payload.as_object().unwrap(); - let number = metadata["number"].as_u64(); - let type_task = metadata["type"].as_str(); - - assert_eq!(Some(2), number); - assert_eq!(Some("AsyncUniqTask"), type_task); - - let result = test - .remove_task_by_hash(AsyncUniqTask { number: 0 }.uniq().unwrap()) - .await - .unwrap(); - assert!(!result, "Should **not** remove task"); - - let result = test - .remove_task_by_hash(AsyncUniqTask { number: 1 }.uniq().unwrap()) - .await - .unwrap(); - assert!(result, "Should remove task"); - - test.remove_all_tasks().await.unwrap(); - } + // + // #[tokio::test] + // async fn remove_all_tasks_test() { + // let pool = pool().await; + // let mut test = PgTaskStore::new(pool); + // + // let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); + // + // let metadata = task.payload.as_object().unwrap(); + // let number = metadata["number"].as_u64(); + // let type_task = metadata["type"].as_str(); + // + // assert_eq!(Some(1), number); + // assert_eq!(Some("AsyncTask"), type_task); + // + // let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); + // + // let metadata = task.payload.as_object().unwrap(); + // let number = metadata["number"].as_u64(); + // let type_task = metadata["type"].as_str(); + // + // assert_eq!(Some(2), number); + // assert_eq!(Some("AsyncTask"), type_task); + // + // let result = test.remove_all_tasks().await.unwrap(); + // assert_eq!(2, result); + // } + // + // #[tokio::test] + // async fn pull_next_task_test() { + // let pool = pool().await; + // let mut queue = PgTaskStore::new(pool); + // + // let task = queue.create_task(&AsyncTask { number: 1 }).await.unwrap(); + // + // let metadata = task.payload.as_object().unwrap(); + // let number = metadata["number"].as_u64(); + // let type_task = metadata["type"].as_str(); + // + // assert_eq!(Some(1), number); + // assert_eq!(Some("AsyncTask"), type_task); + // + // let task = queue.create_task(&AsyncTask { number: 2 }).await.unwrap(); + // + // let metadata = task.payload.as_object().unwrap(); + // let number = metadata["number"].as_u64(); + // let type_task = metadata["type"].as_str(); + // + // assert_eq!(Some(2), number); + // assert_eq!(Some("AsyncTask"), type_task); + // + // let task = queue.pull_next_task(None).await.unwrap().unwrap(); + // + // let metadata = task.payload.as_object().unwrap(); + // let number = metadata["number"].as_u64(); + // let type_task = metadata["type"].as_str(); + // + // assert_eq!(Some(1), number); + // assert_eq!(Some("AsyncTask"), type_task); + // + // let task = queue.pull_next_task(None).await.unwrap().unwrap(); + // let metadata = task.payload.as_object().unwrap(); + // let number = metadata["number"].as_u64(); + // let type_task = metadata["type"].as_str(); + // + // assert_eq!(Some(2), number); + // assert_eq!(Some("AsyncTask"), type_task); + // + // queue.remove_all_tasks().await.unwrap(); + // } async fn pool() -> Pool { let manager = AsyncDieselConnectionManager::::new( diff --git a/src/runnable.rs b/src/runnable.rs index e2bfb1d..f6a9742 100644 --- a/src/runnable.rs +++ b/src/runnable.rs @@ -1,27 +1,34 @@ -use crate::queue::Queueable; -use crate::task::TaskHash; -use crate::task::TaskType; -use crate::Scheduled; +use crate::task::{CurrentTask, TaskHash}; use async_trait::async_trait; -use std::error::Error; - -pub const RETRIES_NUMBER: i32 = 5; +use serde::{de::DeserializeOwned, ser::Serialize}; /// Task that can be executed by the queue. /// -/// The `RunnableTask` trait is used to define the behaviour of a task. You must implement this +/// The `BackgroundTask` trait is used to define the behaviour of a task. You must implement this /// trait for all tasks you want to execute. -#[typetag::serde(tag = "type")] #[async_trait] -pub trait RunnableTask: Send + Sync { - /// Execute the task. This method should define its logic - async fn run(&self, queue: &mut dyn Queueable) -> Result<(), Box>; +pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static { + /// Unique name of the task. + /// + /// This MUST be unique for the whole application. + const TASK_NAME: &'static str; - /// Define the type of the task. - /// The `common` task type is used by default - fn task_type(&self) -> TaskType { - TaskType::default() - } + /// Task queue where this task will be executed. + /// + /// Used to define which workers are going to be executing this task. It uses the default + /// task queue if not changed. + const QUEUE: &'static str = "default"; + + /// Number of retries for tasks. + /// + /// By default, it is set to 5. + const MAX_RETRIES: i32 = 5; + + /// The application data provided to this task at runtime. + type AppData: Clone + Send + 'static; + + /// Execute the task. This method should define its logic + async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>; /// If set to true, no new tasks with the same metadata will be inserted /// By default it is set to false. @@ -29,27 +36,10 @@ pub trait RunnableTask: Send + Sync { None } - /// This method defines if a task is periodic or it should be executed once in the future. - /// - /// Be careful it works only with the UTC timezone. - /// - /// Example: - /// - /// ```rust - /// fn cron(&self) -> Option { - /// let expression = "0/20 * * * Aug-Sep * 2022/1"; - /// Some(Scheduled::CronPattern(expression.to_string())) - /// } - ///``` - /// In order to schedule a task once, use the `Scheduled::ScheduleOnce` enum variant. - fn cron(&self) -> Option { - None - } - /// Define the maximum number of retries the task will be retried. /// By default the number of retries is 20. fn max_retries(&self) -> i32 { - RETRIES_NUMBER + Self::MAX_RETRIES } /// Define the backoff mode diff --git a/src/schema.rs b/src/schema.rs new file mode 100644 index 0000000..0b28f0c --- /dev/null +++ b/src/schema.rs @@ -0,0 +1,19 @@ +// @generated automatically by Diesel CLI. + +diesel::table! { + backie_tasks (id) { + id -> Uuid, + task_name -> Varchar, + queue_name -> Varchar, + uniq_hash -> Nullable, + payload -> Jsonb, + timeout_msecs -> Int8, + created_at -> Timestamptz, + scheduled_at -> Timestamptz, + running_at -> Nullable, + done_at -> Nullable, + error_info -> Nullable, + retries -> Int4, + max_retries -> Int4, + } +} diff --git a/src/store.rs b/src/store.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/store.rs @@ -0,0 +1 @@ + diff --git a/src/task.rs b/src/task.rs index 6f4f00d..3cfa5c5 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,4 +1,5 @@ use crate::schema::backie_tasks; +use crate::BackgroundTask; use chrono::DateTime; use chrono::Utc; use diesel::prelude::*; @@ -8,11 +9,11 @@ use sha2::{Digest, Sha256}; use std::borrow::Cow; use std::fmt; use std::fmt::Display; -use typed_builder::TypedBuilder; +use std::time::Duration; use uuid::Uuid; /// States of a task. -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub enum TaskState { /// The task is ready to be executed. Ready, @@ -21,7 +22,7 @@ pub enum TaskState { Running, /// The task has failed to execute. - Failed, + Failed(String), /// The task finished successfully. Done, @@ -36,24 +37,6 @@ impl Display for TaskId { } } -#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)] -pub struct TaskType(Cow<'static, str>); - -impl Default for TaskType { - fn default() -> Self { - Self(Cow::from("default")) - } -} - -impl From for TaskType -where - S: AsRef + 'static, -{ - fn from(s: S) -> Self { - TaskType(Cow::from(s.as_ref().to_owned())) - } -} - #[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)] pub struct TaskHash(Cow<'static, str>); @@ -70,42 +53,55 @@ impl TaskHash { } } -#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone, TypedBuilder)] +#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone)] #[diesel(table_name = backie_tasks)] pub struct Task { - #[builder(setter(into))] + /// Unique identifier of the task. pub id: TaskId, - #[builder(setter(into))] - pub payload: serde_json::Value, + /// Name of the type of task. + pub task_name: String, - #[builder(setter(into))] - pub error_message: Option, + /// Queue name that the task belongs to. + pub queue_name: String, - #[builder(setter(into))] - pub task_type: TaskType, - - #[builder(setter(into))] + /// Unique hash is used to identify and avoid duplicate tasks. pub uniq_hash: Option, - #[builder(setter(into))] - pub retries: i32, + /// Representation of the task. + pub payload: serde_json::Value, - #[builder(setter(into))] + /// Max timeout that the task can run for. + pub timeout_msecs: i64, + + /// Creation time of the task. pub created_at: DateTime, - #[builder(setter(into))] + /// Date time when the task is scheduled to run. + pub scheduled_at: DateTime, + + /// Date time when the task is started to run. pub running_at: Option>, - #[builder(setter(into))] + /// Date time when the task is finished. pub done_at: Option>, + + /// Failure reason, when the task is failed. + pub error_info: Option, + + /// Number of times a task was retried. + pub retries: i32, + + /// Maximum number of retries allow for this task before it is maked as failure. + pub max_retries: i32, } impl Task { pub fn state(&self) -> TaskState { if self.done_at.is_some() { - if self.error_message.is_some() { - TaskState::Failed + if self.error_info.is_some() { + // TODO: use a proper error type + TaskState::Failed(self.error_info.clone().unwrap().to_string()) } else { TaskState::Done } @@ -117,22 +113,61 @@ impl Task { } } -#[derive(Insertable, Debug, Eq, PartialEq, Clone, TypedBuilder)] +#[derive(Insertable, Debug, Eq, PartialEq, Clone)] #[diesel(table_name = backie_tasks)] -pub struct NewTask { - #[builder(setter(into))] - payload: serde_json::Value, - - #[builder(setter(into))] - task_type: TaskType, - - #[builder(setter(into))] +pub(crate) struct NewTask { + task_name: String, + queue_name: String, uniq_hash: Option, + payload: serde_json::Value, + timeout_msecs: i64, + max_retries: i32, } -pub struct TaskInfo { +impl NewTask { + pub(crate) fn new(background_task: T, timeout: Duration) -> Result + where + T: BackgroundTask, + { + let max_retries = background_task.max_retries(); + let uniq_hash = background_task.uniq(); + let payload = serde_json::to_value(background_task)?; + + Ok(Self { + task_name: T::TASK_NAME.to_string(), + queue_name: T::QUEUE.to_string(), + uniq_hash, + payload, + timeout_msecs: timeout.as_millis() as i64, + max_retries, + }) + } +} + +pub struct CurrentTask { id: TaskId, - error_message: Option, retries: i32, created_at: DateTime, } + +impl CurrentTask { + pub(crate) fn new(task: &Task) -> Self { + Self { + id: task.id.clone(), + retries: task.retries, + created_at: task.created_at, + } + } + + pub fn id(&self) -> TaskId { + self.id + } + + pub fn retry_count(&self) -> i32 { + self.retries + } + + pub fn created_at(&self) -> DateTime { + self.created_at + } +} diff --git a/src/worker.rs b/src/worker.rs index aecc558..a75dee4 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,53 +1,104 @@ -use crate::errors::BackieError; -use crate::queue::Queueable; -use crate::runnable::RunnableTask; -use crate::task::{Task, TaskType}; -use crate::RetentionMode; -use crate::Scheduled::*; +use crate::errors::{AsyncQueueError, BackieError}; +use crate::runnable::BackgroundTask; +use crate::task::{CurrentTask, Task, TaskState}; +use crate::{PgTaskStore, RetentionMode}; use futures::future::FutureExt; use futures::select; -use std::error::Error; -use typed_builder::TypedBuilder; +use std::collections::BTreeMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use thiserror::Error; -/// it executes tasks only of task_type type, it sleeps when there are no tasks in the queue -#[derive(TypedBuilder)] -pub struct Worker -where - Q: Queueable + Clone + Sync + 'static, -{ - #[builder(setter(into))] - pub queue: Q, +pub type ExecuteTaskFn = Arc< + dyn Fn( + CurrentTask, + serde_json::Value, + AppData, + ) -> Pin> + Send>> + + Send + + Sync, +>; - #[builder(default, setter(into))] - pub task_type: Option, +pub type StateFn = Arc AppData + Send + Sync>; - #[builder(default, setter(into))] - pub retention_mode: RetentionMode, +#[derive(Debug, Error)] +pub enum TaskExecError { + #[error("Task execution failed: {0}")] + ExecutionFailed(#[from] anyhow::Error), - #[builder(default, setter(into))] - pub shutdown: Option>, + #[error("Task deserialization failed: {0}")] + TaskDeserializationFailed(#[from] serde_json::Error), } -impl Worker +pub(crate) fn runnable( + task_info: CurrentTask, + payload: serde_json::Value, + app_context: BT::AppData, +) -> Pin> + Send>> where - Q: Queueable + Clone + Sync + 'static, + BT: BackgroundTask, { + Box::pin(async move { + let background_task: BT = serde_json::from_value(payload)?; + background_task.run(task_info, app_context).await?; + Ok(()) + }) +} + +/// Worker that executes tasks. +pub struct Worker +where + AppData: Clone + Send + 'static, +{ + store: Arc, + + queue_name: String, + + retention_mode: RetentionMode, + + task_registry: BTreeMap>, + + app_data_fn: StateFn, + + /// Notification for the worker to stop. + shutdown: Option>, +} + +impl Worker +where + AppData: Clone + Send + 'static, +{ + pub(crate) fn new( + store: Arc, + queue_name: String, + retention_mode: RetentionMode, + task_registry: BTreeMap>, + app_data_fn: StateFn, + shutdown: Option>, + ) -> Self { + Self { + store, + queue_name, + retention_mode, + task_registry, + app_data_fn, + shutdown, + } + } + pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> { loop { - // Need to check if has to stop before pulling next task - match self.queue.pull_next_task(self.task_type.clone()).await? { - Some(task) => { - let actual_task: Box = - serde_json::from_value(task.payload.clone())?; + // Check if has to stop before pulling next task + if let Some(ref shutdown) = self.shutdown { + if shutdown.has_changed()? { + return Ok(()); + } + }; - // check if task is scheduled or not - if let Some(CronPattern(_)) = actual_task.cron() { - // program task - //self.queue.schedule_task(&*actual_task).await?; - } - // run scheduled task - // TODO: what do we do if the task fails? it's an internal error, inform the logs - let _ = self.run(task, actual_task).await; + match self.store.pull_next_task(&self.queue_name).await? { + Some(task) => { + self.run(task).await?; } None => { // Listen to watchable future @@ -73,41 +124,45 @@ where } } - #[cfg(test)] - pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> { - loop { - match self.queue.pull_next_task(self.task_type.clone()).await? { - Some(task) => { - let actual_task: Box = - serde_json::from_value(task.payload.clone()).unwrap(); + // #[cfg(test)] + // pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> { + // loop { + // match self.store.pull_next_task(self.queue_name.clone()).await? { + // Some(task) => { + // let actual_task: Box = + // serde_json::from_value(task.payload.clone()).unwrap(); + // + // // check if task is scheduled or not + // if let Some(CronPattern(_)) = actual_task.cron() { + // // program task + // // self.queue.schedule_task(&*actual_task).await?; + // } + // // run scheduled task + // self.run(task, actual_task).await?; + // } + // None => { + // return Ok(()); + // } + // }; + // } + // } - // check if task is scheduled or not - if let Some(CronPattern(_)) = actual_task.cron() { - // program task - // self.queue.schedule_task(&*actual_task).await?; - } - // run scheduled task - self.run(task, actual_task).await?; - } - None => { - return Ok(()); - } - }; - } - } + async fn run(&self, task: Task) -> Result<(), BackieError> { + let task_info = CurrentTask::new(&task); + let runnable_task_caller = self + .task_registry + .get(&task.task_name) + .ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?; - async fn run( - &mut self, - task: Task, - runnable: Box, - ) -> Result<(), BackieError> { // TODO: catch panics - let result = runnable.run(&mut self.queue).await; - match result { + let result: Result<(), TaskExecError> = + runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await; + + match &result { Ok(_) => self.finalize_task(task, result).await?, Err(error) => { - if task.retries < runnable.max_retries() { - let backoff_seconds = runnable.backoff(task.retries as u32); + if task.retries < task.max_retries { + let backoff_seconds = 5; // TODO: runnable.backoff(task.retries as u32); log::debug!( "Task {} failed to run and will be retried in {} seconds", @@ -115,12 +170,12 @@ where backoff_seconds ); let error_message = format!("{}", error); - self.queue + self.store .schedule_task_retry(task.id, backoff_seconds, &error_message) .await?; } else { log::debug!("Task {} failed and reached the maximum retries", task.id); - self.finalize_task(task, Err(error)).await?; + self.finalize_task(task, result).await?; } } } @@ -128,36 +183,36 @@ where } async fn finalize_task( - &mut self, + &self, task: Task, - result: Result<(), Box>, + result: Result<(), TaskExecError>, ) -> Result<(), BackieError> { match self.retention_mode { RetentionMode::KeepAll => match result { Ok(_) => { - self.queue.set_task_done(task.id).await?; + self.store.set_task_state(task.id, TaskState::Done).await?; log::debug!("Task {} done and kept in the database", task.id); } Err(error) => { log::debug!("Task {} failed and kept in the database", task.id); - self.queue - .set_task_failed(task.id, &format!("{}", error)) + self.store + .set_task_state(task.id, TaskState::Failed(format!("{}", error))) .await?; } }, RetentionMode::RemoveAll => { log::debug!("Task {} finalized and deleted from the database", task.id); - self.queue.remove_task(task.id).await?; + self.store.remove_task(task.id).await?; } RetentionMode::RemoveDone => match result { Ok(_) => { log::debug!("Task {} done and deleted from the database", task.id); - self.queue.remove_task(task.id).await?; + self.store.remove_task(task.id).await?; } Err(error) => { log::debug!("Task {} failed and kept in the database", task.id); - self.queue - .set_task_failed(task.id, &format!("{}", error)) + self.store + .set_task_state(task.id, TaskState::Failed(format!("{}", error))) .await?; } }, @@ -169,35 +224,19 @@ where #[cfg(test)] mod async_worker_tests { - use std::fmt::Display; use super::*; - use crate::queue::PgAsyncQueue; - use crate::queue::Queueable; - use crate::task::TaskState; - use crate::worker::Task; - use crate::RetentionMode; - use crate::Scheduled; use async_trait::async_trait; - use chrono::Duration; - use chrono::Utc; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::AsyncPgConnection; use serde::{Deserialize, Serialize}; - use thiserror::Error; - #[derive(Error, Debug)] + #[derive(thiserror::Error, Debug)] enum TaskError { + #[error("Something went wrong")] SomethingWrong, - Custom(String), - } - impl Display for TaskError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - TaskError::SomethingWrong => write!(f, "Something went wrong"), - TaskError::Custom(message) => write!(f, "{}", message), - } - } + #[error("{0}")] + Custom(String), } #[derive(Serialize, Deserialize)] @@ -205,13 +244,12 @@ mod async_worker_tests { pub number: u16, } - #[typetag::serde] #[async_trait] - impl RunnableTask for WorkerAsyncTask { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { + impl BackgroundTask for WorkerAsyncTask { + const TASK_NAME: &'static str = "WorkerAsyncTask"; + type AppData = (); + + async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } } @@ -221,18 +259,18 @@ mod async_worker_tests { pub number: u16, } - #[typetag::serde] #[async_trait] - impl RunnableTask for WorkerAsyncTaskSchedule { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { + impl BackgroundTask for WorkerAsyncTaskSchedule { + const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule"; + type AppData = (); + + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } - fn cron(&self) -> Option { - Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1))) - } + + // fn cron(&self) -> Option { + // Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1))) + // } } #[derive(Serialize, Deserialize)] @@ -240,16 +278,15 @@ mod async_worker_tests { pub number: u16, } - #[typetag::serde] #[async_trait] - impl RunnableTask for AsyncFailedTask { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { + impl BackgroundTask for AsyncFailedTask { + const TASK_NAME: &'static str = "AsyncFailedTask"; + type AppData = (); + + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { let message = format!("number {} is wrong :(", self.number); - Err(Box::new(TaskError::Custom(message))) + Err(TaskError::Custom(message).into()) } fn max_retries(&self) -> i32 { @@ -260,282 +297,40 @@ mod async_worker_tests { #[derive(Serialize, Deserialize, Clone)] struct AsyncRetryTask {} - #[typetag::serde] #[async_trait] - impl RunnableTask for AsyncRetryTask { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { - Err(Box::new(TaskError::SomethingWrong)) - } + impl BackgroundTask for AsyncRetryTask { + const TASK_NAME: &'static str = "AsyncRetryTask"; + type AppData = (); - fn max_retries(&self) -> i32 { - 2 + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { + Err(TaskError::SomethingWrong.into()) } } #[derive(Serialize, Deserialize)] struct AsyncTaskType1 {} - #[typetag::serde] #[async_trait] - impl RunnableTask for AsyncTaskType1 { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { - Ok(()) - } + impl BackgroundTask for AsyncTaskType1 { + const TASK_NAME: &'static str = "AsyncTaskType1"; + type AppData = (); - fn task_type(&self) -> TaskType { - "type1".into() + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { + Ok(()) } } #[derive(Serialize, Deserialize)] struct AsyncTaskType2 {} - #[typetag::serde] #[async_trait] - impl RunnableTask for AsyncTaskType2 { - async fn run( - &self, - _queueable: &mut dyn Queueable, - ) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { + impl BackgroundTask for AsyncTaskType2 { + const TASK_NAME: &'static str = "AsyncTaskType2"; + type AppData = (); + + async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } - - fn task_type(&self) -> TaskType { - TaskType::from("type2") - } - } - - #[tokio::test] - async fn execute_and_finishes_task() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let actual_task = WorkerAsyncTask { number: 1 }; - - let task = insert_task(&mut test, &actual_task).await; - let id = task.id; - - let mut worker = Worker::::builder() - .queue(test.clone()) - .retention_mode(RetentionMode::KeepAll) - .build(); - - worker.run(task, Box::new(actual_task)).await.unwrap(); - let task_finished = test.find_task_by_id(id).await.unwrap(); - assert_eq!(id, task_finished.id); - assert_eq!(TaskState::Done, task_finished.state()); - - test.remove_all_tasks().await.unwrap(); - } - - // #[tokio::test] - // async fn schedule_task_test() { - // let pool = pool().await; - // let mut test = PgAsyncQueue::new(pool); - // - // let actual_task = WorkerAsyncTaskSchedule { number: 1 }; - // - // let task = test.schedule_task(&actual_task).await.unwrap(); - // - // let id = task.id; - // - // let mut worker = AsyncWorker::::builder() - // .queue(test.clone()) - // .retention_mode(RetentionMode::KeepAll) - // .build(); - // - // worker.run_tasks_until_none().await.unwrap(); - // - // let task = worker.queue.find_task_by_id(id).await.unwrap(); - // - // assert_eq!(id, task.id); - // assert_eq!(TaskState::Ready, task.state()); - // - // tokio::time::sleep(core::time::Duration::from_secs(3)).await; - // - // worker.run_tasks_until_none().await.unwrap(); - // - // let task = test.find_task_by_id(id).await.unwrap(); - // assert_eq!(id, task.id); - // assert_eq!(TaskState::Done, task.state()); - // - // test.remove_all_tasks().await.unwrap(); - // } - - #[tokio::test] - async fn retries_task_test() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let actual_task = AsyncRetryTask {}; - - let task = test.create_task(&actual_task).await.unwrap(); - - let id = task.id; - - let mut worker = Worker::::builder() - .queue(test.clone()) - .retention_mode(RetentionMode::KeepAll) - .build(); - - worker.run_tasks_until_none().await.unwrap(); - - let task = worker.queue.find_task_by_id(id).await.unwrap(); - - assert_eq!(id, task.id); - assert_eq!(TaskState::Ready, task.state()); - assert_eq!(1, task.retries); - assert!(task.error_message.is_some()); - - tokio::time::sleep(core::time::Duration::from_secs(5)).await; - worker.run_tasks_until_none().await.unwrap(); - - let task = worker.queue.find_task_by_id(id).await.unwrap(); - - assert_eq!(id, task.id); - assert_eq!(TaskState::Ready, task.state()); - assert_eq!(2, task.retries); - - tokio::time::sleep(core::time::Duration::from_secs(10)).await; - worker.run_tasks_until_none().await.unwrap(); - - let task = test.find_task_by_id(id).await.unwrap(); - assert_eq!(id, task.id); - assert_eq!(TaskState::Failed, task.state()); - assert_eq!("Something went wrong".to_string(), task.error_message.unwrap()); - - test.remove_all_tasks().await.unwrap(); - } - - #[tokio::test] - async fn worker_shutsdown_when_notified() { - let pool = pool().await; - let queue = PgAsyncQueue::new(pool); - - let (tx, rx) = tokio::sync::watch::channel(()); - - let mut worker = Worker::::builder() - .queue(queue) - .shutdown(rx) - .build(); - - let handle = tokio::spawn(async move { - worker.run_tasks().await.unwrap(); - true - }); - - tx.send(()).unwrap(); - select! { - _ = handle.fuse() => {} - _ = tokio::time::sleep(core::time::Duration::from_secs(1)).fuse() => panic!("Worker did not shutdown") - } - } - - #[tokio::test] - async fn saves_error_for_failed_task() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let failed_task = AsyncFailedTask { number: 1 }; - - let task = insert_task(&mut test, &failed_task).await; - let id = task.id; - - let mut worker = Worker::::builder() - .queue(test.clone()) - .retention_mode(RetentionMode::KeepAll) - .build(); - - worker.run(task, Box::new(failed_task)).await.unwrap(); - let task_finished = test.find_task_by_id(id).await.unwrap(); - - assert_eq!(id, task_finished.id); - assert_eq!(TaskState::Failed, task_finished.state()); - assert_eq!( - "number 1 is wrong :(".to_string(), - task_finished.error_message.unwrap() - ); - - test.remove_all_tasks().await.unwrap(); - } - - #[tokio::test] - async fn executes_task_only_of_specific_type() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await; - let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await; - let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await; - - let id1 = task1.id; - let id12 = task12.id; - let id2 = task2.id; - - let mut worker = Worker::::builder() - .queue(test.clone()) - .task_type(TaskType::from("type1")) - .retention_mode(RetentionMode::KeepAll) - .build(); - - worker.run_tasks_until_none().await.unwrap(); - let task1 = test.find_task_by_id(id1).await.unwrap(); - let task12 = test.find_task_by_id(id12).await.unwrap(); - let task2 = test.find_task_by_id(id2).await.unwrap(); - - assert_eq!(id1, task1.id); - assert_eq!(id12, task12.id); - assert_eq!(id2, task2.id); - assert_eq!(TaskState::Done, task1.state()); - assert_eq!(TaskState::Done, task12.state()); - assert_eq!(TaskState::Ready, task2.state()); - - test.remove_all_tasks().await.unwrap(); - } - - #[tokio::test] - async fn remove_when_finished() { - let pool = pool().await; - let mut test = PgAsyncQueue::new(pool); - - let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await; - let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await; - let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await; - - let _id1 = task1.id; - let _id12 = task12.id; - let id2 = task2.id; - - let mut worker = Worker::::builder() - .queue(test.clone()) - .task_type(TaskType::from("type1")) - .build(); - - worker.run_tasks_until_none().await.unwrap(); - let task = test - .pull_next_task(Some(TaskType::from("type1"))) - .await - .unwrap(); - assert_eq!(None, task); - - let task2 = test - .pull_next_task(Some(TaskType::from("type2"))) - .await - .unwrap() - .unwrap(); - assert_eq!(id2, task2.id); - - test.remove_all_tasks().await.unwrap(); - } - - async fn insert_task(test: &mut PgAsyncQueue, task: &dyn RunnableTask) -> Task { - test.create_task(task).await.unwrap() } async fn pool() -> Pool { diff --git a/src/worker_pool.rs b/src/worker_pool.rs index 63a0874..f41a33b 100644 --- a/src/worker_pool.rs +++ b/src/worker_pool.rs @@ -1,102 +1,200 @@ use crate::errors::BackieError; -use crate::queue::Queueable; -use crate::task::TaskType; -use crate::worker::Worker; -use crate::RetentionMode; -use async_recursion::async_recursion; -use log::error; +use crate::queue::Queue; +use crate::worker::{runnable, ExecuteTaskFn}; +use crate::worker::{StateFn, Worker}; +use crate::{BackgroundTask, CurrentTask, PgTaskStore, RetentionMode}; +use std::collections::BTreeMap; use std::future::Future; -use tokio::sync::watch::Receiver; -use typed_builder::TypedBuilder; +use std::sync::Arc; +use tokio::task::JoinHandle; -#[derive(TypedBuilder, Clone)] -pub struct WorkerPool +pub type AppDataFn = Arc AppData + Send + Sync>; + +#[derive(Clone)] +pub struct WorkerPool where - AQueue: Queueable + Clone + Sync + 'static, + AppData: Clone + Send + 'static, { - #[builder(setter(into))] - /// the AsyncWorkerPool uses a queue to control the tasks that will be executed. - pub queue: AQueue, + /// Storage of tasks. + queue_store: Arc, // TODO: make this generic/dynamic referenced - /// retention_mode controls if tasks should be persisted after execution - #[builder(default, setter(into))] - pub retention_mode: RetentionMode, + /// Queue used to spawn tasks. + queue: Queue, - /// the number of workers of the AsyncWorkerPool. - #[builder(setter(into))] - pub number_of_workers: u32, + /// Make possible to load the application data. + /// + /// The application data is loaded when the worker pool is started and is passed to the tasks. + /// The loading function accepts a queue instance in case the application data depends on it. This + /// is interesting for situations where the application wants to allow tasks to spawn other tasks. + application_data_fn: StateFn, - /// The type of tasks that will be executed by `AsyncWorkerPool`. - #[builder(default, setter(into))] - pub task_type: Option, + /// The types of task the worker pool can execute and the loaders for them. + task_registry: BTreeMap>, + + /// Number of workers that will be spawned per queue. + worker_queues: BTreeMap, } -// impl AsyncWorkerBuilder -// where -// TypedBuilderFields: Clone, -// Q: Queueable + Clone + Sync + 'static, -// { -// pub fn with_graceful_shutdown(self, signal: F) -> Self -// where -// F: Future, -// { -// self -// } -// } - -impl WorkerPool +impl WorkerPool where - AQueue: Queueable + Clone + Sync + 'static, + AppData: Clone + Send + 'static, { - /// Starts the configured number of workers - /// This is necessary in order to execute tasks. - pub async fn start(&mut self, graceful_shutdown: F) -> Result<(), BackieError> + /// Create a new worker pool. + pub fn new(queue_store: PgTaskStore, application_data_fn: A) -> Self + where + A: Fn(Queue) -> AppData + Send + Sync + 'static, + { + let queue_store = Arc::new(queue_store); + let queue = Queue::new(queue_store.clone()); + let application_data_fn = { + let queue = queue.clone(); + move || application_data_fn(queue.clone()) + }; + Self { + queue_store, + queue, + application_data_fn: Arc::new(application_data_fn), + task_registry: BTreeMap::new(), + worker_queues: BTreeMap::new(), + } + } + + /// Register a task type with the worker pool. + pub fn register_task_type(mut self, num_workers: u32, retention_mode: RetentionMode) -> Self + where + BT: BackgroundTask, + { + self.worker_queues + .insert(BT::QUEUE.to_string(), (retention_mode, num_workers)); + self.task_registry + .insert(BT::TASK_NAME.to_string(), Arc::new(runnable::)); + self + } + + pub async fn start( + self, + graceful_shutdown: F, + ) -> Result<(JoinHandle<()>, Queue), BackieError> where F: Future + Send + 'static, { let (tx, rx) = tokio::sync::watch::channel(()); - for idx in 0..self.number_of_workers { - let pool = self.clone(); - // TODO: the worker pool keeps track of the number of workers and spawns new workers as needed. - // There should be always a minimum number of workers active waiting for tasks to execute - // or for a gracefull shutdown. - tokio::spawn(Self::supervise_task(pool, rx.clone(), 0, idx)); + + // Spawn all individual workers per queue + for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() { + for idx in 0..*num_workers { + let mut worker: Worker = Worker::new( + self.queue_store.clone(), + queue_name.clone(), + retention_mode.clone(), + self.task_registry.clone(), + self.application_data_fn.clone(), + Some(rx.clone()), + ); + let worker_name = format!("worker-{queue_name}-{idx}"); + // TODO: grab the join handle for every worker for graceful shutdown + tokio::spawn(async move { + worker.run_tasks().await; + log::info!("Worker {} stopped", worker_name); + }); + } } - graceful_shutdown.await; - tx.send(())?; - log::info!("Worker pool stopped gracefully"); - Ok(()) - } - #[async_recursion] - async fn supervise_task( - pool: WorkerPool, - receiver: Receiver<()>, - restarts: u64, - worker_number: u32, - ) { - let restarts = restarts + 1; - - let inner_pool = pool.clone(); - let inner_receiver = receiver.clone(); - - let join_handle = tokio::spawn(async move { - let mut worker: Worker = Worker::builder() - .queue(inner_pool.queue.clone()) - .retention_mode(inner_pool.retention_mode) - .task_type(inner_pool.task_type.clone()) - .shutdown(inner_receiver) - .build(); - - worker.run_tasks().await - }); - - if (join_handle.await).is_err() { - error!( - "Worker {} stopped. Restarting. the number of restarts {}", - worker_number, restarts, - ); - Self::supervise_task(pool, receiver, restarts, worker_number).await; - } + Ok(( + tokio::spawn(async move { + graceful_shutdown.await; + if let Err(err) = tx.send(()) { + log::warn!("Failed to send shutdown signal to worker pool: {}", err); + } else { + log::info!("Worker pool stopped gracefully"); + } + }), + self.queue, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; + use diesel_async::AsyncPgConnection; + + #[derive(Clone, Debug)] + struct ApplicationContext { + app_name: String, + } + + impl ApplicationContext { + fn new() -> Self { + Self { + app_name: "Backie".to_string(), + } + } + + fn get_app_name(&self) -> String { + self.app_name.clone() + } + } + + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] + struct GreetingTask { + person: String, + } + + #[async_trait] + impl BackgroundTask for GreetingTask { + const TASK_NAME: &'static str = "my_task"; + + type AppData = ApplicationContext; + + async fn run( + &self, + task_info: CurrentTask, + app_context: Self::AppData, + ) -> Result<(), anyhow::Error> { + println!( + "[{}] Hello {}! I'm {}.", + task_info.id(), + self.person, + app_context.get_app_name() + ); + Ok(()) + } + } + + #[tokio::test] + async fn test_worker_pool() { + let my_app_context = ApplicationContext::new(); + + let task_store = PgTaskStore::new(pool().await); + + let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone()) + .register_task_type::(1, RetentionMode::RemoveDone) + .start(futures::future::ready(())) + .await + .unwrap(); + + queue + .enqueue(GreetingTask { + person: "Rafael".to_string(), + }) + .await + .unwrap(); + + join_handle.await.unwrap(); + } + + async fn pool() -> Pool { + let manager = AsyncDieselConnectionManager::::new( + option_env!("DATABASE_URL").expect("DATABASE_URL must be set"), + ); + Pool::builder() + .max_size(1) + .min_idle(Some(1)) + .build(manager) + .await + .unwrap() } }