Async worker test (#46)

* preparing code to tests

* changing queue tests and starting worker tests with new struct AsyncQueueTest

* fix clippy and delete Debug trait

* fix warnings

* Get task by id in worker tests and AsyncQueueTest #[cfg(test)]

* Defaults and AsyncWorker Builder

* Testing :D

* Execute task specific type

* deleting Options and comment

* insert task is back git !!

* Test remove tasks and changing insert task auxiliar func
This commit is contained in:
Pmarquez 2022-07-27 17:05:05 +00:00 committed by GitHub
parent 8d0a23e2f9
commit 2d724c3776
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 435 additions and 116 deletions

View file

@ -1,9 +1,11 @@
use crate::asynk::async_runnable::Error as FangError;
use async_trait::async_trait; use async_trait::async_trait;
use bb8_postgres::bb8::Pool; use bb8_postgres::bb8::Pool;
use bb8_postgres::bb8::RunError; use bb8_postgres::bb8::RunError;
use bb8_postgres::tokio_postgres::row::Row; use bb8_postgres::tokio_postgres::row::Row;
use bb8_postgres::tokio_postgres::tls::MakeTlsConnect; #[cfg(test)]
use bb8_postgres::tokio_postgres::tls::TlsConnect; use bb8_postgres::tokio_postgres::tls::NoTls;
use bb8_postgres::tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use bb8_postgres::tokio_postgres::Socket; use bb8_postgres::tokio_postgres::Socket;
use bb8_postgres::tokio_postgres::Transaction; use bb8_postgres::tokio_postgres::Transaction;
use bb8_postgres::PostgresConnectionManager; use bb8_postgres::PostgresConnectionManager;
@ -21,8 +23,10 @@ const REMOVE_ALL_TASK_QUERY: &str = include_str!("queries/remove_all_tasks.sql")
const REMOVE_TASK_QUERY: &str = include_str!("queries/remove_task.sql"); const REMOVE_TASK_QUERY: &str = include_str!("queries/remove_task.sql");
const REMOVE_TASKS_TYPE_QUERY: &str = include_str!("queries/remove_tasks_type.sql"); const REMOVE_TASKS_TYPE_QUERY: &str = include_str!("queries/remove_tasks_type.sql");
const FETCH_TASK_TYPE_QUERY: &str = include_str!("queries/fetch_task_type.sql"); const FETCH_TASK_TYPE_QUERY: &str = include_str!("queries/fetch_task_type.sql");
#[cfg(test)]
const GET_TASK_BY_ID_QUERY: &str = include_str!("queries/get_task_by_id.sql");
const DEFAULT_TASK_TYPE: &str = "common"; pub const DEFAULT_TASK_TYPE: &str = "common";
#[derive(Debug, Eq, PartialEq, Clone, ToSql, FromSql)] #[derive(Debug, Eq, PartialEq, Clone, ToSql, FromSql)]
#[postgres(name = "fang_task_state")] #[postgres(name = "fang_task_state")]
@ -100,7 +104,14 @@ pub enum AsyncQueueError {
#[error("returned invalid result (expected {expected:?}, found {found:?})")] #[error("returned invalid result (expected {expected:?}, found {found:?})")]
ResultError { expected: u64, found: u64 }, ResultError { expected: u64, found: u64 },
} }
impl From<AsyncQueueError> for FangError {
fn from(error: AsyncQueueError) -> Self {
let message = format!("{:?}", error);
FangError {
description: message,
}
}
}
#[async_trait] #[async_trait]
pub trait AsyncQueueable { pub trait AsyncQueueable {
async fn fetch_and_touch_task( async fn fetch_and_touch_task(
@ -141,6 +152,98 @@ where
pool: Pool<PostgresConnectionManager<Tls>>, pool: Pool<PostgresConnectionManager<Tls>>,
} }
#[cfg(test)]
pub struct AsyncQueueTest<'a> {
pub transaction: Transaction<'a>,
}
#[cfg(test)]
impl<'a> AsyncQueueTest<'a> {
pub async fn get_task_by_id(&mut self, id: Uuid) -> Result<Task, AsyncQueueError> {
let row: Row = self
.transaction
.query_one(GET_TASK_BY_ID_QUERY, &[&id])
.await?;
let task = AsyncQueue::<NoTls>::row_to_task(row);
Ok(task)
}
}
#[cfg(test)]
#[async_trait]
impl AsyncQueueable for AsyncQueueTest<'_> {
async fn fetch_and_touch_task(
&mut self,
task_type: &Option<String>,
) -> Result<Option<Task>, AsyncQueueError> {
let transaction = &mut self.transaction;
let task = AsyncQueue::<NoTls>::fetch_and_touch_task_query(transaction, task_type).await?;
Ok(task)
}
async fn insert_task(
&mut self,
metadata: serde_json::Value,
task_type: &str,
) -> Result<Task, AsyncQueueError> {
let transaction = &mut self.transaction;
let task = AsyncQueue::<NoTls>::insert_task_query(transaction, metadata, task_type).await?;
Ok(task)
}
async fn remove_all_tasks(&mut self) -> Result<u64, AsyncQueueError> {
let transaction = &mut self.transaction;
let result = AsyncQueue::<NoTls>::remove_all_tasks_query(transaction).await?;
Ok(result)
}
async fn remove_task(&mut self, task: Task) -> Result<u64, AsyncQueueError> {
let transaction = &mut self.transaction;
let result = AsyncQueue::<NoTls>::remove_task_query(transaction, task).await?;
Ok(result)
}
async fn remove_tasks_type(&mut self, task_type: &str) -> Result<u64, AsyncQueueError> {
let transaction = &mut self.transaction;
let result = AsyncQueue::<NoTls>::remove_tasks_type_query(transaction, task_type).await?;
Ok(result)
}
async fn update_task_state(
&mut self,
task: Task,
state: FangTaskState,
) -> Result<Task, AsyncQueueError> {
let transaction = &mut self.transaction;
let task = AsyncQueue::<NoTls>::update_task_state_query(transaction, task, state).await?;
Ok(task)
}
async fn fail_task(
&mut self,
task: Task,
error_message: &str,
) -> Result<Task, AsyncQueueError> {
let transaction = &mut self.transaction;
let task = AsyncQueue::<NoTls>::fail_task_query(transaction, task, error_message).await?;
Ok(task)
}
}
impl<Tls> AsyncQueue<Tls> impl<Tls> AsyncQueue<Tls>
where where
Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static, Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
@ -404,7 +507,7 @@ where
#[cfg(test)] #[cfg(test)]
mod async_queue_tests { mod async_queue_tests {
use super::AsyncQueue; use super::AsyncQueueTest;
use super::AsyncQueueable; use super::AsyncQueueable;
use super::FangTaskState; use super::FangTaskState;
use super::Task; use super::Task;
@ -413,7 +516,6 @@ mod async_queue_tests {
use async_trait::async_trait; use async_trait::async_trait;
use bb8_postgres::bb8::Pool; use bb8_postgres::bb8::Pool;
use bb8_postgres::tokio_postgres::NoTls; use bb8_postgres::tokio_postgres::NoTls;
use bb8_postgres::tokio_postgres::Transaction;
use bb8_postgres::PostgresConnectionManager; use bb8_postgres::PostgresConnectionManager;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -434,15 +536,11 @@ mod async_queue_tests {
async fn insert_task_creates_new_task() { async fn insert_task_creates_new_task() {
let pool = pool().await; let pool = pool().await;
let mut connection = pool.get().await.unwrap(); let mut connection = pool.get().await.unwrap();
let mut transaction = connection.transaction().await.unwrap(); let transaction = connection.transaction().await.unwrap();
let task = AsyncTask { number: 1 }; let mut test = AsyncQueueTest { transaction };
let metadata = serde_json::to_value(&task as &dyn AsyncRunnable).unwrap();
let task = let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
AsyncQueue::<NoTls>::insert_task_query(&mut transaction, metadata, &task.task_type())
.await
.unwrap();
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -450,16 +548,18 @@ mod async_queue_tests {
assert_eq!(Some(1), number); assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
transaction.rollback().await.unwrap(); test.transaction.rollback().await.unwrap();
} }
#[tokio::test] #[tokio::test]
async fn update_task_state_test() { async fn update_task_state_test() {
let pool = pool().await; let pool = pool().await;
let mut connection = pool.get().await.unwrap(); let mut connection = pool.get().await.unwrap();
let mut transaction = connection.transaction().await.unwrap(); let transaction = connection.transaction().await.unwrap();
let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; let mut test = AsyncQueueTest { transaction };
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -469,27 +569,26 @@ mod async_queue_tests {
assert_eq!(Some(1), number); assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let finished_task = AsyncQueue::<NoTls>::update_task_state_query( let finished_task = test
&mut transaction, .update_task_state(task, FangTaskState::Finished)
task, .await
FangTaskState::Finished, .unwrap();
)
.await
.unwrap();
assert_eq!(id, finished_task.id); assert_eq!(id, finished_task.id);
assert_eq!(FangTaskState::Finished, finished_task.state); assert_eq!(FangTaskState::Finished, finished_task.state);
transaction.rollback().await.unwrap(); test.transaction.rollback().await.unwrap();
} }
#[tokio::test] #[tokio::test]
async fn failed_task_query_test() { async fn failed_task_query_test() {
let pool = pool().await; let pool = pool().await;
let mut connection = pool.get().await.unwrap(); let mut connection = pool.get().await.unwrap();
let mut transaction = connection.transaction().await.unwrap(); let transaction = connection.transaction().await.unwrap();
let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; let mut test = AsyncQueueTest { transaction };
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -499,25 +598,24 @@ mod async_queue_tests {
assert_eq!(Some(1), number); assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let failed_task = let failed_task = test.fail_task(task, "Some error").await.unwrap();
AsyncQueue::<NoTls>::fail_task_query(&mut transaction, task, "Some error")
.await
.unwrap();
assert_eq!(id, failed_task.id); assert_eq!(id, failed_task.id);
assert_eq!(Some("Some error"), failed_task.error_message.as_deref()); assert_eq!(Some("Some error"), failed_task.error_message.as_deref());
assert_eq!(FangTaskState::Failed, failed_task.state); assert_eq!(FangTaskState::Failed, failed_task.state);
transaction.rollback().await.unwrap(); test.transaction.rollback().await.unwrap();
} }
#[tokio::test] #[tokio::test]
async fn remove_all_tasks_test() { async fn remove_all_tasks_test() {
let pool = pool().await; let pool = pool().await;
let mut connection = pool.get().await.unwrap(); let mut connection = pool.get().await.unwrap();
let mut transaction = connection.transaction().await.unwrap(); let transaction = connection.transaction().await.unwrap();
let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; let mut test = AsyncQueueTest { transaction };
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -526,7 +624,7 @@ mod async_queue_tests {
assert_eq!(Some(1), number); assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let task = insert_task(&mut transaction, &AsyncTask { number: 2 }).await; let task = insert_task(&mut test, &AsyncTask { number: 2 }).await;
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -535,21 +633,21 @@ mod async_queue_tests {
assert_eq!(Some(2), number); assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let result = AsyncQueue::<NoTls>::remove_all_tasks_query(&mut transaction) let result = test.remove_all_tasks().await.unwrap();
.await
.unwrap();
assert_eq!(2, result); assert_eq!(2, result);
transaction.rollback().await.unwrap(); test.transaction.rollback().await.unwrap();
} }
#[tokio::test] #[tokio::test]
async fn fetch_and_touch_test() { async fn fetch_and_touch_test() {
let pool = pool().await; let pool = pool().await;
let mut connection = pool.get().await.unwrap(); let mut connection = pool.get().await.unwrap();
let mut transaction = connection.transaction().await.unwrap(); let transaction = connection.transaction().await.unwrap();
let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; let mut test = AsyncQueueTest { transaction };
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -558,7 +656,7 @@ mod async_queue_tests {
assert_eq!(Some(1), number); assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let task = insert_task(&mut transaction, &AsyncTask { number: 2 }).await; let task = insert_task(&mut test, &AsyncTask { number: 2 }).await;
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -567,10 +665,7 @@ mod async_queue_tests {
assert_eq!(Some(2), number); assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let task = AsyncQueue::<NoTls>::fetch_and_touch_task_query(&mut transaction, &None) let task = test.fetch_and_touch_task(&None).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -579,10 +674,7 @@ mod async_queue_tests {
assert_eq!(Some(1), number); assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let task = AsyncQueue::<NoTls>::fetch_and_touch_task_query(&mut transaction, &None) let task = test.fetch_and_touch_task(&None).await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
let type_task = metadata["type"].as_str(); let type_task = metadata["type"].as_str();
@ -590,16 +682,18 @@ mod async_queue_tests {
assert_eq!(Some(2), number); assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
transaction.rollback().await.unwrap(); test.transaction.rollback().await.unwrap();
} }
#[tokio::test] #[tokio::test]
async fn remove_tasks_type_test() { async fn remove_tasks_type_test() {
let pool = pool().await; let pool = pool().await;
let mut connection = pool.get().await.unwrap(); let mut connection = pool.get().await.unwrap();
let mut transaction = connection.transaction().await.unwrap(); let transaction = connection.transaction().await.unwrap();
let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; let mut test = AsyncQueueTest { transaction };
let task = insert_task(&mut test, &AsyncTask { number: 1 }).await;
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -608,7 +702,7 @@ mod async_queue_tests {
assert_eq!(Some(1), number); assert_eq!(Some(1), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let task = insert_task(&mut transaction, &AsyncTask { number: 2 }).await; let task = insert_task(&mut test, &AsyncTask { number: 2 }).await;
let metadata = task.metadata.as_object().unwrap(); let metadata = task.metadata.as_object().unwrap();
let number = metadata["number"].as_u64(); let number = metadata["number"].as_u64();
@ -617,17 +711,18 @@ mod async_queue_tests {
assert_eq!(Some(2), number); assert_eq!(Some(2), number);
assert_eq!(Some("AsyncTask"), type_task); assert_eq!(Some("AsyncTask"), type_task);
let result = AsyncQueue::<NoTls>::remove_tasks_type_query(&mut transaction, "mytype") let result = test.remove_tasks_type("mytype").await.unwrap();
.await
.unwrap();
assert_eq!(0, result); assert_eq!(0, result);
let result = AsyncQueue::<NoTls>::remove_tasks_type_query(&mut transaction, "common") let result = test.remove_tasks_type("common").await.unwrap();
.await
.unwrap();
assert_eq!(2, result); assert_eq!(2, result);
transaction.rollback().await.unwrap(); test.transaction.rollback().await.unwrap();
}
async fn insert_task(test: &mut AsyncQueueTest<'_>, task: &dyn AsyncRunnable) -> Task {
let metadata = serde_json::to_value(task).unwrap();
test.insert_task(metadata, &task.task_type()).await.unwrap()
} }
async fn pool() -> Pool<PostgresConnectionManager<NoTls>> { async fn pool() -> Pool<PostgresConnectionManager<NoTls>> {
@ -639,12 +734,4 @@ mod async_queue_tests {
Pool::builder().build(pg_mgr).await.unwrap() Pool::builder().build(pg_mgr).await.unwrap()
} }
async fn insert_task(transaction: &mut Transaction<'_>, task: &dyn AsyncRunnable) -> Task {
let metadata = serde_json::to_value(task).unwrap();
AsyncQueue::<NoTls>::insert_task_query(transaction, metadata, &task.task_type())
.await
.unwrap()
}
} }

View file

@ -1,42 +1,27 @@
use crate::asynk::async_queue::AsyncQueue;
use crate::asynk::async_queue::AsyncQueueable; use crate::asynk::async_queue::AsyncQueueable;
use crate::asynk::async_queue::FangTaskState; use crate::asynk::async_queue::FangTaskState;
use crate::asynk::async_queue::Task; use crate::asynk::async_queue::Task;
use crate::asynk::async_queue::DEFAULT_TASK_TYPE;
use crate::asynk::async_runnable::AsyncRunnable; use crate::asynk::async_runnable::AsyncRunnable;
use crate::asynk::Error; use crate::asynk::Error;
use crate::{RetentionMode, SleepParams}; use crate::{RetentionMode, SleepParams};
use bb8_postgres::tokio_postgres::tls::MakeTlsConnect;
use bb8_postgres::tokio_postgres::tls::TlsConnect;
use bb8_postgres::tokio_postgres::Socket;
use log::error; use log::error;
use std::time::Duration; use std::time::Duration;
use typed_builder::TypedBuilder; use typed_builder::TypedBuilder;
#[derive(TypedBuilder, Debug)] #[derive(TypedBuilder)]
pub struct AsyncWorker<Tls> pub struct AsyncWorker<'a> {
where
Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
<Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
<Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
<<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
#[builder(setter(into))]
pub queue: AsyncQueue<Tls>,
#[builder(setter(into))]
pub task_type: Option<String>,
#[builder(setter(into))] #[builder(setter(into))]
pub queue: &'a mut dyn AsyncQueueable,
#[builder(default=DEFAULT_TASK_TYPE.to_string() , setter(into))]
pub task_type: String,
#[builder(default, setter(into))]
pub sleep_params: SleepParams, pub sleep_params: SleepParams,
#[builder(setter(into))] #[builder(default, setter(into))]
pub retention_mode: RetentionMode, pub retention_mode: RetentionMode,
} }
impl<Tls> AsyncWorker<Tls> impl<'a> AsyncWorker<'a> {
where pub async fn run(&mut self, task: Task) -> Result<(), Error> {
Tls: MakeTlsConnect<Socket> + Clone + Send + Sync + 'static,
<Tls as MakeTlsConnect<Socket>>::Stream: Send + Sync,
<Tls as MakeTlsConnect<Socket>>::TlsConnect: Send,
<<Tls as MakeTlsConnect<Socket>>::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
pub async fn run(&mut self, task: Task) {
let result = self.execute_task(task).await; let result = self.execute_task(task).await;
self.finalize_task(result).await self.finalize_task(result).await
} }
@ -44,36 +29,44 @@ where
let actual_task: Box<dyn AsyncRunnable> = let actual_task: Box<dyn AsyncRunnable> =
serde_json::from_value(task.metadata.clone()).unwrap(); serde_json::from_value(task.metadata.clone()).unwrap();
let task_result = actual_task.run(&mut self.queue).await; let task_result = actual_task.run(self.queue).await;
match task_result { match task_result {
Ok(()) => Ok(task), Ok(()) => Ok(task),
Err(error) => Err((task, error.description)), Err(error) => Err((task, error.description)),
} }
} }
async fn finalize_task(&mut self, result: Result<Task, (Task, String)>) { async fn finalize_task(&mut self, result: Result<Task, (Task, String)>) -> Result<(), Error> {
match self.retention_mode { match self.retention_mode {
RetentionMode::KeepAll => { RetentionMode::KeepAll => match result {
match result {
Ok(task) => self
.queue
.update_task_state(task, FangTaskState::Finished)
.await
.unwrap(),
Err((task, error)) => self.queue.fail_task(task, &error).await.unwrap(),
};
}
RetentionMode::RemoveAll => {
match result {
Ok(task) => self.queue.remove_task(task).await.unwrap(),
Err((task, _error)) => self.queue.remove_task(task).await.unwrap(),
};
}
RetentionMode::RemoveFinished => match result {
Ok(task) => { Ok(task) => {
self.queue.remove_task(task).await.unwrap(); self.queue
.update_task_state(task, FangTaskState::Finished)
.await?;
Ok(())
} }
Err((task, error)) => { Err((task, error)) => {
self.queue.fail_task(task, &error).await.unwrap(); self.queue.fail_task(task, &error).await?;
Ok(())
}
},
RetentionMode::RemoveAll => match result {
Ok(task) => {
self.queue.remove_task(task).await?;
Ok(())
}
Err((task, _error)) => {
self.queue.remove_task(task).await?;
Ok(())
}
},
RetentionMode::RemoveFinished => match result {
Ok(task) => {
self.queue.remove_task(task).await?;
Ok(())
}
Err((task, error)) => {
self.queue.fail_task(task, &error).await?;
Ok(())
} }
}, },
} }
@ -87,12 +80,12 @@ where
loop { loop {
match self match self
.queue .queue
.fetch_and_touch_task(&self.task_type.clone()) .fetch_and_touch_task(&Some(self.task_type.clone()))
.await .await
{ {
Ok(Some(task)) => { Ok(Some(task)) => {
self.sleep_params.maybe_reset_sleep_period(); self.sleep_params.maybe_reset_sleep_period();
self.run(task).await; self.run(task).await?
} }
Ok(None) => { Ok(None) => {
self.sleep().await; self.sleep().await;
@ -106,4 +99,237 @@ where
}; };
} }
} }
#[cfg(test)]
pub async fn run_tasks_until_none(&mut self) -> Result<(), Error> {
loop {
match self
.queue
.fetch_and_touch_task(&Some(self.task_type.clone()))
.await
{
Ok(Some(task)) => {
self.sleep_params.maybe_reset_sleep_period();
self.run(task).await?
}
Ok(None) => {
return Ok(());
}
Err(error) => {
error!("Failed to fetch a task {:?}", error);
self.sleep().await;
}
};
}
}
}
#[cfg(test)]
mod async_worker_tests {
use super::AsyncWorker;
use crate::asynk::async_queue::AsyncQueueTest;
use crate::asynk::async_queue::AsyncQueueable;
use crate::asynk::async_queue::FangTaskState;
use crate::asynk::async_worker::Task;
use crate::asynk::AsyncRunnable;
use crate::asynk::Error;
use crate::RetentionMode;
use async_trait::async_trait;
use bb8_postgres::bb8::Pool;
use bb8_postgres::tokio_postgres::NoTls;
use bb8_postgres::PostgresConnectionManager;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct WorkerAsyncTask {
pub number: u16,
}
#[typetag::serde]
#[async_trait(?Send)]
impl AsyncRunnable for WorkerAsyncTask {
async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), Error> {
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct AsyncFailedTask {
pub number: u16,
}
#[typetag::serde]
#[async_trait(?Send)]
impl AsyncRunnable for AsyncFailedTask {
async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), Error> {
let message = format!("number {} is wrong :(", self.number);
Err(Error {
description: message,
})
}
}
#[derive(Serialize, Deserialize)]
struct AsyncTaskType1 {}
#[typetag::serde]
#[async_trait(?Send)]
impl AsyncRunnable for AsyncTaskType1 {
async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), Error> {
Ok(())
}
fn task_type(&self) -> String {
"type1".to_string()
}
}
#[derive(Serialize, Deserialize)]
struct AsyncTaskType2 {}
#[typetag::serde]
#[async_trait(?Send)]
impl AsyncRunnable for AsyncTaskType2 {
async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), Error> {
Ok(())
}
fn task_type(&self) -> String {
"type2".to_string()
}
}
#[tokio::test]
async fn execute_and_finishes_task() {
let pool = pool().await;
let mut connection = pool.get().await.unwrap();
let transaction = connection.transaction().await.unwrap();
let mut test = AsyncQueueTest { transaction };
let task = insert_task(&mut test, &WorkerAsyncTask { number: 1 }).await;
let id = task.id;
let mut worker = AsyncWorker::builder()
.queue(&mut test as &mut dyn AsyncQueueable)
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run(task).await.unwrap();
let task_finished = test.get_task_by_id(id).await.unwrap();
assert_eq!(id, task_finished.id);
assert_eq!(FangTaskState::Finished, task_finished.state);
test.transaction.rollback().await.unwrap();
}
#[tokio::test]
async fn saves_error_for_failed_task() {
let pool = pool().await;
let mut connection = pool.get().await.unwrap();
let transaction = connection.transaction().await.unwrap();
let mut test = AsyncQueueTest { transaction };
let task = insert_task(&mut test, &AsyncFailedTask { number: 1 }).await;
let id = task.id;
let mut worker = AsyncWorker::builder()
.queue(&mut test as &mut dyn AsyncQueueable)
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run(task).await.unwrap();
let task_finished = test.get_task_by_id(id).await.unwrap();
assert_eq!(id, task_finished.id);
assert_eq!(FangTaskState::Failed, task_finished.state);
assert_eq!(
"number 1 is wrong :(".to_string(),
task_finished.error_message.unwrap()
);
test.transaction.rollback().await.unwrap();
}
#[tokio::test]
async fn executes_task_only_of_specific_type() {
let pool = pool().await;
let mut connection = pool.get().await.unwrap();
let transaction = connection.transaction().await.unwrap();
let mut test = AsyncQueueTest { transaction };
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
let id1 = task1.id;
let id12 = task12.id;
let id2 = task2.id;
let mut worker = AsyncWorker::builder()
.queue(&mut test as &mut dyn AsyncQueueable)
.task_type("type1".to_string())
.retention_mode(RetentionMode::KeepAll)
.build();
worker.run_tasks_until_none().await.unwrap();
let task1 = test.get_task_by_id(id1).await.unwrap();
let task12 = test.get_task_by_id(id12).await.unwrap();
let task2 = test.get_task_by_id(id2).await.unwrap();
assert_eq!(id1, task1.id);
assert_eq!(id12, task12.id);
assert_eq!(id2, task2.id);
assert_eq!(FangTaskState::Finished, task1.state);
assert_eq!(FangTaskState::Finished, task12.state);
assert_eq!(FangTaskState::New, task2.state);
test.transaction.rollback().await.unwrap();
}
#[tokio::test]
async fn remove_when_finished() {
let pool = pool().await;
let mut connection = pool.get().await.unwrap();
let transaction = connection.transaction().await.unwrap();
let mut test = AsyncQueueTest { transaction };
let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await;
let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await;
let _id1 = task1.id;
let _id12 = task12.id;
let id2 = task2.id;
let mut worker = AsyncWorker::builder()
.queue(&mut test as &mut dyn AsyncQueueable)
.task_type("type1".to_string())
.build();
worker.run_tasks_until_none().await.unwrap();
let task = test
.fetch_and_touch_task(&Some("type1".to_string()))
.await
.unwrap();
assert_eq!(None, task);
let task2 = test
.fetch_and_touch_task(&Some("type2".to_string()))
.await
.unwrap()
.unwrap();
assert_eq!(id2, task2.id);
test.transaction.rollback().await.unwrap();
}
async fn insert_task(test: &mut AsyncQueueTest<'_>, task: &dyn AsyncRunnable) -> Task {
let metadata = serde_json::to_value(task).unwrap();
test.insert_task(metadata, &task.task_type()).await.unwrap()
}
async fn pool() -> Pool<PostgresConnectionManager<NoTls>> {
let pg_mgr = PostgresConnectionManager::new_from_stringlike(
"postgres://postgres:postgres@localhost/fang",
NoTls,
)
.unwrap();
Pool::builder().build(pg_mgr).await.unwrap()
}
} }

View file

@ -0,0 +1 @@
SELECT * FROM fang_tasks WHERE id = $1

View file

@ -5,6 +5,11 @@ pub enum RetentionMode {
RemoveAll, RemoveAll,
RemoveFinished, RemoveFinished,
} }
impl Default for RetentionMode {
fn default() -> Self {
RetentionMode::RemoveAll
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct SleepParams { pub struct SleepParams {
pub sleep_period: u64, pub sleep_period: u64,