Make possible to provide app state to tasks

This commit is contained in:
Rafael Caricio 2023-03-10 23:41:34 +01:00
parent 7fcb63f75c
commit aac0b44c7f
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
19 changed files with 776 additions and 1151 deletions

1
.gitignore vendored
View file

@ -1,5 +1,4 @@
**/target **/target
Cargo.lock Cargo.lock
src/schema.rs
docs/content/docs/CHANGELOG.md docs/content/docs/CHANGELOG.md
docs/content/docs/README.md docs/content/docs/README.md

View file

@ -19,16 +19,13 @@ cron = "0.12"
chrono = "0.4" chrono = "0.4"
hex = "0.4" hex = "0.4"
log = "0.4" log = "0.4"
serde = "1.0" serde = { version = "1", features = ["derive"] }
serde_derive = "1.0" serde_json = "1"
serde_json = "1.0"
sha2 = "0.10" sha2 = "0.10"
thiserror = "1.0" anyhow = "1"
typed-builder = "0.13" thiserror = "1"
typetag = "0.2"
uuid = { version = "1.1", features = ["v4", "serde"] } uuid = { version = "1.1", features = ["v4", "serde"] }
async-trait = "0.1" async-trait = "0.1"
async-recursion = "1"
futures = "0.3" futures = "0.3"
diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] } diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] }
diesel-derive-newtype = "2.0.0-rc.0" diesel-derive-newtype = "2.0.0-rc.0"

View file

@ -1,12 +0,0 @@
[package]
name = "simple_cron_async_worker"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
fang = { path = "../../../" , features = ["asynk"]}
env_logger = "0.9.0"
log = "0.4.0"
tokio = { version = "1", features = ["full"] }

View file

@ -1,33 +0,0 @@
use fang::async_trait;
use fang::asynk::async_queue::AsyncQueueable;
use fang::serde::{Deserialize, Serialize};
use fang::typetag;
use fang::AsyncRunnable;
use fang::FangError;
use fang::Scheduled;
#[derive(Serialize, Deserialize)]
#[serde(crate = "fang::serde")]
pub struct MyCronTask {}
#[async_trait]
#[typetag::serde]
impl AsyncRunnable for MyCronTask {
async fn run(&self, _queue: &mut dyn AsyncQueueable) -> Result<(), FangError> {
log::info!("CRON!!!!!!!!!!!!!!!",);
Ok(())
}
fn cron(&self) -> Option<Scheduled> {
// sec min hour day of month month day of week year
// be careful works only with UTC hour.
// https://www.timeanddate.com/worldclock/timezone/utc
let expression = "0/20 * * * Aug-Sep * 2022/1";
Some(Scheduled::CronPattern(expression.to_string()))
}
fn uniq(&self) -> bool {
true
}
}

View file

@ -1,41 +0,0 @@
use fang::asynk::async_queue::AsyncQueue;
use fang::asynk::async_queue::AsyncQueueable;
use fang::asynk::async_worker_pool::AsyncWorkerPool;
use fang::AsyncRunnable;
use fang::NoTls;
use simple_cron_async_worker::MyCronTask;
use std::time::Duration;
#[tokio::main]
async fn main() {
env_logger::init();
log::info!("Starting...");
let max_pool_size: u32 = 3;
let mut queue = AsyncQueue::builder()
.uri("postgres://postgres:postgres@localhost/fang")
.max_pool_size(max_pool_size)
.build();
queue.connect(NoTls).await.unwrap();
log::info!("Queue connected...");
let mut pool: AsyncWorkerPool<AsyncQueue<NoTls>> = AsyncWorkerPool::builder()
.number_of_workers(10_u32)
.queue(queue.clone())
.build();
log::info!("Pool created ...");
pool.start().await;
log::info!("Workers started ...");
let task = MyCronTask {};
queue
.schedule_task(&task as &dyn AsyncRunnable)
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(100)).await;
}

View file

@ -5,6 +5,7 @@ edition = "2021"
[dependencies] [dependencies]
backie = { path = "../../" } backie = { path = "../../" }
anyhow = "1"
env_logger = "0.9.0" env_logger = "0.9.0"
log = "0.4.0" log = "0.4.0"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
@ -12,4 +13,3 @@ diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
diesel = { version = "2.0", features = ["postgres"] } diesel = { version = "2.0", features = ["postgres"] }
async-trait = "0.1" async-trait = "0.1"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
typetag = "0.2"

View file

@ -1,7 +1,20 @@
use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Serialize, Deserialize}; use backie::{BackgroundTask, CurrentTask};
use backie::{RunnableTask, Queueable}; use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct MyApplicationContext {
app_name: String,
}
impl MyApplicationContext {
pub fn new(app_name: &str) -> Self {
Self {
app_name: app_name.to_string(),
}
}
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct MyTask { pub struct MyTask {
@ -26,37 +39,51 @@ impl MyFailingTask {
} }
#[async_trait] #[async_trait]
#[typetag::serde] impl BackgroundTask for MyTask {
impl RunnableTask for MyTask { const TASK_NAME: &'static str = "my_task";
async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box<dyn std::error::Error + Send + 'static>> { type AppData = MyApplicationContext;
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> {
// let new_task = MyTask::new(self.number + 1); // let new_task = MyTask::new(self.number + 1);
// queue // queue
// .insert_task(&new_task as &dyn AsyncRunnable) // .insert_task(&new_task)
// .await // .await
// .unwrap(); // .unwrap();
log::info!("the current number is {}", self.number); log::info!(
"[{}] Hello from {}! the current number is {}",
task.id(),
ctx.app_name,
self.number
);
tokio::time::sleep(Duration::from_secs(3)).await; tokio::time::sleep(Duration::from_secs(3)).await;
log::info!("done.."); log::info!("[{}] done..", task.id());
Ok(()) Ok(())
} }
} }
#[async_trait] #[async_trait]
#[typetag::serde] impl BackgroundTask for MyFailingTask {
impl RunnableTask for MyFailingTask { const TASK_NAME: &'static str = "my_failing_task";
async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box<dyn std::error::Error + Send + 'static>> { type AppData = MyApplicationContext;
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> {
// let new_task = MyFailingTask::new(self.number + 1); // let new_task = MyFailingTask::new(self.number + 1);
// queue // queue
// .insert_task(&new_task as &dyn AsyncRunnable) // .insert_task(&new_task)
// .await // .await
// .unwrap(); // .unwrap();
log::info!("the current number is {}", self.number); // task.id();
// task.keep_alive().await?;
// task.previous_error();
// task.retry_count();
log::info!("[{}] the current number is {}", task.id(), self.number);
tokio::time::sleep(Duration::from_secs(3)).await; tokio::time::sleep(Duration::from_secs(3)).await;
log::info!("done.."); log::info!("[{}] done..", task.id());
// //
// let b = true; // let b = true;
// //

View file

@ -1,9 +1,9 @@
use simple_worker::MyFailingTask; use backie::{PgTaskStore, RetentionMode, WorkerPool};
use simple_worker::MyTask;
use std::time::Duration;
use diesel_async::pg::AsyncPgConnection; use diesel_async::pg::AsyncPgConnection;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use backie::{PgAsyncQueue, WorkerPool, Queueable}; use simple_worker::MyApplicationContext;
use simple_worker::MyFailingTask;
use simple_worker::MyTask;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -22,49 +22,38 @@ async fn main() {
.unwrap(); .unwrap();
log::info!("Pool created ..."); log::info!("Pool created ...");
let mut queue = PgAsyncQueue::new(pool); let task_store = PgTaskStore::new(pool);
let (tx, mut rx) = tokio::sync::watch::channel(false); let (tx, mut rx) = tokio::sync::watch::channel(false);
let executor_task = tokio::spawn({ // Some global application context I want to pass to my background tasks
let mut queue = queue.clone(); let my_app_context = MyApplicationContext::new("Backie Example App");
async move {
let mut workers_pool: WorkerPool<PgAsyncQueue> = WorkerPool::builder()
.number_of_workers(10_u32)
.queue(queue)
.build();
log::info!("Workers starting ..."); // Register the task types I want to use and start the worker pool
workers_pool.start(async move { let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
rx.changed().await; .register_task_type::<MyTask>(1, RetentionMode::RemoveDone)
}).await; .register_task_type::<MyFailingTask>(1, RetentionMode::RemoveDone)
log::info!("Workers stopped!"); .start(async move {
} let _ = rx.changed().await;
}); })
.await
.unwrap();
log::info!("Workers started ...");
let task1 = MyTask::new(0); let task1 = MyTask::new(0);
let task2 = MyTask::new(20_000); let task2 = MyTask::new(20_000);
let task3 = MyFailingTask::new(50_000); let task3 = MyFailingTask::new(50_000);
queue queue.enqueue(task1).await.unwrap();
.create_task(&task1) queue.enqueue(task2).await.unwrap();
.await queue.enqueue(task3).await.unwrap();
.unwrap();
queue
.create_task(&task2)
.await
.unwrap();
queue
.create_task(&task3)
.await
.unwrap();
log::info!("Tasks created ..."); log::info!("Tasks created ...");
tokio::signal::ctrl_c().await;
// Wait for Ctrl+C
let _ = tokio::signal::ctrl_c().await;
log::info!("Stopping ..."); log::info!("Stopping ...");
tx.send(true).unwrap(); tx.send(true).unwrap();
executor_task.await.unwrap(); join_handle.await.unwrap();
log::info!("Stopped!"); log::info!("Workers Stopped!");
} }

View file

@ -2,20 +2,20 @@ CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE backie_tasks ( CREATE TABLE backie_tasks (
id uuid PRIMARY KEY DEFAULT uuid_generate_v4(), id uuid PRIMARY KEY DEFAULT uuid_generate_v4(),
payload jsonb NOT NULL, task_name VARCHAR NOT NULL,
error_message TEXT DEFAULT NULL, queue_name VARCHAR DEFAULT 'common' NOT NULL,
task_type VARCHAR DEFAULT 'common' NOT NULL,
uniq_hash CHAR(64) DEFAULT NULL, uniq_hash CHAR(64) DEFAULT NULL,
retries INTEGER DEFAULT 0 NOT NULL, payload jsonb NOT NULL,
timeout_msecs INT8 NOT NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
running_at TIMESTAMP WITH TIME ZONE DEFAULT NULL, running_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
error_info jsonb DEFAULT NULL,
retries INTEGER DEFAULT 0 NOT NULL,
max_retries INTEGER DEFAULT 0 NOT NULL
); );
CREATE INDEX backie_tasks_type_index ON backie_tasks(task_type);
CREATE INDEX backie_tasks_created_at_index ON backie_tasks(created_at);
CREATE INDEX backie_tasks_uniq_hash ON backie_tasks(uniq_hash);
--- create uniqueness index --- create uniqueness index
CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL; CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL;

View file

@ -1,25 +1,16 @@
use serde_json::Error as SerdeError;
use std::fmt::Display;
use thiserror::Error; use thiserror::Error;
/// Library errors /// Library errors
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum BackieError { pub enum BackieError {
#[error("Queue processing error: {0}")]
QueueProcessingError(#[from] AsyncQueueError), QueueProcessingError(#[from] AsyncQueueError),
SerializationError(#[from] SerdeError),
ShutdownError(#[from] tokio::sync::watch::error::SendError<()>),
}
impl Display for BackieError { #[error("Worker Pool shutdown error: {0}")]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { WorkerPoolShutdownError(#[from] tokio::sync::watch::error::SendError<()>),
match self {
BackieError::QueueProcessingError(error) => { #[error("Worker shutdown error: {0}")]
write!(f, "Queue processing error: {}", error) WorkerShutdownError(#[from] tokio::sync::watch::error::RecvError),
}
BackieError::SerializationError(error) => write!(f, "Serialization error: {}", error),
BackieError::ShutdownError(error) => write!(f, "Shutdown error: {}", error),
}
}
} }
/// List of error types that can occur while working with cron schedules. /// List of error types that can occur while working with cron schedules.
@ -29,7 +20,7 @@ pub enum CronError {
#[error(transparent)] #[error(transparent)]
LibraryError(#[from] cron::error::Error), LibraryError(#[from] cron::error::Error),
/// [`Scheduled`] enum variant is not provided /// [`Scheduled`] enum variant is not provided
#[error("You have to implement method `cron()` in your AsyncRunnable")] #[error("You have to implement method `cron()` in your Runnable")]
TaskNotSchedulableError, TaskNotSchedulableError,
/// The next execution can not be determined using the current [`Scheduled::CronPattern`] /// The next execution can not be determined using the current [`Scheduled::CronPattern`]
#[error("No timestamps match with this cron pattern")] #[error("No timestamps match with this cron pattern")]
@ -41,7 +32,7 @@ pub enum AsyncQueueError {
#[error(transparent)] #[error(transparent)]
PgError(#[from] diesel::result::Error), PgError(#[from] diesel::result::Error),
#[error(transparent)] #[error("Task serialization error: {0}")]
SerdeError(#[from] serde_json::Error), SerdeError(#[from] serde_json::Error),
#[error(transparent)] #[error(transparent)]
@ -49,6 +40,9 @@ pub enum AsyncQueueError {
#[error("Task is not in progress, operation not allowed")] #[error("Task is not in progress, operation not allowed")]
TaskNotRunning, TaskNotRunning,
#[error("Task with name {0} is not registered")]
TaskNotRegistered(String),
} }
impl From<cron::error::Error> for AsyncQueueError { impl From<cron::error::Error> for AsyncQueueError {

View file

@ -38,9 +38,9 @@ impl Default for RetentionMode {
} }
} }
pub use queue::PgAsyncQueue; pub use queue::PgTaskStore;
pub use queue::Queueable; pub use runnable::BackgroundTask;
pub use runnable::RunnableTask; pub use task::CurrentTask;
pub use worker_pool::WorkerPool; pub use worker_pool::WorkerPool;
pub mod errors; pub mod errors;
@ -48,6 +48,7 @@ mod queries;
pub mod queue; pub mod queue;
pub mod runnable; pub mod runnable;
mod schema; mod schema;
pub mod store;
pub mod task; pub mod task;
pub mod worker; pub mod worker;
pub mod worker_pool; pub mod worker_pool;

View file

@ -1,8 +1,7 @@
use crate::errors::AsyncQueueError; use crate::errors::AsyncQueueError;
use crate::runnable::RunnableTask;
use crate::schema::backie_tasks; use crate::schema::backie_tasks;
use crate::task::Task; use crate::task::Task;
use crate::task::{NewTask, TaskHash, TaskId, TaskType}; use crate::task::{NewTask, TaskHash, TaskId};
use chrono::DateTime; use chrono::DateTime;
use chrono::Duration; use chrono::Duration;
use chrono::Utc; use chrono::Utc;
@ -43,14 +42,6 @@ impl Task {
Ok(qty > 0) Ok(qty > 0)
} }
pub(crate) async fn remove_by_type(
connection: &mut AsyncPgConnection,
task_type: TaskType,
) -> Result<u64, AsyncQueueError> {
let query = backie_tasks::table.filter(backie_tasks::task_type.eq(task_type));
Ok(diesel::delete(query).execute(connection).await? as u64)
}
pub(crate) async fn find_by_id( pub(crate) async fn find_by_id(
connection: &mut AsyncPgConnection, connection: &mut AsyncPgConnection,
id: TaskId, id: TaskId,
@ -67,10 +58,13 @@ impl Task {
id: TaskId, id: TaskId,
error_message: &str, error_message: &str,
) -> Result<Task, AsyncQueueError> { ) -> Result<Task, AsyncQueueError> {
let error = serde_json::json!({
"error": error_message,
});
let query = backie_tasks::table.filter(backie_tasks::id.eq(id)); let query = backie_tasks::table.filter(backie_tasks::id.eq(id));
Ok(diesel::update(query) Ok(diesel::update(query)
.set(( .set((
backie_tasks::error_message.eq(error_message), backie_tasks::error_info.eq(Some(error)),
backie_tasks::done_at.eq(Utc::now()), backie_tasks::done_at.eq(Utc::now()),
)) ))
.get_result::<Task>(connection) .get_result::<Task>(connection)
@ -81,18 +75,22 @@ impl Task {
connection: &mut AsyncPgConnection, connection: &mut AsyncPgConnection,
id: TaskId, id: TaskId,
backoff_seconds: u32, backoff_seconds: u32,
error: &str, error_message: &str,
) -> Result<Task, AsyncQueueError> { ) -> Result<Task, AsyncQueueError> {
use crate::schema::backie_tasks::dsl; use crate::schema::backie_tasks::dsl;
let now = Utc::now(); let now = Utc::now();
let scheduled_at = now + Duration::seconds(backoff_seconds as i64); let scheduled_at = now + Duration::seconds(backoff_seconds as i64);
let error = serde_json::json!({
"error": error_message,
});
let task = diesel::update(backie_tasks::table.filter(backie_tasks::id.eq(id))) let task = diesel::update(backie_tasks::table.filter(backie_tasks::id.eq(id)))
.set(( .set((
backie_tasks::error_message.eq(error), backie_tasks::error_info.eq(Some(error)),
backie_tasks::retries.eq(dsl::retries + 1), backie_tasks::retries.eq(dsl::retries + 1),
backie_tasks::created_at.eq(scheduled_at), backie_tasks::scheduled_at.eq(scheduled_at),
backie_tasks::running_at.eq::<Option<DateTime<Utc>>>(None), backie_tasks::running_at.eq::<Option<DateTime<Utc>>>(None),
)) ))
.get_result::<Task>(connection) .get_result::<Task>(connection)
@ -103,14 +101,14 @@ impl Task {
pub(crate) async fn fetch_next_pending( pub(crate) async fn fetch_next_pending(
connection: &mut AsyncPgConnection, connection: &mut AsyncPgConnection,
task_type: TaskType, queue_name: &str,
) -> Option<Task> { ) -> Option<Task> {
backie_tasks::table backie_tasks::table
.filter(backie_tasks::created_at.lt(Utc::now())) // skip tasks scheduled for the future .filter(backie_tasks::scheduled_at.lt(Utc::now())) // skip tasks scheduled for the future
.order(backie_tasks::created_at.asc()) // get the oldest task first .order(backie_tasks::created_at.asc()) // get the oldest task first
.filter(backie_tasks::running_at.is_null()) // that is not marked as running already .filter(backie_tasks::running_at.is_null()) // that is not marked as running already
.filter(backie_tasks::done_at.is_null()) // and not marked as done .filter(backie_tasks::done_at.is_null()) // and not marked as done
.filter(backie_tasks::task_type.eq(task_type)) .filter(backie_tasks::queue_name.eq(queue_name))
.limit(1) .limit(1)
.for_update() .for_update()
.skip_locked() .skip_locked()
@ -143,39 +141,13 @@ impl Task {
pub(crate) async fn insert( pub(crate) async fn insert(
connection: &mut AsyncPgConnection, connection: &mut AsyncPgConnection,
runnable: &dyn RunnableTask, new_task: NewTask,
) -> Result<Task, AsyncQueueError> { ) -> Result<Task, AsyncQueueError> {
let payload = serde_json::to_value(runnable)?;
match runnable.uniq() {
None => {
let new_task = NewTask::builder()
.uniq_hash(None)
.task_type(runnable.task_type())
.payload(payload)
.build();
Ok(diesel::insert_into(backie_tasks::table) Ok(diesel::insert_into(backie_tasks::table)
.values(new_task) .values(new_task)
.get_result::<Task>(connection) .get_result::<Task>(connection)
.await?) .await?)
} }
Some(hash) => match Self::find_by_uniq_hash(connection, hash.clone()).await {
Some(task) => Ok(task),
None => {
let new_task = NewTask::builder()
.uniq_hash(Some(hash))
.task_type(runnable.task_type())
.payload(payload)
.build();
Ok(diesel::insert_into(backie_tasks::table)
.values(new_task)
.get_result::<Task>(connection)
.await?)
}
},
}
}
pub(crate) async fn find_by_uniq_hash( pub(crate) async fn find_by_uniq_hash(
connection: &mut AsyncPgConnection, connection: &mut AsyncPgConnection,

View file

@ -1,86 +1,48 @@
use crate::errors::AsyncQueueError; use crate::errors::AsyncQueueError;
use crate::runnable::RunnableTask; use crate::runnable::BackgroundTask;
use crate::task::{Task, TaskHash, TaskId, TaskType}; use crate::task::{NewTask, Task, TaskHash, TaskId, TaskState};
use async_trait::async_trait;
use diesel::result::Error::QueryBuilderError; use diesel::result::Error::QueryBuilderError;
use diesel_async::scoped_futures::ScopedFutureExt; use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection; use diesel_async::AsyncConnection;
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool}; use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
use std::sync::Arc;
use std::time::Duration;
/// This trait defines operations for an asynchronous queue. #[derive(Clone)]
/// The trait can be implemented for different storage backends. pub struct Queue {
/// For now, the trait is only implemented for PostgreSQL. More backends are planned to be implemented in the future. task_store: Arc<PgTaskStore>,
#[async_trait] }
pub trait Queueable: Send {
/// Pull pending tasks from the queue to execute them.
///
/// This method returns one task of the `task_type` type. If `task_type` is `None` it will try to
/// fetch a task of the type `common`. The returned task is marked as running and must be executed.
async fn pull_next_task(
&mut self,
kind: Option<TaskType>,
) -> Result<Option<Task>, AsyncQueueError>;
/// Enqueue a task to the queue, The task will be executed as soon as possible by the worker of the same type impl Queue {
/// created by an AsyncWorkerPool. pub(crate) fn new(task_store: Arc<PgTaskStore>) -> Self {
async fn create_task(&mut self, task: &dyn RunnableTask) -> Result<Task, AsyncQueueError>; Queue { task_store }
}
/// Retrieve a task by its `id`. pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), AsyncQueueError>
async fn find_task_by_id(&mut self, id: TaskId) -> Result<Task, AsyncQueueError>; where
BT: BackgroundTask,
/// Update the state of a task to failed and set an error_message. {
async fn set_task_failed( self.task_store
&mut self, .create_task(NewTask::new(background_task, Duration::from_secs(10))?)
id: TaskId, .await?;
error_message: &str, Ok(())
) -> Result<Task, AsyncQueueError>; }
/// Update the state of a task to done.
async fn set_task_done(&mut self, id: TaskId) -> Result<Task, AsyncQueueError>;
/// Update the state of a task to inform that it's still in progress.
async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError>;
/// Remove a task by its id.
async fn remove_task(&mut self, id: TaskId) -> Result<u64, AsyncQueueError>;
/// The method will remove all tasks from the queue
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError>;
/// Remove all tasks that are scheduled in the future.
async fn remove_all_scheduled_tasks(&mut self) -> Result<u64, AsyncQueueError>;
/// Remove a task by its metadata (struct fields values)
async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result<bool, AsyncQueueError>;
/// Removes all tasks that have the specified `task_type`.
async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result<u64, AsyncQueueError>;
async fn schedule_task_retry(
&mut self,
id: TaskId,
backoff_seconds: u32,
error: &str,
) -> Result<Task, AsyncQueueError>;
} }
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage. /// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PgAsyncQueue { pub struct PgTaskStore {
pool: Pool<AsyncPgConnection>, pool: Pool<AsyncPgConnection>,
} }
impl PgAsyncQueue { impl PgTaskStore {
pub fn new(pool: Pool<AsyncPgConnection>) -> Self { pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
PgAsyncQueue { pool } PgTaskStore { pool }
}
} }
#[async_trait] pub(crate) async fn pull_next_task(
impl Queueable for PgAsyncQueue { &self,
async fn pull_next_task( queue_name: &str,
&mut self,
task_type: Option<TaskType>,
) -> Result<Option<Task>, AsyncQueueError> { ) -> Result<Option<Task>, AsyncQueueError> {
let mut connection = self let mut connection = self
.pool .pool
@ -90,7 +52,7 @@ impl Queueable for PgAsyncQueue {
connection connection
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| { .transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
async move { async move {
let Some(pending_task) = Task::fetch_next_pending(conn, task_type.unwrap_or_default()).await else { let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else {
return Ok(None); return Ok(None);
}; };
@ -101,16 +63,16 @@ impl Queueable for PgAsyncQueue {
.await .await
} }
async fn create_task(&mut self, runnable: &dyn RunnableTask) -> Result<Task, AsyncQueueError> { pub(crate) async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
let mut connection = self let mut connection = self
.pool .pool
.get() .get()
.await .await
.map_err(|e| QueryBuilderError(e.into()))?; .map_err(|e| QueryBuilderError(e.into()))?;
Ok(Task::insert(&mut connection, runnable).await?) Task::insert(&mut connection, new_task).await
} }
async fn find_task_by_id(&mut self, id: TaskId) -> Result<Task, AsyncQueueError> { pub(crate) async fn find_task_by_id(&self, id: TaskId) -> Result<Task, AsyncQueueError> {
let mut connection = self let mut connection = self
.pool .pool
.get() .get()
@ -119,29 +81,27 @@ impl Queueable for PgAsyncQueue {
Task::find_by_id(&mut connection, id).await Task::find_by_id(&mut connection, id).await
} }
async fn set_task_failed( pub(crate) async fn set_task_state(
&mut self, &self,
id: TaskId, id: TaskId,
error_message: &str, state: TaskState,
) -> Result<Task, AsyncQueueError> { ) -> Result<(), AsyncQueueError> {
let mut connection = self let mut connection = self
.pool .pool
.get() .get()
.await .await
.map_err(|e| QueryBuilderError(e.into()))?; .map_err(|e| QueryBuilderError(e.into()))?;
Task::fail_with_message(&mut connection, id, error_message).await match state {
TaskState::Done => Task::set_done(&mut connection, id).await?,
TaskState::Failed(error_msg) => {
Task::fail_with_message(&mut connection, id, &error_msg).await?
}
_ => return Ok(()),
};
Ok(())
} }
async fn set_task_done(&mut self, id: TaskId) -> Result<Task, AsyncQueueError> { pub(crate) async fn keep_task_alive(&self, id: TaskId) -> Result<(), AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::set_done(&mut connection, id).await
}
async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError> {
let mut connection = self let mut connection = self
.pool .pool
.get() .get()
@ -159,7 +119,7 @@ impl Queueable for PgAsyncQueue {
.await .await
} }
async fn remove_task(&mut self, id: TaskId) -> Result<u64, AsyncQueueError> { pub(crate) async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
let mut connection = self let mut connection = self
.pool .pool
.get() .get()
@ -169,7 +129,7 @@ impl Queueable for PgAsyncQueue {
Ok(result) Ok(result)
} }
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError> { pub(crate) async fn remove_all_tasks(&self) -> Result<u64, AsyncQueueError> {
let mut connection = self let mut connection = self
.pool .pool
.get() .get()
@ -178,37 +138,8 @@ impl Queueable for PgAsyncQueue {
Task::remove_all(&mut connection).await Task::remove_all(&mut connection).await
} }
async fn remove_all_scheduled_tasks(&mut self) -> Result<u64, AsyncQueueError> { pub(crate) async fn schedule_task_retry(
let mut connection = self &self,
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let result = Task::remove_all_scheduled(&mut connection).await?;
Ok(result)
}
async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result<bool, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
Task::remove_by_hash(&mut connection, task_hash).await
}
async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let result = Task::remove_by_type(&mut connection, task_type).await?;
Ok(result)
}
async fn schedule_task_retry(
&mut self,
id: TaskId, id: TaskId,
backoff_seconds: u32, backoff_seconds: u32,
error: &str, error: &str,
@ -226,11 +157,8 @@ impl Queueable for PgAsyncQueue {
#[cfg(test)] #[cfg(test)]
mod async_queue_tests { mod async_queue_tests {
use super::*; use super::*;
use crate::task::TaskState; use crate::CurrentTask;
use crate::Scheduled;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::DateTime;
use chrono::Utc;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection; use diesel_async::AsyncPgConnection;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -240,13 +168,12 @@ mod async_queue_tests {
pub number: u16, pub number: u16,
} }
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for AsyncTask { impl BackgroundTask for AsyncTask {
async fn run( const TASK_NAME: &'static str = "AsyncUniqTask";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(()) Ok(())
} }
} }
@ -256,13 +183,12 @@ mod async_queue_tests {
pub number: u16, pub number: u16,
} }
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for AsyncUniqTask { impl BackgroundTask for AsyncUniqTask {
async fn run( const TASK_NAME: &'static str = "AsyncUniqTask";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(()) Ok(())
} }
@ -277,286 +203,154 @@ mod async_queue_tests {
pub datetime: String, pub datetime: String,
} }
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for AsyncTaskSchedule { impl BackgroundTask for AsyncTaskSchedule {
async fn run( const TASK_NAME: &'static str = "AsyncUniqTask";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(()) Ok(())
} }
fn cron(&self) -> Option<Scheduled> { // fn cron(&self) -> Option<Scheduled> {
let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?; // let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
Some(Scheduled::ScheduleOnce(datetime)) // Some(Scheduled::ScheduleOnce(datetime))
} // }
}
#[tokio::test]
async fn insert_task_creates_new_task() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn update_task_state_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
let id = task.id;
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let finished_task = test.set_task_done(task.id).await.unwrap();
assert_eq!(id, finished_task.id);
assert_eq!(TaskState::Done, finished_task.state());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn failed_task_query_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
let id = task.id;
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let failed_task = test.set_task_failed(task.id, "Some error").await.unwrap();
assert_eq!(id, failed_task.id);
assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
assert_eq!(TaskState::Failed, failed_task.state());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn remove_all_tasks_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task);
let result = test.remove_all_tasks().await.unwrap();
assert_eq!(2, result);
} }
// #[tokio::test] // #[tokio::test]
// async fn schedule_task_test() { // async fn insert_task_creates_new_task() {
// let pool = pool().await; // let pool = pool().await;
// let mut test = PgAsyncQueue::new(pool); // let mut queue = PgTaskStore::new(pool);
// //
// let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0); // let task = queue.create_task(AsyncTask { number: 1 }).await.unwrap();
//
// let task = &AsyncTaskSchedule {
// number: 1,
// datetime: datetime.to_string(),
// };
//
// let task = test.schedule_task(task).await.unwrap();
// //
// let metadata = task.payload.as_object().unwrap(); // let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64(); // let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str(); // let type_task = metadata["type"].as_str();
// //
// assert_eq!(Some(1), number); // assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTaskSchedule"), type_task); // assert_eq!(Some("AsyncTask"), type_task);
// assert_eq!(task.scheduled_at, datetime); //
// queue.remove_all_tasks().await.unwrap();
// }
//
// #[tokio::test]
// async fn update_task_state_test() {
// let pool = pool().await;
// let mut test = PgTaskStore::new(pool);
//
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
// let id = task.id;
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let finished_task = test.set_task_state(task.id, TaskState::Done).await.unwrap();
//
// assert_eq!(id, finished_task.id);
// assert_eq!(TaskState::Done, finished_task.state());
// //
// test.remove_all_tasks().await.unwrap(); // test.remove_all_tasks().await.unwrap();
// } // }
// //
// #[tokio::test] // #[tokio::test]
// async fn remove_all_scheduled_tasks_test() { // async fn failed_task_query_test() {
// let pool = pool().await; // let pool = pool().await;
// let mut test = PgAsyncQueue::new(pool); // let mut test = PgTaskStore::new(pool);
// //
// let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0); // let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
// //
// let task1 = &AsyncTaskSchedule { // let metadata = task.payload.as_object().unwrap();
// number: 1, // let number = metadata["number"].as_u64();
// datetime: datetime.to_string(), // let type_task = metadata["type"].as_str();
// }; // let id = task.id;
// //
// let task2 = &AsyncTaskSchedule { // assert_eq!(Some(1), number);
// number: 2, // assert_eq!(Some("AsyncTask"), type_task);
// datetime: datetime.to_string(),
// };
// //
// test.schedule_task(task1).await.unwrap(); // let failed_task = test.set_task_state(task.id, TaskState::Failed("Some error".to_string())).await.unwrap();
// test.schedule_task(task2).await.unwrap();
// //
// let number = test.remove_all_scheduled_tasks().await.unwrap(); // assert_eq!(id, failed_task.id);
// // assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
// assert_eq!(2, number); // assert_eq!(TaskState::Failed, failed_task.state());
// //
// test.remove_all_tasks().await.unwrap(); // test.remove_all_tasks().await.unwrap();
// } // }
//
#[tokio::test] // #[tokio::test]
async fn pull_next_task_test() { // async fn remove_all_tasks_test() {
let pool = pool().await; // let pool = pool().await;
let mut test = PgAsyncQueue::new(pool); // let mut test = PgTaskStore::new(pool);
//
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); // let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
let metadata = task.payload.as_object().unwrap(); // let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64(); // let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str(); // let type_task = metadata["type"].as_str();
//
assert_eq!(Some(1), number); // assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task); // assert_eq!(Some("AsyncTask"), type_task);
//
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); // let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
//
let metadata = task.payload.as_object().unwrap(); // let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64(); // let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str(); // let type_task = metadata["type"].as_str();
//
assert_eq!(Some(2), number); // assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task); // assert_eq!(Some("AsyncTask"), type_task);
//
let task = test.pull_next_task(None).await.unwrap().unwrap(); // let result = test.remove_all_tasks().await.unwrap();
// assert_eq!(2, result);
let metadata = task.payload.as_object().unwrap(); // }
let number = metadata["number"].as_u64(); //
let type_task = metadata["type"].as_str(); // #[tokio::test]
// async fn pull_next_task_test() {
assert_eq!(Some(1), number); // let pool = pool().await;
assert_eq!(Some("AsyncTask"), type_task); // let mut queue = PgTaskStore::new(pool);
//
let task = test.pull_next_task(None).await.unwrap().unwrap(); // let task = queue.create_task(&AsyncTask { number: 1 }).await.unwrap();
let metadata = task.payload.as_object().unwrap(); //
let number = metadata["number"].as_u64(); // let metadata = task.payload.as_object().unwrap();
let type_task = metadata["type"].as_str(); // let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number); //
assert_eq!(Some("AsyncTask"), type_task); // assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
test.remove_all_tasks().await.unwrap(); //
} // let task = queue.create_task(&AsyncTask { number: 2 }).await.unwrap();
//
#[tokio::test] // let metadata = task.payload.as_object().unwrap();
async fn remove_tasks_type_test() { // let number = metadata["number"].as_u64();
let pool = pool().await; // let type_task = metadata["type"].as_str();
let mut test = PgAsyncQueue::new(pool); //
// assert_eq!(Some(2), number);
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap(); // assert_eq!(Some("AsyncTask"), type_task);
//
let metadata = task.payload.as_object().unwrap(); // let task = queue.pull_next_task(None).await.unwrap().unwrap();
let number = metadata["number"].as_u64(); //
let type_task = metadata["type"].as_str(); // let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
assert_eq!(Some(1), number); // let type_task = metadata["type"].as_str();
assert_eq!(Some("AsyncTask"), type_task); //
// assert_eq!(Some(1), number);
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap(); // assert_eq!(Some("AsyncTask"), type_task);
//
let metadata = task.payload.as_object().unwrap(); // let task = queue.pull_next_task(None).await.unwrap().unwrap();
let number = metadata["number"].as_u64(); // let metadata = task.payload.as_object().unwrap();
let type_task = metadata["type"].as_str(); // let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number); //
assert_eq!(Some("AsyncTask"), type_task); // assert_eq!(Some(2), number);
// assert_eq!(Some("AsyncTask"), type_task);
let result = test //
.remove_tasks_type(TaskType::from("nonexistentType")) // queue.remove_all_tasks().await.unwrap();
.await // }
.unwrap();
assert_eq!(0, result);
let result = test.remove_tasks_type(TaskType::default()).await.unwrap();
assert_eq!(2, result);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn remove_tasks_by_metadata() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task = test
.create_task(&AsyncUniqTask { number: 1 })
.await
.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncUniqTask"), type_task);
let task = test
.create_task(&AsyncUniqTask { number: 2 })
.await
.unwrap();
let metadata = task.payload.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
assert_eq!(Some(2), number);
assert_eq!(Some("AsyncUniqTask"), type_task);
let result = test
.remove_task_by_hash(AsyncUniqTask { number: 0 }.uniq().unwrap())
.await
.unwrap();
assert!(!result, "Should **not** remove task");
let result = test
.remove_task_by_hash(AsyncUniqTask { number: 1 }.uniq().unwrap())
.await
.unwrap();
assert!(result, "Should remove task");
test.remove_all_tasks().await.unwrap();
}
async fn pool() -> Pool<AsyncPgConnection> { async fn pool() -> Pool<AsyncPgConnection> {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new( let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(

View file

@ -1,27 +1,34 @@
use crate::queue::Queueable; use crate::task::{CurrentTask, TaskHash};
use crate::task::TaskHash;
use crate::task::TaskType;
use crate::Scheduled;
use async_trait::async_trait; use async_trait::async_trait;
use std::error::Error; use serde::{de::DeserializeOwned, ser::Serialize};
pub const RETRIES_NUMBER: i32 = 5;
/// Task that can be executed by the queue. /// Task that can be executed by the queue.
/// ///
/// The `RunnableTask` trait is used to define the behaviour of a task. You must implement this /// The `BackgroundTask` trait is used to define the behaviour of a task. You must implement this
/// trait for all tasks you want to execute. /// trait for all tasks you want to execute.
#[typetag::serde(tag = "type")]
#[async_trait] #[async_trait]
pub trait RunnableTask: Send + Sync { pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
/// Execute the task. This method should define its logic /// Unique name of the task.
async fn run(&self, queue: &mut dyn Queueable) -> Result<(), Box<dyn Error + Send + 'static>>; ///
/// This MUST be unique for the whole application.
const TASK_NAME: &'static str;
/// Define the type of the task. /// Task queue where this task will be executed.
/// The `common` task type is used by default ///
fn task_type(&self) -> TaskType { /// Used to define which workers are going to be executing this task. It uses the default
TaskType::default() /// task queue if not changed.
} const QUEUE: &'static str = "default";
/// Number of retries for tasks.
///
/// By default, it is set to 5.
const MAX_RETRIES: i32 = 5;
/// The application data provided to this task at runtime.
type AppData: Clone + Send + 'static;
/// Execute the task. This method should define its logic
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>;
/// If set to true, no new tasks with the same metadata will be inserted /// If set to true, no new tasks with the same metadata will be inserted
/// By default it is set to false. /// By default it is set to false.
@ -29,27 +36,10 @@ pub trait RunnableTask: Send + Sync {
None None
} }
/// This method defines if a task is periodic or it should be executed once in the future.
///
/// Be careful it works only with the UTC timezone.
///
/// Example:
///
/// ```rust
/// fn cron(&self) -> Option<Scheduled> {
/// let expression = "0/20 * * * Aug-Sep * 2022/1";
/// Some(Scheduled::CronPattern(expression.to_string()))
/// }
///```
/// In order to schedule a task once, use the `Scheduled::ScheduleOnce` enum variant.
fn cron(&self) -> Option<Scheduled> {
None
}
/// Define the maximum number of retries the task will be retried. /// Define the maximum number of retries the task will be retried.
/// By default the number of retries is 20. /// By default the number of retries is 20.
fn max_retries(&self) -> i32 { fn max_retries(&self) -> i32 {
RETRIES_NUMBER Self::MAX_RETRIES
} }
/// Define the backoff mode /// Define the backoff mode

19
src/schema.rs Normal file
View file

@ -0,0 +1,19 @@
// @generated automatically by Diesel CLI.
diesel::table! {
backie_tasks (id) {
id -> Uuid,
task_name -> Varchar,
queue_name -> Varchar,
uniq_hash -> Nullable<Bpchar>,
payload -> Jsonb,
timeout_msecs -> Int8,
created_at -> Timestamptz,
scheduled_at -> Timestamptz,
running_at -> Nullable<Timestamptz>,
done_at -> Nullable<Timestamptz>,
error_info -> Nullable<Jsonb>,
retries -> Int4,
max_retries -> Int4,
}
}

1
src/store.rs Normal file
View file

@ -0,0 +1 @@

View file

@ -1,4 +1,5 @@
use crate::schema::backie_tasks; use crate::schema::backie_tasks;
use crate::BackgroundTask;
use chrono::DateTime; use chrono::DateTime;
use chrono::Utc; use chrono::Utc;
use diesel::prelude::*; use diesel::prelude::*;
@ -8,11 +9,11 @@ use sha2::{Digest, Sha256};
use std::borrow::Cow; use std::borrow::Cow;
use std::fmt; use std::fmt;
use std::fmt::Display; use std::fmt::Display;
use typed_builder::TypedBuilder; use std::time::Duration;
use uuid::Uuid; use uuid::Uuid;
/// States of a task. /// States of a task.
#[derive(Copy, Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub enum TaskState { pub enum TaskState {
/// The task is ready to be executed. /// The task is ready to be executed.
Ready, Ready,
@ -21,7 +22,7 @@ pub enum TaskState {
Running, Running,
/// The task has failed to execute. /// The task has failed to execute.
Failed, Failed(String),
/// The task finished successfully. /// The task finished successfully.
Done, Done,
@ -36,24 +37,6 @@ impl Display for TaskId {
} }
} }
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
pub struct TaskType(Cow<'static, str>);
impl Default for TaskType {
fn default() -> Self {
Self(Cow::from("default"))
}
}
impl<S> From<S> for TaskType
where
S: AsRef<str> + 'static,
{
fn from(s: S) -> Self {
TaskType(Cow::from(s.as_ref().to_owned()))
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)] #[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
pub struct TaskHash(Cow<'static, str>); pub struct TaskHash(Cow<'static, str>);
@ -70,42 +53,55 @@ impl TaskHash {
} }
} }
#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone, TypedBuilder)] #[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone)]
#[diesel(table_name = backie_tasks)] #[diesel(table_name = backie_tasks)]
pub struct Task { pub struct Task {
#[builder(setter(into))] /// Unique identifier of the task.
pub id: TaskId, pub id: TaskId,
#[builder(setter(into))] /// Name of the type of task.
pub payload: serde_json::Value, pub task_name: String,
#[builder(setter(into))] /// Queue name that the task belongs to.
pub error_message: Option<String>, pub queue_name: String,
#[builder(setter(into))] /// Unique hash is used to identify and avoid duplicate tasks.
pub task_type: TaskType,
#[builder(setter(into))]
pub uniq_hash: Option<TaskHash>, pub uniq_hash: Option<TaskHash>,
#[builder(setter(into))] /// Representation of the task.
pub retries: i32, pub payload: serde_json::Value,
#[builder(setter(into))] /// Max timeout that the task can run for.
pub timeout_msecs: i64,
/// Creation time of the task.
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
#[builder(setter(into))] /// Date time when the task is scheduled to run.
pub scheduled_at: DateTime<Utc>,
/// Date time when the task is started to run.
pub running_at: Option<DateTime<Utc>>, pub running_at: Option<DateTime<Utc>>,
#[builder(setter(into))] /// Date time when the task is finished.
pub done_at: Option<DateTime<Utc>>, pub done_at: Option<DateTime<Utc>>,
/// Failure reason, when the task is failed.
pub error_info: Option<serde_json::Value>,
/// Number of times a task was retried.
pub retries: i32,
/// Maximum number of retries allow for this task before it is maked as failure.
pub max_retries: i32,
} }
impl Task { impl Task {
pub fn state(&self) -> TaskState { pub fn state(&self) -> TaskState {
if self.done_at.is_some() { if self.done_at.is_some() {
if self.error_message.is_some() { if self.error_info.is_some() {
TaskState::Failed // TODO: use a proper error type
TaskState::Failed(self.error_info.clone().unwrap().to_string())
} else { } else {
TaskState::Done TaskState::Done
} }
@ -117,22 +113,61 @@ impl Task {
} }
} }
#[derive(Insertable, Debug, Eq, PartialEq, Clone, TypedBuilder)] #[derive(Insertable, Debug, Eq, PartialEq, Clone)]
#[diesel(table_name = backie_tasks)] #[diesel(table_name = backie_tasks)]
pub struct NewTask { pub(crate) struct NewTask {
#[builder(setter(into))] task_name: String,
payload: serde_json::Value, queue_name: String,
#[builder(setter(into))]
task_type: TaskType,
#[builder(setter(into))]
uniq_hash: Option<TaskHash>, uniq_hash: Option<TaskHash>,
payload: serde_json::Value,
timeout_msecs: i64,
max_retries: i32,
} }
pub struct TaskInfo { impl NewTask {
pub(crate) fn new<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error>
where
T: BackgroundTask,
{
let max_retries = background_task.max_retries();
let uniq_hash = background_task.uniq();
let payload = serde_json::to_value(background_task)?;
Ok(Self {
task_name: T::TASK_NAME.to_string(),
queue_name: T::QUEUE.to_string(),
uniq_hash,
payload,
timeout_msecs: timeout.as_millis() as i64,
max_retries,
})
}
}
pub struct CurrentTask {
id: TaskId, id: TaskId,
error_message: Option<String>,
retries: i32, retries: i32,
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
} }
impl CurrentTask {
pub(crate) fn new(task: &Task) -> Self {
Self {
id: task.id.clone(),
retries: task.retries,
created_at: task.created_at,
}
}
pub fn id(&self) -> TaskId {
self.id
}
pub fn retry_count(&self) -> i32 {
self.retries
}
pub fn created_at(&self) -> DateTime<Utc> {
self.created_at
}
}

View file

@ -1,53 +1,104 @@
use crate::errors::BackieError; use crate::errors::{AsyncQueueError, BackieError};
use crate::queue::Queueable; use crate::runnable::BackgroundTask;
use crate::runnable::RunnableTask; use crate::task::{CurrentTask, Task, TaskState};
use crate::task::{Task, TaskType}; use crate::{PgTaskStore, RetentionMode};
use crate::RetentionMode;
use crate::Scheduled::*;
use futures::future::FutureExt; use futures::future::FutureExt;
use futures::select; use futures::select;
use std::error::Error; use std::collections::BTreeMap;
use typed_builder::TypedBuilder; use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
/// it executes tasks only of task_type type, it sleeps when there are no tasks in the queue pub type ExecuteTaskFn<AppData> = Arc<
#[derive(TypedBuilder)] dyn Fn(
pub struct Worker<Q> CurrentTask,
where serde_json::Value,
Q: Queueable + Clone + Sync + 'static, AppData,
{ ) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
#[builder(setter(into))] + Send
pub queue: Q, + Sync,
>;
#[builder(default, setter(into))] pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
pub task_type: Option<TaskType>,
#[builder(default, setter(into))] #[derive(Debug, Error)]
pub retention_mode: RetentionMode, pub enum TaskExecError {
#[error("Task execution failed: {0}")]
ExecutionFailed(#[from] anyhow::Error),
#[builder(default, setter(into))] #[error("Task deserialization failed: {0}")]
pub shutdown: Option<tokio::sync::watch::Receiver<()>>, TaskDeserializationFailed(#[from] serde_json::Error),
} }
impl<Q> Worker<Q> pub(crate) fn runnable<BT>(
task_info: CurrentTask,
payload: serde_json::Value,
app_context: BT::AppData,
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
where where
Q: Queueable + Clone + Sync + 'static, BT: BackgroundTask,
{ {
Box::pin(async move {
let background_task: BT = serde_json::from_value(payload)?;
background_task.run(task_info, app_context).await?;
Ok(())
})
}
/// Worker that executes tasks.
pub struct Worker<AppData>
where
AppData: Clone + Send + 'static,
{
store: Arc<PgTaskStore>,
queue_name: String,
retention_mode: RetentionMode,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
app_data_fn: StateFn<AppData>,
/// Notification for the worker to stop.
shutdown: Option<tokio::sync::watch::Receiver<()>>,
}
impl<AppData> Worker<AppData>
where
AppData: Clone + Send + 'static,
{
pub(crate) fn new(
store: Arc<PgTaskStore>,
queue_name: String,
retention_mode: RetentionMode,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
app_data_fn: StateFn<AppData>,
shutdown: Option<tokio::sync::watch::Receiver<()>>,
) -> Self {
Self {
store,
queue_name,
retention_mode,
task_registry,
app_data_fn,
shutdown,
}
}
pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> { pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> {
loop { loop {
// Need to check if has to stop before pulling next task // Check if has to stop before pulling next task
match self.queue.pull_next_task(self.task_type.clone()).await? { if let Some(ref shutdown) = self.shutdown {
Some(task) => { if shutdown.has_changed()? {
let actual_task: Box<dyn RunnableTask> = return Ok(());
serde_json::from_value(task.payload.clone())?;
// check if task is scheduled or not
if let Some(CronPattern(_)) = actual_task.cron() {
// program task
//self.queue.schedule_task(&*actual_task).await?;
} }
// run scheduled task };
// TODO: what do we do if the task fails? it's an internal error, inform the logs
let _ = self.run(task, actual_task).await; match self.store.pull_next_task(&self.queue_name).await? {
Some(task) => {
self.run(task).await?;
} }
None => { None => {
// Listen to watchable future // Listen to watchable future
@ -73,41 +124,45 @@ where
} }
} }
#[cfg(test)] // #[cfg(test)]
pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> { // pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> {
loop { // loop {
match self.queue.pull_next_task(self.task_type.clone()).await? { // match self.store.pull_next_task(self.queue_name.clone()).await? {
Some(task) => { // Some(task) => {
let actual_task: Box<dyn RunnableTask> = // let actual_task: Box<dyn BackgroundTask> =
serde_json::from_value(task.payload.clone()).unwrap(); // serde_json::from_value(task.payload.clone()).unwrap();
//
// // check if task is scheduled or not
// if let Some(CronPattern(_)) = actual_task.cron() {
// // program task
// // self.queue.schedule_task(&*actual_task).await?;
// }
// // run scheduled task
// self.run(task, actual_task).await?;
// }
// None => {
// return Ok(());
// }
// };
// }
// }
// check if task is scheduled or not async fn run(&self, task: Task) -> Result<(), BackieError> {
if let Some(CronPattern(_)) = actual_task.cron() { let task_info = CurrentTask::new(&task);
// program task let runnable_task_caller = self
// self.queue.schedule_task(&*actual_task).await?; .task_registry
} .get(&task.task_name)
// run scheduled task .ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
self.run(task, actual_task).await?;
}
None => {
return Ok(());
}
};
}
}
async fn run(
&mut self,
task: Task,
runnable: Box<dyn RunnableTask>,
) -> Result<(), BackieError> {
// TODO: catch panics // TODO: catch panics
let result = runnable.run(&mut self.queue).await; let result: Result<(), TaskExecError> =
match result { runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
match &result {
Ok(_) => self.finalize_task(task, result).await?, Ok(_) => self.finalize_task(task, result).await?,
Err(error) => { Err(error) => {
if task.retries < runnable.max_retries() { if task.retries < task.max_retries {
let backoff_seconds = runnable.backoff(task.retries as u32); let backoff_seconds = 5; // TODO: runnable.backoff(task.retries as u32);
log::debug!( log::debug!(
"Task {} failed to run and will be retried in {} seconds", "Task {} failed to run and will be retried in {} seconds",
@ -115,12 +170,12 @@ where
backoff_seconds backoff_seconds
); );
let error_message = format!("{}", error); let error_message = format!("{}", error);
self.queue self.store
.schedule_task_retry(task.id, backoff_seconds, &error_message) .schedule_task_retry(task.id, backoff_seconds, &error_message)
.await?; .await?;
} else { } else {
log::debug!("Task {} failed and reached the maximum retries", task.id); log::debug!("Task {} failed and reached the maximum retries", task.id);
self.finalize_task(task, Err(error)).await?; self.finalize_task(task, result).await?;
} }
} }
} }
@ -128,36 +183,36 @@ where
} }
async fn finalize_task( async fn finalize_task(
&mut self, &self,
task: Task, task: Task,
result: Result<(), Box<dyn Error + Send + 'static>>, result: Result<(), TaskExecError>,
) -> Result<(), BackieError> { ) -> Result<(), BackieError> {
match self.retention_mode { match self.retention_mode {
RetentionMode::KeepAll => match result { RetentionMode::KeepAll => match result {
Ok(_) => { Ok(_) => {
self.queue.set_task_done(task.id).await?; self.store.set_task_state(task.id, TaskState::Done).await?;
log::debug!("Task {} done and kept in the database", task.id); log::debug!("Task {} done and kept in the database", task.id);
} }
Err(error) => { Err(error) => {
log::debug!("Task {} failed and kept in the database", task.id); log::debug!("Task {} failed and kept in the database", task.id);
self.queue self.store
.set_task_failed(task.id, &format!("{}", error)) .set_task_state(task.id, TaskState::Failed(format!("{}", error)))
.await?; .await?;
} }
}, },
RetentionMode::RemoveAll => { RetentionMode::RemoveAll => {
log::debug!("Task {} finalized and deleted from the database", task.id); log::debug!("Task {} finalized and deleted from the database", task.id);
self.queue.remove_task(task.id).await?; self.store.remove_task(task.id).await?;
} }
RetentionMode::RemoveDone => match result { RetentionMode::RemoveDone => match result {
Ok(_) => { Ok(_) => {
log::debug!("Task {} done and deleted from the database", task.id); log::debug!("Task {} done and deleted from the database", task.id);
self.queue.remove_task(task.id).await?; self.store.remove_task(task.id).await?;
} }
Err(error) => { Err(error) => {
log::debug!("Task {} failed and kept in the database", task.id); log::debug!("Task {} failed and kept in the database", task.id);
self.queue self.store
.set_task_failed(task.id, &format!("{}", error)) .set_task_state(task.id, TaskState::Failed(format!("{}", error)))
.await?; .await?;
} }
}, },
@ -169,35 +224,19 @@ where
#[cfg(test)] #[cfg(test)]
mod async_worker_tests { mod async_worker_tests {
use std::fmt::Display;
use super::*; use super::*;
use crate::queue::PgAsyncQueue;
use crate::queue::Queueable;
use crate::task::TaskState;
use crate::worker::Task;
use crate::RetentionMode;
use crate::Scheduled;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Duration;
use chrono::Utc;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection; use diesel_async::AsyncPgConnection;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Error, Debug)] #[derive(thiserror::Error, Debug)]
enum TaskError { enum TaskError {
#[error("Something went wrong")]
SomethingWrong, SomethingWrong,
Custom(String),
}
impl Display for TaskError { #[error("{0}")]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Custom(String),
match self {
TaskError::SomethingWrong => write!(f, "Something went wrong"),
TaskError::Custom(message) => write!(f, "{}", message),
}
}
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -205,13 +244,12 @@ mod async_worker_tests {
pub number: u16, pub number: u16,
} }
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for WorkerAsyncTask { impl BackgroundTask for WorkerAsyncTask {
async fn run( const TASK_NAME: &'static str = "WorkerAsyncTask";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(()) Ok(())
} }
} }
@ -221,18 +259,18 @@ mod async_worker_tests {
pub number: u16, pub number: u16,
} }
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for WorkerAsyncTaskSchedule { impl BackgroundTask for WorkerAsyncTaskSchedule {
async fn run( const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Ok(()) Ok(())
} }
fn cron(&self) -> Option<Scheduled> {
Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1))) // fn cron(&self) -> Option<Scheduled> {
} // Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
// }
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -240,16 +278,15 @@ mod async_worker_tests {
pub number: u16, pub number: u16,
} }
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for AsyncFailedTask { impl BackgroundTask for AsyncFailedTask {
async fn run( const TASK_NAME: &'static str = "AsyncFailedTask";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
let message = format!("number {} is wrong :(", self.number); let message = format!("number {} is wrong :(", self.number);
Err(Box::new(TaskError::Custom(message))) Err(TaskError::Custom(message).into())
} }
fn max_retries(&self) -> i32 { fn max_retries(&self) -> i32 {
@ -260,282 +297,40 @@ mod async_worker_tests {
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
struct AsyncRetryTask {} struct AsyncRetryTask {}
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for AsyncRetryTask { impl BackgroundTask for AsyncRetryTask {
async fn run( const TASK_NAME: &'static str = "AsyncRetryTask";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
Err(Box::new(TaskError::SomethingWrong))
}
fn max_retries(&self) -> i32 { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
2 Err(TaskError::SomethingWrong.into())
} }
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct AsyncTaskType1 {} struct AsyncTaskType1 {}
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for AsyncTaskType1 { impl BackgroundTask for AsyncTaskType1 {
async fn run( const TASK_NAME: &'static str = "AsyncTaskType1";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
Ok(())
}
fn task_type(&self) -> TaskType { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
"type1".into() Ok(())
} }
} }
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct AsyncTaskType2 {} struct AsyncTaskType2 {}
#[typetag::serde]
#[async_trait] #[async_trait]
impl RunnableTask for AsyncTaskType2 { impl BackgroundTask for AsyncTaskType2 {
async fn run( const TASK_NAME: &'static str = "AsyncTaskType2";
&self, type AppData = ();
_queueable: &mut dyn Queueable,
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Ok(()) Ok(())
} }
fn task_type(&self) -> TaskType {
TaskType::from("type2")
}
}
#[tokio::test]
async fn execute_and_finishes_task() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let actual_task = WorkerAsyncTask { number: 1 };
let task = insert_task(&mut test, &actual_task).await;
let id = task.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run(task, Box::new(actual_task)).await.unwrap();
let task_finished = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task_finished.id);
assert_eq!(TaskState::Done, task_finished.state());
test.remove_all_tasks().await.unwrap();
}
// #[tokio::test]
// async fn schedule_task_test() {
// let pool = pool().await;
// let mut test = PgAsyncQueue::new(pool);
//
// let actual_task = WorkerAsyncTaskSchedule { number: 1 };
//
// let task = test.schedule_task(&actual_task).await.unwrap();
//
// let id = task.id;
//
// let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
// .queue(test.clone())
// .retention_mode(RetentionMode::KeepAll)
// .build();
//
// worker.run_tasks_until_none().await.unwrap();
//
// let task = worker.queue.find_task_by_id(id).await.unwrap();
//
// assert_eq!(id, task.id);
// assert_eq!(TaskState::Ready, task.state());
//
// tokio::time::sleep(core::time::Duration::from_secs(3)).await;
//
// worker.run_tasks_until_none().await.unwrap();
//
// let task = test.find_task_by_id(id).await.unwrap();
// assert_eq!(id, task.id);
// assert_eq!(TaskState::Done, task.state());
//
// test.remove_all_tasks().await.unwrap();
// }
#[tokio::test]
async fn retries_task_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let actual_task = AsyncRetryTask {};
let task = test.create_task(&actual_task).await.unwrap();
let id = task.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run_tasks_until_none().await.unwrap();
let task = worker.queue.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(TaskState::Ready, task.state());
assert_eq!(1, task.retries);
assert!(task.error_message.is_some());
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
worker.run_tasks_until_none().await.unwrap();
let task = worker.queue.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(TaskState::Ready, task.state());
assert_eq!(2, task.retries);
tokio::time::sleep(core::time::Duration::from_secs(10)).await;
worker.run_tasks_until_none().await.unwrap();
let task = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(TaskState::Failed, task.state());
assert_eq!("Something went wrong".to_string(), task.error_message.unwrap());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn worker_shutsdown_when_notified() {
let pool = pool().await;
let queue = PgAsyncQueue::new(pool);
let (tx, rx) = tokio::sync::watch::channel(());
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(queue)
.shutdown(rx)
.build();
let handle = tokio::spawn(async move {
worker.run_tasks().await.unwrap();
true
});
tx.send(()).unwrap();
select! {
_ = handle.fuse() => {}
_ = tokio::time::sleep(core::time::Duration::from_secs(1)).fuse() => panic!("Worker did not shutdown")
}
}
#[tokio::test]
async fn saves_error_for_failed_task() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let failed_task = AsyncFailedTask { number: 1 };
let task = insert_task(&mut test, &failed_task).await;
let id = task.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run(task, Box::new(failed_task)).await.unwrap();
let task_finished = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task_finished.id);
assert_eq!(TaskState::Failed, task_finished.state());
assert_eq!(
"number 1 is wrong :(".to_string(),
task_finished.error_message.unwrap()
);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn executes_task_only_of_specific_type() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
let id1 = task1.id;
let id12 = task12.id;
let id2 = task2.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.task_type(TaskType::from("type1"))
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run_tasks_until_none().await.unwrap();
let task1 = test.find_task_by_id(id1).await.unwrap();
let task12 = test.find_task_by_id(id12).await.unwrap();
let task2 = test.find_task_by_id(id2).await.unwrap();
assert_eq!(id1, task1.id);
assert_eq!(id12, task12.id);
assert_eq!(id2, task2.id);
assert_eq!(TaskState::Done, task1.state());
assert_eq!(TaskState::Done, task12.state());
assert_eq!(TaskState::Ready, task2.state());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn remove_when_finished() {
let pool = pool().await;
let mut test = PgAsyncQueue::new(pool);
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
let _id1 = task1.id;
let _id12 = task12.id;
let id2 = task2.id;
let mut worker = Worker::<PgAsyncQueue>::builder()
.queue(test.clone())
.task_type(TaskType::from("type1"))
.build();
worker.run_tasks_until_none().await.unwrap();
let task = test
.pull_next_task(Some(TaskType::from("type1")))
.await
.unwrap();
assert_eq!(None, task);
let task2 = test
.pull_next_task(Some(TaskType::from("type2")))
.await
.unwrap()
.unwrap();
assert_eq!(id2, task2.id);
test.remove_all_tasks().await.unwrap();
}
async fn insert_task(test: &mut PgAsyncQueue, task: &dyn RunnableTask) -> Task {
test.create_task(task).await.unwrap()
} }
async fn pool() -> Pool<AsyncPgConnection> { async fn pool() -> Pool<AsyncPgConnection> {

View file

@ -1,102 +1,200 @@
use crate::errors::BackieError; use crate::errors::BackieError;
use crate::queue::Queueable; use crate::queue::Queue;
use crate::task::TaskType; use crate::worker::{runnable, ExecuteTaskFn};
use crate::worker::Worker; use crate::worker::{StateFn, Worker};
use crate::RetentionMode; use crate::{BackgroundTask, CurrentTask, PgTaskStore, RetentionMode};
use async_recursion::async_recursion; use std::collections::BTreeMap;
use log::error;
use std::future::Future; use std::future::Future;
use tokio::sync::watch::Receiver; use std::sync::Arc;
use typed_builder::TypedBuilder; use tokio::task::JoinHandle;
#[derive(TypedBuilder, Clone)] pub type AppDataFn<AppData> = Arc<dyn Fn(Queue) -> AppData + Send + Sync>;
pub struct WorkerPool<AQueue>
#[derive(Clone)]
pub struct WorkerPool<AppData>
where where
AQueue: Queueable + Clone + Sync + 'static, AppData: Clone + Send + 'static,
{ {
#[builder(setter(into))] /// Storage of tasks.
/// the AsyncWorkerPool uses a queue to control the tasks that will be executed. queue_store: Arc<PgTaskStore>, // TODO: make this generic/dynamic referenced
pub queue: AQueue,
/// retention_mode controls if tasks should be persisted after execution /// Queue used to spawn tasks.
#[builder(default, setter(into))] queue: Queue,
pub retention_mode: RetentionMode,
/// the number of workers of the AsyncWorkerPool. /// Make possible to load the application data.
#[builder(setter(into))] ///
pub number_of_workers: u32, /// The application data is loaded when the worker pool is started and is passed to the tasks.
/// The loading function accepts a queue instance in case the application data depends on it. This
/// is interesting for situations where the application wants to allow tasks to spawn other tasks.
application_data_fn: StateFn<AppData>,
/// The type of tasks that will be executed by `AsyncWorkerPool`. /// The types of task the worker pool can execute and the loaders for them.
#[builder(default, setter(into))] task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
pub task_type: Option<TaskType>,
/// Number of workers that will be spawned per queue.
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
} }
// impl<TypedBuilderFields, Q> AsyncWorkerBuilder<TypedBuilderFields, Q> impl<AppData> WorkerPool<AppData>
// where
// TypedBuilderFields: Clone,
// Q: Queueable + Clone + Sync + 'static,
// {
// pub fn with_graceful_shutdown<F>(self, signal: F) -> Self<TypedBuilderFields, Q>
// where
// F: Future<Output = ()>,
// {
// self
// }
// }
impl<AQueue> WorkerPool<AQueue>
where where
AQueue: Queueable + Clone + Sync + 'static, AppData: Clone + Send + 'static,
{ {
/// Starts the configured number of workers /// Create a new worker pool.
/// This is necessary in order to execute tasks. pub fn new<A>(queue_store: PgTaskStore, application_data_fn: A) -> Self
pub async fn start<F>(&mut self, graceful_shutdown: F) -> Result<(), BackieError> where
A: Fn(Queue) -> AppData + Send + Sync + 'static,
{
let queue_store = Arc::new(queue_store);
let queue = Queue::new(queue_store.clone());
let application_data_fn = {
let queue = queue.clone();
move || application_data_fn(queue.clone())
};
Self {
queue_store,
queue,
application_data_fn: Arc::new(application_data_fn),
task_registry: BTreeMap::new(),
worker_queues: BTreeMap::new(),
}
}
/// Register a task type with the worker pool.
pub fn register_task_type<BT>(mut self, num_workers: u32, retention_mode: RetentionMode) -> Self
where
BT: BackgroundTask<AppData = AppData>,
{
self.worker_queues
.insert(BT::QUEUE.to_string(), (retention_mode, num_workers));
self.task_registry
.insert(BT::TASK_NAME.to_string(), Arc::new(runnable::<BT>));
self
}
pub async fn start<F>(
self,
graceful_shutdown: F,
) -> Result<(JoinHandle<()>, Queue), BackieError>
where where
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
{ {
let (tx, rx) = tokio::sync::watch::channel(()); let (tx, rx) = tokio::sync::watch::channel(());
for idx in 0..self.number_of_workers {
let pool = self.clone(); // Spawn all individual workers per queue
// TODO: the worker pool keeps track of the number of workers and spawns new workers as needed. for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() {
// There should be always a minimum number of workers active waiting for tasks to execute for idx in 0..*num_workers {
// or for a gracefull shutdown. let mut worker: Worker<AppData> = Worker::new(
tokio::spawn(Self::supervise_task(pool, rx.clone(), 0, idx)); self.queue_store.clone(),
queue_name.clone(),
retention_mode.clone(),
self.task_registry.clone(),
self.application_data_fn.clone(),
Some(rx.clone()),
);
let worker_name = format!("worker-{queue_name}-{idx}");
// TODO: grab the join handle for every worker for graceful shutdown
tokio::spawn(async move {
worker.run_tasks().await;
log::info!("Worker {} stopped", worker_name);
});
} }
}
Ok((
tokio::spawn(async move {
graceful_shutdown.await; graceful_shutdown.await;
tx.send(())?; if let Err(err) = tx.send(()) {
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
} else {
log::info!("Worker pool stopped gracefully"); log::info!("Worker pool stopped gracefully");
}
}),
self.queue,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
#[derive(Clone, Debug)]
struct ApplicationContext {
app_name: String,
}
impl ApplicationContext {
fn new() -> Self {
Self {
app_name: "Backie".to_string(),
}
}
fn get_app_name(&self) -> String {
self.app_name.clone()
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct GreetingTask {
person: String,
}
#[async_trait]
impl BackgroundTask for GreetingTask {
const TASK_NAME: &'static str = "my_task";
type AppData = ApplicationContext;
async fn run(
&self,
task_info: CurrentTask,
app_context: Self::AppData,
) -> Result<(), anyhow::Error> {
println!(
"[{}] Hello {}! I'm {}.",
task_info.id(),
self.person,
app_context.get_app_name()
);
Ok(()) Ok(())
} }
}
#[async_recursion] #[tokio::test]
async fn supervise_task( async fn test_worker_pool() {
pool: WorkerPool<AQueue>, let my_app_context = ApplicationContext::new();
receiver: Receiver<()>,
restarts: u64,
worker_number: u32,
) {
let restarts = restarts + 1;
let inner_pool = pool.clone(); let task_store = PgTaskStore::new(pool().await);
let inner_receiver = receiver.clone();
let join_handle = tokio::spawn(async move { let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
let mut worker: Worker<AQueue> = Worker::builder() .register_task_type::<GreetingTask>(1, RetentionMode::RemoveDone)
.queue(inner_pool.queue.clone()) .start(futures::future::ready(()))
.retention_mode(inner_pool.retention_mode) .await
.task_type(inner_pool.task_type.clone()) .unwrap();
.shutdown(inner_receiver)
.build();
worker.run_tasks().await queue
}); .enqueue(GreetingTask {
person: "Rafael".to_string(),
})
.await
.unwrap();
if (join_handle.await).is_err() { join_handle.await.unwrap();
error!( }
"Worker {} stopped. Restarting. the number of restarts {}",
worker_number, restarts, async fn pool() -> Pool<AsyncPgConnection> {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
); );
Self::supervise_task(pool, receiver, restarts, worker_number).await; Pool::builder()
} .max_size(1)
.min_idle(Some(1))
.build(manager)
.await
.unwrap()
} }
} }