diff --git a/src/asynk/async_queue.rs b/src/asynk/async_queue.rs index c837379..fc16953 100644 --- a/src/asynk/async_queue.rs +++ b/src/asynk/async_queue.rs @@ -30,6 +30,7 @@ const REMOVE_ALL_TASK_QUERY: &str = include_str!("queries/remove_all_tasks.sql") const REMOVE_ALL_SCHEDULED_TASK_QUERY: &str = include_str!("queries/remove_all_scheduled_tasks.sql"); const REMOVE_TASK_QUERY: &str = include_str!("queries/remove_task.sql"); +const REMOVE_TASK_BY_METADATA_QUERY: &str = include_str!("queries/remove_task_by_metadata.sql"); const REMOVE_TASKS_TYPE_QUERY: &str = include_str!("queries/remove_tasks_type.sql"); const FETCH_TASK_TYPE_QUERY: &str = include_str!("queries/fetch_task_type.sql"); const FIND_TASK_BY_UNIQ_HASH_QUERY: &str = include_str!("queries/find_task_by_uniq_hash.sql"); @@ -96,6 +97,8 @@ pub enum AsyncQueueError { NotConnectedError, #[error("Can not convert `std::time::Duration` to `chrono::Duration`")] TimeError, + #[error("Can not perform this operation if task is not uniq, please check its definition in impl AsyncRunnable")] + TaskNotUniqError, } impl From for AsyncQueueError { @@ -117,7 +120,12 @@ pub trait AsyncQueueable: Send { async fn remove_all_scheduled_tasks(&mut self) -> Result; - async fn remove_task(&mut self, task: Task) -> Result; + async fn remove_task(&mut self, id: Uuid) -> Result; + + async fn remove_task_by_metadata( + &mut self, + task: &dyn AsyncRunnable, + ) -> Result; async fn remove_tasks_type(&mut self, task_type: &str) -> Result; @@ -259,10 +267,23 @@ impl AsyncQueueable for AsyncQueueTest<'_> { AsyncQueue::::remove_all_scheduled_tasks_query(transaction).await } - async fn remove_task(&mut self, task: Task) -> Result { + async fn remove_task(&mut self, id: Uuid) -> Result { let transaction = &mut self.transaction; - AsyncQueue::::remove_task_query(transaction, task).await + AsyncQueue::::remove_task_query(transaction, id).await + } + + async fn remove_task_by_metadata( + &mut self, + task: &dyn AsyncRunnable, + ) -> Result { + if task.uniq() { + let transaction = &mut self.transaction; + + AsyncQueue::::remove_task_by_metadata_query(transaction, task).await + } else { + Err(AsyncQueueError::TaskNotUniqError) + } } async fn remove_tasks_type(&mut self, task_type: &str) -> Result { @@ -338,9 +359,26 @@ where async fn remove_task_query( transaction: &mut Transaction<'_>, - task: Task, + id: Uuid, ) -> Result { - Self::execute_query(transaction, REMOVE_TASK_QUERY, &[&task.id], Some(1)).await + Self::execute_query(transaction, REMOVE_TASK_QUERY, &[&id], Some(1)).await + } + + async fn remove_task_by_metadata_query( + transaction: &mut Transaction<'_>, + task: &dyn AsyncRunnable, + ) -> Result { + let metadata = serde_json::to_value(task)?; + + let uniq_hash = Self::calculate_hash(metadata.to_string()); + + Self::execute_query( + transaction, + REMOVE_TASK_BY_METADATA_QUERY, + &[&uniq_hash], + None, + ) + .await } async fn remove_tasks_type_query( @@ -671,18 +709,37 @@ where Ok(result) } - async fn remove_task(&mut self, task: Task) -> Result { + async fn remove_task(&mut self, id: Uuid) -> Result { self.check_if_connection()?; let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; - let result = Self::remove_task_query(&mut transaction, task).await?; + let result = Self::remove_task_query(&mut transaction, id).await?; transaction.commit().await?; Ok(result) } + async fn remove_task_by_metadata( + &mut self, + task: &dyn AsyncRunnable, + ) -> Result { + if task.uniq() { + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; + let mut transaction = connection.transaction().await?; + + let result = Self::remove_task_by_metadata_query(&mut transaction, task).await?; + + transaction.commit().await?; + + Ok(result) + } else { + Err(AsyncQueueError::TaskNotUniqError) + } + } + async fn remove_tasks_type(&mut self, task_type: &str) -> Result { self.check_if_connection()?; let mut connection = self.pool.as_ref().unwrap().get().await?; @@ -758,6 +815,23 @@ mod async_queue_tests { } } + #[derive(Serialize, Deserialize)] + struct AsyncUniqTask { + pub number: u16, + } + + #[typetag::serde] + #[async_trait] + impl AsyncRunnable for AsyncUniqTask { + async fn run(&self, _queueable: &mut dyn AsyncQueueable) -> Result<(), FangError> { + Ok(()) + } + + fn uniq(&self) -> bool { + true + } + } + #[derive(Serialize, Deserialize)] struct AsyncTaskSchedule { pub number: u16, @@ -1019,6 +1093,47 @@ mod async_queue_tests { test.transaction.rollback().await.unwrap(); } + #[tokio::test] + async fn remove_tasks_by_metadata() { + let pool = pool().await; + let mut connection = pool.get().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); + + let mut test = AsyncQueueTest::builder().transaction(transaction).build(); + + let task = insert_task(&mut test, &AsyncUniqTask { number: 1 }).await; + + let metadata = task.metadata.as_object().unwrap(); + let number = metadata["number"].as_u64(); + let type_task = metadata["type"].as_str(); + + assert_eq!(Some(1), number); + assert_eq!(Some("AsyncUniqTask"), type_task); + + let task = insert_task(&mut test, &AsyncUniqTask { number: 2 }).await; + + let metadata = task.metadata.as_object().unwrap(); + let number = metadata["number"].as_u64(); + let type_task = metadata["type"].as_str(); + + assert_eq!(Some(2), number); + assert_eq!(Some("AsyncUniqTask"), type_task); + + let result = test + .remove_task_by_metadata(&AsyncUniqTask { number: 0 }) + .await + .unwrap(); + assert_eq!(0, result); + + let result = test + .remove_task_by_metadata(&AsyncUniqTask { number: 1 }) + .await + .unwrap(); + assert_eq!(1, result); + + test.transaction.rollback().await.unwrap(); + } + async fn insert_task(test: &mut AsyncQueueTest<'_>, task: &dyn AsyncRunnable) -> Task { test.insert_task(task).await.unwrap() } diff --git a/src/asynk/async_worker.rs b/src/asynk/async_worker.rs index 6dbd7e8..e362dcf 100644 --- a/src/asynk/async_worker.rs +++ b/src/asynk/async_worker.rs @@ -68,17 +68,17 @@ where }, RetentionMode::RemoveAll => match result { Ok(task) => { - self.queue.remove_task(task).await?; + self.queue.remove_task(task.id).await?; Ok(()) } Err((task, _error)) => { - self.queue.remove_task(task).await?; + self.queue.remove_task(task.id).await?; Ok(()) } }, RetentionMode::RemoveFinished => match result { Ok(task) => { - self.queue.remove_task(task).await?; + self.queue.remove_task(task.id).await?; Ok(()) } Err((task, error)) => { @@ -185,17 +185,17 @@ impl<'a> AsyncWorkerTest<'a> { }, RetentionMode::RemoveAll => match result { Ok(task) => { - self.queue.remove_task(task).await?; + self.queue.remove_task(task.id).await?; Ok(()) } Err((task, _error)) => { - self.queue.remove_task(task).await?; + self.queue.remove_task(task.id).await?; Ok(()) } }, RetentionMode::RemoveFinished => match result { Ok(task) => { - self.queue.remove_task(task).await?; + self.queue.remove_task(task.id).await?; Ok(()) } Err((task, error)) => { diff --git a/src/asynk/queries/remove_task_by_metadata.sql b/src/asynk/queries/remove_task_by_metadata.sql new file mode 100644 index 0000000..94324e2 --- /dev/null +++ b/src/asynk/queries/remove_task_by_metadata.sql @@ -0,0 +1 @@ +DELETE FROM "fang_tasks" WHERE uniq_hash = $1 diff --git a/src/blocking/queue.rs b/src/blocking/queue.rs index 909bae0..47060fd 100644 --- a/src/blocking/queue.rs +++ b/src/blocking/queue.rs @@ -71,6 +71,8 @@ pub enum QueueError { PoolError(#[from] PoolError), #[error(transparent)] CronError(#[from] CronError), + #[error("Can not perform this operation if task is not uniq, please check its definition in impl Runnable")] + TaskNotUniqError, } impl From for QueueError { @@ -92,6 +94,10 @@ pub trait Queueable { fn remove_task(&self, id: Uuid) -> Result; + /// To use this function task has to be uniq. uniq() has to return true. + /// If task is not uniq this function will not do anything. + fn remove_task_by_metadata(&self, task: &dyn Runnable) -> Result; + fn find_task_by_id(&self, id: Uuid) -> Option; fn update_task_state(&self, task: &Task, state: FangTaskState) -> Result; @@ -149,6 +155,18 @@ impl Queueable for Queue { Self::remove_task_query(&mut connection, id) } + /// To use this function task has to be uniq. uniq() has to return true. + /// If task is not uniq this function will not do anything. + fn remove_task_by_metadata(&self, task: &dyn Runnable) -> Result { + if task.uniq() { + let mut connection = self.get_connection()?; + + Self::remove_task_by_metadata_query(&mut connection, task) + } else { + Err(QueueError::TaskNotUniqError) + } + } + fn update_task_state(&self, task: &Task, state: FangTaskState) -> Result { let mut connection = self.get_connection()?; @@ -304,6 +322,19 @@ impl Queue { Ok(diesel::delete(query).execute(connection)?) } + pub fn remove_task_by_metadata_query( + connection: &mut PgConnection, + task: &dyn Runnable, + ) -> Result { + let metadata = serde_json::to_value(task).unwrap(); + + let uniq_hash = Self::calculate_hash(metadata.to_string()); + + let query = fang_tasks::table.filter(fang_tasks::uniq_hash.eq(uniq_hash)); + + Ok(diesel::delete(query).execute(connection)?) + } + pub fn remove_task_query(connection: &mut PgConnection, id: Uuid) -> Result { let query = fang_tasks::table.filter(fang_tasks::id.eq(id)); @@ -775,4 +806,35 @@ mod queue_tests { Ok(()) }); } + + #[test] + fn remove_task_by_metadata() { + let m_task1 = PepeTask { number: 10 }; + let m_task2 = PepeTask { number: 11 }; + let m_task3 = AyratTask { number: 10 }; + + let pool = Queue::connection_pool(5); + + let queue = Queue::builder().connection_pool(pool).build(); + + let mut queue_pooled_connection = queue.connection_pool.get().unwrap(); + + queue_pooled_connection.test_transaction::<(), Error, _>(|conn| { + let task1 = Queue::insert_query(conn, &m_task1, Utc::now()).unwrap(); + let task2 = Queue::insert_query(conn, &m_task2, Utc::now()).unwrap(); + let task3 = Queue::insert_query(conn, &m_task3, Utc::now()).unwrap(); + + assert!(Queue::find_task_by_id_query(conn, task1.id).is_some()); + assert!(Queue::find_task_by_id_query(conn, task2.id).is_some()); + assert!(Queue::find_task_by_id_query(conn, task3.id).is_some()); + + Queue::remove_task_by_metadata_query(conn, &m_task1).unwrap(); + + assert!(Queue::find_task_by_id_query(conn, task1.id).is_none()); + assert!(Queue::find_task_by_id_query(conn, task2.id).is_some()); + assert!(Queue::find_task_by_id_query(conn, task3.id).is_some()); + + Ok(()) + }); + } }