Make possible to provide app state to tasks
This commit is contained in:
parent
7fcb63f75c
commit
aac0b44c7f
19 changed files with 776 additions and 1151 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,5 +1,4 @@
|
||||||
**/target
|
**/target
|
||||||
Cargo.lock
|
Cargo.lock
|
||||||
src/schema.rs
|
|
||||||
docs/content/docs/CHANGELOG.md
|
docs/content/docs/CHANGELOG.md
|
||||||
docs/content/docs/README.md
|
docs/content/docs/README.md
|
||||||
|
|
11
Cargo.toml
11
Cargo.toml
|
@ -19,16 +19,13 @@ cron = "0.12"
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
serde = "1.0"
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_derive = "1.0"
|
serde_json = "1"
|
||||||
serde_json = "1.0"
|
|
||||||
sha2 = "0.10"
|
sha2 = "0.10"
|
||||||
thiserror = "1.0"
|
anyhow = "1"
|
||||||
typed-builder = "0.13"
|
thiserror = "1"
|
||||||
typetag = "0.2"
|
|
||||||
uuid = { version = "1.1", features = ["v4", "serde"] }
|
uuid = { version = "1.1", features = ["v4", "serde"] }
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
async-recursion = "1"
|
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] }
|
diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] }
|
||||||
diesel-derive-newtype = "2.0.0-rc.0"
|
diesel-derive-newtype = "2.0.0-rc.0"
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
[package]
|
|
||||||
name = "simple_cron_async_worker"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
fang = { path = "../../../" , features = ["asynk"]}
|
|
||||||
env_logger = "0.9.0"
|
|
||||||
log = "0.4.0"
|
|
||||||
tokio = { version = "1", features = ["full"] }
|
|
|
@ -1,33 +0,0 @@
|
||||||
use fang::async_trait;
|
|
||||||
use fang::asynk::async_queue::AsyncQueueable;
|
|
||||||
use fang::serde::{Deserialize, Serialize};
|
|
||||||
use fang::typetag;
|
|
||||||
use fang::AsyncRunnable;
|
|
||||||
use fang::FangError;
|
|
||||||
use fang::Scheduled;
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
|
||||||
#[serde(crate = "fang::serde")]
|
|
||||||
pub struct MyCronTask {}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
#[typetag::serde]
|
|
||||||
impl AsyncRunnable for MyCronTask {
|
|
||||||
async fn run(&self, _queue: &mut dyn AsyncQueueable) -> Result<(), FangError> {
|
|
||||||
log::info!("CRON!!!!!!!!!!!!!!!",);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cron(&self) -> Option<Scheduled> {
|
|
||||||
// sec min hour day of month month day of week year
|
|
||||||
// be careful works only with UTC hour.
|
|
||||||
// https://www.timeanddate.com/worldclock/timezone/utc
|
|
||||||
let expression = "0/20 * * * Aug-Sep * 2022/1";
|
|
||||||
Some(Scheduled::CronPattern(expression.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn uniq(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
use fang::asynk::async_queue::AsyncQueue;
|
|
||||||
use fang::asynk::async_queue::AsyncQueueable;
|
|
||||||
use fang::asynk::async_worker_pool::AsyncWorkerPool;
|
|
||||||
use fang::AsyncRunnable;
|
|
||||||
use fang::NoTls;
|
|
||||||
use simple_cron_async_worker::MyCronTask;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() {
|
|
||||||
env_logger::init();
|
|
||||||
|
|
||||||
log::info!("Starting...");
|
|
||||||
let max_pool_size: u32 = 3;
|
|
||||||
let mut queue = AsyncQueue::builder()
|
|
||||||
.uri("postgres://postgres:postgres@localhost/fang")
|
|
||||||
.max_pool_size(max_pool_size)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
queue.connect(NoTls).await.unwrap();
|
|
||||||
log::info!("Queue connected...");
|
|
||||||
|
|
||||||
let mut pool: AsyncWorkerPool<AsyncQueue<NoTls>> = AsyncWorkerPool::builder()
|
|
||||||
.number_of_workers(10_u32)
|
|
||||||
.queue(queue.clone())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
log::info!("Pool created ...");
|
|
||||||
|
|
||||||
pool.start().await;
|
|
||||||
log::info!("Workers started ...");
|
|
||||||
|
|
||||||
let task = MyCronTask {};
|
|
||||||
|
|
||||||
queue
|
|
||||||
.schedule_task(&task as &dyn AsyncRunnable)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_secs(100)).await;
|
|
||||||
}
|
|
|
@ -5,6 +5,7 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
backie = { path = "../../" }
|
backie = { path = "../../" }
|
||||||
|
anyhow = "1"
|
||||||
env_logger = "0.9.0"
|
env_logger = "0.9.0"
|
||||||
log = "0.4.0"
|
log = "0.4.0"
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
@ -12,4 +13,3 @@ diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
|
||||||
diesel = { version = "2.0", features = ["postgres"] }
|
diesel = { version = "2.0", features = ["postgres"] }
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
typetag = "0.2"
|
|
||||||
|
|
|
@ -1,7 +1,20 @@
|
||||||
use std::time::Duration;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Serialize, Deserialize};
|
use backie::{BackgroundTask, CurrentTask};
|
||||||
use backie::{RunnableTask, Queueable};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct MyApplicationContext {
|
||||||
|
app_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MyApplicationContext {
|
||||||
|
pub fn new(app_name: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
app_name: app_name.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub struct MyTask {
|
pub struct MyTask {
|
||||||
|
@ -26,37 +39,51 @@ impl MyFailingTask {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
#[typetag::serde]
|
impl BackgroundTask for MyTask {
|
||||||
impl RunnableTask for MyTask {
|
const TASK_NAME: &'static str = "my_task";
|
||||||
async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box<dyn std::error::Error + Send + 'static>> {
|
type AppData = MyApplicationContext;
|
||||||
|
|
||||||
|
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
// let new_task = MyTask::new(self.number + 1);
|
// let new_task = MyTask::new(self.number + 1);
|
||||||
// queue
|
// queue
|
||||||
// .insert_task(&new_task as &dyn AsyncRunnable)
|
// .insert_task(&new_task)
|
||||||
// .await
|
// .await
|
||||||
// .unwrap();
|
// .unwrap();
|
||||||
|
|
||||||
log::info!("the current number is {}", self.number);
|
log::info!(
|
||||||
|
"[{}] Hello from {}! the current number is {}",
|
||||||
|
task.id(),
|
||||||
|
ctx.app_name,
|
||||||
|
self.number
|
||||||
|
);
|
||||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||||
|
|
||||||
log::info!("done..");
|
log::info!("[{}] done..", task.id());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
#[typetag::serde]
|
impl BackgroundTask for MyFailingTask {
|
||||||
impl RunnableTask for MyFailingTask {
|
const TASK_NAME: &'static str = "my_failing_task";
|
||||||
async fn run(&self, _queue: &mut dyn Queueable) -> Result<(), Box<dyn std::error::Error + Send + 'static>> {
|
type AppData = MyApplicationContext;
|
||||||
|
|
||||||
|
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
// let new_task = MyFailingTask::new(self.number + 1);
|
// let new_task = MyFailingTask::new(self.number + 1);
|
||||||
// queue
|
// queue
|
||||||
// .insert_task(&new_task as &dyn AsyncRunnable)
|
// .insert_task(&new_task)
|
||||||
// .await
|
// .await
|
||||||
// .unwrap();
|
// .unwrap();
|
||||||
|
|
||||||
log::info!("the current number is {}", self.number);
|
// task.id();
|
||||||
|
// task.keep_alive().await?;
|
||||||
|
// task.previous_error();
|
||||||
|
// task.retry_count();
|
||||||
|
|
||||||
|
log::info!("[{}] the current number is {}", task.id(), self.number);
|
||||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||||
|
|
||||||
log::info!("done..");
|
log::info!("[{}] done..", task.id());
|
||||||
//
|
//
|
||||||
// let b = true;
|
// let b = true;
|
||||||
//
|
//
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
use simple_worker::MyFailingTask;
|
use backie::{PgTaskStore, RetentionMode, WorkerPool};
|
||||||
use simple_worker::MyTask;
|
|
||||||
use std::time::Duration;
|
|
||||||
use diesel_async::pg::AsyncPgConnection;
|
use diesel_async::pg::AsyncPgConnection;
|
||||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||||
use backie::{PgAsyncQueue, WorkerPool, Queueable};
|
use simple_worker::MyApplicationContext;
|
||||||
|
use simple_worker::MyFailingTask;
|
||||||
|
use simple_worker::MyTask;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
|
@ -22,49 +22,38 @@ async fn main() {
|
||||||
.unwrap();
|
.unwrap();
|
||||||
log::info!("Pool created ...");
|
log::info!("Pool created ...");
|
||||||
|
|
||||||
let mut queue = PgAsyncQueue::new(pool);
|
let task_store = PgTaskStore::new(pool);
|
||||||
|
|
||||||
let (tx, mut rx) = tokio::sync::watch::channel(false);
|
let (tx, mut rx) = tokio::sync::watch::channel(false);
|
||||||
|
|
||||||
let executor_task = tokio::spawn({
|
// Some global application context I want to pass to my background tasks
|
||||||
let mut queue = queue.clone();
|
let my_app_context = MyApplicationContext::new("Backie Example App");
|
||||||
async move {
|
|
||||||
let mut workers_pool: WorkerPool<PgAsyncQueue> = WorkerPool::builder()
|
|
||||||
.number_of_workers(10_u32)
|
|
||||||
.queue(queue)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
log::info!("Workers starting ...");
|
// Register the task types I want to use and start the worker pool
|
||||||
workers_pool.start(async move {
|
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||||
rx.changed().await;
|
.register_task_type::<MyTask>(1, RetentionMode::RemoveDone)
|
||||||
}).await;
|
.register_task_type::<MyFailingTask>(1, RetentionMode::RemoveDone)
|
||||||
log::info!("Workers stopped!");
|
.start(async move {
|
||||||
}
|
let _ = rx.changed().await;
|
||||||
});
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
log::info!("Workers started ...");
|
||||||
|
|
||||||
let task1 = MyTask::new(0);
|
let task1 = MyTask::new(0);
|
||||||
let task2 = MyTask::new(20_000);
|
let task2 = MyTask::new(20_000);
|
||||||
let task3 = MyFailingTask::new(50_000);
|
let task3 = MyFailingTask::new(50_000);
|
||||||
|
|
||||||
queue
|
queue.enqueue(task1).await.unwrap();
|
||||||
.create_task(&task1)
|
queue.enqueue(task2).await.unwrap();
|
||||||
.await
|
queue.enqueue(task3).await.unwrap();
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
queue
|
|
||||||
.create_task(&task2)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
queue
|
|
||||||
.create_task(&task3)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
log::info!("Tasks created ...");
|
log::info!("Tasks created ...");
|
||||||
tokio::signal::ctrl_c().await;
|
|
||||||
|
// Wait for Ctrl+C
|
||||||
|
let _ = tokio::signal::ctrl_c().await;
|
||||||
log::info!("Stopping ...");
|
log::info!("Stopping ...");
|
||||||
tx.send(true).unwrap();
|
tx.send(true).unwrap();
|
||||||
executor_task.await.unwrap();
|
join_handle.await.unwrap();
|
||||||
log::info!("Stopped!");
|
log::info!("Workers Stopped!");
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,20 +2,20 @@ CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||||
|
|
||||||
CREATE TABLE backie_tasks (
|
CREATE TABLE backie_tasks (
|
||||||
id uuid PRIMARY KEY DEFAULT uuid_generate_v4(),
|
id uuid PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||||
payload jsonb NOT NULL,
|
task_name VARCHAR NOT NULL,
|
||||||
error_message TEXT DEFAULT NULL,
|
queue_name VARCHAR DEFAULT 'common' NOT NULL,
|
||||||
task_type VARCHAR DEFAULT 'common' NOT NULL,
|
|
||||||
uniq_hash CHAR(64) DEFAULT NULL,
|
uniq_hash CHAR(64) DEFAULT NULL,
|
||||||
retries INTEGER DEFAULT 0 NOT NULL,
|
payload jsonb NOT NULL,
|
||||||
|
timeout_msecs INT8 NOT NULL,
|
||||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
|
scheduled_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
|
||||||
running_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
running_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||||
done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL
|
done_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||||
|
error_info jsonb DEFAULT NULL,
|
||||||
|
retries INTEGER DEFAULT 0 NOT NULL,
|
||||||
|
max_retries INTEGER DEFAULT 0 NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX backie_tasks_type_index ON backie_tasks(task_type);
|
|
||||||
CREATE INDEX backie_tasks_created_at_index ON backie_tasks(created_at);
|
|
||||||
CREATE INDEX backie_tasks_uniq_hash ON backie_tasks(uniq_hash);
|
|
||||||
|
|
||||||
--- create uniqueness index
|
--- create uniqueness index
|
||||||
CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL;
|
CREATE UNIQUE INDEX backie_tasks_uniq_hash_index ON backie_tasks(uniq_hash) WHERE uniq_hash IS NOT NULL;
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,16 @@
|
||||||
use serde_json::Error as SerdeError;
|
|
||||||
use std::fmt::Display;
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
/// Library errors
|
/// Library errors
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum BackieError {
|
pub enum BackieError {
|
||||||
|
#[error("Queue processing error: {0}")]
|
||||||
QueueProcessingError(#[from] AsyncQueueError),
|
QueueProcessingError(#[from] AsyncQueueError),
|
||||||
SerializationError(#[from] SerdeError),
|
|
||||||
ShutdownError(#[from] tokio::sync::watch::error::SendError<()>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for BackieError {
|
#[error("Worker Pool shutdown error: {0}")]
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
WorkerPoolShutdownError(#[from] tokio::sync::watch::error::SendError<()>),
|
||||||
match self {
|
|
||||||
BackieError::QueueProcessingError(error) => {
|
#[error("Worker shutdown error: {0}")]
|
||||||
write!(f, "Queue processing error: {}", error)
|
WorkerShutdownError(#[from] tokio::sync::watch::error::RecvError),
|
||||||
}
|
|
||||||
BackieError::SerializationError(error) => write!(f, "Serialization error: {}", error),
|
|
||||||
BackieError::ShutdownError(error) => write!(f, "Shutdown error: {}", error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List of error types that can occur while working with cron schedules.
|
/// List of error types that can occur while working with cron schedules.
|
||||||
|
@ -29,7 +20,7 @@ pub enum CronError {
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
LibraryError(#[from] cron::error::Error),
|
LibraryError(#[from] cron::error::Error),
|
||||||
/// [`Scheduled`] enum variant is not provided
|
/// [`Scheduled`] enum variant is not provided
|
||||||
#[error("You have to implement method `cron()` in your AsyncRunnable")]
|
#[error("You have to implement method `cron()` in your Runnable")]
|
||||||
TaskNotSchedulableError,
|
TaskNotSchedulableError,
|
||||||
/// The next execution can not be determined using the current [`Scheduled::CronPattern`]
|
/// The next execution can not be determined using the current [`Scheduled::CronPattern`]
|
||||||
#[error("No timestamps match with this cron pattern")]
|
#[error("No timestamps match with this cron pattern")]
|
||||||
|
@ -41,7 +32,7 @@ pub enum AsyncQueueError {
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
PgError(#[from] diesel::result::Error),
|
PgError(#[from] diesel::result::Error),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error("Task serialization error: {0}")]
|
||||||
SerdeError(#[from] serde_json::Error),
|
SerdeError(#[from] serde_json::Error),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
|
@ -49,6 +40,9 @@ pub enum AsyncQueueError {
|
||||||
|
|
||||||
#[error("Task is not in progress, operation not allowed")]
|
#[error("Task is not in progress, operation not allowed")]
|
||||||
TaskNotRunning,
|
TaskNotRunning,
|
||||||
|
|
||||||
|
#[error("Task with name {0} is not registered")]
|
||||||
|
TaskNotRegistered(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<cron::error::Error> for AsyncQueueError {
|
impl From<cron::error::Error> for AsyncQueueError {
|
||||||
|
|
|
@ -38,9 +38,9 @@ impl Default for RetentionMode {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub use queue::PgAsyncQueue;
|
pub use queue::PgTaskStore;
|
||||||
pub use queue::Queueable;
|
pub use runnable::BackgroundTask;
|
||||||
pub use runnable::RunnableTask;
|
pub use task::CurrentTask;
|
||||||
pub use worker_pool::WorkerPool;
|
pub use worker_pool::WorkerPool;
|
||||||
|
|
||||||
pub mod errors;
|
pub mod errors;
|
||||||
|
@ -48,6 +48,7 @@ mod queries;
|
||||||
pub mod queue;
|
pub mod queue;
|
||||||
pub mod runnable;
|
pub mod runnable;
|
||||||
mod schema;
|
mod schema;
|
||||||
|
pub mod store;
|
||||||
pub mod task;
|
pub mod task;
|
||||||
pub mod worker;
|
pub mod worker;
|
||||||
pub mod worker_pool;
|
pub mod worker_pool;
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
use crate::errors::AsyncQueueError;
|
use crate::errors::AsyncQueueError;
|
||||||
use crate::runnable::RunnableTask;
|
|
||||||
use crate::schema::backie_tasks;
|
use crate::schema::backie_tasks;
|
||||||
use crate::task::Task;
|
use crate::task::Task;
|
||||||
use crate::task::{NewTask, TaskHash, TaskId, TaskType};
|
use crate::task::{NewTask, TaskHash, TaskId};
|
||||||
use chrono::DateTime;
|
use chrono::DateTime;
|
||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
|
@ -43,14 +42,6 @@ impl Task {
|
||||||
Ok(qty > 0)
|
Ok(qty > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn remove_by_type(
|
|
||||||
connection: &mut AsyncPgConnection,
|
|
||||||
task_type: TaskType,
|
|
||||||
) -> Result<u64, AsyncQueueError> {
|
|
||||||
let query = backie_tasks::table.filter(backie_tasks::task_type.eq(task_type));
|
|
||||||
Ok(diesel::delete(query).execute(connection).await? as u64)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn find_by_id(
|
pub(crate) async fn find_by_id(
|
||||||
connection: &mut AsyncPgConnection,
|
connection: &mut AsyncPgConnection,
|
||||||
id: TaskId,
|
id: TaskId,
|
||||||
|
@ -67,10 +58,13 @@ impl Task {
|
||||||
id: TaskId,
|
id: TaskId,
|
||||||
error_message: &str,
|
error_message: &str,
|
||||||
) -> Result<Task, AsyncQueueError> {
|
) -> Result<Task, AsyncQueueError> {
|
||||||
|
let error = serde_json::json!({
|
||||||
|
"error": error_message,
|
||||||
|
});
|
||||||
let query = backie_tasks::table.filter(backie_tasks::id.eq(id));
|
let query = backie_tasks::table.filter(backie_tasks::id.eq(id));
|
||||||
Ok(diesel::update(query)
|
Ok(diesel::update(query)
|
||||||
.set((
|
.set((
|
||||||
backie_tasks::error_message.eq(error_message),
|
backie_tasks::error_info.eq(Some(error)),
|
||||||
backie_tasks::done_at.eq(Utc::now()),
|
backie_tasks::done_at.eq(Utc::now()),
|
||||||
))
|
))
|
||||||
.get_result::<Task>(connection)
|
.get_result::<Task>(connection)
|
||||||
|
@ -81,18 +75,22 @@ impl Task {
|
||||||
connection: &mut AsyncPgConnection,
|
connection: &mut AsyncPgConnection,
|
||||||
id: TaskId,
|
id: TaskId,
|
||||||
backoff_seconds: u32,
|
backoff_seconds: u32,
|
||||||
error: &str,
|
error_message: &str,
|
||||||
) -> Result<Task, AsyncQueueError> {
|
) -> Result<Task, AsyncQueueError> {
|
||||||
use crate::schema::backie_tasks::dsl;
|
use crate::schema::backie_tasks::dsl;
|
||||||
|
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
let scheduled_at = now + Duration::seconds(backoff_seconds as i64);
|
let scheduled_at = now + Duration::seconds(backoff_seconds as i64);
|
||||||
|
|
||||||
|
let error = serde_json::json!({
|
||||||
|
"error": error_message,
|
||||||
|
});
|
||||||
|
|
||||||
let task = diesel::update(backie_tasks::table.filter(backie_tasks::id.eq(id)))
|
let task = diesel::update(backie_tasks::table.filter(backie_tasks::id.eq(id)))
|
||||||
.set((
|
.set((
|
||||||
backie_tasks::error_message.eq(error),
|
backie_tasks::error_info.eq(Some(error)),
|
||||||
backie_tasks::retries.eq(dsl::retries + 1),
|
backie_tasks::retries.eq(dsl::retries + 1),
|
||||||
backie_tasks::created_at.eq(scheduled_at),
|
backie_tasks::scheduled_at.eq(scheduled_at),
|
||||||
backie_tasks::running_at.eq::<Option<DateTime<Utc>>>(None),
|
backie_tasks::running_at.eq::<Option<DateTime<Utc>>>(None),
|
||||||
))
|
))
|
||||||
.get_result::<Task>(connection)
|
.get_result::<Task>(connection)
|
||||||
|
@ -103,14 +101,14 @@ impl Task {
|
||||||
|
|
||||||
pub(crate) async fn fetch_next_pending(
|
pub(crate) async fn fetch_next_pending(
|
||||||
connection: &mut AsyncPgConnection,
|
connection: &mut AsyncPgConnection,
|
||||||
task_type: TaskType,
|
queue_name: &str,
|
||||||
) -> Option<Task> {
|
) -> Option<Task> {
|
||||||
backie_tasks::table
|
backie_tasks::table
|
||||||
.filter(backie_tasks::created_at.lt(Utc::now())) // skip tasks scheduled for the future
|
.filter(backie_tasks::scheduled_at.lt(Utc::now())) // skip tasks scheduled for the future
|
||||||
.order(backie_tasks::created_at.asc()) // get the oldest task first
|
.order(backie_tasks::created_at.asc()) // get the oldest task first
|
||||||
.filter(backie_tasks::running_at.is_null()) // that is not marked as running already
|
.filter(backie_tasks::running_at.is_null()) // that is not marked as running already
|
||||||
.filter(backie_tasks::done_at.is_null()) // and not marked as done
|
.filter(backie_tasks::done_at.is_null()) // and not marked as done
|
||||||
.filter(backie_tasks::task_type.eq(task_type))
|
.filter(backie_tasks::queue_name.eq(queue_name))
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.for_update()
|
.for_update()
|
||||||
.skip_locked()
|
.skip_locked()
|
||||||
|
@ -143,39 +141,13 @@ impl Task {
|
||||||
|
|
||||||
pub(crate) async fn insert(
|
pub(crate) async fn insert(
|
||||||
connection: &mut AsyncPgConnection,
|
connection: &mut AsyncPgConnection,
|
||||||
runnable: &dyn RunnableTask,
|
new_task: NewTask,
|
||||||
) -> Result<Task, AsyncQueueError> {
|
) -> Result<Task, AsyncQueueError> {
|
||||||
let payload = serde_json::to_value(runnable)?;
|
|
||||||
match runnable.uniq() {
|
|
||||||
None => {
|
|
||||||
let new_task = NewTask::builder()
|
|
||||||
.uniq_hash(None)
|
|
||||||
.task_type(runnable.task_type())
|
|
||||||
.payload(payload)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Ok(diesel::insert_into(backie_tasks::table)
|
Ok(diesel::insert_into(backie_tasks::table)
|
||||||
.values(new_task)
|
.values(new_task)
|
||||||
.get_result::<Task>(connection)
|
.get_result::<Task>(connection)
|
||||||
.await?)
|
.await?)
|
||||||
}
|
}
|
||||||
Some(hash) => match Self::find_by_uniq_hash(connection, hash.clone()).await {
|
|
||||||
Some(task) => Ok(task),
|
|
||||||
None => {
|
|
||||||
let new_task = NewTask::builder()
|
|
||||||
.uniq_hash(Some(hash))
|
|
||||||
.task_type(runnable.task_type())
|
|
||||||
.payload(payload)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Ok(diesel::insert_into(backie_tasks::table)
|
|
||||||
.values(new_task)
|
|
||||||
.get_result::<Task>(connection)
|
|
||||||
.await?)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn find_by_uniq_hash(
|
pub(crate) async fn find_by_uniq_hash(
|
||||||
connection: &mut AsyncPgConnection,
|
connection: &mut AsyncPgConnection,
|
||||||
|
|
562
src/queue.rs
562
src/queue.rs
|
@ -1,86 +1,48 @@
|
||||||
use crate::errors::AsyncQueueError;
|
use crate::errors::AsyncQueueError;
|
||||||
use crate::runnable::RunnableTask;
|
use crate::runnable::BackgroundTask;
|
||||||
use crate::task::{Task, TaskHash, TaskId, TaskType};
|
use crate::task::{NewTask, Task, TaskHash, TaskId, TaskState};
|
||||||
use async_trait::async_trait;
|
|
||||||
use diesel::result::Error::QueryBuilderError;
|
use diesel::result::Error::QueryBuilderError;
|
||||||
use diesel_async::scoped_futures::ScopedFutureExt;
|
use diesel_async::scoped_futures::ScopedFutureExt;
|
||||||
use diesel_async::AsyncConnection;
|
use diesel_async::AsyncConnection;
|
||||||
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
|
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
/// This trait defines operations for an asynchronous queue.
|
#[derive(Clone)]
|
||||||
/// The trait can be implemented for different storage backends.
|
pub struct Queue {
|
||||||
/// For now, the trait is only implemented for PostgreSQL. More backends are planned to be implemented in the future.
|
task_store: Arc<PgTaskStore>,
|
||||||
#[async_trait]
|
}
|
||||||
pub trait Queueable: Send {
|
|
||||||
/// Pull pending tasks from the queue to execute them.
|
|
||||||
///
|
|
||||||
/// This method returns one task of the `task_type` type. If `task_type` is `None` it will try to
|
|
||||||
/// fetch a task of the type `common`. The returned task is marked as running and must be executed.
|
|
||||||
async fn pull_next_task(
|
|
||||||
&mut self,
|
|
||||||
kind: Option<TaskType>,
|
|
||||||
) -> Result<Option<Task>, AsyncQueueError>;
|
|
||||||
|
|
||||||
/// Enqueue a task to the queue, The task will be executed as soon as possible by the worker of the same type
|
impl Queue {
|
||||||
/// created by an AsyncWorkerPool.
|
pub(crate) fn new(task_store: Arc<PgTaskStore>) -> Self {
|
||||||
async fn create_task(&mut self, task: &dyn RunnableTask) -> Result<Task, AsyncQueueError>;
|
Queue { task_store }
|
||||||
|
}
|
||||||
|
|
||||||
/// Retrieve a task by its `id`.
|
pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), AsyncQueueError>
|
||||||
async fn find_task_by_id(&mut self, id: TaskId) -> Result<Task, AsyncQueueError>;
|
where
|
||||||
|
BT: BackgroundTask,
|
||||||
/// Update the state of a task to failed and set an error_message.
|
{
|
||||||
async fn set_task_failed(
|
self.task_store
|
||||||
&mut self,
|
.create_task(NewTask::new(background_task, Duration::from_secs(10))?)
|
||||||
id: TaskId,
|
.await?;
|
||||||
error_message: &str,
|
Ok(())
|
||||||
) -> Result<Task, AsyncQueueError>;
|
}
|
||||||
|
|
||||||
/// Update the state of a task to done.
|
|
||||||
async fn set_task_done(&mut self, id: TaskId) -> Result<Task, AsyncQueueError>;
|
|
||||||
|
|
||||||
/// Update the state of a task to inform that it's still in progress.
|
|
||||||
async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError>;
|
|
||||||
|
|
||||||
/// Remove a task by its id.
|
|
||||||
async fn remove_task(&mut self, id: TaskId) -> Result<u64, AsyncQueueError>;
|
|
||||||
|
|
||||||
/// The method will remove all tasks from the queue
|
|
||||||
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError>;
|
|
||||||
|
|
||||||
/// Remove all tasks that are scheduled in the future.
|
|
||||||
async fn remove_all_scheduled_tasks(&mut self) -> Result<u64, AsyncQueueError>;
|
|
||||||
|
|
||||||
/// Remove a task by its metadata (struct fields values)
|
|
||||||
async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result<bool, AsyncQueueError>;
|
|
||||||
|
|
||||||
/// Removes all tasks that have the specified `task_type`.
|
|
||||||
async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result<u64, AsyncQueueError>;
|
|
||||||
|
|
||||||
async fn schedule_task_retry(
|
|
||||||
&mut self,
|
|
||||||
id: TaskId,
|
|
||||||
backoff_seconds: u32,
|
|
||||||
error: &str,
|
|
||||||
) -> Result<Task, AsyncQueueError>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
|
/// An async queue that is used to manipulate tasks, it uses PostgreSQL as storage.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct PgAsyncQueue {
|
pub struct PgTaskStore {
|
||||||
pool: Pool<AsyncPgConnection>,
|
pool: Pool<AsyncPgConnection>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PgAsyncQueue {
|
impl PgTaskStore {
|
||||||
pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
|
pub fn new(pool: Pool<AsyncPgConnection>) -> Self {
|
||||||
PgAsyncQueue { pool }
|
PgTaskStore { pool }
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
pub(crate) async fn pull_next_task(
|
||||||
impl Queueable for PgAsyncQueue {
|
&self,
|
||||||
async fn pull_next_task(
|
queue_name: &str,
|
||||||
&mut self,
|
|
||||||
task_type: Option<TaskType>,
|
|
||||||
) -> Result<Option<Task>, AsyncQueueError> {
|
) -> Result<Option<Task>, AsyncQueueError> {
|
||||||
let mut connection = self
|
let mut connection = self
|
||||||
.pool
|
.pool
|
||||||
|
@ -90,7 +52,7 @@ impl Queueable for PgAsyncQueue {
|
||||||
connection
|
connection
|
||||||
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
|
.transaction::<Option<Task>, AsyncQueueError, _>(|conn| {
|
||||||
async move {
|
async move {
|
||||||
let Some(pending_task) = Task::fetch_next_pending(conn, task_type.unwrap_or_default()).await else {
|
let Some(pending_task) = Task::fetch_next_pending(conn, queue_name).await else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -101,16 +63,16 @@ impl Queueable for PgAsyncQueue {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn create_task(&mut self, runnable: &dyn RunnableTask) -> Result<Task, AsyncQueueError> {
|
pub(crate) async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
|
||||||
let mut connection = self
|
let mut connection = self
|
||||||
.pool
|
.pool
|
||||||
.get()
|
.get()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||||
Ok(Task::insert(&mut connection, runnable).await?)
|
Task::insert(&mut connection, new_task).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn find_task_by_id(&mut self, id: TaskId) -> Result<Task, AsyncQueueError> {
|
pub(crate) async fn find_task_by_id(&self, id: TaskId) -> Result<Task, AsyncQueueError> {
|
||||||
let mut connection = self
|
let mut connection = self
|
||||||
.pool
|
.pool
|
||||||
.get()
|
.get()
|
||||||
|
@ -119,29 +81,27 @@ impl Queueable for PgAsyncQueue {
|
||||||
Task::find_by_id(&mut connection, id).await
|
Task::find_by_id(&mut connection, id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn set_task_failed(
|
pub(crate) async fn set_task_state(
|
||||||
&mut self,
|
&self,
|
||||||
id: TaskId,
|
id: TaskId,
|
||||||
error_message: &str,
|
state: TaskState,
|
||||||
) -> Result<Task, AsyncQueueError> {
|
) -> Result<(), AsyncQueueError> {
|
||||||
let mut connection = self
|
let mut connection = self
|
||||||
.pool
|
.pool
|
||||||
.get()
|
.get()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||||
Task::fail_with_message(&mut connection, id, error_message).await
|
match state {
|
||||||
|
TaskState::Done => Task::set_done(&mut connection, id).await?,
|
||||||
|
TaskState::Failed(error_msg) => {
|
||||||
|
Task::fail_with_message(&mut connection, id, &error_msg).await?
|
||||||
|
}
|
||||||
|
_ => return Ok(()),
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn set_task_done(&mut self, id: TaskId) -> Result<Task, AsyncQueueError> {
|
pub(crate) async fn keep_task_alive(&self, id: TaskId) -> Result<(), AsyncQueueError> {
|
||||||
let mut connection = self
|
|
||||||
.pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
|
||||||
Task::set_done(&mut connection, id).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn keep_task_alive(&mut self, id: TaskId) -> Result<(), AsyncQueueError> {
|
|
||||||
let mut connection = self
|
let mut connection = self
|
||||||
.pool
|
.pool
|
||||||
.get()
|
.get()
|
||||||
|
@ -159,7 +119,7 @@ impl Queueable for PgAsyncQueue {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_task(&mut self, id: TaskId) -> Result<u64, AsyncQueueError> {
|
pub(crate) async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
|
||||||
let mut connection = self
|
let mut connection = self
|
||||||
.pool
|
.pool
|
||||||
.get()
|
.get()
|
||||||
|
@ -169,7 +129,7 @@ impl Queueable for PgAsyncQueue {
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError> {
|
pub(crate) async fn remove_all_tasks(&self) -> Result<u64, AsyncQueueError> {
|
||||||
let mut connection = self
|
let mut connection = self
|
||||||
.pool
|
.pool
|
||||||
.get()
|
.get()
|
||||||
|
@ -178,37 +138,8 @@ impl Queueable for PgAsyncQueue {
|
||||||
Task::remove_all(&mut connection).await
|
Task::remove_all(&mut connection).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn remove_all_scheduled_tasks(&mut self) -> Result<u64, AsyncQueueError> {
|
pub(crate) async fn schedule_task_retry(
|
||||||
let mut connection = self
|
&self,
|
||||||
.pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
|
||||||
let result = Task::remove_all_scheduled(&mut connection).await?;
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn remove_task_by_hash(&mut self, task_hash: TaskHash) -> Result<bool, AsyncQueueError> {
|
|
||||||
let mut connection = self
|
|
||||||
.pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
|
||||||
Task::remove_by_hash(&mut connection, task_hash).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn remove_tasks_type(&mut self, task_type: TaskType) -> Result<u64, AsyncQueueError> {
|
|
||||||
let mut connection = self
|
|
||||||
.pool
|
|
||||||
.get()
|
|
||||||
.await
|
|
||||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
|
||||||
let result = Task::remove_by_type(&mut connection, task_type).await?;
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn schedule_task_retry(
|
|
||||||
&mut self,
|
|
||||||
id: TaskId,
|
id: TaskId,
|
||||||
backoff_seconds: u32,
|
backoff_seconds: u32,
|
||||||
error: &str,
|
error: &str,
|
||||||
|
@ -226,11 +157,8 @@ impl Queueable for PgAsyncQueue {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod async_queue_tests {
|
mod async_queue_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::task::TaskState;
|
use crate::CurrentTask;
|
||||||
use crate::Scheduled;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::DateTime;
|
|
||||||
use chrono::Utc;
|
|
||||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||||
use diesel_async::AsyncPgConnection;
|
use diesel_async::AsyncPgConnection;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
@ -240,13 +168,12 @@ mod async_queue_tests {
|
||||||
pub number: u16,
|
pub number: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for AsyncTask {
|
impl BackgroundTask for AsyncTask {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -256,13 +183,12 @@ mod async_queue_tests {
|
||||||
pub number: u16,
|
pub number: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for AsyncUniqTask {
|
impl BackgroundTask for AsyncUniqTask {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -277,286 +203,154 @@ mod async_queue_tests {
|
||||||
pub datetime: String,
|
pub datetime: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for AsyncTaskSchedule {
|
impl BackgroundTask for AsyncTaskSchedule {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cron(&self) -> Option<Scheduled> {
|
// fn cron(&self) -> Option<Scheduled> {
|
||||||
let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
|
// let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
|
||||||
Some(Scheduled::ScheduleOnce(datetime))
|
// Some(Scheduled::ScheduleOnce(datetime))
|
||||||
}
|
// }
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn insert_task_creates_new_task() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
|
||||||
|
|
||||||
let metadata = task.payload.as_object().unwrap();
|
|
||||||
let number = metadata["number"].as_u64();
|
|
||||||
let type_task = metadata["type"].as_str();
|
|
||||||
|
|
||||||
assert_eq!(Some(1), number);
|
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn update_task_state_test() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
|
||||||
|
|
||||||
let metadata = task.payload.as_object().unwrap();
|
|
||||||
let number = metadata["number"].as_u64();
|
|
||||||
let type_task = metadata["type"].as_str();
|
|
||||||
let id = task.id;
|
|
||||||
|
|
||||||
assert_eq!(Some(1), number);
|
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
|
||||||
|
|
||||||
let finished_task = test.set_task_done(task.id).await.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(id, finished_task.id);
|
|
||||||
assert_eq!(TaskState::Done, finished_task.state());
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn failed_task_query_test() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
|
||||||
|
|
||||||
let metadata = task.payload.as_object().unwrap();
|
|
||||||
let number = metadata["number"].as_u64();
|
|
||||||
let type_task = metadata["type"].as_str();
|
|
||||||
let id = task.id;
|
|
||||||
|
|
||||||
assert_eq!(Some(1), number);
|
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
|
||||||
|
|
||||||
let failed_task = test.set_task_failed(task.id, "Some error").await.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(id, failed_task.id);
|
|
||||||
assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
|
|
||||||
assert_eq!(TaskState::Failed, failed_task.state());
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn remove_all_tasks_test() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
|
||||||
|
|
||||||
let metadata = task.payload.as_object().unwrap();
|
|
||||||
let number = metadata["number"].as_u64();
|
|
||||||
let type_task = metadata["type"].as_str();
|
|
||||||
|
|
||||||
assert_eq!(Some(1), number);
|
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
|
||||||
|
|
||||||
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
|
||||||
|
|
||||||
let metadata = task.payload.as_object().unwrap();
|
|
||||||
let number = metadata["number"].as_u64();
|
|
||||||
let type_task = metadata["type"].as_str();
|
|
||||||
|
|
||||||
assert_eq!(Some(2), number);
|
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
|
||||||
|
|
||||||
let result = test.remove_all_tasks().await.unwrap();
|
|
||||||
assert_eq!(2, result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[tokio::test]
|
// #[tokio::test]
|
||||||
// async fn schedule_task_test() {
|
// async fn insert_task_creates_new_task() {
|
||||||
// let pool = pool().await;
|
// let pool = pool().await;
|
||||||
// let mut test = PgAsyncQueue::new(pool);
|
// let mut queue = PgTaskStore::new(pool);
|
||||||
//
|
//
|
||||||
// let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0);
|
// let task = queue.create_task(AsyncTask { number: 1 }).await.unwrap();
|
||||||
//
|
|
||||||
// let task = &AsyncTaskSchedule {
|
|
||||||
// number: 1,
|
|
||||||
// datetime: datetime.to_string(),
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// let task = test.schedule_task(task).await.unwrap();
|
|
||||||
//
|
//
|
||||||
// let metadata = task.payload.as_object().unwrap();
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
// let number = metadata["number"].as_u64();
|
// let number = metadata["number"].as_u64();
|
||||||
// let type_task = metadata["type"].as_str();
|
// let type_task = metadata["type"].as_str();
|
||||||
//
|
//
|
||||||
// assert_eq!(Some(1), number);
|
// assert_eq!(Some(1), number);
|
||||||
// assert_eq!(Some("AsyncTaskSchedule"), type_task);
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
// assert_eq!(task.scheduled_at, datetime);
|
//
|
||||||
|
// queue.remove_all_tasks().await.unwrap();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// #[tokio::test]
|
||||||
|
// async fn update_task_state_test() {
|
||||||
|
// let pool = pool().await;
|
||||||
|
// let mut test = PgTaskStore::new(pool);
|
||||||
|
//
|
||||||
|
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||||
|
//
|
||||||
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
|
// let number = metadata["number"].as_u64();
|
||||||
|
// let type_task = metadata["type"].as_str();
|
||||||
|
// let id = task.id;
|
||||||
|
//
|
||||||
|
// assert_eq!(Some(1), number);
|
||||||
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
|
//
|
||||||
|
// let finished_task = test.set_task_state(task.id, TaskState::Done).await.unwrap();
|
||||||
|
//
|
||||||
|
// assert_eq!(id, finished_task.id);
|
||||||
|
// assert_eq!(TaskState::Done, finished_task.state());
|
||||||
//
|
//
|
||||||
// test.remove_all_tasks().await.unwrap();
|
// test.remove_all_tasks().await.unwrap();
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// #[tokio::test]
|
// #[tokio::test]
|
||||||
// async fn remove_all_scheduled_tasks_test() {
|
// async fn failed_task_query_test() {
|
||||||
// let pool = pool().await;
|
// let pool = pool().await;
|
||||||
// let mut test = PgAsyncQueue::new(pool);
|
// let mut test = PgTaskStore::new(pool);
|
||||||
//
|
//
|
||||||
// let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0);
|
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||||
//
|
//
|
||||||
// let task1 = &AsyncTaskSchedule {
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
// number: 1,
|
// let number = metadata["number"].as_u64();
|
||||||
// datetime: datetime.to_string(),
|
// let type_task = metadata["type"].as_str();
|
||||||
// };
|
// let id = task.id;
|
||||||
//
|
//
|
||||||
// let task2 = &AsyncTaskSchedule {
|
// assert_eq!(Some(1), number);
|
||||||
// number: 2,
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
// datetime: datetime.to_string(),
|
|
||||||
// };
|
|
||||||
//
|
//
|
||||||
// test.schedule_task(task1).await.unwrap();
|
// let failed_task = test.set_task_state(task.id, TaskState::Failed("Some error".to_string())).await.unwrap();
|
||||||
// test.schedule_task(task2).await.unwrap();
|
|
||||||
//
|
//
|
||||||
// let number = test.remove_all_scheduled_tasks().await.unwrap();
|
// assert_eq!(id, failed_task.id);
|
||||||
//
|
// assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
|
||||||
// assert_eq!(2, number);
|
// assert_eq!(TaskState::Failed, failed_task.state());
|
||||||
//
|
//
|
||||||
// test.remove_all_tasks().await.unwrap();
|
// test.remove_all_tasks().await.unwrap();
|
||||||
// }
|
// }
|
||||||
|
//
|
||||||
#[tokio::test]
|
// #[tokio::test]
|
||||||
async fn pull_next_task_test() {
|
// async fn remove_all_tasks_test() {
|
||||||
let pool = pool().await;
|
// let pool = pool().await;
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
// let mut test = PgTaskStore::new(pool);
|
||||||
|
//
|
||||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||||
|
//
|
||||||
let metadata = task.payload.as_object().unwrap();
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
let number = metadata["number"].as_u64();
|
// let number = metadata["number"].as_u64();
|
||||||
let type_task = metadata["type"].as_str();
|
// let type_task = metadata["type"].as_str();
|
||||||
|
//
|
||||||
assert_eq!(Some(1), number);
|
// assert_eq!(Some(1), number);
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
|
//
|
||||||
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
// let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||||
|
//
|
||||||
let metadata = task.payload.as_object().unwrap();
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
let number = metadata["number"].as_u64();
|
// let number = metadata["number"].as_u64();
|
||||||
let type_task = metadata["type"].as_str();
|
// let type_task = metadata["type"].as_str();
|
||||||
|
//
|
||||||
assert_eq!(Some(2), number);
|
// assert_eq!(Some(2), number);
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
|
//
|
||||||
let task = test.pull_next_task(None).await.unwrap().unwrap();
|
// let result = test.remove_all_tasks().await.unwrap();
|
||||||
|
// assert_eq!(2, result);
|
||||||
let metadata = task.payload.as_object().unwrap();
|
// }
|
||||||
let number = metadata["number"].as_u64();
|
//
|
||||||
let type_task = metadata["type"].as_str();
|
// #[tokio::test]
|
||||||
|
// async fn pull_next_task_test() {
|
||||||
assert_eq!(Some(1), number);
|
// let pool = pool().await;
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
// let mut queue = PgTaskStore::new(pool);
|
||||||
|
//
|
||||||
let task = test.pull_next_task(None).await.unwrap().unwrap();
|
// let task = queue.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||||
let metadata = task.payload.as_object().unwrap();
|
//
|
||||||
let number = metadata["number"].as_u64();
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
let type_task = metadata["type"].as_str();
|
// let number = metadata["number"].as_u64();
|
||||||
|
// let type_task = metadata["type"].as_str();
|
||||||
assert_eq!(Some(2), number);
|
//
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
// assert_eq!(Some(1), number);
|
||||||
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
test.remove_all_tasks().await.unwrap();
|
//
|
||||||
}
|
// let task = queue.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||||
|
//
|
||||||
#[tokio::test]
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
async fn remove_tasks_type_test() {
|
// let number = metadata["number"].as_u64();
|
||||||
let pool = pool().await;
|
// let type_task = metadata["type"].as_str();
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
//
|
||||||
|
// assert_eq!(Some(2), number);
|
||||||
let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
|
//
|
||||||
let metadata = task.payload.as_object().unwrap();
|
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
|
||||||
let number = metadata["number"].as_u64();
|
//
|
||||||
let type_task = metadata["type"].as_str();
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
|
// let number = metadata["number"].as_u64();
|
||||||
assert_eq!(Some(1), number);
|
// let type_task = metadata["type"].as_str();
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
//
|
||||||
|
// assert_eq!(Some(1), number);
|
||||||
let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
|
//
|
||||||
let metadata = task.payload.as_object().unwrap();
|
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
|
||||||
let number = metadata["number"].as_u64();
|
// let metadata = task.payload.as_object().unwrap();
|
||||||
let type_task = metadata["type"].as_str();
|
// let number = metadata["number"].as_u64();
|
||||||
|
// let type_task = metadata["type"].as_str();
|
||||||
assert_eq!(Some(2), number);
|
//
|
||||||
assert_eq!(Some("AsyncTask"), type_task);
|
// assert_eq!(Some(2), number);
|
||||||
|
// assert_eq!(Some("AsyncTask"), type_task);
|
||||||
let result = test
|
//
|
||||||
.remove_tasks_type(TaskType::from("nonexistentType"))
|
// queue.remove_all_tasks().await.unwrap();
|
||||||
.await
|
// }
|
||||||
.unwrap();
|
|
||||||
assert_eq!(0, result);
|
|
||||||
|
|
||||||
let result = test.remove_tasks_type(TaskType::default()).await.unwrap();
|
|
||||||
assert_eq!(2, result);
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn remove_tasks_by_metadata() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let task = test
|
|
||||||
.create_task(&AsyncUniqTask { number: 1 })
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let metadata = task.payload.as_object().unwrap();
|
|
||||||
let number = metadata["number"].as_u64();
|
|
||||||
let type_task = metadata["type"].as_str();
|
|
||||||
|
|
||||||
assert_eq!(Some(1), number);
|
|
||||||
assert_eq!(Some("AsyncUniqTask"), type_task);
|
|
||||||
|
|
||||||
let task = test
|
|
||||||
.create_task(&AsyncUniqTask { number: 2 })
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let metadata = task.payload.as_object().unwrap();
|
|
||||||
let number = metadata["number"].as_u64();
|
|
||||||
let type_task = metadata["type"].as_str();
|
|
||||||
|
|
||||||
assert_eq!(Some(2), number);
|
|
||||||
assert_eq!(Some("AsyncUniqTask"), type_task);
|
|
||||||
|
|
||||||
let result = test
|
|
||||||
.remove_task_by_hash(AsyncUniqTask { number: 0 }.uniq().unwrap())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(!result, "Should **not** remove task");
|
|
||||||
|
|
||||||
let result = test
|
|
||||||
.remove_task_by_hash(AsyncUniqTask { number: 1 }.uniq().unwrap())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert!(result, "Should remove task");
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn pool() -> Pool<AsyncPgConnection> {
|
async fn pool() -> Pool<AsyncPgConnection> {
|
||||||
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
||||||
|
|
|
@ -1,27 +1,34 @@
|
||||||
use crate::queue::Queueable;
|
use crate::task::{CurrentTask, TaskHash};
|
||||||
use crate::task::TaskHash;
|
|
||||||
use crate::task::TaskType;
|
|
||||||
use crate::Scheduled;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::error::Error;
|
use serde::{de::DeserializeOwned, ser::Serialize};
|
||||||
|
|
||||||
pub const RETRIES_NUMBER: i32 = 5;
|
|
||||||
|
|
||||||
/// Task that can be executed by the queue.
|
/// Task that can be executed by the queue.
|
||||||
///
|
///
|
||||||
/// The `RunnableTask` trait is used to define the behaviour of a task. You must implement this
|
/// The `BackgroundTask` trait is used to define the behaviour of a task. You must implement this
|
||||||
/// trait for all tasks you want to execute.
|
/// trait for all tasks you want to execute.
|
||||||
#[typetag::serde(tag = "type")]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait RunnableTask: Send + Sync {
|
pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
|
||||||
/// Execute the task. This method should define its logic
|
/// Unique name of the task.
|
||||||
async fn run(&self, queue: &mut dyn Queueable) -> Result<(), Box<dyn Error + Send + 'static>>;
|
///
|
||||||
|
/// This MUST be unique for the whole application.
|
||||||
|
const TASK_NAME: &'static str;
|
||||||
|
|
||||||
/// Define the type of the task.
|
/// Task queue where this task will be executed.
|
||||||
/// The `common` task type is used by default
|
///
|
||||||
fn task_type(&self) -> TaskType {
|
/// Used to define which workers are going to be executing this task. It uses the default
|
||||||
TaskType::default()
|
/// task queue if not changed.
|
||||||
}
|
const QUEUE: &'static str = "default";
|
||||||
|
|
||||||
|
/// Number of retries for tasks.
|
||||||
|
///
|
||||||
|
/// By default, it is set to 5.
|
||||||
|
const MAX_RETRIES: i32 = 5;
|
||||||
|
|
||||||
|
/// The application data provided to this task at runtime.
|
||||||
|
type AppData: Clone + Send + 'static;
|
||||||
|
|
||||||
|
/// Execute the task. This method should define its logic
|
||||||
|
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>;
|
||||||
|
|
||||||
/// If set to true, no new tasks with the same metadata will be inserted
|
/// If set to true, no new tasks with the same metadata will be inserted
|
||||||
/// By default it is set to false.
|
/// By default it is set to false.
|
||||||
|
@ -29,27 +36,10 @@ pub trait RunnableTask: Send + Sync {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This method defines if a task is periodic or it should be executed once in the future.
|
|
||||||
///
|
|
||||||
/// Be careful it works only with the UTC timezone.
|
|
||||||
///
|
|
||||||
/// Example:
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// fn cron(&self) -> Option<Scheduled> {
|
|
||||||
/// let expression = "0/20 * * * Aug-Sep * 2022/1";
|
|
||||||
/// Some(Scheduled::CronPattern(expression.to_string()))
|
|
||||||
/// }
|
|
||||||
///```
|
|
||||||
/// In order to schedule a task once, use the `Scheduled::ScheduleOnce` enum variant.
|
|
||||||
fn cron(&self) -> Option<Scheduled> {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Define the maximum number of retries the task will be retried.
|
/// Define the maximum number of retries the task will be retried.
|
||||||
/// By default the number of retries is 20.
|
/// By default the number of retries is 20.
|
||||||
fn max_retries(&self) -> i32 {
|
fn max_retries(&self) -> i32 {
|
||||||
RETRIES_NUMBER
|
Self::MAX_RETRIES
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Define the backoff mode
|
/// Define the backoff mode
|
||||||
|
|
19
src/schema.rs
Normal file
19
src/schema.rs
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
// @generated automatically by Diesel CLI.
|
||||||
|
|
||||||
|
diesel::table! {
|
||||||
|
backie_tasks (id) {
|
||||||
|
id -> Uuid,
|
||||||
|
task_name -> Varchar,
|
||||||
|
queue_name -> Varchar,
|
||||||
|
uniq_hash -> Nullable<Bpchar>,
|
||||||
|
payload -> Jsonb,
|
||||||
|
timeout_msecs -> Int8,
|
||||||
|
created_at -> Timestamptz,
|
||||||
|
scheduled_at -> Timestamptz,
|
||||||
|
running_at -> Nullable<Timestamptz>,
|
||||||
|
done_at -> Nullable<Timestamptz>,
|
||||||
|
error_info -> Nullable<Jsonb>,
|
||||||
|
retries -> Int4,
|
||||||
|
max_retries -> Int4,
|
||||||
|
}
|
||||||
|
}
|
1
src/store.rs
Normal file
1
src/store.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
|
133
src/task.rs
133
src/task.rs
|
@ -1,4 +1,5 @@
|
||||||
use crate::schema::backie_tasks;
|
use crate::schema::backie_tasks;
|
||||||
|
use crate::BackgroundTask;
|
||||||
use chrono::DateTime;
|
use chrono::DateTime;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use diesel::prelude::*;
|
use diesel::prelude::*;
|
||||||
|
@ -8,11 +9,11 @@ use sha2::{Digest, Sha256};
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use typed_builder::TypedBuilder;
|
use std::time::Duration;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// States of a task.
|
/// States of a task.
|
||||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
pub enum TaskState {
|
pub enum TaskState {
|
||||||
/// The task is ready to be executed.
|
/// The task is ready to be executed.
|
||||||
Ready,
|
Ready,
|
||||||
|
@ -21,7 +22,7 @@ pub enum TaskState {
|
||||||
Running,
|
Running,
|
||||||
|
|
||||||
/// The task has failed to execute.
|
/// The task has failed to execute.
|
||||||
Failed,
|
Failed(String),
|
||||||
|
|
||||||
/// The task finished successfully.
|
/// The task finished successfully.
|
||||||
Done,
|
Done,
|
||||||
|
@ -36,24 +37,6 @@ impl Display for TaskId {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
|
|
||||||
pub struct TaskType(Cow<'static, str>);
|
|
||||||
|
|
||||||
impl Default for TaskType {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self(Cow::from("default"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> From<S> for TaskType
|
|
||||||
where
|
|
||||||
S: AsRef<str> + 'static,
|
|
||||||
{
|
|
||||||
fn from(s: S) -> Self {
|
|
||||||
TaskType(Cow::from(s.as_ref().to_owned()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
|
#[derive(Clone, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
|
||||||
pub struct TaskHash(Cow<'static, str>);
|
pub struct TaskHash(Cow<'static, str>);
|
||||||
|
|
||||||
|
@ -70,42 +53,55 @@ impl TaskHash {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone, TypedBuilder)]
|
#[derive(Queryable, Identifiable, Debug, Eq, PartialEq, Clone)]
|
||||||
#[diesel(table_name = backie_tasks)]
|
#[diesel(table_name = backie_tasks)]
|
||||||
pub struct Task {
|
pub struct Task {
|
||||||
#[builder(setter(into))]
|
/// Unique identifier of the task.
|
||||||
pub id: TaskId,
|
pub id: TaskId,
|
||||||
|
|
||||||
#[builder(setter(into))]
|
/// Name of the type of task.
|
||||||
pub payload: serde_json::Value,
|
pub task_name: String,
|
||||||
|
|
||||||
#[builder(setter(into))]
|
/// Queue name that the task belongs to.
|
||||||
pub error_message: Option<String>,
|
pub queue_name: String,
|
||||||
|
|
||||||
#[builder(setter(into))]
|
/// Unique hash is used to identify and avoid duplicate tasks.
|
||||||
pub task_type: TaskType,
|
|
||||||
|
|
||||||
#[builder(setter(into))]
|
|
||||||
pub uniq_hash: Option<TaskHash>,
|
pub uniq_hash: Option<TaskHash>,
|
||||||
|
|
||||||
#[builder(setter(into))]
|
/// Representation of the task.
|
||||||
pub retries: i32,
|
pub payload: serde_json::Value,
|
||||||
|
|
||||||
#[builder(setter(into))]
|
/// Max timeout that the task can run for.
|
||||||
|
pub timeout_msecs: i64,
|
||||||
|
|
||||||
|
/// Creation time of the task.
|
||||||
pub created_at: DateTime<Utc>,
|
pub created_at: DateTime<Utc>,
|
||||||
|
|
||||||
#[builder(setter(into))]
|
/// Date time when the task is scheduled to run.
|
||||||
|
pub scheduled_at: DateTime<Utc>,
|
||||||
|
|
||||||
|
/// Date time when the task is started to run.
|
||||||
pub running_at: Option<DateTime<Utc>>,
|
pub running_at: Option<DateTime<Utc>>,
|
||||||
|
|
||||||
#[builder(setter(into))]
|
/// Date time when the task is finished.
|
||||||
pub done_at: Option<DateTime<Utc>>,
|
pub done_at: Option<DateTime<Utc>>,
|
||||||
|
|
||||||
|
/// Failure reason, when the task is failed.
|
||||||
|
pub error_info: Option<serde_json::Value>,
|
||||||
|
|
||||||
|
/// Number of times a task was retried.
|
||||||
|
pub retries: i32,
|
||||||
|
|
||||||
|
/// Maximum number of retries allow for this task before it is maked as failure.
|
||||||
|
pub max_retries: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Task {
|
impl Task {
|
||||||
pub fn state(&self) -> TaskState {
|
pub fn state(&self) -> TaskState {
|
||||||
if self.done_at.is_some() {
|
if self.done_at.is_some() {
|
||||||
if self.error_message.is_some() {
|
if self.error_info.is_some() {
|
||||||
TaskState::Failed
|
// TODO: use a proper error type
|
||||||
|
TaskState::Failed(self.error_info.clone().unwrap().to_string())
|
||||||
} else {
|
} else {
|
||||||
TaskState::Done
|
TaskState::Done
|
||||||
}
|
}
|
||||||
|
@ -117,22 +113,61 @@ impl Task {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Insertable, Debug, Eq, PartialEq, Clone, TypedBuilder)]
|
#[derive(Insertable, Debug, Eq, PartialEq, Clone)]
|
||||||
#[diesel(table_name = backie_tasks)]
|
#[diesel(table_name = backie_tasks)]
|
||||||
pub struct NewTask {
|
pub(crate) struct NewTask {
|
||||||
#[builder(setter(into))]
|
task_name: String,
|
||||||
payload: serde_json::Value,
|
queue_name: String,
|
||||||
|
|
||||||
#[builder(setter(into))]
|
|
||||||
task_type: TaskType,
|
|
||||||
|
|
||||||
#[builder(setter(into))]
|
|
||||||
uniq_hash: Option<TaskHash>,
|
uniq_hash: Option<TaskHash>,
|
||||||
|
payload: serde_json::Value,
|
||||||
|
timeout_msecs: i64,
|
||||||
|
max_retries: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TaskInfo {
|
impl NewTask {
|
||||||
|
pub(crate) fn new<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error>
|
||||||
|
where
|
||||||
|
T: BackgroundTask,
|
||||||
|
{
|
||||||
|
let max_retries = background_task.max_retries();
|
||||||
|
let uniq_hash = background_task.uniq();
|
||||||
|
let payload = serde_json::to_value(background_task)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
task_name: T::TASK_NAME.to_string(),
|
||||||
|
queue_name: T::QUEUE.to_string(),
|
||||||
|
uniq_hash,
|
||||||
|
payload,
|
||||||
|
timeout_msecs: timeout.as_millis() as i64,
|
||||||
|
max_retries,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CurrentTask {
|
||||||
id: TaskId,
|
id: TaskId,
|
||||||
error_message: Option<String>,
|
|
||||||
retries: i32,
|
retries: i32,
|
||||||
created_at: DateTime<Utc>,
|
created_at: DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl CurrentTask {
|
||||||
|
pub(crate) fn new(task: &Task) -> Self {
|
||||||
|
Self {
|
||||||
|
id: task.id.clone(),
|
||||||
|
retries: task.retries,
|
||||||
|
created_at: task.created_at,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> TaskId {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn retry_count(&self) -> i32 {
|
||||||
|
self.retries
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn created_at(&self) -> DateTime<Utc> {
|
||||||
|
self.created_at
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
549
src/worker.rs
549
src/worker.rs
|
@ -1,53 +1,104 @@
|
||||||
use crate::errors::BackieError;
|
use crate::errors::{AsyncQueueError, BackieError};
|
||||||
use crate::queue::Queueable;
|
use crate::runnable::BackgroundTask;
|
||||||
use crate::runnable::RunnableTask;
|
use crate::task::{CurrentTask, Task, TaskState};
|
||||||
use crate::task::{Task, TaskType};
|
use crate::{PgTaskStore, RetentionMode};
|
||||||
use crate::RetentionMode;
|
|
||||||
use crate::Scheduled::*;
|
|
||||||
use futures::future::FutureExt;
|
use futures::future::FutureExt;
|
||||||
use futures::select;
|
use futures::select;
|
||||||
use std::error::Error;
|
use std::collections::BTreeMap;
|
||||||
use typed_builder::TypedBuilder;
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
/// it executes tasks only of task_type type, it sleeps when there are no tasks in the queue
|
pub type ExecuteTaskFn<AppData> = Arc<
|
||||||
#[derive(TypedBuilder)]
|
dyn Fn(
|
||||||
pub struct Worker<Q>
|
CurrentTask,
|
||||||
where
|
serde_json::Value,
|
||||||
Q: Queueable + Clone + Sync + 'static,
|
AppData,
|
||||||
{
|
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
|
||||||
#[builder(setter(into))]
|
+ Send
|
||||||
pub queue: Q,
|
+ Sync,
|
||||||
|
>;
|
||||||
|
|
||||||
#[builder(default, setter(into))]
|
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
|
||||||
pub task_type: Option<TaskType>,
|
|
||||||
|
|
||||||
#[builder(default, setter(into))]
|
#[derive(Debug, Error)]
|
||||||
pub retention_mode: RetentionMode,
|
pub enum TaskExecError {
|
||||||
|
#[error("Task execution failed: {0}")]
|
||||||
|
ExecutionFailed(#[from] anyhow::Error),
|
||||||
|
|
||||||
#[builder(default, setter(into))]
|
#[error("Task deserialization failed: {0}")]
|
||||||
pub shutdown: Option<tokio::sync::watch::Receiver<()>>,
|
TaskDeserializationFailed(#[from] serde_json::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Q> Worker<Q>
|
pub(crate) fn runnable<BT>(
|
||||||
|
task_info: CurrentTask,
|
||||||
|
payload: serde_json::Value,
|
||||||
|
app_context: BT::AppData,
|
||||||
|
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
|
||||||
where
|
where
|
||||||
Q: Queueable + Clone + Sync + 'static,
|
BT: BackgroundTask,
|
||||||
{
|
{
|
||||||
|
Box::pin(async move {
|
||||||
|
let background_task: BT = serde_json::from_value(payload)?;
|
||||||
|
background_task.run(task_info, app_context).await?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Worker that executes tasks.
|
||||||
|
pub struct Worker<AppData>
|
||||||
|
where
|
||||||
|
AppData: Clone + Send + 'static,
|
||||||
|
{
|
||||||
|
store: Arc<PgTaskStore>,
|
||||||
|
|
||||||
|
queue_name: String,
|
||||||
|
|
||||||
|
retention_mode: RetentionMode,
|
||||||
|
|
||||||
|
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||||
|
|
||||||
|
app_data_fn: StateFn<AppData>,
|
||||||
|
|
||||||
|
/// Notification for the worker to stop.
|
||||||
|
shutdown: Option<tokio::sync::watch::Receiver<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<AppData> Worker<AppData>
|
||||||
|
where
|
||||||
|
AppData: Clone + Send + 'static,
|
||||||
|
{
|
||||||
|
pub(crate) fn new(
|
||||||
|
store: Arc<PgTaskStore>,
|
||||||
|
queue_name: String,
|
||||||
|
retention_mode: RetentionMode,
|
||||||
|
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||||
|
app_data_fn: StateFn<AppData>,
|
||||||
|
shutdown: Option<tokio::sync::watch::Receiver<()>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
store,
|
||||||
|
queue_name,
|
||||||
|
retention_mode,
|
||||||
|
task_registry,
|
||||||
|
app_data_fn,
|
||||||
|
shutdown,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> {
|
pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> {
|
||||||
loop {
|
loop {
|
||||||
// Need to check if has to stop before pulling next task
|
// Check if has to stop before pulling next task
|
||||||
match self.queue.pull_next_task(self.task_type.clone()).await? {
|
if let Some(ref shutdown) = self.shutdown {
|
||||||
Some(task) => {
|
if shutdown.has_changed()? {
|
||||||
let actual_task: Box<dyn RunnableTask> =
|
return Ok(());
|
||||||
serde_json::from_value(task.payload.clone())?;
|
|
||||||
|
|
||||||
// check if task is scheduled or not
|
|
||||||
if let Some(CronPattern(_)) = actual_task.cron() {
|
|
||||||
// program task
|
|
||||||
//self.queue.schedule_task(&*actual_task).await?;
|
|
||||||
}
|
}
|
||||||
// run scheduled task
|
};
|
||||||
// TODO: what do we do if the task fails? it's an internal error, inform the logs
|
|
||||||
let _ = self.run(task, actual_task).await;
|
match self.store.pull_next_task(&self.queue_name).await? {
|
||||||
|
Some(task) => {
|
||||||
|
self.run(task).await?;
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
// Listen to watchable future
|
// Listen to watchable future
|
||||||
|
@ -73,41 +124,45 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
// #[cfg(test)]
|
||||||
pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> {
|
// pub async fn run_tasks_until_none(&mut self) -> Result<(), BackieError> {
|
||||||
loop {
|
// loop {
|
||||||
match self.queue.pull_next_task(self.task_type.clone()).await? {
|
// match self.store.pull_next_task(self.queue_name.clone()).await? {
|
||||||
Some(task) => {
|
// Some(task) => {
|
||||||
let actual_task: Box<dyn RunnableTask> =
|
// let actual_task: Box<dyn BackgroundTask> =
|
||||||
serde_json::from_value(task.payload.clone()).unwrap();
|
// serde_json::from_value(task.payload.clone()).unwrap();
|
||||||
|
//
|
||||||
|
// // check if task is scheduled or not
|
||||||
|
// if let Some(CronPattern(_)) = actual_task.cron() {
|
||||||
|
// // program task
|
||||||
|
// // self.queue.schedule_task(&*actual_task).await?;
|
||||||
|
// }
|
||||||
|
// // run scheduled task
|
||||||
|
// self.run(task, actual_task).await?;
|
||||||
|
// }
|
||||||
|
// None => {
|
||||||
|
// return Ok(());
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
// check if task is scheduled or not
|
async fn run(&self, task: Task) -> Result<(), BackieError> {
|
||||||
if let Some(CronPattern(_)) = actual_task.cron() {
|
let task_info = CurrentTask::new(&task);
|
||||||
// program task
|
let runnable_task_caller = self
|
||||||
// self.queue.schedule_task(&*actual_task).await?;
|
.task_registry
|
||||||
}
|
.get(&task.task_name)
|
||||||
// run scheduled task
|
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
|
||||||
self.run(task, actual_task).await?;
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run(
|
|
||||||
&mut self,
|
|
||||||
task: Task,
|
|
||||||
runnable: Box<dyn RunnableTask>,
|
|
||||||
) -> Result<(), BackieError> {
|
|
||||||
// TODO: catch panics
|
// TODO: catch panics
|
||||||
let result = runnable.run(&mut self.queue).await;
|
let result: Result<(), TaskExecError> =
|
||||||
match result {
|
runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
|
||||||
|
|
||||||
|
match &result {
|
||||||
Ok(_) => self.finalize_task(task, result).await?,
|
Ok(_) => self.finalize_task(task, result).await?,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
if task.retries < runnable.max_retries() {
|
if task.retries < task.max_retries {
|
||||||
let backoff_seconds = runnable.backoff(task.retries as u32);
|
let backoff_seconds = 5; // TODO: runnable.backoff(task.retries as u32);
|
||||||
|
|
||||||
log::debug!(
|
log::debug!(
|
||||||
"Task {} failed to run and will be retried in {} seconds",
|
"Task {} failed to run and will be retried in {} seconds",
|
||||||
|
@ -115,12 +170,12 @@ where
|
||||||
backoff_seconds
|
backoff_seconds
|
||||||
);
|
);
|
||||||
let error_message = format!("{}", error);
|
let error_message = format!("{}", error);
|
||||||
self.queue
|
self.store
|
||||||
.schedule_task_retry(task.id, backoff_seconds, &error_message)
|
.schedule_task_retry(task.id, backoff_seconds, &error_message)
|
||||||
.await?;
|
.await?;
|
||||||
} else {
|
} else {
|
||||||
log::debug!("Task {} failed and reached the maximum retries", task.id);
|
log::debug!("Task {} failed and reached the maximum retries", task.id);
|
||||||
self.finalize_task(task, Err(error)).await?;
|
self.finalize_task(task, result).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -128,36 +183,36 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn finalize_task(
|
async fn finalize_task(
|
||||||
&mut self,
|
&self,
|
||||||
task: Task,
|
task: Task,
|
||||||
result: Result<(), Box<dyn Error + Send + 'static>>,
|
result: Result<(), TaskExecError>,
|
||||||
) -> Result<(), BackieError> {
|
) -> Result<(), BackieError> {
|
||||||
match self.retention_mode {
|
match self.retention_mode {
|
||||||
RetentionMode::KeepAll => match result {
|
RetentionMode::KeepAll => match result {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
self.queue.set_task_done(task.id).await?;
|
self.store.set_task_state(task.id, TaskState::Done).await?;
|
||||||
log::debug!("Task {} done and kept in the database", task.id);
|
log::debug!("Task {} done and kept in the database", task.id);
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
log::debug!("Task {} failed and kept in the database", task.id);
|
log::debug!("Task {} failed and kept in the database", task.id);
|
||||||
self.queue
|
self.store
|
||||||
.set_task_failed(task.id, &format!("{}", error))
|
.set_task_state(task.id, TaskState::Failed(format!("{}", error)))
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
RetentionMode::RemoveAll => {
|
RetentionMode::RemoveAll => {
|
||||||
log::debug!("Task {} finalized and deleted from the database", task.id);
|
log::debug!("Task {} finalized and deleted from the database", task.id);
|
||||||
self.queue.remove_task(task.id).await?;
|
self.store.remove_task(task.id).await?;
|
||||||
}
|
}
|
||||||
RetentionMode::RemoveDone => match result {
|
RetentionMode::RemoveDone => match result {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
log::debug!("Task {} done and deleted from the database", task.id);
|
log::debug!("Task {} done and deleted from the database", task.id);
|
||||||
self.queue.remove_task(task.id).await?;
|
self.store.remove_task(task.id).await?;
|
||||||
}
|
}
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
log::debug!("Task {} failed and kept in the database", task.id);
|
log::debug!("Task {} failed and kept in the database", task.id);
|
||||||
self.queue
|
self.store
|
||||||
.set_task_failed(task.id, &format!("{}", error))
|
.set_task_state(task.id, TaskState::Failed(format!("{}", error)))
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -169,35 +224,19 @@ where
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod async_worker_tests {
|
mod async_worker_tests {
|
||||||
use std::fmt::Display;
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::queue::PgAsyncQueue;
|
|
||||||
use crate::queue::Queueable;
|
|
||||||
use crate::task::TaskState;
|
|
||||||
use crate::worker::Task;
|
|
||||||
use crate::RetentionMode;
|
|
||||||
use crate::Scheduled;
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Duration;
|
|
||||||
use chrono::Utc;
|
|
||||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||||
use diesel_async::AsyncPgConnection;
|
use diesel_async::AsyncPgConnection;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
enum TaskError {
|
enum TaskError {
|
||||||
|
#[error("Something went wrong")]
|
||||||
SomethingWrong,
|
SomethingWrong,
|
||||||
Custom(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for TaskError {
|
#[error("{0}")]
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
Custom(String),
|
||||||
match self {
|
|
||||||
TaskError::SomethingWrong => write!(f, "Something went wrong"),
|
|
||||||
TaskError::Custom(message) => write!(f, "{}", message),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
|
@ -205,13 +244,12 @@ mod async_worker_tests {
|
||||||
pub number: u16,
|
pub number: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for WorkerAsyncTask {
|
impl BackgroundTask for WorkerAsyncTask {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "WorkerAsyncTask";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -221,18 +259,18 @@ mod async_worker_tests {
|
||||||
pub number: u16,
|
pub number: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for WorkerAsyncTaskSchedule {
|
impl BackgroundTask for WorkerAsyncTaskSchedule {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
fn cron(&self) -> Option<Scheduled> {
|
|
||||||
Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
|
// fn cron(&self) -> Option<Scheduled> {
|
||||||
}
|
// Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
|
@ -240,16 +278,15 @@ mod async_worker_tests {
|
||||||
pub number: u16,
|
pub number: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for AsyncFailedTask {
|
impl BackgroundTask for AsyncFailedTask {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "AsyncFailedTask";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
let message = format!("number {} is wrong :(", self.number);
|
let message = format!("number {} is wrong :(", self.number);
|
||||||
|
|
||||||
Err(Box::new(TaskError::Custom(message)))
|
Err(TaskError::Custom(message).into())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_retries(&self) -> i32 {
|
fn max_retries(&self) -> i32 {
|
||||||
|
@ -260,282 +297,40 @@ mod async_worker_tests {
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
struct AsyncRetryTask {}
|
struct AsyncRetryTask {}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for AsyncRetryTask {
|
impl BackgroundTask for AsyncRetryTask {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "AsyncRetryTask";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
|
||||||
Err(Box::new(TaskError::SomethingWrong))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn max_retries(&self) -> i32 {
|
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
2
|
Err(TaskError::SomethingWrong.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct AsyncTaskType1 {}
|
struct AsyncTaskType1 {}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for AsyncTaskType1 {
|
impl BackgroundTask for AsyncTaskType1 {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "AsyncTaskType1";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn task_type(&self) -> TaskType {
|
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
"type1".into()
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
struct AsyncTaskType2 {}
|
struct AsyncTaskType2 {}
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableTask for AsyncTaskType2 {
|
impl BackgroundTask for AsyncTaskType2 {
|
||||||
async fn run(
|
const TASK_NAME: &'static str = "AsyncTaskType2";
|
||||||
&self,
|
type AppData = ();
|
||||||
_queueable: &mut dyn Queueable,
|
|
||||||
) -> Result<(), Box<(dyn std::error::Error + Send + 'static)>> {
|
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn task_type(&self) -> TaskType {
|
|
||||||
TaskType::from("type2")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn execute_and_finishes_task() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let actual_task = WorkerAsyncTask { number: 1 };
|
|
||||||
|
|
||||||
let task = insert_task(&mut test, &actual_task).await;
|
|
||||||
let id = task.id;
|
|
||||||
|
|
||||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
|
||||||
.queue(test.clone())
|
|
||||||
.retention_mode(RetentionMode::KeepAll)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
worker.run(task, Box::new(actual_task)).await.unwrap();
|
|
||||||
let task_finished = test.find_task_by_id(id).await.unwrap();
|
|
||||||
assert_eq!(id, task_finished.id);
|
|
||||||
assert_eq!(TaskState::Done, task_finished.state());
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
// #[tokio::test]
|
|
||||||
// async fn schedule_task_test() {
|
|
||||||
// let pool = pool().await;
|
|
||||||
// let mut test = PgAsyncQueue::new(pool);
|
|
||||||
//
|
|
||||||
// let actual_task = WorkerAsyncTaskSchedule { number: 1 };
|
|
||||||
//
|
|
||||||
// let task = test.schedule_task(&actual_task).await.unwrap();
|
|
||||||
//
|
|
||||||
// let id = task.id;
|
|
||||||
//
|
|
||||||
// let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
|
|
||||||
// .queue(test.clone())
|
|
||||||
// .retention_mode(RetentionMode::KeepAll)
|
|
||||||
// .build();
|
|
||||||
//
|
|
||||||
// worker.run_tasks_until_none().await.unwrap();
|
|
||||||
//
|
|
||||||
// let task = worker.queue.find_task_by_id(id).await.unwrap();
|
|
||||||
//
|
|
||||||
// assert_eq!(id, task.id);
|
|
||||||
// assert_eq!(TaskState::Ready, task.state());
|
|
||||||
//
|
|
||||||
// tokio::time::sleep(core::time::Duration::from_secs(3)).await;
|
|
||||||
//
|
|
||||||
// worker.run_tasks_until_none().await.unwrap();
|
|
||||||
//
|
|
||||||
// let task = test.find_task_by_id(id).await.unwrap();
|
|
||||||
// assert_eq!(id, task.id);
|
|
||||||
// assert_eq!(TaskState::Done, task.state());
|
|
||||||
//
|
|
||||||
// test.remove_all_tasks().await.unwrap();
|
|
||||||
// }
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn retries_task_test() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let actual_task = AsyncRetryTask {};
|
|
||||||
|
|
||||||
let task = test.create_task(&actual_task).await.unwrap();
|
|
||||||
|
|
||||||
let id = task.id;
|
|
||||||
|
|
||||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
|
||||||
.queue(test.clone())
|
|
||||||
.retention_mode(RetentionMode::KeepAll)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
worker.run_tasks_until_none().await.unwrap();
|
|
||||||
|
|
||||||
let task = worker.queue.find_task_by_id(id).await.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(id, task.id);
|
|
||||||
assert_eq!(TaskState::Ready, task.state());
|
|
||||||
assert_eq!(1, task.retries);
|
|
||||||
assert!(task.error_message.is_some());
|
|
||||||
|
|
||||||
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
|
|
||||||
worker.run_tasks_until_none().await.unwrap();
|
|
||||||
|
|
||||||
let task = worker.queue.find_task_by_id(id).await.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(id, task.id);
|
|
||||||
assert_eq!(TaskState::Ready, task.state());
|
|
||||||
assert_eq!(2, task.retries);
|
|
||||||
|
|
||||||
tokio::time::sleep(core::time::Duration::from_secs(10)).await;
|
|
||||||
worker.run_tasks_until_none().await.unwrap();
|
|
||||||
|
|
||||||
let task = test.find_task_by_id(id).await.unwrap();
|
|
||||||
assert_eq!(id, task.id);
|
|
||||||
assert_eq!(TaskState::Failed, task.state());
|
|
||||||
assert_eq!("Something went wrong".to_string(), task.error_message.unwrap());
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn worker_shutsdown_when_notified() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let queue = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::watch::channel(());
|
|
||||||
|
|
||||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
|
||||||
.queue(queue)
|
|
||||||
.shutdown(rx)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
let handle = tokio::spawn(async move {
|
|
||||||
worker.run_tasks().await.unwrap();
|
|
||||||
true
|
|
||||||
});
|
|
||||||
|
|
||||||
tx.send(()).unwrap();
|
|
||||||
select! {
|
|
||||||
_ = handle.fuse() => {}
|
|
||||||
_ = tokio::time::sleep(core::time::Duration::from_secs(1)).fuse() => panic!("Worker did not shutdown")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn saves_error_for_failed_task() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let failed_task = AsyncFailedTask { number: 1 };
|
|
||||||
|
|
||||||
let task = insert_task(&mut test, &failed_task).await;
|
|
||||||
let id = task.id;
|
|
||||||
|
|
||||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
|
||||||
.queue(test.clone())
|
|
||||||
.retention_mode(RetentionMode::KeepAll)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
worker.run(task, Box::new(failed_task)).await.unwrap();
|
|
||||||
let task_finished = test.find_task_by_id(id).await.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(id, task_finished.id);
|
|
||||||
assert_eq!(TaskState::Failed, task_finished.state());
|
|
||||||
assert_eq!(
|
|
||||||
"number 1 is wrong :(".to_string(),
|
|
||||||
task_finished.error_message.unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn executes_task_only_of_specific_type() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
|
|
||||||
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
|
|
||||||
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
|
|
||||||
|
|
||||||
let id1 = task1.id;
|
|
||||||
let id12 = task12.id;
|
|
||||||
let id2 = task2.id;
|
|
||||||
|
|
||||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
|
||||||
.queue(test.clone())
|
|
||||||
.task_type(TaskType::from("type1"))
|
|
||||||
.retention_mode(RetentionMode::KeepAll)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
worker.run_tasks_until_none().await.unwrap();
|
|
||||||
let task1 = test.find_task_by_id(id1).await.unwrap();
|
|
||||||
let task12 = test.find_task_by_id(id12).await.unwrap();
|
|
||||||
let task2 = test.find_task_by_id(id2).await.unwrap();
|
|
||||||
|
|
||||||
assert_eq!(id1, task1.id);
|
|
||||||
assert_eq!(id12, task12.id);
|
|
||||||
assert_eq!(id2, task2.id);
|
|
||||||
assert_eq!(TaskState::Done, task1.state());
|
|
||||||
assert_eq!(TaskState::Done, task12.state());
|
|
||||||
assert_eq!(TaskState::Ready, task2.state());
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn remove_when_finished() {
|
|
||||||
let pool = pool().await;
|
|
||||||
let mut test = PgAsyncQueue::new(pool);
|
|
||||||
|
|
||||||
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
|
|
||||||
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
|
|
||||||
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
|
|
||||||
|
|
||||||
let _id1 = task1.id;
|
|
||||||
let _id12 = task12.id;
|
|
||||||
let id2 = task2.id;
|
|
||||||
|
|
||||||
let mut worker = Worker::<PgAsyncQueue>::builder()
|
|
||||||
.queue(test.clone())
|
|
||||||
.task_type(TaskType::from("type1"))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
worker.run_tasks_until_none().await.unwrap();
|
|
||||||
let task = test
|
|
||||||
.pull_next_task(Some(TaskType::from("type1")))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(None, task);
|
|
||||||
|
|
||||||
let task2 = test
|
|
||||||
.pull_next_task(Some(TaskType::from("type2")))
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(id2, task2.id);
|
|
||||||
|
|
||||||
test.remove_all_tasks().await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn insert_task(test: &mut PgAsyncQueue, task: &dyn RunnableTask) -> Task {
|
|
||||||
test.create_task(task).await.unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn pool() -> Pool<AsyncPgConnection> {
|
async fn pool() -> Pool<AsyncPgConnection> {
|
||||||
|
|
|
@ -1,102 +1,200 @@
|
||||||
use crate::errors::BackieError;
|
use crate::errors::BackieError;
|
||||||
use crate::queue::Queueable;
|
use crate::queue::Queue;
|
||||||
use crate::task::TaskType;
|
use crate::worker::{runnable, ExecuteTaskFn};
|
||||||
use crate::worker::Worker;
|
use crate::worker::{StateFn, Worker};
|
||||||
use crate::RetentionMode;
|
use crate::{BackgroundTask, CurrentTask, PgTaskStore, RetentionMode};
|
||||||
use async_recursion::async_recursion;
|
use std::collections::BTreeMap;
|
||||||
use log::error;
|
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use tokio::sync::watch::Receiver;
|
use std::sync::Arc;
|
||||||
use typed_builder::TypedBuilder;
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
#[derive(TypedBuilder, Clone)]
|
pub type AppDataFn<AppData> = Arc<dyn Fn(Queue) -> AppData + Send + Sync>;
|
||||||
pub struct WorkerPool<AQueue>
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct WorkerPool<AppData>
|
||||||
where
|
where
|
||||||
AQueue: Queueable + Clone + Sync + 'static,
|
AppData: Clone + Send + 'static,
|
||||||
{
|
{
|
||||||
#[builder(setter(into))]
|
/// Storage of tasks.
|
||||||
/// the AsyncWorkerPool uses a queue to control the tasks that will be executed.
|
queue_store: Arc<PgTaskStore>, // TODO: make this generic/dynamic referenced
|
||||||
pub queue: AQueue,
|
|
||||||
|
|
||||||
/// retention_mode controls if tasks should be persisted after execution
|
/// Queue used to spawn tasks.
|
||||||
#[builder(default, setter(into))]
|
queue: Queue,
|
||||||
pub retention_mode: RetentionMode,
|
|
||||||
|
|
||||||
/// the number of workers of the AsyncWorkerPool.
|
/// Make possible to load the application data.
|
||||||
#[builder(setter(into))]
|
///
|
||||||
pub number_of_workers: u32,
|
/// The application data is loaded when the worker pool is started and is passed to the tasks.
|
||||||
|
/// The loading function accepts a queue instance in case the application data depends on it. This
|
||||||
|
/// is interesting for situations where the application wants to allow tasks to spawn other tasks.
|
||||||
|
application_data_fn: StateFn<AppData>,
|
||||||
|
|
||||||
/// The type of tasks that will be executed by `AsyncWorkerPool`.
|
/// The types of task the worker pool can execute and the loaders for them.
|
||||||
#[builder(default, setter(into))]
|
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||||
pub task_type: Option<TaskType>,
|
|
||||||
|
/// Number of workers that will be spawned per queue.
|
||||||
|
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// impl<TypedBuilderFields, Q> AsyncWorkerBuilder<TypedBuilderFields, Q>
|
impl<AppData> WorkerPool<AppData>
|
||||||
// where
|
|
||||||
// TypedBuilderFields: Clone,
|
|
||||||
// Q: Queueable + Clone + Sync + 'static,
|
|
||||||
// {
|
|
||||||
// pub fn with_graceful_shutdown<F>(self, signal: F) -> Self<TypedBuilderFields, Q>
|
|
||||||
// where
|
|
||||||
// F: Future<Output = ()>,
|
|
||||||
// {
|
|
||||||
// self
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
impl<AQueue> WorkerPool<AQueue>
|
|
||||||
where
|
where
|
||||||
AQueue: Queueable + Clone + Sync + 'static,
|
AppData: Clone + Send + 'static,
|
||||||
{
|
{
|
||||||
/// Starts the configured number of workers
|
/// Create a new worker pool.
|
||||||
/// This is necessary in order to execute tasks.
|
pub fn new<A>(queue_store: PgTaskStore, application_data_fn: A) -> Self
|
||||||
pub async fn start<F>(&mut self, graceful_shutdown: F) -> Result<(), BackieError>
|
where
|
||||||
|
A: Fn(Queue) -> AppData + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
let queue_store = Arc::new(queue_store);
|
||||||
|
let queue = Queue::new(queue_store.clone());
|
||||||
|
let application_data_fn = {
|
||||||
|
let queue = queue.clone();
|
||||||
|
move || application_data_fn(queue.clone())
|
||||||
|
};
|
||||||
|
Self {
|
||||||
|
queue_store,
|
||||||
|
queue,
|
||||||
|
application_data_fn: Arc::new(application_data_fn),
|
||||||
|
task_registry: BTreeMap::new(),
|
||||||
|
worker_queues: BTreeMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a task type with the worker pool.
|
||||||
|
pub fn register_task_type<BT>(mut self, num_workers: u32, retention_mode: RetentionMode) -> Self
|
||||||
|
where
|
||||||
|
BT: BackgroundTask<AppData = AppData>,
|
||||||
|
{
|
||||||
|
self.worker_queues
|
||||||
|
.insert(BT::QUEUE.to_string(), (retention_mode, num_workers));
|
||||||
|
self.task_registry
|
||||||
|
.insert(BT::TASK_NAME.to_string(), Arc::new(runnable::<BT>));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start<F>(
|
||||||
|
self,
|
||||||
|
graceful_shutdown: F,
|
||||||
|
) -> Result<(JoinHandle<()>, Queue), BackieError>
|
||||||
where
|
where
|
||||||
F: Future<Output = ()> + Send + 'static,
|
F: Future<Output = ()> + Send + 'static,
|
||||||
{
|
{
|
||||||
let (tx, rx) = tokio::sync::watch::channel(());
|
let (tx, rx) = tokio::sync::watch::channel(());
|
||||||
for idx in 0..self.number_of_workers {
|
|
||||||
let pool = self.clone();
|
|
||||||
// TODO: the worker pool keeps track of the number of workers and spawns new workers as needed.
|
|
||||||
// There should be always a minimum number of workers active waiting for tasks to execute
|
|
||||||
// or for a gracefull shutdown.
|
|
||||||
tokio::spawn(Self::supervise_task(pool, rx.clone(), 0, idx));
|
|
||||||
}
|
|
||||||
graceful_shutdown.await;
|
|
||||||
tx.send(())?;
|
|
||||||
log::info!("Worker pool stopped gracefully");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_recursion]
|
// Spawn all individual workers per queue
|
||||||
async fn supervise_task(
|
for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() {
|
||||||
pool: WorkerPool<AQueue>,
|
for idx in 0..*num_workers {
|
||||||
receiver: Receiver<()>,
|
let mut worker: Worker<AppData> = Worker::new(
|
||||||
restarts: u64,
|
self.queue_store.clone(),
|
||||||
worker_number: u32,
|
queue_name.clone(),
|
||||||
) {
|
retention_mode.clone(),
|
||||||
let restarts = restarts + 1;
|
self.task_registry.clone(),
|
||||||
|
self.application_data_fn.clone(),
|
||||||
let inner_pool = pool.clone();
|
Some(rx.clone()),
|
||||||
let inner_receiver = receiver.clone();
|
|
||||||
|
|
||||||
let join_handle = tokio::spawn(async move {
|
|
||||||
let mut worker: Worker<AQueue> = Worker::builder()
|
|
||||||
.queue(inner_pool.queue.clone())
|
|
||||||
.retention_mode(inner_pool.retention_mode)
|
|
||||||
.task_type(inner_pool.task_type.clone())
|
|
||||||
.shutdown(inner_receiver)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
worker.run_tasks().await
|
|
||||||
});
|
|
||||||
|
|
||||||
if (join_handle.await).is_err() {
|
|
||||||
error!(
|
|
||||||
"Worker {} stopped. Restarting. the number of restarts {}",
|
|
||||||
worker_number, restarts,
|
|
||||||
);
|
);
|
||||||
Self::supervise_task(pool, receiver, restarts, worker_number).await;
|
let worker_name = format!("worker-{queue_name}-{idx}");
|
||||||
|
// TODO: grab the join handle for every worker for graceful shutdown
|
||||||
|
tokio::spawn(async move {
|
||||||
|
worker.run_tasks().await;
|
||||||
|
log::info!("Worker {} stopped", worker_name);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
tokio::spawn(async move {
|
||||||
|
graceful_shutdown.await;
|
||||||
|
if let Err(err) = tx.send(()) {
|
||||||
|
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
|
||||||
|
} else {
|
||||||
|
log::info!("Worker pool stopped gracefully");
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
self.queue,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||||
|
use diesel_async::AsyncPgConnection;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct ApplicationContext {
|
||||||
|
app_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApplicationContext {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
app_name: "Backie".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_app_name(&self) -> String {
|
||||||
|
self.app_name.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||||
|
struct GreetingTask {
|
||||||
|
person: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl BackgroundTask for GreetingTask {
|
||||||
|
const TASK_NAME: &'static str = "my_task";
|
||||||
|
|
||||||
|
type AppData = ApplicationContext;
|
||||||
|
|
||||||
|
async fn run(
|
||||||
|
&self,
|
||||||
|
task_info: CurrentTask,
|
||||||
|
app_context: Self::AppData,
|
||||||
|
) -> Result<(), anyhow::Error> {
|
||||||
|
println!(
|
||||||
|
"[{}] Hello {}! I'm {}.",
|
||||||
|
task_info.id(),
|
||||||
|
self.person,
|
||||||
|
app_context.get_app_name()
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_worker_pool() {
|
||||||
|
let my_app_context = ApplicationContext::new();
|
||||||
|
|
||||||
|
let task_store = PgTaskStore::new(pool().await);
|
||||||
|
|
||||||
|
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||||
|
.register_task_type::<GreetingTask>(1, RetentionMode::RemoveDone)
|
||||||
|
.start(futures::future::ready(()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
queue
|
||||||
|
.enqueue(GreetingTask {
|
||||||
|
person: "Rafael".to_string(),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
join_handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn pool() -> Pool<AsyncPgConnection> {
|
||||||
|
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
||||||
|
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
|
||||||
|
);
|
||||||
|
Pool::builder()
|
||||||
|
.max_size(1)
|
||||||
|
.min_idle(Some(1))
|
||||||
|
.build(manager)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue