diff --git a/examples/simple_worker/src/main.rs b/examples/simple_worker/src/main.rs index 08f571f..8ba1beb 100644 --- a/examples/simple_worker/src/main.rs +++ b/examples/simple_worker/src/main.rs @@ -118,8 +118,8 @@ async fn main() { let my_app_context = MyApplicationContext::new("Backie Example App"); // Register the task types I want to use and start the worker pool - let (join_handle, _queue) = - WorkerPool::new(task_store.clone(), move |_| my_app_context.clone()) + let join_handle = + WorkerPool::new(task_store.clone(), move || my_app_context.clone()) .register_task_type::() .register_task_type::() .configure_queue("default".into()) @@ -135,7 +135,7 @@ async fn main() { let task2 = MyTask::new(20_000); let task3 = MyFailingTask::new(50_000); - let queue = Queue::new(task_store); // or use the `queue` instance returned by the worker pool + let queue = Queue::new(task_store); queue.enqueue(task1).await.unwrap(); queue.enqueue(task2).await.unwrap(); queue.enqueue(task3).await.unwrap(); diff --git a/src/queue.rs b/src/queue.rs index 8d0a20e..547d6e7 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -2,34 +2,42 @@ use crate::errors::BackieError; use crate::runnable::BackgroundTask; use crate::store::TaskStore; use crate::task::NewTask; -use std::sync::Arc; use std::time::Duration; -#[derive(Clone)] pub struct Queue where - S: TaskStore + Clone, + S: TaskStore, { - task_store: Arc, + task_store: S, } impl Queue where - S: TaskStore + Clone, + S: TaskStore, { pub fn new(task_store: S) -> Self { - Queue { - task_store: Arc::new(task_store), - } + Queue { task_store } } pub async fn enqueue(&self, background_task: BT) -> Result<(), BackieError> where BT: BackgroundTask, { + // TODO: Add option to specify the timeout of a task self.task_store .create_task(NewTask::new(background_task, Duration::from_secs(10))?) .await?; Ok(()) } } + +impl Clone for Queue +where + S: TaskStore + Clone, +{ + fn clone(&self) -> Self { + Self { + task_store: self.task_store.clone(), + } + } +} diff --git a/src/worker.rs b/src/worker.rs index c545b2e..8057f99 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -59,7 +59,7 @@ where AppData: Clone + Send + 'static, S: TaskStore + Clone, { - store: Arc, + store: S, queue_name: String, @@ -81,7 +81,7 @@ where S: TaskStore + Clone, { pub(crate) fn new( - store: Arc, + store: S, queue_name: String, retention_mode: RetentionMode, pull_interval: Duration, diff --git a/src/worker_pool.rs b/src/worker_pool.rs index d9c4e53..22e9873 100644 --- a/src/worker_pool.rs +++ b/src/worker_pool.rs @@ -1,5 +1,4 @@ use crate::errors::BackieError; -use crate::queue::Queue; use crate::runnable::BackgroundTask; use crate::store::TaskStore; use crate::worker::{runnable, ExecuteTaskFn}; @@ -19,10 +18,7 @@ where S: TaskStore + Clone, { /// Storage of tasks. - task_store: Arc, - - /// Queue used to spawn tasks. - queue: Queue, + task_store: S, /// Make possible to load the application data. /// @@ -49,16 +45,10 @@ where /// Create a new worker pool. pub fn new(task_store: S, application_data_fn: A) -> Self where - A: Fn(Queue) -> AppData + Send + Sync + 'static, + A: Fn() -> AppData + Send + Sync + 'static, { - let queue = Queue::new(task_store.clone()); - let application_data_fn = { - let queue = queue.clone(); - move || application_data_fn(queue.clone()) - }; Self { - task_store: Arc::new(task_store), - queue, + task_store, application_data_fn: Arc::new(application_data_fn), task_registry: BTreeMap::new(), queue_tasks: BTreeMap::new(), @@ -85,10 +75,7 @@ where self } - pub async fn start( - self, - graceful_shutdown: F, - ) -> Result<(JoinHandle<()>, Queue), BackieError> + pub async fn start(self, graceful_shutdown: F) -> Result, BackieError> where F: Future + Send + 'static, { @@ -127,28 +114,25 @@ where } } - Ok(( - tokio::spawn(async move { - graceful_shutdown.await; - if let Err(err) = tx.send(()) { - log::warn!("Failed to send shutdown signal to worker pool: {}", err); + Ok(tokio::spawn(async move { + graceful_shutdown.await; + if let Err(err) = tx.send(()) { + log::warn!("Failed to send shutdown signal to worker pool: {}", err); + } else { + // Wait for all workers to finish processing + let results = join_all(worker_handles) + .await + .into_iter() + .filter(Result::is_err) + .map(Result::unwrap_err) + .collect::>(); + if !results.is_empty() { + log::error!("Worker pool stopped with errors: {:?}", results); } else { - // Wait for all workers to finish processing - let results = join_all(worker_handles) - .await - .into_iter() - .filter(Result::is_err) - .map(Result::unwrap_err) - .collect::>(); - if !results.is_empty() { - log::error!("Worker pool stopped with errors: {:?}", results); - } else { - log::info!("Worker pool stopped gracefully"); - } + log::info!("Worker pool stopped gracefully"); } - }), - self.queue, - )) + } + })) } } @@ -232,6 +216,7 @@ mod tests { use crate::store::test_store::MemoryTaskStore; use crate::store::PgTaskStore; use crate::task::CurrentTask; + use crate::Queue; use async_trait::async_trait; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::AsyncPgConnection; @@ -341,7 +326,7 @@ mod tests { async fn validate_all_registered_tasks_queues_are_configured() { let my_app_context = ApplicationContext::new(); - let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) + let result = WorkerPool::new(memory_store(), move || my_app_context.clone()) .register_task_type::() .start(futures::future::ready(())) .await; @@ -359,14 +344,16 @@ mod tests { async fn test_worker_pool_with_task() { let my_app_context = ApplicationContext::new(); - let (join_handle, queue) = - WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) - .register_task_type::() - .configure_queue(::QUEUE.into()) - .start(futures::future::ready(())) - .await - .unwrap(); + let task_store = memory_store(); + let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone()) + .register_task_type::() + .configure_queue(::QUEUE.into()) + .start(futures::future::ready(())) + .await + .unwrap(); + + let queue = Queue::new(task_store); queue .enqueue(GreetingTask { person: "Rafael".to_string(), @@ -381,16 +368,17 @@ mod tests { async fn test_worker_pool_with_multiple_task_types() { let my_app_context = ApplicationContext::new(); - let (join_handle, queue) = - WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) - .register_task_type::() - .register_task_type::() - .configure_queue("default".into()) - .configure_queue("other_queue".into()) - .start(futures::future::ready(())) - .await - .unwrap(); + let task_store = memory_store(); + let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone()) + .register_task_type::() + .register_task_type::() + .configure_queue("default".into()) + .configure_queue("other_queue".into()) + .start(futures::future::ready(())) + .await + .unwrap(); + let queue = Queue::new(task_store.clone()); queue .enqueue(GreetingTask { person: "Rafael".to_string(), @@ -442,17 +430,19 @@ mod tests { notify_finished: Arc::new(Mutex::new(Some(tx))), }; - let (join_handle, queue) = - WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) - .register_task_type::() - .configure_queue("default".into()) - .start(async move { - rx.await.unwrap(); - println!("Worker pool got notified to stop"); - }) - .await - .unwrap(); + let memory_store = memory_store(); + let join_handle = WorkerPool::new(memory_store.clone(), move || my_app_context.clone()) + .register_task_type::() + .configure_queue("default".into()) + .start(async move { + rx.await.unwrap(); + println!("Worker pool got notified to stop"); + }) + .await + .unwrap(); + + let queue = Queue::new(memory_store); // Notifies the worker pool to stop after the task is executed queue.enqueue(NotifyFinished).await.unwrap(); @@ -527,11 +517,11 @@ mod tests { unknown_task_ran: Arc::new(AtomicBool::new(false)), }; - let task_store = memory_store().await; + let task_store = memory_store(); - let (join_handle, queue) = WorkerPool::new(task_store, { + let join_handle = WorkerPool::new(task_store.clone(), { let my_app_context = my_app_context.clone(); - move |_| my_app_context.clone() + move || my_app_context.clone() }) .register_task_type::() .configure_queue("default".into()) @@ -542,6 +532,7 @@ mod tests { .await .unwrap(); + let queue = Queue::new(task_store); // Enqueue a task that is not registered queue.enqueue(UnknownTask).await.unwrap(); @@ -574,9 +565,9 @@ mod tests { let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel(); - let task_store = memory_store().await; + let task_store = memory_store(); - let (worker_pool_finished, queue) = WorkerPool::new(task_store.clone(), |_| ()) + let worker_pool_finished = WorkerPool::new(task_store.clone(), || ()) .register_task_type::() .configure_queue("default".into()) .start(async move { @@ -585,6 +576,7 @@ mod tests { .await .unwrap(); + let queue = Queue::new(task_store.clone()); // Enqueue a task that will panic queue.enqueue(BrokenTask).await.unwrap(); @@ -670,11 +662,11 @@ mod tests { ping_rx: Arc::new(Mutex::new(ping_rx)), }; - let task_store = memory_store().await; + let task_store = memory_store(); - let (worker_pool_finished, queue) = WorkerPool::new(task_store, { + let worker_pool_finished = WorkerPool::new(task_store.clone(), { let player_context = player_context.clone(); - move |_| player_context.clone() + move || player_context.clone() }) .register_task_type::() .configure_queue("default".into()) @@ -685,6 +677,7 @@ mod tests { .await .unwrap(); + let queue = Queue::new(task_store); queue.enqueue(KeepAliveTask).await.unwrap(); // Make sure task is running @@ -710,7 +703,7 @@ mod tests { ping_tx.send(PingPongGame::StopThisNow).await.unwrap(); } - async fn memory_store() -> MemoryTaskStore { + fn memory_store() -> MemoryTaskStore { MemoryTaskStore::default() } @@ -719,16 +712,15 @@ mod tests { async fn test_worker_pool_with_pg_store() { let my_app_context = ApplicationContext::new(); - let (join_handle, _queue) = - WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone()) - .register_task_type::() - .configure_queue( - QueueConfig::new(::QUEUE) - .retention_mode(RetentionMode::RemoveDone), - ) - .start(futures::future::ready(())) - .await - .unwrap(); + let join_handle = WorkerPool::new(pg_task_store().await, move || my_app_context.clone()) + .register_task_type::() + .configure_queue( + QueueConfig::new(::QUEUE) + .retention_mode(RetentionMode::RemoveDone), + ) + .start(futures::future::ready(())) + .await + .unwrap(); join_handle.await.unwrap(); }