Make possible to provide app state to tasks
This commit is contained in:
parent
7fcb63f75c
commit
aac0b44c7f
19 changed files with 776 additions and 1151 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,5 +1,4 @@
|
|||
**/target
|
||||
Cargo.lock
|
||||
src/schema.rs
|
||||
docs/content/docs/CHANGELOG.md
|
||||
docs/content/docs/README.md
|
||||
|
|
11
Cargo.toml
11
Cargo.toml
|
@ -19,16 +19,13 @@ cron = "0.12"
|
|||
chrono = "0.4"
|
||||
hex = "0.4"
|
||||
log = "0.4"
|
||||
serde = "1.0"
|
||||
serde_derive = "1.0"
|
||||
serde_json = "1.0"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha2 = "0.10"
|
||||
thiserror = "1.0"
|
||||
typed-builder = "0.13"
|
||||
typetag = "0.2"
|
||||
anyhow = "1"
|
||||
thiserror = "1"
|
||||
uuid = { version = "1.1", features = ["v4", "serde"] }
|
||||
async-trait = "0.1"
|
||||
async-recursion = "1"
|
||||
futures = "0.3"
|
||||
diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] }
|
||||
diesel-derive-newtype = "2.0.0-rc.0"
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
[package]
|
||||
name = "simple_cron_async_worker"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
fang = { path = "../../../" , features = ["asynk"]}
|
||||
env_logger = "0.9.0"
|
||||
log = "0.4.0"
|
||||
tokio = { version = "1", features = ["full"] }
|
|
@ -1,33 +0,0 @@
|
|||
use fang::async_trait;
|
||||
use fang::asynk::async_queue::AsyncQueueable;
|
||||
use fang::serde::{Deserialize, Serialize};
|
||||
use fang::typetag;
|
||||
use fang::AsyncRunnable;
|
||||
use fang::FangError;
|
||||
use fang::Scheduled;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(crate = "fang::serde")]
|
||||
pub struct MyCronTask {}
|
||||
|
||||
#[async_trait]
|
||||
#[typetag::serde]
|
||||
impl AsyncRunnable for MyCronTask {
|
||||
async fn run(&self, _queue: &mut dyn AsyncQueueable) -> Result<(), FangError> {
|
||||
log::info!("CRON!!!!!!!!!!!!!!!",);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cron(&self) -> Option<Scheduled> {
|
||||
// sec min hour day of month month day of week year
|
||||
// be careful works only with UTC hour.
|
||||
// https://www.timeanddate.com/worldclock/timezone/utc
|
||||
let expression = "0/20 * * * Aug-Sep * 2022/1";
|
||||
Some(Scheduled::CronPattern(expression.to_string()))
|
||||
}
|
||||
|
||||
fn uniq(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
use fang::asynk::async_queue::AsyncQueue;
|
||||
use fang::asynk::async_queue::AsyncQueueable;
|
||||
use fang::asynk::async_worker_pool::AsyncWorkerPool;
|
||||
use fang::AsyncRunnable;
|
||||
use fang::NoTls;
|
||||
use simple_cron_async_worker::MyCronTask;
|
||||
use std::time::Duration;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
env_logger::init();
|
||||
|
||||
log::info!("Starting...");
|
||||
let max_pool_size: u32 = 3;
|
||||
let mut queue = AsyncQueue::builder()
|
||||
.uri("postgres://postgres:postgres@localhost/fang")
|
||||
.max_pool_size(max_pool_size)
|
||||
.build();
|
||||
|
||||
queue.connect(NoTls).await.unwrap();
|
||||
log::info!("Queue connected...");
|
||||
|
||||
let mut pool: AsyncWorkerPool<AsyncQueue<NoTls>> = AsyncWorkerPool::builder()
|
||||
.number_of_workers(10_u32)
|
||||
.queue(queue.clone())
|
||||
.build();
|
||||
|
||||
log::info!("Pool created ...");
|
||||
|
||||
pool.start().await;
|
||||
log::info!("Workers started ...");
|
||||
|
||||
let task = MyCronTask {};
|
||||
|
||||
queue
|
||||
.schedule_task(&task as &dyn AsyncRunnable)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(100)).await;
|
||||
}
|
|
@ -5,6 +5,7 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
backie = { path = "../../" }
|
||||
anyhow = "1"
|
||||
env_logger = "0.9.0"
|
||||
log = "0.4.0"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
@ -12,4 +13,3 @@ diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
|
|||
diesel = { version = "2.0", features = ["postgres"] }
|
||||
async-trait = "0.1"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
typetag = "0.2"
|
||||
|
|
|
@ -1,7 +1,20 @@
|
|||
use std::time::Duration;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use backie::{RunnableTask, Queueable};
|
||||
use backie::{BackgroundTask, CurrentTask};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MyApplicationContext {
|
||||
app_name: String,
|
||||
}
|
||||
|
||||
impl MyApplicationContext {
|
||||
pub fn new(app_name: &str) -> Self {
|
||||
Self {
|
||||
app_name: app_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct MyTask {
|
||||
|
@ -26,37 +39,51 @@ impl MyFailingTask {
|
|||
}
|
||||
|
||||
#[async_trait]
|
||||
#[typetag::serde]
|
||||
impl RunnableTask for MyTask {
|
||||
async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box<dyn std::error::Error + Send + 'static>> {
|
||||
impl BackgroundTask for MyTask {
|
||||
const TASK_NAME: &'static str = "my_task";
|
||||
type AppData = MyApplicationContext;
|
||||
|
||||
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
// let new_task = MyTask::new(self.number + 1);
|
||||
// queue
|
||||
// .insert_task(&new_task as &dyn AsyncRunnable)
|
||||
// .insert_task(&new_task)
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
log::info!("the current number is {}", self.number);
|
||||
log::info!(
|
||||
"[{}] Hello from {}! the current number is {}",
|
||||
task.id(),
|
||||
ctx.app_name,
|
||||
self.number
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
|
||||
log::info!("done..");
|
||||
log::info!("[{}] done..", task.id());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[typetag::serde]
|
||||
impl RunnableTask for MyFailingTask {
|
||||
async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box<dyn std::error::Error + Send + 'static>> {
|
||||
impl BackgroundTask for MyFailingTask {
|
||||
const TASK_NAME: &'static str = "my_failing_task";
|
||||
type AppData = MyApplicationContext;
|
||||
|
||||
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
// let new_task = MyFailingTask::new(self.number + 1);
|
||||
// queue
|
||||
// .insert_task(&new_task as &dyn AsyncRunnable)
|
||||
// .insert_task(&new_task)
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
log::info!("the current number is {}", self.number);
|
||||
// task.id();
|
||||
// task.keep_alive().await?;
|
||||
// task.previous_error();
|
||||
// task.retry_count();
|
||||
|
||||
log::info!("[{}] the current number is {}", task.id(), self.number);
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
|
||||
log::info!("done..");
|
||||
log::info!("[{}] done..", task.id());
|
||||
//
|
||||
// let b = true;
|
||||
//
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
use simple_worker::MyFailingTask;
|
||||
use simple_worker::MyTask;
|
||||
use std::time::Duration;
|
||||
use backie::{PgTaskStore, RetentionMode, WorkerPool};
|
||||
use diesel_async::pg::AsyncPgConnection;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use backie::{PgAsyncQueue, WorkerPool, Queueable};
|
||||
use simple_worker::MyApplicationContext;
|
||||
use simple_worker::MyFailingTask;
|
||||
use simple_worker::MyTask;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
@ -22,49 +22,38 @@ async fn main() {
|
|||
.unwrap();
|
||||
log::info!("Pool created ...");
|
||||
|
||||
let mut queue = PgAsyncQueue::new(pool);
|
||||
let task_store = PgTaskStore::new(pool);
|
||||
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(false);
|
||||
|
||||
let executor_task = tokio::spawn({
|
||||
let mut queue = queue.clone();
|
||||
async move {
|
||||
let mut workers_pool: WorkerPool<PgAsyncQueue> = WorkerPool::builder()
|
||||
.number_of_workers(10_u32)
|
||||
.queue(queue)
|
||||
.build();
|
||||
// Some global application context I want to pass to my background tasks
|
||||
let my_app_context = MyApplicationContext::new("Backie Example App");
|
||||
|
||||
log::info!("Workers starting ...");
|
||||
workers_pool.start(async move {
|
||||
rx.changed().await;
|
||||
}).await;
|
||||
log::info!("Workers stopped!");
|
||||
}
|
||||
});
|
||||
// Register the task types I want to use and start the worker pool
|
||||
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||
.register_task_type::<MyTask>(1, RetentionMode::RemoveDone)
|
||||
.register_task_type::<MyFailingTask>(1, RetentionMode::RemoveDone)
|
||||
.start(async move {
|
||||
let _ = rx.changed().await;
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
log::info!("Workers started ...");
|
||||
|
||||
let task1 = MyTask::new(0);
|
||||
let task2 = MyTask::new(20_000);
|
||||
let task3 = MyFailingTask::new(50_000);
|
||||
|
||||
queue
|
||||
.create_task(&task1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
queue
|
||||
.create_task(&task2)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
queue
|
||||
.create_task(&task3)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
queue.enqueue(task1).await.unwrap();
|
||||
queue.enqueue(task2).await.unwrap();
|
||||
queue.enqueue(task3).await.unwrap();
|
||||
log::info!("Tasks created ...");
|
||||
tokio::signal::ctrl_c().await;
|
||||
|
||||
// Wait for Ctrl+C
|
||||
let _ = tokio::signal::ctrl_c().await;
|
||||
log::info!("Stopping ...");
|
||||
tx.send(true).unwrap();
|
||||
executor_task.await.unwrap();
|
||||
log::info!("Stopped!");
|
||||
join_handle.await.unwrap();
|
||||
log::info!("Workers Stopped!");
|
||||
}
|
||||
|
|
|
@ -2,20 +2,20 @@ CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
|||
|
||||
CREATE TABLE backie_tasks (
|
||||
id uuid PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
payload jsonb NOT NULL,
|
||||
error_message TEXT DEFAULT NULL,
|
||||
task_type VARCHAR DEFAULT 'common' NOT NULL,
|
||||
task_name VARCHAR NOT NULL,
|
||||
queue_name VARCHAR DEFAULT 'common' NOT NULL,
|
||||
uniq_hash CHAR(64) DEFAULT NULL,
|
||||
retries INTEGER DEFAULT 0 NOT NULL,
|
||||
payload jsonb NOT NULL,
|
||||
timeout_msecs INT8 NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||
running_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL
|
||||
done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
error_info jsonb DEFAULT NULL,
|
||||
retries INTEGER DEFAULT 0 NOT NULL,
|
||||
max_retries INTEGER DEFAULT 0 NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX backie_tasks_type_index ON backie_tasks(task_type);
|
||||
CREATE INDEX backie_tasks_created_at_index ON backie_tasks(created_at);
|
||||
CREATE INDEX backie_tasks_uniq_hash ON backie_tasks(uniq_hash);
|
||||
|
||||
--- create uniqueness index
|
||||
CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL;
|
||||
|
||||
|
|
|
@ -1,25 +1,16 @@
|
|||
use serde_json::Error as SerdeError;
|
||||
use std::fmt::Display;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Library errors
|
||||
#[derive(Debug, Error)]
|
||||
pub enum BackieError {
|
||||
#[error("Queue processing error: {0}")]
|
||||
QueueProcessingError(#[from] AsyncQueueError),
|
||||
SerializationError(#[from] SerdeError),
|
||||
ShutdownError(#[from] tokio::sync::watch::error::SendError<()>),
|
||||
}
|
||||
|
||||
impl Display for BackieError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
BackieError::QueueProcessingError(error) => {
|
||||
write!(f, "Queue processing error: {}", error)
|
||||
}
|
||||
BackieError::SerializationError(error) => write!(f, "Serialization error: {}", error),
|
||||
BackieError::ShutdownError(error) => write!(f, "Shutdown error: {}", error),
|
||||
}
|
||||
}
|
||||
#[error("Worker Pool shutdown error: {0}")]
|
||||
WorkerPoolShutdownError(#[from] tokio::sync::watch::error::SendError<()>),
|
||||
|
||||
#[error("Worker shutdown error: {0}")]
|
||||
WorkerShutdownError(#[from] tokio::sync::watch::error::RecvError),
|
||||
}
|
||||
|
||||
/// List of error types that can occur while working with cron schedules.
|
||||
|
@ -29,7 +20,7 @@ pub enum CronError {
|
|||
#[error(transparent)]
|
||||
LibraryError(#[from] cron::error::Error),
|
||||
/// [`Scheduled`] enum variant is not provided
|
||||
#[error("You have to implement method `cron()` in your AsyncRunnable")]
|
||||
#[error("You have to implement method `cron()` in your Runnable")]
|
||||
TaskNotSchedulableError,
|
||||
/// The next execution can not be determined using the current [`Scheduled::CronPattern`]
|
||||
#[error("No timestamps match with this cron pattern")]
|
||||
|
@ -41,7 +32,7 @@ pub enum AsyncQueueError {
|
|||
#[error(transparent)]
|
||||
PgError(#[from] diesel::result::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
#[error("Task serialization error: {0}")]
|
||||
SerdeError(#[from] serde_json::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
|
@ -49,6 +40,9 @@ pub enum AsyncQueueError {
|
|||
|
||||
#[error("Task is not in progress, operation not allowed")]
|
||||
TaskNotRunning,
|
||||
|
||||
#[error("Task with name {0} is not registered")]
|
||||
TaskNotRegistered(String),
|
||||
}
|
||||
|
||||
impl From<cron::error::Error> for AsyncQueueError {
|
||||
|
|
|
@ -38,9 +38,9 @@ impl Default for RetentionMode {
|
|||
}
|
||||
}
|
||||
|
||||
pub use queue::PgAsyncQueue;
|
||||
pub use queue::Queueable;
|
||||
pub use runnable::RunnableTask;
|
||||
pub use queue::PgTaskStore;
|
||||
pub use runnable::BackgroundTask;
|
||||
pub use task::CurrentTask;
|
||||
pub use worker_pool::WorkerPool;
|
||||
|
||||
pub mod errors;
|
||||
|
@ -48,6 +48,7 @@ mod queries;
|
|||
pub mod queue;
|
||||
pub mod runnable;
|
||||
mod schema;
|
||||
pub mod store;
|
||||
pub mod task;
|
||||
pub mod worker;
|
||||
pub mod worker_pool;
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use crate::errors::AsyncQueueError;
|
||||
use crate::runnable::RunnableTask;
|
||||
use crate::schema::backie_tasks;
|
||||
use crate::task::Task;
|
||||
use crate::task::{NewTask, TaskHash, TaskId, TaskType};
|
||||
use crate::task::{NewTask, TaskHash, TaskId};
|
||||
use chrono::DateTime;
|
||||
use chrono::Duration;
|
||||
use chrono::Utc;
|
||||
|
@ -43,14 +42,6 @@ impl Task {
|
|||
Ok(qty > 0)
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_by_type(
|
||||
connection: &mut AsyncPgConnection,
|
||||
task_type: TaskType,
|
||||
) -> Result<u64, AsyncQueueError> {
|
||||
let query = backie_tasks::table.filter(backie_tasks::task_type.eq(task_type));
|
||||
Ok(diesel::delete(query).execute(connection).await? as u64)
|
||||
}
|
||||
|
||||
pub(crate) async fn find_by_id(
|
||||
connection: &mut AsyncPgConnection,
|
||||
id: TaskId,
|
||||
|
@ -67,10 +58,13 @@ impl Task {
|
|||
id: TaskId,
|
||||
error_message: &str,
|
||||
) -> Result<Task, AsyncQueueError> {
|
||||
let error = serde_json::json!({
|
||||
"error": error_message,
|
||||
});
|
||||
let query = backie_tasks::table.filter(backie_tasks::id.eq(id));
|
||||
Ok(diesel::update(query)
|
||||
.set((
|
||||
backie_tasks::error_message.eq(error_message),
|
||||
backie_tasks::error_info.eq(Some(error)),
|
||||
backie_tasks::done_at.eq(Utc::now()),
|
||||
))
|
||||
.get_result::<Task>(connection)
|
||||
|
@ -81,18 +75,22 @@ impl Task {
|
|||
connection: &mut AsyncPgConnection,
|
||||
id: TaskId,
|
||||
backoff_seconds: u32,
|
||||
error: &str,
|
||||
error_message: &str,
|
||||
) -> Result<Task, AsyncQueueError> {
|
||||
use crate::schema::backie_tasks::dsl;
|
||||
|
||||
let now = Utc::now();
|
||||
let scheduled_at = now + Duration::seconds(backoff_seconds as i64);
|
||||
|
||||
let error = serde_json::json!({
|
||||
"error": error_message,
|
||||
});
|
||||
|
||||
let task = diesel::update(backie_tasks::table.filter(backie_tasks::id.eq(id)))
|
||||
.set((
|
||||
backie_tasks::error_message.eq(error),
|
||||
backie_tasks::error_info.eq(Some(error)),
|
||||
backie_tasks::retries.eq(dsl::retries + 1),
|
||||
backie_tasks::created_at.eq(scheduled_at),
|
||||
backie_tasks::scheduled_at.eq(scheduled_at),
|
||||
backie_tasks::running_at.eq::<Option<DateTime<Utc>>>(None),
|
||||
))
|
||||
.get_result::<Task>(connection)
|
||||
|
@ -103,14 +101,14 @@ impl Task {
|
|||
|
||||
pub(crate) async fn fetch_next_pending(
|
||||
connection: &mut AsyncPgConnection,
|
||||
task_type: TaskType,
|
||||
queue_name: &str,
|
||||
) -> Option<Task> {
|
||||
backie_tasks::table
|
||||
.filter(backie_tasks::created_at.lt(Utc::now())) // skip tasks scheduled for the future
|
||||
.filter(backie_tasks::scheduled_at.lt(Utc::now())) // skip tasks scheduled for the future
|
||||
.order(backie_tasks::created_at.asc()) // get the oldest task first
|
||||
.filter(backie_tasks::running_at.is_null()) // that is not marked as running already
|
||||
.filter(backie_tasks::done_at.is_null()) // and not marked as done
|
||||
.filter(backie_tasks::task_type.eq(task_type))
|
||||
.filter(backie_tasks::queue_name.eq(queue_name))
|
||||
.limit(1)
|
||||
.for_update()
|
||||
.skip_locked()
|
||||
|
@ -143,39 +141,13 @@ impl Task {
|
|||
|
||||
pub(crate) async fn insert(
|
||||
connection: &mut AsyncPgConnection,
|
||||
runnable: &dyn RunnableTask,
|
||||
new_task: NewTask,
|
||||
) -> Result<Task, AsyncQueueError> {
|
||||
let payload = serde_json::to_value(runnable)?;
|
||||
match runnable.uniq() {
|
||||
None => {
|
||||
let new_task = NewTask::builder()
|
||||
.uniq_hash(None)
|
||||
.task_type(runnable.task_type())
|
||||
.payload(payload)
|
||||
.build();
|
||||
|
||||
Ok(diesel::insert_into(backie_tasks::table)
|
||||
.values(new_task)
|
||||
.get_result::<Task>(connection)
|
||||
.await?)
|
||||
}
|
||||
Some(hash) => match Self::find_by_uniq_hash(connection, hash.clone()).await {
|
||||
Some(task) => Ok(task),
|
||||
None => {
|
||||
let new_task = NewTask::builder()
|
||||
.uniq_hash(Some(hash))
|
||||
.task_type(runnable.task_type())
|
||||
.payload(payload)
|
||||
.build();
|
||||
|
||||
Ok(diesel::insert_into(backie_tasks::table)
|
||||
.values(new_task)
|
||||
.get_result::<Task>(connection)
|
||||
.await?)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn find_by_uniq_hash(
|
||||
connection: &mut AsyncPgConnection,
|
||||
|
|
562
src/queue.rs
562
src/queue.rs
|
@ -1,86 +1,48 @@
|
|||
use crate::errors::AsyncQueueError;
|
||||
use crate::runnable::RunnableTask;
|
||||
use crate::task::{Task, TaskHash, TaskId, TaskType};
|
||||
use async_trait::async_trait;
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::task::{NewTask, Task, TaskHash, TaskId, TaskState};
|
||||
use diesel::result::Error::QueryBuilderError;
|
||||
use diesel_async::scoped_futures::ScopedFutureExt;
|
||||
use diesel_async::AsyncConnection;
|
||||
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// This trait defines operations for an asynchronous queue.
|
||||
/// The trait can be implemented for different storage backends.
|
||||
/// For now, the trait is only implemented for PostgreSQL. More backends are planned to be implemented in the future.
|
||||
#[async_trait]
|
||||
pub trait Queueable: Send {
|
||||
/// Pull pending tasks from the queue to execute them.
|
||||
///
|
||||
/// This method returns one task of the `task_type` type. If `task_type` is `None` it will try to
|
||||
/// fetch a task of the type `common`. The returned task is marked as running and must be executed.
|
||||
async fn pull_next_task(
|
||||
&mut self,
|
||||
kind: Option<TaskType>,
|
||||
) -> Result<Option<Task>, AsyncQueueError>;
|
||||
#[derive(Clone)]
|
||||
pub struct Queue {
|
||||
task_store: Arc<PgTaskStore>,
|
||||
}
|
||||
|
||||
/// Enqueue a task to the queue, The task will be executed as soon as possible by the worker of the same type
|
||||
/// created by an AsyncWorkerPool.
|
||||
async fn create_task(&mut self, task: &dyn RunnableTask) -> Result<Task, AsyncQueueError>;
|
||||
impl Queue {
|
||||
pub(crate) fn new(task_store: Arc<PgTaskStore>) -> Self {
|
||||
Queue { task_store }
|
||||
}
|
||||
|
||||
/// Retrieve a task by its `id`.
|
||||
async fn find_task_by_id(&mut self, id: TaskId) -> Result<Task, AsyncQueueError>;
|
||||
|
||||
/// Update the state of a task to failed and set an error_message.
|
||||
async fn set_task_failed(
|
||||
&mut self,
|
||||
id: TaskId,
|
||||
error_message: &str,
|
||||
) -> Result<Task, AsyncQueueError>;
|
||||
|
||||
/// Update the state of a task to done.
|
||||
async fn set_task_done(&mut self, id: TaskId) -> Result<Task, AsyncQueueError>;
|
||||
|
||||
/// Update the state of a task to inform that it's still in progress.
|
||||
async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError>;
|
||||
|
||||
/// Remove a task by its id.
|
||||
async fn remove_task(&mut self, id: TaskId) -> Result<u64, AsyncQueueError>;
|
||||
|
||||
/// The method will remove all tasks from the queue
|
||||
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError>;
|
||||
|
||||
/// Remove all tasks that are scheduled in the future.
|
||||
async fn remove_all_scheduled_tasks(&mut self) -> Result<u64, AsyncQueueError>;
|
||||
|
||||
/// Remove a task by its metadata (struct fields values)
|
||||
async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result<bool, AsyncQueueError>;
|
||||
|
||||
/// Removes all tasks that have the specified `task_type`.
|
||||
async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result<u64, AsyncQueueError>;
|
||||
|
||||
async fn schedule_task_retry(
|
||||
&mut self,
|
||||
id: TaskId,
|
||||
backoff_seconds: u32,
|
||||
error: &str,
|
||||
) -> Result<Task, AsyncQueueError>;
|
||||
pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), AsyncQueueError>
|
||||
where
|
||||
BT: BackgroundTask,
|
||||
{
|
||||
self.task_store
|
||||
.create_task(NewTask::new(background_task, Duration::from_secs(10))?)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PgAsyncQueue {
|
||||
pub struct PgTaskStore {
|
||||
pool: Pool<AsyncPgConnection>,
|
||||
}
|
||||
|
||||
impl PgAsyncQueue {
|
||||
impl PgTaskStore {
|
||||
pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
|
||||
PgAsyncQueue { pool }
|
||||
PgTaskStore { pool }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Queueable for PgAsyncQueue {
|
||||
async fn pull_next_task(
|
||||
&mut self,
|
||||
task_type: Option<TaskType>,
|
||||
pub(crate) async fn pull_next_task(
|
||||
&self,
|
||||
queue_name: &str,
|
||||
) -> Result<Option<Task>, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
|
@ -90,7 +52,7 @@ impl Queueable for PgAsyncQueue {
|
|||
connection
|
||||
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
|
||||
async move {
|
||||
let Some(pending_task) = Task::fetch_next_pending(conn, task_type.unwrap_or_default()).await else {
|
||||
let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
|
@ -101,16 +63,16 @@ impl Queueable for PgAsyncQueue {
|
|||
.await
|
||||
}
|
||||
|
||||
async fn create_task(&mut self, runnable: &dyn RunnableTask) -> Result<Task, AsyncQueueError> {
|
||||
pub(crate) async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Ok(Task::insert(&mut connection, runnable).await?)
|
||||
Task::insert(&mut connection, new_task).await
|
||||
}
|
||||
|
||||
async fn find_task_by_id(&mut self, id: TaskId) -> Result<Task, AsyncQueueError> {
|
||||
pub(crate) async fn find_task_by_id(&self, id: TaskId) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
|
@ -119,29 +81,27 @@ impl Queueable for PgAsyncQueue {
|
|||
Task::find_by_id(&mut connection, id).await
|
||||
}
|
||||
|
||||
async fn set_task_failed(
|
||||
&mut self,
|
||||
pub(crate) async fn set_task_state(
|
||||
&self,
|
||||
id: TaskId,
|
||||
error_message: &str,
|
||||
) -> Result<Task, AsyncQueueError> {
|
||||
state: TaskState,
|
||||
) -> Result<(), AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Task::fail_with_message(&mut connection, id, error_message).await
|
||||
match state {
|
||||
TaskState::Done => Task::set_done(&mut connection, id).await?,
|
||||
TaskState::Failed(error_msg) => {
|
||||
Task::fail_with_message(&mut connection, id, &error_msg).await?
|
||||
}
|
||||
_ => return Ok(()),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn set_task_done(&mut self, id: TaskId) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Task::set_done(&mut connection, id).await
|
||||
}
|
||||
|
||||
async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError> {
|
||||
pub(crate) async fn keep_task_alive(&self, id: TaskId) -> Result<(), AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
|
@ -159,7 +119,7 @@ impl Queueable for PgAsyncQueue {
|
|||
.await
|
||||
}
|
||||
|
||||
async fn remove_task(&mut self, id: TaskId) -> Result<u64, AsyncQueueError> {
|
||||
pub(crate) async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
|
@ -169,7 +129,7 @@ impl Queueable for PgAsyncQueue {
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError> {
|
||||
pub(crate) async fn remove_all_tasks(&self) -> Result<u64, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
|
@ -178,37 +138,8 @@ impl Queueable for PgAsyncQueue {
|
|||
Task::remove_all(&mut connection).await
|
||||
}
|
||||
|
||||
async fn remove_all_scheduled_tasks(&mut self) -> Result<u64, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
let result = Task::remove_all_scheduled(&mut connection).await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result<bool, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Task::remove_by_hash(&mut connection, task_hash).await
|
||||
}
|
||||
|
||||
async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result<u64, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
let result = Task::remove_by_type(&mut connection, task_type).await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn schedule_task_retry(
|
||||
&mut self,
|
||||
pub(crate) async fn schedule_task_retry(
|
||||
&self,
|
||||
id: TaskId,
|
||||
backoff_seconds: u32,
|
||||
error: &str,
|
||||
|
@ -226,11 +157,8 @@ impl Queueable for PgAsyncQueue {
|
|||
#[cfg(test)]
|
||||
mod async_queue_tests {
|
||||
use super::*;
|
||||
use crate::task::TaskState;
|
||||
use crate::Scheduled;
|
||||
use crate::CurrentTask;
|
||||
use async_trait::async_trait;
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -240,13 +168,12 @@ mod async_queue_tests {
|
|||
pub number: u16,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for AsyncTask {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
impl BackgroundTask for AsyncTask {
|
||||
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -256,13 +183,12 @@ mod async_queue_tests {
|
|||
pub number: u16,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for AsyncUniqTask {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
impl BackgroundTask for AsyncUniqTask {
|
||||
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -277,286 +203,154 @@ mod async_queue_tests {
|
|||
pub datetime: String,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for AsyncTaskSchedule {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
impl BackgroundTask for AsyncTaskSchedule {
|
||||
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cron(&self) -> Option<Scheduled> {
|
||||
let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
|
||||
Some(Scheduled::ScheduleOnce(datetime))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_task_creates_new_task() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(1), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_task_state_test() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
let id = task.id;
|
||||
|
||||
assert_eq!(Some(1), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let finished_task = test.set_task_done(task.id).await.unwrap();
|
||||
|
||||
assert_eq!(id, finished_task.id);
|
||||
assert_eq!(TaskState::Done, finished_task.state());
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn failed_task_query_test() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
let id = task.id;
|
||||
|
||||
assert_eq!(Some(1), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let failed_task = test.set_task_failed(task.id, "Some error").await.unwrap();
|
||||
|
||||
assert_eq!(id, failed_task.id);
|
||||
assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
|
||||
assert_eq!(TaskState::Failed, failed_task.state());
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_all_tasks_test() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(1), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(2), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let result = test.remove_all_tasks().await.unwrap();
|
||||
assert_eq!(2, result);
|
||||
// fn cron(&self) -> Option<Scheduled> {
|
||||
// let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
|
||||
// Some(Scheduled::ScheduleOnce(datetime))
|
||||
// }
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn schedule_task_test() {
|
||||
// async fn insert_task_creates_new_task() {
|
||||
// let pool = pool().await;
|
||||
// let mut test = PgAsyncQueue::new(pool);
|
||||
// let mut queue = PgTaskStore::new(pool);
|
||||
//
|
||||
// let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0);
|
||||
//
|
||||
// let task = &AsyncTaskSchedule {
|
||||
// number: 1,
|
||||
// datetime: datetime.to_string(),
|
||||
// };
|
||||
//
|
||||
// let task = test.schedule_task(task).await.unwrap();
|
||||
// let task = queue.create_task(AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTaskSchedule"), type_task);
|
||||
// assert_eq!(task.scheduled_at, datetime);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// queue.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn update_task_state_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut test = PgTaskStore::new(pool);
|
||||
//
|
||||
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
// let id = task.id;
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let finished_task = test.set_task_state(task.id, TaskState::Done).await.unwrap();
|
||||
//
|
||||
// assert_eq!(id, finished_task.id);
|
||||
// assert_eq!(TaskState::Done, finished_task.state());
|
||||
//
|
||||
// test.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn remove_all_scheduled_tasks_test() {
|
||||
// async fn failed_task_query_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut test = PgAsyncQueue::new(pool);
|
||||
// let mut test = PgTaskStore::new(pool);
|
||||
//
|
||||
// let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0);
|
||||
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let task1 = &AsyncTaskSchedule {
|
||||
// number: 1,
|
||||
// datetime: datetime.to_string(),
|
||||
// };
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
// let id = task.id;
|
||||
//
|
||||
// let task2 = &AsyncTaskSchedule {
|
||||
// number: 2,
|
||||
// datetime: datetime.to_string(),
|
||||
// };
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// test.schedule_task(task1).await.unwrap();
|
||||
// test.schedule_task(task2).await.unwrap();
|
||||
// let failed_task = test.set_task_state(task.id, TaskState::Failed("Some error".to_string())).await.unwrap();
|
||||
//
|
||||
// let number = test.remove_all_scheduled_tasks().await.unwrap();
|
||||
//
|
||||
// assert_eq!(2, number);
|
||||
// assert_eq!(id, failed_task.id);
|
||||
// assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
|
||||
// assert_eq!(TaskState::Failed, failed_task.state());
|
||||
//
|
||||
// test.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
|
||||
#[tokio::test]
|
||||
async fn pull_next_task_test() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(1), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(2), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let task = test.pull_next_task(None).await.unwrap().unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(1), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let task = test.pull_next_task(None).await.unwrap().unwrap();
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(2), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_tasks_type_test() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(1), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(2), number);
|
||||
assert_eq!(Some("AsyncTask"), type_task);
|
||||
|
||||
let result = test
|
||||
.remove_tasks_type(TaskType::from("nonexistentType"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(0, result);
|
||||
|
||||
let result = test.remove_tasks_type(TaskType::default()).await.unwrap();
|
||||
assert_eq!(2, result);
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_tasks_by_metadata() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task = test
|
||||
.create_task(&AsyncUniqTask { number: 1 })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(1), number);
|
||||
assert_eq!(Some("AsyncUniqTask"), type_task);
|
||||
|
||||
let task = test
|
||||
.create_task(&AsyncUniqTask { number: 2 })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let metadata = task.payload.as_object().unwrap();
|
||||
let number = metadata["number"].as_u64();
|
||||
let type_task = metadata["type"].as_str();
|
||||
|
||||
assert_eq!(Some(2), number);
|
||||
assert_eq!(Some("AsyncUniqTask"), type_task);
|
||||
|
||||
let result = test
|
||||
.remove_task_by_hash(AsyncUniqTask { number: 0 }.uniq().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result, "Should **not** remove task");
|
||||
|
||||
let result = test
|
||||
.remove_task_by_hash(AsyncUniqTask { number: 1 }.uniq().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result, "Should remove task");
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn remove_all_tasks_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut test = PgTaskStore::new(pool);
|
||||
//
|
||||
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(2), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let result = test.remove_all_tasks().await.unwrap();
|
||||
// assert_eq!(2, result);
|
||||
// }
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn pull_next_task_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut queue = PgTaskStore::new(pool);
|
||||
//
|
||||
// let task = queue.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let task = queue.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(2), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(2), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// queue.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
|
||||
async fn pool() -> Pool<AsyncPgConnection> {
|
||||
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
||||
|
|
|
@ -1,27 +1,34 @@
|
|||
use crate::queue::Queueable;
|
||||
use crate::task::TaskHash;
|
||||
use crate::task::TaskType;
|
||||
use crate::Scheduled;
|
||||
use crate::task::{CurrentTask, TaskHash};
|
||||
use async_trait::async_trait;
|
||||
use std::error::Error;
|
||||
|
||||
pub const RETRIES_NUMBER: i32 = 5;
|
||||
use serde::{de::DeserializeOwned, ser::Serialize};
|
||||
|
||||
/// Task that can be executed by the queue.
|
||||
///
|
||||
/// The `RunnableTask` trait is used to define the behaviour of a task. You must implement this
|
||||
/// The `BackgroundTask` trait is used to define the behaviour of a task. You must implement this
|
||||
/// trait for all tasks you want to execute.
|
||||
#[typetag::serde(tag = "type")]
|
||||
#[async_trait]
|
||||
pub trait RunnableTask: Send + Sync {
|
||||
/// Execute the task. This method should define its logic
|
||||
async fn run(&self, queue: &mut dyn Queueable) -> Result<(), Box<dyn Error + Send + 'static>>;
|
||||
pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
|
||||
/// Unique name of the task.
|
||||
///
|
||||
/// This MUST be unique for the whole application.
|
||||
const TASK_NAME: &'static str;
|
||||
|
||||
/// Define the type of the task.
|
||||
/// The `common` task type is used by default
|
||||
fn task_type(&self) -> TaskType {
|
||||
TaskType::default()
|
||||
}
|
||||
/// Task queue where this task will be executed.
|
||||
///
|
||||
/// Used to define which workers are going to be executing this task. It uses the default
|
||||
/// task queue if not changed.
|
||||
const QUEUE: &'static str = "default";
|
||||
|
||||
/// Number of retries for tasks.
|
||||
///
|
||||
/// By default, it is set to 5.
|
||||
const MAX_RETRIES: i32 = 5;
|
||||
|
||||
/// The application data provided to this task at runtime.
|
||||
type AppData: Clone + Send + 'static;
|
||||
|
||||
/// Execute the task. This method should define its logic
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>;
|
||||
|
||||
/// If set to true, no new tasks with the same metadata will be inserted
|
||||
/// By default it is set to false.
|
||||
|
@ -29,27 +36,10 @@ pub trait RunnableTask: Send + Sync {
|
|||
None
|
||||
}
|
||||
|
||||
/// This method defines if a task is periodic or it should be executed once in the future.
|
||||
///
|
||||
/// Be careful it works only with the UTC timezone.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```rust
|
||||
/// fn cron(&self) -> Option<Scheduled> {
|
||||
/// let expression = "0/20 * * * Aug-Sep * 2022/1";
|
||||
/// Some(Scheduled::CronPattern(expression.to_string()))
|
||||
/// }
|
||||
///```
|
||||
/// In order to schedule a task once, use the `Scheduled::ScheduleOnce` enum variant.
|
||||
fn cron(&self) -> Option<Scheduled> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Define the maximum number of retries the task will be retried.
|
||||
/// By default the number of retries is 20.
|
||||
fn max_retries(&self) -> i32 {
|
||||
RETRIES_NUMBER
|
||||
Self::MAX_RETRIES
|
||||
}
|
||||
|
||||
/// Define the backoff mode
|
||||
|
|
19
src/schema.rs
Normal file
19
src/schema.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
// @generated automatically by Diesel CLI.
|
||||
|
||||
diesel::table! {
|
||||
backie_tasks (id) {
|
||||
id -> Uuid,
|
||||
task_name -> Varchar,
|
||||
queue_name -> Varchar,
|
||||
uniq_hash -> Nullable<Bpchar>,
|
||||
payload -> Jsonb,
|
||||
timeout_msecs -> Int8,
|
||||
created_at -> Timestamptz,
|
||||
scheduled_at -> Timestamptz,
|
||||
running_at -> Nullable<Timestamptz>,
|
||||
done_at -> Nullable<Timestamptz>,
|
||||
error_info -> Nullable<Jsonb>,
|
||||
retries -> Int4,
|
||||
max_retries -> Int4,
|
||||
}
|
||||
}
|
1
src/store.rs
Normal file
1
src/store.rs
Normal file
|
@ -0,0 +1 @@
|
|||
|
133
src/task.rs
133
src/task.rs
|
@ -1,4 +1,5 @@
|
|||
use crate::schema::backie_tasks;
|
||||
use crate::BackgroundTask;
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use diesel::prelude::*;
|
||||
|
@ -8,11 +9,11 @@ use sha2::{Digest, Sha256};
|
|||
use std::borrow::Cow;
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use typed_builder::TypedBuilder;
|
||||
use std::time::Duration;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// States of a task.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum TaskState {
|
||||
/// The task is ready to be executed.
|
||||
Ready,
|
||||
|
@ -21,7 +22,7 @@ pub enum TaskState {
|
|||
Running,
|
||||
|
||||
/// The task has failed to execute.
|
||||
Failed,
|
||||
Failed(String),
|
||||
|
||||
/// The task finished successfully.
|
||||
Done,
|
||||
|
@ -36,24 +37,6 @@ impl Display for TaskId {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
|
||||
pub struct TaskType(Cow<'static, str>);
|
||||
|
||||
impl Default for TaskType {
|
||||
fn default() -> Self {
|
||||
Self(Cow::from("default"))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<S> for TaskType
|
||||
where
|
||||
S: AsRef<str> + 'static,
|
||||
{
|
||||
fn from(s: S) -> Self {
|
||||
TaskType(Cow::from(s.as_ref().to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
|
||||
pub struct TaskHash(Cow<'static, str>);
|
||||
|
||||
|
@ -70,42 +53,55 @@ impl TaskHash {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone, TypedBuilder)]
|
||||
#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone)]
|
||||
#[diesel(table_name = backie_tasks)]
|
||||
pub struct Task {
|
||||
#[builder(setter(into))]
|
||||
/// Unique identifier of the task.
|
||||
pub id: TaskId,
|
||||
|
||||
#[builder(setter(into))]
|
||||
pub payload: serde_json::Value,
|
||||
/// Name of the type of task.
|
||||
pub task_name: String,
|
||||
|
||||
#[builder(setter(into))]
|
||||
pub error_message: Option<String>,
|
||||
/// Queue name that the task belongs to.
|
||||
pub queue_name: String,
|
||||
|
||||
#[builder(setter(into))]
|
||||
pub task_type: TaskType,
|
||||
|
||||
#[builder(setter(into))]
|
||||
/// Unique hash is used to identify and avoid duplicate tasks.
|
||||
pub uniq_hash: Option<TaskHash>,
|
||||
|
||||
#[builder(setter(into))]
|
||||
pub retries: i32,
|
||||
/// Representation of the task.
|
||||
pub payload: serde_json::Value,
|
||||
|
||||
#[builder(setter(into))]
|
||||
/// Max timeout that the task can run for.
|
||||
pub timeout_msecs: i64,
|
||||
|
||||
/// Creation time of the task.
|
||||
pub created_at: DateTime<Utc>,
|
||||
|
||||
#[builder(setter(into))]
|
||||
/// Date time when the task is scheduled to run.
|
||||
pub scheduled_at: DateTime<Utc>,
|
||||
|
||||
/// Date time when the task is started to run.
|
||||
pub running_at: Option<DateTime<Utc>>,
|
||||
|
||||
#[builder(setter(into))]
|
||||
/// Date time when the task is finished.
|
||||
pub done_at: Option<DateTime<Utc>>,
|
||||
|
||||
/// Failure reason, when the task is failed.
|
||||
pub error_info: Option<serde_json::Value>,
|
||||
|
||||
/// Number of times a task was retried.
|
||||
pub retries: i32,
|
||||
|
||||
/// Maximum number of retries allow for this task before it is maked as failure.
|
||||
pub max_retries: i32,
|
||||
}
|
||||
|
||||
impl Task {
|
||||
pub fn state(&self) -> TaskState {
|
||||
if self.done_at.is_some() {
|
||||
if self.error_message.is_some() {
|
||||
TaskState::Failed
|
||||
if self.error_info.is_some() {
|
||||
// TODO: use a proper error type
|
||||
TaskState::Failed(self.error_info.clone().unwrap().to_string())
|
||||
} else {
|
||||
TaskState::Done
|
||||
}
|
||||
|
@ -117,22 +113,61 @@ impl Task {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Insertable, Debug, Eq, PartialEq, Clone, TypedBuilder)]
|
||||
#[derive(Insertable, Debug, Eq, PartialEq, Clone)]
|
||||
#[diesel(table_name = backie_tasks)]
|
||||
pub struct NewTask {
|
||||
#[builder(setter(into))]
|
||||
payload: serde_json::Value,
|
||||
|
||||
#[builder(setter(into))]
|
||||
task_type: TaskType,
|
||||
|
||||
#[builder(setter(into))]
|
||||
pub(crate) struct NewTask {
|
||||
task_name: String,
|
||||
queue_name: String,
|
||||
uniq_hash: Option<TaskHash>,
|
||||
payload: serde_json::Value,
|
||||
timeout_msecs: i64,
|
||||
max_retries: i32,
|
||||
}
|
||||
|
||||
pub struct TaskInfo {
|
||||
impl NewTask {
|
||||
pub(crate) fn new<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error>
|
||||
where
|
||||
T: BackgroundTask,
|
||||
{
|
||||
let max_retries = background_task.max_retries();
|
||||
let uniq_hash = background_task.uniq();
|
||||
let payload = serde_json::to_value(background_task)?;
|
||||
|
||||
Ok(Self {
|
||||
task_name: T::TASK_NAME.to_string(),
|
||||
queue_name: T::QUEUE.to_string(),
|
||||
uniq_hash,
|
||||
payload,
|
||||
timeout_msecs: timeout.as_millis() as i64,
|
||||
max_retries,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CurrentTask {
|
||||
id: TaskId,
|
||||
error_message: Option<String>,
|
||||
retries: i32,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl CurrentTask {
|
||||
pub(crate) fn new(task: &Task) -> Self {
|
||||
Self {
|
||||
id: task.id.clone(),
|
||||
retries: task.retries,
|
||||
created_at: task.created_at,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> TaskId {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn retry_count(&self) -> i32 {
|
||||
self.retries
|
||||
}
|
||||
|
||||
pub fn created_at(&self) -> DateTime<Utc> {
|
||||
self.created_at
|
||||
}
|
||||
}
|
||||
|
|
549
src/worker.rs
549
src/worker.rs
|
@ -1,53 +1,104 @@
|
|||
use crate::errors::BackieError;
|
||||
use crate::queue::Queueable;
|
||||
use crate::runnable::RunnableTask;
|
||||
use crate::task::{Task, TaskType};
|
||||
use crate::RetentionMode;
|
||||
use crate::Scheduled::*;
|
||||
use crate::errors::{AsyncQueueError, BackieError};
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::task::{CurrentTask, Task, TaskState};
|
||||
use crate::{PgTaskStore, RetentionMode};
|
||||
use futures::future::FutureExt;
|
||||
use futures::select;
|
||||
use std::error::Error;
|
||||
use typed_builder::TypedBuilder;
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
|
||||
/// it executes tasks only of task_type type, it sleeps when there are no tasks in the queue
|
||||
#[derive(TypedBuilder)]
|
||||
pub struct Worker<Q>
|
||||
where
|
||||
Q: Queueable + Clone + Sync + 'static,
|
||||
{
|
||||
#[builder(setter(into))]
|
||||
pub queue: Q,
|
||||
pub type ExecuteTaskFn<AppData> = Arc<
|
||||
dyn Fn(
|
||||
CurrentTask,
|
||||
serde_json::Value,
|
||||
AppData,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
>;
|
||||
|
||||
#[builder(default, setter(into))]
|
||||
pub task_type: Option<TaskType>,
|
||||
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
|
||||
|
||||
#[builder(default, setter(into))]
|
||||
pub retention_mode: RetentionMode,
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TaskExecError {
|
||||
#[error("Task execution failed: {0}")]
|
||||
ExecutionFailed(#[from] anyhow::Error),
|
||||
|
||||
#[builder(default, setter(into))]
|
||||
pub shutdown: Option<tokio::sync::watch::Receiver<()>>,
|
||||
#[error("Task deserialization failed: {0}")]
|
||||
TaskDeserializationFailed(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
impl<Q> Worker<Q>
|
||||
pub(crate) fn runnable<BT>(
|
||||
task_info: CurrentTask,
|
||||
payload: serde_json::Value,
|
||||
app_context: BT::AppData,
|
||||
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
|
||||
where
|
||||
Q: Queueable + Clone + Sync + 'static,
|
||||
BT: BackgroundTask,
|
||||
{
|
||||
Box::pin(async move {
|
||||
let background_task: BT = serde_json::from_value(payload)?;
|
||||
background_task.run(task_info, app_context).await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Worker that executes tasks.
|
||||
pub struct Worker<AppData>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
{
|
||||
store: Arc<PgTaskStore>,
|
||||
|
||||
queue_name: String,
|
||||
|
||||
retention_mode: RetentionMode,
|
||||
|
||||
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||
|
||||
app_data_fn: StateFn<AppData>,
|
||||
|
||||
/// Notification for the worker to stop.
|
||||
shutdown: Option<tokio::sync::watch::Receiver<()>>,
|
||||
}
|
||||
|
||||
impl<AppData> Worker<AppData>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
store: Arc<PgTaskStore>,
|
||||
queue_name: String,
|
||||
retention_mode: RetentionMode,
|
||||
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||
app_data_fn: StateFn<AppData>,
|
||||
shutdown: Option<tokio::sync::watch::Receiver<()>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
queue_name,
|
||||
retention_mode,
|
||||
task_registry,
|
||||
app_data_fn,
|
||||
shutdown,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> {
|
||||
loop {
|
||||
// Need to check if has to stop before pulling next task
|
||||
match self.queue.pull_next_task(self.task_type.clone()).await? {
|
||||
Some(task) => {
|
||||
let actual_task: Box<dyn RunnableTask> =
|
||||
serde_json::from_value(task.payload.clone())?;
|
||||
|
||||
// check if task is scheduled or not
|
||||
if let Some(CronPattern(_)) = actual_task.cron() {
|
||||
// program task
|
||||
//self.queue.schedule_task(&*actual_task).await?;
|
||||
// Check if has to stop before pulling next task
|
||||
if let Some(ref shutdown) = self.shutdown {
|
||||
if shutdown.has_changed()? {
|
||||
return Ok(());
|
||||
}
|
||||
// run scheduled task
|
||||
// TODO: what do we do if the task fails? it's an internal error, inform the logs
|
||||
let _ = self.run(task, actual_task).await;
|
||||
};
|
||||
|
||||
match self.store.pull_next_task(&self.queue_name).await? {
|
||||
Some(task) => {
|
||||
self.run(task).await?;
|
||||
}
|
||||
None => {
|
||||
// Listen to watchable future
|
||||
|
@ -73,41 +124,45 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> {
|
||||
loop {
|
||||
match self.queue.pull_next_task(self.task_type.clone()).await? {
|
||||
Some(task) => {
|
||||
let actual_task: Box<dyn RunnableTask> =
|
||||
serde_json::from_value(task.payload.clone()).unwrap();
|
||||
// #[cfg(test)]
|
||||
// pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> {
|
||||
// loop {
|
||||
// match self.store.pull_next_task(self.queue_name.clone()).await? {
|
||||
// Some(task) => {
|
||||
// let actual_task: Box<dyn BackgroundTask> =
|
||||
// serde_json::from_value(task.payload.clone()).unwrap();
|
||||
//
|
||||
// // check if task is scheduled or not
|
||||
// if let Some(CronPattern(_)) = actual_task.cron() {
|
||||
// // program task
|
||||
// // self.queue.schedule_task(&*actual_task).await?;
|
||||
// }
|
||||
// // run scheduled task
|
||||
// self.run(task, actual_task).await?;
|
||||
// }
|
||||
// None => {
|
||||
// return Ok(());
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
// }
|
||||
|
||||
// check if task is scheduled or not
|
||||
if let Some(CronPattern(_)) = actual_task.cron() {
|
||||
// program task
|
||||
// self.queue.schedule_task(&*actual_task).await?;
|
||||
}
|
||||
// run scheduled task
|
||||
self.run(task, actual_task).await?;
|
||||
}
|
||||
None => {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
async fn run(&self, task: Task) -> Result<(), BackieError> {
|
||||
let task_info = CurrentTask::new(&task);
|
||||
let runnable_task_caller = self
|
||||
.task_registry
|
||||
.get(&task.task_name)
|
||||
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
|
||||
|
||||
async fn run(
|
||||
&mut self,
|
||||
task: Task,
|
||||
runnable: Box<dyn RunnableTask>,
|
||||
) -> Result<(), BackieError> {
|
||||
// TODO: catch panics
|
||||
let result = runnable.run(&mut self.queue).await;
|
||||
match result {
|
||||
let result: Result<(), TaskExecError> =
|
||||
runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
|
||||
|
||||
match &result {
|
||||
Ok(_) => self.finalize_task(task, result).await?,
|
||||
Err(error) => {
|
||||
if task.retries < runnable.max_retries() {
|
||||
let backoff_seconds = runnable.backoff(task.retries as u32);
|
||||
if task.retries < task.max_retries {
|
||||
let backoff_seconds = 5; // TODO: runnable.backoff(task.retries as u32);
|
||||
|
||||
log::debug!(
|
||||
"Task {} failed to run and will be retried in {} seconds",
|
||||
|
@ -115,12 +170,12 @@ where
|
|||
backoff_seconds
|
||||
);
|
||||
let error_message = format!("{}", error);
|
||||
self.queue
|
||||
self.store
|
||||
.schedule_task_retry(task.id, backoff_seconds, &error_message)
|
||||
.await?;
|
||||
} else {
|
||||
log::debug!("Task {} failed and reached the maximum retries", task.id);
|
||||
self.finalize_task(task, Err(error)).await?;
|
||||
self.finalize_task(task, result).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -128,36 +183,36 @@ where
|
|||
}
|
||||
|
||||
async fn finalize_task(
|
||||
&mut self,
|
||||
&self,
|
||||
task: Task,
|
||||
result: Result<(), Box<dyn Error + Send + 'static>>,
|
||||
result: Result<(), TaskExecError>,
|
||||
) -> Result<(), BackieError> {
|
||||
match self.retention_mode {
|
||||
RetentionMode::KeepAll => match result {
|
||||
Ok(_) => {
|
||||
self.queue.set_task_done(task.id).await?;
|
||||
self.store.set_task_state(task.id, TaskState::Done).await?;
|
||||
log::debug!("Task {} done and kept in the database", task.id);
|
||||
}
|
||||
Err(error) => {
|
||||
log::debug!("Task {} failed and kept in the database", task.id);
|
||||
self.queue
|
||||
.set_task_failed(task.id, &format!("{}", error))
|
||||
self.store
|
||||
.set_task_state(task.id, TaskState::Failed(format!("{}", error)))
|
||||
.await?;
|
||||
}
|
||||
},
|
||||
RetentionMode::RemoveAll => {
|
||||
log::debug!("Task {} finalized and deleted from the database", task.id);
|
||||
self.queue.remove_task(task.id).await?;
|
||||
self.store.remove_task(task.id).await?;
|
||||
}
|
||||
RetentionMode::RemoveDone => match result {
|
||||
Ok(_) => {
|
||||
log::debug!("Task {} done and deleted from the database", task.id);
|
||||
self.queue.remove_task(task.id).await?;
|
||||
self.store.remove_task(task.id).await?;
|
||||
}
|
||||
Err(error) => {
|
||||
log::debug!("Task {} failed and kept in the database", task.id);
|
||||
self.queue
|
||||
.set_task_failed(task.id, &format!("{}", error))
|
||||
self.store
|
||||
.set_task_state(task.id, TaskState::Failed(format!("{}", error)))
|
||||
.await?;
|
||||
}
|
||||
},
|
||||
|
@ -169,35 +224,19 @@ where
|
|||
|
||||
#[cfg(test)]
|
||||
mod async_worker_tests {
|
||||
use std::fmt::Display;
|
||||
use super::*;
|
||||
use crate::queue::PgAsyncQueue;
|
||||
use crate::queue::Queueable;
|
||||
use crate::task::TaskState;
|
||||
use crate::worker::Task;
|
||||
use crate::RetentionMode;
|
||||
use crate::Scheduled;
|
||||
use async_trait::async_trait;
|
||||
use chrono::Duration;
|
||||
use chrono::Utc;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
enum TaskError {
|
||||
#[error("Something went wrong")]
|
||||
SomethingWrong,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl Display for TaskError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
TaskError::SomethingWrong => write!(f, "Something went wrong"),
|
||||
TaskError::Custom(message) => write!(f, "{}", message),
|
||||
}
|
||||
}
|
||||
#[error("{0}")]
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
@ -205,13 +244,12 @@ mod async_worker_tests {
|
|||
pub number: u16,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for WorkerAsyncTask {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
impl BackgroundTask for WorkerAsyncTask {
|
||||
const TASK_NAME: &'static str = "WorkerAsyncTask";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -221,18 +259,18 @@ mod async_worker_tests {
|
|||
pub number: u16,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for WorkerAsyncTaskSchedule {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
impl BackgroundTask for WorkerAsyncTaskSchedule {
|
||||
const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
fn cron(&self) -> Option<Scheduled> {
|
||||
Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
|
||||
}
|
||||
|
||||
// fn cron(&self) -> Option<Scheduled> {
|
||||
// Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
|
||||
// }
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
@ -240,16 +278,15 @@ mod async_worker_tests {
|
|||
pub number: u16,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for AsyncFailedTask {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
impl BackgroundTask for AsyncFailedTask {
|
||||
const TASK_NAME: &'static str = "AsyncFailedTask";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
let message = format!("number {} is wrong :(", self.number);
|
||||
|
||||
Err(Box::new(TaskError::Custom(message)))
|
||||
Err(TaskError::Custom(message).into())
|
||||
}
|
||||
|
||||
fn max_retries(&self) -> i32 {
|
||||
|
@ -260,282 +297,40 @@ mod async_worker_tests {
|
|||
#[derive(Serialize, Deserialize, Clone)]
|
||||
struct AsyncRetryTask {}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for AsyncRetryTask {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
Err(Box::new(TaskError::SomethingWrong))
|
||||
}
|
||||
impl BackgroundTask for AsyncRetryTask {
|
||||
const TASK_NAME: &'static str = "AsyncRetryTask";
|
||||
type AppData = ();
|
||||
|
||||
fn max_retries(&self) -> i32 {
|
||||
2
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Err(TaskError::SomethingWrong.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AsyncTaskType1 {}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for AsyncTaskType1 {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
Ok(())
|
||||
}
|
||||
impl BackgroundTask for AsyncTaskType1 {
|
||||
const TASK_NAME: &'static str = "AsyncTaskType1";
|
||||
type AppData = ();
|
||||
|
||||
fn task_type(&self) -> TaskType {
|
||||
"type1".into()
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AsyncTaskType2 {}
|
||||
|
||||
#[typetag::serde]
|
||||
#[async_trait]
|
||||
impl RunnableTask for AsyncTaskType2 {
|
||||
async fn run(
|
||||
&self,
|
||||
_queueable: &mut dyn Queueable,
|
||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
||||
impl BackgroundTask for AsyncTaskType2 {
|
||||
const TASK_NAME: &'static str = "AsyncTaskType2";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn task_type(&self) -> TaskType {
|
||||
TaskType::from("type2")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_and_finishes_task() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let actual_task = WorkerAsyncTask { number: 1 };
|
||||
|
||||
let task = insert_task(&mut test, &actual_task).await;
|
||||
let id = task.id;
|
||||
|
||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
||||
.queue(test.clone())
|
||||
.retention_mode(RetentionMode::KeepAll)
|
||||
.build();
|
||||
|
||||
worker.run(task, Box::new(actual_task)).await.unwrap();
|
||||
let task_finished = test.find_task_by_id(id).await.unwrap();
|
||||
assert_eq!(id, task_finished.id);
|
||||
assert_eq!(TaskState::Done, task_finished.state());
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn schedule_task_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut test = PgAsyncQueue::new(pool);
|
||||
//
|
||||
// let actual_task = WorkerAsyncTaskSchedule { number: 1 };
|
||||
//
|
||||
// let task = test.schedule_task(&actual_task).await.unwrap();
|
||||
//
|
||||
// let id = task.id;
|
||||
//
|
||||
// let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
|
||||
// .queue(test.clone())
|
||||
// .retention_mode(RetentionMode::KeepAll)
|
||||
// .build();
|
||||
//
|
||||
// worker.run_tasks_until_none().await.unwrap();
|
||||
//
|
||||
// let task = worker.queue.find_task_by_id(id).await.unwrap();
|
||||
//
|
||||
// assert_eq!(id, task.id);
|
||||
// assert_eq!(TaskState::Ready, task.state());
|
||||
//
|
||||
// tokio::time::sleep(core::time::Duration::from_secs(3)).await;
|
||||
//
|
||||
// worker.run_tasks_until_none().await.unwrap();
|
||||
//
|
||||
// let task = test.find_task_by_id(id).await.unwrap();
|
||||
// assert_eq!(id, task.id);
|
||||
// assert_eq!(TaskState::Done, task.state());
|
||||
//
|
||||
// test.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
|
||||
#[tokio::test]
|
||||
async fn retries_task_test() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let actual_task = AsyncRetryTask {};
|
||||
|
||||
let task = test.create_task(&actual_task).await.unwrap();
|
||||
|
||||
let id = task.id;
|
||||
|
||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
||||
.queue(test.clone())
|
||||
.retention_mode(RetentionMode::KeepAll)
|
||||
.build();
|
||||
|
||||
worker.run_tasks_until_none().await.unwrap();
|
||||
|
||||
let task = worker.queue.find_task_by_id(id).await.unwrap();
|
||||
|
||||
assert_eq!(id, task.id);
|
||||
assert_eq!(TaskState::Ready, task.state());
|
||||
assert_eq!(1, task.retries);
|
||||
assert!(task.error_message.is_some());
|
||||
|
||||
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
|
||||
worker.run_tasks_until_none().await.unwrap();
|
||||
|
||||
let task = worker.queue.find_task_by_id(id).await.unwrap();
|
||||
|
||||
assert_eq!(id, task.id);
|
||||
assert_eq!(TaskState::Ready, task.state());
|
||||
assert_eq!(2, task.retries);
|
||||
|
||||
tokio::time::sleep(core::time::Duration::from_secs(10)).await;
|
||||
worker.run_tasks_until_none().await.unwrap();
|
||||
|
||||
let task = test.find_task_by_id(id).await.unwrap();
|
||||
assert_eq!(id, task.id);
|
||||
assert_eq!(TaskState::Failed, task.state());
|
||||
assert_eq!("Something went wrong".to_string(), task.error_message.unwrap());
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn worker_shutsdown_when_notified() {
|
||||
let pool = pool().await;
|
||||
let queue = PgAsyncQueue::new(pool);
|
||||
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
|
||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
||||
.queue(queue)
|
||||
.shutdown(rx)
|
||||
.build();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
worker.run_tasks().await.unwrap();
|
||||
true
|
||||
});
|
||||
|
||||
tx.send(()).unwrap();
|
||||
select! {
|
||||
_ = handle.fuse() => {}
|
||||
_ = tokio::time::sleep(core::time::Duration::from_secs(1)).fuse() => panic!("Worker did not shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn saves_error_for_failed_task() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let failed_task = AsyncFailedTask { number: 1 };
|
||||
|
||||
let task = insert_task(&mut test, &failed_task).await;
|
||||
let id = task.id;
|
||||
|
||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
||||
.queue(test.clone())
|
||||
.retention_mode(RetentionMode::KeepAll)
|
||||
.build();
|
||||
|
||||
worker.run(task, Box::new(failed_task)).await.unwrap();
|
||||
let task_finished = test.find_task_by_id(id).await.unwrap();
|
||||
|
||||
assert_eq!(id, task_finished.id);
|
||||
assert_eq!(TaskState::Failed, task_finished.state());
|
||||
assert_eq!(
|
||||
"number 1 is wrong :(".to_string(),
|
||||
task_finished.error_message.unwrap()
|
||||
);
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn executes_task_only_of_specific_type() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
|
||||
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
|
||||
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
|
||||
|
||||
let id1 = task1.id;
|
||||
let id12 = task12.id;
|
||||
let id2 = task2.id;
|
||||
|
||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
||||
.queue(test.clone())
|
||||
.task_type(TaskType::from("type1"))
|
||||
.retention_mode(RetentionMode::KeepAll)
|
||||
.build();
|
||||
|
||||
worker.run_tasks_until_none().await.unwrap();
|
||||
let task1 = test.find_task_by_id(id1).await.unwrap();
|
||||
let task12 = test.find_task_by_id(id12).await.unwrap();
|
||||
let task2 = test.find_task_by_id(id2).await.unwrap();
|
||||
|
||||
assert_eq!(id1, task1.id);
|
||||
assert_eq!(id12, task12.id);
|
||||
assert_eq!(id2, task2.id);
|
||||
assert_eq!(TaskState::Done, task1.state());
|
||||
assert_eq!(TaskState::Done, task12.state());
|
||||
assert_eq!(TaskState::Ready, task2.state());
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remove_when_finished() {
|
||||
let pool = pool().await;
|
||||
let mut test = PgAsyncQueue::new(pool);
|
||||
|
||||
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
|
||||
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
|
||||
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
|
||||
|
||||
let _id1 = task1.id;
|
||||
let _id12 = task12.id;
|
||||
let id2 = task2.id;
|
||||
|
||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
||||
.queue(test.clone())
|
||||
.task_type(TaskType::from("type1"))
|
||||
.build();
|
||||
|
||||
worker.run_tasks_until_none().await.unwrap();
|
||||
let task = test
|
||||
.pull_next_task(Some(TaskType::from("type1")))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(None, task);
|
||||
|
||||
let task2 = test
|
||||
.pull_next_task(Some(TaskType::from("type2")))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(id2, task2.id);
|
||||
|
||||
test.remove_all_tasks().await.unwrap();
|
||||
}
|
||||
|
||||
async fn insert_task(test: &mut PgAsyncQueue, task: &dyn RunnableTask) -> Task {
|
||||
test.create_task(task).await.unwrap()
|
||||
}
|
||||
|
||||
async fn pool() -> Pool<AsyncPgConnection> {
|
||||
|
|
|
@ -1,102 +1,200 @@
|
|||
use crate::errors::BackieError;
|
||||
use crate::queue::Queueable;
|
||||
use crate::task::TaskType;
|
||||
use crate::worker::Worker;
|
||||
use crate::RetentionMode;
|
||||
use async_recursion::async_recursion;
|
||||
use log::error;
|
||||
use crate::queue::Queue;
|
||||
use crate::worker::{runnable, ExecuteTaskFn};
|
||||
use crate::worker::{StateFn, Worker};
|
||||
use crate::{BackgroundTask, CurrentTask, PgTaskStore, RetentionMode};
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use tokio::sync::watch::Receiver;
|
||||
use typed_builder::TypedBuilder;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[derive(TypedBuilder, Clone)]
|
||||
pub struct WorkerPool<AQueue>
|
||||
pub type AppDataFn<AppData> = Arc<dyn Fn(Queue) -> AppData + Send + Sync>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WorkerPool<AppData>
|
||||
where
|
||||
AQueue: Queueable + Clone + Sync + 'static,
|
||||
AppData: Clone + Send + 'static,
|
||||
{
|
||||
#[builder(setter(into))]
|
||||
/// the AsyncWorkerPool uses a queue to control the tasks that will be executed.
|
||||
pub queue: AQueue,
|
||||
/// Storage of tasks.
|
||||
queue_store: Arc<PgTaskStore>, // TODO: make this generic/dynamic referenced
|
||||
|
||||
/// retention_mode controls if tasks should be persisted after execution
|
||||
#[builder(default, setter(into))]
|
||||
pub retention_mode: RetentionMode,
|
||||
/// Queue used to spawn tasks.
|
||||
queue: Queue,
|
||||
|
||||
/// the number of workers of the AsyncWorkerPool.
|
||||
#[builder(setter(into))]
|
||||
pub number_of_workers: u32,
|
||||
/// Make possible to load the application data.
|
||||
///
|
||||
/// The application data is loaded when the worker pool is started and is passed to the tasks.
|
||||
/// The loading function accepts a queue instance in case the application data depends on it. This
|
||||
/// is interesting for situations where the application wants to allow tasks to spawn other tasks.
|
||||
application_data_fn: StateFn<AppData>,
|
||||
|
||||
/// The type of tasks that will be executed by `AsyncWorkerPool`.
|
||||
#[builder(default, setter(into))]
|
||||
pub task_type: Option<TaskType>,
|
||||
/// The types of task the worker pool can execute and the loaders for them.
|
||||
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||
|
||||
/// Number of workers that will be spawned per queue.
|
||||
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
|
||||
}
|
||||
|
||||
// impl<TypedBuilderFields, Q> AsyncWorkerBuilder<TypedBuilderFields, Q>
|
||||
// where
|
||||
// TypedBuilderFields: Clone,
|
||||
// Q: Queueable + Clone + Sync + 'static,
|
||||
// {
|
||||
// pub fn with_graceful_shutdown<F>(self, signal: F) -> Self<TypedBuilderFields, Q>
|
||||
// where
|
||||
// F: Future<Output = ()>,
|
||||
// {
|
||||
// self
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<AQueue> WorkerPool<AQueue>
|
||||
impl<AppData> WorkerPool<AppData>
|
||||
where
|
||||
AQueue: Queueable + Clone + Sync + 'static,
|
||||
AppData: Clone + Send + 'static,
|
||||
{
|
||||
/// Starts the configured number of workers
|
||||
/// This is necessary in order to execute tasks.
|
||||
pub async fn start<F>(&mut self, graceful_shutdown: F) -> Result<(), BackieError>
|
||||
/// Create a new worker pool.
|
||||
pub fn new<A>(queue_store: PgTaskStore, application_data_fn: A) -> Self
|
||||
where
|
||||
A: Fn(Queue) -> AppData + Send + Sync + 'static,
|
||||
{
|
||||
let queue_store = Arc::new(queue_store);
|
||||
let queue = Queue::new(queue_store.clone());
|
||||
let application_data_fn = {
|
||||
let queue = queue.clone();
|
||||
move || application_data_fn(queue.clone())
|
||||
};
|
||||
Self {
|
||||
queue_store,
|
||||
queue,
|
||||
application_data_fn: Arc::new(application_data_fn),
|
||||
task_registry: BTreeMap::new(),
|
||||
worker_queues: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a task type with the worker pool.
|
||||
pub fn register_task_type<BT>(mut self, num_workers: u32, retention_mode: RetentionMode) -> Self
|
||||
where
|
||||
BT: BackgroundTask<AppData = AppData>,
|
||||
{
|
||||
self.worker_queues
|
||||
.insert(BT::QUEUE.to_string(), (retention_mode, num_workers));
|
||||
self.task_registry
|
||||
.insert(BT::TASK_NAME.to_string(), Arc::new(runnable::<BT>));
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn start<F>(
|
||||
self,
|
||||
graceful_shutdown: F,
|
||||
) -> Result<(JoinHandle<()>, Queue), BackieError>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
for idx in 0..self.number_of_workers {
|
||||
let pool = self.clone();
|
||||
// TODO: the worker pool keeps track of the number of workers and spawns new workers as needed.
|
||||
// There should be always a minimum number of workers active waiting for tasks to execute
|
||||
// or for a gracefull shutdown.
|
||||
tokio::spawn(Self::supervise_task(pool, rx.clone(), 0, idx));
|
||||
}
|
||||
graceful_shutdown.await;
|
||||
tx.send(())?;
|
||||
log::info!("Worker pool stopped gracefully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[async_recursion]
|
||||
async fn supervise_task(
|
||||
pool: WorkerPool<AQueue>,
|
||||
receiver: Receiver<()>,
|
||||
restarts: u64,
|
||||
worker_number: u32,
|
||||
) {
|
||||
let restarts = restarts + 1;
|
||||
|
||||
let inner_pool = pool.clone();
|
||||
let inner_receiver = receiver.clone();
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
let mut worker: Worker<AQueue> = Worker::builder()
|
||||
.queue(inner_pool.queue.clone())
|
||||
.retention_mode(inner_pool.retention_mode)
|
||||
.task_type(inner_pool.task_type.clone())
|
||||
.shutdown(inner_receiver)
|
||||
.build();
|
||||
|
||||
worker.run_tasks().await
|
||||
});
|
||||
|
||||
if (join_handle.await).is_err() {
|
||||
error!(
|
||||
"Worker {} stopped. Restarting. the number of restarts {}",
|
||||
worker_number, restarts,
|
||||
// Spawn all individual workers per queue
|
||||
for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() {
|
||||
for idx in 0..*num_workers {
|
||||
let mut worker: Worker<AppData> = Worker::new(
|
||||
self.queue_store.clone(),
|
||||
queue_name.clone(),
|
||||
retention_mode.clone(),
|
||||
self.task_registry.clone(),
|
||||
self.application_data_fn.clone(),
|
||||
Some(rx.clone()),
|
||||
);
|
||||
Self::supervise_task(pool, receiver, restarts, worker_number).await;
|
||||
let worker_name = format!("worker-{queue_name}-{idx}");
|
||||
// TODO: grab the join handle for every worker for graceful shutdown
|
||||
tokio::spawn(async move {
|
||||
worker.run_tasks().await;
|
||||
log::info!("Worker {} stopped", worker_name);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
tokio::spawn(async move {
|
||||
graceful_shutdown.await;
|
||||
if let Err(err) = tx.send(()) {
|
||||
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
|
||||
} else {
|
||||
log::info!("Worker pool stopped gracefully");
|
||||
}
|
||||
}),
|
||||
self.queue,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ApplicationContext {
|
||||
app_name: String,
|
||||
}
|
||||
|
||||
impl ApplicationContext {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
app_name: "Backie".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_app_name(&self) -> String {
|
||||
self.app_name.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
struct GreetingTask {
|
||||
person: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for GreetingTask {
|
||||
const TASK_NAME: &'static str = "my_task";
|
||||
|
||||
type AppData = ApplicationContext;
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task_info: CurrentTask,
|
||||
app_context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
println!(
|
||||
"[{}] Hello {}! I'm {}.",
|
||||
task_info.id(),
|
||||
self.person,
|
||||
app_context.get_app_name()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_worker_pool() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let task_store = PgTaskStore::new(pool().await);
|
||||
|
||||
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>(1, RetentionMode::RemoveDone)
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
queue
|
||||
.enqueue(GreetingTask {
|
||||
person: "Rafael".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn pool() -> Pool<AsyncPgConnection> {
|
||||
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
||||
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
|
||||
);
|
||||
Pool::builder()
|
||||
.max_size(1)
|
||||
.min_idle(Some(1))
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue