2023-03-09 15:59:45 +00:00
|
|
|
use crate::errors::BackieError;
|
2023-03-10 22:41:34 +00:00
|
|
|
use crate::queue::Queue;
|
2023-03-11 16:49:23 +00:00
|
|
|
use crate::runnable::BackgroundTask;
|
|
|
|
use crate::store::TaskStore;
|
2023-03-10 22:41:34 +00:00
|
|
|
use crate::worker::{runnable, ExecuteTaskFn};
|
|
|
|
use crate::worker::{StateFn, Worker};
|
2023-03-11 16:49:23 +00:00
|
|
|
use crate::RetentionMode;
|
2023-03-11 23:18:15 +00:00
|
|
|
use futures::future::join_all;
|
2023-03-10 22:41:34 +00:00
|
|
|
use std::collections::BTreeMap;
|
2023-03-09 15:59:45 +00:00
|
|
|
use std::future::Future;
|
2023-03-10 22:41:34 +00:00
|
|
|
use std::sync::Arc;
|
2023-03-12 16:15:40 +00:00
|
|
|
use std::time::Duration;
|
2023-03-10 22:41:34 +00:00
|
|
|
use tokio::task::JoinHandle;
|
2022-07-31 13:32:37 +00:00
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
#[derive(Clone)]
|
2023-03-11 16:49:23 +00:00
|
|
|
pub struct WorkerPool<AppData, S>
|
2022-07-31 13:32:37 +00:00
|
|
|
where
|
2023-03-10 22:41:34 +00:00
|
|
|
AppData: Clone + Send + 'static,
|
2023-03-11 16:49:23 +00:00
|
|
|
S: TaskStore,
|
2022-07-31 13:32:37 +00:00
|
|
|
{
|
2023-03-10 22:41:34 +00:00
|
|
|
/// Storage of tasks.
|
2023-03-11 16:49:23 +00:00
|
|
|
task_store: Arc<S>,
|
2023-03-10 22:41:34 +00:00
|
|
|
|
|
|
|
/// Queue used to spawn tasks.
|
2023-03-11 16:49:23 +00:00
|
|
|
queue: Queue<S>,
|
2023-03-07 15:41:20 +00:00
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
/// Make possible to load the application data.
|
|
|
|
///
|
|
|
|
/// The application data is loaded when the worker pool is started and is passed to the tasks.
|
|
|
|
/// The loading function accepts a queue instance in case the application data depends on it. This
|
|
|
|
/// is interesting for situations where the application wants to allow tasks to spawn other tasks.
|
|
|
|
application_data_fn: StateFn<AppData>,
|
2023-03-07 15:41:20 +00:00
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
/// The types of task the worker pool can execute and the loaders for them.
|
|
|
|
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
2023-03-07 15:41:20 +00:00
|
|
|
|
2023-03-11 15:38:32 +00:00
|
|
|
/// The queue names for the registered tasks.
|
|
|
|
queue_tasks: BTreeMap<String, Vec<String>>,
|
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
/// Number of workers that will be spawned per queue.
|
2023-03-12 16:15:40 +00:00
|
|
|
worker_queues: BTreeMap<String, QueueConfig>,
|
2022-07-31 13:32:37 +00:00
|
|
|
}
|
|
|
|
|
2023-03-11 16:49:23 +00:00
|
|
|
impl<AppData, S> WorkerPool<AppData, S>
|
2022-07-31 13:32:37 +00:00
|
|
|
where
|
2023-03-10 22:41:34 +00:00
|
|
|
AppData: Clone + Send + 'static,
|
2023-03-11 16:49:23 +00:00
|
|
|
S: TaskStore,
|
2022-07-31 13:32:37 +00:00
|
|
|
{
|
2023-03-10 22:41:34 +00:00
|
|
|
/// Create a new worker pool.
|
2023-03-11 16:49:23 +00:00
|
|
|
pub fn new<A>(task_store: S, application_data_fn: A) -> Self
|
2023-03-10 22:41:34 +00:00
|
|
|
where
|
2023-03-11 16:49:23 +00:00
|
|
|
A: Fn(Queue<S>) -> AppData + Send + Sync + 'static,
|
2023-03-10 22:41:34 +00:00
|
|
|
{
|
2023-03-11 15:38:32 +00:00
|
|
|
let queue_store = Arc::new(task_store);
|
2023-03-10 22:41:34 +00:00
|
|
|
let queue = Queue::new(queue_store.clone());
|
|
|
|
let application_data_fn = {
|
|
|
|
let queue = queue.clone();
|
|
|
|
move || application_data_fn(queue.clone())
|
|
|
|
};
|
|
|
|
Self {
|
2023-03-11 16:49:23 +00:00
|
|
|
task_store: queue_store,
|
2023-03-10 22:41:34 +00:00
|
|
|
queue,
|
|
|
|
application_data_fn: Arc::new(application_data_fn),
|
|
|
|
task_registry: BTreeMap::new(),
|
2023-03-11 15:38:32 +00:00
|
|
|
queue_tasks: BTreeMap::new(),
|
2023-03-10 22:41:34 +00:00
|
|
|
worker_queues: BTreeMap::new(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Register a task type with the worker pool.
|
2023-03-11 15:38:32 +00:00
|
|
|
pub fn register_task_type<BT>(mut self) -> Self
|
2023-03-10 22:41:34 +00:00
|
|
|
where
|
|
|
|
BT: BackgroundTask<AppData = AppData>,
|
|
|
|
{
|
2023-03-11 15:38:32 +00:00
|
|
|
self.queue_tasks
|
|
|
|
.entry(BT::QUEUE.to_string())
|
|
|
|
.or_insert_with(Vec::new)
|
|
|
|
.push(BT::TASK_NAME.to_string());
|
2023-03-10 22:41:34 +00:00
|
|
|
self.task_registry
|
|
|
|
.insert(BT::TASK_NAME.to_string(), Arc::new(runnable::<BT>));
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
2023-03-12 16:15:40 +00:00
|
|
|
pub fn configure_queue(mut self, config: QueueConfig) -> Self {
|
|
|
|
self.worker_queues.insert(config.name.clone(), config);
|
2023-03-11 15:38:32 +00:00
|
|
|
self
|
|
|
|
}
|
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
pub async fn start<F>(
|
|
|
|
self,
|
|
|
|
graceful_shutdown: F,
|
2023-03-11 16:49:23 +00:00
|
|
|
) -> Result<(JoinHandle<()>, Queue<S>), BackieError>
|
2023-03-09 15:59:45 +00:00
|
|
|
where
|
|
|
|
F: Future<Output = ()> + Send + 'static,
|
|
|
|
{
|
2023-03-11 15:38:32 +00:00
|
|
|
// Validate that all registered tasks queues are configured
|
|
|
|
for (queue_name, tasks_for_queue) in self.queue_tasks.into_iter() {
|
|
|
|
if !self.worker_queues.contains_key(&queue_name) {
|
|
|
|
return Err(BackieError::QueueNotConfigured(queue_name, tasks_for_queue));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-09 15:59:45 +00:00
|
|
|
let (tx, rx) = tokio::sync::watch::channel(());
|
2023-03-10 22:41:34 +00:00
|
|
|
|
2023-03-11 23:18:15 +00:00
|
|
|
let mut worker_handles = Vec::new();
|
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
// Spawn all individual workers per queue
|
2023-03-12 16:15:40 +00:00
|
|
|
for (queue_name, queue_config) in self.worker_queues.iter() {
|
|
|
|
for idx in 0..queue_config.num_workers {
|
2023-03-11 16:49:23 +00:00
|
|
|
let mut worker: Worker<AppData, S> = Worker::new(
|
|
|
|
self.task_store.clone(),
|
2023-03-10 22:41:34 +00:00
|
|
|
queue_name.clone(),
|
2023-03-12 16:15:40 +00:00
|
|
|
queue_config.retention_mode,
|
|
|
|
queue_config.pull_interval,
|
2023-03-10 22:41:34 +00:00
|
|
|
self.task_registry.clone(),
|
|
|
|
self.application_data_fn.clone(),
|
|
|
|
Some(rx.clone()),
|
|
|
|
);
|
|
|
|
let worker_name = format!("worker-{queue_name}-{idx}");
|
2023-03-11 23:18:15 +00:00
|
|
|
// grabs the join handle for every worker for graceful shutdown
|
|
|
|
let join_handle = tokio::spawn(async move {
|
2023-03-11 15:38:32 +00:00
|
|
|
match worker.run_tasks().await {
|
2023-03-11 21:22:25 +00:00
|
|
|
Ok(()) => log::info!("Worker {worker_name} stopped successfully"),
|
2023-03-11 15:38:32 +00:00
|
|
|
Err(err) => log::error!("Worker {worker_name} stopped due to error: {err}"),
|
|
|
|
}
|
2023-03-10 22:41:34 +00:00
|
|
|
});
|
2023-03-11 23:18:15 +00:00
|
|
|
worker_handles.push(join_handle);
|
2023-03-10 22:41:34 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok((
|
|
|
|
tokio::spawn(async move {
|
|
|
|
graceful_shutdown.await;
|
|
|
|
if let Err(err) = tx.send(()) {
|
|
|
|
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
|
|
|
|
} else {
|
2023-03-11 23:18:15 +00:00
|
|
|
// Wait for all workers to finish processing
|
|
|
|
let results = join_all(worker_handles)
|
|
|
|
.await
|
|
|
|
.into_iter()
|
|
|
|
.filter(Result::is_err)
|
|
|
|
.map(Result::unwrap_err)
|
|
|
|
.collect::<Vec<_>>();
|
|
|
|
if !results.is_empty() {
|
|
|
|
log::error!("Worker pool stopped with errors: {:?}", results);
|
|
|
|
} else {
|
|
|
|
log::info!("Worker pool stopped gracefully");
|
|
|
|
}
|
2023-03-10 22:41:34 +00:00
|
|
|
}
|
|
|
|
}),
|
|
|
|
self.queue,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-12 16:15:40 +00:00
|
|
|
/// Configuration for a queue.
|
|
|
|
///
|
|
|
|
/// This is used to configure the number of workers, the retention mode, and the pulling interval
|
|
|
|
/// for a queue.
|
|
|
|
///
|
|
|
|
/// # Examples
|
|
|
|
///
|
|
|
|
/// Example of configuring a queue with all options:
|
|
|
|
/// ```
|
|
|
|
/// # use backie::QueueConfig;
|
|
|
|
/// # use backie::RetentionMode;
|
|
|
|
/// # use std::time::Duration;
|
|
|
|
/// let config = QueueConfig::new("default")
|
|
|
|
/// .num_workers(5)
|
|
|
|
/// .retention_mode(RetentionMode::KeepAll)
|
|
|
|
/// .pull_interval(Duration::from_secs(1));
|
|
|
|
/// ```
|
|
|
|
/// Example of queue configuration with default options:
|
|
|
|
/// ```
|
|
|
|
/// # use backie::QueueConfig;
|
|
|
|
/// let config = QueueConfig::new("default");
|
|
|
|
/// // Also possible to use the `From` trait:
|
|
|
|
/// let config: QueueConfig = "default".into();
|
|
|
|
/// ```
|
|
|
|
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
|
|
|
pub struct QueueConfig {
|
|
|
|
name: String,
|
|
|
|
num_workers: u32,
|
|
|
|
retention_mode: RetentionMode,
|
|
|
|
pull_interval: Duration,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl QueueConfig {
|
|
|
|
/// Create a new queue configuration.
|
|
|
|
pub fn new(name: impl ToString) -> Self {
|
|
|
|
Self {
|
|
|
|
name: name.to_string(),
|
|
|
|
num_workers: 1,
|
|
|
|
retention_mode: RetentionMode::default(),
|
|
|
|
pull_interval: Duration::from_secs(1),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Set the number of workers for this queue.
|
|
|
|
pub fn num_workers(mut self, num_workers: u32) -> Self {
|
|
|
|
self.num_workers = num_workers;
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Set the retention mode for this queue.
|
|
|
|
pub fn retention_mode(mut self, retention_mode: RetentionMode) -> Self {
|
|
|
|
self.retention_mode = retention_mode;
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Set the pull interval for this queue.
|
|
|
|
///
|
|
|
|
/// This is the interval at which the queue will be checking for new tasks by calling
|
|
|
|
/// the backend storage.
|
|
|
|
pub fn pull_interval(mut self, pull_interval: Duration) -> Self {
|
|
|
|
self.pull_interval = pull_interval;
|
|
|
|
self
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl<S> From<S> for QueueConfig
|
|
|
|
where
|
|
|
|
S: ToString,
|
|
|
|
{
|
|
|
|
fn from(name: S) -> Self {
|
|
|
|
Self::new(name.to_string())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use super::*;
|
2023-03-11 16:49:23 +00:00
|
|
|
use crate::store::test_store::MemoryTaskStore;
|
|
|
|
use crate::store::PgTaskStore;
|
2023-03-11 15:38:32 +00:00
|
|
|
use crate::task::CurrentTask;
|
2023-03-10 22:41:34 +00:00
|
|
|
use async_trait::async_trait;
|
|
|
|
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
|
|
|
use diesel_async::AsyncPgConnection;
|
2023-03-12 14:52:13 +00:00
|
|
|
use futures::FutureExt;
|
|
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
2023-03-11 21:22:25 +00:00
|
|
|
use tokio::sync::Mutex;
|
2023-03-10 22:41:34 +00:00
|
|
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
|
|
struct ApplicationContext {
|
|
|
|
app_name: String,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl ApplicationContext {
|
|
|
|
fn new() -> Self {
|
|
|
|
Self {
|
|
|
|
app_name: "Backie".to_string(),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fn get_app_name(&self) -> String {
|
|
|
|
self.app_name.clone()
|
2022-07-31 13:32:37 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
|
|
|
struct GreetingTask {
|
|
|
|
person: String,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl BackgroundTask for GreetingTask {
|
|
|
|
const TASK_NAME: &'static str = "my_task";
|
|
|
|
|
|
|
|
type AppData = ApplicationContext;
|
|
|
|
|
|
|
|
async fn run(
|
|
|
|
&self,
|
|
|
|
task_info: CurrentTask,
|
|
|
|
app_context: Self::AppData,
|
|
|
|
) -> Result<(), anyhow::Error> {
|
|
|
|
println!(
|
|
|
|
"[{}] Hello {}! I'm {}.",
|
|
|
|
task_info.id(),
|
|
|
|
self.person,
|
|
|
|
app_context.get_app_name()
|
2022-08-04 15:22:53 +00:00
|
|
|
);
|
2023-03-10 22:41:34 +00:00
|
|
|
Ok(())
|
2022-07-31 13:32:37 +00:00
|
|
|
}
|
|
|
|
}
|
2023-03-10 22:41:34 +00:00
|
|
|
|
2023-03-11 16:49:23 +00:00
|
|
|
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
|
|
|
struct OtherTask;
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl BackgroundTask for OtherTask {
|
|
|
|
const TASK_NAME: &'static str = "other_task";
|
|
|
|
|
|
|
|
const QUEUE: &'static str = "other_queue";
|
|
|
|
|
|
|
|
type AppData = ApplicationContext;
|
|
|
|
|
2023-03-11 21:22:25 +00:00
|
|
|
async fn run(
|
|
|
|
&self,
|
|
|
|
task: CurrentTask,
|
|
|
|
context: Self::AppData,
|
|
|
|
) -> Result<(), anyhow::Error> {
|
2023-03-11 16:49:23 +00:00
|
|
|
println!(
|
|
|
|
"[{}] Other task with {}!",
|
|
|
|
task.id(),
|
|
|
|
context.get_app_name()
|
|
|
|
);
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-10 22:41:34 +00:00
|
|
|
#[tokio::test]
|
2023-03-11 15:38:32 +00:00
|
|
|
async fn validate_all_registered_tasks_queues_are_configured() {
|
2023-03-10 22:41:34 +00:00
|
|
|
let my_app_context = ApplicationContext::new();
|
|
|
|
|
2023-03-11 16:49:23 +00:00
|
|
|
let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
2023-03-11 15:38:32 +00:00
|
|
|
.register_task_type::<GreetingTask>()
|
2023-03-10 22:41:34 +00:00
|
|
|
.start(futures::future::ready(()))
|
2023-03-11 15:38:32 +00:00
|
|
|
.await;
|
|
|
|
|
|
|
|
assert!(matches!(result, Err(BackieError::QueueNotConfigured(..))));
|
|
|
|
if let Err(err) = result {
|
|
|
|
assert_eq!(
|
|
|
|
err.to_string(),
|
|
|
|
"Queue \"default\" needs to be configured because of registered tasks: [\"my_task\"]"
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
2023-03-11 16:49:23 +00:00
|
|
|
async fn test_worker_pool_with_task() {
|
2023-03-11 15:38:32 +00:00
|
|
|
let my_app_context = ApplicationContext::new();
|
|
|
|
|
|
|
|
let (join_handle, queue) =
|
2023-03-11 16:49:23 +00:00
|
|
|
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
2023-03-11 15:38:32 +00:00
|
|
|
.register_task_type::<GreetingTask>()
|
2023-03-12 16:15:40 +00:00
|
|
|
.configure_queue(GreetingTask::QUEUE.into())
|
2023-03-11 15:38:32 +00:00
|
|
|
.start(futures::future::ready(()))
|
|
|
|
.await
|
|
|
|
.unwrap();
|
2023-03-10 22:41:34 +00:00
|
|
|
|
|
|
|
queue
|
|
|
|
.enqueue(GreetingTask {
|
|
|
|
person: "Rafael".to_string(),
|
|
|
|
})
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
join_handle.await.unwrap();
|
|
|
|
}
|
|
|
|
|
2023-03-11 16:49:23 +00:00
|
|
|
#[tokio::test]
|
|
|
|
async fn test_worker_pool_with_multiple_task_types() {
|
|
|
|
let my_app_context = ApplicationContext::new();
|
|
|
|
|
|
|
|
let (join_handle, queue) =
|
|
|
|
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
|
|
|
.register_task_type::<GreetingTask>()
|
|
|
|
.register_task_type::<OtherTask>()
|
2023-03-12 16:15:40 +00:00
|
|
|
.configure_queue("default".into())
|
|
|
|
.configure_queue("other_queue".into())
|
2023-03-11 16:49:23 +00:00
|
|
|
.start(futures::future::ready(()))
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
queue
|
|
|
|
.enqueue(GreetingTask {
|
|
|
|
person: "Rafael".to_string(),
|
|
|
|
})
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
queue.enqueue(OtherTask).await.unwrap();
|
|
|
|
|
|
|
|
join_handle.await.unwrap();
|
|
|
|
}
|
|
|
|
|
2023-03-11 21:22:25 +00:00
|
|
|
#[tokio::test]
|
|
|
|
async fn test_worker_pool_stop_after_task_execute() {
|
2023-03-12 14:52:13 +00:00
|
|
|
#[derive(Clone)]
|
|
|
|
struct NotifyFinishedContext {
|
|
|
|
/// Used to notify the task ran
|
|
|
|
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
/// A task that notifies the test that it ran
|
|
|
|
#[derive(serde::Serialize, serde::Deserialize)]
|
|
|
|
struct NotifyFinished;
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl BackgroundTask for NotifyFinished {
|
|
|
|
const TASK_NAME: &'static str = "notify_finished";
|
|
|
|
|
|
|
|
type AppData = NotifyFinishedContext;
|
|
|
|
|
|
|
|
async fn run(
|
|
|
|
&self,
|
|
|
|
task: CurrentTask,
|
|
|
|
context: Self::AppData,
|
|
|
|
) -> Result<(), anyhow::Error> {
|
|
|
|
// Notify the test that the task ran
|
|
|
|
match context.tx.lock().await.take() {
|
|
|
|
None => println!("Cannot notify, already done that!"),
|
|
|
|
Some(tx) => {
|
|
|
|
tx.send(()).unwrap();
|
|
|
|
println!("[{}] Notify finished did it's job!", task.id())
|
|
|
|
}
|
|
|
|
};
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-11 21:22:25 +00:00
|
|
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
|
|
|
|
|
|
let my_app_context = NotifyFinishedContext {
|
|
|
|
tx: Arc::new(Mutex::new(Some(tx))),
|
|
|
|
};
|
|
|
|
|
|
|
|
let (join_handle, queue) =
|
|
|
|
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
|
|
|
.register_task_type::<NotifyFinished>()
|
2023-03-12 16:15:40 +00:00
|
|
|
.configure_queue("default".into())
|
2023-03-11 21:22:25 +00:00
|
|
|
.start(async move {
|
|
|
|
rx.await.unwrap();
|
|
|
|
println!("Worker pool got notified to stop");
|
|
|
|
})
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
// Notifies the worker pool to stop after the task is executed
|
|
|
|
queue.enqueue(NotifyFinished).await.unwrap();
|
|
|
|
|
|
|
|
// This makes sure the task can run multiple times and use the shared context
|
|
|
|
queue.enqueue(NotifyFinished).await.unwrap();
|
|
|
|
|
|
|
|
join_handle.await.unwrap();
|
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
async fn test_worker_pool_try_to_run_unknown_task() {
|
2023-03-12 14:52:13 +00:00
|
|
|
#[derive(Clone)]
|
|
|
|
struct NotifyUnknownRanContext {
|
|
|
|
/// Notify that application should stop
|
|
|
|
should_stop: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
|
|
|
|
|
|
|
/// Used to mark if the unknown task ran
|
|
|
|
unknown_task_ran: Arc<AtomicBool>,
|
|
|
|
}
|
|
|
|
|
|
|
|
/// A task that notifies the test that it ran
|
|
|
|
#[derive(serde::Serialize, serde::Deserialize)]
|
|
|
|
struct NotifyStopDuringRun;
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl BackgroundTask for NotifyStopDuringRun {
|
|
|
|
const TASK_NAME: &'static str = "notify_finished";
|
|
|
|
|
|
|
|
type AppData = NotifyUnknownRanContext;
|
|
|
|
|
|
|
|
async fn run(
|
|
|
|
&self,
|
|
|
|
task: CurrentTask,
|
|
|
|
context: Self::AppData,
|
|
|
|
) -> Result<(), anyhow::Error> {
|
|
|
|
// Notify the test that the task ran
|
|
|
|
match context.should_stop.lock().await.take() {
|
|
|
|
None => println!("Cannot notify, already done that!"),
|
|
|
|
Some(tx) => {
|
|
|
|
tx.send(()).unwrap();
|
|
|
|
println!("[{}] Notify finished did it's job!", task.id())
|
|
|
|
}
|
|
|
|
};
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-11 21:22:25 +00:00
|
|
|
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
|
|
|
struct UnknownTask;
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl BackgroundTask for UnknownTask {
|
|
|
|
const TASK_NAME: &'static str = "unknown_task";
|
|
|
|
|
2023-03-12 14:52:13 +00:00
|
|
|
type AppData = NotifyUnknownRanContext;
|
2023-03-11 21:22:25 +00:00
|
|
|
|
|
|
|
async fn run(
|
|
|
|
&self,
|
|
|
|
task: CurrentTask,
|
2023-03-12 14:52:13 +00:00
|
|
|
context: Self::AppData,
|
2023-03-11 21:22:25 +00:00
|
|
|
) -> Result<(), anyhow::Error> {
|
|
|
|
println!("[{}] Unknown task ran!", task.id());
|
2023-03-12 14:52:13 +00:00
|
|
|
context.unknown_task_ran.store(true, Ordering::Relaxed);
|
2023-03-11 21:22:25 +00:00
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
|
|
|
2023-03-12 14:52:13 +00:00
|
|
|
let my_app_context = NotifyUnknownRanContext {
|
|
|
|
should_stop: Arc::new(Mutex::new(Some(tx))),
|
|
|
|
unknown_task_ran: Arc::new(AtomicBool::new(false)),
|
2023-03-11 21:22:25 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
let task_store = memory_store().await;
|
|
|
|
|
2023-03-12 14:52:13 +00:00
|
|
|
let (join_handle, queue) = WorkerPool::new(task_store, {
|
|
|
|
let my_app_context = my_app_context.clone();
|
|
|
|
move |_| my_app_context.clone()
|
|
|
|
})
|
|
|
|
.register_task_type::<NotifyStopDuringRun>()
|
2023-03-12 16:15:40 +00:00
|
|
|
.configure_queue("default".into())
|
2023-03-12 14:52:13 +00:00
|
|
|
.start(async move {
|
|
|
|
rx.await.unwrap();
|
|
|
|
println!("Worker pool got notified to stop");
|
|
|
|
})
|
|
|
|
.await
|
|
|
|
.unwrap();
|
2023-03-11 21:22:25 +00:00
|
|
|
|
|
|
|
// Enqueue a task that is not registered
|
|
|
|
queue.enqueue(UnknownTask).await.unwrap();
|
|
|
|
|
|
|
|
// Notifies the worker pool to stop for this test
|
2023-03-12 14:52:13 +00:00
|
|
|
queue.enqueue(NotifyStopDuringRun).await.unwrap();
|
2023-03-11 21:22:25 +00:00
|
|
|
|
|
|
|
join_handle.await.unwrap();
|
2023-03-12 14:52:13 +00:00
|
|
|
|
|
|
|
assert!(
|
|
|
|
!my_app_context.unknown_task_ran.load(Ordering::Relaxed),
|
|
|
|
"Unknown task ran but it is not registered in the worker pool!"
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// This test will make sure that the worker pool will only stop after all workers are done.
|
|
|
|
/// We create a KeepAliveTask that will keep running until we notify it to stop.
|
|
|
|
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
|
|
|
|
/// Then we notify the KeepAliveTask to stop and make sure that the worker pool stops.
|
|
|
|
#[tokio::test]
|
|
|
|
async fn tasks_only_stop_running_when_finished() {
|
|
|
|
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
|
|
|
enum PingPongGame {
|
|
|
|
Ping,
|
|
|
|
Pong,
|
|
|
|
StopThisNow,
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
struct PlayerContext {
|
|
|
|
/// Used to communicate with the running task
|
|
|
|
pong_tx: Arc<tokio::sync::mpsc::Sender<PingPongGame>>,
|
|
|
|
ping_rx: Arc<Mutex<tokio::sync::mpsc::Receiver<PingPongGame>>>,
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Task that will respond to the ping pong game and keep alive as long as we need
|
|
|
|
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
|
|
|
struct KeepAliveTask;
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
impl BackgroundTask for KeepAliveTask {
|
|
|
|
const TASK_NAME: &'static str = "keep_alive_task";
|
|
|
|
|
|
|
|
type AppData = PlayerContext;
|
|
|
|
|
|
|
|
async fn run(
|
|
|
|
&self,
|
|
|
|
_task: CurrentTask,
|
|
|
|
context: Self::AppData,
|
|
|
|
) -> Result<(), anyhow::Error> {
|
|
|
|
loop {
|
|
|
|
let msg = context.ping_rx.lock().await.recv().await.unwrap();
|
|
|
|
match msg {
|
|
|
|
PingPongGame::Ping => {
|
|
|
|
println!("Pong!");
|
|
|
|
context.pong_tx.send(PingPongGame::Pong).await.unwrap();
|
|
|
|
}
|
|
|
|
PingPongGame::Pong => {
|
|
|
|
context.pong_tx.send(PingPongGame::Ping).await.unwrap();
|
|
|
|
}
|
|
|
|
PingPongGame::StopThisNow => {
|
|
|
|
println!("Got stop signal, stopping the ping pong game now!");
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
Ok(())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
|
|
|
|
let (pong_tx, mut pong_rx) = tokio::sync::mpsc::channel(1);
|
|
|
|
let (ping_tx, ping_rx) = tokio::sync::mpsc::channel(1);
|
|
|
|
|
|
|
|
let player_context = PlayerContext {
|
|
|
|
pong_tx: Arc::new(pong_tx),
|
|
|
|
ping_rx: Arc::new(Mutex::new(ping_rx)),
|
|
|
|
};
|
|
|
|
|
|
|
|
let task_store = memory_store().await;
|
|
|
|
|
|
|
|
let (worker_pool_finished, queue) = WorkerPool::new(task_store, {
|
|
|
|
let player_context = player_context.clone();
|
|
|
|
move |_| player_context.clone()
|
|
|
|
})
|
|
|
|
.register_task_type::<KeepAliveTask>()
|
2023-03-12 16:15:40 +00:00
|
|
|
.configure_queue("default".into())
|
2023-03-12 14:52:13 +00:00
|
|
|
.start(async move {
|
|
|
|
should_stop.await.unwrap();
|
|
|
|
println!("Worker pool got notified to stop");
|
|
|
|
})
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
queue.enqueue(KeepAliveTask).await.unwrap();
|
|
|
|
|
|
|
|
// Make sure task is running
|
|
|
|
println!("Ping!");
|
|
|
|
ping_tx.send(PingPongGame::Ping).await.unwrap();
|
|
|
|
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
|
|
|
|
|
|
|
|
// Notify to stop the worker pool
|
|
|
|
notify_stop_worker_pool.send(()).unwrap();
|
|
|
|
|
|
|
|
// Make sure task is still running
|
|
|
|
println!("Ping!");
|
|
|
|
ping_tx.send(PingPongGame::Ping).await.unwrap();
|
|
|
|
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
|
|
|
|
|
|
|
|
// is_none() means that the worker pool is still waiting for tasks to finish, which is what we want!
|
|
|
|
assert!(
|
|
|
|
worker_pool_finished.now_or_never().is_none(),
|
|
|
|
"Worker pool finished before task stopped!"
|
|
|
|
);
|
|
|
|
|
|
|
|
// Notify to stop the task, which will stop the worker pool
|
|
|
|
ping_tx.send(PingPongGame::StopThisNow).await.unwrap();
|
2023-03-11 21:22:25 +00:00
|
|
|
}
|
|
|
|
|
2023-03-11 16:49:23 +00:00
|
|
|
async fn memory_store() -> MemoryTaskStore {
|
2023-03-11 17:13:48 +00:00
|
|
|
MemoryTaskStore::default()
|
2023-03-11 16:49:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
#[tokio::test]
|
|
|
|
#[ignore]
|
|
|
|
async fn test_worker_pool_with_pg_store() {
|
|
|
|
let my_app_context = ApplicationContext::new();
|
|
|
|
|
|
|
|
let (join_handle, _queue) =
|
|
|
|
WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone())
|
|
|
|
.register_task_type::<GreetingTask>()
|
2023-03-12 16:15:40 +00:00
|
|
|
.configure_queue(
|
|
|
|
QueueConfig::new(GreetingTask::QUEUE).retention_mode(RetentionMode::RemoveDone),
|
|
|
|
)
|
2023-03-11 16:49:23 +00:00
|
|
|
.start(futures::future::ready(()))
|
|
|
|
.await
|
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
join_handle.await.unwrap();
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn pg_task_store() -> PgTaskStore {
|
2023-03-10 22:41:34 +00:00
|
|
|
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
|
|
|
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
|
|
|
|
);
|
2023-03-11 15:38:32 +00:00
|
|
|
let pool = Pool::builder()
|
2023-03-10 22:41:34 +00:00
|
|
|
.max_size(1)
|
|
|
|
.min_idle(Some(1))
|
|
|
|
.build(manager)
|
|
|
|
.await
|
2023-03-11 15:38:32 +00:00
|
|
|
.unwrap();
|
|
|
|
|
|
|
|
PgTaskStore::new(pool)
|
2023-03-10 22:41:34 +00:00
|
|
|
}
|
2022-07-31 13:32:37 +00:00
|
|
|
}
|