Make workers go brrrr...

This commit is contained in:
Rafael Caricio 2023-03-04 20:46:09 +01:00
parent 0be173ef02
commit 18303be796
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
10 changed files with 337 additions and 355 deletions

View file

@ -23,6 +23,12 @@ Background task processing library for Rust. It uses Postgres DB as a task queue
- Retries.
Tasks can be retried with a custom backoff mode
## Differences from original fang
- Supports only async processing
- Supports graceful shutdown
- The connection pool for the queue is provided by the user
## Installation
1. Add this to your Cargo.toml

View file

@ -1,4 +1,4 @@
use fang::queue::AsyncQueue;
use fang::queue::PgAsyncQueue;
use fang::queue::AsyncQueueable;
use fang::worker_pool::AsyncWorkerPool;
use fang::runnable::AsyncRunnable;
@ -25,13 +25,13 @@ async fn main() {
.await
.unwrap();
let mut queue = AsyncQueue::builder()
let mut queue = PgAsyncQueue::builder()
.pool(pool)
.build();
log::info!("Queue connected...");
let mut workers_pool: AsyncWorkerPool<AsyncQueue> = AsyncWorkerPool::builder()
let mut workers_pool: AsyncWorkerPool<PgAsyncQueue> = AsyncWorkerPool::builder()
.number_of_workers(10_u32)
.queue(queue.clone())
.build();

View file

@ -1,3 +1,4 @@
use serde_json::Error as SerdeError;
use thiserror::Error;
/// An error that can happen during executing of tasks
@ -7,6 +8,21 @@ pub struct FangError {
pub description: String,
}
impl From<AsyncQueueError> for FangError {
fn from(error: AsyncQueueError) -> Self {
let message = format!("{error:?}");
FangError {
description: message,
}
}
}
impl From<SerdeError> for FangError {
fn from(error: SerdeError) -> Self {
Self::from(AsyncQueueError::SerdeError(error))
}
}
/// List of error types that can occur while working with cron schedules.
#[derive(Debug, Error)]
pub enum CronError {

View file

@ -1,7 +1,7 @@
#![doc = include_str!("../README.md")]
use std::time::Duration;
use chrono::{DateTime, Utc};
use std::time::Duration;
use typed_builder::TypedBuilder;
/// Represents a schedule for scheduled tasks.
@ -80,12 +80,12 @@ impl Default for SleepParams {
}
}
pub mod errors;
pub mod fang_task_state;
mod queries;
pub mod queue;
pub mod runnable;
pub mod schema;
pub mod task;
pub mod queue;
mod queries;
pub mod errors;
pub mod runnable;
pub mod worker;
pub mod worker_pool;

View file

@ -1,28 +1,17 @@
use crate::runnable::AsyncRunnable;
use crate::errors::AsyncQueueError;
use crate::fang_task_state::FangTaskState;
use crate::runnable::AsyncRunnable;
use crate::schema::fang_tasks;
use crate::errors::CronError;
use crate::Scheduled::*;
use crate::task::{DEFAULT_TASK_TYPE, Task};
use async_trait::async_trait;
use crate::task::NewTask;
use crate::task::{Task, DEFAULT_TASK_TYPE};
use chrono::DateTime;
use chrono::Duration;
use chrono::Utc;
use cron::Schedule;
use diesel::prelude::*;
use diesel::result::Error::QueryBuilderError;
use diesel::ExpressionMethods;
use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection;
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool, pooled_connection::bb8::PooledConnection, RunQueryDsl};
use diesel_async::pooled_connection::PoolableConnection;
use diesel_async::{pg::AsyncPgConnection, RunQueryDsl};
use sha2::{Digest, Sha256};
use std::str::FromStr;
use typed_builder::TypedBuilder;
use uuid::Uuid;
use crate::task::NewTask;
use crate::errors::AsyncQueueError;
impl Task {
pub async fn remove_all_scheduled_tasks(
@ -120,7 +109,10 @@ impl Task {
.limit(1)
.filter(fang_tasks::scheduled_at.le(Utc::now()))
.filter(fang_tasks::state.eq_any(vec![FangTaskState::New, FangTaskState::Retried]))
.filter(fang_tasks::task_type.eq(task_type.unwrap_or_else(|| DEFAULT_TASK_TYPE.to_string())))
.filter(
fang_tasks::task_type
.eq(task_type.unwrap_or_else(|| DEFAULT_TASK_TYPE.to_string())),
)
.for_update()
.skip_locked()
.get_result::<Task>(connection)

View file

@ -1,33 +1,24 @@
use crate::runnable::AsyncRunnable;
use crate::fang_task_state::FangTaskState;
use crate::schema::fang_tasks;
use crate::errors::AsyncQueueError;
use crate::errors::CronError;
use crate::fang_task_state::FangTaskState;
use crate::runnable::AsyncRunnable;
use crate::schema::fang_tasks;
use crate::task::Task;
use crate::Scheduled::*;
use crate::task::{DEFAULT_TASK_TYPE, Task};
use async_trait::async_trait;
use chrono::DateTime;
use chrono::Duration;
use chrono::Utc;
use crate::task::NewTask;
use cron::Schedule;
use diesel::result::Error::QueryBuilderError;
use diesel::ExpressionMethods;
use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection;
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool, pooled_connection::AsyncDieselConnectionManager, RunQueryDsl};
use sha2::{Sha256};
use diesel_async::{pg::AsyncPgConnection, pooled_connection::bb8::Pool, RunQueryDsl};
use std::str::FromStr;
use diesel_async::pooled_connection::PoolableConnection;
use thiserror::Error;
use crate::errors::AsyncQueueError;
use typed_builder::TypedBuilder;
use uuid::Uuid;
/// This trait defines operations for an asynchronous queue.
/// The trait can be implemented for different storage backends.
/// For now, the trait is only implemented for PostgreSQL. More backends are planned to be implemented in the future.
#[async_trait]
pub trait AsyncQueueable: Send {
/// This method should retrieve one task of the `task_type` type. If `task_type` is `None` it will try to
@ -100,12 +91,12 @@ pub trait AsyncQueueable: Send {
/// ```
///
#[derive(TypedBuilder, Debug, Clone)]
pub struct AsyncQueue {
pub struct PgAsyncQueue {
pool: Pool<AsyncPgConnection>,
}
#[async_trait]
impl AsyncQueueable for AsyncQueue {
impl AsyncQueueable for PgAsyncQueue {
async fn find_task_by_id(&mut self, id: Uuid) -> Result<Task, AsyncQueueError> {
let mut connection = self
.pool
@ -131,12 +122,7 @@ impl AsyncQueueable for AsyncQueue {
return Ok(None);
};
match Task::update_task_state(
conn,
found_task,
FangTaskState::InProgress,
)
.await
match Task::update_task_state(conn, found_task, FangTaskState::InProgress).await
{
Ok(updated_task) => Ok(Some(updated_task)),
Err(err) => Err(err),
@ -281,8 +267,7 @@ impl AsyncQueueable for AsyncQueue {
.get()
.await
.map_err(|e| QueryBuilderError(e.into()))?;
let task =
Task::schedule_retry(&mut connection, task, backoff_seconds, error).await?;
let task = Task::schedule_retry(&mut connection, task, backoff_seconds, error).await?;
Ok(task)
}
}
@ -290,12 +275,12 @@ impl AsyncQueueable for AsyncQueue {
#[cfg(test)]
mod async_queue_tests {
use super::*;
use crate::schema::fang_tasks::task_type;
use crate::errors::FangError;
use crate::Scheduled;
use async_trait::async_trait;
use chrono::prelude::*;
use chrono::DateTime;
use chrono::Duration;
use chrono::Utc;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
@ -353,7 +338,7 @@ mod async_queue_tests {
#[tokio::test]
async fn insert_task_creates_new_task() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool).build();
let mut test = PgAsyncQueue::builder().pool(pool).build();
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
@ -370,7 +355,7 @@ mod async_queue_tests {
#[tokio::test]
async fn update_task_state_test() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool).build();
let mut test = PgAsyncQueue::builder().pool(pool).build();
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
@ -396,7 +381,7 @@ mod async_queue_tests {
#[tokio::test]
async fn failed_task_query_test() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool).build();
let mut test = PgAsyncQueue::builder().pool(pool).build();
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
@ -420,7 +405,7 @@ mod async_queue_tests {
#[tokio::test]
async fn remove_all_tasks_test() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool.into()).build();
let mut test = PgAsyncQueue::builder().pool(pool.into()).build();
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
@ -447,7 +432,7 @@ mod async_queue_tests {
#[tokio::test]
async fn schedule_task_test() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool).build();
let mut test = PgAsyncQueue::builder().pool(pool).build();
let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0);
@ -472,7 +457,7 @@ mod async_queue_tests {
#[tokio::test]
async fn remove_all_scheduled_tasks_test() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool).build();
let mut test = PgAsyncQueue::builder().pool(pool).build();
let datetime = (Utc::now() + Duration::seconds(7)).round_subsecs(0);
@ -499,7 +484,7 @@ mod async_queue_tests {
#[tokio::test]
async fn fetch_and_touch_test() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool).build();
let mut test = PgAsyncQueue::builder().pool(pool).build();
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
@ -519,11 +504,7 @@ mod async_queue_tests {
assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task);
let task = test
.fetch_and_touch_task(None)
.await
.unwrap()
.unwrap();
let task = test.fetch_and_touch_task(None).await.unwrap().unwrap();
let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64();
@ -532,11 +513,7 @@ mod async_queue_tests {
assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task);
let task = test
.fetch_and_touch_task(None)
.await
.unwrap()
.unwrap();
let task = test.fetch_and_touch_task(None).await.unwrap().unwrap();
let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str();
@ -550,7 +527,7 @@ mod async_queue_tests {
#[tokio::test]
async fn remove_tasks_type_test() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool).build();
let mut test = PgAsyncQueue::builder().pool(pool).build();
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
@ -582,7 +559,7 @@ mod async_queue_tests {
#[tokio::test]
async fn remove_tasks_by_metadata() {
let pool = pool().await;
let mut test = AsyncQueue::builder().pool(pool).build();
let mut test = PgAsyncQueue::builder().pool(pool).build();
let task = insert_task(&mut test, &AsyncUniqTask { number: 1 }).await;
@ -617,7 +594,7 @@ mod async_queue_tests {
test.remove_all_tasks().await.unwrap();
}
async fn insert_task(test: &mut AsyncQueue, task: &dyn AsyncRunnable) -> Task {
async fn insert_task(test: &mut PgAsyncQueue, task: &dyn AsyncRunnable) -> Task {
test.insert_task(task).await.unwrap()
}

View file

@ -1,28 +1,10 @@
use crate::errors::AsyncQueueError;
use crate::queue::AsyncQueueable;
use crate::errors::FangError;
use crate::queue::AsyncQueueable;
use crate::Scheduled;
use async_trait::async_trait;
use serde_json::Error as SerdeError;
const COMMON_TYPE: &str = "common";
pub const RETRIES_NUMBER: i32 = 20;
impl From<AsyncQueueError> for FangError {
fn from(error: AsyncQueueError) -> Self {
let message = format!("{error:?}");
FangError {
description: message,
}
}
}
impl From<SerdeError> for FangError {
fn from(error: SerdeError) -> Self {
Self::from(AsyncQueueError::SerdeError(error))
}
}
/// Implement this trait to run your custom tasks.
#[typetag::serde(tag = "type")]
#[async_trait]

View file

@ -1,12 +1,8 @@
use crate::fang_task_state::FangTaskState;
use crate::schema::fang_tasks;
use chrono::DateTime;
use chrono::Duration;
use chrono::Utc;
use cron::Schedule;
use diesel::prelude::*;
use sha2::{Digest, Sha256};
use thiserror::Error;
use typed_builder::TypedBuilder;
use uuid::Uuid;

View file

@ -1,9 +1,9 @@
use crate::errors::FangError;
use crate::fang_task_state::FangTaskState;
use crate::queue::AsyncQueueable;
use crate::runnable::AsyncRunnable;
use crate::task::Task;
use crate::task::DEFAULT_TASK_TYPE;
use crate::runnable::AsyncRunnable;
use crate::fang_task_state::FangTaskState;
use crate::errors::FangError;
use crate::Scheduled::*;
use crate::{RetentionMode, SleepParams};
use log::error;
@ -92,7 +92,11 @@ where
pub(crate) async fn run_tasks(&mut self) -> Result<(), FangError> {
loop {
//fetch task
match self.queue.fetch_and_touch_task(Some(self.task_type.clone())).await {
match self
.queue
.fetch_and_touch_task(Some(self.task_type.clone()))
.await
{
Ok(Some(task)) => {
let actual_task: Box<dyn AsyncRunnable> =
serde_json::from_value(task.metadata.clone()).unwrap();
@ -118,19 +122,55 @@ where
};
}
}
#[cfg(test)]
pub async fn run_tasks_until_none(&mut self) -> Result<(), FangError> {
loop {
match self
.queue
.fetch_and_touch_task(Some(self.task_type.clone()))
.await
{
Ok(Some(task)) => {
let actual_task: Box<dyn AsyncRunnable> =
serde_json::from_value(task.metadata.clone()).unwrap();
// check if task is scheduled or not
if let Some(CronPattern(_)) = actual_task.cron() {
// program task
self.queue.schedule_task(&*actual_task).await?;
}
self.sleep_params.maybe_reset_sleep_period();
// run scheduled task
self.run(task, actual_task).await?;
}
Ok(None) => {
return Ok(());
}
Err(error) => {
error!("Failed to fetch a task {:?}", error);
self.sleep().await;
}
};
}
}
}
#[cfg(test)]
mod async_worker_tests {
use super::*;
use crate::queue::AsyncQueueable;
use crate::worker::Task;
use crate::errors::FangError;
use crate::queue::AsyncQueueable;
use crate::queue::PgAsyncQueue;
use crate::worker::Task;
use crate::RetentionMode;
use crate::Scheduled;
use async_trait::async_trait;
use chrono::Duration;
use chrono::Utc;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use diesel_async::AsyncPgConnection;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
@ -232,222 +272,217 @@ mod async_worker_tests {
}
}
// #[tokio::test]
// async fn execute_and_finishes_task() {
// let pool = pool().await;
// let mut connection = pool.get().await.unwrap();
// let transaction = connection.transaction().await.unwrap();
//
// let mut test = AsyncQueueTest::builder().transaction(transaction).build();
// let actual_task = WorkerAsyncTask { number: 1 };
//
// let task = insert_task(&mut test, &actual_task).await;
// let id = task.id;
//
// let mut worker = AsyncWorkerTest::builder()
// .queue(&mut test as &mut dyn AsyncQueueable)
// .retention_mode(RetentionMode::KeepAll)
// .build();
//
// worker.run(task, Box::new(actual_task)).await.unwrap();
// let task_finished = test.find_task_by_id(id).await.unwrap();
// assert_eq!(id, task_finished.id);
// assert_eq!(FangTaskState::Finished, task_finished.state);
// test.transaction.rollback().await.unwrap();
// }
//
// #[tokio::test]
// async fn schedule_task_test() {
// let pool = pool().await;
// let mut connection = pool.get().await.unwrap();
// let transaction = connection.transaction().await.unwrap();
//
// let mut test = AsyncQueueTest::builder().transaction(transaction).build();
//
// let actual_task = WorkerAsyncTaskSchedule { number: 1 };
//
// let task = test.schedule_task(&actual_task).await.unwrap();
//
// let id = task.id;
//
// let mut worker = AsyncWorkerTest::builder()
// .queue(&mut test as &mut dyn AsyncQueueable)
// .retention_mode(RetentionMode::KeepAll)
// .build();
//
// worker.run_tasks_until_none().await.unwrap();
//
// let task = worker.queue.find_task_by_id(id).await.unwrap();
//
// assert_eq!(id, task.id);
// assert_eq!(FangTaskState::New, task.state);
//
// tokio::time::sleep(core::time::Duration::from_secs(3)).await;
//
// worker.run_tasks_until_none().await.unwrap();
//
// let task = test.find_task_by_id(id).await.unwrap();
// assert_eq!(id, task.id);
// assert_eq!(FangTaskState::Finished, task.state);
// }
//
// #[tokio::test]
// async fn retries_task_test() {
// let pool = pool().await;
// let mut connection = pool.get().await.unwrap();
// let transaction = connection.transaction().await.unwrap();
//
// let mut test = AsyncQueueTest::builder().transaction(transaction).build();
//
// let actual_task = AsyncRetryTask {};
//
// let task = test.insert_task(&actual_task).await.unwrap();
//
// let id = task.id;
//
// let mut worker = AsyncWorkerTest::builder()
// .queue(&mut test as &mut dyn AsyncQueueable)
// .retention_mode(RetentionMode::KeepAll)
// .build();
//
// worker.run_tasks_until_none().await.unwrap();
//
// let task = worker.queue.find_task_by_id(id).await.unwrap();
//
// assert_eq!(id, task.id);
// assert_eq!(FangTaskState::Retried, task.state);
// assert_eq!(1, task.retries);
//
// tokio::time::sleep(core::time::Duration::from_secs(5)).await;
// worker.run_tasks_until_none().await.unwrap();
//
// let task = worker.queue.find_task_by_id(id).await.unwrap();
//
// assert_eq!(id, task.id);
// assert_eq!(FangTaskState::Retried, task.state);
// assert_eq!(2, task.retries);
//
// tokio::time::sleep(core::time::Duration::from_secs(10)).await;
// worker.run_tasks_until_none().await.unwrap();
//
// let task = test.find_task_by_id(id).await.unwrap();
// assert_eq!(id, task.id);
// assert_eq!(FangTaskState::Failed, task.state);
// assert_eq!("Failed".to_string(), task.error_message.unwrap());
// }
//
// #[tokio::test]
// async fn saves_error_for_failed_task() {
// let pool = pool().await;
// let mut connection = pool.get().await.unwrap();
// let transaction = connection.transaction().await.unwrap();
//
// let mut test = AsyncQueueTest::builder().transaction(transaction).build();
// let failed_task = AsyncFailedTask { number: 1 };
//
// let task = insert_task(&mut test, &failed_task).await;
// let id = task.id;
//
// let mut worker = AsyncWorkerTest::builder()
// .queue(&mut test as &mut dyn AsyncQueueable)
// .retention_mode(RetentionMode::KeepAll)
// .build();
//
// worker.run(task, Box::new(failed_task)).await.unwrap();
// let task_finished = test.find_task_by_id(id).await.unwrap();
//
// assert_eq!(id, task_finished.id);
// assert_eq!(FangTaskState::Failed, task_finished.state);
// assert_eq!(
// "number 1 is wrong :(".to_string(),
// task_finished.error_message.unwrap()
// );
// test.transaction.rollback().await.unwrap();
// }
//
// #[tokio::test]
// async fn executes_task_only_of_specific_type() {
// let pool = pool().await;
// let mut connection = pool.get().await.unwrap();
// let transaction = connection.transaction().await.unwrap();
//
// let mut test = AsyncQueueTest::builder().transaction(transaction).build();
//
// let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
// let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
// let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
//
// let id1 = task1.id;
// let id12 = task12.id;
// let id2 = task2.id;
//
// let mut worker = AsyncWorkerTest::builder()
// .queue(&mut test as &mut dyn AsyncQueueable)
// .task_type("type1".to_string())
// .retention_mode(RetentionMode::KeepAll)
// .build();
//
// worker.run_tasks_until_none().await.unwrap();
// let task1 = test.find_task_by_id(id1).await.unwrap();
// let task12 = test.find_task_by_id(id12).await.unwrap();
// let task2 = test.find_task_by_id(id2).await.unwrap();
//
// assert_eq!(id1, task1.id);
// assert_eq!(id12, task12.id);
// assert_eq!(id2, task2.id);
// assert_eq!(FangTaskState::Finished, task1.state);
// assert_eq!(FangTaskState::Finished, task12.state);
// assert_eq!(FangTaskState::New, task2.state);
// test.transaction.rollback().await.unwrap();
// }
//
// #[tokio::test]
// async fn remove_when_finished() {
// let pool = pool().await;
// let mut connection = pool.get().await.unwrap();
// let transaction = connection.transaction().await.unwrap();
//
// let mut test = AsyncQueueTest::builder().transaction(transaction).build();
//
// let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
// let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
// let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
//
// let _id1 = task1.id;
// let _id12 = task12.id;
// let id2 = task2.id;
//
// let mut worker = AsyncWorkerTest::builder()
// .queue(&mut test as &mut dyn AsyncQueueable)
// .task_type("type1".to_string())
// .build();
//
// worker.run_tasks_until_none().await.unwrap();
// let task = test
// .fetch_and_touch_task(Some("type1".to_string()))
// .await
// .unwrap();
// assert_eq!(None, task);
//
// let task2 = test
// .fetch_and_touch_task(Some("type2".to_string()))
// .await
// .unwrap()
// .unwrap();
// assert_eq!(id2, task2.id);
//
// test.transaction.rollback().await.unwrap();
// }
// async fn insert_task(test: &mut AsyncQueueTest<'_>, task: &dyn AsyncRunnable) -> Task {
// test.insert_task(task).await.unwrap()
// }
// async fn pool() -> Pool<PostgresConnectionManager<NoTls>> {
// let pg_mgr = PostgresConnectionManager::new_from_stringlike(
// "postgres://postgres:postgres@localhost/fang",
// NoTls,
// )
// .unwrap();
//
// Pool::builder().build(pg_mgr).await.unwrap()
// }
#[tokio::test]
async fn execute_and_finishes_task() {
let pool = pool().await;
let mut test = PgAsyncQueue::builder().pool(pool).build();
let actual_task = WorkerAsyncTask { number: 1 };
let task = insert_task(&mut test, &actual_task).await;
let id = task.id;
let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run(task, Box::new(actual_task)).await.unwrap();
let task_finished = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task_finished.id);
assert_eq!(FangTaskState::Finished, task_finished.state);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn schedule_task_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::builder().pool(pool).build();
let actual_task = WorkerAsyncTaskSchedule { number: 1 };
let task = test.schedule_task(&actual_task).await.unwrap();
let id = task.id;
let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run_tasks_until_none().await.unwrap();
let task = worker.queue.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(FangTaskState::New, task.state);
tokio::time::sleep(core::time::Duration::from_secs(3)).await;
worker.run_tasks_until_none().await.unwrap();
let task = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(FangTaskState::Finished, task.state);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn retries_task_test() {
let pool = pool().await;
let mut test = PgAsyncQueue::builder().pool(pool).build();
let actual_task = AsyncRetryTask {};
let task = test.insert_task(&actual_task).await.unwrap();
let id = task.id;
let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run_tasks_until_none().await.unwrap();
let task = worker.queue.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(FangTaskState::Retried, task.state);
assert_eq!(1, task.retries);
tokio::time::sleep(core::time::Duration::from_secs(5)).await;
worker.run_tasks_until_none().await.unwrap();
let task = worker.queue.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(FangTaskState::Retried, task.state);
assert_eq!(2, task.retries);
tokio::time::sleep(core::time::Duration::from_secs(10)).await;
worker.run_tasks_until_none().await.unwrap();
let task = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task.id);
assert_eq!(FangTaskState::Failed, task.state);
assert_eq!("Failed".to_string(), task.error_message.unwrap());
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn saves_error_for_failed_task() {
let pool = pool().await;
let mut test = PgAsyncQueue::builder().pool(pool).build();
let failed_task = AsyncFailedTask { number: 1 };
let task = insert_task(&mut test, &failed_task).await;
let id = task.id;
let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
.queue(test.clone())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run(task, Box::new(failed_task)).await.unwrap();
let task_finished = test.find_task_by_id(id).await.unwrap();
assert_eq!(id, task_finished.id);
assert_eq!(FangTaskState::Failed, task_finished.state);
assert_eq!(
"number 1 is wrong :(".to_string(),
task_finished.error_message.unwrap()
);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn executes_task_only_of_specific_type() {
let pool = pool().await;
let mut test = PgAsyncQueue::builder().pool(pool).build();
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
let id1 = task1.id;
let id12 = task12.id;
let id2 = task2.id;
let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
.queue(test.clone())
.task_type("type1".to_string())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run_tasks_until_none().await.unwrap();
let task1 = test.find_task_by_id(id1).await.unwrap();
let task12 = test.find_task_by_id(id12).await.unwrap();
let task2 = test.find_task_by_id(id2).await.unwrap();
assert_eq!(id1, task1.id);
assert_eq!(id12, task12.id);
assert_eq!(id2, task2.id);
assert_eq!(FangTaskState::Finished, task1.state);
assert_eq!(FangTaskState::Finished, task12.state);
assert_eq!(FangTaskState::New, task2.state);
test.remove_all_tasks().await.unwrap();
}
#[tokio::test]
async fn remove_when_finished() {
let pool = pool().await;
let mut test = PgAsyncQueue::builder().pool(pool).build();
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
let _id1 = task1.id;
let _id12 = task12.id;
let id2 = task2.id;
let mut worker = AsyncWorker::<PgAsyncQueue>::builder()
.queue(test.clone())
.task_type("type1".to_string())
.build();
worker.run_tasks_until_none().await.unwrap();
let task = test
.fetch_and_touch_task(Some("type1".to_string()))
.await
.unwrap();
assert_eq!(None, task);
let task2 = test
.fetch_and_touch_task(Some("type2".to_string()))
.await
.unwrap()
.unwrap();
assert_eq!(id2, task2.id);
test.remove_all_tasks().await.unwrap();
}
async fn insert_task(test: &mut PgAsyncQueue, task: &dyn AsyncRunnable) -> Task {
test.insert_task(task).await.unwrap()
}
async fn pool() -> Pool<AsyncPgConnection> {
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(
"postgres://postgres:password@localhost/fang",
);
Pool::builder()
.max_size(1)
.min_idle(Some(1))
.build(manager)
.await
.unwrap()
}
}

View file

@ -1,11 +1,9 @@
use crate::queue::AsyncQueueable;
use crate::task::DEFAULT_TASK_TYPE;
use crate::worker::AsyncWorker;
use crate::errors::FangError;
use crate::{RetentionMode, SleepParams};
use async_recursion::async_recursion;
use log::error;
use tokio::task::JoinHandle;
use typed_builder::TypedBuilder;
#[derive(TypedBuilder, Clone)]
@ -46,13 +44,19 @@ where
#[async_recursion]
async fn supervise_task(pool: AsyncWorkerPool<AQueue>, restarts: u64, worker_number: u32) {
let restarts = restarts + 1;
let join_handle = Self::spawn_worker(
pool.queue.clone(),
pool.sleep_params.clone(),
pool.retention_mode.clone(),
pool.task_type.clone(),
)
.await;
let inner_pool = pool.clone();
let join_handle = tokio::spawn(async move {
let mut worker: AsyncWorker<AQueue> = AsyncWorker::builder()
.queue(inner_pool.queue.clone())
.sleep_params(inner_pool.sleep_params.clone())
.retention_mode(inner_pool.retention_mode.clone())
.task_type(inner_pool.task_type.clone())
.build();
worker.run_tasks().await
});
if (join_handle.await).is_err() {
error!(
@ -62,30 +66,4 @@ where
Self::supervise_task(pool, restarts, worker_number).await;
}
}
async fn spawn_worker(
queue: AQueue,
sleep_params: SleepParams,
retention_mode: RetentionMode,
task_type: String,
) -> JoinHandle<Result<(), FangError>> {
tokio::spawn(async move {
Self::run_worker(queue, sleep_params, retention_mode, task_type).await
})
}
async fn run_worker(
queue: AQueue,
sleep_params: SleepParams,
retention_mode: RetentionMode,
task_type: String,
) -> Result<(), FangError> {
let mut worker: AsyncWorker<AQueue> = AsyncWorker::builder()
.queue(queue)
.sleep_params(sleep_params)
.retention_mode(retention_mode)
.task_type(task_type)
.build();
worker.run_tasks().await
}
}