Tasks are let run until completion

This commit is contained in:
Rafael Caricio 2023-03-12 15:52:13 +01:00
parent 2964dc2b88
commit 0f0a9c2238
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947

View file

@ -166,6 +166,8 @@ mod tests {
use async_trait::async_trait; use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection; use diesel_async::AsyncPgConnection;
use futures::FutureExt;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::Mutex; use tokio::sync::Mutex;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -236,36 +238,6 @@ mod tests {
} }
} }
#[derive(Clone)]
struct NotifyFinishedContext {
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct NotifyFinished;
#[async_trait]
impl BackgroundTask for NotifyFinished {
const TASK_NAME: &'static str = "notify_finished";
type AppData = NotifyFinishedContext;
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
match context.tx.lock().await.take() {
None => println!("Cannot notify, already done that!"),
Some(tx) => {
tx.send(()).unwrap();
println!("[{}] Notify finished did it's job!", task.id())
}
};
Ok(())
}
}
#[tokio::test] #[tokio::test]
async fn validate_all_registered_tasks_queues_are_configured() { async fn validate_all_registered_tasks_queues_are_configured() {
let my_app_context = ApplicationContext::new(); let my_app_context = ApplicationContext::new();
@ -334,6 +306,39 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_worker_pool_stop_after_task_execute() { async fn test_worker_pool_stop_after_task_execute() {
#[derive(Clone)]
struct NotifyFinishedContext {
/// Used to notify the task ran
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
}
/// A task that notifies the test that it ran
#[derive(serde::Serialize, serde::Deserialize)]
struct NotifyFinished;
#[async_trait]
impl BackgroundTask for NotifyFinished {
const TASK_NAME: &'static str = "notify_finished";
type AppData = NotifyFinishedContext;
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
// Notify the test that the task ran
match context.tx.lock().await.take() {
None => println!("Cannot notify, already done that!"),
Some(tx) => {
tx.send(()).unwrap();
println!("[{}] Notify finished did it's job!", task.id())
}
};
Ok(())
}
}
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
let my_app_context = NotifyFinishedContext { let my_app_context = NotifyFinishedContext {
@ -362,6 +367,42 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_worker_pool_try_to_run_unknown_task() { async fn test_worker_pool_try_to_run_unknown_task() {
#[derive(Clone)]
struct NotifyUnknownRanContext {
/// Notify that application should stop
should_stop: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
/// Used to mark if the unknown task ran
unknown_task_ran: Arc<AtomicBool>,
}
/// A task that notifies the test that it ran
#[derive(serde::Serialize, serde::Deserialize)]
struct NotifyStopDuringRun;
#[async_trait]
impl BackgroundTask for NotifyStopDuringRun {
const TASK_NAME: &'static str = "notify_finished";
type AppData = NotifyUnknownRanContext;
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
// Notify the test that the task ran
match context.should_stop.lock().await.take() {
None => println!("Cannot notify, already done that!"),
Some(tx) => {
tx.send(()).unwrap();
println!("[{}] Notify finished did it's job!", task.id())
}
};
Ok(())
}
}
#[derive(Clone, serde::Serialize, serde::Deserialize)] #[derive(Clone, serde::Serialize, serde::Deserialize)]
struct UnknownTask; struct UnknownTask;
@ -369,28 +410,33 @@ mod tests {
impl BackgroundTask for UnknownTask { impl BackgroundTask for UnknownTask {
const TASK_NAME: &'static str = "unknown_task"; const TASK_NAME: &'static str = "unknown_task";
type AppData = NotifyFinishedContext; type AppData = NotifyUnknownRanContext;
async fn run( async fn run(
&self, &self,
task: CurrentTask, task: CurrentTask,
_context: Self::AppData, context: Self::AppData,
) -> Result<(), anyhow::Error> { ) -> Result<(), anyhow::Error> {
println!("[{}] Unknown task ran!", task.id()); println!("[{}] Unknown task ran!", task.id());
context.unknown_task_ran.store(true, Ordering::Relaxed);
Ok(()) Ok(())
} }
} }
let (tx, rx) = tokio::sync::oneshot::channel(); let (tx, rx) = tokio::sync::oneshot::channel();
let my_app_context = NotifyFinishedContext { let my_app_context = NotifyUnknownRanContext {
tx: Arc::new(Mutex::new(Some(tx))), should_stop: Arc::new(Mutex::new(Some(tx))),
unknown_task_ran: Arc::new(AtomicBool::new(false)),
}; };
let task_store = memory_store().await; let task_store = memory_store().await;
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone()) let (join_handle, queue) = WorkerPool::new(task_store, {
.register_task_type::<NotifyFinished>() let my_app_context = my_app_context.clone();
move |_| my_app_context.clone()
})
.register_task_type::<NotifyStopDuringRun>()
.configure_queue("default", 1, RetentionMode::default()) .configure_queue("default", 1, RetentionMode::default())
.start(async move { .start(async move {
rx.await.unwrap(); rx.await.unwrap();
@ -403,9 +449,118 @@ mod tests {
queue.enqueue(UnknownTask).await.unwrap(); queue.enqueue(UnknownTask).await.unwrap();
// Notifies the worker pool to stop for this test // Notifies the worker pool to stop for this test
queue.enqueue(NotifyFinished).await.unwrap(); queue.enqueue(NotifyStopDuringRun).await.unwrap();
join_handle.await.unwrap(); join_handle.await.unwrap();
assert!(
!my_app_context.unknown_task_ran.load(Ordering::Relaxed),
"Unknown task ran but it is not registered in the worker pool!"
);
}
/// This test will make sure that the worker pool will only stop after all workers are done.
/// We create a KeepAliveTask that will keep running until we notify it to stop.
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
/// Then we notify the KeepAliveTask to stop and make sure that the worker pool stops.
#[tokio::test]
async fn tasks_only_stop_running_when_finished() {
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
enum PingPongGame {
Ping,
Pong,
StopThisNow,
}
#[derive(Clone)]
struct PlayerContext {
/// Used to communicate with the running task
pong_tx: Arc<tokio::sync::mpsc::Sender<PingPongGame>>,
ping_rx: Arc<Mutex<tokio::sync::mpsc::Receiver<PingPongGame>>>,
}
/// Task that will respond to the ping pong game and keep alive as long as we need
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct KeepAliveTask;
#[async_trait]
impl BackgroundTask for KeepAliveTask {
const TASK_NAME: &'static str = "keep_alive_task";
type AppData = PlayerContext;
async fn run(
&self,
_task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
loop {
let msg = context.ping_rx.lock().await.recv().await.unwrap();
match msg {
PingPongGame::Ping => {
println!("Pong!");
context.pong_tx.send(PingPongGame::Pong).await.unwrap();
}
PingPongGame::Pong => {
context.pong_tx.send(PingPongGame::Ping).await.unwrap();
}
PingPongGame::StopThisNow => {
println!("Got stop signal, stopping the ping pong game now!");
break;
}
}
}
Ok(())
}
}
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
let (pong_tx, mut pong_rx) = tokio::sync::mpsc::channel(1);
let (ping_tx, ping_rx) = tokio::sync::mpsc::channel(1);
let player_context = PlayerContext {
pong_tx: Arc::new(pong_tx),
ping_rx: Arc::new(Mutex::new(ping_rx)),
};
let task_store = memory_store().await;
let (worker_pool_finished, queue) = WorkerPool::new(task_store, {
let player_context = player_context.clone();
move |_| player_context.clone()
})
.register_task_type::<KeepAliveTask>()
.configure_queue("default", 1, RetentionMode::default())
.start(async move {
should_stop.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
queue.enqueue(KeepAliveTask).await.unwrap();
// Make sure task is running
println!("Ping!");
ping_tx.send(PingPongGame::Ping).await.unwrap();
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
// Notify to stop the worker pool
notify_stop_worker_pool.send(()).unwrap();
// Make sure task is still running
println!("Ping!");
ping_tx.send(PingPongGame::Ping).await.unwrap();
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
// is_none() means that the worker pool is still waiting for tasks to finish, which is what we want!
assert!(
worker_pool_finished.now_or_never().is_none(),
"Worker pool finished before task stopped!"
);
// Notify to stop the task, which will stop the worker pool
ping_tx.send(PingPongGame::StopThisNow).await.unwrap();
} }
async fn memory_store() -> MemoryTaskStore { async fn memory_store() -> MemoryTaskStore {