Make queues configurable

This commit is contained in:
Rafael Caricio 2023-03-11 16:38:32 +01:00
parent aac0b44c7f
commit fd92b25190
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
12 changed files with 217 additions and 274 deletions

View file

@ -31,8 +31,9 @@ async fn main() {
// Register the task types I want to use and start the worker pool // Register the task types I want to use and start the worker pool
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone()) let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
.register_task_type::<MyTask>(1, RetentionMode::RemoveDone) .register_task_type::<MyTask>()
.register_task_type::<MyFailingTask>(1, RetentionMode::RemoveDone) .register_task_type::<MyFailingTask>()
.configure_queue("default", 3, RetentionMode::RemoveDone)
.start(async move { .start(async move {
let _ = rx.changed().await; let _ = rx.changed().await;
}) })

View file

@ -1,2 +1 @@
DROP TABLE backie_tasks; DROP TABLE backie_tasks;
DROP FUNCTION backie_notify_new_tasks;

View file

@ -18,12 +18,3 @@ CREATE TABLE backie_tasks (
--- create uniqueness index --- create uniqueness index
CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL; CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL;
CREATE FUNCTION backie_notify_new_tasks() returns trigger as $$
BEGIN
perform pg_notify('backie::tasks', 'created');
return new;
END;
$$ language plpgsql;
CREATE TRIGGER backie_notify_workers after insert on backie_tasks for each statement execute procedure backie_notify_new_tasks();

View file

@ -3,6 +3,12 @@ use thiserror::Error;
/// Library errors /// Library errors
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum BackieError { pub enum BackieError {
#[error("Queue \"{0}\" needs to be configured because of registered tasks: {1:?}")]
QueueNotConfigured(String, Vec<String>),
#[error("Provided task is not serializable to JSON: {0}")]
NonSerializableTask(#[from] serde_json::Error),
#[error("Queue processing error: {0}")] #[error("Queue processing error: {0}")]
QueueProcessingError(#[from] AsyncQueueError), QueueProcessingError(#[from] AsyncQueueError),
@ -13,40 +19,11 @@ pub enum BackieError {
WorkerShutdownError(#[from] tokio::sync::watch::error::RecvError), WorkerShutdownError(#[from] tokio::sync::watch::error::RecvError),
} }
/// List of error types that can occur while working with cron schedules.
#[derive(Debug, Error)]
pub enum CronError {
/// A problem occured during cron schedule parsing.
#[error(transparent)]
LibraryError(#[from] cron::error::Error),
/// [`Scheduled`] enum variant is not provided
#[error("You have to implement method `cron()` in your Runnable")]
TaskNotSchedulableError,
/// The next execution can not be determined using the current [`Scheduled::CronPattern`]
#[error("No timestamps match with this cron pattern")]
NoTimestampsError,
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum AsyncQueueError { pub enum AsyncQueueError {
#[error(transparent)] #[error(transparent)]
PgError(#[from] diesel::result::Error), PgError(#[from] diesel::result::Error),
#[error("Task serialization error: {0}")]
SerdeError(#[from] serde_json::Error),
#[error(transparent)]
CronError(#[from] CronError),
#[error("Task is not in progress, operation not allowed")]
TaskNotRunning,
#[error("Task with name {0} is not registered")] #[error("Task with name {0} is not registered")]
TaskNotRegistered(String), TaskNotRegistered(String),
} }
impl From<cron::error::Error> for AsyncQueueError {
fn from(error: cron::error::Error) -> Self {
AsyncQueueError::CronError(CronError::LibraryError(error))
}
}

View file

@ -4,7 +4,7 @@ use chrono::{DateTime, Utc};
/// Represents a schedule for scheduled tasks. /// Represents a schedule for scheduled tasks.
/// ///
/// It's used in the [`AsyncRunnable::cron`] and [`Runnable::cron`] /// It's used in the [`BackgroundTask::cron`]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Scheduled { pub enum Scheduled {
/// A cron pattern for a periodic task /// A cron pattern for a periodic task
@ -38,8 +38,8 @@ impl Default for RetentionMode {
} }
} }
pub use queue::PgTaskStore;
pub use runnable::BackgroundTask; pub use runnable::BackgroundTask;
pub use store::{PgTaskStore, TaskStore};
pub use task::CurrentTask; pub use task::CurrentTask;
pub use worker_pool::WorkerPool; pub use worker_pool::WorkerPool;

View file

@ -1,7 +1,7 @@
use crate::errors::AsyncQueueError; use crate::errors::AsyncQueueError;
use crate::schema::backie_tasks; use crate::schema::backie_tasks;
use crate::task::Task; use crate::task::Task;
use crate::task::{NewTask, TaskHash, TaskId}; use crate::task::{NewTask, TaskId};
use chrono::DateTime; use chrono::DateTime;
use chrono::Duration; use chrono::Duration;
use chrono::Utc; use chrono::Utc;
@ -10,21 +10,6 @@ use diesel::ExpressionMethods;
use diesel_async::{pg::AsyncPgConnection, RunQueryDsl}; use diesel_async::{pg::AsyncPgConnection, RunQueryDsl};
impl Task { impl Task {
pub(crate) async fn remove_all(
connection: &mut AsyncPgConnection,
) -> Result<u64, AsyncQueueError> {
Ok(diesel::delete(backie_tasks::table)
.execute(connection)
.await? as u64)
}
pub(crate) async fn remove_all_scheduled(
connection: &mut AsyncPgConnection,
) -> Result<u64, AsyncQueueError> {
let query = backie_tasks::table.filter(backie_tasks::running_at.is_null());
Ok(diesel::delete(query).execute(connection).await? as u64)
}
pub(crate) async fn remove( pub(crate) async fn remove(
connection: &mut AsyncPgConnection, connection: &mut AsyncPgConnection,
id: TaskId, id: TaskId,
@ -33,26 +18,6 @@ impl Task {
Ok(diesel::delete(query).execute(connection).await? as u64) Ok(diesel::delete(query).execute(connection).await? as u64)
} }
pub(crate) async fn remove_by_hash(
connection: &mut AsyncPgConnection,
task_hash: TaskHash,
) -> Result<bool, AsyncQueueError> {
let query = backie_tasks::table.filter(backie_tasks::uniq_hash.eq(task_hash));
let qty = diesel::delete(query).execute(connection).await?;
Ok(qty > 0)
}
pub(crate) async fn find_by_id(
connection: &mut AsyncPgConnection,
id: TaskId,
) -> Result<Task, AsyncQueueError> {
let task = backie_tasks::table
.filter(backie_tasks::id.eq(id))
.first::<Task>(connection)
.await?;
Ok(task)
}
pub(crate) async fn fail_with_message( pub(crate) async fn fail_with_message(
connection: &mut AsyncPgConnection, connection: &mut AsyncPgConnection,
id: TaskId, id: TaskId,
@ -148,15 +113,4 @@ impl Task {
.get_result::<Task>(connection) .get_result::<Task>(connection)
.await?) .await?)
} }
pub(crate) async fn find_by_uniq_hash(
connection: &mut AsyncPgConnection,
hash: TaskHash,
) -> Option<Task> {
backie_tasks::table
.filter(backie_tasks::uniq_hash.eq(hash))
.first::<Task>(connection)
.await
.ok()
}
} }

View file

@ -1,10 +1,7 @@
use crate::errors::AsyncQueueError; use crate::errors::BackieError;
use crate::runnable::BackgroundTask; use crate::runnable::BackgroundTask;
use crate::task::{NewTask, Task, TaskHash, TaskId, TaskState}; use crate::store::{PgTaskStore, TaskStore};
use diesel::result::Error::QueryBuilderError; use crate::task::{NewTask, TaskHash};
use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection;
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -18,7 +15,7 @@ impl Queue {
Queue { task_store } Queue { task_store }
} }
pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), AsyncQueueError> pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), BackieError>
where where
BT: BackgroundTask, BT: BackgroundTask,
{ {
@ -29,138 +26,11 @@ impl Queue {
} }
} }
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
#[derive(Debug, Clone)]
pub struct PgTaskStore {
pool: Pool<AsyncPgConnection>,
}
impl PgTaskStore {
pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
PgTaskStore { pool }
}
pub(crate) async fn pull_next_task(
&self,
queue_name: &str,
) -> Result<Option<Task>, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
connection
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
async move {
let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else {
return Ok(None);
};
Task::set_running(conn, pending_task).await.map(Some)
}
.scope_boxed()
})
.await
}
pub(crate) async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::insert(&mut connection, new_task).await
}
pub(crate) async fn find_task_by_id(&self, id: TaskId) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::find_by_id(&mut connection, id).await
}
pub(crate) async fn set_task_state(
&self,
id: TaskId,
state: TaskState,
) -> Result<(), AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
match state {
TaskState::Done => Task::set_done(&mut connection, id).await?,
TaskState::Failed(error_msg) => {
Task::fail_with_message(&mut connection, id, &error_msg).await?
}
_ => return Ok(()),
};
Ok(())
}
pub(crate) async fn keep_task_alive(&self, id: TaskId) -> Result<(), AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
connection
.transaction::<(), AsyncQueueError, _>(|conn| {
async move {
let task = Task::find_by_id(conn, id).await?;
Task::set_running(conn, task).await?;
Ok(())
}
.scope_boxed()
})
.await
}
pub(crate) async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let result = Task::remove(&mut connection, id).await?;
Ok(result)
}
pub(crate) async fn remove_all_tasks(&self) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::remove_all(&mut connection).await
}
pub(crate) async fn schedule_task_retry(
&self,
id: TaskId,
backoff_seconds: u32,
error: &str,
) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let task = Task::schedule_retry(&mut connection, id, backoff_seconds, error).await?;
Ok(task)
}
}
#[cfg(test)] #[cfg(test)]
mod async_queue_tests { mod async_queue_tests {
use super::*; use super::*;
use crate::CurrentTask; use crate::CurrentTask;
use async_trait::async_trait; use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -351,16 +221,4 @@ mod async_queue_tests {
// //
// queue.remove_all_tasks().await.unwrap(); // queue.remove_all_tasks().await.unwrap();
// } // }
async fn pool() -> Pool<AsyncPgConnection> {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
);
Pool::builder()
.max_size(1)
.min_idle(Some(1))
.build(manager)
.await
.unwrap()
}
} }

View file

@ -2,10 +2,39 @@ use crate::task::{CurrentTask, TaskHash};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{de::DeserializeOwned, ser::Serialize}; use serde::{de::DeserializeOwned, ser::Serialize};
/// Task that can be executed by the queue. /// The [`BackgroundTask`] trait is used to define the behaviour of a task. You must implement this
///
/// The `BackgroundTask` trait is used to define the behaviour of a task. You must implement this
/// trait for all tasks you want to execute. /// trait for all tasks you want to execute.
///
/// The [`BackgroundTask::TASK_NAME`] attribute must be unique for the whole application. This
/// attribute is critical for reconstructing the task back from the database.
///
/// The [`BackgroundTask::AppData`] can be used to argument the task with application specific
/// contextual information. This is useful for example to pass a database connection pool to the
/// task or other application configuration.
///
/// The [`BackgroundTask::run`] method is the main method of the task. It is executed by the
/// the task queue workers.
///
///
/// # Example
/// ```rust
/// use async_trait::async_trait;
/// use backie::{BackgroundTask, CurrentTask};
/// use serde::{Deserialize, Serialize};
///
/// #[derive(Serialize, Deserialize)]
/// pub struct MyTask {}
///
/// impl BackgroundTask for MyTask {
/// const TASK_NAME: &'static str = "my_task_unique_name";
/// type AppData = ();
///
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> {
/// // Do something
/// Ok(())
/// }
/// }
/// ```
#[async_trait] #[async_trait]
pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static { pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
/// Unique name of the task. /// Unique name of the task.
@ -15,7 +44,7 @@ pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
/// Task queue where this task will be executed. /// Task queue where this task will be executed.
/// ///
/// Used to define which workers are going to be executing this task. It uses the default /// Used to route to which workers are going to be executing this task. It uses the default
/// task queue if not changed. /// task queue if not changed.
const QUEUE: &'static str = "default"; const QUEUE: &'static str = "default";

View file

@ -1 +1,101 @@
use crate::errors::AsyncQueueError;
use crate::task::{NewTask, Task, TaskId, TaskState};
use diesel::result::Error::QueryBuilderError;
use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection;
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
#[derive(Debug, Clone)]
pub struct PgTaskStore {
pool: Pool<AsyncPgConnection>,
}
impl PgTaskStore {
pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
PgTaskStore { pool }
}
}
#[async_trait::async_trait]
impl TaskStore for PgTaskStore {
async fn pull_next_task(&self, queue_name: &str) -> Result<Option<Task>, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
connection
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
async move {
let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else {
return Ok(None);
};
Task::set_running(conn, pending_task).await.map(Some)
}
.scope_boxed()
})
.await
}
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::insert(&mut connection, new_task).await
}
async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
match state {
TaskState::Done => Task::set_done(&mut connection, id).await?,
TaskState::Failed(error_msg) => {
Task::fail_with_message(&mut connection, id, &error_msg).await?
}
_ => return Ok(()),
};
Ok(())
}
async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let result = Task::remove(&mut connection, id).await?;
Ok(result)
}
async fn schedule_task_retry(
&self,
id: TaskId,
backoff_seconds: u32,
error: &str,
) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let task = Task::schedule_retry(&mut connection, id, backoff_seconds, error).await?;
Ok(task)
}
}
#[async_trait::async_trait]
pub trait TaskStore {
async fn pull_next_task(&self, queue_name: &str) -> Result<Option<Task>, AsyncQueueError>;
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError>;
async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError>;
async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError>;
async fn schedule_task_retry(
&self,
id: TaskId,
backoff_seconds: u32,
error: &str,
) -> Result<Task, AsyncQueueError>;
}

View file

@ -115,7 +115,7 @@ impl Task {
#[derive(Insertable, Debug, Eq, PartialEq, Clone)] #[derive(Insertable, Debug, Eq, PartialEq, Clone)]
#[diesel(table_name = backie_tasks)] #[diesel(table_name = backie_tasks)]
pub(crate) struct NewTask { pub struct NewTask {
task_name: String, task_name: String,
queue_name: String, queue_name: String,
uniq_hash: Option<TaskHash>, uniq_hash: Option<TaskHash>,

View file

@ -1,5 +1,6 @@
use crate::errors::{AsyncQueueError, BackieError}; use crate::errors::{AsyncQueueError, BackieError};
use crate::runnable::BackgroundTask; use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
use crate::task::{CurrentTask, Task, TaskState}; use crate::task::{CurrentTask, Task, TaskState};
use crate::{PgTaskStore, RetentionMode}; use crate::{PgTaskStore, RetentionMode};
use futures::future::FutureExt; use futures::future::FutureExt;
@ -162,7 +163,7 @@ where
Ok(_) => self.finalize_task(task, result).await?, Ok(_) => self.finalize_task(task, result).await?,
Err(error) => { Err(error) => {
if task.retries < task.max_retries { if task.retries < task.max_retries {
let backoff_seconds = 5; // TODO: runnable.backoff(task.retries as u32); let backoff_seconds = 5; // TODO: runnable_task.backoff(task.retries as u32);
log::debug!( log::debug!(
"Task {} failed to run and will be retried in {} seconds", "Task {} failed to run and will be retried in {} seconds",
@ -226,8 +227,6 @@ where
mod async_worker_tests { mod async_worker_tests {
use super::*; use super::*;
use async_trait::async_trait; use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -332,16 +331,4 @@ mod async_worker_tests {
Ok(()) Ok(())
} }
} }
async fn pool() -> Pool<AsyncPgConnection> {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
);
Pool::builder()
.max_size(1)
.min_idle(Some(1))
.build(manager)
.await
.unwrap()
}
} }

View file

@ -2,7 +2,7 @@ use crate::errors::BackieError;
use crate::queue::Queue; use crate::queue::Queue;
use crate::worker::{runnable, ExecuteTaskFn}; use crate::worker::{runnable, ExecuteTaskFn};
use crate::worker::{StateFn, Worker}; use crate::worker::{StateFn, Worker};
use crate::{BackgroundTask, CurrentTask, PgTaskStore, RetentionMode}; use crate::{BackgroundTask, PgTaskStore, RetentionMode};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
@ -31,6 +31,9 @@ where
/// The types of task the worker pool can execute and the loaders for them. /// The types of task the worker pool can execute and the loaders for them.
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>, task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
/// The queue names for the registered tasks.
queue_tasks: BTreeMap<String, Vec<String>>,
/// Number of workers that will be spawned per queue. /// Number of workers that will be spawned per queue.
worker_queues: BTreeMap<String, (RetentionMode, u32)>, worker_queues: BTreeMap<String, (RetentionMode, u32)>,
} }
@ -40,11 +43,11 @@ where
AppData: Clone + Send + 'static, AppData: Clone + Send + 'static,
{ {
/// Create a new worker pool. /// Create a new worker pool.
pub fn new<A>(queue_store: PgTaskStore, application_data_fn: A) -> Self pub fn new<A>(task_store: PgTaskStore, application_data_fn: A) -> Self
where where
A: Fn(Queue) -> AppData + Send + Sync + 'static, A: Fn(Queue) -> AppData + Send + Sync + 'static,
{ {
let queue_store = Arc::new(queue_store); let queue_store = Arc::new(task_store);
let queue = Queue::new(queue_store.clone()); let queue = Queue::new(queue_store.clone());
let application_data_fn = { let application_data_fn = {
let queue = queue.clone(); let queue = queue.clone();
@ -55,22 +58,36 @@ where
queue, queue,
application_data_fn: Arc::new(application_data_fn), application_data_fn: Arc::new(application_data_fn),
task_registry: BTreeMap::new(), task_registry: BTreeMap::new(),
queue_tasks: BTreeMap::new(),
worker_queues: BTreeMap::new(), worker_queues: BTreeMap::new(),
} }
} }
/// Register a task type with the worker pool. /// Register a task type with the worker pool.
pub fn register_task_type<BT>(mut self, num_workers: u32, retention_mode: RetentionMode) -> Self pub fn register_task_type<BT>(mut self) -> Self
where where
BT: BackgroundTask<AppData = AppData>, BT: BackgroundTask<AppData = AppData>,
{ {
self.worker_queues self.queue_tasks
.insert(BT::QUEUE.to_string(), (retention_mode, num_workers)); .entry(BT::QUEUE.to_string())
.or_insert_with(Vec::new)
.push(BT::TASK_NAME.to_string());
self.task_registry self.task_registry
.insert(BT::TASK_NAME.to_string(), Arc::new(runnable::<BT>)); .insert(BT::TASK_NAME.to_string(), Arc::new(runnable::<BT>));
self self
} }
pub fn configure_queue(
mut self,
queue_name: impl ToString,
num_workers: u32,
retention_mode: RetentionMode,
) -> Self {
self.worker_queues
.insert(queue_name.to_string(), (retention_mode, num_workers));
self
}
pub async fn start<F>( pub async fn start<F>(
self, self,
graceful_shutdown: F, graceful_shutdown: F,
@ -78,6 +95,13 @@ where
where where
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
{ {
// Validate that all registered tasks queues are configured
for (queue_name, tasks_for_queue) in self.queue_tasks.into_iter() {
if !self.worker_queues.contains_key(&queue_name) {
return Err(BackieError::QueueNotConfigured(queue_name, tasks_for_queue));
}
}
let (tx, rx) = tokio::sync::watch::channel(()); let (tx, rx) = tokio::sync::watch::channel(());
// Spawn all individual workers per queue // Spawn all individual workers per queue
@ -94,8 +118,10 @@ where
let worker_name = format!("worker-{queue_name}-{idx}"); let worker_name = format!("worker-{queue_name}-{idx}");
// TODO: grab the join handle for every worker for graceful shutdown // TODO: grab the join handle for every worker for graceful shutdown
tokio::spawn(async move { tokio::spawn(async move {
worker.run_tasks().await; match worker.run_tasks().await {
log::info!("Worker {} stopped", worker_name); Ok(()) => log::info!("Worker {worker_name} stopped sucessfully"),
Err(err) => log::error!("Worker {worker_name} stopped due to error: {err}"),
}
}); });
} }
} }
@ -117,6 +143,7 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::task::CurrentTask;
use async_trait::async_trait; use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection; use diesel_async::AsyncPgConnection;
@ -164,17 +191,35 @@ mod tests {
} }
} }
#[tokio::test]
async fn validate_all_registered_tasks_queues_are_configured() {
let my_app_context = ApplicationContext::new();
let result = WorkerPool::new(task_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.start(futures::future::ready(()))
.await;
assert!(matches!(result, Err(BackieError::QueueNotConfigured(..))));
if let Err(err) = result {
assert_eq!(
err.to_string(),
"Queue \"default\" needs to be configured because of registered tasks: [\"my_task\"]"
);
}
}
#[tokio::test] #[tokio::test]
async fn test_worker_pool() { async fn test_worker_pool() {
let my_app_context = ApplicationContext::new(); let my_app_context = ApplicationContext::new();
let task_store = PgTaskStore::new(pool().await); let (join_handle, queue) =
WorkerPool::new(task_store().await, move |_| my_app_context.clone())
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone()) .register_task_type::<GreetingTask>()
.register_task_type::<GreetingTask>(1, RetentionMode::RemoveDone) .configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
.start(futures::future::ready(())) .start(futures::future::ready(()))
.await .await
.unwrap(); .unwrap();
queue queue
.enqueue(GreetingTask { .enqueue(GreetingTask {
@ -186,15 +231,17 @@ mod tests {
join_handle.await.unwrap(); join_handle.await.unwrap();
} }
async fn pool() -> Pool<AsyncPgConnection> { async fn task_store() -> PgTaskStore {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new( let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"), option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
); );
Pool::builder() let pool = Pool::builder()
.max_size(1) .max_size(1)
.min_idle(Some(1)) .min_idle(Some(1))
.build(manager) .build(manager)
.await .await
.unwrap() .unwrap();
PgTaskStore::new(pool)
} }
} }