Make possible to provide app state to tasks

This commit is contained in:
Rafael Caricio 2023-03-10 23:41:34 +01:00
parent 7fcb63f75c
commit aac0b44c7f
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
19 changed files with 776 additions and 1151 deletions

1
.gitignore vendored
View file

@ -1,5 +1,4 @@
**/target
Cargo.lock
src/schema.rs
docs/content/docs/CHANGELOG.md
docs/content/docs/README.md

View file

@ -19,16 +19,13 @@ cron = "0.12"
chrono = "0.4"
hex = "0.4"
log = "0.4"
serde = "1.0"
serde_derive = "1.0"
serde_json = "1.0"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sha2 = "0.10"
thiserror = "1.0"
typed-builder = "0.13"
typetag = "0.2"
anyhow = "1"
thiserror = "1"
uuid = { version = "1.1", features = ["v4", "serde"] }
async-trait = "0.1"
async-recursion = "1"
futures = "0.3"
diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] }
diesel-derive-newtype = "2.0.0-rc.0"

View file

@ -1,12 +0,0 @@
[package]
name = "simple_cron_async_worker"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
fang = { path = "../../../" , features = ["asynk"]}
env_logger = "0.9.0"
log = "0.4.0"
tokio = { version = "1", features = ["full"] }

View file

@ -1,33 +0,0 @@
use fang::async_trait;
use fang::asynk::async_queue::AsyncQueueable;
use fang::serde::{Deserialize, Serialize};
use fang::typetag;
use fang::AsyncRunnable;
use fang::FangError;
use fang::Scheduled;
#[derive(Serialize, Deserialize)]
#[serde(crate = "fang::serde")]
pub struct MyCronTask {}
#[async_trait]
#[typetag::serde]
impl AsyncRunnable for MyCronTask {
async fn run(&self, _queue: &mut dyn AsyncQueueable) -> Result<(), FangError> {
log::info!("CRON!!!!!!!!!!!!!!!",);
Ok(())
}
fn cron(&self) -> Option<Scheduled> {
// sec min hour day of month month day of week year
// be careful works only with UTC hour.
// https://www.timeanddate.com/worldclock/timezone/utc
let expression = "0/20 * * * Aug-Sep * 2022/1";
Some(Scheduled::CronPattern(expression.to_string()))
}
fn uniq(&self) -> bool {
true
}
}

View file

@ -1,41 +0,0 @@
use fang::asynk::async_queue::AsyncQueue;
use fang::asynk::async_queue::AsyncQueueable;
use fang::asynk::async_worker_pool::AsyncWorkerPool;
use fang::AsyncRunnable;
use fang::NoTls;
use simple_cron_async_worker::MyCronTask;
use std::time::Duration;
#[tokio::main]
async fn main() {
env_logger::init();
log::info!("Starting...");
let max_pool_size: u32 = 3;
let mut queue = AsyncQueue::builder()
.uri("postgres://postgres:postgres@localhost/fang")
.max_pool_size(max_pool_size)
.build();
queue.connect(NoTls).await.unwrap();
log::info!("Queue connected...");
let mut pool: AsyncWorkerPool<AsyncQueue<NoTls>> = AsyncWorkerPool::builder()
.number_of_workers(10_u32)
.queue(queue.clone())
.build();
log::info!("Pool created ...");
pool.start().await;
log::info!("Workers started ...");
let task = MyCronTask {};
queue
.schedule_task(&task as &dyn AsyncRunnable)
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(100)).await;
}

View file

@ -5,6 +5,7 @@ edition = "2021"
[dependencies]
backie = { path = "../../" }
anyhow = "1"
env_logger = "0.9.0"
log = "0.4.0"
tokio = { version = "1", features = ["full"] }
@ -12,4 +13,3 @@ diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
diesel = { version = "2.0", features = ["postgres"] }
async-trait = "0.1"
serde = { version = "1.0", features = ["derive"] }
typetag = "0.2"

View file

@ -1,7 +1,20 @@
use std::time::Duration;
use async_trait::async_trait;
use serde::{Serialize, Deserialize};
use backie::{RunnableTask, Queueable};
use backie::{BackgroundTask, CurrentTask};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct MyApplicationContext {
app_name: String,
}
impl MyApplicationContext {
pub fn new(app_name: &str) -> Self {
Self {
app_name: app_name.to_string(),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct MyTask {
@ -26,37 +39,51 @@ impl MyFailingTask {
}
#[async_trait]
#[typetag::serde]
impl RunnableTask for MyTask {
async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box<dyn std::error::Error + Send + 'static>> {
impl BackgroundTask for MyTask {
const TASK_NAME: &'static str = "my_task";
type AppData = MyApplicationContext;
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> {
// let new_task = MyTask::new(self.number + 1);
// queue
// .insert_task(&new_task as &dyn AsyncRunnable)
// .insert_task(&new_task)
// .await
// .unwrap();
log::info!("the current number is {}", self.number);
log::info!(
"[{}] Hello from {}! the current number is {}",
task.id(),
ctx.app_name,
self.number
);
tokio::time::sleep(Duration::from_secs(3)).await;
log::info!("done..");
log::info!("[{}] done..", task.id());
Ok(())
}
}
#[async_trait]
#[typetag::serde]
impl RunnableTask for MyFailingTask {
async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box<dyn std::error::Error + Send + 'static>> {
impl BackgroundTask for MyFailingTask {
const TASK_NAME: &'static str = "my_failing_task";
type AppData = MyApplicationContext;
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> {
// let new_task = MyFailingTask::new(self.number + 1);
// queue
// .insert_task(&new_task as &dyn AsyncRunnable)
// .insert_task(&new_task)
// .await
// .unwrap();
log::info!("the current number is {}", self.number);
// task.id();
// task.keep_alive().await?;
// task.previous_error();
// task.retry_count();
log::info!("[{}] the current number is {}", task.id(), self.number);
tokio::time::sleep(Duration::from_secs(3)).await;
log::info!("done..");
log::info!("[{}] done..", task.id());
//
// let b = true;
//

View file

@ -1,9 +1,9 @@
use simple_worker::MyFailingTask;
use simple_worker::MyTask;
use std::time::Duration;
use backie::{PgTaskStore, RetentionMode, WorkerPool};
use diesel_async::pg::AsyncPgConnection;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use backie::{PgAsyncQueue, WorkerPool, Queueable};
use simple_worker::MyApplicationContext;
use simple_worker::MyFailingTask;
use simple_worker::MyTask;
#[tokio::main]
async fn main() {
@ -22,49 +22,38 @@ async fn main() {
.unwrap();
log::info!("Pool created ...");
let mut queue = PgAsyncQueue::new(pool);
let task_store = PgTaskStore::new(pool);
let (tx, mut rx) = tokio::sync::watch::channel(false);
let executor_task = tokio::spawn({
let mut queue = queue.clone();
async move {
let mut workers_pool: WorkerPool<PgAsyncQueue> = WorkerPool::builder()
.number_of_workers(10_u32)
.queue(queue)
.build();
// Some global application context I want to pass to my background tasks
let my_app_context = MyApplicationContext::new("Backie Example App");
log::info!("Workers starting ...");
workers_pool.start(async move {
rx.changed().await;
}).await;
log::info!("Workers stopped!");
}
});
// Register the task types I want to use and start the worker pool
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
.register_task_type::<MyTask>(1, RetentionMode::RemoveDone)
.register_task_type::<MyFailingTask>(1, RetentionMode::RemoveDone)
.start(async move {
let _ = rx.changed().await;
})
.await
.unwrap();
log::info!("Workers started ...");
let task1 = MyTask::new(0);
let task2 = MyTask::new(20_000);
let task3 = MyFailingTask::new(50_000);
queue
.create_task(&task1)
.await
.unwrap();
queue
.create_task(&task2)
.await
.unwrap();
queue
.create_task(&task3)
.await
.unwrap();
queue.enqueue(task1).await.unwrap();
queue.enqueue(task2).await.unwrap();
queue.enqueue(task3).await.unwrap();
log::info!("Tasks created ...");
tokio::signal::ctrl_c().await;
// Wait for Ctrl+C
let _ = tokio::signal::ctrl_c().await;
log::info!("Stopping ...");
tx.send(true).unwrap();
executor_task.await.unwrap();
log::info!("Stopped!");
join_handle.await.unwrap();
log::info!("Workers Stopped!");
}

View file

@ -2,20 +2,20 @@ CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE backie_tasks (
id uuid PRIMARY KEY DEFAULT uuid_generate_v4(),
payload jsonb NOT NULL,
error_message TEXT DEFAULT NULL,
task_type VARCHAR DEFAULT 'common' NOT NULL,
task_name VARCHAR NOT NULL,
queue_name VARCHAR DEFAULT 'common' NOT NULL,
uniq_hash CHAR(64) DEFAULT NULL,
retries INTEGER DEFAULT 0 NOT NULL,
payload jsonb NOT NULL,
timeout_msecs INT8 NOT NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
running_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL
done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
error_info jsonb DEFAULT NULL,
retries INTEGER DEFAULT 0 NOT NULL,
max_retries INTEGER DEFAULT 0 NOT NULL
);
CREATE INDEX backie_tasks_type_index ON backie_tasks(task_type);
CREATE INDEX backie_tasks_created_at_index ON backie_tasks(created_at);
CREATE INDEX backie_tasks_uniq_hash ON backie_tasks(uniq_hash);
--- create uniqueness index
CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL;

View file

@ -1,25 +1,16 @@
use serde_json::Error as SerdeError;
use std::fmt::Display;
use thiserror::Error;
/// Library errors
#[derive(Debug, Error)]
pub enum BackieError {
#[error("Queue processing error: {0}")]
QueueProcessingError(#[from] AsyncQueueError),
SerializationError(#[from] SerdeError),
ShutdownError(#[from] tokio::sync::watch::error::SendError<()>),
}
impl Display for BackieError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BackieError::QueueProcessingError(error) => {
write!(f, "Queue processing error: {}", error)
}
BackieError::SerializationError(error) => write!(f, "Serialization error: {}", error),
BackieError::ShutdownError(error) => write!(f, "Shutdown error: {}", error),
}
}
#[error("Worker Pool shutdown error: {0}")]
WorkerPoolShutdownError(#[from] tokio::sync::watch::error::SendError<()>),
#[error("Worker shutdown error: {0}")]
WorkerShutdownError(#[from] tokio::sync::watch::error::RecvError),
}
/// List of error types that can occur while working with cron schedules.
@ -29,7 +20,7 @@ pub enum CronError {
#[error(transparent)]
LibraryError(#[from] cron::error::Error),
/// [`Scheduled`] enum variant is not provided
#[error("You have to implement method `cron()` in your AsyncRunnable")]
#[error("You have to implement method `cron()` in your Runnable")]
TaskNotSchedulableError,
/// The next execution can not be determined using the current [`Scheduled::CronPattern`]
#[error("No timestamps match with this cron pattern")]
@ -41,7 +32,7 @@ pub enum AsyncQueueError {
#[error(transparent)]
PgError(#[from] diesel::result::Error),
#[error(transparent)]
#[error("Task serialization error: {0}")]
SerdeError(#[from] serde_json::Error),
#[error(transparent)]
@ -49,6 +40,9 @@ pub enum AsyncQueueError {
#[error("Task is not in progress, operation not allowed")]
TaskNotRunning,
#[error("Task with name {0} is not registered")]
TaskNotRegistered(String),
}
impl From<cron::error::Error> for AsyncQueueError {

View file

@ -38,9 +38,9 @@ impl Default for RetentionMode {
}
}
pub use queue::PgAsyncQueue;
pub use queue::Queueable;
pub use runnable::RunnableTask;
pub use queue::PgTaskStore;
pub use runnable::BackgroundTask;
pub use task::CurrentTask;
pub use worker_pool::WorkerPool;
pub mod errors;
@ -48,6 +48,7 @@ mod queries;
pub mod queue;
pub mod runnable;
mod schema;
pub mod store;
pub mod task;
pub mod worker;
pub mod worker_pool;

View file

@ -1,8 +1,7 @@
use crate::errors::AsyncQueueError;
use crate::runnable::RunnableTask;
use crate::schema::backie_tasks;
use crate::task::Task;
use crate::task::{NewTask, TaskHash, TaskId, TaskType};
use crate::task::{NewTask, TaskHash, TaskId};
use chrono::DateTime;
use chrono::Duration;
use chrono::Utc;
@ -43,14 +42,6 @@ impl Task {
Ok(qty > 0)
}
pub(crate) async fn remove_by_type(
connection: &mut AsyncPgConnection,
task_type: TaskType,
) -> Result<u64, AsyncQueueError> {
let query = backie_tasks::table.filter(backie_tasks::task_type.eq(task_type));
Ok(diesel::delete(query).execute(connection).await? as u64)
}
pub(crate) async fn find_by_id(
connection: &mut AsyncPgConnection,
id: TaskId,
@ -67,10 +58,13 @@ impl Task {
id: TaskId,
error_message: &str,
) -> Result<Task, AsyncQueueError> {
let error = serde_json::json!({
"error": error_message,
});
let query = backie_tasks::table.filter(backie_tasks::id.eq(id));
Ok(diesel::update(query)
.set((
backie_tasks::error_message.eq(error_message),
backie_tasks::error_info.eq(Some(error)),
backie_tasks::done_at.eq(Utc::now()),
))
.get_result::<Task>(connection)
@ -81,18 +75,22 @@ impl Task {
connection: &mut AsyncPgConnection,
id: TaskId,
backoff_seconds: u32,
error: &str,
error_message: &str,
) -> Result<Task, AsyncQueueError> {
use crate::schema::backie_tasks::dsl;
let now = Utc::now();
let scheduled_at = now + Duration::seconds(backoff_seconds as i64);
let error = serde_json::json!({
"error": error_message,
});
let task = diesel::update(backie_tasks::table.filter(backie_tasks::id.eq(id)))
.set((
backie_tasks::error_message.eq(error),
backie_tasks::error_info.eq(Some(error)),
backie_tasks::retries.eq(dsl::retries + 1),
backie_tasks::created_at.eq(scheduled_at),
backie_tasks::scheduled_at.eq(scheduled_at),
backie_tasks::running_at.eq::<Option<DateTime<Utc>>>(None),
))
.get_result::<Task>(connection)
@ -103,14 +101,14 @@ impl Task {
pub(crate) async fn fetch_next_pending(
connection: &mut AsyncPgConnection,
task_type: TaskType,
queue_name: &str,
) -> Option<Task> {
backie_tasks::table
.filter(backie_tasks::created_at.lt(Utc::now())) // skip tasks scheduled for the future
.filter(backie_tasks::scheduled_at.lt(Utc::now())) // skip tasks scheduled for the future
.order(backie_tasks::created_at.asc()) // get the oldest task first
.filter(backie_tasks::running_at.is_null()) // that is not marked as running already
.filter(backie_tasks::done_at.is_null()) // and not marked as done
.filter(backie_tasks::task_type.eq(task_type))
.filter(backie_tasks::queue_name.eq(queue_name))
.limit(1)
.for_update()
.skip_locked()
@ -143,39 +141,13 @@ impl Task {
pub(crate) async fn insert(
connection: &mut AsyncPgConnection,
runnable: &dyn RunnableTask,
new_task: NewTask,
) -> Result<Task, AsyncQueueError> {
let payload = serde_json::to_value(runnable)?;
match runnable.uniq() {
None => {
let new_task = NewTask::builder()
.uniq_hash(None)
.task_type(runnable.task_type())
.payload(payload)
.build();
Ok(diesel::insert_into(backie_tasks::table)
.values(new_task)
.get_result::<Task>(connection)
.await?)
}
Some(hash) => match Self::find_by_uniq_hash(connection, hash.clone()).await {
Some(task) => Ok(task),
None => {
let new_task = NewTask::builder()
.uniq_hash(Some(hash))
.task_type(runnable.task_type())
.payload(payload)
.build();
Ok(diesel::insert_into(backie_tasks::table)
.values(new_task)
.get_result::<Task>(connection)
.await?)
}
},
}
}
pub(crate) async fn find_by_uniq_hash(
connection: &mut AsyncPgConnection,

View file

@ -1,86 +1,48 @@
use crate::errors::AsyncQueueError;
use crate::runnable::RunnableTask;
use crate::task::{Task, TaskHash, TaskId, TaskType};
use async_trait::async_trait;
use crate::runnable::BackgroundTask;
use crate::task::{NewTask, Task, TaskHash, TaskId, TaskState};
use diesel::result::Error::QueryBuilderError;
use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection;
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
use std::sync::Arc;
use std::time::Duration;
/// This trait defines operations for an asynchronous queue.
/// The trait can be implemented for different storage backends.
/// For now, the trait is only implemented for PostgreSQL. More backends are planned to be implemented in the future.
#[async_trait]
pub trait Queueable: Send {
/// Pull pending tasks from the queue to execute them.
///
/// This method returns one task of the `task_type` type. If `task_type` is `None` it will try to
/// fetch a task of the type `common`. The returned task is marked as running and must be executed.
async fn pull_next_task(
&mut self,
kind: Option<TaskType>,
) -> Result<Option<Task>, AsyncQueueError>;
#[derive(Clone)]
pub struct Queue {
task_store: Arc<PgTaskStore>,
}
/// Enqueue a task to the queue, The task will be executed as soon as possible by the worker of the same type
/// created by an AsyncWorkerPool.
async fn create_task(&mut self, task: &dyn RunnableTask) -> Result<Task, AsyncQueueError>;
impl Queue {
pub(crate) fn new(task_store: Arc<PgTaskStore>) -> Self {
Queue { task_store }
}
/// Retrieve a task by its `id`.
async fn find_task_by_id(&mut self, id: TaskId) -> Result<Task, AsyncQueueError>;
/// Update the state of a task to failed and set an error_message.
async fn set_task_failed(
&mut self,
id: TaskId,
error_message: &str,
) -> Result<Task, AsyncQueueError>;
/// Update the state of a task to done.
async fn set_task_done(&mut self, id: TaskId) -> Result<Task, AsyncQueueError>;
/// Update the state of a task to inform that it's still in progress.
async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError>;
/// Remove a task by its id.
async fn remove_task(&mut self, id: TaskId) -> Result<u64, AsyncQueueError>;
/// The method will remove all tasks from the queue
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError>;
/// Remove all tasks that are scheduled in the future.
async fn remove_all_scheduled_tasks(&mut self) -> Result<u64, AsyncQueueError>;
/// Remove a task by its metadata (struct fields values)
async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result<bool, AsyncQueueError>;
/// Removes all tasks that have the specified `task_type`.
async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result<u64, AsyncQueueError>;
async fn schedule_task_retry(
&mut self,
id: TaskId,
backoff_seconds: u32,
error: &str,
) -> Result<Task, AsyncQueueError>;
pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), AsyncQueueError>
where
BT: BackgroundTask,
{
self.task_store
.create_task(NewTask::new(background_task, Duration::from_secs(10))?)
.await?;
Ok(())
}
}
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
#[derive(Debug, Clone)]
pub struct PgAsyncQueue {
pub struct PgTaskStore {
pool: Pool<AsyncPgConnection>,
}
impl PgAsyncQueue {
impl PgTaskStore {
pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
PgAsyncQueue { pool }
PgTaskStore { pool }
}
}
#[async_trait]
impl Queueable for PgAsyncQueue {
async fn pull_next_task(
&mut self,
task_type: Option<TaskType>,
pub(crate) async fn pull_next_task(
&self,
queue_name: &str,
) -> Result<Option<Task>, AsyncQueueError> {
let mut connection = self
.pool
@ -90,7 +52,7 @@ impl Queueable for PgAsyncQueue {
connection
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
async move {
let Some(pending_task) = Task::fetch_next_pending(conn, task_type.unwrap_or_default()).await else {
let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else {
return Ok(None);
};
@ -101,16 +63,16 @@ impl Queueable for PgAsyncQueue {
.await
}
async fn create_task(&mut self, runnable: &dyn RunnableTask) -> Result<Task, AsyncQueueError> {
pub(crate) async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Ok(Task::insert(&mut connection, runnable).await?)
Task::insert(&mut connection, new_task).await
}
async fn find_task_by_id(&mut self, id: TaskId) -> Result<Task, AsyncQueueError> {
pub(crate) async fn find_task_by_id(&self, id: TaskId) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
.get()
@ -119,29 +81,27 @@ impl Queueable for PgAsyncQueue {
Task::find_by_id(&mut connection, id).await
}
async fn set_task_failed(
&mut self,
pub(crate) async fn set_task_state(
&self,
id: TaskId,
error_message: &str,
) -> Result<Task, AsyncQueueError> {
state: TaskState,
) -> Result<(), AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::fail_with_message(&mut connection, id, error_message).await
match state {
TaskState::Done => Task::set_done(&mut connection, id).await?,
TaskState::Failed(error_msg) => {
Task::fail_with_message(&mut connection, id, &error_msg).await?
}
_ => return Ok(()),
};
Ok(())
}
async fn set_task_done(&mut self, id: TaskId) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::set_done(&mut connection, id).await
}
async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError> {
pub(crate) async fn keep_task_alive(&self, id: TaskId) -> Result<(), AsyncQueueError> {
let mut connection = self
.pool
.get()
@ -159,7 +119,7 @@ impl Queueable for PgAsyncQueue {
.await
}
async fn remove_task(&mut self, id: TaskId) -> Result<u64, AsyncQueueError> {
pub(crate) async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
.get()
@ -169,7 +129,7 @@ impl Queueable for PgAsyncQueue {
Ok(result)
}
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError> {
pub(crate) async fn remove_all_tasks(&self) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
.get()
@ -178,37 +138,8 @@ impl Queueable for PgAsyncQueue {
Task::remove_all(&mut connection).await
}
async fn remove_all_scheduled_tasks(&mut self) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let result = Task::remove_all_scheduled(&mut connection).await?;
Ok(result)
}
async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result<bool, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::remove_by_hash(&mut connection, task_hash).await
}
async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let result = Task::remove_by_type(&mut connection, task_type).await?;
Ok(result)
}
async fn schedule_task_retry(
&mut self,
pub(crate) async fn schedule_task_retry(
&self,
id: TaskId,
backoff_seconds: u32,
error: &str,
@ -226,11 +157,8 @@ impl Queueable for PgAsyncQueue {
#[cfg(test)]
mod async_queue_tests {
use super::*;
use crate::task::TaskState;
use crate::Scheduled;
use crate::CurrentTask;
use async_trait::async_trait;
use chrono::DateTime;
use chrono::Utc;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
use serde::{Deserialize, Serialize};
@ -240,13 +168,12 @@ mod async_queue_tests {
pub number: u16,
}
#[typetag::serde]
#[async_trait]
impl RunnableTask for AsyncTask {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
impl BackgroundTask for AsyncTask {
const TASK_NAME: &'static str = "AsyncUniqTask";
type AppData = ();
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
}
@ -256,13 +183,12 @@ mod async_queue_tests {
pub number: u16,
}
#[typetag::serde]
#[async_trait]
impl RunnableTask for AsyncUniqTask {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
impl BackgroundTask for AsyncUniqTask {
const TASK_NAME: &'static str = "AsyncUniqTask";
type AppData = ();
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
@ -277,286 +203,154 @@ mod async_queue_tests {
pub datetime: String,
}
#[typetag::serde]
#[async_trait]
impl RunnableTask for AsyncTaskSchedule {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
impl BackgroundTask for AsyncTaskSchedule {
const TASK_NAME: &'static str = "AsyncUniqTask";
type AppData = ();
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
fn cron(&self) -> Option<Scheduled> {
let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
Some(Scheduled::ScheduleOnce(datetime))
}
}
#[tokio::test]
async fn insert_task_creates_new_task() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn update_task_state_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
let id = task.id;
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let finished_task = test.set_task_done(task.id).await.unwrap();
assert_eq!(id, finished_task.id);
assert_eq!(TaskState::Done, finished_task.state());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn failed_task_query_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
let id = task.id;
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let failed_task = test.set_task_failed(task.id, "Some error").await.unwrap();
assert_eq!(id, failed_task.id);
assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
assert_eq!(TaskState::Failed, failed_task.state());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn remove_all_tasks_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task);
let result = test.remove_all_tasks().await.unwrap();
assert_eq!(2, result);
// fn cron(&self) -> Option<Scheduled> {
// let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
// Some(Scheduled::ScheduleOnce(datetime))
// }
}
// #[tokio::test]
// async fn schedule_task_test() {
// async fn insert_task_creates_new_task() {
// let pool = pool().await;
// let mut test = PgAsyncQueue::new(pool);
// let mut queue = PgTaskStore::new(pool);
//
// let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0);
//
// let task = &AsyncTaskSchedule {
// number: 1,
// datetime: datetime.to_string(),
// };
//
// let task = test.schedule_task(task).await.unwrap();
// let task = queue.create_task(AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTaskSchedule"), type_task);
// assert_eq!(task.scheduled_at, datetime);
// assert_eq!(Some("AsyncTask"), type_task);
//
// queue.remove_all_tasks().await.unwrap();
// }
//
// #[tokio::test]
// async fn update_task_state_test() {
// let pool = pool().await;
// let mut test = PgTaskStore::new(pool);
//
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
// let id = task.id;
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let finished_task = test.set_task_state(task.id, TaskState::Done).await.unwrap();
//
// assert_eq!(id, finished_task.id);
// assert_eq!(TaskState::Done, finished_task.state());
//
// test.remove_all_tasks().await.unwrap();
// }
//
// #[tokio::test]
// async fn remove_all_scheduled_tasks_test() {
// async fn failed_task_query_test() {
// let pool = pool().await;
// let mut test = PgAsyncQueue::new(pool);
// let mut test = PgTaskStore::new(pool);
//
// let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0);
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let task1 = &AsyncTaskSchedule {
// number: 1,
// datetime: datetime.to_string(),
// };
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
// let id = task.id;
//
// let task2 = &AsyncTaskSchedule {
// number: 2,
// datetime: datetime.to_string(),
// };
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// test.schedule_task(task1).await.unwrap();
// test.schedule_task(task2).await.unwrap();
// let failed_task = test.set_task_state(task.id, TaskState::Failed("Some error".to_string())).await.unwrap();
//
// let number = test.remove_all_scheduled_tasks().await.unwrap();
//
// assert_eq!(2, number);
// assert_eq!(id, failed_task.id);
// assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
// assert_eq!(TaskState::Failed, failed_task.state());
//
// test.remove_all_tasks().await.unwrap();
// }
#[tokio::test]
async fn pull_next_task_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task);
let task = test.pull_next_task(None).await.unwrap().unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let task = test.pull_next_task(None).await.unwrap().unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn remove_tasks_type_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task);
let result = test
.remove_tasks_type(TaskType::from("nonexistentType"))
.await
.unwrap();
assert_eq!(0, result);
let result = test.remove_tasks_type(TaskType::default()).await.unwrap();
assert_eq!(2, result);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn remove_tasks_by_metadata() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test
.create_task(&AsyncUniqTask { number: 1 })
.await
.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncUniqTask"), type_task);
let task = test
.create_task(&AsyncUniqTask { number: 2 })
.await
.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number);
assert_eq!(Some("AsyncUniqTask"), type_task);
let result = test
.remove_task_by_hash(AsyncUniqTask { number: 0 }.uniq().unwrap())
.await
.unwrap();
assert!(!result, "Should **not** remove task");
let result = test
.remove_task_by_hash(AsyncUniqTask { number: 1 }.uniq().unwrap())
.await
.unwrap();
assert!(result, "Should remove task");
test.remove_all_tasks().await.unwrap();
}
//
// #[tokio::test]
// async fn remove_all_tasks_test() {
// let pool = pool().await;
// let mut test = PgTaskStore::new(pool);
//
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(2), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let result = test.remove_all_tasks().await.unwrap();
// assert_eq!(2, result);
// }
//
// #[tokio::test]
// async fn pull_next_task_test() {
// let pool = pool().await;
// let mut queue = PgTaskStore::new(pool);
//
// let task = queue.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let task = queue.create_task(&AsyncTask { number: 2 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(2), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(2), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// queue.remove_all_tasks().await.unwrap();
// }
async fn pool() -> Pool<AsyncPgConnection> {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(

View file

@ -1,27 +1,34 @@
use crate::queue::Queueable;
use crate::task::TaskHash;
use crate::task::TaskType;
use crate::Scheduled;
use crate::task::{CurrentTask, TaskHash};
use async_trait::async_trait;
use std::error::Error;
pub const RETRIES_NUMBER: i32 = 5;
use serde::{de::DeserializeOwned, ser::Serialize};
/// Task that can be executed by the queue.
///
/// The `RunnableTask` trait is used to define the behaviour of a task. You must implement this
/// The `BackgroundTask` trait is used to define the behaviour of a task. You must implement this
/// trait for all tasks you want to execute.
#[typetag::serde(tag = "type")]
#[async_trait]
pub trait RunnableTask: Send + Sync {
/// Execute the task. This method should define its logic
async fn run(&self, queue: &mut dyn Queueable) -> Result<(), Box<dyn Error + Send + 'static>>;
pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
/// Unique name of the task.
///
/// This MUST be unique for the whole application.
const TASK_NAME: &'static str;
/// Define the type of the task.
/// The `common` task type is used by default
fn task_type(&self) -> TaskType {
TaskType::default()
}
/// Task queue where this task will be executed.
///
/// Used to define which workers are going to be executing this task. It uses the default
/// task queue if not changed.
const QUEUE: &'static str = "default";
/// Number of retries for tasks.
///
/// By default, it is set to 5.
const MAX_RETRIES: i32 = 5;
/// The application data provided to this task at runtime.
type AppData: Clone + Send + 'static;
/// Execute the task. This method should define its logic
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>;
/// If set to true, no new tasks with the same metadata will be inserted
/// By default it is set to false.
@ -29,27 +36,10 @@ pub trait RunnableTask: Send + Sync {
None
}
/// This method defines if a task is periodic or it should be executed once in the future.
///
/// Be careful it works only with the UTC timezone.
///
/// Example:
///
/// ```rust
/// fn cron(&self) -> Option<Scheduled> {
/// let expression = "0/20 * * * Aug-Sep * 2022/1";
/// Some(Scheduled::CronPattern(expression.to_string()))
/// }
///```
/// In order to schedule a task once, use the `Scheduled::ScheduleOnce` enum variant.
fn cron(&self) -> Option<Scheduled> {
None
}
/// Define the maximum number of retries the task will be retried.
/// By default the number of retries is 20.
fn max_retries(&self) -> i32 {
RETRIES_NUMBER
Self::MAX_RETRIES
}
/// Define the backoff mode

19
src/schema.rs Normal file
View file

@ -0,0 +1,19 @@
// @generated automatically by Diesel CLI.
diesel::table! {
backie_tasks (id) {
id -> Uuid,
task_name -> Varchar,
queue_name -> Varchar,
uniq_hash -> Nullable<Bpchar>,
payload -> Jsonb,
timeout_msecs -> Int8,
created_at -> Timestamptz,
scheduled_at -> Timestamptz,
running_at -> Nullable<Timestamptz>,
done_at -> Nullable<Timestamptz>,
error_info -> Nullable<Jsonb>,
retries -> Int4,
max_retries -> Int4,
}
}

1
src/store.rs Normal file
View file

@ -0,0 +1 @@

View file

@ -1,4 +1,5 @@
use crate::schema::backie_tasks;
use crate::BackgroundTask;
use chrono::DateTime;
use chrono::Utc;
use diesel::prelude::*;
@ -8,11 +9,11 @@ use sha2::{Digest, Sha256};
use std::borrow::Cow;
use std::fmt;
use std::fmt::Display;
use typed_builder::TypedBuilder;
use std::time::Duration;
use uuid::Uuid;
/// States of a task.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum TaskState {
/// The task is ready to be executed.
Ready,
@ -21,7 +22,7 @@ pub enum TaskState {
Running,
/// The task has failed to execute.
Failed,
Failed(String),
/// The task finished successfully.
Done,
@ -36,24 +37,6 @@ impl Display for TaskId {
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
pub struct TaskType(Cow<'static, str>);
impl Default for TaskType {
fn default() -> Self {
Self(Cow::from("default"))
}
}
impl<S> From<S> for TaskType
where
S: AsRef<str> + 'static,
{
fn from(s: S) -> Self {
TaskType(Cow::from(s.as_ref().to_owned()))
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
pub struct TaskHash(Cow<'static, str>);
@ -70,42 +53,55 @@ impl TaskHash {
}
}
#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone, TypedBuilder)]
#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone)]
#[diesel(table_name = backie_tasks)]
pub struct Task {
#[builder(setter(into))]
/// Unique identifier of the task.
pub id: TaskId,
#[builder(setter(into))]
pub payload: serde_json::Value,
/// Name of the type of task.
pub task_name: String,
#[builder(setter(into))]
pub error_message: Option<String>,
/// Queue name that the task belongs to.
pub queue_name: String,
#[builder(setter(into))]
pub task_type: TaskType,
#[builder(setter(into))]
/// Unique hash is used to identify and avoid duplicate tasks.
pub uniq_hash: Option<TaskHash>,
#[builder(setter(into))]
pub retries: i32,
/// Representation of the task.
pub payload: serde_json::Value,
#[builder(setter(into))]
/// Max timeout that the task can run for.
pub timeout_msecs: i64,
/// Creation time of the task.
pub created_at: DateTime<Utc>,
#[builder(setter(into))]
/// Date time when the task is scheduled to run.
pub scheduled_at: DateTime<Utc>,
/// Date time when the task is started to run.
pub running_at: Option<DateTime<Utc>>,
#[builder(setter(into))]
/// Date time when the task is finished.
pub done_at: Option<DateTime<Utc>>,
/// Failure reason, when the task is failed.
pub error_info: Option<serde_json::Value>,
/// Number of times a task was retried.
pub retries: i32,
/// Maximum number of retries allow for this task before it is maked as failure.
pub max_retries: i32,
}
impl Task {
pub fn state(&self) -> TaskState {
if self.done_at.is_some() {
if self.error_message.is_some() {
TaskState::Failed
if self.error_info.is_some() {
// TODO: use a proper error type
TaskState::Failed(self.error_info.clone().unwrap().to_string())
} else {
TaskState::Done
}
@ -117,22 +113,61 @@ impl Task {
}
}
#[derive(Insertable, Debug, Eq, PartialEq, Clone, TypedBuilder)]
#[derive(Insertable, Debug, Eq, PartialEq, Clone)]
#[diesel(table_name = backie_tasks)]
pub struct NewTask {
#[builder(setter(into))]
payload: serde_json::Value,
#[builder(setter(into))]
task_type: TaskType,
#[builder(setter(into))]
pub(crate) struct NewTask {
task_name: String,
queue_name: String,
uniq_hash: Option<TaskHash>,
payload: serde_json::Value,
timeout_msecs: i64,
max_retries: i32,
}
pub struct TaskInfo {
impl NewTask {
pub(crate) fn new<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error>
where
T: BackgroundTask,
{
let max_retries = background_task.max_retries();
let uniq_hash = background_task.uniq();
let payload = serde_json::to_value(background_task)?;
Ok(Self {
task_name: T::TASK_NAME.to_string(),
queue_name: T::QUEUE.to_string(),
uniq_hash,
payload,
timeout_msecs: timeout.as_millis() as i64,
max_retries,
})
}
}
pub struct CurrentTask {
id: TaskId,
error_message: Option<String>,
retries: i32,
created_at: DateTime<Utc>,
}
impl CurrentTask {
pub(crate) fn new(task: &Task) -> Self {
Self {
id: task.id.clone(),
retries: task.retries,
created_at: task.created_at,
}
}
pub fn id(&self) -> TaskId {
self.id
}
pub fn retry_count(&self) -> i32 {
self.retries
}
pub fn created_at(&self) -> DateTime<Utc> {
self.created_at
}
}

View file

@ -1,53 +1,104 @@
use crate::errors::BackieError;
use crate::queue::Queueable;
use crate::runnable::RunnableTask;
use crate::task::{Task, TaskType};
use crate::RetentionMode;
use crate::Scheduled::*;
use crate::errors::{AsyncQueueError, BackieError};
use crate::runnable::BackgroundTask;
use crate::task::{CurrentTask, Task, TaskState};
use crate::{PgTaskStore, RetentionMode};
use futures::future::FutureExt;
use futures::select;
use std::error::Error;
use typed_builder::TypedBuilder;
use std::collections::BTreeMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
/// it executes tasks only of task_type type, it sleeps when there are no tasks in the queue
#[derive(TypedBuilder)]
pub struct Worker<Q>
where
Q: Queueable + Clone + Sync + 'static,
{
#[builder(setter(into))]
pub queue: Q,
pub type ExecuteTaskFn<AppData> = Arc<
dyn Fn(
CurrentTask,
serde_json::Value,
AppData,
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
+ Send
+ Sync,
>;
#[builder(default, setter(into))]
pub task_type: Option<TaskType>,
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
#[builder(default, setter(into))]
pub retention_mode: RetentionMode,
#[derive(Debug, Error)]
pub enum TaskExecError {
#[error("Task execution failed: {0}")]
ExecutionFailed(#[from] anyhow::Error),
#[builder(default, setter(into))]
pub shutdown: Option<tokio::sync::watch::Receiver<()>>,
#[error("Task deserialization failed: {0}")]
TaskDeserializationFailed(#[from] serde_json::Error),
}
impl<Q> Worker<Q>
pub(crate) fn runnable<BT>(
task_info: CurrentTask,
payload: serde_json::Value,
app_context: BT::AppData,
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
where
Q: Queueable + Clone + Sync + 'static,
BT: BackgroundTask,
{
Box::pin(async move {
let background_task: BT = serde_json::from_value(payload)?;
background_task.run(task_info, app_context).await?;
Ok(())
})
}
/// Worker that executes tasks.
pub struct Worker<AppData>
where
AppData: Clone + Send + 'static,
{
store: Arc<PgTaskStore>,
queue_name: String,
retention_mode: RetentionMode,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
app_data_fn: StateFn<AppData>,
/// Notification for the worker to stop.
shutdown: Option<tokio::sync::watch::Receiver<()>>,
}
impl<AppData> Worker<AppData>
where
AppData: Clone + Send + 'static,
{
pub(crate) fn new(
store: Arc<PgTaskStore>,
queue_name: String,
retention_mode: RetentionMode,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
app_data_fn: StateFn<AppData>,
shutdown: Option<tokio::sync::watch::Receiver<()>>,
) -> Self {
Self {
store,
queue_name,
retention_mode,
task_registry,
app_data_fn,
shutdown,
}
}
pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> {
loop {
// Need to check if has to stop before pulling next task
match self.queue.pull_next_task(self.task_type.clone()).await? {
Some(task) => {
let actual_task: Box<dyn RunnableTask> =
serde_json::from_value(task.payload.clone())?;
// check if task is scheduled or not
if let Some(CronPattern(_)) = actual_task.cron() {
// program task
//self.queue.schedule_task(&*actual_task).await?;
// Check if has to stop before pulling next task
if let Some(ref shutdown) = self.shutdown {
if shutdown.has_changed()? {
return Ok(());
}
// run scheduled task
// TODO: what do we do if the task fails? it's an internal error, inform the logs
let _ = self.run(task, actual_task).await;
};
match self.store.pull_next_task(&self.queue_name).await? {
Some(task) => {
self.run(task).await?;
}
None => {
// Listen to watchable future
@ -73,41 +124,45 @@ where
}
}
#[cfg(test)]
pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> {
loop {
match self.queue.pull_next_task(self.task_type.clone()).await? {
Some(task) => {
let actual_task: Box<dyn RunnableTask> =
serde_json::from_value(task.payload.clone()).unwrap();
// #[cfg(test)]
// pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> {
// loop {
// match self.store.pull_next_task(self.queue_name.clone()).await? {
// Some(task) => {
// let actual_task: Box<dyn BackgroundTask> =
// serde_json::from_value(task.payload.clone()).unwrap();
//
// // check if task is scheduled or not
// if let Some(CronPattern(_)) = actual_task.cron() {
// // program task
// // self.queue.schedule_task(&*actual_task).await?;
// }
// // run scheduled task
// self.run(task, actual_task).await?;
// }
// None => {
// return Ok(());
// }
// };
// }
// }
// check if task is scheduled or not
if let Some(CronPattern(_)) = actual_task.cron() {
// program task
// self.queue.schedule_task(&*actual_task).await?;
}
// run scheduled task
self.run(task, actual_task).await?;
}
None => {
return Ok(());
}
};
}
}
async fn run(&self, task: Task) -> Result<(), BackieError> {
let task_info = CurrentTask::new(&task);
let runnable_task_caller = self
.task_registry
.get(&task.task_name)
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
async fn run(
&mut self,
task: Task,
runnable: Box<dyn RunnableTask>,
) -> Result<(), BackieError> {
// TODO: catch panics
let result = runnable.run(&mut self.queue).await;
match result {
let result: Result<(), TaskExecError> =
runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
match &result {
Ok(_) => self.finalize_task(task, result).await?,
Err(error) => {
if task.retries < runnable.max_retries() {
let backoff_seconds = runnable.backoff(task.retries as u32);
if task.retries < task.max_retries {
let backoff_seconds = 5; // TODO: runnable.backoff(task.retries as u32);
log::debug!(
"Task {} failed to run and will be retried in {} seconds",
@ -115,12 +170,12 @@ where
backoff_seconds
);
let error_message = format!("{}", error);
self.queue
self.store
.schedule_task_retry(task.id, backoff_seconds, &error_message)
.await?;
} else {
log::debug!("Task {} failed and reached the maximum retries", task.id);
self.finalize_task(task, Err(error)).await?;
self.finalize_task(task, result).await?;
}
}
}
@ -128,36 +183,36 @@ where
}
async fn finalize_task(
&mut self,
&self,
task: Task,
result: Result<(), Box<dyn Error + Send + 'static>>,
result: Result<(), TaskExecError>,
) -> Result<(), BackieError> {
match self.retention_mode {
RetentionMode::KeepAll => match result {
Ok(_) => {
self.queue.set_task_done(task.id).await?;
self.store.set_task_state(task.id, TaskState::Done).await?;
log::debug!("Task {} done and kept in the database", task.id);
}
Err(error) => {
log::debug!("Task {} failed and kept in the database", task.id);
self.queue
.set_task_failed(task.id, &format!("{}", error))
self.store
.set_task_state(task.id, TaskState::Failed(format!("{}", error)))
.await?;
}
},
RetentionMode::RemoveAll => {
log::debug!("Task {} finalized and deleted from the database", task.id);
self.queue.remove_task(task.id).await?;
self.store.remove_task(task.id).await?;
}
RetentionMode::RemoveDone => match result {
Ok(_) => {
log::debug!("Task {} done and deleted from the database", task.id);
self.queue.remove_task(task.id).await?;
self.store.remove_task(task.id).await?;
}
Err(error) => {
log::debug!("Task {} failed and kept in the database", task.id);
self.queue
.set_task_failed(task.id, &format!("{}", error))
self.store
.set_task_state(task.id, TaskState::Failed(format!("{}", error)))
.await?;
}
},
@ -169,35 +224,19 @@ where
#[cfg(test)]
mod async_worker_tests {
use std::fmt::Display;
use super::*;
use crate::queue::PgAsyncQueue;
use crate::queue::Queueable;
use crate::task::TaskState;
use crate::worker::Task;
use crate::RetentionMode;
use crate::Scheduled;
use async_trait::async_trait;
use chrono::Duration;
use chrono::Utc;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Error, Debug)]
#[derive(thiserror::Error, Debug)]
enum TaskError {
#[error("Something went wrong")]
SomethingWrong,
Custom(String),
}
impl Display for TaskError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TaskError::SomethingWrong => write!(f, "Something went wrong"),
TaskError::Custom(message) => write!(f, "{}", message),
}
}
#[error("{0}")]
Custom(String),
}
#[derive(Serialize, Deserialize)]
@ -205,13 +244,12 @@ mod async_worker_tests {
pub number: u16,
}
#[typetag::serde]
#[async_trait]
impl RunnableTask for WorkerAsyncTask {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
impl BackgroundTask for WorkerAsyncTask {
const TASK_NAME: &'static str = "WorkerAsyncTask";
type AppData = ();
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
}
@ -221,18 +259,18 @@ mod async_worker_tests {
pub number: u16,
}
#[typetag::serde]
#[async_trait]
impl RunnableTask for WorkerAsyncTaskSchedule {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
impl BackgroundTask for WorkerAsyncTaskSchedule {
const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
type AppData = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
fn cron(&self) -> Option<Scheduled> {
Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
}
// fn cron(&self) -> Option<Scheduled> {
// Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
// }
}
#[derive(Serialize, Deserialize)]
@ -240,16 +278,15 @@ mod async_worker_tests {
pub number: u16,
}
#[typetag::serde]
#[async_trait]
impl RunnableTask for AsyncFailedTask {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
impl BackgroundTask for AsyncFailedTask {
const TASK_NAME: &'static str = "AsyncFailedTask";
type AppData = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
let message = format!("number {} is wrong :(", self.number);
Err(Box::new(TaskError::Custom(message)))
Err(TaskError::Custom(message).into())
}
fn max_retries(&self) -> i32 {
@ -260,282 +297,40 @@ mod async_worker_tests {
#[derive(Serialize, Deserialize, Clone)]
struct AsyncRetryTask {}
#[typetag::serde]
#[async_trait]
impl RunnableTask for AsyncRetryTask {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
Err(Box::new(TaskError::SomethingWrong))
}
impl BackgroundTask for AsyncRetryTask {
const TASK_NAME: &'static str = "AsyncRetryTask";
type AppData = ();
fn max_retries(&self) -> i32 {
2
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Err(TaskError::SomethingWrong.into())
}
}
#[derive(Serialize, Deserialize)]
struct AsyncTaskType1 {}
#[typetag::serde]
#[async_trait]
impl RunnableTask for AsyncTaskType1 {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
Ok(())
}
impl BackgroundTask for AsyncTaskType1 {
const TASK_NAME: &'static str = "AsyncTaskType1";
type AppData = ();
fn task_type(&self) -> TaskType {
"type1".into()
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct AsyncTaskType2 {}
#[typetag::serde]
#[async_trait]
impl RunnableTask for AsyncTaskType2 {
async fn run(
&self,
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
impl BackgroundTask for AsyncTaskType2 {
const TASK_NAME: &'static str = "AsyncTaskType2";
type AppData = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
fn task_type(&self) -> TaskType {
TaskType::from("type2")
}
}
#[tokio::test]
async fn execute_and_finishes_task() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let actual_task = WorkerAsyncTask { number: 1 };
let task = insert_task(&mut test, &actual_task).await;
let id = task.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run(task, Box::new(actual_task)).await.unwrap();
let task_finished = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task_finished.id);
assert_eq!(TaskState::Done, task_finished.state());
test.remove_all_tasks().await.unwrap();
}
// #[tokio::test]
// async fn schedule_task_test() {
// let pool = pool().await;
// let mut test = PgAsyncQueue::new(pool);
//
// let actual_task = WorkerAsyncTaskSchedule { number: 1 };
//
// let task = test.schedule_task(&actual_task).await.unwrap();
//
// let id = task.id;
//
// let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
// .queue(test.clone())
// .retention_mode(RetentionMode::KeepAll)
// .build();
//
// worker.run_tasks_until_none().await.unwrap();
//
// let task = worker.queue.find_task_by_id(id).await.unwrap();
//
// assert_eq!(id, task.id);
// assert_eq!(TaskState::Ready, task.state());
//
// tokio::time::sleep(core::time::Duration::from_secs(3)).await;
//
// worker.run_tasks_until_none().await.unwrap();
//
// let task = test.find_task_by_id(id).await.unwrap();
// assert_eq!(id, task.id);
// assert_eq!(TaskState::Done, task.state());
//
// test.remove_all_tasks().await.unwrap();
// }
#[tokio::test]
async fn retries_task_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let actual_task = AsyncRetryTask {};
let task = test.create_task(&actual_task).await.unwrap();
let id = task.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run_tasks_until_none().await.unwrap();
let task = worker.queue.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(TaskState::Ready, task.state());
assert_eq!(1, task.retries);
assert!(task.error_message.is_some());
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
worker.run_tasks_until_none().await.unwrap();
let task = worker.queue.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(TaskState::Ready, task.state());
assert_eq!(2, task.retries);
tokio::time::sleep(core::time::Duration::from_secs(10)).await;
worker.run_tasks_until_none().await.unwrap();
let task = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(TaskState::Failed, task.state());
assert_eq!("Something went wrong".to_string(), task.error_message.unwrap());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn worker_shutsdown_when_notified() {
let pool = pool().await;
let queue = PgAsyncQueue::new(pool);
let (tx, rx) = tokio::sync::watch::channel(());
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(queue)
.shutdown(rx)
.build();
let handle = tokio::spawn(async move {
worker.run_tasks().await.unwrap();
true
});
tx.send(()).unwrap();
select! {
_ = handle.fuse() => {}
_ = tokio::time::sleep(core::time::Duration::from_secs(1)).fuse() => panic!("Worker did not shutdown")
}
}
#[tokio::test]
async fn saves_error_for_failed_task() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let failed_task = AsyncFailedTask { number: 1 };
let task = insert_task(&mut test, &failed_task).await;
let id = task.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run(task, Box::new(failed_task)).await.unwrap();
let task_finished = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task_finished.id);
assert_eq!(TaskState::Failed, task_finished.state());
assert_eq!(
"number 1 is wrong :(".to_string(),
task_finished.error_message.unwrap()
);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn executes_task_only_of_specific_type() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
let id1 = task1.id;
let id12 = task12.id;
let id2 = task2.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.task_type(TaskType::from("type1"))
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run_tasks_until_none().await.unwrap();
let task1 = test.find_task_by_id(id1).await.unwrap();
let task12 = test.find_task_by_id(id12).await.unwrap();
let task2 = test.find_task_by_id(id2).await.unwrap();
assert_eq!(id1, task1.id);
assert_eq!(id12, task12.id);
assert_eq!(id2, task2.id);
assert_eq!(TaskState::Done, task1.state());
assert_eq!(TaskState::Done, task12.state());
assert_eq!(TaskState::Ready, task2.state());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn remove_when_finished() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
let _id1 = task1.id;
let _id12 = task12.id;
let id2 = task2.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.task_type(TaskType::from("type1"))
.build();
worker.run_tasks_until_none().await.unwrap();
let task = test
.pull_next_task(Some(TaskType::from("type1")))
.await
.unwrap();
assert_eq!(None, task);
let task2 = test
.pull_next_task(Some(TaskType::from("type2")))
.await
.unwrap()
.unwrap();
assert_eq!(id2, task2.id);
test.remove_all_tasks().await.unwrap();
}
async fn insert_task(test: &mut PgAsyncQueue, task: &dyn RunnableTask) -> Task {
test.create_task(task).await.unwrap()
}
async fn pool() -> Pool<AsyncPgConnection> {

View file

@ -1,102 +1,200 @@
use crate::errors::BackieError;
use crate::queue::Queueable;
use crate::task::TaskType;
use crate::worker::Worker;
use crate::RetentionMode;
use async_recursion::async_recursion;
use log::error;
use crate::queue::Queue;
use crate::worker::{runnable, ExecuteTaskFn};
use crate::worker::{StateFn, Worker};
use crate::{BackgroundTask, CurrentTask, PgTaskStore, RetentionMode};
use std::collections::BTreeMap;
use std::future::Future;
use tokio::sync::watch::Receiver;
use typed_builder::TypedBuilder;
use std::sync::Arc;
use tokio::task::JoinHandle;
#[derive(TypedBuilder, Clone)]
pub struct WorkerPool<AQueue>
pub type AppDataFn<AppData> = Arc<dyn Fn(Queue) -> AppData + Send + Sync>;
#[derive(Clone)]
pub struct WorkerPool<AppData>
where
AQueue: Queueable + Clone + Sync + 'static,
AppData: Clone + Send + 'static,
{
#[builder(setter(into))]
/// the AsyncWorkerPool uses a queue to control the tasks that will be executed.
pub queue: AQueue,
/// Storage of tasks.
queue_store: Arc<PgTaskStore>, // TODO: make this generic/dynamic referenced
/// retention_mode controls if tasks should be persisted after execution
#[builder(default, setter(into))]
pub retention_mode: RetentionMode,
/// Queue used to spawn tasks.
queue: Queue,
/// the number of workers of the AsyncWorkerPool.
#[builder(setter(into))]
pub number_of_workers: u32,
/// Make possible to load the application data.
///
/// The application data is loaded when the worker pool is started and is passed to the tasks.
/// The loading function accepts a queue instance in case the application data depends on it. This
/// is interesting for situations where the application wants to allow tasks to spawn other tasks.
application_data_fn: StateFn<AppData>,
/// The type of tasks that will be executed by `AsyncWorkerPool`.
#[builder(default, setter(into))]
pub task_type: Option<TaskType>,
/// The types of task the worker pool can execute and the loaders for them.
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
/// Number of workers that will be spawned per queue.
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
}
// impl<TypedBuilderFields, Q> AsyncWorkerBuilder<TypedBuilderFields, Q>
// where
// TypedBuilderFields: Clone,
// Q: Queueable + Clone + Sync + 'static,
// {
// pub fn with_graceful_shutdown<F>(self, signal: F) -> Self<TypedBuilderFields, Q>
// where
// F: Future<Output = ()>,
// {
// self
// }
// }
impl<AQueue> WorkerPool<AQueue>
impl<AppData> WorkerPool<AppData>
where
AQueue: Queueable + Clone + Sync + 'static,
AppData: Clone + Send + 'static,
{
/// Starts the configured number of workers
/// This is necessary in order to execute tasks.
pub async fn start<F>(&mut self, graceful_shutdown: F) -> Result<(), BackieError>
/// Create a new worker pool.
pub fn new<A>(queue_store: PgTaskStore, application_data_fn: A) -> Self
where
A: Fn(Queue) -> AppData + Send + Sync + 'static,
{
let queue_store = Arc::new(queue_store);
let queue = Queue::new(queue_store.clone());
let application_data_fn = {
let queue = queue.clone();
move || application_data_fn(queue.clone())
};
Self {
queue_store,
queue,
application_data_fn: Arc::new(application_data_fn),
task_registry: BTreeMap::new(),
worker_queues: BTreeMap::new(),
}
}
/// Register a task type with the worker pool.
pub fn register_task_type<BT>(mut self, num_workers: u32, retention_mode: RetentionMode) -> Self
where
BT: BackgroundTask<AppData = AppData>,
{
self.worker_queues
.insert(BT::QUEUE.to_string(), (retention_mode, num_workers));
self.task_registry
.insert(BT::TASK_NAME.to_string(), Arc::new(runnable::<BT>));
self
}
pub async fn start<F>(
self,
graceful_shutdown: F,
) -> Result<(JoinHandle<()>, Queue), BackieError>
where
F: Future<Output = ()> + Send + 'static,
{
let (tx, rx) = tokio::sync::watch::channel(());
for idx in 0..self.number_of_workers {
let pool = self.clone();
// TODO: the worker pool keeps track of the number of workers and spawns new workers as needed.
// There should be always a minimum number of workers active waiting for tasks to execute
// or for a gracefull shutdown.
tokio::spawn(Self::supervise_task(pool, rx.clone(), 0, idx));
}
graceful_shutdown.await;
tx.send(())?;
log::info!("Worker pool stopped gracefully");
Ok(())
}
#[async_recursion]
async fn supervise_task(
pool: WorkerPool<AQueue>,
receiver: Receiver<()>,
restarts: u64,
worker_number: u32,
) {
let restarts = restarts + 1;
let inner_pool = pool.clone();
let inner_receiver = receiver.clone();
let join_handle = tokio::spawn(async move {
let mut worker: Worker<AQueue> = Worker::builder()
.queue(inner_pool.queue.clone())
.retention_mode(inner_pool.retention_mode)
.task_type(inner_pool.task_type.clone())
.shutdown(inner_receiver)
.build();
worker.run_tasks().await
});
if (join_handle.await).is_err() {
error!(
"Worker {} stopped. Restarting. the number of restarts {}",
worker_number, restarts,
// Spawn all individual workers per queue
for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() {
for idx in 0..*num_workers {
let mut worker: Worker<AppData> = Worker::new(
self.queue_store.clone(),
queue_name.clone(),
retention_mode.clone(),
self.task_registry.clone(),
self.application_data_fn.clone(),
Some(rx.clone()),
);
Self::supervise_task(pool, receiver, restarts, worker_number).await;
let worker_name = format!("worker-{queue_name}-{idx}");
// TODO: grab the join handle for every worker for graceful shutdown
tokio::spawn(async move {
worker.run_tasks().await;
log::info!("Worker {} stopped", worker_name);
});
}
}
Ok((
tokio::spawn(async move {
graceful_shutdown.await;
if let Err(err) = tx.send(()) {
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
} else {
log::info!("Worker pool stopped gracefully");
}
}),
self.queue,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
#[derive(Clone, Debug)]
struct ApplicationContext {
app_name: String,
}
impl ApplicationContext {
fn new() -> Self {
Self {
app_name: "Backie".to_string(),
}
}
fn get_app_name(&self) -> String {
self.app_name.clone()
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct GreetingTask {
person: String,
}
#[async_trait]
impl BackgroundTask for GreetingTask {
const TASK_NAME: &'static str = "my_task";
type AppData = ApplicationContext;
async fn run(
&self,
task_info: CurrentTask,
app_context: Self::AppData,
) -> Result<(), anyhow::Error> {
println!(
"[{}] Hello {}! I'm {}.",
task_info.id(),
self.person,
app_context.get_app_name()
);
Ok(())
}
}
#[tokio::test]
async fn test_worker_pool() {
let my_app_context = ApplicationContext::new();
let task_store = PgTaskStore::new(pool().await);
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>(1, RetentionMode::RemoveDone)
.start(futures::future::ready(()))
.await
.unwrap();
queue
.enqueue(GreetingTask {
person: "Rafael".to_string(),
})
.await
.unwrap();
join_handle.await.unwrap();
}
async fn pool() -> Pool<AsyncPgConnection> {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
);
Pool::builder()
.max_size(1)
.min_idle(Some(1))
.build(manager)
.await
.unwrap()
}
}