Compare commits

...

10 commits

12 changed files with 348 additions and 191 deletions

1
.gitignore vendored
View file

@ -2,3 +2,4 @@
Cargo.lock
docs/content/docs/CHANGELOG.md
docs/content/docs/README.md
.DS_Store

View file

@ -1,6 +1,6 @@
[package]
name = "backie"
version = "0.3.0"
version = "0.6.0"
authors = [
"Rafael Caricio <rafael@caricio.com>",
]
@ -17,7 +17,6 @@ chrono = "0.4"
log = "0.4"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
anyhow = "1"
thiserror = "1"
uuid = { version = "1.1", features = ["v4", "serde"] }
async-trait = "0.1"

View file

@ -1,5 +1,6 @@
# Backie 🚲
<p align="center"><img src="logo.png" alt="Backie" width="400"></p>
---
Async persistent background task processing for Rust applications with Tokio. Queue asynchronous tasks
to be processed by workers. It's designed to be easy to use and horizontally scalable. It uses Postgres as
a storage backend and can also be extended to support other types of storage.
@ -31,6 +32,25 @@ Here are some of the Backie's key features:
- Task timeout: Tasks are retried if they are not completed in time
- Scheduling of tasks: Tasks can be scheduled to be executed at a specific time
## Task execution protocol
The following diagram shows the protocol used to execute tasks:
```mermaid
stateDiagram-v2
[*] --> Ready
Ready --> Running: Task is picked up by a worker
Running --> Done: Task is finished
Running --> Failed: Task failed
Failed --> Ready: Task is retried
Failed --> [*]: Task is not retried anymore, max retries reached
Done --> [*]
```
When a task goes from `Running` to `Failed` it is retried. The number of retries is controlled by the
[`BackgroundTask::MAX_RETRIES`] attribute. The default implementation uses `3` retries.
## Safety
This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust.
@ -53,7 +73,6 @@ If you are not already using, you will also want to include the following depend
```toml
[dependencies]
async-trait = "0.1"
anyhow = "1"
serde = { version = "1.0", features = ["derive"] }
diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uuid"] }
diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
@ -75,6 +94,9 @@ the whole application. This attribute is critical for reconstructing the task ba
The [`BackgroundTask::AppData`] can be used to argument the task with your application specific contextual information.
This is useful for example to pass a database connection pool to the task or other application configuration.
The [`BackgroundTask::Error`] is the error type that will be returned by the [`BackgroundTask::run`] method. You can
use this to define your own error type for your tasks.
The [`BackgroundTask::run`] method is where you define the behaviour of your background task execution. This method
will be called by the task queue workers.
@ -92,8 +114,9 @@ pub struct MyTask {
impl BackgroundTask for MyTask {
const TASK_NAME: &'static str = "my_task_unique_name";
type AppData = ();
type Error = ();
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
// Do something
Ok(())
}

View file

@ -1,20 +1,33 @@
use async_trait::async_trait;
use backie::{BackgroundTask, CurrentTask};
use backie::{BackgroundTask, CurrentTask, QueueConfig, RetentionMode};
use backie::{PgTaskStore, Queue, WorkerPool};
use diesel_async::pg::AsyncPgConnection;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
#[derive(Clone, Debug)]
pub struct MyApplicationContext {
app_name: String,
notify_finished: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
}
impl MyApplicationContext {
pub fn new(app_name: &str) -> Self {
pub fn new(app_name: &str, notify_finished: tokio::sync::oneshot::Sender<()>) -> Self {
Self {
app_name: app_name.to_string(),
notify_finished: Arc::new(Mutex::new(Some(notify_finished))),
}
}
pub async fn notify_finished(&self) {
let mut lock = self.notify_finished.lock().await;
if let Some(sender) = lock.take() {
sender.send(()).unwrap();
}
}
}
@ -34,14 +47,9 @@ impl MyTask {
impl BackgroundTask for MyTask {
const TASK_NAME: &'static str = "my_task";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> {
// let new_task = MyTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> {
log::info!(
"[{}] Hello from {}! the current number is {}",
task.id(),
@ -70,19 +78,9 @@ impl MyFailingTask {
impl BackgroundTask for MyFailingTask {
const TASK_NAME: &'static str = "my_failing_task";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> {
// let new_task = MyFailingTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
// task.id();
// task.keep_alive().await?;
// task.previous_error();
// task.retry_count();
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> {
log::info!("[{}] the current number is {}", task.id(), self.number);
tokio::time::sleep(Duration::from_secs(3)).await;
@ -91,58 +89,124 @@ impl BackgroundTask for MyFailingTask {
}
}
#[derive(Serialize, Deserialize)]
struct EmptyTask {
pub idx: u64,
}
#[async_trait]
impl BackgroundTask for EmptyTask {
const TASK_NAME: &'static str = "empty_task";
const QUEUE: &'static str = "loaded_queue";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, _task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct FinalTask;
#[async_trait]
impl BackgroundTask for FinalTask {
const TASK_NAME: &'static str = "final_task";
const QUEUE: &'static str = "loaded_queue";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, _task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> {
ctx.notify_finished().await;
Ok(())
}
}
#[tokio::main]
async fn main() {
async fn main() -> anyhow::Result<()> {
env_logger::init();
let connection_url = "postgres://postgres:password@localhost/backie";
log::info!("Starting...");
let max_pool_size: u32 = 3;
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(connection_url);
let pool = Pool::builder()
.max_size(max_pool_size)
.max_size(300)
.min_idle(Some(1))
.build(manager)
.await
.unwrap();
log::info!("Pool created ...");
let task_store = PgTaskStore::new(pool);
let (tx, mut rx) = tokio::sync::watch::channel(false);
let (notify_finished, wait_done) = tokio::sync::oneshot::channel();
// Some global application context I want to pass to my background tasks
let my_app_context = MyApplicationContext::new("Backie Example App");
let my_app_context = MyApplicationContext::new("Backie Example App", notify_finished);
// queue.enqueue(task1).await.unwrap();
// queue.enqueue(task2).await.unwrap();
// queue.enqueue(task3).await.unwrap();
// Store all task to join them later
let mut tasks = JoinSet::new();
for i in 0..1_000 {
tasks.spawn({
let pool = pool.clone();
async move {
let mut connection = pool.get().await.unwrap();
let task = EmptyTask { idx: i };
task.enqueue(&mut connection).await.unwrap();
}
});
}
while let Some(result) = tasks.join_next().await {
let _ = result?;
}
(FinalTask {})
.enqueue(&mut pool.get().await.unwrap())
.await
.unwrap();
log::info!("Tasks created ...");
let started = Instant::now();
// Register the task types I want to use and start the worker pool
let (join_handle, _queue) =
WorkerPool::new(task_store.clone(), move |_| my_app_context.clone())
.register_task_type::<MyTask>()
.register_task_type::<MyFailingTask>()
.configure_queue("default".into())
.start(async move {
let _ = rx.changed().await;
})
.await
.unwrap();
let join_handle = WorkerPool::new(PgTaskStore::new(pool.clone()), move || my_app_context.clone())
.register_task_type::<MyTask>()
.register_task_type::<MyFailingTask>()
.register_task_type::<EmptyTask>()
.register_task_type::<FinalTask>()
.configure_queue("default".into())
.configure_queue(
QueueConfig::new("loaded_queue")
.pull_interval(Duration::from_millis(100))
.retention_mode(RetentionMode::RemoveDone)
.num_workers(300),
)
.start(async move {
let _ = rx.changed().await;
})
.await
.unwrap();
log::info!("Workers started ...");
let task1 = MyTask::new(0);
let task2 = MyTask::new(20_000);
let task3 = MyFailingTask::new(50_000);
let queue = Queue::new(task_store); // or use the `queue` instance returned by the worker pool
queue.enqueue(task1).await.unwrap();
queue.enqueue(task2).await.unwrap();
queue.enqueue(task3).await.unwrap();
log::info!("Tasks created ...");
wait_done.await.unwrap();
let elapsed = started.elapsed();
println!("Ran 50k jobs in {} seconds", elapsed.as_secs());
// Wait for Ctrl+C
let _ = tokio::signal::ctrl_c().await;
// let _ = tokio::signal::ctrl_c().await;
log::info!("Stopping ...");
tx.send(true).unwrap();
join_handle.await.unwrap();
log::info!("Workers Stopped!");
Ok(())
}

BIN
logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

View file

@ -7,7 +7,10 @@ use chrono::Duration;
use chrono::Utc;
use diesel::prelude::*;
use diesel::ExpressionMethods;
use diesel_async::{pg::AsyncPgConnection, RunQueryDsl};
use diesel::query_builder::{Query, QueryFragment, QueryId};
use diesel::sql_types::{HasSqlType, SingleValue};
use diesel_async::return_futures::GetResult;
use diesel_async::{pg::AsyncPgConnection, AsyncConnection, RunQueryDsl};
impl Task {
pub(crate) async fn remove(

View file

@ -2,34 +2,42 @@ use crate::errors::BackieError;
use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
use crate::task::NewTask;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct Queue<S>
where
S: TaskStore + Clone,
S: TaskStore,
{
task_store: Arc<S>,
task_store: S,
}
impl<S> Queue<S>
where
S: TaskStore + Clone,
S: TaskStore,
{
pub fn new(task_store: S) -> Self {
Queue {
task_store: Arc::new(task_store),
}
Queue { task_store }
}
pub async fn enqueue<BT>(&self, background_task: BT) -> Result<(), BackieError>
where
BT: BackgroundTask,
{
// TODO: Add option to specify the timeout of a task
self.task_store
.create_task(NewTask::new(background_task, Duration::from_secs(10))?)
.create_task(NewTask::with_timeout(background_task, Duration::from_secs(10))?)
.await?;
Ok(())
}
}
impl<S> Clone for Queue<S>
where
S: TaskStore + Clone,
{
fn clone(&self) -> Self {
Self {
task_store: self.task_store.clone(),
}
}
}

View file

@ -1,6 +1,7 @@
use crate::task::{CurrentTask, TaskHash};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, ser::Serialize};
use std::fmt::Debug;
/// The [`BackgroundTask`] trait is used to define the behaviour of a task. You must implement this
/// trait for all tasks you want to execute.
@ -29,8 +30,9 @@ use serde::{de::DeserializeOwned, ser::Serialize};
/// impl BackgroundTask for MyTask {
/// const TASK_NAME: &'static str = "my_task_unique_name";
/// type AppData = ();
/// type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
///
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> {
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
/// // Do something
/// Ok(())
/// }
@ -51,14 +53,17 @@ pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
/// Number of retries for tasks.
///
/// By default, it is set to 5.
const MAX_RETRIES: i32 = 5;
/// By default, it is set to 3.
const MAX_RETRIES: i32 = 3;
/// The application data provided to this task at runtime.
type AppData: Clone + Send + 'static;
/// An application custom error type.
type Error: Debug + Send + 'static;
/// Execute the task. This method should define its logic
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>;
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error>;
/// If set to true, no new tasks with the same metadata will be inserted
/// By default it is set to false.

View file

@ -1,5 +1,7 @@
use crate::errors::AsyncQueueError;
use crate::task::{NewTask, Task, TaskId, TaskState};
use crate::BackgroundTask;
use async_trait::async_trait;
use diesel::result::Error::QueryBuilderError;
use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection;
@ -17,6 +19,23 @@ impl PgTaskStore {
}
}
/// A trait that is used to enqueue tasks for the PostgreSQL backend.
#[async_trait::async_trait]
pub trait PgQueueTask {
async fn enqueue(self, connection: &mut AsyncPgConnection) -> Result<(), AsyncQueueError>;
}
impl<T> PgQueueTask for T
where
T: BackgroundTask,
{
async fn enqueue(self, connection: &mut AsyncPgConnection) -> Result<(), AsyncQueueError> {
let new_task = NewTask::new::<T>(self)?;
Task::insert(connection, new_task).await?;
Ok(())
}
}
#[async_trait::async_trait]
impl TaskStore for PgTaskStore {
async fn pull_next_task(

View file

@ -117,9 +117,9 @@ pub struct NewTask {
}
impl NewTask {
pub(crate) fn new<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error>
where
T: BackgroundTask,
pub(crate) fn with_timeout<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error>
where
T: BackgroundTask,
{
let max_retries = background_task.max_retries();
let uniq_hash = background_task.uniq();
@ -134,6 +134,13 @@ impl NewTask {
max_retries,
})
}
pub(crate) fn new<T>(background_task: T) -> Result<Self, serde_json::Error>
where
T: BackgroundTask,
{
Self::with_timeout(background_task, Duration::from_secs(120))
}
}
#[cfg(test)]

View file

@ -30,7 +30,7 @@ pub enum TaskExecError {
TaskDeserializationFailed(#[from] serde_json::Error),
#[error("Task execution failed: {0}")]
ExecutionFailed(#[from] anyhow::Error),
ExecutionFailed(String),
#[error("Task panicked with: {0}")]
Panicked(String),
@ -46,8 +46,10 @@ where
{
Box::pin(async move {
let background_task: BT = serde_json::from_value(payload)?;
background_task.run(task_info, app_context).await?;
Ok(())
match background_task.run(task_info, app_context).await {
Ok(_) => Ok(()),
Err(err) => Err(TaskExecError::ExecutionFailed(format!("{:?}", err))),
}
})
}
@ -57,7 +59,7 @@ where
AppData: Clone + Send + 'static,
S: TaskStore + Clone,
{
store: Arc<S>,
store: S,
queue_name: String,
@ -79,7 +81,7 @@ where
S: TaskStore + Clone,
{
pub(crate) fn new(
store: Arc<S>,
store: S,
queue_name: String,
retention_mode: RetentionMode,
pull_interval: Duration,
@ -250,8 +252,9 @@ mod async_worker_tests {
impl BackgroundTask for WorkerAsyncTask {
const TASK_NAME: &'static str = "WorkerAsyncTask";
type AppData = ();
type Error = ();
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), ()> {
Ok(())
}
}
@ -265,8 +268,9 @@ mod async_worker_tests {
impl BackgroundTask for WorkerAsyncTaskSchedule {
const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
Ok(())
}
@ -284,11 +288,12 @@ mod async_worker_tests {
impl BackgroundTask for AsyncFailedTask {
const TASK_NAME: &'static str = "AsyncFailedTask";
type AppData = ();
type Error = TaskError;
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), TaskError> {
let message = format!("number {} is wrong :(", self.number);
Err(TaskError::Custom(message).into())
Err(TaskError::Custom(message))
}
fn max_retries(&self) -> i32 {
@ -303,9 +308,10 @@ mod async_worker_tests {
impl BackgroundTask for AsyncRetryTask {
const TASK_NAME: &'static str = "AsyncRetryTask";
type AppData = ();
type Error = TaskError;
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Err(TaskError::SomethingWrong.into())
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
Err(TaskError::SomethingWrong)
}
}
@ -316,8 +322,9 @@ mod async_worker_tests {
impl BackgroundTask for AsyncTaskType1 {
const TASK_NAME: &'static str = "AsyncTaskType1";
type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
Ok(())
}
}
@ -329,8 +336,9 @@ mod async_worker_tests {
impl BackgroundTask for AsyncTaskType2 {
const TASK_NAME: &'static str = "AsyncTaskType2";
type AppData = ();
type Error = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
Ok(())
}
}

View file

@ -1,5 +1,4 @@
use crate::errors::BackieError;
use crate::queue::Queue;
use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
use crate::worker::{runnable, ExecuteTaskFn};
@ -19,10 +18,7 @@ where
S: TaskStore + Clone,
{
/// Storage of tasks.
task_store: Arc<S>,
/// Queue used to spawn tasks.
queue: Queue<S>,
task_store: S,
/// Make possible to load the application data.
///
@ -49,16 +45,10 @@ where
/// Create a new worker pool.
pub fn new<A>(task_store: S, application_data_fn: A) -> Self
where
A: Fn(Queue<S>) -> AppData + Send + Sync + 'static,
A: Fn() -> AppData + Send + Sync + 'static,
{
let queue = Queue::new(task_store.clone());
let application_data_fn = {
let queue = queue.clone();
move || application_data_fn(queue.clone())
};
Self {
task_store: Arc::new(task_store),
queue,
task_store,
application_data_fn: Arc::new(application_data_fn),
task_registry: BTreeMap::new(),
queue_tasks: BTreeMap::new(),
@ -85,10 +75,7 @@ where
self
}
pub async fn start<F>(
self,
graceful_shutdown: F,
) -> Result<(JoinHandle<()>, Queue<S>), BackieError>
pub async fn start<F>(self, graceful_shutdown: F) -> Result<JoinHandle<()>, BackieError>
where
F: Future<Output = ()> + Send + 'static,
{
@ -127,28 +114,25 @@ where
}
}
Ok((
tokio::spawn(async move {
graceful_shutdown.await;
if let Err(err) = tx.send(()) {
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
Ok(tokio::spawn(async move {
graceful_shutdown.await;
if let Err(err) = tx.send(()) {
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
} else {
// Wait for all workers to finish processing
let results = join_all(worker_handles)
.await
.into_iter()
.filter(Result::is_err)
.map(Result::unwrap_err)
.collect::<Vec<_>>();
if !results.is_empty() {
log::error!("Worker pool stopped with errors: {:?}", results);
} else {
// Wait for all workers to finish processing
let results = join_all(worker_handles)
.await
.into_iter()
.filter(Result::is_err)
.map(Result::unwrap_err)
.collect::<Vec<_>>();
if !results.is_empty() {
log::error!("Worker pool stopped with errors: {:?}", results);
} else {
log::info!("Worker pool stopped gracefully");
}
log::info!("Worker pool stopped gracefully");
}
}),
self.queue,
))
}
}))
}
}
@ -232,6 +216,7 @@ mod tests {
use crate::store::test_store::MemoryTaskStore;
use crate::store::PgTaskStore;
use crate::task::CurrentTask;
use crate::Queue;
use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
@ -240,7 +225,7 @@ mod tests {
use tokio::sync::Mutex;
#[derive(Clone, Debug)]
struct ApplicationContext {
pub struct ApplicationContext {
app_name: String,
}
@ -261,17 +246,50 @@ mod tests {
person: String,
}
/// This tests that one can customize the task parameters for the application.
#[async_trait]
impl BackgroundTask for GreetingTask {
const TASK_NAME: &'static str = "my_task";
trait MyAppTask {
const TASK_NAME: &'static str;
const QUEUE: &'static str = "default";
async fn run(
&self,
task_info: CurrentTask,
app_context: ApplicationContext,
) -> Result<(), ()>;
}
#[async_trait]
impl<T> BackgroundTask for T
where
T: MyAppTask + serde::de::DeserializeOwned + serde::ser::Serialize + Sync + Send + 'static,
{
const TASK_NAME: &'static str = T::TASK_NAME;
const QUEUE: &'static str = T::QUEUE;
type AppData = ApplicationContext;
type Error = ();
async fn run(
&self,
task_info: CurrentTask,
app_context: Self::AppData,
) -> Result<(), anyhow::Error> {
) -> Result<(), Self::Error> {
self.run(task_info, app_context).await
}
}
#[async_trait]
impl MyAppTask for GreetingTask {
const TASK_NAME: &'static str = "my_task";
async fn run(
&self,
task_info: CurrentTask,
app_context: ApplicationContext,
) -> Result<(), ()> {
println!(
"[{}] Hello {}! I'm {}.",
task_info.id(),
@ -292,12 +310,9 @@ mod tests {
const QUEUE: &'static str = "other_queue";
type AppData = ApplicationContext;
type Error = ();
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
println!(
"[{}] Other task with {}!",
task.id(),
@ -311,7 +326,7 @@ mod tests {
async fn validate_all_registered_tasks_queues_are_configured() {
let my_app_context = ApplicationContext::new();
let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
let result = WorkerPool::new(memory_store(), move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.start(futures::future::ready(()))
.await;
@ -329,14 +344,16 @@ mod tests {
async fn test_worker_pool_with_task() {
let my_app_context = ApplicationContext::new();
let (join_handle, queue) =
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(GreetingTask::QUEUE.into())
.start(futures::future::ready(()))
.await
.unwrap();
let task_store = memory_store();
let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(<GreetingTask as MyAppTask>::QUEUE.into())
.start(futures::future::ready(()))
.await
.unwrap();
let queue = Queue::new(task_store);
queue
.enqueue(GreetingTask {
person: "Rafael".to_string(),
@ -351,16 +368,17 @@ mod tests {
async fn test_worker_pool_with_multiple_task_types() {
let my_app_context = ApplicationContext::new();
let (join_handle, queue) =
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.register_task_type::<OtherTask>()
.configure_queue("default".into())
.configure_queue("other_queue".into())
.start(futures::future::ready(()))
.await
.unwrap();
let task_store = memory_store();
let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.register_task_type::<OtherTask>()
.configure_queue("default".into())
.configure_queue("other_queue".into())
.start(futures::future::ready(()))
.await
.unwrap();
let queue = Queue::new(task_store.clone());
queue
.enqueue(GreetingTask {
person: "Rafael".to_string(),
@ -391,11 +409,9 @@ mod tests {
type AppData = NotifyFinishedContext;
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
type Error = ();
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> {
// Notify the test that the task ran
match context.notify_finished.lock().await.take() {
None => println!("Cannot notify, already done that!"),
@ -414,17 +430,19 @@ mod tests {
notify_finished: Arc::new(Mutex::new(Some(tx))),
};
let (join_handle, queue) =
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<NotifyFinished>()
.configure_queue("default".into())
.start(async move {
rx.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let memory_store = memory_store();
let join_handle = WorkerPool::new(memory_store.clone(), move || my_app_context.clone())
.register_task_type::<NotifyFinished>()
.configure_queue("default".into())
.start(async move {
rx.await.unwrap();
println!("Worker pool got notified to stop");
})
.await
.unwrap();
let queue = Queue::new(memory_store);
// Notifies the worker pool to stop after the task is executed
queue.enqueue(NotifyFinished).await.unwrap();
@ -455,11 +473,13 @@ mod tests {
type AppData = NotifyUnknownRanContext;
type Error = ();
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
) -> Result<(), Self::Error> {
// Notify the test that the task ran
match context.should_stop.lock().await.take() {
None => println!("Cannot notify, already done that!"),
@ -481,11 +501,9 @@ mod tests {
type AppData = NotifyUnknownRanContext;
async fn run(
&self,
task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
type Error = ();
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> {
println!("[{}] Unknown task ran!", task.id());
context.unknown_task_ran.store(true, Ordering::Relaxed);
Ok(())
@ -499,11 +517,11 @@ mod tests {
unknown_task_ran: Arc::new(AtomicBool::new(false)),
};
let task_store = memory_store().await;
let task_store = memory_store();
let (join_handle, queue) = WorkerPool::new(task_store, {
let join_handle = WorkerPool::new(task_store.clone(), {
let my_app_context = my_app_context.clone();
move |_| my_app_context.clone()
move || my_app_context.clone()
})
.register_task_type::<NotifyStopDuringRun>()
.configure_queue("default".into())
@ -514,6 +532,7 @@ mod tests {
.await
.unwrap();
let queue = Queue::new(task_store);
// Enqueue a task that is not registered
queue.enqueue(UnknownTask).await.unwrap();
@ -537,21 +556,18 @@ mod tests {
impl BackgroundTask for BrokenTask {
const TASK_NAME: &'static str = "panic_me";
type AppData = ();
type Error = ();
async fn run(
&self,
_task: CurrentTask,
_context: Self::AppData,
) -> Result<(), anyhow::Error> {
async fn run(&self, _task: CurrentTask, _context: Self::AppData) -> Result<(), ()> {
panic!("Oh no!");
}
}
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
let task_store = memory_store().await;
let task_store = memory_store();
let (worker_pool_finished, queue) = WorkerPool::new(task_store.clone(), |_| ())
let worker_pool_finished = WorkerPool::new(task_store.clone(), || ())
.register_task_type::<BrokenTask>()
.configure_queue("default".into())
.start(async move {
@ -560,6 +576,7 @@ mod tests {
.await
.unwrap();
let queue = Queue::new(task_store.clone());
// Enqueue a task that will panic
queue.enqueue(BrokenTask).await.unwrap();
@ -609,11 +626,13 @@ mod tests {
type AppData = PlayerContext;
type Error = ();
async fn run(
&self,
_task: CurrentTask,
context: Self::AppData,
) -> Result<(), anyhow::Error> {
) -> Result<(), Self::Error> {
loop {
let msg = context.ping_rx.lock().await.recv().await.unwrap();
match msg {
@ -643,11 +662,11 @@ mod tests {
ping_rx: Arc::new(Mutex::new(ping_rx)),
};
let task_store = memory_store().await;
let task_store = memory_store();
let (worker_pool_finished, queue) = WorkerPool::new(task_store, {
let worker_pool_finished = WorkerPool::new(task_store.clone(), {
let player_context = player_context.clone();
move |_| player_context.clone()
move || player_context.clone()
})
.register_task_type::<KeepAliveTask>()
.configure_queue("default".into())
@ -658,6 +677,7 @@ mod tests {
.await
.unwrap();
let queue = Queue::new(task_store);
queue.enqueue(KeepAliveTask).await.unwrap();
// Make sure task is running
@ -683,7 +703,7 @@ mod tests {
ping_tx.send(PingPongGame::StopThisNow).await.unwrap();
}
async fn memory_store() -> MemoryTaskStore {
fn memory_store() -> MemoryTaskStore {
MemoryTaskStore::default()
}
@ -692,15 +712,15 @@ mod tests {
async fn test_worker_pool_with_pg_store() {
let my_app_context = ApplicationContext::new();
let (join_handle, _queue) =
WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(
QueueConfig::new(GreetingTask::QUEUE).retention_mode(RetentionMode::RemoveDone),
)
.start(futures::future::ready(()))
.await
.unwrap();
let join_handle = WorkerPool::new(pg_task_store().await, move || my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(
QueueConfig::new(<GreetingTask as MyAppTask>::QUEUE)
.retention_mode(RetentionMode::RemoveDone),
)
.start(futures::future::ready(()))
.await
.unwrap();
join_handle.await.unwrap();
}