Compare commits
21 commits
Author | SHA1 | Date | |
---|---|---|---|
912127857f | |||
f198d3983a | |||
7a9eddf9e4 | |||
5c43dacb5b | |||
617bd71bd1 | |||
ed6a784e02 | |||
e1a8eeb7de | |||
c93d38de01 | |||
aa1144e54f | |||
2b42a27b72 | |||
64e2315999 | |||
253a82fecf | |||
979294296e | |||
c99486eaa6 | |||
c07781a79b | |||
042de9261f | |||
716eeae4b1 | |||
10e01390b8 | |||
82e6ef6dac | |||
0f0a9c2238 | |||
2964dc2b88 |
16 changed files with 772 additions and 338 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -2,3 +2,4 @@
|
|||
Cargo.lock
|
||||
docs/content/docs/CHANGELOG.md
|
||||
docs/content/docs/README.md
|
||||
.DS_Store
|
||||
|
|
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -1,3 +0,0 @@
|
|||
[submodule "docs/themes/adidoks"]
|
||||
path = docs/themes/adidoks
|
||||
url = https://github.com/aaranxu/adidoks.git
|
11
Cargo.toml
11
Cargo.toml
|
@ -1,25 +1,22 @@
|
|||
[package]
|
||||
name = "backie"
|
||||
version = "0.1.0"
|
||||
version = "0.6.0"
|
||||
authors = [
|
||||
"Rafael Caricio <rafael@caricio.com>",
|
||||
]
|
||||
description = "Async persistent background task processing for Rust applications with Tokio and PostgreSQL."
|
||||
repository = "https://code.caric.io/rafaelcaricio/backie"
|
||||
description = "Background task processing for Rust applications with Tokio, Diesel, and PostgreSQL."
|
||||
keywords = ["async", "background", "task", "jobs", "queue"]
|
||||
repository = "https://github.com/rafaelcaricio/backie"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
rust-version = "1.67"
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
|
||||
[dependencies]
|
||||
chrono = "0.4"
|
||||
log = "0.4"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
anyhow = "1"
|
||||
thiserror = "1"
|
||||
uuid = { version = "1.1", features = ["v4", "serde"] }
|
||||
async-trait = "0.1"
|
||||
|
|
95
README.md
95
README.md
|
@ -1,4 +1,4 @@
|
|||
# Backie 🚲
|
||||
<p align="center"><img src="logo.png" alt="Backie" width="400"></p>
|
||||
|
||||
Async persistent background task processing for Rust applications with Tokio. Queue asynchronous tasks
|
||||
to be processed by workers. It's designed to be easy to use and horizontally scalable. It uses Postgres as
|
||||
|
@ -17,19 +17,47 @@ Backie started as a fork of
|
|||
|
||||
Here are some of the Backie's key features:
|
||||
|
||||
- Async workers: Workers are started as [Tokio](https://tokio.rs/) tasks
|
||||
- Application context: Tasks can access an shared user-provided application context
|
||||
- Single-purpose workers: Tasks are stored together but workers are configured to execute only tasks of a specific queue
|
||||
- Retries: Tasks are retried with a custom backoff mode
|
||||
- Graceful shutdown: provide a future to gracefully shutdown the workers, on-the-fly tasks are not interrupted
|
||||
- Recovery of unfinished tasks: Tasks that were not finished are retried on the next worker start
|
||||
- Unique tasks: Tasks are not duplicated in the queue if they provide a unique hash
|
||||
- **Guaranteed execution**: at least one execution of a task
|
||||
- **Async workers**: Workers are started as [Tokio](https://tokio.rs/) tasks
|
||||
- **Application context**: Tasks can access an shared user-provided application context
|
||||
- **Single-purpose workers**: Tasks are stored together but workers are configured to execute only tasks of a specific queue
|
||||
- **Retries**: Tasks are retried with a custom backoff mode
|
||||
- **Graceful shutdown**: provide a future to gracefully shutdown the workers, on-the-fly tasks are not interrupted
|
||||
- **Recovery of unfinished tasks**: Tasks that were not finished are retried on the next worker start
|
||||
- **Unique tasks**: Tasks are not duplicated in the queue if they provide a unique hash
|
||||
|
||||
## Other planned features
|
||||
|
||||
- Task timeout: Tasks are retried if they are not completed in time
|
||||
- Scheduling of tasks: Tasks can be scheduled to be executed at a specific time
|
||||
|
||||
## Task execution protocol
|
||||
|
||||
The following diagram shows the protocol used to execute tasks:
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> Ready
|
||||
Ready --> Running: Task is picked up by a worker
|
||||
Running --> Done: Task is finished
|
||||
Running --> Failed: Task failed
|
||||
Failed --> Ready: Task is retried
|
||||
Failed --> [*]: Task is not retried anymore, max retries reached
|
||||
Done --> [*]
|
||||
```
|
||||
|
||||
When a task goes from `Running` to `Failed` it is retried. The number of retries is controlled by the
|
||||
[`BackgroundTask::MAX_RETRIES`] attribute. The default implementation uses `3` retries.
|
||||
|
||||
|
||||
## Safety
|
||||
|
||||
This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% safe Rust.
|
||||
|
||||
## Minimum supported Rust version
|
||||
|
||||
Backie's MSRV is 1.68.
|
||||
|
||||
## Installation
|
||||
|
||||
1. Add this to your `Cargo.toml`
|
||||
|
@ -52,8 +80,6 @@ diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
|
|||
Those dependencies are required to use the `#[async_trait]` and `#[derive(Serialize, Deserialize)]` attributes
|
||||
in your task definitions and to connect to the Postgres database.
|
||||
|
||||
*Supports rustc 1.68+*
|
||||
|
||||
2. Create the `backie_tasks` table in the Postgres database. The migration can be found in [the migrations directory](https://github.com/rafaelcaricio/backie/blob/master/migrations/2023-03-06-151907_create_backie_tasks/up.sql).
|
||||
|
||||
## Usage
|
||||
|
@ -67,6 +93,9 @@ the whole application. This attribute is critical for reconstructing the task ba
|
|||
The [`BackgroundTask::AppData`] can be used to argument the task with your application specific contextual information.
|
||||
This is useful for example to pass a database connection pool to the task or other application configuration.
|
||||
|
||||
The [`BackgroundTask::Error`] is the error type that will be returned by the [`BackgroundTask::run`] method. You can
|
||||
use this to define your own error type for your tasks.
|
||||
|
||||
The [`BackgroundTask::run`] method is where you define the behaviour of your background task execution. This method
|
||||
will be called by the task queue workers.
|
||||
|
||||
|
@ -84,8 +113,9 @@ pub struct MyTask {
|
|||
impl BackgroundTask for MyTask {
|
||||
const TASK_NAME: &'static str = "my_task_unique_name";
|
||||
type AppData = ();
|
||||
type Error = ();
|
||||
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
|
||||
// Do something
|
||||
Ok(())
|
||||
}
|
||||
|
@ -98,44 +128,23 @@ First, we need to create a [`TaskStore`] trait instance. This is the object resp
|
|||
tasks from a database. Backie currently only supports Postgres as a storage backend via the provided
|
||||
[`PgTaskStore`]. You can implement other storage backends by implementing the [`TaskStore`] trait.
|
||||
|
||||
```rust
|
||||
let connection_url = "postgres://postgres:password@localhost/backie";
|
||||
|
||||
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(connection_url);
|
||||
let pool = Pool::builder()
|
||||
.max_size(3)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let task_store = PgTaskStore::new(pool);
|
||||
```
|
||||
|
||||
Then, we can use the `task_store` to start a worker pool using the [`WorkerPool`]. The [`WorkerPool`] is responsible
|
||||
for starting the workers and managing their lifecycle.
|
||||
|
||||
```rust
|
||||
// Register the task types I want to use and start the worker pool
|
||||
let (_, queue) = WorkerPool::new(task_store, |_|())
|
||||
.register_task_type::<MyTask>()
|
||||
.configure_queue("default", 1, RetentionMode::default())
|
||||
.start(futures::future::pending::<()>())
|
||||
.await
|
||||
.unwrap();
|
||||
```
|
||||
|
||||
With that, we are defining that we want to execute instances of `MyTask` and that the `default` queue should
|
||||
have 1 worker running using the default [`RetentionMode`] (remove from the database only successfully finished tasks).
|
||||
We also defined in the `start` method that the worker pool should run forever.
|
||||
A full example of starting a worker pool can be found in the [examples directory](https://github.com/rafaelcaricio/backie/blob/main/examples/simple_worker/src/main.rs).
|
||||
|
||||
### Queueing tasks
|
||||
|
||||
After stating the workers we get an instance of [`Queue`] which we can use to enqueue tasks:
|
||||
After stating the workers, we get an instance of [`Queue`] which we can use to enqueue tasks. It is also possible
|
||||
to directly create a [`Queue`] instance from with a [`TaskStore`] instance.
|
||||
|
||||
```rust
|
||||
let task = MyTask { info: "Hello world!".to_string() };
|
||||
queue.enqueue(task).await.unwrap();
|
||||
```
|
||||
This will enqueue the task and whenever a worker is available it will start processing. Workers don't need to be
|
||||
started before enqueuing tasks. Workers don't need to be in the same process as the queue as long as the workers have
|
||||
access to the same underlying storage system. This enables horizontal scaling of the workers.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the [MIT license][license].
|
||||
|
||||
## Contributing
|
||||
|
||||
|
@ -145,7 +154,7 @@ queue.enqueue(task).await.unwrap();
|
|||
4. Push to the branch (`git push origin my-new-feature`)
|
||||
5. Create a new Pull Request
|
||||
|
||||
## Thanks to related crates authors
|
||||
## Acknowledgements
|
||||
|
||||
I would like to thank the authors of the [Fang](https://github.com/ayrat555/fang) and [background_job](https://git.asonix.dog/asonix/background-jobs.git) crates which were the main inspiration for this project.
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
backie = { path = "../../" }
|
||||
anyhow = "1"
|
||||
env_logger = "0.9.0"
|
||||
env_logger = "0.10"
|
||||
log = "0.4.0"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
|
||||
|
|
|
@ -1,97 +0,0 @@
|
|||
use async_trait::async_trait;
|
||||
use backie::{BackgroundTask, CurrentTask};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MyApplicationContext {
|
||||
app_name: String,
|
||||
}
|
||||
|
||||
impl MyApplicationContext {
|
||||
pub fn new(app_name: &str) -> Self {
|
||||
Self {
|
||||
app_name: app_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct MyTask {
|
||||
pub number: u16,
|
||||
}
|
||||
|
||||
impl MyTask {
|
||||
pub fn new(number: u16) -> Self {
|
||||
Self { number }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct MyFailingTask {
|
||||
pub number: u16,
|
||||
}
|
||||
|
||||
impl MyFailingTask {
|
||||
pub fn new(number: u16) -> Self {
|
||||
Self { number }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for MyTask {
|
||||
const TASK_NAME: &'static str = "my_task";
|
||||
type AppData = MyApplicationContext;
|
||||
|
||||
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
// let new_task = MyTask::new(self.number + 1);
|
||||
// queue
|
||||
// .insert_task(&new_task)
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
log::info!(
|
||||
"[{}] Hello from {}! the current number is {}",
|
||||
task.id(),
|
||||
ctx.app_name,
|
||||
self.number
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
|
||||
log::info!("[{}] done..", task.id());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for MyFailingTask {
|
||||
const TASK_NAME: &'static str = "my_failing_task";
|
||||
type AppData = MyApplicationContext;
|
||||
|
||||
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
// let new_task = MyFailingTask::new(self.number + 1);
|
||||
// queue
|
||||
// .insert_task(&new_task)
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
// task.id();
|
||||
// task.keep_alive().await?;
|
||||
// task.previous_error();
|
||||
// task.retry_count();
|
||||
|
||||
log::info!("[{}] the current number is {}", task.id(), self.number);
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
|
||||
log::info!("[{}] done..", task.id());
|
||||
//
|
||||
// let b = true;
|
||||
//
|
||||
// if b {
|
||||
// panic!("Hello!");
|
||||
// } else {
|
||||
// Ok(())
|
||||
// }
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,9 +1,97 @@
|
|||
use backie::{PgTaskStore, RetentionMode, WorkerPool};
|
||||
use async_trait::async_trait;
|
||||
use backie::{BackgroundTask, CurrentTask};
|
||||
use backie::{PgTaskStore, Queue, WorkerPool};
|
||||
use diesel_async::pg::AsyncPgConnection;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use simple_worker::MyApplicationContext;
|
||||
use simple_worker::MyFailingTask;
|
||||
use simple_worker::MyTask;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MyApplicationContext {
|
||||
app_name: String,
|
||||
}
|
||||
|
||||
impl MyApplicationContext {
|
||||
pub fn new(app_name: &str) -> Self {
|
||||
Self {
|
||||
app_name: app_name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct MyTask {
|
||||
pub number: u16,
|
||||
}
|
||||
|
||||
impl MyTask {
|
||||
pub fn new(number: u16) -> Self {
|
||||
Self { number }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for MyTask {
|
||||
const TASK_NAME: &'static str = "my_task";
|
||||
type AppData = MyApplicationContext;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> {
|
||||
// let new_task = MyTask::new(self.number + 1);
|
||||
// queue
|
||||
// .insert_task(&new_task)
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
log::info!(
|
||||
"[{}] Hello from {}! the current number is {}",
|
||||
task.id(),
|
||||
ctx.app_name,
|
||||
self.number
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
|
||||
log::info!("[{}] done..", task.id());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct MyFailingTask {
|
||||
pub number: u16,
|
||||
}
|
||||
|
||||
impl MyFailingTask {
|
||||
pub fn new(number: u16) -> Self {
|
||||
Self { number }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for MyFailingTask {
|
||||
const TASK_NAME: &'static str = "my_failing_task";
|
||||
type AppData = MyApplicationContext;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> {
|
||||
// let new_task = MyFailingTask::new(self.number + 1);
|
||||
// queue
|
||||
// .insert_task(&new_task)
|
||||
// .await
|
||||
// .unwrap();
|
||||
|
||||
// task.id();
|
||||
// task.keep_alive().await?;
|
||||
// task.previous_error();
|
||||
// task.retry_count();
|
||||
|
||||
log::info!("[{}] the current number is {}", task.id(), self.number);
|
||||
tokio::time::sleep(Duration::from_secs(3)).await;
|
||||
|
||||
log::info!("[{}] done..", task.id());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
@ -30,15 +118,16 @@ async fn main() {
|
|||
let my_app_context = MyApplicationContext::new("Backie Example App");
|
||||
|
||||
// Register the task types I want to use and start the worker pool
|
||||
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||
.register_task_type::<MyTask>()
|
||||
.register_task_type::<MyFailingTask>()
|
||||
.configure_queue("default", 3, RetentionMode::RemoveDone)
|
||||
.start(async move {
|
||||
let _ = rx.changed().await;
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let join_handle =
|
||||
WorkerPool::new(task_store.clone(), move || my_app_context.clone())
|
||||
.register_task_type::<MyTask>()
|
||||
.register_task_type::<MyFailingTask>()
|
||||
.configure_queue("default".into())
|
||||
.start(async move {
|
||||
let _ = rx.changed().await;
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
log::info!("Workers started ...");
|
||||
|
||||
|
@ -46,6 +135,7 @@ async fn main() {
|
|||
let task2 = MyTask::new(20_000);
|
||||
let task3 = MyFailingTask::new(50_000);
|
||||
|
||||
let queue = Queue::new(task_store);
|
||||
queue.enqueue(task1).await.unwrap();
|
||||
queue.enqueue(task2).await.unwrap();
|
||||
queue.enqueue(task3).await.unwrap();
|
||||
|
|
BIN
logo.png
Normal file
BIN
logo.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 131 KiB |
23
release.toml
Normal file
23
release.toml
Normal file
|
@ -0,0 +1,23 @@
|
|||
allow-branch = [
|
||||
"*",
|
||||
"!HEAD",
|
||||
]
|
||||
sign-commit = true
|
||||
sign-tag = true
|
||||
push-remote = "origin"
|
||||
release = true
|
||||
publish = true
|
||||
verify = true
|
||||
owners = []
|
||||
push = true
|
||||
push-options = []
|
||||
consolidate-commits = false
|
||||
pre-release-commit-message = "Release {{crate_name}} version {{version}}"
|
||||
pre-release-replacements = []
|
||||
tag-message = "Release {{version}}"
|
||||
tag-name = "{{version}}"
|
||||
tag = true
|
||||
enable-features = []
|
||||
enable-all-features = false
|
||||
dependent-version = "upgrade"
|
||||
metadata = "optional"
|
46
src/catch_unwind.rs
Normal file
46
src/catch_unwind.rs
Normal file
|
@ -0,0 +1,46 @@
|
|||
use crate::worker::TaskExecError;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::FutureExt;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
pub(crate) struct CatchUnwindFuture<F: Future + Send + 'static> {
|
||||
inner: BoxFuture<'static, F::Output>,
|
||||
}
|
||||
|
||||
impl<F: Future + Send + 'static> CatchUnwindFuture<F> {
|
||||
pub fn create(f: F) -> CatchUnwindFuture<F> {
|
||||
Self { inner: f.boxed() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future + Send + 'static> Future for CatchUnwindFuture<F> {
|
||||
type Output = Result<F::Output, TaskExecError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let inner = &mut self.inner;
|
||||
|
||||
match catch_unwind(move || inner.poll_unpin(cx)) {
|
||||
Ok(Poll::Pending) => Poll::Pending,
|
||||
Ok(Poll::Ready(value)) => Poll::Ready(Ok(value)),
|
||||
Err(cause) => Poll::Ready(Err(cause)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> Result<R, TaskExecError> {
|
||||
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
|
||||
Ok(res) => Ok(res),
|
||||
Err(cause) => match cause.downcast_ref::<&'static str>() {
|
||||
None => match cause.downcast_ref::<String>() {
|
||||
None => Err(TaskExecError::Panicked(
|
||||
"Sorry, unknown panic message".to_string(),
|
||||
)),
|
||||
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
|
||||
},
|
||||
Some(message) => Err(TaskExecError::Panicked(message.to_string())),
|
||||
},
|
||||
}
|
||||
}
|
|
@ -5,7 +5,7 @@
|
|||
/// All possible options for retaining tasks in the db after their execution.
|
||||
///
|
||||
/// The default mode is [`RetentionMode::RemoveAll`]
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
|
||||
#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
|
||||
pub enum RetentionMode {
|
||||
/// Keep all tasks
|
||||
KeepAll,
|
||||
|
@ -26,10 +26,11 @@ impl Default for RetentionMode {
|
|||
pub use queue::Queue;
|
||||
pub use runnable::BackgroundTask;
|
||||
pub use store::{PgTaskStore, TaskStore};
|
||||
pub use task::{CurrentTask, Task, TaskId, TaskState};
|
||||
pub use task::{CurrentTask, NewTask, Task, TaskId, TaskState};
|
||||
pub use worker::Worker;
|
||||
pub use worker_pool::WorkerPool;
|
||||
pub use worker_pool::{QueueConfig, WorkerPool};
|
||||
|
||||
mod catch_unwind;
|
||||
pub mod errors;
|
||||
mod queries;
|
||||
mod queue;
|
||||
|
|
18
src/queue.rs
18
src/queue.rs
|
@ -2,22 +2,20 @@ use crate::errors::BackieError;
|
|||
use crate::runnable::BackgroundTask;
|
||||
use crate::store::TaskStore;
|
||||
use crate::task::NewTask;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Queue<S>
|
||||
where
|
||||
S: TaskStore,
|
||||
{
|
||||
task_store: Arc<S>,
|
||||
task_store: S,
|
||||
}
|
||||
|
||||
impl<S> Queue<S>
|
||||
where
|
||||
S: TaskStore,
|
||||
{
|
||||
pub fn new(task_store: Arc<S>) -> Self {
|
||||
pub fn new(task_store: S) -> Self {
|
||||
Queue { task_store }
|
||||
}
|
||||
|
||||
|
@ -25,9 +23,21 @@ where
|
|||
where
|
||||
BT: BackgroundTask,
|
||||
{
|
||||
// TODO: Add option to specify the timeout of a task
|
||||
self.task_store
|
||||
.create_task(NewTask::new(background_task, Duration::from_secs(10))?)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Clone for Queue<S>
|
||||
where
|
||||
S: TaskStore + Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
task_store: self.task_store.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::task::{CurrentTask, TaskHash};
|
||||
use async_trait::async_trait;
|
||||
use serde::{de::DeserializeOwned, ser::Serialize};
|
||||
use std::fmt::Debug;
|
||||
|
||||
/// The [`BackgroundTask`] trait is used to define the behaviour of a task. You must implement this
|
||||
/// trait for all tasks you want to execute.
|
||||
|
@ -17,7 +18,7 @@ use serde::{de::DeserializeOwned, ser::Serialize};
|
|||
///
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// ```
|
||||
/// use async_trait::async_trait;
|
||||
/// use backie::{BackgroundTask, CurrentTask};
|
||||
/// use serde::{Deserialize, Serialize};
|
||||
|
@ -25,11 +26,13 @@ use serde::{de::DeserializeOwned, ser::Serialize};
|
|||
/// #[derive(Serialize, Deserialize)]
|
||||
/// pub struct MyTask {}
|
||||
///
|
||||
/// #[async_trait]
|
||||
/// impl BackgroundTask for MyTask {
|
||||
/// const TASK_NAME: &'static str = "my_task_unique_name";
|
||||
/// type AppData = ();
|
||||
/// type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
|
||||
///
|
||||
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
/// async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
|
||||
/// // Do something
|
||||
/// Ok(())
|
||||
/// }
|
||||
|
@ -50,14 +53,17 @@ pub trait BackgroundTask: Serialize + DeserializeOwned + Sync + Send + 'static {
|
|||
|
||||
/// Number of retries for tasks.
|
||||
///
|
||||
/// By default, it is set to 5.
|
||||
const MAX_RETRIES: i32 = 5;
|
||||
/// By default, it is set to 3.
|
||||
const MAX_RETRIES: i32 = 3;
|
||||
|
||||
/// The application data provided to this task at runtime.
|
||||
type AppData: Clone + Send + 'static;
|
||||
|
||||
/// An application custom error type.
|
||||
type Error: Debug + Send + 'static;
|
||||
|
||||
/// Execute the task. This method should define its logic
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), anyhow::Error>;
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error>;
|
||||
|
||||
/// If set to true, no new tasks with the same metadata will be inserted
|
||||
/// By default it is set to false.
|
||||
|
|
16
src/store.rs
16
src/store.rs
|
@ -106,7 +106,7 @@ pub mod test_store {
|
|||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct MemoryTaskStore {
|
||||
tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
|
||||
pub tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
|
@ -197,7 +197,7 @@ pub mod test_store {
|
|||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait TaskStore: Clone + Send + Sync + 'static {
|
||||
pub trait TaskStore: Send + Sync + 'static {
|
||||
async fn pull_next_task(
|
||||
&self,
|
||||
queue_name: &str,
|
||||
|
@ -213,3 +213,15 @@ pub trait TaskStore: Clone + Send + Sync + 'static {
|
|||
error: &str,
|
||||
) -> Result<Task, AsyncQueueError>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::store::test_store::MemoryTaskStore;
|
||||
|
||||
#[test]
|
||||
fn task_store_trait_is_object_safe() {
|
||||
let store = MemoryTaskStore::default();
|
||||
let _object = &store as &dyn TaskStore;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::catch_unwind::CatchUnwindFuture;
|
||||
use crate::errors::{AsyncQueueError, BackieError};
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::store::TaskStore;
|
||||
|
@ -9,7 +10,7 @@ use std::collections::BTreeMap;
|
|||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use std::time::Duration;
|
||||
|
||||
pub type ExecuteTaskFn<AppData> = Arc<
|
||||
dyn Fn(
|
||||
|
@ -23,13 +24,16 @@ pub type ExecuteTaskFn<AppData> = Arc<
|
|||
|
||||
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TaskExecError {
|
||||
#[error("Task execution failed: {0}")]
|
||||
ExecutionFailed(#[from] anyhow::Error),
|
||||
|
||||
#[error("Task deserialization failed: {0}")]
|
||||
TaskDeserializationFailed(#[from] serde_json::Error),
|
||||
|
||||
#[error("Task execution failed: {0}")]
|
||||
ExecutionFailed(String),
|
||||
|
||||
#[error("Task panicked with: {0}")]
|
||||
Panicked(String),
|
||||
}
|
||||
|
||||
pub(crate) fn runnable<BT>(
|
||||
|
@ -42,8 +46,10 @@ where
|
|||
{
|
||||
Box::pin(async move {
|
||||
let background_task: BT = serde_json::from_value(payload)?;
|
||||
background_task.run(task_info, app_context).await?;
|
||||
Ok(())
|
||||
match background_task.run(task_info, app_context).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(TaskExecError::ExecutionFailed(format!("{:?}", err))),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -51,14 +57,16 @@ where
|
|||
pub struct Worker<AppData, S>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
S: TaskStore,
|
||||
S: TaskStore + Clone,
|
||||
{
|
||||
store: Arc<S>,
|
||||
store: S,
|
||||
|
||||
queue_name: String,
|
||||
|
||||
retention_mode: RetentionMode,
|
||||
|
||||
pull_interval: Duration,
|
||||
|
||||
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||
|
||||
app_data_fn: StateFn<AppData>,
|
||||
|
@ -70,12 +78,13 @@ where
|
|||
impl<AppData, S> Worker<AppData, S>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
S: TaskStore,
|
||||
S: TaskStore + Clone,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
store: Arc<S>,
|
||||
store: S,
|
||||
queue_name: String,
|
||||
retention_mode: RetentionMode,
|
||||
pull_interval: Duration,
|
||||
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||
app_data_fn: StateFn<AppData>,
|
||||
shutdown: Option<tokio::sync::watch::Receiver<()>>,
|
||||
|
@ -84,6 +93,7 @@ where
|
|||
store,
|
||||
queue_name,
|
||||
retention_mode,
|
||||
pull_interval,
|
||||
task_registry,
|
||||
app_data_fn,
|
||||
shutdown,
|
||||
|
@ -120,11 +130,11 @@ where
|
|||
log::info!("Shutting down worker");
|
||||
return Ok(());
|
||||
}
|
||||
_ = tokio::time::sleep(std::time::Duration::from_secs(1)).fuse() => {}
|
||||
_ = tokio::time::sleep(self.pull_interval).fuse() => {}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
tokio::time::sleep(self.pull_interval).await;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -139,9 +149,18 @@ where
|
|||
.get(&task.task_name)
|
||||
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
|
||||
|
||||
// TODO: catch panics
|
||||
let result: Result<(), TaskExecError> =
|
||||
runnable_task_caller(task_info, task.payload.clone(), (self.app_data_fn)()).await;
|
||||
// catch panics
|
||||
let result: Result<(), TaskExecError> = CatchUnwindFuture::create({
|
||||
let task_payload = task.payload.clone();
|
||||
let app_data = (self.app_data_fn)();
|
||||
let runnable_task_caller = runnable_task_caller.clone();
|
||||
async move { runnable_task_caller(task_info, task_payload, app_data).await }
|
||||
})
|
||||
.await
|
||||
.and_then(|result| {
|
||||
result?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
match &result {
|
||||
Ok(_) => self.finalize_task(task, result).await?,
|
||||
|
@ -154,7 +173,9 @@ where
|
|||
task.id,
|
||||
backoff_seconds
|
||||
);
|
||||
|
||||
let error_message = format!("{}", error);
|
||||
|
||||
self.store
|
||||
.schedule_task_retry(task.id, backoff_seconds, &error_message)
|
||||
.await?;
|
||||
|
@ -231,8 +252,9 @@ mod async_worker_tests {
|
|||
impl BackgroundTask for WorkerAsyncTask {
|
||||
const TASK_NAME: &'static str = "WorkerAsyncTask";
|
||||
type AppData = ();
|
||||
type Error = ();
|
||||
|
||||
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), ()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -246,8 +268,9 @@ mod async_worker_tests {
|
|||
impl BackgroundTask for WorkerAsyncTaskSchedule {
|
||||
const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
|
||||
type AppData = ();
|
||||
type Error = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -265,11 +288,12 @@ mod async_worker_tests {
|
|||
impl BackgroundTask for AsyncFailedTask {
|
||||
const TASK_NAME: &'static str = "AsyncFailedTask";
|
||||
type AppData = ();
|
||||
type Error = TaskError;
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), TaskError> {
|
||||
let message = format!("number {} is wrong :(", self.number);
|
||||
|
||||
Err(TaskError::Custom(message).into())
|
||||
Err(TaskError::Custom(message))
|
||||
}
|
||||
|
||||
fn max_retries(&self) -> i32 {
|
||||
|
@ -284,9 +308,10 @@ mod async_worker_tests {
|
|||
impl BackgroundTask for AsyncRetryTask {
|
||||
const TASK_NAME: &'static str = "AsyncRetryTask";
|
||||
type AppData = ();
|
||||
type Error = TaskError;
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Err(TaskError::SomethingWrong.into())
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
|
||||
Err(TaskError::SomethingWrong)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -297,8 +322,9 @@ mod async_worker_tests {
|
|||
impl BackgroundTask for AsyncTaskType1 {
|
||||
const TASK_NAME: &'static str = "AsyncTaskType1";
|
||||
type AppData = ();
|
||||
type Error = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -310,8 +336,9 @@ mod async_worker_tests {
|
|||
impl BackgroundTask for AsyncTaskType2 {
|
||||
const TASK_NAME: &'static str = "AsyncTaskType2";
|
||||
type AppData = ();
|
||||
type Error = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,26 +1,24 @@
|
|||
use crate::errors::BackieError;
|
||||
use crate::queue::Queue;
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::store::TaskStore;
|
||||
use crate::worker::{runnable, ExecuteTaskFn};
|
||||
use crate::worker::{StateFn, Worker};
|
||||
use crate::RetentionMode;
|
||||
use futures::future::join_all;
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WorkerPool<AppData, S>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
S: TaskStore,
|
||||
S: TaskStore + Clone,
|
||||
{
|
||||
/// Storage of tasks.
|
||||
task_store: Arc<S>,
|
||||
|
||||
/// Queue used to spawn tasks.
|
||||
queue: Queue<S>,
|
||||
task_store: S,
|
||||
|
||||
/// Make possible to load the application data.
|
||||
///
|
||||
|
@ -36,28 +34,21 @@ where
|
|||
queue_tasks: BTreeMap<String, Vec<String>>,
|
||||
|
||||
/// Number of workers that will be spawned per queue.
|
||||
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
|
||||
worker_queues: BTreeMap<String, QueueConfig>,
|
||||
}
|
||||
|
||||
impl<AppData, S> WorkerPool<AppData, S>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
S: TaskStore,
|
||||
S: TaskStore + Clone,
|
||||
{
|
||||
/// Create a new worker pool.
|
||||
pub fn new<A>(task_store: S, application_data_fn: A) -> Self
|
||||
where
|
||||
A: Fn(Queue<S>) -> AppData + Send + Sync + 'static,
|
||||
A: Fn() -> AppData + Send + Sync + 'static,
|
||||
{
|
||||
let queue_store = Arc::new(task_store);
|
||||
let queue = Queue::new(queue_store.clone());
|
||||
let application_data_fn = {
|
||||
let queue = queue.clone();
|
||||
move || application_data_fn(queue.clone())
|
||||
};
|
||||
Self {
|
||||
task_store: queue_store,
|
||||
queue,
|
||||
task_store,
|
||||
application_data_fn: Arc::new(application_data_fn),
|
||||
task_registry: BTreeMap::new(),
|
||||
queue_tasks: BTreeMap::new(),
|
||||
|
@ -79,21 +70,12 @@ where
|
|||
self
|
||||
}
|
||||
|
||||
pub fn configure_queue(
|
||||
mut self,
|
||||
queue_name: impl ToString,
|
||||
num_workers: u32,
|
||||
retention_mode: RetentionMode,
|
||||
) -> Self {
|
||||
self.worker_queues
|
||||
.insert(queue_name.to_string(), (retention_mode, num_workers));
|
||||
pub fn configure_queue(mut self, config: QueueConfig) -> Self {
|
||||
self.worker_queues.insert(config.name.clone(), config);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn start<F>(
|
||||
self,
|
||||
graceful_shutdown: F,
|
||||
) -> Result<(JoinHandle<()>, Queue<S>), BackieError>
|
||||
pub async fn start<F>(self, graceful_shutdown: F) -> Result<JoinHandle<()>, BackieError>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
|
@ -106,39 +88,125 @@ where
|
|||
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
|
||||
let mut worker_handles = Vec::new();
|
||||
|
||||
// Spawn all individual workers per queue
|
||||
for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() {
|
||||
for idx in 0..*num_workers {
|
||||
for (queue_name, queue_config) in self.worker_queues.iter() {
|
||||
for idx in 0..queue_config.num_workers {
|
||||
let mut worker: Worker<AppData, S> = Worker::new(
|
||||
self.task_store.clone(),
|
||||
queue_name.clone(),
|
||||
*retention_mode,
|
||||
queue_config.retention_mode,
|
||||
queue_config.pull_interval,
|
||||
self.task_registry.clone(),
|
||||
self.application_data_fn.clone(),
|
||||
Some(rx.clone()),
|
||||
);
|
||||
let worker_name = format!("worker-{queue_name}-{idx}");
|
||||
// TODO: grab the join handle for every worker for graceful shutdown
|
||||
tokio::spawn(async move {
|
||||
// grabs the join handle for every worker for graceful shutdown
|
||||
let join_handle = tokio::spawn(async move {
|
||||
match worker.run_tasks().await {
|
||||
Ok(()) => log::info!("Worker {worker_name} stopped successfully"),
|
||||
Err(err) => log::error!("Worker {worker_name} stopped due to error: {err}"),
|
||||
}
|
||||
});
|
||||
worker_handles.push(join_handle);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
tokio::spawn(async move {
|
||||
graceful_shutdown.await;
|
||||
if let Err(err) = tx.send(()) {
|
||||
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
|
||||
Ok(tokio::spawn(async move {
|
||||
graceful_shutdown.await;
|
||||
if let Err(err) = tx.send(()) {
|
||||
log::warn!("Failed to send shutdown signal to worker pool: {}", err);
|
||||
} else {
|
||||
// Wait for all workers to finish processing
|
||||
let results = join_all(worker_handles)
|
||||
.await
|
||||
.into_iter()
|
||||
.filter(Result::is_err)
|
||||
.map(Result::unwrap_err)
|
||||
.collect::<Vec<_>>();
|
||||
if !results.is_empty() {
|
||||
log::error!("Worker pool stopped with errors: {:?}", results);
|
||||
} else {
|
||||
log::info!("Worker pool stopped gracefully");
|
||||
}
|
||||
}),
|
||||
self.queue,
|
||||
))
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for a queue.
|
||||
///
|
||||
/// This is used to configure the number of workers, the retention mode, and the pulling interval
|
||||
/// for a queue.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// Example of configuring a queue with all options:
|
||||
/// ```
|
||||
/// # use backie::QueueConfig;
|
||||
/// # use backie::RetentionMode;
|
||||
/// # use std::time::Duration;
|
||||
/// let config = QueueConfig::new("default")
|
||||
/// .num_workers(5)
|
||||
/// .retention_mode(RetentionMode::KeepAll)
|
||||
/// .pull_interval(Duration::from_secs(1));
|
||||
/// ```
|
||||
/// Example of queue configuration with default options:
|
||||
/// ```
|
||||
/// # use backie::QueueConfig;
|
||||
/// let config = QueueConfig::new("default");
|
||||
/// // Also possible to use the `From` trait:
|
||||
/// let config: QueueConfig = "default".into();
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct QueueConfig {
|
||||
name: String,
|
||||
num_workers: u32,
|
||||
retention_mode: RetentionMode,
|
||||
pull_interval: Duration,
|
||||
}
|
||||
|
||||
impl QueueConfig {
|
||||
/// Create a new queue configuration.
|
||||
pub fn new(name: impl ToString) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
num_workers: 1,
|
||||
retention_mode: RetentionMode::default(),
|
||||
pull_interval: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the number of workers for this queue.
|
||||
pub fn num_workers(mut self, num_workers: u32) -> Self {
|
||||
self.num_workers = num_workers;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the retention mode for this queue.
|
||||
pub fn retention_mode(mut self, retention_mode: RetentionMode) -> Self {
|
||||
self.retention_mode = retention_mode;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the pull interval for this queue.
|
||||
///
|
||||
/// This is the interval at which the queue will be checking for new tasks by calling
|
||||
/// the backend storage.
|
||||
pub fn pull_interval(mut self, pull_interval: Duration) -> Self {
|
||||
self.pull_interval = pull_interval;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<S> for QueueConfig
|
||||
where
|
||||
S: ToString,
|
||||
{
|
||||
fn from(name: S) -> Self {
|
||||
Self::new(name.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,13 +216,16 @@ mod tests {
|
|||
use crate::store::test_store::MemoryTaskStore;
|
||||
use crate::store::PgTaskStore;
|
||||
use crate::task::CurrentTask;
|
||||
use crate::Queue;
|
||||
use async_trait::async_trait;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
use futures::FutureExt;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ApplicationContext {
|
||||
pub struct ApplicationContext {
|
||||
app_name: String,
|
||||
}
|
||||
|
||||
|
@ -175,17 +246,50 @@ mod tests {
|
|||
person: String,
|
||||
}
|
||||
|
||||
/// This tests that one can customize the task parameters for the application.
|
||||
#[async_trait]
|
||||
impl BackgroundTask for GreetingTask {
|
||||
const TASK_NAME: &'static str = "my_task";
|
||||
trait MyAppTask {
|
||||
const TASK_NAME: &'static str;
|
||||
const QUEUE: &'static str = "default";
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task_info: CurrentTask,
|
||||
app_context: ApplicationContext,
|
||||
) -> Result<(), ()>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T> BackgroundTask for T
|
||||
where
|
||||
T: MyAppTask + serde::de::DeserializeOwned + serde::ser::Serialize + Sync + Send + 'static,
|
||||
{
|
||||
const TASK_NAME: &'static str = T::TASK_NAME;
|
||||
|
||||
const QUEUE: &'static str = T::QUEUE;
|
||||
|
||||
type AppData = ApplicationContext;
|
||||
|
||||
type Error = ();
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task_info: CurrentTask,
|
||||
app_context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<(), Self::Error> {
|
||||
self.run(task_info, app_context).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl MyAppTask for GreetingTask {
|
||||
const TASK_NAME: &'static str = "my_task";
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task_info: CurrentTask,
|
||||
app_context: ApplicationContext,
|
||||
) -> Result<(), ()> {
|
||||
println!(
|
||||
"[{}] Hello {}! I'm {}.",
|
||||
task_info.id(),
|
||||
|
@ -206,12 +310,9 @@ mod tests {
|
|||
const QUEUE: &'static str = "other_queue";
|
||||
|
||||
type AppData = ApplicationContext;
|
||||
type Error = ();
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task: CurrentTask,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Self::Error> {
|
||||
println!(
|
||||
"[{}] Other task with {}!",
|
||||
task.id(),
|
||||
|
@ -221,41 +322,11 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NotifyFinishedContext {
|
||||
tx: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct NotifyFinished;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for NotifyFinished {
|
||||
const TASK_NAME: &'static str = "notify_finished";
|
||||
|
||||
type AppData = NotifyFinishedContext;
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task: CurrentTask,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
match context.tx.lock().await.take() {
|
||||
None => println!("Cannot notify, already done that!"),
|
||||
Some(tx) => {
|
||||
tx.send(()).unwrap();
|
||||
println!("[{}] Notify finished did it's job!", task.id())
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_all_registered_tasks_queues_are_configured() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
||||
let result = WorkerPool::new(memory_store(), move || my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.start(futures::future::ready(()))
|
||||
.await;
|
||||
|
@ -273,14 +344,16 @@ mod tests {
|
|||
async fn test_worker_pool_with_task() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let (join_handle, queue) =
|
||||
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
let task_store = memory_store();
|
||||
|
||||
let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.configure_queue(<GreetingTask as MyAppTask>::QUEUE.into())
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let queue = Queue::new(task_store);
|
||||
queue
|
||||
.enqueue(GreetingTask {
|
||||
person: "Rafael".to_string(),
|
||||
|
@ -295,16 +368,17 @@ mod tests {
|
|||
async fn test_worker_pool_with_multiple_task_types() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let (join_handle, queue) =
|
||||
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.register_task_type::<OtherTask>()
|
||||
.configure_queue("default", 1, RetentionMode::default())
|
||||
.configure_queue("other_queue", 1, RetentionMode::default())
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
let task_store = memory_store();
|
||||
let join_handle = WorkerPool::new(task_store.clone(), move || my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.register_task_type::<OtherTask>()
|
||||
.configure_queue("default".into())
|
||||
.configure_queue("other_queue".into())
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let queue = Queue::new(task_store.clone());
|
||||
queue
|
||||
.enqueue(GreetingTask {
|
||||
person: "Rafael".to_string(),
|
||||
|
@ -319,23 +393,56 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_worker_pool_stop_after_task_execute() {
|
||||
#[derive(Clone)]
|
||||
struct NotifyFinishedContext {
|
||||
/// Used to notify the task ran
|
||||
notify_finished: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||
}
|
||||
|
||||
/// A task that notifies the test that it ran
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct NotifyFinished;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for NotifyFinished {
|
||||
const TASK_NAME: &'static str = "notify_finished";
|
||||
|
||||
type AppData = NotifyFinishedContext;
|
||||
|
||||
type Error = ();
|
||||
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> {
|
||||
// Notify the test that the task ran
|
||||
match context.notify_finished.lock().await.take() {
|
||||
None => println!("Cannot notify, already done that!"),
|
||||
Some(tx) => {
|
||||
tx.send(()).unwrap();
|
||||
println!("[{}] Notify finished did it's job!", task.id())
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
let my_app_context = NotifyFinishedContext {
|
||||
tx: Arc::new(Mutex::new(Some(tx))),
|
||||
notify_finished: Arc::new(Mutex::new(Some(tx))),
|
||||
};
|
||||
|
||||
let (join_handle, queue) =
|
||||
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<NotifyFinished>()
|
||||
.configure_queue("default", 1, RetentionMode::default())
|
||||
.start(async move {
|
||||
rx.await.unwrap();
|
||||
println!("Worker pool got notified to stop");
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let memory_store = memory_store();
|
||||
|
||||
let join_handle = WorkerPool::new(memory_store.clone(), move || my_app_context.clone())
|
||||
.register_task_type::<NotifyFinished>()
|
||||
.configure_queue("default".into())
|
||||
.start(async move {
|
||||
rx.await.unwrap();
|
||||
println!("Worker pool got notified to stop");
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let queue = Queue::new(memory_store);
|
||||
// Notifies the worker pool to stop after the task is executed
|
||||
queue.enqueue(NotifyFinished).await.unwrap();
|
||||
|
||||
|
@ -347,6 +454,44 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_worker_pool_try_to_run_unknown_task() {
|
||||
#[derive(Clone)]
|
||||
struct NotifyUnknownRanContext {
|
||||
/// Notify that application should stop
|
||||
should_stop: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
|
||||
|
||||
/// Used to mark if the unknown task ran
|
||||
unknown_task_ran: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
/// A task that notifies the test that it ran
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct NotifyStopDuringRun;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for NotifyStopDuringRun {
|
||||
const TASK_NAME: &'static str = "notify_finished";
|
||||
|
||||
type AppData = NotifyUnknownRanContext;
|
||||
|
||||
type Error = ();
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task: CurrentTask,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), Self::Error> {
|
||||
// Notify the test that the task ran
|
||||
match context.should_stop.lock().await.take() {
|
||||
None => println!("Cannot notify, already done that!"),
|
||||
Some(tx) => {
|
||||
tx.send(()).unwrap();
|
||||
println!("[{}] Notify finished did it's job!", task.id())
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct UnknownTask;
|
||||
|
||||
|
@ -354,46 +499,211 @@ mod tests {
|
|||
impl BackgroundTask for UnknownTask {
|
||||
const TASK_NAME: &'static str = "unknown_task";
|
||||
|
||||
type AppData = NotifyFinishedContext;
|
||||
type AppData = NotifyUnknownRanContext;
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
task: CurrentTask,
|
||||
_context: Self::AppData,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
type Error = ();
|
||||
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), ()> {
|
||||
println!("[{}] Unknown task ran!", task.id());
|
||||
context.unknown_task_ran.store(true, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
let my_app_context = NotifyFinishedContext {
|
||||
tx: Arc::new(Mutex::new(Some(tx))),
|
||||
let my_app_context = NotifyUnknownRanContext {
|
||||
should_stop: Arc::new(Mutex::new(Some(tx))),
|
||||
unknown_task_ran: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
let task_store = memory_store().await;
|
||||
let task_store = memory_store();
|
||||
|
||||
let (join_handle, queue) = WorkerPool::new(task_store, move |_| my_app_context.clone())
|
||||
.register_task_type::<NotifyFinished>()
|
||||
.configure_queue("default", 1, RetentionMode::default())
|
||||
.start(async move {
|
||||
rx.await.unwrap();
|
||||
println!("Worker pool got notified to stop");
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let join_handle = WorkerPool::new(task_store.clone(), {
|
||||
let my_app_context = my_app_context.clone();
|
||||
move || my_app_context.clone()
|
||||
})
|
||||
.register_task_type::<NotifyStopDuringRun>()
|
||||
.configure_queue("default".into())
|
||||
.start(async move {
|
||||
rx.await.unwrap();
|
||||
println!("Worker pool got notified to stop");
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let queue = Queue::new(task_store);
|
||||
// Enqueue a task that is not registered
|
||||
queue.enqueue(UnknownTask).await.unwrap();
|
||||
|
||||
// Notifies the worker pool to stop for this test
|
||||
queue.enqueue(NotifyFinished).await.unwrap();
|
||||
queue.enqueue(NotifyStopDuringRun).await.unwrap();
|
||||
|
||||
join_handle.await.unwrap();
|
||||
|
||||
assert!(
|
||||
!my_app_context.unknown_task_ran.load(Ordering::Relaxed),
|
||||
"Unknown task ran but it is not registered in the worker pool!"
|
||||
);
|
||||
}
|
||||
|
||||
async fn memory_store() -> MemoryTaskStore {
|
||||
#[tokio::test]
|
||||
async fn task_can_panic_and_not_affect_worker() {
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct BrokenTask;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for BrokenTask {
|
||||
const TASK_NAME: &'static str = "panic_me";
|
||||
type AppData = ();
|
||||
type Error = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _context: Self::AppData) -> Result<(), ()> {
|
||||
panic!("Oh no!");
|
||||
}
|
||||
}
|
||||
|
||||
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
|
||||
|
||||
let task_store = memory_store();
|
||||
|
||||
let worker_pool_finished = WorkerPool::new(task_store.clone(), || ())
|
||||
.register_task_type::<BrokenTask>()
|
||||
.configure_queue("default".into())
|
||||
.start(async move {
|
||||
should_stop.await.unwrap();
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let queue = Queue::new(task_store.clone());
|
||||
// Enqueue a task that will panic
|
||||
queue.enqueue(BrokenTask).await.unwrap();
|
||||
|
||||
notify_stop_worker_pool.send(()).unwrap();
|
||||
worker_pool_finished.await.unwrap();
|
||||
|
||||
let raw_task = task_store
|
||||
.tasks
|
||||
.lock()
|
||||
.await
|
||||
.first_entry()
|
||||
.unwrap()
|
||||
.remove();
|
||||
assert_eq!(
|
||||
serde_json::to_string(&raw_task.error_info.unwrap()).unwrap(),
|
||||
"{\"error\":\"Task panicked with: Oh no!\"}"
|
||||
);
|
||||
}
|
||||
|
||||
/// This test will make sure that the worker pool will only stop after all workers are done.
|
||||
/// We create a KeepAliveTask that will keep running until we notify it to stop.
|
||||
/// We stop the worker pool and make sure that the KeepAliveTask is still running.
|
||||
/// Then we notify the KeepAliveTask to stop and make sure that the worker pool stops.
|
||||
#[tokio::test]
|
||||
async fn tasks_only_stop_running_when_finished() {
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||
enum PingPongGame {
|
||||
Ping,
|
||||
Pong,
|
||||
StopThisNow,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct PlayerContext {
|
||||
/// Used to communicate with the running task
|
||||
pong_tx: Arc<tokio::sync::mpsc::Sender<PingPongGame>>,
|
||||
ping_rx: Arc<Mutex<tokio::sync::mpsc::Receiver<PingPongGame>>>,
|
||||
}
|
||||
|
||||
/// Task that will respond to the ping pong game and keep alive as long as we need
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct KeepAliveTask;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for KeepAliveTask {
|
||||
const TASK_NAME: &'static str = "keep_alive_task";
|
||||
|
||||
type AppData = PlayerContext;
|
||||
|
||||
type Error = ();
|
||||
|
||||
async fn run(
|
||||
&self,
|
||||
_task: CurrentTask,
|
||||
context: Self::AppData,
|
||||
) -> Result<(), Self::Error> {
|
||||
loop {
|
||||
let msg = context.ping_rx.lock().await.recv().await.unwrap();
|
||||
match msg {
|
||||
PingPongGame::Ping => {
|
||||
println!("Pong!");
|
||||
context.pong_tx.send(PingPongGame::Pong).await.unwrap();
|
||||
}
|
||||
PingPongGame::Pong => {
|
||||
context.pong_tx.send(PingPongGame::Ping).await.unwrap();
|
||||
}
|
||||
PingPongGame::StopThisNow => {
|
||||
println!("Got stop signal, stopping the ping pong game now!");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let (notify_stop_worker_pool, should_stop) = tokio::sync::oneshot::channel();
|
||||
let (pong_tx, mut pong_rx) = tokio::sync::mpsc::channel(1);
|
||||
let (ping_tx, ping_rx) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
let player_context = PlayerContext {
|
||||
pong_tx: Arc::new(pong_tx),
|
||||
ping_rx: Arc::new(Mutex::new(ping_rx)),
|
||||
};
|
||||
|
||||
let task_store = memory_store();
|
||||
|
||||
let worker_pool_finished = WorkerPool::new(task_store.clone(), {
|
||||
let player_context = player_context.clone();
|
||||
move || player_context.clone()
|
||||
})
|
||||
.register_task_type::<KeepAliveTask>()
|
||||
.configure_queue("default".into())
|
||||
.start(async move {
|
||||
should_stop.await.unwrap();
|
||||
println!("Worker pool got notified to stop");
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let queue = Queue::new(task_store);
|
||||
queue.enqueue(KeepAliveTask).await.unwrap();
|
||||
|
||||
// Make sure task is running
|
||||
println!("Ping!");
|
||||
ping_tx.send(PingPongGame::Ping).await.unwrap();
|
||||
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
|
||||
|
||||
// Notify to stop the worker pool
|
||||
notify_stop_worker_pool.send(()).unwrap();
|
||||
|
||||
// Make sure task is still running
|
||||
println!("Ping!");
|
||||
ping_tx.send(PingPongGame::Ping).await.unwrap();
|
||||
assert_eq!(pong_rx.recv().await.unwrap(), PingPongGame::Pong);
|
||||
|
||||
// is_none() means that the worker pool is still waiting for tasks to finish, which is what we want!
|
||||
assert!(
|
||||
worker_pool_finished.now_or_never().is_none(),
|
||||
"Worker pool finished before task stopped!"
|
||||
);
|
||||
|
||||
// Notify to stop the task, which will stop the worker pool
|
||||
ping_tx.send(PingPongGame::StopThisNow).await.unwrap();
|
||||
}
|
||||
|
||||
fn memory_store() -> MemoryTaskStore {
|
||||
MemoryTaskStore::default()
|
||||
}
|
||||
|
||||
|
@ -402,13 +712,15 @@ mod tests {
|
|||
async fn test_worker_pool_with_pg_store() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let (join_handle, _queue) =
|
||||
WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
let join_handle = WorkerPool::new(pg_task_store().await, move || my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.configure_queue(
|
||||
QueueConfig::new(<GreetingTask as MyAppTask>::QUEUE)
|
||||
.retention_mode(RetentionMode::RemoveDone),
|
||||
)
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_handle.await.unwrap();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue