diff --git a/src/worker_pool.rs b/src/worker_pool.rs index 05edf36..a8653af 100644 --- a/src/worker_pool.rs +++ b/src/worker_pool.rs @@ -166,6 +166,8 @@ mod tests { use async_trait::async_trait; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::AsyncPgConnection; + use futures::FutureExt; + use std::sync::atomic::{AtomicBool, Ordering}; use tokio::sync::Mutex; #[derive(Clone, Debug)] @@ -236,36 +238,6 @@ mod tests { } } - #[derive(Clone)] - struct NotifyFinishedContext { - tx: Arc>>>, - } - - #[derive(serde::Serialize, serde::Deserialize)] - struct NotifyFinished; - - #[async_trait] - impl BackgroundTask for NotifyFinished { - const TASK_NAME: &'static str = "notify_finished"; - - type AppData = NotifyFinishedContext; - - async fn run( - &self, - task: CurrentTask, - context: Self::AppData, - ) -> Result<(), anyhow::Error> { - match context.tx.lock().await.take() { - None => println!("Cannot notify, already done that!"), - Some(tx) => { - tx.send(()).unwrap(); - println!("[{}] Notify finished did it's job!", task.id()) - } - }; - Ok(()) - } - } - #[tokio::test] async fn validate_all_registered_tasks_queues_are_configured() { let my_app_context = ApplicationContext::new(); @@ -334,6 +306,39 @@ mod tests { #[tokio::test] async fn test_worker_pool_stop_after_task_execute() { + #[derive(Clone)] + struct NotifyFinishedContext { + /// Used to notify the task ran + tx: Arc>>>, + } + + /// A task that notifies the test that it ran + #[derive(serde::Serialize, serde::Deserialize)] + struct NotifyFinished; + + #[async_trait] + impl BackgroundTask for NotifyFinished { + const TASK_NAME: &'static str = "notify_finished"; + + type AppData = NotifyFinishedContext; + + async fn run( + &self, + task: CurrentTask, + context: Self::AppData, + ) -> Result<(), anyhow::Error> { + // Notify the test that the task ran + match context.tx.lock().await.take() { + None => println!("Cannot notify, already done that!"), + Some(tx) => { + tx.send(()).unwrap(); + println!("[{}] Notify finished did it's job!", task.id()) + } + }; + Ok(()) + } + } + let (tx, rx) = tokio::sync::oneshot::channel(); let my_app_context = NotifyFinishedContext { @@ -362,6 +367,42 @@ mod tests { #[tokio::test] async fn test_worker_pool_try_to_run_unknown_task() { + #[derive(Clone)] + struct NotifyUnknownRanContext { + /// Notify that application should stop + should_stop: Arc>>>, + + /// Used to mark if the unknown task ran + unknown_task_ran: Arc, + } + + /// A task that notifies the test that it ran + #[derive(serde::Serialize, serde::Deserialize)] + struct NotifyStopDuringRun; + + #[async_trait] + impl BackgroundTask for NotifyStopDuringRun { + const TASK_NAME: &'static str = "notify_finished"; + + type AppData = NotifyUnknownRanContext; + + async fn run( + &self, + task: CurrentTask, + context: Self::AppData, + ) -> Result<(), anyhow::Error> { + // Notify the test that the task ran + match context.should_stop.lock().await.take() { + None => println!("Cannot notify, already done that!"), + Some(tx) => { + tx.send(()).unwrap(); + println!("[{}] Notify finished did it's job!", task.id()) + } + }; + Ok(()) + } + } + #[derive(Clone, serde::Serialize, serde::Deserialize)] struct UnknownTask; @@ -369,43 +410,157 @@ mod tests { impl BackgroundTask for UnknownTask { const TASK_NAME: &'static str = "unknown_task"; - type AppData = NotifyFinishedContext; + type AppData = NotifyUnknownRanContext; async fn run( &self, task: CurrentTask, - _context: Self::AppData, + context: Self::AppData, ) -> Result<(), anyhow::Error> { println!("[{}] Unknown task ran!", task.id()); + context.unknown_task_ran.store(true, Ordering::Relaxed); Ok(()) } } let (tx, rx) = tokio::sync::oneshot::channel(); - let my_app_context = NotifyFinishedContext { - tx: Arc::new(Mutex::new(Some(tx))), + let my_app_context = NotifyUnknownRanContext { + should_stop: Arc::new(Mutex::new(Some(tx))), + unknown_task_ran: Arc::new(AtomicBool::new(false)), }; let task_store = memory_store().await; - let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone()) - .register_task_type::() - .configure_queue("default", 1, RetentionMode::default()) - .start(async move { - rx.await.unwrap(); - println!("Worker pool got notified to stop"); - }) - .await - .unwrap(); + let (join_handle, queue) = WorkerPool::new(task_store, { + let my_app_context = my_app_context.clone(); + move |_| my_app_context.clone() + }) + .register_task_type::() + .configure_queue("default", 1, RetentionMode::default()) + .start(async move { + rx.await.unwrap(); + println!("Worker pool got notified to stop"); + }) + .await + .unwrap(); // Enqueue a task that is not registered queue.enqueue(UnknownTask).await.unwrap(); // Notifies the worker pool to stop for this test - queue.enqueue(NotifyFinished).await.unwrap(); + queue.enqueue(NotifyStopDuringRun).await.unwrap(); join_handle.await.unwrap(); + + assert!( + !my_app_context.unknown_task_ran.load(Ordering::Relaxed), + "Unknown task ran but it is not registered in the worker pool!" + ); + } + + /// This test will make sure that the worker pool will only stop after all workers are done. + /// We create a KeepAliveTask that will keep running until we notify it to stop. + /// We stop the worker pool and make sure that the KeepAliveTask is still running. + /// Then we notify the KeepAliveTask to stop and make sure that the worker pool stops. + #[tokio::test] + async fn tasks_only_stop_running_when_finished() { + #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] + enum PingPongGame { + Ping, + Pong, + StopThisNow, + } + + #[derive(Clone)] + struct PlayerContext { + /// Used to communicate with the running task + pong_tx: Arc>, + ping_rx: Arc>>, + } + + /// Task that will respond to the ping pong game and keep alive as long as we need + #[derive(Clone, serde::Serialize, serde::Deserialize)] + struct KeepAliveTask; + + #[async_trait] + impl BackgroundTask for KeepAliveTask { + const TASK_NAME: &'static str = "keep_alive_task"; + + type AppData = PlayerContext; + + async fn run( + &self, + _task: CurrentTask, + context: Self::AppData, + ) -> Result<(), anyhow::Error> { + loop { + let msg = context.ping_rx.lock().await.recv().await.unwrap(); + match msg { + PingPongGame::Ping => { + println!("Pong!"); + context.pong_tx.send(PingPongGame::Pong).await.unwrap(); + } + PingPongGame::Pong => { + context.pong_tx.send(PingPongGame::Ping).await.unwrap(); + } + PingPongGame::StopThisNow => { + println!("Got stop signal, stopping the ping pong game now!"); + break; + } + } + } + Ok(()) + } + } + + let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel(); + let (pong_tx, mut pong_rx) = tokio::sync::mpsc::channel(1); + let (ping_tx, ping_rx) = tokio::sync::mpsc::channel(1); + + let player_context = PlayerContext { + pong_tx: Arc::new(pong_tx), + ping_rx: Arc::new(Mutex::new(ping_rx)), + }; + + let task_store = memory_store().await; + + let (worker_pool_finished, queue) = WorkerPool::new(task_store, { + let player_context = player_context.clone(); + move |_| player_context.clone() + }) + .register_task_type::() + .configure_queue("default", 1, RetentionMode::default()) + .start(async move { + should_stop.await.unwrap(); + println!("Worker pool got notified to stop"); + }) + .await + .unwrap(); + + queue.enqueue(KeepAliveTask).await.unwrap(); + + // Make sure task is running + println!("Ping!"); + ping_tx.send(PingPongGame::Ping).await.unwrap(); + assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong); + + // Notify to stop the worker pool + notify_stop_worker_pool.send(()).unwrap(); + + // Make sure task is still running + println!("Ping!"); + ping_tx.send(PingPongGame::Ping).await.unwrap(); + assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong); + + // is_none() means that the worker pool is still waiting for tasks to finish, which is what we want! + assert!( + worker_pool_finished.now_or_never().is_none(), + "Worker pool finished before task stopped!" + ); + + // Notify to stop the task, which will stop the worker pool + ping_tx.send(PingPongGame::StopThisNow).await.unwrap(); } async fn memory_store() -> MemoryTaskStore {