diff --git a/fang_examples/simple_async_worker/src/main.rs b/fang_examples/simple_async_worker/src/main.rs index 7adb9da..b4cd5f1 100644 --- a/fang_examples/simple_async_worker/src/main.rs +++ b/fang_examples/simple_async_worker/src/main.rs @@ -11,16 +11,19 @@ async fn main() { env_logger::init(); log::info!("Starting..."); + let max_pool_size: u32 = 2; + let mut queue = AsyncQueue::builder() + .uri("postgres://postgres:postgres@localhost/fang") + .max_pool_size(max_pool_size) + .duplicated_tasks(true) + .build(); - let mut queue = AsyncQueue::connect("postgres://postgres:postgres@localhost/fang", NoTls, true) - .await - .unwrap(); - + queue.connect(NoTls).await.unwrap(); log::info!("Queue connected..."); let mut pool = AsyncWorkerPool::builder() + .number_of_workers(max_pool_size) .queue(queue.clone()) - .number_of_workers(2 as u16) .build(); log::info!("Pool created ..."); diff --git a/src/asynk/async_queue.rs b/src/asynk/async_queue.rs index a6b4791..6c59704 100644 --- a/src/asynk/async_queue.rs +++ b/src/asynk/async_queue.rs @@ -100,6 +100,10 @@ pub enum AsyncQueueError { SerdeError(#[from] serde_json::Error), #[error("returned invalid result (expected {expected:?}, found {found:?})")] ResultError { expected: u64, found: u64 }, + #[error( + "AsyncQueue is not connected :( , call connect() method first and then perform operations" + )] + NotConnectedError, } impl From for FangError { @@ -159,10 +163,16 @@ where >::TlsConnect: Send, <>::TlsConnect as TlsConnect>::Future: Send, { + #[builder(default=None, setter(skip))] + pool: Option>>, #[builder(setter(into))] - pool: Pool>, + uri: String, + #[builder(setter(into))] + max_pool_size: u32, #[builder(default = false, setter(into))] duplicated_tasks: bool, + #[builder(default = false, setter(skip))] + connected: bool, } #[cfg(test)] @@ -330,20 +340,25 @@ where >::TlsConnect: Send, <>::TlsConnect as TlsConnect>::Future: Send, { - pub async fn connect( - uri: impl ToString, - tls: Tls, - duplicated_tasks: bool, - ) -> Result { - let manager = PostgresConnectionManager::new_from_stringlike(uri, tls)?; - let pool = Pool::builder().build(manager).await?; - - Ok(Self { - pool, - duplicated_tasks, - }) + pub fn check_if_connection(&self) -> Result<(), AsyncQueueError> { + if self.connected { + Ok(()) + } else { + Err(AsyncQueueError::NotConnectedError) + } } + pub async fn connect(&mut self, tls: Tls) -> Result<(), AsyncQueueError> { + let manager = PostgresConnectionManager::new_from_stringlike(self.uri.clone(), tls)?; + let pool = Pool::builder() + .max_size(self.max_pool_size) + .build(manager) + .await?; + + self.pool = Some(pool); + self.connected = true; + Ok(()) + } async fn remove_all_tasks_query( transaction: &mut Transaction<'_>, ) -> Result { @@ -599,7 +614,8 @@ where &mut self, task_type: Option, ) -> Result, AsyncQueueError> { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let task = Self::fetch_and_touch_task_query(&mut transaction, task_type).await?; @@ -610,7 +626,8 @@ where } async fn insert_task(&mut self, task: &dyn AsyncRunnable) -> Result { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let metadata = serde_json::to_value(task)?; @@ -633,7 +650,8 @@ where timestamp: DateTime, period: i32, ) -> Result { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let metadata = serde_json::to_value(task)?; @@ -650,7 +668,8 @@ where &mut self, periodic_task: PeriodicTask, ) -> Result { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let periodic_task = Self::schedule_next_task_query(&mut transaction, periodic_task).await?; @@ -664,7 +683,8 @@ where &mut self, error_margin_seconds: i64, ) -> Result>, AsyncQueueError> { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let periodic_task = @@ -676,7 +696,8 @@ where } async fn remove_all_tasks(&mut self) -> Result { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let result = Self::remove_all_tasks_query(&mut transaction).await?; @@ -687,7 +708,8 @@ where } async fn remove_task(&mut self, task: Task) -> Result { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let result = Self::remove_task_query(&mut transaction, task).await?; @@ -698,7 +720,8 @@ where } async fn remove_tasks_type(&mut self, task_type: &str) -> Result { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let result = Self::remove_tasks_type_query(&mut transaction, task_type).await?; @@ -713,7 +736,8 @@ where task: Task, state: FangTaskState, ) -> Result { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let task = Self::update_task_state_query(&mut transaction, task, state).await?; @@ -727,7 +751,8 @@ where task: Task, error_message: &str, ) -> Result { - let mut connection = self.pool.get().await?; + self.check_if_connection()?; + let mut connection = self.pool.as_ref().unwrap().get().await?; let mut transaction = connection.transaction().await?; let task = Self::fail_task_query(&mut transaction, task, error_message).await?; diff --git a/src/asynk/async_worker_pool.rs b/src/asynk/async_worker_pool.rs index 9ee1139..df7c201 100644 --- a/src/asynk/async_worker_pool.rs +++ b/src/asynk/async_worker_pool.rs @@ -26,7 +26,7 @@ where #[builder(default, setter(into))] pub retention_mode: RetentionMode, #[builder(setter(into))] - pub number_of_workers: u16, + pub number_of_workers: u32, } #[derive(TypedBuilder, Clone)] @@ -47,7 +47,7 @@ where <>::TlsConnect as TlsConnect>::Future: Send, { pub async fn start(&mut self) { - for _idx in 1..self.number_of_workers + 1 { + for _idx in 0..self.number_of_workers { let queue = self.queue.clone(); let sleep_params = self.sleep_params.clone(); let retention_mode = self.retention_mode.clone();