From 2d724c377628a66f29edaa66002482980e80e978 Mon Sep 17 00:00:00 2001 From: Pmarquez <48651252+pxp9@users.noreply.github.com> Date: Wed, 27 Jul 2022 17:05:05 +0000 Subject: [PATCH] Async worker test (#46) * preparing code to tests * changing queue tests and starting worker tests with new struct AsyncQueueTest * fix clippy and delete Debug trait * fix warnings * Get task by id in worker tests and AsyncQueueTest #[cfg(test)] * Defaults and AsyncWorker Builder * Testing :D * Execute task specific type * deleting Options and comment * insert task is back git !! * Test remove tasks and changing insert task auxiliar func --- src/asynk/async_queue.rs | 223 +++++++++++++------ src/asynk/async_worker.rs | 322 +++++++++++++++++++++++---- src/asynk/queries/get_task_by_id.sql | 1 + src/lib.rs | 5 + 4 files changed, 435 insertions(+), 116 deletions(-) create mode 100644 src/asynk/queries/get_task_by_id.sql diff --git a/src/asynk/async_queue.rs b/src/asynk/async_queue.rs index 628ffd8..5b90ca9 100644 --- a/src/asynk/async_queue.rs +++ b/src/asynk/async_queue.rs @@ -1,9 +1,11 @@ +use crate::asynk::async_runnable::Error as FangError; use async_trait::async_trait; use bb8_postgres::bb8::Pool; use bb8_postgres::bb8::RunError; use bb8_postgres::tokio_postgres::row::Row; -use bb8_postgres::tokio_postgres::tls::MakeTlsConnect; -use bb8_postgres::tokio_postgres::tls::TlsConnect; +#[cfg(test)] +use bb8_postgres::tokio_postgres::tls::NoTls; +use bb8_postgres::tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use bb8_postgres::tokio_postgres::Socket; use bb8_postgres::tokio_postgres::Transaction; use bb8_postgres::PostgresConnectionManager; @@ -21,8 +23,10 @@ const REMOVE_ALL_TASK_QUERY: &str = include_str!("queries/remove_all_tasks.sql") const REMOVE_TASK_QUERY: &str = include_str!("queries/remove_task.sql"); const REMOVE_TASKS_TYPE_QUERY: &str = include_str!("queries/remove_tasks_type.sql"); const FETCH_TASK_TYPE_QUERY: &str = include_str!("queries/fetch_task_type.sql"); +#[cfg(test)] +const GET_TASK_BY_ID_QUERY: &str = include_str!("queries/get_task_by_id.sql"); -const DEFAULT_TASK_TYPE: &str = "common"; +pub const DEFAULT_TASK_TYPE: &str = "common"; #[derive(Debug, Eq, PartialEq, Clone, ToSql, FromSql)] #[postgres(name = "fang_task_state")] @@ -100,7 +104,14 @@ pub enum AsyncQueueError { #[error("returned invalid result (expected {expected:?}, found {found:?})")] ResultError { expected: u64, found: u64 }, } - +impl From for FangError { + fn from(error: AsyncQueueError) -> Self { + let message = format!("{:?}", error); + FangError { + description: message, + } + } +} #[async_trait] pub trait AsyncQueueable { async fn fetch_and_touch_task( @@ -141,6 +152,98 @@ where pool: Pool>, } +#[cfg(test)] +pub struct AsyncQueueTest<'a> { + pub transaction: Transaction<'a>, +} + +#[cfg(test)] +impl<'a> AsyncQueueTest<'a> { + pub async fn get_task_by_id(&mut self, id: Uuid) -> Result { + let row: Row = self + .transaction + .query_one(GET_TASK_BY_ID_QUERY, &[&id]) + .await?; + + let task = AsyncQueue::::row_to_task(row); + Ok(task) + } +} + +#[cfg(test)] +#[async_trait] +impl AsyncQueueable for AsyncQueueTest<'_> { + async fn fetch_and_touch_task( + &mut self, + task_type: &Option, + ) -> Result, AsyncQueueError> { + let transaction = &mut self.transaction; + + let task = AsyncQueue::::fetch_and_touch_task_query(transaction, task_type).await?; + + Ok(task) + } + + async fn insert_task( + &mut self, + metadata: serde_json::Value, + task_type: &str, + ) -> Result { + let transaction = &mut self.transaction; + + let task = AsyncQueue::::insert_task_query(transaction, metadata, task_type).await?; + + Ok(task) + } + + async fn remove_all_tasks(&mut self) -> Result { + let transaction = &mut self.transaction; + + let result = AsyncQueue::::remove_all_tasks_query(transaction).await?; + + Ok(result) + } + + async fn remove_task(&mut self, task: Task) -> Result { + let transaction = &mut self.transaction; + + let result = AsyncQueue::::remove_task_query(transaction, task).await?; + + Ok(result) + } + + async fn remove_tasks_type(&mut self, task_type: &str) -> Result { + let transaction = &mut self.transaction; + + let result = AsyncQueue::::remove_tasks_type_query(transaction, task_type).await?; + + Ok(result) + } + + async fn update_task_state( + &mut self, + task: Task, + state: FangTaskState, + ) -> Result { + let transaction = &mut self.transaction; + + let task = AsyncQueue::::update_task_state_query(transaction, task, state).await?; + + Ok(task) + } + + async fn fail_task( + &mut self, + task: Task, + error_message: &str, + ) -> Result { + let transaction = &mut self.transaction; + + let task = AsyncQueue::::fail_task_query(transaction, task, error_message).await?; + + Ok(task) + } +} impl AsyncQueue where Tls: MakeTlsConnect + Clone + Send + Sync + 'static, @@ -404,7 +507,7 @@ where #[cfg(test)] mod async_queue_tests { - use super::AsyncQueue; + use super::AsyncQueueTest; use super::AsyncQueueable; use super::FangTaskState; use super::Task; @@ -413,7 +516,6 @@ mod async_queue_tests { use async_trait::async_trait; use bb8_postgres::bb8::Pool; use bb8_postgres::tokio_postgres::NoTls; - use bb8_postgres::tokio_postgres::Transaction; use bb8_postgres::PostgresConnectionManager; use serde::{Deserialize, Serialize}; @@ -434,15 +536,11 @@ mod async_queue_tests { async fn insert_task_creates_new_task() { let pool = pool().await; let mut connection = pool.get().await.unwrap(); - let mut transaction = connection.transaction().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); - let task = AsyncTask { number: 1 }; - let metadata = serde_json::to_value(&task as &dyn AsyncRunnable).unwrap(); + let mut test = AsyncQueueTest { transaction }; - let task = - AsyncQueue::::insert_task_query(&mut transaction, metadata, &task.task_type()) - .await - .unwrap(); + let task = insert_task(&mut test, &AsyncTask { number: 1 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -450,16 +548,18 @@ mod async_queue_tests { assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); - transaction.rollback().await.unwrap(); + test.transaction.rollback().await.unwrap(); } #[tokio::test] async fn update_task_state_test() { let pool = pool().await; let mut connection = pool.get().await.unwrap(); - let mut transaction = connection.transaction().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); - let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; + let mut test = AsyncQueueTest { transaction }; + + let task = insert_task(&mut test, &AsyncTask { number: 1 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -469,27 +569,26 @@ mod async_queue_tests { assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); - let finished_task = AsyncQueue::::update_task_state_query( - &mut transaction, - task, - FangTaskState::Finished, - ) - .await - .unwrap(); + let finished_task = test + .update_task_state(task, FangTaskState::Finished) + .await + .unwrap(); assert_eq!(id, finished_task.id); assert_eq!(FangTaskState::Finished, finished_task.state); - transaction.rollback().await.unwrap(); + test.transaction.rollback().await.unwrap(); } #[tokio::test] async fn failed_task_query_test() { let pool = pool().await; let mut connection = pool.get().await.unwrap(); - let mut transaction = connection.transaction().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); - let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; + let mut test = AsyncQueueTest { transaction }; + + let task = insert_task(&mut test, &AsyncTask { number: 1 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -499,25 +598,24 @@ mod async_queue_tests { assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); - let failed_task = - AsyncQueue::::fail_task_query(&mut transaction, task, "Some error") - .await - .unwrap(); + let failed_task = test.fail_task(task, "Some error").await.unwrap(); assert_eq!(id, failed_task.id); assert_eq!(Some("Some error"), failed_task.error_message.as_deref()); assert_eq!(FangTaskState::Failed, failed_task.state); - transaction.rollback().await.unwrap(); + test.transaction.rollback().await.unwrap(); } #[tokio::test] async fn remove_all_tasks_test() { let pool = pool().await; let mut connection = pool.get().await.unwrap(); - let mut transaction = connection.transaction().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); - let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; + let mut test = AsyncQueueTest { transaction }; + + let task = insert_task(&mut test, &AsyncTask { number: 1 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -526,7 +624,7 @@ mod async_queue_tests { assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); - let task = insert_task(&mut transaction, &AsyncTask { number: 2 }).await; + let task = insert_task(&mut test, &AsyncTask { number: 2 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -535,21 +633,21 @@ mod async_queue_tests { assert_eq!(Some(2), number); assert_eq!(Some("AsyncTask"), type_task); - let result = AsyncQueue::::remove_all_tasks_query(&mut transaction) - .await - .unwrap(); + let result = test.remove_all_tasks().await.unwrap(); assert_eq!(2, result); - transaction.rollback().await.unwrap(); + test.transaction.rollback().await.unwrap(); } #[tokio::test] async fn fetch_and_touch_test() { let pool = pool().await; let mut connection = pool.get().await.unwrap(); - let mut transaction = connection.transaction().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); - let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; + let mut test = AsyncQueueTest { transaction }; + + let task = insert_task(&mut test, &AsyncTask { number: 1 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -558,7 +656,7 @@ mod async_queue_tests { assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); - let task = insert_task(&mut transaction, &AsyncTask { number: 2 }).await; + let task = insert_task(&mut test, &AsyncTask { number: 2 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -567,10 +665,7 @@ mod async_queue_tests { assert_eq!(Some(2), number); assert_eq!(Some("AsyncTask"), type_task); - let task = AsyncQueue::::fetch_and_touch_task_query(&mut transaction, &None) - .await - .unwrap() - .unwrap(); + let task = test.fetch_and_touch_task(&None).await.unwrap().unwrap(); let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -579,10 +674,7 @@ mod async_queue_tests { assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); - let task = AsyncQueue::::fetch_and_touch_task_query(&mut transaction, &None) - .await - .unwrap() - .unwrap(); + let task = test.fetch_and_touch_task(&None).await.unwrap().unwrap(); let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); let type_task = metadata["type"].as_str(); @@ -590,16 +682,18 @@ mod async_queue_tests { assert_eq!(Some(2), number); assert_eq!(Some("AsyncTask"), type_task); - transaction.rollback().await.unwrap(); + test.transaction.rollback().await.unwrap(); } #[tokio::test] async fn remove_tasks_type_test() { let pool = pool().await; let mut connection = pool.get().await.unwrap(); - let mut transaction = connection.transaction().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); - let task = insert_task(&mut transaction, &AsyncTask { number: 1 }).await; + let mut test = AsyncQueueTest { transaction }; + + let task = insert_task(&mut test, &AsyncTask { number: 1 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -608,7 +702,7 @@ mod async_queue_tests { assert_eq!(Some(1), number); assert_eq!(Some("AsyncTask"), type_task); - let task = insert_task(&mut transaction, &AsyncTask { number: 2 }).await; + let task = insert_task(&mut test, &AsyncTask { number: 2 }).await; let metadata = task.metadata.as_object().unwrap(); let number = metadata["number"].as_u64(); @@ -617,17 +711,18 @@ mod async_queue_tests { assert_eq!(Some(2), number); assert_eq!(Some("AsyncTask"), type_task); - let result = AsyncQueue::::remove_tasks_type_query(&mut transaction, "mytype") - .await - .unwrap(); + let result = test.remove_tasks_type("mytype").await.unwrap(); assert_eq!(0, result); - let result = AsyncQueue::::remove_tasks_type_query(&mut transaction, "common") - .await - .unwrap(); + let result = test.remove_tasks_type("common").await.unwrap(); assert_eq!(2, result); - transaction.rollback().await.unwrap(); + test.transaction.rollback().await.unwrap(); + } + + async fn insert_task(test: &mut AsyncQueueTest<'_>, task: &dyn AsyncRunnable) -> Task { + let metadata = serde_json::to_value(task).unwrap(); + test.insert_task(metadata, &task.task_type()).await.unwrap() } async fn pool() -> Pool> { @@ -639,12 +734,4 @@ mod async_queue_tests { Pool::builder().build(pg_mgr).await.unwrap() } - - async fn insert_task(transaction: &mut Transaction<'_>, task: &dyn AsyncRunnable) -> Task { - let metadata = serde_json::to_value(task).unwrap(); - - AsyncQueue::::insert_task_query(transaction, metadata, &task.task_type()) - .await - .unwrap() - } } diff --git a/src/asynk/async_worker.rs b/src/asynk/async_worker.rs index fa4c1cc..d151b55 100644 --- a/src/asynk/async_worker.rs +++ b/src/asynk/async_worker.rs @@ -1,42 +1,27 @@ -use crate::asynk::async_queue::AsyncQueue; use crate::asynk::async_queue::AsyncQueueable; use crate::asynk::async_queue::FangTaskState; use crate::asynk::async_queue::Task; +use crate::asynk::async_queue::DEFAULT_TASK_TYPE; use crate::asynk::async_runnable::AsyncRunnable; use crate::asynk::Error; use crate::{RetentionMode, SleepParams}; -use bb8_postgres::tokio_postgres::tls::MakeTlsConnect; -use bb8_postgres::tokio_postgres::tls::TlsConnect; -use bb8_postgres::tokio_postgres::Socket; use log::error; use std::time::Duration; use typed_builder::TypedBuilder; -#[derive(TypedBuilder, Debug)] -pub struct AsyncWorker -where - Tls: MakeTlsConnect + Clone + Send + Sync + 'static, - >::Stream: Send + Sync, - >::TlsConnect: Send, - <>::TlsConnect as TlsConnect>::Future: Send, -{ - #[builder(setter(into))] - pub queue: AsyncQueue, - #[builder(setter(into))] - pub task_type: Option, +#[derive(TypedBuilder)] +pub struct AsyncWorker<'a> { #[builder(setter(into))] + pub queue: &'a mut dyn AsyncQueueable, + #[builder(default=DEFAULT_TASK_TYPE.to_string() , setter(into))] + pub task_type: String, + #[builder(default, setter(into))] pub sleep_params: SleepParams, - #[builder(setter(into))] + #[builder(default, setter(into))] pub retention_mode: RetentionMode, } -impl AsyncWorker -where - Tls: MakeTlsConnect + Clone + Send + Sync + 'static, - >::Stream: Send + Sync, - >::TlsConnect: Send, - <>::TlsConnect as TlsConnect>::Future: Send, -{ - pub async fn run(&mut self, task: Task) { +impl<'a> AsyncWorker<'a> { + pub async fn run(&mut self, task: Task) -> Result<(), Error> { let result = self.execute_task(task).await; self.finalize_task(result).await } @@ -44,36 +29,44 @@ where let actual_task: Box = serde_json::from_value(task.metadata.clone()).unwrap(); - let task_result = actual_task.run(&mut self.queue).await; + let task_result = actual_task.run(self.queue).await; match task_result { Ok(()) => Ok(task), Err(error) => Err((task, error.description)), } } - async fn finalize_task(&mut self, result: Result) { + async fn finalize_task(&mut self, result: Result) -> Result<(), Error> { match self.retention_mode { - RetentionMode::KeepAll => { - match result { - Ok(task) => self - .queue - .update_task_state(task, FangTaskState::Finished) - .await - .unwrap(), - Err((task, error)) => self.queue.fail_task(task, &error).await.unwrap(), - }; - } - RetentionMode::RemoveAll => { - match result { - Ok(task) => self.queue.remove_task(task).await.unwrap(), - Err((task, _error)) => self.queue.remove_task(task).await.unwrap(), - }; - } - RetentionMode::RemoveFinished => match result { + RetentionMode::KeepAll => match result { Ok(task) => { - self.queue.remove_task(task).await.unwrap(); + self.queue + .update_task_state(task, FangTaskState::Finished) + .await?; + Ok(()) } Err((task, error)) => { - self.queue.fail_task(task, &error).await.unwrap(); + self.queue.fail_task(task, &error).await?; + Ok(()) + } + }, + RetentionMode::RemoveAll => match result { + Ok(task) => { + self.queue.remove_task(task).await?; + Ok(()) + } + Err((task, _error)) => { + self.queue.remove_task(task).await?; + Ok(()) + } + }, + RetentionMode::RemoveFinished => match result { + Ok(task) => { + self.queue.remove_task(task).await?; + Ok(()) + } + Err((task, error)) => { + self.queue.fail_task(task, &error).await?; + Ok(()) } }, } @@ -87,12 +80,12 @@ where loop { match self .queue - .fetch_and_touch_task(&self.task_type.clone()) + .fetch_and_touch_task(&Some(self.task_type.clone())) .await { Ok(Some(task)) => { self.sleep_params.maybe_reset_sleep_period(); - self.run(task).await; + self.run(task).await? } Ok(None) => { self.sleep().await; @@ -106,4 +99,237 @@ where }; } } + #[cfg(test)] + pub async fn run_tasks_until_none(&mut self) -> Result<(), Error> { + loop { + match self + .queue + .fetch_and_touch_task(&Some(self.task_type.clone())) + .await + { + Ok(Some(task)) => { + self.sleep_params.maybe_reset_sleep_period(); + self.run(task).await? + } + Ok(None) => { + return Ok(()); + } + Err(error) => { + error!("Failed to fetch a task {:?}", error); + + self.sleep().await; + } + }; + } + } +} + +#[cfg(test)] +mod async_worker_tests { + use super::AsyncWorker; + use crate::asynk::async_queue::AsyncQueueTest; + use crate::asynk::async_queue::AsyncQueueable; + use crate::asynk::async_queue::FangTaskState; + use crate::asynk::async_worker::Task; + use crate::asynk::AsyncRunnable; + use crate::asynk::Error; + use crate::RetentionMode; + use async_trait::async_trait; + use bb8_postgres::bb8::Pool; + use bb8_postgres::tokio_postgres::NoTls; + use bb8_postgres::PostgresConnectionManager; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize)] + struct WorkerAsyncTask { + pub number: u16, + } + + #[typetag::serde] + #[async_trait(?Send)] + impl AsyncRunnable for WorkerAsyncTask { + async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), Error> { + Ok(()) + } + } + #[derive(Serialize, Deserialize)] + struct AsyncFailedTask { + pub number: u16, + } + + #[typetag::serde] + #[async_trait(?Send)] + impl AsyncRunnable for AsyncFailedTask { + async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), Error> { + let message = format!("number {} is wrong :(", self.number); + + Err(Error { + description: message, + }) + } + } + + #[derive(Serialize, Deserialize)] + struct AsyncTaskType1 {} + + #[typetag::serde] + #[async_trait(?Send)] + impl AsyncRunnable for AsyncTaskType1 { + async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), Error> { + Ok(()) + } + + fn task_type(&self) -> String { + "type1".to_string() + } + } + + #[derive(Serialize, Deserialize)] + struct AsyncTaskType2 {} + + #[typetag::serde] + #[async_trait(?Send)] + impl AsyncRunnable for AsyncTaskType2 { + async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), Error> { + Ok(()) + } + + fn task_type(&self) -> String { + "type2".to_string() + } + } + #[tokio::test] + async fn execute_and_finishes_task() { + let pool = pool().await; + let mut connection = pool.get().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); + + let mut test = AsyncQueueTest { transaction }; + + let task = insert_task(&mut test, &WorkerAsyncTask { number: 1 }).await; + let id = task.id; + + let mut worker = AsyncWorker::builder() + .queue(&mut test as &mut dyn AsyncQueueable) + .retention_mode(RetentionMode::KeepAll) + .build(); + + worker.run(task).await.unwrap(); + let task_finished = test.get_task_by_id(id).await.unwrap(); + assert_eq!(id, task_finished.id); + assert_eq!(FangTaskState::Finished, task_finished.state); + test.transaction.rollback().await.unwrap(); + } + #[tokio::test] + async fn saves_error_for_failed_task() { + let pool = pool().await; + let mut connection = pool.get().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); + + let mut test = AsyncQueueTest { transaction }; + + let task = insert_task(&mut test, &AsyncFailedTask { number: 1 }).await; + let id = task.id; + + let mut worker = AsyncWorker::builder() + .queue(&mut test as &mut dyn AsyncQueueable) + .retention_mode(RetentionMode::KeepAll) + .build(); + + worker.run(task).await.unwrap(); + let task_finished = test.get_task_by_id(id).await.unwrap(); + + assert_eq!(id, task_finished.id); + assert_eq!(FangTaskState::Failed, task_finished.state); + assert_eq!( + "number 1 is wrong :(".to_string(), + task_finished.error_message.unwrap() + ); + test.transaction.rollback().await.unwrap(); + } + #[tokio::test] + async fn executes_task_only_of_specific_type() { + let pool = pool().await; + let mut connection = pool.get().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); + + let mut test = AsyncQueueTest { transaction }; + + let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await; + let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await; + let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await; + + let id1 = task1.id; + let id12 = task12.id; + let id2 = task2.id; + + let mut worker = AsyncWorker::builder() + .queue(&mut test as &mut dyn AsyncQueueable) + .task_type("type1".to_string()) + .retention_mode(RetentionMode::KeepAll) + .build(); + + worker.run_tasks_until_none().await.unwrap(); + let task1 = test.get_task_by_id(id1).await.unwrap(); + let task12 = test.get_task_by_id(id12).await.unwrap(); + let task2 = test.get_task_by_id(id2).await.unwrap(); + + assert_eq!(id1, task1.id); + assert_eq!(id12, task12.id); + assert_eq!(id2, task2.id); + assert_eq!(FangTaskState::Finished, task1.state); + assert_eq!(FangTaskState::Finished, task12.state); + assert_eq!(FangTaskState::New, task2.state); + test.transaction.rollback().await.unwrap(); + } + #[tokio::test] + async fn remove_when_finished() { + let pool = pool().await; + let mut connection = pool.get().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); + + let mut test = AsyncQueueTest { transaction }; + + let task1 = insert_task(&mut test, &AsyncTaskType1 {}).await; + let task12 = insert_task(&mut test, &AsyncTaskType1 {}).await; + let task2 = insert_task(&mut test, &AsyncTaskType2 {}).await; + + let _id1 = task1.id; + let _id12 = task12.id; + let id2 = task2.id; + + let mut worker = AsyncWorker::builder() + .queue(&mut test as &mut dyn AsyncQueueable) + .task_type("type1".to_string()) + .build(); + + worker.run_tasks_until_none().await.unwrap(); + let task = test + .fetch_and_touch_task(&Some("type1".to_string())) + .await + .unwrap(); + assert_eq!(None, task); + + let task2 = test + .fetch_and_touch_task(&Some("type2".to_string())) + .await + .unwrap() + .unwrap(); + assert_eq!(id2, task2.id); + + test.transaction.rollback().await.unwrap(); + } + async fn insert_task(test: &mut AsyncQueueTest<'_>, task: &dyn AsyncRunnable) -> Task { + let metadata = serde_json::to_value(task).unwrap(); + test.insert_task(metadata, &task.task_type()).await.unwrap() + } + async fn pool() -> Pool> { + let pg_mgr = PostgresConnectionManager::new_from_stringlike( + "postgres://postgres:postgres@localhost/fang", + NoTls, + ) + .unwrap(); + + Pool::builder().build(pg_mgr).await.unwrap() + } } diff --git a/src/asynk/queries/get_task_by_id.sql b/src/asynk/queries/get_task_by_id.sql new file mode 100644 index 0000000..608166f --- /dev/null +++ b/src/asynk/queries/get_task_by_id.sql @@ -0,0 +1 @@ +SELECT * FROM fang_tasks WHERE id = $1 diff --git a/src/lib.rs b/src/lib.rs index d4fc169..c9d36f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,11 @@ pub enum RetentionMode { RemoveAll, RemoveFinished, } +impl Default for RetentionMode { + fn default() -> Self { + RetentionMode::RemoveAll + } +} #[derive(Clone, Debug)] pub struct SleepParams { pub sleep_period: u64,