Tasks are let run until completion
This commit is contained in:
parent
2964dc2b88
commit
0f0a9c2238
1 changed files with 199 additions and 44 deletions
|
@ -166,6 +166,8 @@ mod tests {
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||||
use diesel_async::AsyncPgConnection;
|
use diesel_async::AsyncPgConnection;
|
||||||
|
use futures::FutureExt;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -236,36 +238,6 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct NotifyFinishedContext {
|
|
||||||
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Serialize, serde::Deserialize)]
|
|
||||||
struct NotifyFinished;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl BackgroundTask for NotifyFinished {
|
|
||||||
const TASK_NAME: &'static str = "notify_finished";
|
|
||||||
|
|
||||||
type AppData = NotifyFinishedContext;
|
|
||||||
|
|
||||||
async fn run(
|
|
||||||
&self,
|
|
||||||
task: CurrentTask,
|
|
||||||
context: Self::AppData,
|
|
||||||
) -> Result<(), anyhow::Error> {
|
|
||||||
match context.tx.lock().await.take() {
|
|
||||||
None => println!("Cannot notify, already done that!"),
|
|
||||||
Some(tx) => {
|
|
||||||
tx.send(()).unwrap();
|
|
||||||
println!("[{}] Notify finished did it's job!", task.id())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn validate_all_registered_tasks_queues_are_configured() {
|
async fn validate_all_registered_tasks_queues_are_configured() {
|
||||||
let my_app_context = ApplicationContext::new();
|
let my_app_context = ApplicationContext::new();
|
||||||
|
@ -334,6 +306,39 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_worker_pool_stop_after_task_execute() {
|
async fn test_worker_pool_stop_after_task_execute() {
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct NotifyFinishedContext {
|
||||||
|
/// Used to notify the task ran
|
||||||
|
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A task that notifies the test that it ran
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
struct NotifyFinished;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BackgroundTask for NotifyFinished {
|
||||||
|
const TASK_NAME: &'static str = "notify_finished";
|
||||||
|
|
||||||
|
type AppData = NotifyFinishedContext;
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
task: CurrentTask,
|
||||||
|
context: Self::AppData,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
// Notify the test that the task ran
|
||||||
|
match context.tx.lock().await.take() {
|
||||||
|
None => println!("Cannot notify, already done that!"),
|
||||||
|
Some(tx) => {
|
||||||
|
tx.send(()).unwrap();
|
||||||
|
println!("[{}] Notify finished did it's job!", task.id())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
|
|
||||||
let my_app_context = NotifyFinishedContext {
|
let my_app_context = NotifyFinishedContext {
|
||||||
|
@ -362,6 +367,42 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_worker_pool_try_to_run_unknown_task() {
|
async fn test_worker_pool_try_to_run_unknown_task() {
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct NotifyUnknownRanContext {
|
||||||
|
/// Notify that application should stop
|
||||||
|
should_stop: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||||
|
|
||||||
|
/// Used to mark if the unknown task ran
|
||||||
|
unknown_task_ran: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A task that notifies the test that it ran
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
struct NotifyStopDuringRun;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BackgroundTask for NotifyStopDuringRun {
|
||||||
|
const TASK_NAME: &'static str = "notify_finished";
|
||||||
|
|
||||||
|
type AppData = NotifyUnknownRanContext;
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
task: CurrentTask,
|
||||||
|
context: Self::AppData,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
// Notify the test that the task ran
|
||||||
|
match context.should_stop.lock().await.take() {
|
||||||
|
None => println!("Cannot notify, already done that!"),
|
||||||
|
Some(tx) => {
|
||||||
|
tx.send(()).unwrap();
|
||||||
|
println!("[{}] Notify finished did it's job!", task.id())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||||
struct UnknownTask;
|
struct UnknownTask;
|
||||||
|
|
||||||
|
@ -369,43 +410,157 @@ mod tests {
|
||||||
impl BackgroundTask for UnknownTask {
|
impl BackgroundTask for UnknownTask {
|
||||||
const TASK_NAME: &'static str = "unknown_task";
|
const TASK_NAME: &'static str = "unknown_task";
|
||||||
|
|
||||||
type AppData = NotifyFinishedContext;
|
type AppData = NotifyUnknownRanContext;
|
||||||
|
|
||||||
async fn run(
|
async fn run(
|
||||||
&self,
|
&self,
|
||||||
task: CurrentTask,
|
task: CurrentTask,
|
||||||
_context: Self::AppData,
|
context: Self::AppData,
|
||||||
) -> Result<(), anyhow::Error> {
|
) -> Result<(), anyhow::Error> {
|
||||||
println!("[{}] Unknown task ran!", task.id());
|
println!("[{}] Unknown task ran!", task.id());
|
||||||
|
context.unknown_task_ran.store(true, Ordering::Relaxed);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
|
|
||||||
let my_app_context = NotifyFinishedContext {
|
let my_app_context = NotifyUnknownRanContext {
|
||||||
tx: Arc::new(Mutex::new(Some(tx))),
|
should_stop: Arc::new(Mutex::new(Some(tx))),
|
||||||
|
unknown_task_ran: Arc::new(AtomicBool::new(false)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let task_store = memory_store().await;
|
let task_store = memory_store().await;
|
||||||
|
|
||||||
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
let (join_handle, queue) = WorkerPool::new(task_store, {
|
||||||
.register_task_type::<NotifyFinished>()
|
let my_app_context = my_app_context.clone();
|
||||||
.configure_queue("default", 1, RetentionMode::default())
|
move |_| my_app_context.clone()
|
||||||
.start(async move {
|
})
|
||||||
rx.await.unwrap();
|
.register_task_type::<NotifyStopDuringRun>()
|
||||||
println!("Worker pool got notified to stop");
|
.configure_queue("default", 1, RetentionMode::default())
|
||||||
})
|
.start(async move {
|
||||||
.await
|
rx.await.unwrap();
|
||||||
.unwrap();
|
println!("Worker pool got notified to stop");
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Enqueue a task that is not registered
|
// Enqueue a task that is not registered
|
||||||
queue.enqueue(UnknownTask).await.unwrap();
|
queue.enqueue(UnknownTask).await.unwrap();
|
||||||
|
|
||||||
// Notifies the worker pool to stop for this test
|
// Notifies the worker pool to stop for this test
|
||||||
queue.enqueue(NotifyFinished).await.unwrap();
|
queue.enqueue(NotifyStopDuringRun).await.unwrap();
|
||||||
|
|
||||||
join_handle.await.unwrap();
|
join_handle.await.unwrap();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!my_app_context.unknown_task_ran.load(Ordering::Relaxed),
|
||||||
|
"Unknown task ran but it is not registered in the worker pool!"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This test will make sure that the worker pool will only stop after all workers are done.
|
||||||
|
/// We create a KeepAliveTask that will keep running until we notify it to stop.
|
||||||
|
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
|
||||||
|
/// Then we notify the KeepAliveTask to stop and make sure that the worker pool stops.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tasks_only_stop_running_when_finished() {
|
||||||
|
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||||
|
enum PingPongGame {
|
||||||
|
Ping,
|
||||||
|
Pong,
|
||||||
|
StopThisNow,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct PlayerContext {
|
||||||
|
/// Used to communicate with the running task
|
||||||
|
pong_tx: Arc<tokio::sync::mpsc::Sender<PingPongGame>>,
|
||||||
|
ping_rx: Arc<Mutex<tokio::sync::mpsc::Receiver<PingPongGame>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Task that will respond to the ping pong game and keep alive as long as we need
|
||||||
|
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||||
|
struct KeepAliveTask;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BackgroundTask for KeepAliveTask {
|
||||||
|
const TASK_NAME: &'static str = "keep_alive_task";
|
||||||
|
|
||||||
|
type AppData = PlayerContext;
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
_task: CurrentTask,
|
||||||
|
context: Self::AppData,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
loop {
|
||||||
|
let msg = context.ping_rx.lock().await.recv().await.unwrap();
|
||||||
|
match msg {
|
||||||
|
PingPongGame::Ping => {
|
||||||
|
println!("Pong!");
|
||||||
|
context.pong_tx.send(PingPongGame::Pong).await.unwrap();
|
||||||
|
}
|
||||||
|
PingPongGame::Pong => {
|
||||||
|
context.pong_tx.send(PingPongGame::Ping).await.unwrap();
|
||||||
|
}
|
||||||
|
PingPongGame::StopThisNow => {
|
||||||
|
println!("Got stop signal, stopping the ping pong game now!");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
|
||||||
|
let (pong_tx, mut pong_rx) = tokio::sync::mpsc::channel(1);
|
||||||
|
let (ping_tx, ping_rx) = tokio::sync::mpsc::channel(1);
|
||||||
|
|
||||||
|
let player_context = PlayerContext {
|
||||||
|
pong_tx: Arc::new(pong_tx),
|
||||||
|
ping_rx: Arc::new(Mutex::new(ping_rx)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let task_store = memory_store().await;
|
||||||
|
|
||||||
|
let (worker_pool_finished, queue) = WorkerPool::new(task_store, {
|
||||||
|
let player_context = player_context.clone();
|
||||||
|
move |_| player_context.clone()
|
||||||
|
})
|
||||||
|
.register_task_type::<KeepAliveTask>()
|
||||||
|
.configure_queue("default", 1, RetentionMode::default())
|
||||||
|
.start(async move {
|
||||||
|
should_stop.await.unwrap();
|
||||||
|
println!("Worker pool got notified to stop");
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
queue.enqueue(KeepAliveTask).await.unwrap();
|
||||||
|
|
||||||
|
// Make sure task is running
|
||||||
|
println!("Ping!");
|
||||||
|
ping_tx.send(PingPongGame::Ping).await.unwrap();
|
||||||
|
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
|
||||||
|
|
||||||
|
// Notify to stop the worker pool
|
||||||
|
notify_stop_worker_pool.send(()).unwrap();
|
||||||
|
|
||||||
|
// Make sure task is still running
|
||||||
|
println!("Ping!");
|
||||||
|
ping_tx.send(PingPongGame::Ping).await.unwrap();
|
||||||
|
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
|
||||||
|
|
||||||
|
// is_none() means that the worker pool is still waiting for tasks to finish, which is what we want!
|
||||||
|
assert!(
|
||||||
|
worker_pool_finished.now_or_never().is_none(),
|
||||||
|
"Worker pool finished before task stopped!"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Notify to stop the task, which will stop the worker pool
|
||||||
|
ping_tx.send(PingPongGame::StopThisNow).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn memory_store() -> MemoryTaskStore {
|
async fn memory_store() -> MemoryTaskStore {
|
||||||
|
|
Loading…
Reference in a new issue