Support generic backend to store tasks

This commit is contained in:
Rafael Caricio 2023-03-11 17:49:23 +01:00
parent fd92b25190
commit 894f928c01
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
8 changed files with 251 additions and 252 deletions

View file

@ -31,3 +31,6 @@ diesel = { version = "2.0", features = ["postgres", "serde_json", "chrono", "uui
diesel-derive-newtype = "2.0.0-rc.0"
diesel-async = { version = "0.2", features = ["postgres", "bb8"] }
tokio = { version = "1.25", features = ["rt", "time", "macros"] }
[dev-dependencies]
itertools = "0.10"

View file

@ -1,22 +1,7 @@
//#![warn(missing_docs)]
#![forbid(unsafe_code)]
#![doc = include_str!("../README.md")]
use chrono::{DateTime, Utc};
/// Represents a schedule for scheduled tasks.
///
/// It's used in the [`BackgroundTask::cron`]
#[derive(Debug, Clone)]
pub enum Scheduled {
/// A cron pattern for a periodic task
///
/// For example, `Scheduled::CronPattern("0/20 * * * * * *")`
CronPattern(String),
/// A datetime for a scheduled task that will be executed once
///
/// For example, `Scheduled::ScheduleOnce(chrono::Utc::now() + std::time::Duration::seconds(7i64))`
ScheduleOnce(DateTime<Utc>),
}
/// All possible options for retaining tasks in the db after their execution.
///
/// The default mode is [`RetentionMode::RemoveAll`]

View file

@ -44,9 +44,6 @@ impl Task {
) -> Result<Task, AsyncQueueError> {
use crate::schema::backie_tasks::dsl;
let now = Utc::now();
let scheduled_at = now + Duration::seconds(backoff_seconds as i64);
let error = serde_json::json!({
"error": error_message,
});
@ -55,7 +52,8 @@ impl Task {
.set((
backie_tasks::error_info.eq(Some(error)),
backie_tasks::retries.eq(dsl::retries + 1),
backie_tasks::scheduled_at.eq(scheduled_at),
backie_tasks::scheduled_at
.eq(Utc::now() + Duration::seconds(backoff_seconds as i64)),
backie_tasks::running_at.eq::<Option<DateTime<Utc>>>(None),
))
.get_result::<Task>(connection)

View file

@ -1,17 +1,23 @@
use crate::errors::BackieError;
use crate::runnable::BackgroundTask;
use crate::store::{PgTaskStore, TaskStore};
use crate::task::{NewTask, TaskHash};
use crate::store::TaskStore;
use crate::task::NewTask;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone)]
pub struct Queue {
task_store: Arc<PgTaskStore>,
pub struct Queue<S>
where
S: TaskStore,
{
task_store: Arc<S>,
}
impl Queue {
pub(crate) fn new(task_store: Arc<PgTaskStore>) -> Self {
impl<S> Queue<S>
where
S: TaskStore,
{
pub(crate) fn new(task_store: Arc<S>) -> Self {
Queue { task_store }
}
@ -25,200 +31,3 @@ impl Queue {
Ok(())
}
}
#[cfg(test)]
mod async_queue_tests {
use super::*;
use crate::CurrentTask;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct AsyncTask {
pub number: u16,
}
#[async_trait]
impl BackgroundTask for AsyncTask {
const TASK_NAME: &'static str = "AsyncUniqTask";
type AppData = ();
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct AsyncUniqTask {
pub number: u16,
}
#[async_trait]
impl BackgroundTask for AsyncUniqTask {
const TASK_NAME: &'static str = "AsyncUniqTask";
type AppData = ();
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
fn uniq(&self) -> Option<TaskHash> {
TaskHash::default_for_task(self).ok()
}
}
#[derive(Serialize, Deserialize)]
struct AsyncTaskSchedule {
pub number: u16,
pub datetime: String,
}
#[async_trait]
impl BackgroundTask for AsyncTaskSchedule {
const TASK_NAME: &'static str = "AsyncUniqTask";
type AppData = ();
async fn run(&self, _task: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
}
// fn cron(&self) -> Option<Scheduled> {
// let datetime = self.datetime.parse::<DateTime<Utc>>().ok()?;
// Some(Scheduled::ScheduleOnce(datetime))
// }
}
// #[tokio::test]
// async fn insert_task_creates_new_task() {
// let pool = pool().await;
// let mut queue = PgTaskStore::new(pool);
//
// let task = queue.create_task(AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// queue.remove_all_tasks().await.unwrap();
// }
//
// #[tokio::test]
// async fn update_task_state_test() {
// let pool = pool().await;
// let mut test = PgTaskStore::new(pool);
//
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
// let id = task.id;
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let finished_task = test.set_task_state(task.id, TaskState::Done).await.unwrap();
//
// assert_eq!(id, finished_task.id);
// assert_eq!(TaskState::Done, finished_task.state());
//
// test.remove_all_tasks().await.unwrap();
// }
//
// #[tokio::test]
// async fn failed_task_query_test() {
// let pool = pool().await;
// let mut test = PgTaskStore::new(pool);
//
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
// let id = task.id;
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let failed_task = test.set_task_state(task.id, TaskState::Failed("Some error".to_string())).await.unwrap();
//
// assert_eq!(id, failed_task.id);
// assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
// assert_eq!(TaskState::Failed, failed_task.state());
//
// test.remove_all_tasks().await.unwrap();
// }
//
// #[tokio::test]
// async fn remove_all_tasks_test() {
// let pool = pool().await;
// let mut test = PgTaskStore::new(pool);
//
// let task = test.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let task = test.create_task(&AsyncTask { number: 2 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(2), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let result = test.remove_all_tasks().await.unwrap();
// assert_eq!(2, result);
// }
//
// #[tokio::test]
// async fn pull_next_task_test() {
// let pool = pool().await;
// let mut queue = PgTaskStore::new(pool);
//
// let task = queue.create_task(&AsyncTask { number: 1 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let task = queue.create_task(&AsyncTask { number: 2 }).await.unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(2), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
//
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(1), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// let task = queue.pull_next_task(None).await.unwrap().unwrap();
// let metadata = task.payload.as_object().unwrap();
// let number = metadata["number"].as_u64();
// let type_task = metadata["type"].as_str();
//
// assert_eq!(Some(2), number);
// assert_eq!(Some("AsyncTask"), type_task);
//
// queue.remove_all_tasks().await.unwrap();
// }
}

View file

@ -38,6 +38,7 @@ impl TaskStore for PgTaskStore {
})
.await
}
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
@ -46,6 +47,7 @@ impl TaskStore for PgTaskStore {
.map_err(|e| QueryBuilderError(e.into()))?;
Task::insert(&mut connection, new_task).await
}
async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError> {
let mut connection = self
.pool
@ -53,14 +55,17 @@ impl TaskStore for PgTaskStore {
.await
.map_err(|e| QueryBuilderError(e.into()))?;
match state {
TaskState::Done => Task::set_done(&mut connection, id).await?,
TaskState::Failed(error_msg) => {
Task::fail_with_message(&mut connection, id, &error_msg).await?
TaskState::Done => {
Task::set_done(&mut connection, id).await?;
}
_ => return Ok(()),
TaskState::Failed(error_msg) => {
Task::fail_with_message(&mut connection, id, &error_msg).await?;
}
_ => (),
};
Ok(())
}
async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
let mut connection = self
.pool
@ -70,6 +75,7 @@ impl TaskStore for PgTaskStore {
let result = Task::remove(&mut connection, id).await?;
Ok(result)
}
async fn schedule_task_retry(
&self,
id: TaskId,
@ -86,8 +92,111 @@ impl TaskStore for PgTaskStore {
}
}
#[cfg(test)]
pub mod test_store {
use super::*;
use itertools::Itertools;
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
pub struct MemoryTaskStore {
tasks: Arc<Mutex<BTreeMap<TaskId, Task>>>,
}
impl MemoryTaskStore {
pub fn new() -> Self {
MemoryTaskStore {
tasks: Arc::new(Mutex::new(BTreeMap::new())),
}
}
}
#[async_trait::async_trait]
impl TaskStore for MemoryTaskStore {
async fn pull_next_task(&self, queue_name: &str) -> Result<Option<Task>, AsyncQueueError> {
let mut tasks = self.tasks.lock().await;
let mut next_task = None;
for (_, task) in tasks
.iter_mut()
.sorted_by(|a, b| a.1.created_at.cmp(&b.1.created_at))
{
if task.queue_name == queue_name && task.state() == TaskState::Ready {
task.running_at = Some(chrono::Utc::now());
next_task = Some(task.clone());
break;
}
}
Ok(next_task)
}
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError> {
let mut tasks = self.tasks.lock().await;
let task = Task::from(new_task);
tasks.insert(task.id, task.clone());
Ok(task)
}
async fn set_task_state(
&self,
id: TaskId,
state: TaskState,
) -> Result<(), AsyncQueueError> {
let mut tasks = self.tasks.lock().await;
let task = tasks.get_mut(&id).unwrap();
use TaskState::*;
match state {
Done => task.done_at = Some(chrono::Utc::now()),
Failed(error_msg) => {
let error_payload = serde_json::json!({
"error": error_msg,
});
task.error_info = Some(error_payload);
task.done_at = Some(chrono::Utc::now());
}
_ => {}
}
Ok(())
}
async fn remove_task(&self, id: TaskId) -> Result<u64, AsyncQueueError> {
let mut tasks = self.tasks.lock().await;
let res = tasks.remove(&id);
if res.is_some() {
Ok(1)
} else {
Ok(0)
}
}
async fn schedule_task_retry(
&self,
id: TaskId,
backoff_seconds: u32,
error: &str,
) -> Result<Task, AsyncQueueError> {
let mut tasks = self.tasks.lock().await;
let task = tasks.get_mut(&id).unwrap();
let error_payload = serde_json::json!({
"error": error,
});
task.error_info = Some(error_payload);
task.running_at = None;
task.retries += 1;
task.scheduled_at =
chrono::Utc::now() + chrono::Duration::seconds(backoff_seconds as i64);
Ok(task.clone())
}
}
}
#[async_trait::async_trait]
pub trait TaskStore {
pub trait TaskStore: Clone + Send + Sync + 'static {
async fn pull_next_task(&self, queue_name: &str) -> Result<Option<Task>, AsyncQueueError>;
async fn create_task(&self, new_task: NewTask) -> Result<Task, AsyncQueueError>;
async fn set_task_state(&self, id: TaskId, state: TaskState) -> Result<(), AsyncQueueError>;

View file

@ -28,7 +28,7 @@ pub enum TaskState {
Done,
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, DieselNewType, Serialize)]
#[derive(Clone, Copy, Debug, Ord, PartialOrd, Hash, PartialEq, Eq, DieselNewType, Serialize)]
pub struct TaskId(Uuid);
impl Display for TaskId {
@ -144,6 +144,27 @@ impl NewTask {
}
}
#[cfg(test)]
impl From<NewTask> for Task {
fn from(new_task: NewTask) -> Self {
Self {
id: TaskId(Uuid::new_v4()),
task_name: new_task.task_name,
queue_name: new_task.queue_name,
uniq_hash: new_task.uniq_hash,
payload: new_task.payload,
timeout_msecs: new_task.timeout_msecs,
created_at: Utc::now(),
scheduled_at: Utc::now(),
running_at: None,
done_at: None,
error_info: None,
retries: 0,
max_retries: new_task.max_retries,
}
}
}
pub struct CurrentTask {
id: TaskId,
retries: i32,

View file

@ -2,7 +2,7 @@ use crate::errors::{AsyncQueueError, BackieError};
use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
use crate::task::{CurrentTask, Task, TaskState};
use crate::{PgTaskStore, RetentionMode};
use crate::RetentionMode;
use futures::future::FutureExt;
use futures::select;
use std::collections::BTreeMap;
@ -48,11 +48,12 @@ where
}
/// Worker that executes tasks.
pub struct Worker<AppData>
pub struct Worker<AppData, S>
where
AppData: Clone + Send + 'static,
S: TaskStore,
{
store: Arc<PgTaskStore>,
store: Arc<S>,
queue_name: String,
@ -66,12 +67,13 @@ where
shutdown: Option<tokio::sync::watch::Receiver<()>>,
}
impl<AppData> Worker<AppData>
impl<AppData, S> Worker<AppData, S>
where
AppData: Clone + Send + 'static,
S: TaskStore,
{
pub(crate) fn new(
store: Arc<PgTaskStore>,
store: Arc<S>,
queue_name: String,
retention_mode: RetentionMode,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,

View file

@ -1,25 +1,26 @@
use crate::errors::BackieError;
use crate::queue::Queue;
use crate::runnable::BackgroundTask;
use crate::store::TaskStore;
use crate::worker::{runnable, ExecuteTaskFn};
use crate::worker::{StateFn, Worker};
use crate::{BackgroundTask, PgTaskStore, RetentionMode};
use crate::RetentionMode;
use std::collections::BTreeMap;
use std::future::Future;
use std::sync::Arc;
use tokio::task::JoinHandle;
pub type AppDataFn<AppData> = Arc<dyn Fn(Queue) -> AppData + Send + Sync>;
#[derive(Clone)]
pub struct WorkerPool<AppData>
pub struct WorkerPool<AppData, S>
where
AppData: Clone + Send + 'static,
S: TaskStore,
{
/// Storage of tasks.
queue_store: Arc<PgTaskStore>, // TODO: make this generic/dynamic referenced
task_store: Arc<S>,
/// Queue used to spawn tasks.
queue: Queue,
queue: Queue<S>,
/// Make possible to load the application data.
///
@ -38,14 +39,15 @@ where
worker_queues: BTreeMap<String, (RetentionMode, u32)>,
}
impl<AppData> WorkerPool<AppData>
impl<AppData, S> WorkerPool<AppData, S>
where
AppData: Clone + Send + 'static,
S: TaskStore,
{
/// Create a new worker pool.
pub fn new<A>(task_store: PgTaskStore, application_data_fn: A) -> Self
pub fn new<A>(task_store: S, application_data_fn: A) -> Self
where
A: Fn(Queue) -> AppData + Send + Sync + 'static,
A: Fn(Queue<S>) -> AppData + Send + Sync + 'static,
{
let queue_store = Arc::new(task_store);
let queue = Queue::new(queue_store.clone());
@ -54,7 +56,7 @@ where
move || application_data_fn(queue.clone())
};
Self {
queue_store,
task_store: queue_store,
queue,
application_data_fn: Arc::new(application_data_fn),
task_registry: BTreeMap::new(),
@ -91,7 +93,7 @@ where
pub async fn start<F>(
self,
graceful_shutdown: F,
) -> Result<(JoinHandle<()>, Queue), BackieError>
) -> Result<(JoinHandle<()>, Queue<S>), BackieError>
where
F: Future<Output = ()> + Send + 'static,
{
@ -107,8 +109,8 @@ where
// Spawn all individual workers per queue
for (queue_name, (retention_mode, num_workers)) in self.worker_queues.iter() {
for idx in 0..*num_workers {
let mut worker: Worker<AppData> = Worker::new(
self.queue_store.clone(),
let mut worker: Worker<AppData, S> = Worker::new(
self.task_store.clone(),
queue_name.clone(),
retention_mode.clone(),
self.task_registry.clone(),
@ -143,7 +145,10 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::store::test_store::MemoryTaskStore;
use crate::store::PgTaskStore;
use crate::task::CurrentTask;
use anyhow::Error;
use async_trait::async_trait;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
@ -191,11 +196,32 @@ mod tests {
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
struct OtherTask;
#[async_trait]
impl BackgroundTask for OtherTask {
const TASK_NAME: &'static str = "other_task";
const QUEUE: &'static str = "other_queue";
type AppData = ApplicationContext;
async fn run(&self, task: CurrentTask, context: Self::AppData) -> Result<(), Error> {
println!(
"[{}] Other task with {}!",
task.id(),
context.get_app_name()
);
Ok(())
}
}
#[tokio::test]
async fn validate_all_registered_tasks_queues_are_configured() {
let my_app_context = ApplicationContext::new();
let result = WorkerPool::new(task_store().await, move |_| my_app_context.clone())
let result = WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.start(futures::future::ready(()))
.await;
@ -210,11 +236,11 @@ mod tests {
}
#[tokio::test]
async fn test_worker_pool() {
async fn test_worker_pool_with_task() {
let my_app_context = ApplicationContext::new();
let (join_handle, queue) =
WorkerPool::new(task_store().await, move |_| my_app_context.clone())
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
.start(futures::future::ready(()))
@ -231,7 +257,53 @@ mod tests {
join_handle.await.unwrap();
}
async fn task_store() -> PgTaskStore {
#[tokio::test]
async fn test_worker_pool_with_multiple_task_types() {
let my_app_context = ApplicationContext::new();
let (join_handle, queue) =
WorkerPool::new(memory_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.register_task_type::<OtherTask>()
.configure_queue("default", 1, RetentionMode::default())
.configure_queue("other_queue", 1, RetentionMode::default())
.start(futures::future::ready(()))
.await
.unwrap();
queue
.enqueue(GreetingTask {
person: "Rafael".to_string(),
})
.await
.unwrap();
queue.enqueue(OtherTask).await.unwrap();
join_handle.await.unwrap();
}
async fn memory_store() -> MemoryTaskStore {
MemoryTaskStore::new()
}
#[tokio::test]
#[ignore]
async fn test_worker_pool_with_pg_store() {
let my_app_context = ApplicationContext::new();
let (join_handle, _queue) =
WorkerPool::new(pg_task_store().await, move |_| my_app_context.clone())
.register_task_type::<GreetingTask>()
.configure_queue(GreetingTask::QUEUE, 1, RetentionMode::RemoveDone)
.start(futures::future::ready(()))
.await
.unwrap();
join_handle.await.unwrap();
}
async fn pg_task_store() -> PgTaskStore {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
option_env!("DATABASE_URL").expect("DATABASE_URL must be set"),
);