diff --git a/Cargo.toml b/Cargo.toml index 7e34edd..d5427dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,5 +22,7 @@ typetag = "0.2" log = "0.4" serde = { version = "1", features = ["derive"] } thiserror = "1.0" -bb8-postgres = "0.8" +bb8-postgres = {version = "0.8", features = ["with-serde_json-1"]} tokio = { version = "1.20", features = ["full"] } +async-trait = "0.1" +typed-builder = "0.10" diff --git a/src/asynk/async_queue.rs b/src/asynk/async_queue.rs index 8a9393a..4f3fcec 100644 --- a/src/asynk/async_queue.rs +++ b/src/asynk/async_queue.rs @@ -1,11 +1,131 @@ -use bb8_postgres::tokio_postgres::Client; +use crate::asynk::AsyncRunnable; +use bb8_postgres::bb8::Pool; +use bb8_postgres::bb8::RunError; +use bb8_postgres::tokio_postgres::tls::MakeTlsConnect; +use bb8_postgres::tokio_postgres::tls::TlsConnect; +use bb8_postgres::tokio_postgres::types::ToSql; +use bb8_postgres::tokio_postgres::Socket; +use bb8_postgres::PostgresConnectionManager; +use thiserror::Error; +use typed_builder::TypedBuilder; -pub struct AsyncQueue { - pg_client: Client, +#[derive(Error, Debug)] +pub enum AsyncQueueError { + #[error(transparent)] + PoolError(#[from] RunError), + #[error(transparent)] + PgError(#[from] bb8_postgres::tokio_postgres::Error), + #[error("returned invalid result (expected {expected:?}, found {found:?})")] + ResultError { expected: u64, found: u64 }, } -impl AsyncQueue { - pub fn new(pg_client: Client) -> Self { - Self { pg_client } +#[derive(Debug, TypedBuilder)] +pub struct AsyncQueue +where + Tls: MakeTlsConnect + Clone + Send + Sync + 'static, + >::Stream: Send + Sync, + >::TlsConnect: Send, + <>::TlsConnect as TlsConnect>::Future: Send, +{ + pool: Pool>, + #[builder(default = false)] + test: bool, +} + +const INSERT_TASK_QUERY: &str = include_str!("queries/insert_task.sql"); + +impl AsyncQueue +where + Tls: MakeTlsConnect + Clone + Send + Sync + 'static, + >::Stream: Send + Sync, + >::TlsConnect: Send, + <>::TlsConnect as TlsConnect>::Future: Send, +{ + pub fn new(pool: Pool>) -> Self { + AsyncQueue::builder().pool(pool).build() + } + + pub async fn insert_task(&mut self, task: &dyn AsyncRunnable) -> Result { + let json_task = serde_json::to_value(task).unwrap(); + let task_type = task.task_type(); + + self.execute_one(INSERT_TASK_QUERY, &[&json_task, &task_type]) + .await + } + + async fn execute_one( + &mut self, + query: &str, + params: &[&(dyn ToSql + Sync)], + ) -> Result { + let mut connection = self.pool.get().await?; + + let result = if self.test { + let transaction = connection.transaction().await?; + + let result = transaction.execute(query, params).await?; + + transaction.rollback().await?; + + result + } else { + connection.execute(query, params).await? + }; + + if result != 1 { + return Err(AsyncQueueError::ResultError { + expected: 1, + found: result, + }); + } + + Ok(result) + } +} + +#[cfg(test)] +mod async_queue_tests { + use super::AsyncQueue; + use crate::asynk::AsyncRunnable; + use crate::asynk::Error; + use async_trait::async_trait; + use bb8_postgres::bb8::Pool; + use bb8_postgres::tokio_postgres::Client; + use bb8_postgres::tokio_postgres::NoTls; + use bb8_postgres::PostgresConnectionManager; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize)] + struct Job { + pub number: u16, + } + + #[typetag::serde] + #[async_trait] + impl AsyncRunnable for Job { + async fn run(&self, _connection: &Client) -> Result<(), Error> { + Ok(()) + } + } + + #[tokio::test] + async fn insert_task_creates_new_task() { + let mut queue = queue().await; + + let result = queue.insert_task(&Job { number: 1 }).await.unwrap(); + + assert_eq!(1, result); + } + + async fn queue() -> AsyncQueue { + let pg_mgr = PostgresConnectionManager::new_from_stringlike( + "postgres://postgres:postgres@localhost/fang", + NoTls, + ) + .unwrap(); + + let pool = Pool::builder().build(pg_mgr).await.unwrap(); + + AsyncQueue::builder().pool(pool).test(true).build() } } diff --git a/src/asynk/async_runnable.rs b/src/asynk/async_runnable.rs new file mode 100644 index 0000000..a933a8d --- /dev/null +++ b/src/asynk/async_runnable.rs @@ -0,0 +1,19 @@ +use async_trait::async_trait; +use bb8_postgres::tokio_postgres::Client; + +const COMMON_TYPE: &str = "common"; + +#[derive(Debug)] +pub struct Error { + pub description: String, +} + +#[typetag::serde(tag = "type")] +#[async_trait] +pub trait AsyncRunnable { + async fn run(&self, client: &Client) -> Result<(), Error>; + + fn task_type(&self) -> String { + COMMON_TYPE.to_string() + } +} diff --git a/src/asynk/mod.rs b/src/asynk/mod.rs index 3b8222c..171d637 100644 --- a/src/asynk/mod.rs +++ b/src/asynk/mod.rs @@ -1 +1,5 @@ pub mod async_queue; +pub mod async_runnable; + +pub use async_runnable::AsyncRunnable; +pub use async_runnable::Error; diff --git a/src/asynk/queries/insert_job.sql b/src/asynk/queries/insert_job.sql index 33aea7d..2bef71d 100644 --- a/src/asynk/queries/insert_job.sql +++ b/src/asynk/queries/insert_job.sql @@ -1 +1 @@ -INSERT INTO "fang_tasks" ("metadata", "created_at") VALUES ($1, $2) +INSERT INTO "fang_tasks" ("metadata", "task_type") VALUES ($1, $2) diff --git a/src/asynk/queries/insert_task.sql b/src/asynk/queries/insert_task.sql index 8981fcc..2bef71d 100644 --- a/src/asynk/queries/insert_task.sql +++ b/src/asynk/queries/insert_task.sql @@ -1 +1 @@ -INSERT INTO "fang_tasks" ("metadata", "error_message", "state", "task_type") VALUES ($1, $2, $3, $4) +INSERT INTO "fang_tasks" ("metadata", "task_type") VALUES ($1, $2)