Tasks are let run until completion
This commit is contained in:
parent
2964dc2b88
commit
0f0a9c2238
1 changed files with 199 additions and 44 deletions
|
@ -166,6 +166,8 @@ mod tests {
|
|||
use async_trait::async_trait;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
use futures::FutureExt;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
@ -236,36 +238,6 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NotifyFinishedContext {
|
||||
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct NotifyFinished;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for NotifyFinished {
|
||||
const TASK_NAME: &'static str = "notify_finished";
|
||||
|
||||
type AppData = NotifyFinishedContext;
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task: CurrentTask,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
match context.tx.lock().await.take() {
|
||||
None => println!("Cannot notify, already done that!"),
|
||||
Some(tx) => {
|
||||
tx.send(()).unwrap();
|
||||
println!("[{}] Notify finished did it's job!", task.id())
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_all_registered_tasks_queues_are_configured() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
@ -334,6 +306,39 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_worker_pool_stop_after_task_execute() {
|
||||
#[derive(Clone)]
|
||||
struct NotifyFinishedContext {
|
||||
/// Used to notify the task ran
|
||||
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||
}
|
||||
|
||||
/// A task that notifies the test that it ran
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct NotifyFinished;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for NotifyFinished {
|
||||
const TASK_NAME: &'static str = "notify_finished";
|
||||
|
||||
type AppData = NotifyFinishedContext;
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task: CurrentTask,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// Notify the test that the task ran
|
||||
match context.tx.lock().await.take() {
|
||||
None => println!("Cannot notify, already done that!"),
|
||||
Some(tx) => {
|
||||
tx.send(()).unwrap();
|
||||
println!("[{}] Notify finished did it's job!", task.id())
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
let my_app_context = NotifyFinishedContext {
|
||||
|
@ -362,6 +367,42 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_worker_pool_try_to_run_unknown_task() {
|
||||
#[derive(Clone)]
|
||||
struct NotifyUnknownRanContext {
|
||||
/// Notify that application should stop
|
||||
should_stop: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||
|
||||
/// Used to mark if the unknown task ran
|
||||
unknown_task_ran: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// A task that notifies the test that it ran
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct NotifyStopDuringRun;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for NotifyStopDuringRun {
|
||||
const TASK_NAME: &'static str = "notify_finished";
|
||||
|
||||
type AppData = NotifyUnknownRanContext;
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task: CurrentTask,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// Notify the test that the task ran
|
||||
match context.should_stop.lock().await.take() {
|
||||
None => println!("Cannot notify, already done that!"),
|
||||
Some(tx) => {
|
||||
tx.send(()).unwrap();
|
||||
println!("[{}] Notify finished did it's job!", task.id())
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct UnknownTask;
|
||||
|
||||
|
@ -369,43 +410,157 @@ mod tests {
|
|||
impl BackgroundTask for UnknownTask {
|
||||
const TASK_NAME: &'static str = "unknown_task";
|
||||
|
||||
type AppData = NotifyFinishedContext;
|
||||
type AppData = NotifyUnknownRanContext;
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task: CurrentTask,
|
||||
_context: Self::AppData,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
println!("[{}] Unknown task ran!", task.id());
|
||||
context.unknown_task_ran.store(true, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
let my_app_context = NotifyFinishedContext {
|
||||
tx: Arc::new(Mutex::new(Some(tx))),
|
||||
let my_app_context = NotifyUnknownRanContext {
|
||||
should_stop: Arc::new(Mutex::new(Some(tx))),
|
||||
unknown_task_ran: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
let task_store = memory_store().await;
|
||||
|
||||
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||
.register_task_type::<NotifyFinished>()
|
||||
.configure_queue("default", 1, RetentionMode::default())
|
||||
.start(async move {
|
||||
rx.await.unwrap();
|
||||
println!("Worker pool got notified to stop");
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let (join_handle, queue) = WorkerPool::new(task_store, {
|
||||
let my_app_context = my_app_context.clone();
|
||||
move |_| my_app_context.clone()
|
||||
})
|
||||
.register_task_type::<NotifyStopDuringRun>()
|
||||
.configure_queue("default", 1, RetentionMode::default())
|
||||
.start(async move {
|
||||
rx.await.unwrap();
|
||||
println!("Worker pool got notified to stop");
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Enqueue a task that is not registered
|
||||
queue.enqueue(UnknownTask).await.unwrap();
|
||||
|
||||
// Notifies the worker pool to stop for this test
|
||||
queue.enqueue(NotifyFinished).await.unwrap();
|
||||
queue.enqueue(NotifyStopDuringRun).await.unwrap();
|
||||
|
||||
join_handle.await.unwrap();
|
||||
|
||||
assert!(
|
||||
!my_app_context.unknown_task_ran.load(Ordering::Relaxed),
|
||||
"Unknown task ran but it is not registered in the worker pool!"
|
||||
);
|
||||
}
|
||||
|
||||
/// This test will make sure that the worker pool will only stop after all workers are done.
|
||||
/// We create a KeepAliveTask that will keep running until we notify it to stop.
|
||||
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
|
||||
/// Then we notify the KeepAliveTask to stop and make sure that the worker pool stops.
|
||||
#[tokio::test]
|
||||
async fn tasks_only_stop_running_when_finished() {
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||
enum PingPongGame {
|
||||
Ping,
|
||||
Pong,
|
||||
StopThisNow,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct PlayerContext {
|
||||
/// Used to communicate with the running task
|
||||
pong_tx: Arc<tokio::sync::mpsc::Sender<PingPongGame>>,
|
||||
ping_rx: Arc<Mutex<tokio::sync::mpsc::Receiver<PingPongGame>>>,
|
||||
}
|
||||
|
||||
/// Task that will respond to the ping pong game and keep alive as long as we need
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct KeepAliveTask;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for KeepAliveTask {
|
||||
const TASK_NAME: &'static str = "keep_alive_task";
|
||||
|
||||
type AppData = PlayerContext;
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
_task: CurrentTask,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
loop {
|
||||
let msg = context.ping_rx.lock().await.recv().await.unwrap();
|
||||
match msg {
|
||||
PingPongGame::Ping => {
|
||||
println!("Pong!");
|
||||
context.pong_tx.send(PingPongGame::Pong).await.unwrap();
|
||||
}
|
||||
PingPongGame::Pong => {
|
||||
context.pong_tx.send(PingPongGame::Ping).await.unwrap();
|
||||
}
|
||||
PingPongGame::StopThisNow => {
|
||||
println!("Got stop signal, stopping the ping pong game now!");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
|
||||
let (pong_tx, mut pong_rx) = tokio::sync::mpsc::channel(1);
|
||||
let (ping_tx, ping_rx) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
let player_context = PlayerContext {
|
||||
pong_tx: Arc::new(pong_tx),
|
||||
ping_rx: Arc::new(Mutex::new(ping_rx)),
|
||||
};
|
||||
|
||||
let task_store = memory_store().await;
|
||||
|
||||
let (worker_pool_finished, queue) = WorkerPool::new(task_store, {
|
||||
let player_context = player_context.clone();
|
||||
move |_| player_context.clone()
|
||||
})
|
||||
.register_task_type::<KeepAliveTask>()
|
||||
.configure_queue("default", 1, RetentionMode::default())
|
||||
.start(async move {
|
||||
should_stop.await.unwrap();
|
||||
println!("Worker pool got notified to stop");
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
queue.enqueue(KeepAliveTask).await.unwrap();
|
||||
|
||||
// Make sure task is running
|
||||
println!("Ping!");
|
||||
ping_tx.send(PingPongGame::Ping).await.unwrap();
|
||||
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
|
||||
|
||||
// Notify to stop the worker pool
|
||||
notify_stop_worker_pool.send(()).unwrap();
|
||||
|
||||
// Make sure task is still running
|
||||
println!("Ping!");
|
||||
ping_tx.send(PingPongGame::Ping).await.unwrap();
|
||||
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
|
||||
|
||||
// is_none() means that the worker pool is still waiting for tasks to finish, which is what we want!
|
||||
assert!(
|
||||
worker_pool_finished.now_or_never().is_none(),
|
||||
"Worker pool finished before task stopped!"
|
||||
);
|
||||
|
||||
// Notify to stop the task, which will stop the worker pool
|
||||
ping_tx.send(PingPongGame::StopThisNow).await.unwrap();
|
||||
}
|
||||
|
||||
async fn memory_store() -> MemoryTaskStore {
|
||||
|
|
Loading…
Reference in a new issue