From d186a5434c4274b5ce14a9de4f60f029acb7a633 Mon Sep 17 00:00:00 2001 From: Ayrat Badykov Date: Sun, 17 Jul 2022 22:34:55 +0300 Subject: [PATCH] allow to use transaction to be able to rollback in tests --- src/asynk/async_queue.rs | 49 ++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/asynk/async_queue.rs b/src/asynk/async_queue.rs index 4f3fcec..89d01f0 100644 --- a/src/asynk/async_queue.rs +++ b/src/asynk/async_queue.rs @@ -5,11 +5,12 @@ use bb8_postgres::tokio_postgres::tls::MakeTlsConnect; use bb8_postgres::tokio_postgres::tls::TlsConnect; use bb8_postgres::tokio_postgres::types::ToSql; use bb8_postgres::tokio_postgres::Socket; +use bb8_postgres::tokio_postgres::Transaction; use bb8_postgres::PostgresConnectionManager; use thiserror::Error; use typed_builder::TypedBuilder; -#[derive(Error, Debug)] +#[derive(Debug, Error)] pub enum AsyncQueueError { #[error(transparent)] PoolError(#[from] RunError), @@ -17,24 +18,27 @@ pub enum AsyncQueueError { PgError(#[from] bb8_postgres::tokio_postgres::Error), #[error("returned invalid result (expected {expected:?}, found {found:?})")] ResultError { expected: u64, found: u64 }, + #[error("Queue doesn't have a connection")] + PoolAndTransactionEmpty, } -#[derive(Debug, TypedBuilder)] -pub struct AsyncQueue +#[derive(TypedBuilder)] +pub struct AsyncQueue<'a, Tls> where Tls: MakeTlsConnect + Clone + Send + Sync + 'static, >::Stream: Send + Sync, >::TlsConnect: Send, <>::TlsConnect as TlsConnect>::Future: Send, { - pool: Pool>, - #[builder(default = false)] - test: bool, + #[builder(default, setter(into))] + pool: Option>>, + #[builder(default, setter(into))] + transaction: Option>, } const INSERT_TASK_QUERY: &str = include_str!("queries/insert_task.sql"); -impl AsyncQueue +impl<'a, Tls> AsyncQueue<'a, Tls> where Tls: MakeTlsConnect + Clone + Send + Sync + 'static, >::Stream: Send + Sync, @@ -45,6 +49,10 @@ where AsyncQueue::builder().pool(pool).build() } + pub fn new_with_transaction(transaction: Transaction<'a>) -> Self { + AsyncQueue::builder().transaction(transaction).build() + } + pub async fn insert_task(&mut self, task: &dyn AsyncRunnable) -> Result { let json_task = serde_json::to_value(task).unwrap(); let task_type = task.task_type(); @@ -58,18 +66,14 @@ where query: &str, params: &[&(dyn ToSql + Sync)], ) -> Result { - let mut connection = self.pool.get().await?; + let result = if let Some(pool) = &self.pool { + let connection = pool.get().await?; - let result = if self.test { - let transaction = connection.transaction().await?; - - let result = transaction.execute(query, params).await?; - - transaction.rollback().await?; - - result - } else { connection.execute(query, params).await? + } else if let Some(transaction) = &self.transaction { + transaction.execute(query, params).await? + } else { + return Err(AsyncQueueError::PoolAndTransactionEmpty); }; if result != 1 { @@ -110,22 +114,23 @@ mod async_queue_tests { #[tokio::test] async fn insert_task_creates_new_task() { - let mut queue = queue().await; + let pool = pool().await; + let mut connection = pool.get().await.unwrap(); + let transaction = connection.transaction().await.unwrap(); + let mut queue = AsyncQueue::::new_with_transaction(transaction); let result = queue.insert_task(&Job { number: 1 }).await.unwrap(); assert_eq!(1, result); } - async fn queue() -> AsyncQueue { + async fn pool() -> Pool> { let pg_mgr = PostgresConnectionManager::new_from_stringlike( "postgres://postgres:postgres@localhost/fang", NoTls, ) .unwrap(); - let pool = Pool::builder().build(pg_mgr).await.unwrap(); - - AsyncQueue::builder().pool(pool).test(true).build() + Pool::builder().build(pg_mgr).await.unwrap() } }