Compare commits

...

10 commits

12 changed files with 348 additions and 191 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@
Cargo.lock Cargo.lock
docs/content/docs/CHANGELOG.md docs/content/docs/CHANGELOG.md
docs/content/docs/README.md docs/content/docs/README.md
.DS_Store

View file

@ -1,6 +1,6 @@
[package] [package]
name = "backie" name = "backie"
version = "0.3.0" version = "0.6.0"
authors = [ authors = [
"Rafael Caricio <rafael@caricio.com>", "Rafael Caricio <rafael@caricio.com>",
] ]
@ -17,7 +17,6 @@ chrono = "0.4"
log = "0.4" log = "0.4"
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1" serde_json = "1"
anyhow = "1"
thiserror = "1" thiserror = "1"
uuid = { version = "1.1", features = ["v4", "serde"] } uuid = { version = "1.1", features = ["v4", "serde"] }
async-trait = "0.1" async-trait = "0.1"

View file

@ -1,5 +1,6 @@
# Backie 🚲 <p align="center"><img src="logo.png" alt="Backie" width="400"></p>
---
Async persistent background task processing for Rust applications with Tokio. Queue asynchronous tasks Async persistent background task processing for Rust applications with Tokio. Queue asynchronous tasks
to be processed by workers. It's designed to be easy to use and horizontally scalable. It uses Postgres as to be processed by workers. It's designed to be easy to use and horizontally scalable. It uses Postgres as
a storage backend and can also be extended to support other types of storage. a storage backend and can also be extended to support other types of storage.
@ -31,6 +32,25 @@ Here are some of the Backie's key features:
- Task timeout: Tasks are retried if they are not completed in time - Task timeout: Tasks are retried if they are not completed in time
- Scheduling of tasks: Tasks can be scheduled to be executed at a specific time - Scheduling of tasks: Tasks can be scheduled to be executed at a specific time
## Task execution protocol
The following diagram shows the protocol used to execute tasks:
```mermaid
stateDiagram-v2
[*] --> Ready
Ready --> Running: Task is picked up by a worker
Running --> Done: Task is finished
Running --> Failed: Task failed
Failed --> Ready: Task is retried
Failed --> [*]: Task is not retried anymore, max retries reached
Done --> [*]
```
When a task goes from `Running` to `Failed` it is retried. The number of retries is controlled by the
[`BackgroundTask::MAX_RETRIES`] attribute. The default implementation uses `3` retries.
## Safety ## Safety
This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust. This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust.
@ -53,7 +73,6 @@ If you are not already using, you will also want to include the following depend
```toml ```toml
[dependencies] [dependencies]
async-trait = "0.1" async-trait = "0.1"
anyhow = "1"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] } diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] }
diesel-async = { version = "0.2", features = ["postgres", "bb8"] } diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
@ -75,6 +94,9 @@ the whole application. This attribute is critical for reconstructing the task ba
The [`BackgroundTask::AppData`] can be used to argument the task with your application specific contextual information. The [`BackgroundTask::AppData`] can be used to argument the task with your application specific contextual information.
This is useful for example to pass a database connection pool to the task or other application configuration. This is useful for example to pass a database connection pool to the task or other application configuration.
The [`BackgroundTask::Error`] is the error type that will be returned by the [`BackgroundTask::run`] method. You can
use this to define your own error type for your tasks.
The [`BackgroundTask::run`] method is where you define the behaviour of your background task execution. This method The [`BackgroundTask::run`] method is where you define the behaviour of your background task execution. This method
will be called by the task queue workers. will be called by the task queue workers.
@ -92,8 +114,9 @@ pub struct MyTask {
impl BackgroundTask for MyTask { impl BackgroundTask for MyTask {
const TASK_NAME: &'static str = "my_task_unique_name"; const TASK_NAME: &'static str = "my_task_unique_name";
type AppData = (); type AppData = ();
type Error = ();
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
// Do something // Do something
Ok(()) Ok(())
} }

View file

@ -1,20 +1,33 @@
use async_trait::async_trait; use async_trait::async_trait;
use backie::{BackgroundTask, CurrentTask}; use backie::{BackgroundTask, CurrentTask, QueueConfig, RetentionMode};
use backie::{PgTaskStore, Queue, WorkerPool}; use backie::{PgTaskStore, Queue, WorkerPool};
use diesel_async::pg::AsyncPgConnection; use diesel_async::pg::AsyncPgConnection;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::time::Duration; use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct MyApplicationContext { pub struct MyApplicationContext {
app_name: String, app_name: String,
notify_finished: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
} }
impl MyApplicationContext { impl MyApplicationContext {
pub fn new(app_name: &str) -> Self { pub fn new(app_name: &str, notify_finished: tokio::sync::oneshot::Sender<()>) -> Self {
Self { Self {
app_name: app_name.to_string(), app_name: app_name.to_string(),
notify_finished: Arc::new(Mutex::new(Some(notify_finished))),
}
}
pub async fn notify_finished(&self) {
let mut lock = self.notify_finished.lock().await;
if let Some(sender) = lock.take() {
sender.send(()).unwrap();
} }
} }
} }
@ -34,14 +47,9 @@ impl MyTask {
impl BackgroundTask for MyTask { impl BackgroundTask for MyTask {
const TASK_NAME: &'static str = "my_task"; const TASK_NAME: &'static str = "my_task";
type AppData = MyApplicationContext; type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> {
// let new_task = MyTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
log::info!( log::info!(
"[{}] Hello from {}! the current number is {}", "[{}] Hello from {}! the current number is {}",
task.id(), task.id(),
@ -70,19 +78,9 @@ impl MyFailingTask {
impl BackgroundTask for MyFailingTask { impl BackgroundTask for MyFailingTask {
const TASK_NAME: &'static str = "my_failing_task"; const TASK_NAME: &'static str = "my_failing_task";
type AppData = MyApplicationContext; type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> {
// let new_task = MyFailingTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
// task.id();
// task.keep_alive().await?;
// task.previous_error();
// task.retry_count();
log::info!("[{}] the current number is {}", task.id(), self.number); log::info!("[{}] the current number is {}", task.id(), self.number);
tokio::time::sleep(Duration::from_secs(3)).await; tokio::time::sleep(Duration::from_secs(3)).await;
@ -91,58 +89,124 @@ impl BackgroundTask for MyFailingTask {
} }
} }
#[derive(Serialize, Deserialize)]
struct EmptyTask {
pub idx: u64,
}
#[async_trait]
impl BackgroundTask for EmptyTask {
const TASK_NAME: &'static str = "empty_task";
const QUEUE: &'static str = "loaded_queue";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, _task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct FinalTask;
#[async_trait]
impl BackgroundTask for FinalTask {
const TASK_NAME: &'static str = "final_task";
const QUEUE: &'static str = "loaded_queue";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, _task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> {
ctx.notify_finished().await;
Ok(())
}
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> anyhow::Result<()> {
env_logger::init(); env_logger::init();
let connection_url = "postgres://postgres:password@localhost/backie"; let connection_url = "postgres://postgres:password@localhost/backie";
log::info!("Starting..."); log::info!("Starting...");
let max_pool_size: u32 = 3;
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(connection_url); let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(connection_url);
let pool = Pool::builder() let pool = Pool::builder()
.max_size(max_pool_size) .max_size(300)
.min_idle(Some(1)) .min_idle(Some(1))
.build(manager) .build(manager)
.await .await
.unwrap(); .unwrap();
log::info!("Pool created ..."); log::info!("Pool created ...");
let task_store = PgTaskStore::new(pool);
let (tx, mut rx) = tokio::sync::watch::channel(false); let (tx, mut rx) = tokio::sync::watch::channel(false);
let (notify_finished, wait_done) = tokio::sync::oneshot::channel();
// Some global application context I want to pass to my background tasks // Some global application context I want to pass to my background tasks
let my_app_context = MyApplicationContext::new("Backie Example App"); let my_app_context = MyApplicationContext::new("Backie Example App", notify_finished);
// queue.enqueue(task1).await.unwrap();
// queue.enqueue(task2).await.unwrap();
// queue.enqueue(task3).await.unwrap();
// Store all task to join them later
let mut tasks = JoinSet::new();
for i in 0..1_000 {
tasks.spawn({
let pool = pool.clone();
async move {
let mut connection = pool.get().await.unwrap();
let task = EmptyTask { idx: i };
task.enqueue(&mut connection).await.unwrap();
}
});
}
while let Some(result) = tasks.join_next().await {
let _ = result?;
}
(FinalTask {})
.enqueue(&mut pool.get().await.unwrap())
.await
.unwrap();
log::info!("Tasks created ...");
let started = Instant::now();
// Register the task types I want to use and start the worker pool // Register the task types I want to use and start the worker pool
let (join_handle, _queue) = let join_handle = WorkerPool::new(PgTaskStore::new(pool.clone()), move || my_app_context.clone())
WorkerPool::new(task_store.clone(), move |_| my_app_context.clone()) .register_task_type::<MyTask>()
.register_task_type::<MyTask>() .register_task_type::<MyFailingTask>()
.register_task_type::<MyFailingTask>() .register_task_type::<EmptyTask>()
.configure_queue("default".into()) .register_task_type::<FinalTask>()
.start(async move { .configure_queue("default".into())
let _ = rx.changed().await; .configure_queue(
}) QueueConfig::new("loaded_queue")
.await .pull_interval(Duration::from_millis(100))
.unwrap(); .retention_mode(RetentionMode::RemoveDone)
.num_workers(300),
)
.start(async move {
let _ = rx.changed().await;
})
.await
.unwrap();
log::info!("Workers started ..."); log::info!("Workers started ...");
let task1 = MyTask::new(0); wait_done.await.unwrap();
let task2 = MyTask::new(20_000); let elapsed = started.elapsed();
let task3 = MyFailingTask::new(50_000); println!("Ran 50k jobs in {} seconds", elapsed.as_secs());
let queue = Queue::new(task_store); // or use the `queue` instance returned by the worker pool
queue.enqueue(task1).await.unwrap();
queue.enqueue(task2).await.unwrap();
queue.enqueue(task3).await.unwrap();
log::info!("Tasks created ...");
// Wait for Ctrl+C // Wait for Ctrl+C
let _ = tokio::signal::ctrl_c().await; // let _ = tokio::signal::ctrl_c().await;
log::info!("Stopping ..."); log::info!("Stopping ...");
tx.send(true).unwrap(); tx.send(true).unwrap();
join_handle.await.unwrap(); join_handle.await.unwrap();
log::info!("Workers Stopped!"); log::info!("Workers Stopped!");
Ok(())
} }

BIN
logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

View file

@ -7,7 +7,10 @@ use chrono::Duration;
use chrono::Utc; use chrono::Utc;
use diesel::prelude::*; use diesel::prelude::*;
use diesel::ExpressionMethods; use diesel::ExpressionMethods;
use diesel_async::{pg::AsyncPgConnection, RunQueryDsl}; use diesel::query_builder::{Query, QueryFragment, QueryId};
use diesel::sql_types::{HasSqlType, SingleValue};
use diesel_async::return_futures::GetResult;
use diesel_async::{pg::AsyncPgConnection, AsyncConnection, RunQueryDsl};
impl Task { impl Task {
pub(crate) async fn remove( pub(crate) async fn remove(

View file

@ -2,34 +2,42 @@ use crate::errors::BackieError;
use crate::runnable::BackgroundTask; use crate::runnable::BackgroundTask;
use crate::store::TaskStore; use crate::store::TaskStore;
use crate::task::NewTask; use crate::task::NewTask;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
#[derive(Clone)]
pub struct Queue<S> pub struct Queue<S>
where where
S: TaskStore + Clone, S: TaskStore,
{ {
task_store: Arc<S>, task_store: S,
} }
impl<S> Queue<S> impl<S> Queue<S>
where where
S: TaskStore + Clone, S: TaskStore,
{ {
pub fn new(task_store: S) -> Self { pub fn new(task_store: S) -> Self {
Queue { Queue { task_store }
task_store: Arc::new(task_store),
}
} }
pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), BackieError> pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), BackieError>
where where
BT: BackgroundTask, BT: BackgroundTask,
{ {
// TODO: Add option to specify the timeout of a task
self.task_store self.task_store
.create_task(NewTask::new(background_task, Duration::from_secs(10))?) .create_task(NewTask::with_timeout(background_task, Duration::from_secs(10))?)
.await?; .await?;
Ok(()) Ok(())
} }
} }
impl<S> Clone for Queue<S>
where
S: TaskStore + Clone,
{
fn clone(&self) -> Self {
Self {
task_store: self.task_store.clone(),
}
}
}

View file

@ -1,6 +1,7 @@
use crate::task::{CurrentTask, TaskHash}; use crate::task::{CurrentTask, TaskHash};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{de::DeserializeOwned, ser::Serialize}; use serde::{de::DeserializeOwned, ser::Serialize};
use std::fmt::Debug;
/// The [`BackgroundTask`] trait is used to define the behaviour of a task. You must implement this /// The [`BackgroundTask`] trait is used to define the behaviour of a task. You must implement this
/// trait for all tasks you want to execute. /// trait for all tasks you want to execute.
@ -29,8 +30,9 @@ use serde::{de::DeserializeOwned, ser::Serialize};
/// impl BackgroundTask for MyTask { /// impl BackgroundTask for MyTask {
/// const TASK_NAME: &'static str = "my_task_unique_name"; /// const TASK_NAME: &'static str = "my_task_unique_name";
/// type AppData = (); /// type AppData = ();
/// type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
/// ///
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> { /// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
/// // Do something /// // Do something
/// Ok(()) /// Ok(())
/// } /// }
@ -51,14 +53,17 @@ pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
/// Number of retries for tasks. /// Number of retries for tasks.
/// ///
/// By default, it is set to 5. /// By default, it is set to 3.
const MAX_RETRIES: i32 = 5; const MAX_RETRIES: i32 = 3;
/// The application data provided to this task at runtime. /// The application data provided to this task at runtime.
type AppData: Clone + Send + 'static; type AppData: Clone + Send + 'static;
/// An application custom error type.
type Error: Debug + Send + 'static;
/// Execute the task. This method should define its logic /// Execute the task. This method should define its logic
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>; async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error>;
/// If set to true, no new tasks with the same metadata will be inserted /// If set to true, no new tasks with the same metadata will be inserted
/// By default it is set to false. /// By default it is set to false.

View file

@ -1,5 +1,7 @@
use crate::errors::AsyncQueueError; use crate::errors::AsyncQueueError;
use crate::task::{NewTask, Task, TaskId, TaskState}; use crate::task::{NewTask, Task, TaskId, TaskState};
use crate::BackgroundTask;
use async_trait::async_trait;
use diesel::result::Error::QueryBuilderError; use diesel::result::Error::QueryBuilderError;
use diesel_async::scoped_futures::ScopedFutureExt; use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection; use diesel_async::AsyncConnection;
@ -17,6 +19,23 @@ impl PgTaskStore {
} }
} }
/// A trait that is used to enqueue tasks for the PostgreSQL backend.
#[async_trait::async_trait]
pub trait PgQueueTask {
async fn enqueue(self, connection: &mut AsyncPgConnection) -> Result<(), AsyncQueueError>;
}
impl<T> PgQueueTask for T
where
T: BackgroundTask,
{
async fn enqueue(self, connection: &mut AsyncPgConnection) -> Result<(), AsyncQueueError> {
let new_task = NewTask::new::<T>(self)?;
Task::insert(connection, new_task).await?;
Ok(())
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl TaskStore for PgTaskStore { impl TaskStore for PgTaskStore {
async fn pull_next_task( async fn pull_next_task(

View file

@ -117,9 +117,9 @@ pub struct NewTask {
} }
impl NewTask { impl NewTask {
pub(crate) fn new<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error> pub(crate) fn with_timeout<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error>
where where
T: BackgroundTask, T: BackgroundTask,
{ {
let max_retries = background_task.max_retries(); let max_retries = background_task.max_retries();
let uniq_hash = background_task.uniq(); let uniq_hash = background_task.uniq();
@ -134,6 +134,13 @@ impl NewTask {
max_retries, max_retries,
}) })
} }
pub(crate) fn new<T>(background_task: T) -> Result<Self, serde_json::Error>
where
T: BackgroundTask,
{
Self::with_timeout(background_task, Duration::from_secs(120))
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -30,7 +30,7 @@ pub enum TaskExecError {
TaskDeserializationFailed(#[from] serde_json::Error), TaskDeserializationFailed(#[from] serde_json::Error),
#[error("Task execution failed: {0}")] #[error("Task execution failed: {0}")]
ExecutionFailed(#[from] anyhow::Error), ExecutionFailed(String),
#[error("Task panicked with: {0}")] #[error("Task panicked with: {0}")]
Panicked(String), Panicked(String),
@ -46,8 +46,10 @@ where
{ {
Box::pin(async move { Box::pin(async move {
let background_task: BT = serde_json::from_value(payload)?; let background_task: BT = serde_json::from_value(payload)?;
background_task.run(task_info, app_context).await?; match background_task.run(task_info, app_context).await {
Ok(()) Ok(_) => Ok(()),
Err(err) => Err(TaskExecError::ExecutionFailed(format!("{:?}", err))),
}
}) })
} }
@ -57,7 +59,7 @@ where
AppData: Clone + Send + 'static, AppData: Clone + Send + 'static,
S: TaskStore + Clone, S: TaskStore + Clone,
{ {
store: Arc<S>, store: S,
queue_name: String, queue_name: String,
@ -79,7 +81,7 @@ where
S: TaskStore + Clone, S: TaskStore + Clone,
{ {
pub(crate) fn new( pub(crate) fn new(
store: Arc<S>, store: S,
queue_name: String, queue_name: String,
retention_mode: RetentionMode, retention_mode: RetentionMode,
pull_interval: Duration, pull_interval: Duration,
@ -250,8 +252,9 @@ mod async_worker_tests {
impl BackgroundTask for WorkerAsyncTask { impl BackgroundTask for WorkerAsyncTask {
const TASK_NAME: &'static str = "WorkerAsyncTask"; const TASK_NAME: &'static str = "WorkerAsyncTask";
type AppData = (); type AppData = ();
type Error = ();
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), ()> {
Ok(()) Ok(())
} }
} }
@ -265,8 +268,9 @@ mod async_worker_tests {
impl BackgroundTask for WorkerAsyncTaskSchedule { impl BackgroundTask for WorkerAsyncTaskSchedule {
const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule"; const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
type AppData = (); type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
Ok(()) Ok(())
} }
@ -284,11 +288,12 @@ mod async_worker_tests {
impl BackgroundTask for AsyncFailedTask { impl BackgroundTask for AsyncFailedTask {
const TASK_NAME: &'static str = "AsyncFailedTask"; const TASK_NAME: &'static str = "AsyncFailedTask";
type AppData = (); type AppData = ();
type Error = TaskError;
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), TaskError> {
let message = format!("number {} is wrong :(", self.number); let message = format!("number {} is wrong :(", self.number);
Err(TaskError::Custom(message).into()) Err(TaskError::Custom(message))
} }
fn max_retries(&self) -> i32 { fn max_retries(&self) -> i32 {
@ -303,9 +308,10 @@ mod async_worker_tests {
impl BackgroundTask for AsyncRetryTask { impl BackgroundTask for AsyncRetryTask {
const TASK_NAME: &'static str = "AsyncRetryTask"; const TASK_NAME: &'static str = "AsyncRetryTask";
type AppData = (); type AppData = ();
type Error = TaskError;
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
Err(TaskError::SomethingWrong.into()) Err(TaskError::SomethingWrong)
} }
} }
@ -316,8 +322,9 @@ mod async_worker_tests {
impl BackgroundTask for AsyncTaskType1 { impl BackgroundTask for AsyncTaskType1 {
const TASK_NAME: &'static str = "AsyncTaskType1"; const TASK_NAME: &'static str = "AsyncTaskType1";
type AppData = (); type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
Ok(()) Ok(())
} }
} }
@ -329,8 +336,9 @@ mod async_worker_tests {
impl BackgroundTask for AsyncTaskType2 { impl BackgroundTask for AsyncTaskType2 {
const TASK_NAME: &'static str = "AsyncTaskType2"; const TASK_NAME: &'static str = "AsyncTaskType2";
type AppData = (); type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
Ok(()) Ok(())
} }
} }

View file

@ -1,5 +1,4 @@
use crate::errors::BackieError; use crate::errors::BackieError;
use crate::queue::Queue;
use crate::runnable::BackgroundTask; use crate::runnable::BackgroundTask;
use crate::store::TaskStore; use crate::store::TaskStore;
use crate::worker::{runnable, ExecuteTaskFn}; use crate::worker::{runnable, ExecuteTaskFn};
@ -19,10 +18,7 @@ where
S: TaskStore + Clone, S: TaskStore + Clone,
{ {
/// Storage of tasks. /// Storage of tasks.
task_store: Arc<S>, task_store: S,
/// Queue used to spawn tasks.
queue: Queue<S>,
/// Make possible to load the application data. /// Make possible to load the application data.
/// ///
@ -49,16 +45,10 @@ where
/// Create a new worker pool. /// Create a new worker pool.
pub fn new<A>(task_store: S, application_data_fn: A) -> Self pub fn new<A>(task_store: S, application_data_fn: A) -> Self
where where
A: Fn(Queue<S>) -> AppData + Send + Sync + 'static, A: Fn() -> AppData + Send + Sync + 'static,
{ {
let queue = Queue::new(task_store.clone());
let application_data_fn = {
let queue = queue.clone();
move || application_data_fn(queue.clone())
};
Self { Self {
task_store: Arc::new(task_store), task_store,
queue,
application_data_fn: Arc::new(application_data_fn), application_data_fn: Arc::new(application_data_fn),
task_registry: BTreeMap::new(), task_registry: BTreeMap::new(),
queue_tasks: BTreeMap::new(), queue_tasks: BTreeMap::new(),
@ -85,10 +75,7 @@ where
self self
} }
pub async fn start<F>( pub async fn start<F>(self, graceful_shutdown: F) -> Result<JoinHandle<()>, BackieError>
self,
graceful_shutdown: F,
) -> Result<(JoinHandle<()>, Queue<S>), BackieError>
where where
F: Future<Output = ()> + Send + 'static, F: Future<Output = ()> + Send + 'static,
{ {
@ -127,28 +114,25 @@ where
} }
} }
Ok(( Ok(tokio::spawn(async move {
tokio::spawn(async move { graceful_shutdown.await;
graceful_shutdown.await; if let Err(err) = tx.send(()) {
if let Err(err) = tx.send(()) { log::warn!("Failed to send shutdown signal to worker pool: {}", err);
log::warn!("Failed to send shutdown signal to worker pool: {}", err); } else {
// Wait for all workers to finish processing
let results = join_all(worker_handles)
.await
.into_iter()
.filter(Result::is_err)
.map(Result::unwrap_err)
.collect::<Vec<_>>();
if !results.is_empty() {
log::error!("Worker pool stopped with errors: {:?}", results);
} else { } else {
// Wait for all workers to finish processing log::info!("Worker pool stopped gracefully");
let results = join_all(worker_handles)
.await
.into_iter()
.filter(Result::is_err)
.map(Result::unwrap_err)
.collect::<Vec<_>>();
if !results.is_empty() {
log::error!("Worker pool stopped with errors: {:?}", results);
} else {
log::info!("Worker pool stopped gracefully");
}
} }
}), }
self.queue, }))
))
} }
} }
@ -232,6 +216,7 @@ mod tests {
use crate::store::test_store::MemoryTaskStore; use crate::store::test_store::MemoryTaskStore;
use crate::store::PgTaskStore; use crate::store::PgTaskStore;
use crate::task::CurrentTask; use crate::task::CurrentTask;
use crate::Queue;
use async_trait::async_trait; use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection; use diesel_async::AsyncPgConnection;
@ -240,7 +225,7 @@ mod tests {
use tokio::sync::Mutex; use tokio::sync::Mutex;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct ApplicationContext { pub struct ApplicationContext {
app_name: String, app_name: String,
} }
@ -261,17 +246,50 @@ mod tests {
person: String, person: String,
} }
/// This tests that one can customize the task parameters for the application.
#[async_trait] #[async_trait]
impl BackgroundTask for GreetingTask { trait MyAppTask {
const TASK_NAME: &'static str = "my_task"; const TASK_NAME: &'static str;
const QUEUE: &'static str = "default";
async fn run(
&self,
task_info: CurrentTask,
app_context: ApplicationContext,
) -> Result<(), ()>;
}
#[async_trait]
impl<T> BackgroundTask for T
where
T: MyAppTask + serde::de::DeserializeOwned + serde::ser::Serialize + Sync + Send + 'static,
{
const TASK_NAME: &'static str = T::TASK_NAME;
const QUEUE: &'static str = T::QUEUE;
type AppData = ApplicationContext; type AppData = ApplicationContext;
type Error = ();
async fn run( async fn run(
&self, &self,
task_info: CurrentTask, task_info: CurrentTask,
app_context: Self::AppData, app_context: Self::AppData,
) -> Result<(), anyhow::Error> { ) -> Result<(), Self::Error> {
self.run(task_info, app_context).await
}
}
#[async_trait]
impl MyAppTask for GreetingTask {
const TASK_NAME: &'static str = "my_task";
async fn run(
&self,
task_info: CurrentTask,
app_context: ApplicationContext,
) -> Result<(), ()> {
println!( println!(
"[{}] Hello {}! I'm {}.", "[{}] Hello {}! I'm {}.",
task_info.id(), task_info.id(),
@ -292,12 +310,9 @@ mod tests {
const QUEUE: &'static str = "other_queue"; const QUEUE: &'static str = "other_queue";
type AppData = ApplicationContext; type AppData = ApplicationContext;
type Error = ();
async fn run( async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
println!( println!(
"[{}] Other task with {}!", "[{}] Other task with {}!",
task.id(), task.id(),
@ -311,7 +326,7 @@ mod tests {
async fn validate_all_registered_tasks_queues_are_configured() { async fn validate_all_registered_tasks_queues_are_configured() {
let my_app_context = ApplicationContext::new(); let my_app_context = ApplicationContext::new();
let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) let result = WorkerPool::new(memory_store(), move || my_app_context.clone())
.register_task_type::<GreetingTask>() .register_task_type::<GreetingTask>()
.start(futures::future::ready(())) .start(futures::future::ready(()))
.await; .await;
@ -329,14 +344,16 @@ mod tests {
async fn test_worker_pool_with_task() { async fn test_worker_pool_with_task() {
let my_app_context = ApplicationContext::new(); let my_app_context = ApplicationContext::new();
let (join_handle, queue) = let task_store = memory_store();
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(GreetingTask::QUEUE.into())
.start(futures::future::ready(()))
.await
.unwrap();
let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(<GreetingTask as MyAppTask>::QUEUE.into())
.start(futures::future::ready(()))
.await
.unwrap();
let queue = Queue::new(task_store);
queue queue
.enqueue(GreetingTask { .enqueue(GreetingTask {
person: "Rafael".to_string(), person: "Rafael".to_string(),
@ -351,16 +368,17 @@ mod tests {
async fn test_worker_pool_with_multiple_task_types() { async fn test_worker_pool_with_multiple_task_types() {
let my_app_context = ApplicationContext::new(); let my_app_context = ApplicationContext::new();
let (join_handle, queue) = let task_store = memory_store();
WorkerPool::new(memory_store().await, move |_| my_app_context.clone()) let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone())
.register_task_type::<GreetingTask>() .register_task_type::<GreetingTask>()
.register_task_type::<OtherTask>() .register_task_type::<OtherTask>()
.configure_queue("default".into()) .configure_queue("default".into())
.configure_queue("other_queue".into()) .configure_queue("other_queue".into())
.start(futures::future::ready(())) .start(futures::future::ready(()))
.await .await
.unwrap(); .unwrap();
let queue = Queue::new(task_store.clone());
queue queue
.enqueue(GreetingTask { .enqueue(GreetingTask {
person: "Rafael".to_string(), person: "Rafael".to_string(),
@ -391,11 +409,9 @@ mod tests {
type AppData = NotifyFinishedContext; type AppData = NotifyFinishedContext;
async fn run( type Error = ();
&self,
task: CurrentTask, async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> {
context: Self::AppData,
) -> Result<(), anyhow::Error> {
// Notify the test that the task ran // Notify the test that the task ran
match context.notify_finished.lock().await.take() { match context.notify_finished.lock().await.take() {
None => println!("Cannot notify, already done that!"), None => println!("Cannot notify, already done that!"),
@ -414,17 +430,19 @@ mod tests {
notify_finished: Arc::new(Mutex::new(Some(tx))), notify_finished: Arc::new(Mutex::new(Some(tx))),
}; };
let (join_handle, queue) = let memory_store = memory_store();
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<NotifyFinished>()
.configure_queue("default".into())
.start(async move {
rx.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let join_handle = WorkerPool::new(memory_store.clone(), move || my_app_context.clone())
.register_task_type::<NotifyFinished>()
.configure_queue("default".into())
.start(async move {
rx.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let queue = Queue::new(memory_store);
// Notifies the worker pool to stop after the task is executed // Notifies the worker pool to stop after the task is executed
queue.enqueue(NotifyFinished).await.unwrap(); queue.enqueue(NotifyFinished).await.unwrap();
@ -455,11 +473,13 @@ mod tests {
type AppData = NotifyUnknownRanContext; type AppData = NotifyUnknownRanContext;
type Error = ();
async fn run( async fn run(
&self, &self,
task: CurrentTask, task: CurrentTask,
context: Self::AppData, context: Self::AppData,
) -> Result<(), anyhow::Error> { ) -> Result<(), Self::Error> {
// Notify the test that the task ran // Notify the test that the task ran
match context.should_stop.lock().await.take() { match context.should_stop.lock().await.take() {
None => println!("Cannot notify, already done that!"), None => println!("Cannot notify, already done that!"),
@ -481,11 +501,9 @@ mod tests {
type AppData = NotifyUnknownRanContext; type AppData = NotifyUnknownRanContext;
async fn run( type Error = ();
&self,
task: CurrentTask, async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> {
context: Self::AppData,
) -> Result<(), anyhow::Error> {
println!("[{}] Unknown task ran!", task.id()); println!("[{}] Unknown task ran!", task.id());
context.unknown_task_ran.store(true, Ordering::Relaxed); context.unknown_task_ran.store(true, Ordering::Relaxed);
Ok(()) Ok(())
@ -499,11 +517,11 @@ mod tests {
unknown_task_ran: Arc::new(AtomicBool::new(false)), unknown_task_ran: Arc::new(AtomicBool::new(false)),
}; };
let task_store = memory_store().await; let task_store = memory_store();
let (join_handle, queue) = WorkerPool::new(task_store, { let join_handle = WorkerPool::new(task_store.clone(), {
let my_app_context = my_app_context.clone(); let my_app_context = my_app_context.clone();
move |_| my_app_context.clone() move || my_app_context.clone()
}) })
.register_task_type::<NotifyStopDuringRun>() .register_task_type::<NotifyStopDuringRun>()
.configure_queue("default".into()) .configure_queue("default".into())
@ -514,6 +532,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let queue = Queue::new(task_store);
// Enqueue a task that is not registered // Enqueue a task that is not registered
queue.enqueue(UnknownTask).await.unwrap(); queue.enqueue(UnknownTask).await.unwrap();
@ -537,21 +556,18 @@ mod tests {
impl BackgroundTask for BrokenTask { impl BackgroundTask for BrokenTask {
const TASK_NAME: &'static str = "panic_me"; const TASK_NAME: &'static str = "panic_me";
type AppData = (); type AppData = ();
type Error = ();
async fn run( async fn run(&self, _task: CurrentTask, _context: Self::AppData) -> Result<(), ()> {
&self,
_task: CurrentTask,
_context: Self::AppData,
) -> Result<(), anyhow::Error> {
panic!("Oh no!"); panic!("Oh no!");
} }
} }
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel(); let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
let task_store = memory_store().await; let task_store = memory_store();
let (worker_pool_finished, queue) = WorkerPool::new(task_store.clone(), |_| ()) let worker_pool_finished = WorkerPool::new(task_store.clone(), || ())
.register_task_type::<BrokenTask>() .register_task_type::<BrokenTask>()
.configure_queue("default".into()) .configure_queue("default".into())
.start(async move { .start(async move {
@ -560,6 +576,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let queue = Queue::new(task_store.clone());
// Enqueue a task that will panic // Enqueue a task that will panic
queue.enqueue(BrokenTask).await.unwrap(); queue.enqueue(BrokenTask).await.unwrap();
@ -609,11 +626,13 @@ mod tests {
type AppData = PlayerContext; type AppData = PlayerContext;
type Error = ();
async fn run( async fn run(
&self, &self,
_task: CurrentTask, _task: CurrentTask,
context: Self::AppData, context: Self::AppData,
) -> Result<(), anyhow::Error> { ) -> Result<(), Self::Error> {
loop { loop {
let msg = context.ping_rx.lock().await.recv().await.unwrap(); let msg = context.ping_rx.lock().await.recv().await.unwrap();
match msg { match msg {
@ -643,11 +662,11 @@ mod tests {
ping_rx: Arc::new(Mutex::new(ping_rx)), ping_rx: Arc::new(Mutex::new(ping_rx)),
}; };
let task_store = memory_store().await; let task_store = memory_store();
let (worker_pool_finished, queue) = WorkerPool::new(task_store, { let worker_pool_finished = WorkerPool::new(task_store.clone(), {
let player_context = player_context.clone(); let player_context = player_context.clone();
move |_| player_context.clone() move || player_context.clone()
}) })
.register_task_type::<KeepAliveTask>() .register_task_type::<KeepAliveTask>()
.configure_queue("default".into()) .configure_queue("default".into())
@ -658,6 +677,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let queue = Queue::new(task_store);
queue.enqueue(KeepAliveTask).await.unwrap(); queue.enqueue(KeepAliveTask).await.unwrap();
// Make sure task is running // Make sure task is running
@ -683,7 +703,7 @@ mod tests {
ping_tx.send(PingPongGame::StopThisNow).await.unwrap(); ping_tx.send(PingPongGame::StopThisNow).await.unwrap();
} }
async fn memory_store() -> MemoryTaskStore { fn memory_store() -> MemoryTaskStore {
MemoryTaskStore::default() MemoryTaskStore::default()
} }
@ -692,15 +712,15 @@ mod tests {
async fn test_worker_pool_with_pg_store() { async fn test_worker_pool_with_pg_store() {
let my_app_context = ApplicationContext::new(); let my_app_context = ApplicationContext::new();
let (join_handle, _queue) = let join_handle = WorkerPool::new(pg_task_store().await, move || my_app_context.clone())
WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone()) .register_task_type::<GreetingTask>()
.register_task_type::<GreetingTask>() .configure_queue(
.configure_queue( QueueConfig::new(<GreetingTask as MyAppTask>::QUEUE)
QueueConfig::new(GreetingTask::QUEUE).retention_mode(RetentionMode::RemoveDone), .retention_mode(RetentionMode::RemoveDone),
) )
.start(futures::future::ready(())) .start(futures::future::ready(()))
.await .await
.unwrap(); .unwrap();
join_handle.await.unwrap(); join_handle.await.unwrap();
} }