diff --git a/Cargo.toml b/Cargo.toml index fe77e32..dc3832a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,3 +31,6 @@ diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uui diesel-derive-newtype = "2.0.0-rc.0" diesel-async = { version = "0.2", features = ["postgres", "bb8"] } tokio = { version = "1.25", features = ["rt", "time", "macros"] } + +[dev-dependencies] +itertools = "0.10" diff --git a/src/lib.rs b/src/lib.rs index 70070db..551ef29 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,22 +1,7 @@ +//#![warn(missing_docs)] +#![forbid(unsafe_code)] #![doc = include_str!("../README.md")] -use chrono::{DateTime, Utc}; - -/// Represents a schedule for scheduled tasks. -/// -/// It's used in the [`BackgroundTask::cron`] -#[derive(Debug, Clone)] -pub enum Scheduled { - /// A cron pattern for a periodic task - /// - /// For example, `Scheduled::CronPattern("0/20 * * * * * *")` - CronPattern(String), - /// A datetime for a scheduled task that will be executed once - /// - /// For example, `Scheduled::ScheduleOnce(chrono::Utc::now() + std::time::Duration::seconds(7i64))` - ScheduleOnce(DateTime), -} - /// All possible options for retaining tasks in the db after their execution. /// /// The default mode is [`RetentionMode::RemoveAll`] diff --git a/src/queries.rs b/src/queries.rs index 2e80340..e3bd774 100644 --- a/src/queries.rs +++ b/src/queries.rs @@ -44,9 +44,6 @@ impl Task { ) -> Result { use crate::schema::backie_tasks::dsl; - let now = Utc::now(); - let scheduled_at = now + Duration::seconds(backoff_seconds as i64); - let error = serde_json::json!({ "error": error_message, }); @@ -55,7 +52,8 @@ impl Task { .set(( backie_tasks::error_info.eq(Some(error)), backie_tasks::retries.eq(dsl::retries + 1), - backie_tasks::scheduled_at.eq(scheduled_at), + backie_tasks::scheduled_at + .eq(Utc::now() + Duration::seconds(backoff_seconds as i64)), backie_tasks::running_at.eq::>>(None), )) .get_result::(connection) diff --git a/src/queue.rs b/src/queue.rs index c21e474..4119091 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -1,17 +1,23 @@ use crate::errors::BackieError; use crate::runnable::BackgroundTask; -use crate::store::{PgTaskStore, TaskStore}; -use crate::task::{NewTask, TaskHash}; +use crate::store::TaskStore; +use crate::task::NewTask; use std::sync::Arc; use std::time::Duration; #[derive(Clone)] -pub struct Queue { - task_store: Arc, +pub struct Queue +where + S: TaskStore, +{ + task_store: Arc, } -impl Queue { - pub(crate) fn new(task_store: Arc) -> Self { +impl Queue +where + S: TaskStore, +{ + pub(crate) fn new(task_store: Arc) -> Self { Queue { task_store } } @@ -25,200 +31,3 @@ impl Queue { Ok(()) } } - -#[cfg(test)] -mod async_queue_tests { - use super::*; - use crate::CurrentTask; - use async_trait::async_trait; - use serde::{Deserialize, Serialize}; - - #[derive(Serialize, Deserialize)] - struct AsyncTask { - pub number: u16, - } - - #[async_trait] - impl BackgroundTask for AsyncTask { - const TASK_NAME: &'static str = "AsyncUniqTask"; - type AppData = (); - - async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { - Ok(()) - } - } - - #[derive(Serialize, Deserialize)] - struct AsyncUniqTask { - pub number: u16, - } - - #[async_trait] - impl BackgroundTask for AsyncUniqTask { - const TASK_NAME: &'static str = "AsyncUniqTask"; - type AppData = (); - - async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { - Ok(()) - } - - fn uniq(&self) -> Option { - TaskHash::default_for_task(self).ok() - } - } - - #[derive(Serialize, Deserialize)] - struct AsyncTaskSchedule { - pub number: u16, - pub datetime: String, - } - - #[async_trait] - impl BackgroundTask for AsyncTaskSchedule { - const TASK_NAME: &'static str = "AsyncUniqTask"; - type AppData = (); - - async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { - Ok(()) - } - - // fn cron(&self) -> Option { - // let datetime = self.datetime.parse::>().ok()?; - // Some(Scheduled::ScheduleOnce(datetime)) - // } - } - - // #[tokio::test] - // async fn insert_task_creates_new_task() { - // let pool = pool().await; - // let mut queue = PgTaskStore::new(pool); - // - // let task = queue.create_task(AsyncTask { number: 1 }).await.unwrap(); - // - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // - // assert_eq!(Some(1), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // queue.remove_all_tasks().await.unwrap(); - // } - // - // #[tokio::test] - // async fn update_task_state_test() { - // let pool = pool().await; - // let mut test = PgTaskStore::new(pool); - // - // let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - // - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // let id = task.id; - // - // assert_eq!(Some(1), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // let finished_task = test.set_task_state(task.id, TaskState::Done).await.unwrap(); - // - // assert_eq!(id, finished_task.id); - // assert_eq!(TaskState::Done, finished_task.state()); - // - // test.remove_all_tasks().await.unwrap(); - // } - // - // #[tokio::test] - // async fn failed_task_query_test() { - // let pool = pool().await; - // let mut test = PgTaskStore::new(pool); - // - // let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - // - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // let id = task.id; - // - // assert_eq!(Some(1), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // let failed_task = test.set_task_state(task.id, TaskState::Failed("Some error".to_string())).await.unwrap(); - // - // assert_eq!(id, failed_task.id); - // assert_eq!(Some("Some error"), failed_task.error_message.as_deref()); - // assert_eq!(TaskState::Failed, failed_task.state()); - // - // test.remove_all_tasks().await.unwrap(); - // } - // - // #[tokio::test] - // async fn remove_all_tasks_test() { - // let pool = pool().await; - // let mut test = PgTaskStore::new(pool); - // - // let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); - // - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // - // assert_eq!(Some(1), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); - // - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // - // assert_eq!(Some(2), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // let result = test.remove_all_tasks().await.unwrap(); - // assert_eq!(2, result); - // } - // - // #[tokio::test] - // async fn pull_next_task_test() { - // let pool = pool().await; - // let mut queue = PgTaskStore::new(pool); - // - // let task = queue.create_task(&AsyncTask { number: 1 }).await.unwrap(); - // - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // - // assert_eq!(Some(1), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // let task = queue.create_task(&AsyncTask { number: 2 }).await.unwrap(); - // - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // - // assert_eq!(Some(2), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // let task = queue.pull_next_task(None).await.unwrap().unwrap(); - // - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // - // assert_eq!(Some(1), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // let task = queue.pull_next_task(None).await.unwrap().unwrap(); - // let metadata = task.payload.as_object().unwrap(); - // let number = metadata["number"].as_u64(); - // let type_task = metadata["type"].as_str(); - // - // assert_eq!(Some(2), number); - // assert_eq!(Some("AsyncTask"), type_task); - // - // queue.remove_all_tasks().await.unwrap(); - // } -} diff --git a/src/store.rs b/src/store.rs index 97e7e4b..fa46218 100644 --- a/src/store.rs +++ b/src/store.rs @@ -38,6 +38,7 @@ impl TaskStore for PgTaskStore { }) .await } + async fn create_task(&self, new_task: NewTask) -> Result { let mut connection = self .pool @@ -46,6 +47,7 @@ impl TaskStore for PgTaskStore { .map_err(|e| QueryBuilderError(e.into()))?; Task::insert(&mut connection, new_task).await } + async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError> { let mut connection = self .pool @@ -53,14 +55,17 @@ impl TaskStore for PgTaskStore { .await .map_err(|e| QueryBuilderError(e.into()))?; match state { - TaskState::Done => Task::set_done(&mut connection, id).await?, - TaskState::Failed(error_msg) => { - Task::fail_with_message(&mut connection, id, &error_msg).await? + TaskState::Done => { + Task::set_done(&mut connection, id).await?; } - _ => return Ok(()), + TaskState::Failed(error_msg) => { + Task::fail_with_message(&mut connection, id, &error_msg).await?; + } + _ => (), }; Ok(()) } + async fn remove_task(&self, id: TaskId) -> Result { let mut connection = self .pool @@ -70,6 +75,7 @@ impl TaskStore for PgTaskStore { let result = Task::remove(&mut connection, id).await?; Ok(result) } + async fn schedule_task_retry( &self, id: TaskId, @@ -86,8 +92,111 @@ impl TaskStore for PgTaskStore { } } +#[cfg(test)] +pub mod test_store { + use super::*; + use itertools::Itertools; + use std::collections::BTreeMap; + use std::sync::Arc; + use tokio::sync::Mutex; + + #[derive(Clone)] + pub struct MemoryTaskStore { + tasks: Arc>>, + } + + impl MemoryTaskStore { + pub fn new() -> Self { + MemoryTaskStore { + tasks: Arc::new(Mutex::new(BTreeMap::new())), + } + } + } + + #[async_trait::async_trait] + impl TaskStore for MemoryTaskStore { + async fn pull_next_task(&self, queue_name: &str) -> Result, AsyncQueueError> { + let mut tasks = self.tasks.lock().await; + let mut next_task = None; + for (_, task) in tasks + .iter_mut() + .sorted_by(|a, b| a.1.created_at.cmp(&b.1.created_at)) + { + if task.queue_name == queue_name && task.state() == TaskState::Ready { + task.running_at = Some(chrono::Utc::now()); + next_task = Some(task.clone()); + break; + } + } + Ok(next_task) + } + + async fn create_task(&self, new_task: NewTask) -> Result { + let mut tasks = self.tasks.lock().await; + let task = Task::from(new_task); + tasks.insert(task.id, task.clone()); + Ok(task) + } + + async fn set_task_state( + &self, + id: TaskId, + state: TaskState, + ) -> Result<(), AsyncQueueError> { + let mut tasks = self.tasks.lock().await; + let task = tasks.get_mut(&id).unwrap(); + + use TaskState::*; + match state { + Done => task.done_at = Some(chrono::Utc::now()), + Failed(error_msg) => { + let error_payload = serde_json::json!({ + "error": error_msg, + }); + task.error_info = Some(error_payload); + task.done_at = Some(chrono::Utc::now()); + } + _ => {} + } + + Ok(()) + } + + async fn remove_task(&self, id: TaskId) -> Result { + let mut tasks = self.tasks.lock().await; + let res = tasks.remove(&id); + if res.is_some() { + Ok(1) + } else { + Ok(0) + } + } + + async fn schedule_task_retry( + &self, + id: TaskId, + backoff_seconds: u32, + error: &str, + ) -> Result { + let mut tasks = self.tasks.lock().await; + let task = tasks.get_mut(&id).unwrap(); + + let error_payload = serde_json::json!({ + "error": error, + }); + task.error_info = Some(error_payload); + task.running_at = None; + task.retries += 1; + task.scheduled_at = + chrono::Utc::now() + chrono::Duration::seconds(backoff_seconds as i64); + + Ok(task.clone()) + } + } +} + #[async_trait::async_trait] -pub trait TaskStore { +pub trait TaskStore: Clone + Send + Sync + 'static { async fn pull_next_task(&self, queue_name: &str) -> Result, AsyncQueueError>; async fn create_task(&self, new_task: NewTask) -> Result; async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError>; diff --git a/src/task.rs b/src/task.rs index d074846..a64945b 100644 --- a/src/task.rs +++ b/src/task.rs @@ -28,7 +28,7 @@ pub enum TaskState { Done, } -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)] +#[derive(Clone, Copy, Debug, Ord, PartialOrd, Hash, PartialEq, Eq, DieselNewType, Serialize)] pub struct TaskId(Uuid); impl Display for TaskId { @@ -144,6 +144,27 @@ impl NewTask { } } +#[cfg(test)] +impl From for Task { + fn from(new_task: NewTask) -> Self { + Self { + id: TaskId(Uuid::new_v4()), + task_name: new_task.task_name, + queue_name: new_task.queue_name, + uniq_hash: new_task.uniq_hash, + payload: new_task.payload, + timeout_msecs: new_task.timeout_msecs, + created_at: Utc::now(), + scheduled_at: Utc::now(), + running_at: None, + done_at: None, + error_info: None, + retries: 0, + max_retries: new_task.max_retries, + } + } +} + pub struct CurrentTask { id: TaskId, retries: i32, diff --git a/src/worker.rs b/src/worker.rs index cc7dfca..3fa2bf4 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -2,7 +2,7 @@ use crate::errors::{AsyncQueueError, BackieError}; use crate::runnable::BackgroundTask; use crate::store::TaskStore; use crate::task::{CurrentTask, Task, TaskState}; -use crate::{PgTaskStore, RetentionMode}; +use crate::RetentionMode; use futures::future::FutureExt; use futures::select; use std::collections::BTreeMap; @@ -48,11 +48,12 @@ where } /// Worker that executes tasks. -pub struct Worker +pub struct Worker where AppData: Clone + Send + 'static, + S: TaskStore, { - store: Arc, + store: Arc, queue_name: String, @@ -66,12 +67,13 @@ where shutdown: Option>, } -impl Worker +impl Worker where AppData: Clone + Send + 'static, + S: TaskStore, { pub(crate) fn new( - store: Arc, + store: Arc, queue_name: String, retention_mode: RetentionMode, task_registry: BTreeMap>, diff --git a/src/worker_pool.rs b/src/worker_pool.rs index 67c6487..2fcfd6b 100644 --- a/src/worker_pool.rs +++ b/src/worker_pool.rs @@ -1,25 +1,26 @@ use crate::errors::BackieError; use crate::queue::Queue; +use crate::runnable::BackgroundTask; +use crate::store::TaskStore; use crate::worker::{runnable, ExecuteTaskFn}; use crate::worker::{StateFn, Worker}; -use crate::{BackgroundTask, PgTaskStore, RetentionMode}; +use crate::RetentionMode; use std::collections::BTreeMap; use std::future::Future; use std::sync::Arc; use tokio::task::JoinHandle; -pub type AppDataFn = Arc AppData + Send + Sync>; - #[derive(Clone)] -pub struct WorkerPool +pub struct WorkerPool where AppData: Clone + Send + 'static, + S: TaskStore, { /// Storage of tasks. - queue_store: Arc, // TODO: make this generic/dynamic referenced + task_store: Arc, /// Queue used to spawn tasks. - queue: Queue, + queue: Queue, /// Make possible to load the application data. /// @@ -38,14 +39,15 @@ where worker_queues: BTreeMap, } -impl WorkerPool +impl WorkerPool where AppData: Clone + Send + 'static, + S: TaskStore, { /// Create a new worker pool. - pub fn new(task_store: PgTaskStore, application_data_fn: A) -> Self + pub fn new(task_store: S, application_data_fn: A) -> Self where - A: Fn(Queue) -> AppData + Send + Sync + 'static, + A: Fn(Queue) -> AppData + Send + Sync + 'static, { let queue_store = Arc::new(task_store); let queue = Queue::new(queue_store.clone()); @@ -54,7 +56,7 @@ where move || application_data_fn(queue.clone()) }; Self { - queue_store, + task_store: queue_store, queue, application_data_fn: Arc::new(application_data_fn), task_registry: BTreeMap::new(), @@ -91,7 +93,7 @@ where pub async fn start( self, graceful_shutdown: F, - ) -> Result<(JoinHandle<()>, Queue), BackieError> + ) -> Result<(JoinHandle<()>, Queue), BackieError> where F: Future + Send + 'static, { @@ -107,8 +109,8 @@ where // Spawn all individual workers per queue for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() { for idx in 0..*num_workers { - let mut worker: Worker = Worker::new( - self.queue_store.clone(), + let mut worker: Worker = Worker::new( + self.task_store.clone(), queue_name.clone(), retention_mode.clone(), self.task_registry.clone(), @@ -143,7 +145,10 @@ where #[cfg(test)] mod tests { use super::*; + use crate::store::test_store::MemoryTaskStore; + use crate::store::PgTaskStore; use crate::task::CurrentTask; + use anyhow::Error; use async_trait::async_trait; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::AsyncPgConnection; @@ -191,11 +196,32 @@ mod tests { } } + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] + struct OtherTask; + + #[async_trait] + impl BackgroundTask for OtherTask { + const TASK_NAME: &'static str = "other_task"; + + const QUEUE: &'static str = "other_queue"; + + type AppData = ApplicationContext; + + async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Error> { + println!( + "[{}] Other task with {}!", + task.id(), + context.get_app_name() + ); + Ok(()) + } + } + #[tokio::test] async fn validate_all_registered_tasks_queues_are_configured() { let my_app_context = ApplicationContext::new(); - let result = WorkerPool::new(task_store().await, move |_| my_app_context.clone()) + let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) .register_task_type::() .start(futures::future::ready(())) .await; @@ -210,11 +236,11 @@ mod tests { } #[tokio::test] - async fn test_worker_pool() { + async fn test_worker_pool_with_task() { let my_app_context = ApplicationContext::new(); let (join_handle, queue) = - WorkerPool::new(task_store().await, move |_| my_app_context.clone()) + WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) .register_task_type::() .configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone) .start(futures::future::ready(())) @@ -231,7 +257,53 @@ mod tests { join_handle.await.unwrap(); } - async fn task_store() -> PgTaskStore { + #[tokio::test] + async fn test_worker_pool_with_multiple_task_types() { + let my_app_context = ApplicationContext::new(); + + let (join_handle, queue) = + WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) + .register_task_type::() + .register_task_type::() + .configure_queue("default", 1, RetentionMode::default()) + .configure_queue("other_queue", 1, RetentionMode::default()) + .start(futures::future::ready(())) + .await + .unwrap(); + + queue + .enqueue(GreetingTask { + person: "Rafael".to_string(), + }) + .await + .unwrap(); + + queue.enqueue(OtherTask).await.unwrap(); + + join_handle.await.unwrap(); + } + + async fn memory_store() -> MemoryTaskStore { + MemoryTaskStore::new() + } + + #[tokio::test] + #[ignore] + async fn test_worker_pool_with_pg_store() { + let my_app_context = ApplicationContext::new(); + + let (join_handle, _queue) = + WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone()) + .register_task_type::() + .configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone) + .start(futures::future::ready(())) + .await + .unwrap(); + + join_handle.await.unwrap(); + } + + async fn pg_task_store() -> PgTaskStore { let manager = AsyncDieselConnectionManager::::new( option_env!("DATABASE_URL").expect("DATABASE_URL must be set"), );