From 140b19e6e4b8a78eac112c6a6ea89f06b87d2a04 Mon Sep 17 00:00:00 2001 From: Pmarquez <48651252+pxp9@users.noreply.github.com> Date: Tue, 2 Aug 2022 14:32:58 +0000 Subject: [PATCH] Generic Worker Pool (#55) * generic async worker pool !! * cfg tests --- fang_examples/simple_async_worker/src/main.rs | 2 +- src/asynk/async_worker.rs | 99 +++++++++++++++++-- src/asynk/async_worker_pool.rs | 28 ++---- 3 files changed, 99 insertions(+), 30 deletions(-) diff --git a/fang_examples/simple_async_worker/src/main.rs b/fang_examples/simple_async_worker/src/main.rs index b4cd5f1..a228219 100644 --- a/fang_examples/simple_async_worker/src/main.rs +++ b/fang_examples/simple_async_worker/src/main.rs @@ -21,7 +21,7 @@ async fn main() { queue.connect(NoTls).await.unwrap(); log::info!("Queue connected..."); - let mut pool = AsyncWorkerPool::builder() + let mut pool: AsyncWorkerPool> = AsyncWorkerPool::builder() .number_of_workers(max_pool_size) .queue(queue.clone()) .build(); diff --git a/src/asynk/async_worker.rs b/src/asynk/async_worker.rs index 16f8e08..5c7a89f 100644 --- a/src/asynk/async_worker.rs +++ b/src/asynk/async_worker.rs @@ -10,9 +10,12 @@ use std::time::Duration; use typed_builder::TypedBuilder; #[derive(TypedBuilder)] -pub struct AsyncWorker<'a> { +pub struct AsyncWorker +where + AQueue: AsyncQueueable + Clone + Sync + 'static, +{ #[builder(setter(into))] - pub queue: &'a mut dyn AsyncQueueable, + pub queue: AQueue, #[builder(default=DEFAULT_TASK_TYPE.to_string(), setter(into))] pub task_type: String, #[builder(default, setter(into))] @@ -21,7 +24,10 @@ pub struct AsyncWorker<'a> { pub retention_mode: RetentionMode, } -impl<'a> AsyncWorker<'a> { +impl AsyncWorker +where + AQueue: AsyncQueueable + Clone + Sync + 'static, +{ pub async fn run(&mut self, task: Task) -> Result<(), Error> { let result = self.execute_task(task).await; self.finalize_task(result).await @@ -31,7 +37,7 @@ impl<'a> AsyncWorker<'a> { let actual_task: Box = serde_json::from_value(task.metadata.clone()).unwrap(); - let task_result = actual_task.run(self.queue).await; + let task_result = actual_task.run(&mut self.queue).await; match task_result { Ok(()) => Ok(task), Err(error) => Err((task, error.description)), @@ -104,8 +110,81 @@ impl<'a> AsyncWorker<'a> { }; } } +} - #[cfg(test)] +#[cfg(test)] +#[derive(TypedBuilder)] +pub struct AsyncWorkerTest<'a> { + #[builder(setter(into))] + pub queue: &'a mut dyn AsyncQueueable, + #[builder(default=DEFAULT_TASK_TYPE.to_string(), setter(into))] + pub task_type: String, + #[builder(default, setter(into))] + pub sleep_params: SleepParams, + #[builder(default, setter(into))] + pub retention_mode: RetentionMode, +} + +#[cfg(test)] +impl<'a> AsyncWorkerTest<'a> { + pub async fn run(&mut self, task: Task) -> Result<(), Error> { + let result = self.execute_task(task).await; + self.finalize_task(result).await + } + + async fn execute_task(&mut self, task: Task) -> Result { + let actual_task: Box = + serde_json::from_value(task.metadata.clone()).unwrap(); + + let task_result = actual_task.run(self.queue).await; + match task_result { + Ok(()) => Ok(task), + Err(error) => Err((task, error.description)), + } + } + + async fn finalize_task(&mut self, result: Result) -> Result<(), Error> { + match self.retention_mode { + RetentionMode::KeepAll => match result { + Ok(task) => { + self.queue + .update_task_state(task, FangTaskState::Finished) + .await?; + Ok(()) + } + Err((task, error)) => { + self.queue.fail_task(task, &error).await?; + Ok(()) + } + }, + RetentionMode::RemoveAll => match result { + Ok(task) => { + self.queue.remove_task(task).await?; + Ok(()) + } + Err((task, _error)) => { + self.queue.remove_task(task).await?; + Ok(()) + } + }, + RetentionMode::RemoveFinished => match result { + Ok(task) => { + self.queue.remove_task(task).await?; + Ok(()) + } + Err((task, error)) => { + self.queue.fail_task(task, &error).await?; + Ok(()) + } + }, + } + } + + pub async fn sleep(&mut self) { + self.sleep_params.maybe_increase_sleep_period(); + + tokio::time::sleep(Duration::from_secs(self.sleep_params.sleep_period)).await; + } pub async fn run_tasks_until_none(&mut self) -> Result<(), Error> { loop { match self @@ -132,7 +211,7 @@ impl<'a> AsyncWorker<'a> { #[cfg(test)] mod async_worker_tests { - use super::AsyncWorker; + use super::AsyncWorkerTest; use crate::asynk::async_queue::AsyncQueueTest; use crate::asynk::async_queue::AsyncQueueable; use crate::asynk::async_queue::FangTaskState; @@ -215,7 +294,7 @@ mod async_worker_tests { let task = insert_task(&mut test, &WorkerAsyncTask { number: 1 }).await; let id = task.id; - let mut worker = AsyncWorker::builder() + let mut worker = AsyncWorkerTest::builder() .queue(&mut test as &mut dyn AsyncQueueable) .retention_mode(RetentionMode::KeepAll) .build(); @@ -237,7 +316,7 @@ mod async_worker_tests { let task = insert_task(&mut test, &AsyncFailedTask { number: 1 }).await; let id = task.id; - let mut worker = AsyncWorker::builder() + let mut worker = AsyncWorkerTest::builder() .queue(&mut test as &mut dyn AsyncQueueable) .retention_mode(RetentionMode::KeepAll) .build(); @@ -269,7 +348,7 @@ mod async_worker_tests { let id12 = task12.id; let id2 = task2.id; - let mut worker = AsyncWorker::builder() + let mut worker = AsyncWorkerTest::builder() .queue(&mut test as &mut dyn AsyncQueueable) .task_type("type1".to_string()) .retention_mode(RetentionMode::KeepAll) @@ -304,7 +383,7 @@ mod async_worker_tests { let _id12 = task12.id; let id2 = task2.id; - let mut worker = AsyncWorker::builder() + let mut worker = AsyncWorkerTest::builder() .queue(&mut test as &mut dyn AsyncQueueable) .task_type("type1".to_string()) .build(); diff --git a/src/asynk/async_worker_pool.rs b/src/asynk/async_worker_pool.rs index df7c201..8a60622 100644 --- a/src/asynk/async_worker_pool.rs +++ b/src/asynk/async_worker_pool.rs @@ -1,26 +1,19 @@ -use crate::asynk::async_queue::AsyncQueue; use crate::asynk::async_queue::AsyncQueueable; use crate::asynk::async_worker::AsyncWorker; use crate::asynk::Error; use crate::{RetentionMode, SleepParams}; use async_recursion::async_recursion; -use bb8_postgres::tokio_postgres::tls::MakeTlsConnect; -use bb8_postgres::tokio_postgres::tls::TlsConnect; -use bb8_postgres::tokio_postgres::Socket; use log::error; use std::time::Duration; use typed_builder::TypedBuilder; #[derive(TypedBuilder, Clone)] -pub struct AsyncWorkerPool +pub struct AsyncWorkerPool where - Tls: MakeTlsConnect + Clone + Send + Sync + 'static, - >::Stream: Send + Sync, - >::TlsConnect: Send, - <>::TlsConnect as TlsConnect>::Future: Send, + AQueue: AsyncQueueable + Clone + Sync + 'static, { #[builder(setter(into))] - pub queue: AsyncQueue, + pub queue: AQueue, #[builder(default, setter(into))] pub sleep_params: SleepParams, #[builder(default, setter(into))] @@ -39,12 +32,9 @@ pub struct WorkerParams { pub task_type: Option, } -impl AsyncWorkerPool +impl AsyncWorkerPool where - Tls: MakeTlsConnect + Clone + Send + Sync + 'static, - >::Stream: Send + Sync, - >::TlsConnect: Send, - <>::TlsConnect as TlsConnect>::Future: Send, + AQueue: AsyncQueueable + Clone + Sync + 'static, { pub async fn start(&mut self) { for _idx in 0..self.number_of_workers { @@ -60,7 +50,7 @@ where #[async_recursion] pub async fn supervise_worker( - queue: AsyncQueue, + queue: AQueue, sleep_params: SleepParams, retention_mode: RetentionMode, ) -> Result<(), Error> { @@ -82,12 +72,12 @@ where } pub async fn run_worker( - mut queue: AsyncQueue, + queue: AQueue, sleep_params: SleepParams, retention_mode: RetentionMode, ) -> Result<(), Error> { - let mut worker = AsyncWorker::builder() - .queue(&mut queue as &mut dyn AsyncQueueable) + let mut worker: AsyncWorker = AsyncWorker::builder() + .queue(queue) .sleep_params(sleep_params) .retention_mode(retention_mode) .build();