Support generic backend to store tasks
This commit is contained in:
parent
fd92b25190
commit
894f928c01
8 changed files with 251 additions and 252 deletions
|
@ -31,3 +31,6 @@ diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uui
|
|||
diesel-derive-newtype = "2.0.0-rc.0"
|
||||
diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
|
||||
tokio = { version = "1.25", features = ["rt", "time", "macros"] }
|
||||
|
||||
[dev-dependencies]
|
||||
itertools = "0.10"
|
||||
|
|
19
src/lib.rs
19
src/lib.rs
|
@ -1,22 +1,7 @@
|
|||
//#![warn(missing_docs)]
|
||||
#![forbid(unsafe_code)]
|
||||
#![doc = include_str!("../README.md")]
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
/// Represents a schedule for scheduled tasks.
|
||||
///
|
||||
/// It's used in the [`BackgroundTask::cron`]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Scheduled {
|
||||
/// A cron pattern for a periodic task
|
||||
///
|
||||
/// For example, `Scheduled::CronPattern("0/20 * * * * * *")`
|
||||
CronPattern(String),
|
||||
/// A datetime for a scheduled task that will be executed once
|
||||
///
|
||||
/// For example, `Scheduled::ScheduleOnce(chrono::Utc::now() + std::time::Duration::seconds(7i64))`
|
||||
ScheduleOnce(DateTime<Utc>),
|
||||
}
|
||||
|
||||
/// All possible options for retaining tasks in the db after their execution.
|
||||
///
|
||||
/// The default mode is [`RetentionMode::RemoveAll`]
|
||||
|
|
|
@ -44,9 +44,6 @@ impl Task {
|
|||
) -> Result<Task, AsyncQueueError> {
|
||||
use crate::schema::backie_tasks::dsl;
|
||||
|
||||
let now = Utc::now();
|
||||
let scheduled_at = now + Duration::seconds(backoff_seconds as i64);
|
||||
|
||||
let error = serde_json::json!({
|
||||
"error": error_message,
|
||||
});
|
||||
|
@ -55,7 +52,8 @@ impl Task {
|
|||
.set((
|
||||
backie_tasks::error_info.eq(Some(error)),
|
||||
backie_tasks::retries.eq(dsl::retries + 1),
|
||||
backie_tasks::scheduled_at.eq(scheduled_at),
|
||||
backie_tasks::scheduled_at
|
||||
.eq(Utc::now() + Duration::seconds(backoff_seconds as i64)),
|
||||
backie_tasks::running_at.eq::<Option<DateTime<Utc>>>(None),
|
||||
))
|
||||
.get_result::<Task>(connection)
|
||||
|
|
215
src/queue.rs
215
src/queue.rs
|
@ -1,17 +1,23 @@
|
|||
use crate::errors::BackieError;
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::store::{PgTaskStore, TaskStore};
|
||||
use crate::task::{NewTask, TaskHash};
|
||||
use crate::store::TaskStore;
|
||||
use crate::task::NewTask;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Queue {
|
||||
task_store: Arc<PgTaskStore>,
|
||||
pub struct Queue<S>
|
||||
where
|
||||
S: TaskStore,
|
||||
{
|
||||
task_store: Arc<S>,
|
||||
}
|
||||
|
||||
impl Queue {
|
||||
pub(crate) fn new(task_store: Arc<PgTaskStore>) -> Self {
|
||||
impl<S> Queue<S>
|
||||
where
|
||||
S: TaskStore,
|
||||
{
|
||||
pub(crate) fn new(task_store: Arc<S>) -> Self {
|
||||
Queue { task_store }
|
||||
}
|
||||
|
||||
|
@ -25,200 +31,3 @@ impl Queue {
|
|||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod async_queue_tests {
|
||||
use super::*;
|
||||
use crate::CurrentTask;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AsyncTask {
|
||||
pub number: u16,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for AsyncTask {
|
||||
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AsyncUniqTask {
|
||||
pub number: u16,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for AsyncUniqTask {
|
||||
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn uniq(&self) -> Option<TaskHash> {
|
||||
TaskHash::default_for_task(self).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct AsyncTaskSchedule {
|
||||
pub number: u16,
|
||||
pub datetime: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for AsyncTaskSchedule {
|
||||
const TASK_NAME: &'static str = "AsyncUniqTask";
|
||||
type AppData = ();
|
||||
|
||||
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// fn cron(&self) -> Option<Scheduled> {
|
||||
// let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
|
||||
// Some(Scheduled::ScheduleOnce(datetime))
|
||||
// }
|
||||
}
|
||||
|
||||
// #[tokio::test]
|
||||
// async fn insert_task_creates_new_task() {
|
||||
// let pool = pool().await;
|
||||
// let mut queue = PgTaskStore::new(pool);
|
||||
//
|
||||
// let task = queue.create_task(AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// queue.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn update_task_state_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut test = PgTaskStore::new(pool);
|
||||
//
|
||||
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
// let id = task.id;
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let finished_task = test.set_task_state(task.id, TaskState::Done).await.unwrap();
|
||||
//
|
||||
// assert_eq!(id, finished_task.id);
|
||||
// assert_eq!(TaskState::Done, finished_task.state());
|
||||
//
|
||||
// test.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn failed_task_query_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut test = PgTaskStore::new(pool);
|
||||
//
|
||||
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
// let id = task.id;
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let failed_task = test.set_task_state(task.id, TaskState::Failed("Some error".to_string())).await.unwrap();
|
||||
//
|
||||
// assert_eq!(id, failed_task.id);
|
||||
// assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
|
||||
// assert_eq!(TaskState::Failed, failed_task.state());
|
||||
//
|
||||
// test.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn remove_all_tasks_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut test = PgTaskStore::new(pool);
|
||||
//
|
||||
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(2), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let result = test.remove_all_tasks().await.unwrap();
|
||||
// assert_eq!(2, result);
|
||||
// }
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn pull_next_task_test() {
|
||||
// let pool = pool().await;
|
||||
// let mut queue = PgTaskStore::new(pool);
|
||||
//
|
||||
// let task = queue.create_task(&AsyncTask { number: 1 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let task = queue.create_task(&AsyncTask { number: 2 }).await.unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(2), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
|
||||
//
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(1), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
|
||||
// let metadata = task.payload.as_object().unwrap();
|
||||
// let number = metadata["number"].as_u64();
|
||||
// let type_task = metadata["type"].as_str();
|
||||
//
|
||||
// assert_eq!(Some(2), number);
|
||||
// assert_eq!(Some("AsyncTask"), type_task);
|
||||
//
|
||||
// queue.remove_all_tasks().await.unwrap();
|
||||
// }
|
||||
}
|
||||
|
|
119
src/store.rs
119
src/store.rs
|
@ -38,6 +38,7 @@ impl TaskStore for PgTaskStore {
|
|||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
|
@ -46,6 +47,7 @@ impl TaskStore for PgTaskStore {
|
|||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
Task::insert(&mut connection, new_task).await
|
||||
}
|
||||
|
||||
async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
|
@ -53,14 +55,17 @@ impl TaskStore for PgTaskStore {
|
|||
.await
|
||||
.map_err(|e| QueryBuilderError(e.into()))?;
|
||||
match state {
|
||||
TaskState::Done => Task::set_done(&mut connection, id).await?,
|
||||
TaskState::Failed(error_msg) => {
|
||||
Task::fail_with_message(&mut connection, id, &error_msg).await?
|
||||
TaskState::Done => {
|
||||
Task::set_done(&mut connection, id).await?;
|
||||
}
|
||||
_ => return Ok(()),
|
||||
TaskState::Failed(error_msg) => {
|
||||
Task::fail_with_message(&mut connection, id, &error_msg).await?;
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
|
@ -70,6 +75,7 @@ impl TaskStore for PgTaskStore {
|
|||
let result = Task::remove(&mut connection, id).await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn schedule_task_retry(
|
||||
&self,
|
||||
id: TaskId,
|
||||
|
@ -86,8 +92,111 @@ impl TaskStore for PgTaskStore {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_store {
|
||||
use super::*;
|
||||
use itertools::Itertools;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryTaskStore {
|
||||
tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
|
||||
}
|
||||
|
||||
impl MemoryTaskStore {
|
||||
pub fn new() -> Self {
|
||||
MemoryTaskStore {
|
||||
tasks: Arc::new(Mutex::new(BTreeMap::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TaskStore for MemoryTaskStore {
|
||||
async fn pull_next_task(&self, queue_name: &str) -> Result<Option<Task>, AsyncQueueError> {
|
||||
let mut tasks = self.tasks.lock().await;
|
||||
let mut next_task = None;
|
||||
for (_, task) in tasks
|
||||
.iter_mut()
|
||||
.sorted_by(|a, b| a.1.created_at.cmp(&b.1.created_at))
|
||||
{
|
||||
if task.queue_name == queue_name && task.state() == TaskState::Ready {
|
||||
task.running_at = Some(chrono::Utc::now());
|
||||
next_task = Some(task.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(next_task)
|
||||
}
|
||||
|
||||
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
|
||||
let mut tasks = self.tasks.lock().await;
|
||||
let task = Task::from(new_task);
|
||||
tasks.insert(task.id, task.clone());
|
||||
Ok(task)
|
||||
}
|
||||
|
||||
async fn set_task_state(
|
||||
&self,
|
||||
id: TaskId,
|
||||
state: TaskState,
|
||||
) -> Result<(), AsyncQueueError> {
|
||||
let mut tasks = self.tasks.lock().await;
|
||||
let task = tasks.get_mut(&id).unwrap();
|
||||
|
||||
use TaskState::*;
|
||||
match state {
|
||||
Done => task.done_at = Some(chrono::Utc::now()),
|
||||
Failed(error_msg) => {
|
||||
let error_payload = serde_json::json!({
|
||||
"error": error_msg,
|
||||
});
|
||||
task.error_info = Some(error_payload);
|
||||
task.done_at = Some(chrono::Utc::now());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
|
||||
let mut tasks = self.tasks.lock().await;
|
||||
let res = tasks.remove(&id);
|
||||
if res.is_some() {
|
||||
Ok(1)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
|
||||
async fn schedule_task_retry(
|
||||
&self,
|
||||
id: TaskId,
|
||||
backoff_seconds: u32,
|
||||
error: &str,
|
||||
) -> Result<Task, AsyncQueueError> {
|
||||
let mut tasks = self.tasks.lock().await;
|
||||
let task = tasks.get_mut(&id).unwrap();
|
||||
|
||||
let error_payload = serde_json::json!({
|
||||
"error": error,
|
||||
});
|
||||
task.error_info = Some(error_payload);
|
||||
task.running_at = None;
|
||||
task.retries += 1;
|
||||
task.scheduled_at =
|
||||
chrono::Utc::now() + chrono::Duration::seconds(backoff_seconds as i64);
|
||||
|
||||
Ok(task.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait TaskStore {
|
||||
pub trait TaskStore: Clone + Send + Sync + 'static {
|
||||
async fn pull_next_task(&self, queue_name: &str) -> Result<Option<Task>, AsyncQueueError>;
|
||||
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError>;
|
||||
async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError>;
|
||||
|
|
23
src/task.rs
23
src/task.rs
|
@ -28,7 +28,7 @@ pub enum TaskState {
|
|||
Done,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
|
||||
#[derive(Clone, Copy, Debug, Ord, PartialOrd, Hash, PartialEq, Eq, DieselNewType, Serialize)]
|
||||
pub struct TaskId(Uuid);
|
||||
|
||||
impl Display for TaskId {
|
||||
|
@ -144,6 +144,27 @@ impl NewTask {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl From<NewTask> for Task {
|
||||
fn from(new_task: NewTask) -> Self {
|
||||
Self {
|
||||
id: TaskId(Uuid::new_v4()),
|
||||
task_name: new_task.task_name,
|
||||
queue_name: new_task.queue_name,
|
||||
uniq_hash: new_task.uniq_hash,
|
||||
payload: new_task.payload,
|
||||
timeout_msecs: new_task.timeout_msecs,
|
||||
created_at: Utc::now(),
|
||||
scheduled_at: Utc::now(),
|
||||
running_at: None,
|
||||
done_at: None,
|
||||
error_info: None,
|
||||
retries: 0,
|
||||
max_retries: new_task.max_retries,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CurrentTask {
|
||||
id: TaskId,
|
||||
retries: i32,
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::errors::{AsyncQueueError, BackieError};
|
|||
use crate::runnable::BackgroundTask;
|
||||
use crate::store::TaskStore;
|
||||
use crate::task::{CurrentTask, Task, TaskState};
|
||||
use crate::{PgTaskStore, RetentionMode};
|
||||
use crate::RetentionMode;
|
||||
use futures::future::FutureExt;
|
||||
use futures::select;
|
||||
use std::collections::BTreeMap;
|
||||
|
@ -48,11 +48,12 @@ where
|
|||
}
|
||||
|
||||
/// Worker that executes tasks.
|
||||
pub struct Worker<AppData>
|
||||
pub struct Worker<AppData, S>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
S: TaskStore,
|
||||
{
|
||||
store: Arc<PgTaskStore>,
|
||||
store: Arc<S>,
|
||||
|
||||
queue_name: String,
|
||||
|
||||
|
@ -66,12 +67,13 @@ where
|
|||
shutdown: Option<tokio::sync::watch::Receiver<()>>,
|
||||
}
|
||||
|
||||
impl<AppData> Worker<AppData>
|
||||
impl<AppData, S> Worker<AppData, S>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
S: TaskStore,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
store: Arc<PgTaskStore>,
|
||||
store: Arc<S>,
|
||||
queue_name: String,
|
||||
retention_mode: RetentionMode,
|
||||
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
|
||||
|
|
|
@ -1,25 +1,26 @@
|
|||
use crate::errors::BackieError;
|
||||
use crate::queue::Queue;
|
||||
use crate::runnable::BackgroundTask;
|
||||
use crate::store::TaskStore;
|
||||
use crate::worker::{runnable, ExecuteTaskFn};
|
||||
use crate::worker::{StateFn, Worker};
|
||||
use crate::{BackgroundTask, PgTaskStore, RetentionMode};
|
||||
use crate::RetentionMode;
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
pub type AppDataFn<AppData> = Arc<dyn Fn(Queue) -> AppData + Send + Sync>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WorkerPool<AppData>
|
||||
pub struct WorkerPool<AppData, S>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
S: TaskStore,
|
||||
{
|
||||
/// Storage of tasks.
|
||||
queue_store: Arc<PgTaskStore>, // TODO: make this generic/dynamic referenced
|
||||
task_store: Arc<S>,
|
||||
|
||||
/// Queue used to spawn tasks.
|
||||
queue: Queue,
|
||||
queue: Queue<S>,
|
||||
|
||||
/// Make possible to load the application data.
|
||||
///
|
||||
|
@ -38,14 +39,15 @@ where
|
|||
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
|
||||
}
|
||||
|
||||
impl<AppData> WorkerPool<AppData>
|
||||
impl<AppData, S> WorkerPool<AppData, S>
|
||||
where
|
||||
AppData: Clone + Send + 'static,
|
||||
S: TaskStore,
|
||||
{
|
||||
/// Create a new worker pool.
|
||||
pub fn new<A>(task_store: PgTaskStore, application_data_fn: A) -> Self
|
||||
pub fn new<A>(task_store: S, application_data_fn: A) -> Self
|
||||
where
|
||||
A: Fn(Queue) -> AppData + Send + Sync + 'static,
|
||||
A: Fn(Queue<S>) -> AppData + Send + Sync + 'static,
|
||||
{
|
||||
let queue_store = Arc::new(task_store);
|
||||
let queue = Queue::new(queue_store.clone());
|
||||
|
@ -54,7 +56,7 @@ where
|
|||
move || application_data_fn(queue.clone())
|
||||
};
|
||||
Self {
|
||||
queue_store,
|
||||
task_store: queue_store,
|
||||
queue,
|
||||
application_data_fn: Arc::new(application_data_fn),
|
||||
task_registry: BTreeMap::new(),
|
||||
|
@ -91,7 +93,7 @@ where
|
|||
pub async fn start<F>(
|
||||
self,
|
||||
graceful_shutdown: F,
|
||||
) -> Result<(JoinHandle<()>, Queue), BackieError>
|
||||
) -> Result<(JoinHandle<()>, Queue<S>), BackieError>
|
||||
where
|
||||
F: Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
|
@ -107,8 +109,8 @@ where
|
|||
// Spawn all individual workers per queue
|
||||
for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() {
|
||||
for idx in 0..*num_workers {
|
||||
let mut worker: Worker<AppData> = Worker::new(
|
||||
self.queue_store.clone(),
|
||||
let mut worker: Worker<AppData, S> = Worker::new(
|
||||
self.task_store.clone(),
|
||||
queue_name.clone(),
|
||||
retention_mode.clone(),
|
||||
self.task_registry.clone(),
|
||||
|
@ -143,7 +145,10 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::store::test_store::MemoryTaskStore;
|
||||
use crate::store::PgTaskStore;
|
||||
use crate::task::CurrentTask;
|
||||
use anyhow::Error;
|
||||
use async_trait::async_trait;
|
||||
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
|
||||
use diesel_async::AsyncPgConnection;
|
||||
|
@ -191,11 +196,32 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
struct OtherTask;
|
||||
|
||||
#[async_trait]
|
||||
impl BackgroundTask for OtherTask {
|
||||
const TASK_NAME: &'static str = "other_task";
|
||||
|
||||
const QUEUE: &'static str = "other_queue";
|
||||
|
||||
type AppData = ApplicationContext;
|
||||
|
||||
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Error> {
|
||||
println!(
|
||||
"[{}] Other task with {}!",
|
||||
task.id(),
|
||||
context.get_app_name()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_all_registered_tasks_queues_are_configured() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let result = WorkerPool::new(task_store().await, move |_| my_app_context.clone())
|
||||
let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.start(futures::future::ready(()))
|
||||
.await;
|
||||
|
@ -210,11 +236,11 @@ mod tests {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_worker_pool() {
|
||||
async fn test_worker_pool_with_task() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let (join_handle, queue) =
|
||||
WorkerPool::new(task_store().await, move |_| my_app_context.clone())
|
||||
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
|
||||
.start(futures::future::ready(()))
|
||||
|
@ -231,7 +257,53 @@ mod tests {
|
|||
join_handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn task_store() -> PgTaskStore {
|
||||
#[tokio::test]
|
||||
async fn test_worker_pool_with_multiple_task_types() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let (join_handle, queue) =
|
||||
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.register_task_type::<OtherTask>()
|
||||
.configure_queue("default", 1, RetentionMode::default())
|
||||
.configure_queue("other_queue", 1, RetentionMode::default())
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
queue
|
||||
.enqueue(GreetingTask {
|
||||
person: "Rafael".to_string(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
queue.enqueue(OtherTask).await.unwrap();
|
||||
|
||||
join_handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn memory_store() -> MemoryTaskStore {
|
||||
MemoryTaskStore::new()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_worker_pool_with_pg_store() {
|
||||
let my_app_context = ApplicationContext::new();
|
||||
|
||||
let (join_handle, _queue) =
|
||||
WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone())
|
||||
.register_task_type::<GreetingTask>()
|
||||
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
|
||||
.start(futures::future::ready(()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
join_handle.await.unwrap();
|
||||
}
|
||||
|
||||
async fn pg_task_store() -> PgTaskStore {
|
||||
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
|
||||
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
|
||||
);
|
||||
|
|
Loading…
Reference in a new issue