Make queues configurable
This commit is contained in:
parent
aac0b44c7f
commit
fd92b25190
12 changed files with 217 additions and 274 deletions
|
@ -31,8 +31,9 @@ async fn main() {
|
|||
|
||||
// Register the task types I want to use and start the worker pool
|
||||
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||
.register_task_type::<MyTask>(1, RetentionMode::RemoveDone)
|
||||
.register_task_type::<MyFailingTask>(1, RetentionMode::RemoveDone)
|
||||
.register_task_type::<MyTask>()
|
||||
.register_task_type::<MyFailingTask>()
|
||||
.configure_queue("default", 3, RetentionMode::RemoveDone)
|
||||
.start(async move {
|
||||
let _ = rx.changed().await;
|
||||
})
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
DROP TABLE backie_tasks;
|
||||
DROP FUNCTION backie_notify_new_tasks;
|
||||
|
|
|
@ -18,12 +18,3 @@ CREATE TABLE backie_tasks (
|
|||
|
||||
--- create uniqueness index
|
||||
CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL;
|
||||
|
||||
CREATE FUNCTION backie_notify_new_tasks() returns trigger as $$
|
||||
BEGIN
|
||||
perform pg_notify('backie::tasks', 'created');
|
||||
return new;
|
||||
END;
|
||||
$$ language plpgsql;
|
||||
|
||||
CREATE TRIGGER backie_notify_workers after insert on backie_tasks for each statement execute procedure backie_notify_new_tasks();
|
||||
|
|
|
@ -3,6 +3,12 @@ use thiserror::Error;
|
|||
/// Library errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum BackieError {
|
||||
#[error("Queue \"{0}\" needs to be configured because of registered tasks: {1:?}")]
|
||||
QueueNotConfigured(String, Vec<String>),
|
||||
|
||||
#[error("Provided task is not serializable to JSON: {0}")]
|
||||
NonSerializableTask(#[from] serde_json::Error),
|
||||
|
||||
#[error("Queue processing error: {0}")]
|
||||
QueueProcessingError(#[from] AsyncQueueError),
|
||||
|
||||
|
@ -13,40 +19,11 @@ pub enum BackieError {
|
|||
WorkerShutdownError(#[from] tokio::sync::watch::error::RecvError),
|
||||
}
|
||||
|
||||
/// List of error types that can occur while working with cron schedules.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CronError {
|
||||
/// A problem occured during cron schedule parsing.
|
||||
#[error(transparent)]
|
||||
LibraryError(#[from] cron::error::Error),
|
||||
/// [`Scheduled`] enum variant is not provided
|
||||
#[error("You have to implement method `cron()` in your Runnable")]
|
||||
TaskNotSchedulableError,
|
||||
/// The next execution can not be determined using the current [`Scheduled::CronPattern`]
|
||||
#[error("No timestamps match with this cron pattern")]
|
||||
NoTimestampsError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AsyncQueueError {
|
||||
#[error(transparent)]
|
||||
PgError(#[from] diesel::result::Error),
|
||||
|
||||
#[error("Task serialization error: {0}")]
|
||||
SerdeError(#[from] serde_json::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
CronError(#[from] CronError),
|
||||
|
||||
#[error("Task is not in progress, operation not allowed")]
|
||||
TaskNotRunning,
|
||||
|
||||
#[error("Task with name {0} is not registered")]
|
||||
TaskNotRegistered(String),
|
||||
}
|
||||
|
||||
impl From<cron::error::Error> for AsyncQueueError {
|
||||
fn from(error: cron::error::Error) -> Self {
|
||||
AsyncQueueError::CronError(CronError::LibraryError(error))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use chrono::{DateTime, Utc};
|
|||
|
||||
/// Represents a schedule for scheduled tasks.
|
||||
///
|
||||
/// It's used in the [`AsyncRunnable::cron`] and [`Runnable::cron`]
|
||||
/// It's used in the [`BackgroundTask::cron`]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Scheduled {
|
||||
/// A cron pattern for a periodic task
|
||||
|
@ -38,8 +38,8 @@ impl Default for RetentionMode {
|
|||
}
|
||||
}
|
||||
|
||||
pub use queue::PgTaskStore;
|
||||
pub use runnable::BackgroundTask;
|
||||
pub use store::{PgTaskStore, TaskStore};
|
||||
pub use task::CurrentTask;
|
||||
pub use worker_pool::WorkerPool;
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::errors::AsyncQueueError;
|
||||
use crate::schema::backie_tasks;
|
||||
use crate::task::Task;
|
||||
use crate::task::{NewTask, TaskHash, TaskId};
|
||||
use crate::task::{NewTask, TaskId};
|
||||
use chrono::DateTime;
|
||||
use chrono::Duration;
|
||||
use chrono::Utc;
|
||||
|
@ -10,21 +10,6 @@ use diesel::ExpressionMethods;
|
|||
use diesel_async::{pg::AsyncPgConnection, RunQueryDsl};
|
||||
|
||||
impl Task {
|
||||
pub(crate) async fn remove_all(
|
||||
connection: &mut AsyncPgConnection,
|
||||
) -> Result<u64, AsyncQueueError> {
|
||||
Ok(diesel::delete(backie_tasks::table)
|
||||
.execute(connection)
|
||||
.await? as u64)
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_all_scheduled(
|
||||
connection: &mut AsyncPgConnection,
|
||||
) -> Result<u64, AsyncQueueError> {
|
||||
let query = backie_tasks::table.filter(backie_tasks::running_at.is_null());
|
||||
Ok(diesel::delete(query).execute(connection).await? as u64)
|
||||
}
|
||||
|
||||
pub(crate) async fn remove(
|
||||
connection: &mut AsyncPgConnection,
|
||||
id: TaskId,
|
||||
|
@ -33,26 +18,6 @@ impl Task {
|
|||
Ok(diesel::delete(query).execute(connection).await? as u64)
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_by_hash(
|
||||
connection: &mut AsyncPgConnection,
|
||||
task_hash: TaskHash,
|
||||
) -> Result<bool, AsyncQueueError> {
|
||||
let query = backie_tasks::table.filter(backie_tasks::uniq_hash.eq(task_hash));
|
||||
let qty = diesel::delete(query).execute(connection).await?;
|
||||
Ok(qty > 0)
|
||||
}
|
||||
|
||||
pub(crate) async fn find_by_id(
|
||||
connection: &mut AsyncPgConnection,
|
||||
id: TaskId,
|
||||
) -> Result<Task, AsyncQueueError> {
|
||||
let task = backie_tasks::table
|
||||
.filter(backie_tasks::id.eq(id))
|
||||
.first::<Task>(connection)
|
||||
.await?;
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
pub(crate) async fn fail_with_message(
|
||||
connection: &mut AsyncPgConnection,
|
||||
id: TaskId,
|
||||
|
@ -148,15 +113,4 @@ impl Task {
|
|||
.get_result::<Task>(connection)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub(crate) async fn find_by_uniq_hash(
|
||||
connection: &mut AsyncPgConnection,
|
||||
hash: TaskHash,
|
||||
) -> Option<Task> {
|
||||
backie_tasks::table
|
||||
.filter(backie_tasks::uniq_hash.eq(hash))
|
||||
.first::<Task>(connection)
|
||||
.await
|
||||
.ok()
|
||||
}
|
||||
}
|
||||
|
|
150
src/queue.rs
150
src/queue.rs
|
@ -1,10 +1,7 @@
|
|||
use crate::errors::AsyncQueueError;
|
||||
use crate::errors::BackieError;
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::task::{NewTask, Task, TaskHash, TaskId, TaskState};
|
||||
use diesel::result::Error::QueryBuilderError;
|
||||
use diesel_async::scoped_futures::ScopedFutureExt;
|
||||
use diesel_async::AsyncConnection;
|
||||
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
|
||||
use crate::store::{PgTaskStore, TaskStore};
|
||||
use crate::task::{NewTask, TaskHash};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
|
@ -18,7 +15,7 @@ impl Queue {
|
|||
Queue { task_store }
|
||||
}
|
||||
|
||||
pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), AsyncQueueError>
|
||||
pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), BackieError>
|
||||
where
|
||||
BT: BackgroundTask,
|
||||
{
|
||||
|
@ -29,138 +26,11 @@ impl Queue {
|
|||
}
|
||||
}
|
||||
|
||||
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PgTaskStore {
|
||||
pool: Pool<AsyncPgConnection>,
|
||||
}
|
||||
|
||||
impl PgTaskStore {
|
||||
pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
|
||||
PgTaskStore { pool }
|
||||
}
|
||||
|
||||
pub(crate) async fn pull_next_task(
|
||||
&self,
|
||||
queue_name: &str,
|
||||
) -> Result<Option<Task>, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
connection
|
||||
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
|
||||
async move {
|
||||
let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Task::set_running(conn, pending_task).await.map(Some)
|
||||
}
|
||||
.scope_boxed()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Task::insert(&mut connection, new_task).await
|
||||
}
|
||||
|
||||
pub(crate) async fn find_task_by_id(&self, id: TaskId) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Task::find_by_id(&mut connection, id).await
|
||||
}
|
||||
|
||||
pub(crate) async fn set_task_state(
|
||||
&self,
|
||||
id: TaskId,
|
||||
state: TaskState,
|
||||
) -> Result<(), AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
match state {
|
||||
TaskState::Done => Task::set_done(&mut connection, id).await?,
|
||||
TaskState::Failed(error_msg) => {
|
||||
Task::fail_with_message(&mut connection, id, &error_msg).await?
|
||||
}
|
||||
_ => return Ok(()),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn keep_task_alive(&self, id: TaskId) -> Result<(), AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
connection
|
||||
.transaction::<(), AsyncQueueError, _>(|conn| {
|
||||
async move {
|
||||
let task = Task::find_by_id(conn, id).await?;
|
||||
Task::set_running(conn, task).await?;
|
||||
Ok(())
|
||||
}
|
||||
.scope_boxed()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
let result = Task::remove(&mut connection, id).await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_all_tasks(&self) -> Result<u64, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Task::remove_all(&mut connection).await
|
||||
}
|
||||
|
||||
pub(crate) async fn schedule_task_retry(
|
||||
&self,
|
||||
id: TaskId,
|
||||
backoff_seconds: u32,
|
||||
error: &str,
|
||||
) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
let task = Task::schedule_retry(&mut connection, id, backoff_seconds, error).await?;
|
||||
Ok(task)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod async_queue_tests {
|
||||
use super::*;
|
||||
use crate::CurrentTask;
|
||||
use async_trait::async_trait;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
@ -351,16 +221,4 @@ mod async_queue_tests {
|
|||
//
|
||||
// queue.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
|
||||
async fn pool() -> Pool<AsyncPgConnection> {
|
||||
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
||||
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
|
||||
);
|
||||
Pool::builder()
|
||||
.max_size(1)
|
||||
.min_idle(Some(1))
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,10 +2,39 @@ use crate::task::{CurrentTask, TaskHash};
|
|||
use async_trait::async_trait;
|
||||
use serde::{de::DeserializeOwned, ser::Serialize};
|
||||
|
||||
/// Task that can be executed by the queue.
|
||||
///
|
||||
/// The `BackgroundTask` trait is used to define the behaviour of a task. You must implement this
|
||||
/// The [`BackgroundTask`] trait is used to define the behaviour of a task. You must implement this
|
||||
/// trait for all tasks you want to execute.
|
||||
///
|
||||
/// The [`BackgroundTask::TASK_NAME`] attribute must be unique for the whole application. This
|
||||
/// attribute is critical for reconstructing the task back from the database.
|
||||
///
|
||||
/// The [`BackgroundTask::AppData`] can be used to argument the task with application specific
|
||||
/// contextual information. This is useful for example to pass a database connection pool to the
|
||||
/// task or other application configuration.
|
||||
///
|
||||
/// The [`BackgroundTask::run`] method is the main method of the task. It is executed by the
|
||||
/// the task queue workers.
|
||||
///
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// use async_trait::async_trait;
|
||||
/// use backie::{BackgroundTask, CurrentTask};
|
||||
/// use serde::{Deserialize, Serialize};
|
||||
///
|
||||
/// #[derive(Serialize, Deserialize)]
|
||||
/// pub struct MyTask {}
|
||||
///
|
||||
/// impl BackgroundTask for MyTask {
|
||||
/// const TASK_NAME: &'static str = "my_task_unique_name";
|
||||
/// type AppData = ();
|
||||
///
|
||||
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
/// // Do something
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[async_trait]
|
||||
pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
|
||||
/// Unique name of the task.
|
||||
|
@ -15,7 +44,7 @@ pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
|
|||
|
||||
/// Task queue where this task will be executed.
|
||||
///
|
||||
/// Used to define which workers are going to be executing this task. It uses the default
|
||||
/// Used to route to which workers are going to be executing this task. It uses the default
|
||||
/// task queue if not changed.
|
||||
const QUEUE: &'static str = "default";
|
||||
|
||||
|
|
100
src/store.rs
100
src/store.rs
|
@ -1 +1,101 @@
|
|||
use crate::errors::AsyncQueueError;
|
||||
use crate::task::{NewTask, Task, TaskId, TaskState};
|
||||
use diesel::result::Error::QueryBuilderError;
|
||||
use diesel_async::scoped_futures::ScopedFutureExt;
|
||||
use diesel_async::AsyncConnection;
|
||||
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
|
||||
|
||||
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PgTaskStore {
|
||||
pool: Pool<AsyncPgConnection>,
|
||||
}
|
||||
|
||||
impl PgTaskStore {
|
||||
pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
|
||||
PgTaskStore { pool }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TaskStore for PgTaskStore {
|
||||
async fn pull_next_task(&self, queue_name: &str) -> Result<Option<Task>, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
connection
|
||||
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
|
||||
async move {
|
||||
let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
Task::set_running(conn, pending_task).await.map(Some)
|
||||
}
|
||||
.scope_boxed()
|
||||
})
|
||||
.await
|
||||
}
|
||||
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Task::insert(&mut connection, new_task).await
|
||||
}
|
||||
async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
match state {
|
||||
TaskState::Done => Task::set_done(&mut connection, id).await?,
|
||||
TaskState::Failed(error_msg) => {
|
||||
Task::fail_with_message(&mut connection, id, &error_msg).await?
|
||||
}
|
||||
_ => return Ok(()),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
let result = Task::remove(&mut connection, id).await?;
|
||||
Ok(result)
|
||||
}
|
||||
async fn schedule_task_retry(
|
||||
&self,
|
||||
id: TaskId,
|
||||
backoff_seconds: u32,
|
||||
error: &str,
|
||||
) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
let task = Task::schedule_retry(&mut connection, id, backoff_seconds, error).await?;
|
||||
Ok(task)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait TaskStore {
|
||||
async fn pull_next_task(&self, queue_name: &str) -> Result<Option<Task>, AsyncQueueError>;
|
||||
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError>;
|
||||
async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError>;
|
||||
async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError>;
|
||||
async fn schedule_task_retry(
|
||||
&self,
|
||||
id: TaskId,
|
||||
backoff_seconds: u32,
|
||||
error: &str,
|
||||
) -> Result<Task, AsyncQueueError>;
|
||||
}
|
||||
|
|
|
@ -115,7 +115,7 @@ impl Task {
|
|||
|
||||
#[derive(Insertable, Debug, Eq, PartialEq, Clone)]
|
||||
#[diesel(table_name = backie_tasks)]
|
||||
pub(crate) struct NewTask {
|
||||
pub struct NewTask {
|
||||
task_name: String,
|
||||
queue_name: String,
|
||||
uniq_hash: Option<TaskHash>,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::errors::{AsyncQueueError, BackieError};
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::store::TaskStore;
|
||||
use crate::task::{CurrentTask, Task, TaskState};
|
||||
use crate::{PgTaskStore, RetentionMode};
|
||||
use futures::future::FutureExt;
|
||||
|
@ -162,7 +163,7 @@ where
|
|||
Ok(_) => self.finalize_task(task, result).await?,
|
||||
Err(error) => {
|
||||
if task.retries < task.max_retries {
|
||||
let backoff_seconds = 5; // TODO: runnable.backoff(task.retries as u32);
|
||||
let backoff_seconds = 5; // TODO: runnable_task.backoff(task.retries as u32);
|
||||
|
||||
log::debug!(
|
||||
"Task {} failed to run and will be retried in {} seconds",
|
||||
|
@ -226,8 +227,6 @@ where
|
|||
mod async_worker_tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
|
@ -332,16 +331,4 @@ mod async_worker_tests {
|
|||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn pool() -> Pool<AsyncPgConnection> {
|
||||
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
||||
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
|
||||
);
|
||||
Pool::builder()
|
||||
.max_size(1)
|
||||
.min_idle(Some(1))
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::errors::BackieError;
|
|||
use crate::queue::Queue;
|
||||
use crate::worker::{runnable, ExecuteTaskFn};
|
||||
use crate::worker::{StateFn, Worker};
|
||||
use crate::{BackgroundTask, CurrentTask, PgTaskStore, RetentionMode};
|
||||
use crate::{BackgroundTask, PgTaskStore, RetentionMode};
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
@ -31,6 +31,9 @@ where
|
|||
/// The types of task the worker pool can execute and the loaders for them.
|
||||
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||
|
||||
/// The queue names for the registered tasks.
|
||||
queue_tasks: BTreeMap<String, Vec<String>>,
|
||||
|
||||
/// Number of workers that will be spawned per queue.
|
||||
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
|
||||
}
|
||||
|
@ -40,11 +43,11 @@ where
|
|||
AppData: Clone + Send + 'static,
|
||||
{
|
||||
/// Create a new worker pool.
|
||||
pub fn new<A>(queue_store: PgTaskStore, application_data_fn: A) -> Self
|
||||
pub fn new<A>(task_store: PgTaskStore, application_data_fn: A) -> Self
|
||||
where
|
||||
A: Fn(Queue) -> AppData + Send + Sync + 'static,
|
||||
{
|
||||
let queue_store = Arc::new(queue_store);
|
||||
let queue_store = Arc::new(task_store);
|
||||
let queue = Queue::new(queue_store.clone());
|
||||
let application_data_fn = {
|
||||
let queue = queue.clone();
|
||||
|
@ -55,22 +58,36 @@ where
|
|||
queue,
|
||||
application_data_fn: Arc::new(application_data_fn),
|
||||
task_registry: BTreeMap::new(),
|
||||
queue_tasks: BTreeMap::new(),
|
||||
worker_queues: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a task type with the worker pool.
|
||||
pub fn register_task_type<BT>(mut self, num_workers: u32, retention_mode: RetentionMode) -> Self
|
||||
pub fn register_task_type<BT>(mut self) -> Self
|
||||
where
|
||||
BT: BackgroundTask<AppData = AppData>,
|
||||
{
|
||||
self.worker_queues
|
||||
.insert(BT::QUEUE.to_string(), (retention_mode, num_workers));
|
||||
self.queue_tasks
|
||||
.entry(BT::QUEUE.to_string())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(BT::TASK_NAME.to_string());
|
||||
self.task_registry
|
||||
.insert(BT::TASK_NAME.to_string(), Arc::new(runnable::<BT>));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn configure_queue(
|
||||
mut self,
|
||||
queue_name: impl ToString,
|
||||
num_workers: u32,
|
||||
retention_mode: RetentionMode,
|
||||
) -> Self {
|
||||
self.worker_queues
|
||||
.insert(queue_name.to_string(), (retention_mode, num_workers));
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn start<F>(
|
||||
self,
|
||||
graceful_shutdown: F,
|
||||
|
@ -78,6 +95,13 @@ where
|
|||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
// Validate that all registered tasks queues are configured
|
||||
for (queue_name, tasks_for_queue) in self.queue_tasks.into_iter() {
|
||||
if !self.worker_queues.contains_key(&queue_name) {
|
||||
return Err(BackieError::QueueNotConfigured(queue_name, tasks_for_queue));
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
|
||||
// Spawn all individual workers per queue
|
||||
|
@ -94,8 +118,10 @@ where
|
|||
let worker_name = format!("worker-{queue_name}-{idx}");
|
||||
// TODO: grab the join handle for every worker for graceful shutdown
|
||||
tokio::spawn(async move {
|
||||
worker.run_tasks().await;
|
||||
log::info!("Worker {} stopped", worker_name);
|
||||
match worker.run_tasks().await {
|
||||
Ok(()) => log::info!("Worker {worker_name} stopped sucessfully"),
|
||||
Err(err) => log::error!("Worker {worker_name} stopped due to error: {err}"),
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -117,6 +143,7 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::task::CurrentTask;
|
||||
use async_trait::async_trait;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
|
@ -164,14 +191,32 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_all_registered_tasks_queues_are_configured() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let result = WorkerPool::new(task_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.start(futures::future::ready(()))
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(BackieError::QueueNotConfigured(..))));
|
||||
if let Err(err) = result {
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"Queue \"default\" needs to be configured because of registered tasks: [\"my_task\"]"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_worker_pool() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let task_store = PgTaskStore::new(pool().await);
|
||||
|
||||
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>(1, RetentionMode::RemoveDone)
|
||||
let (join_handle, queue) =
|
||||
WorkerPool::new(task_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -186,15 +231,17 @@ mod tests {
|
|||
join_handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn pool() -> Pool<AsyncPgConnection> {
|
||||
async fn task_store() -> PgTaskStore {
|
||||
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
||||
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
|
||||
);
|
||||
Pool::builder()
|
||||
let pool = Pool::builder()
|
||||
.max_size(1)
|
||||
.min_idle(Some(1))
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
PgTaskStore::new(pool)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue