From 6681fe87defe43f48d07718a3eea85fa52e26ecb Mon Sep 17 00:00:00 2001 From: Diggory Blake Date: Sun, 28 Mar 2021 02:57:17 +0100 Subject: [PATCH] Initial commit --- .env | 1 + .gitignore | 2 + Cargo.toml | 27 ++ migrations/20210316025847_setup.down.sql | 12 + migrations/20210316025847_setup.up.sql | 286 ++++++++++++++++++ src/lib.rs | 195 +++++++++++++ src/runner.rs | 356 +++++++++++++++++++++++ src/spawn.rs | 106 +++++++ src/utils.rs | 37 +++ 9 files changed, 1022 insertions(+) create mode 100644 .env create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 migrations/20210316025847_setup.down.sql create mode 100644 migrations/20210316025847_setup.up.sql create mode 100644 src/lib.rs create mode 100644 src/runner.rs create mode 100644 src/spawn.rs create mode 100644 src/utils.rs diff --git a/.env b/.env new file mode 100644 index 0000000..be573bc --- /dev/null +++ b/.env @@ -0,0 +1 @@ +DATABASE_URL=postgres://postgres:password@localhost/sqlxmq diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..f56a9bc --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "sqlxmq" +version = "0.1.0" +authors = ["Diggory Blake "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +sqlx = { version = "0.5.1", features = [ + "postgres", + "runtime-async-std-native-tls", + "chrono", + "uuid", +] } +tokio = { version = "1.3.0", features = ["full"] } +dotenv = "0.15.0" +futures = "0.3.13" +chrono = "0.4.19" +uuid = { version = "0.8.2", features = ["v4"] } +log = "0.4.14" +serde_json = "1.0.64" +serde = "1.0.124" + +[dev-dependencies] +dotenv = "0.15.0" +pretty_env_logger = "0.4.0" diff --git a/migrations/20210316025847_setup.down.sql b/migrations/20210316025847_setup.down.sql new file mode 100644 index 0000000..933f6d4 --- /dev/null +++ b/migrations/20210316025847_setup.down.sql @@ -0,0 +1,12 @@ +DROP FUNCTION mq_checkpoint; +DROP FUNCTION mq_keep_alive; +DROP FUNCTION mq_delete; +DROP FUNCTION mq_commit; +DROP FUNCTION mq_insert; +DROP FUNCTION mq_poll; +DROP FUNCTION mq_active_channels; +DROP FUNCTION mq_latest_message; +DROP TABLE mq_payloads; +DROP TABLE mq_msgs; +DROP FUNCTION mq_uuid_exists; +DROP TYPE mq_new_t; \ No newline at end of file diff --git a/migrations/20210316025847_setup.up.sql b/migrations/20210316025847_setup.up.sql new file mode 100644 index 0000000..1623269 --- /dev/null +++ b/migrations/20210316025847_setup.up.sql @@ -0,0 +1,286 @@ +-- The UDT for creating messages +CREATE TYPE mq_new_t AS ( + -- Unique message ID + id UUID, + -- Delay before message is processed + delay INTERVAL, + -- Number of retries if initial processing fails + retries INT, + -- Initial backoff between retries + retry_backoff INTERVAL, + -- Name of channel + channel_name TEXT, + -- Arguments to channel + channel_args TEXT, + -- Interval for two-phase commit (or NULL to disable two-phase commit) + commit_interval INTERVAL, + -- Whether this message should be processed in order with respect to other + -- ordered messages. + ordered BOOLEAN, + -- Name of message + name TEXT, + -- JSON payload + payload_json TEXT, + -- Binary payload + payload_bytes BYTEA +); + +-- Small, frequently updated table of messages +CREATE TABLE mq_msgs ( + id UUID PRIMARY KEY, + created_at TIMESTAMPTZ DEFAULT NOW(), + attempt_at TIMESTAMPTZ DEFAULT NOW(), + attempts INT NOT NULL DEFAULT 5, + retry_backoff INTERVAL NOT NULL DEFAULT INTERVAL '1 second', + channel_name TEXT NOT NULL, + channel_args TEXT NOT NULL, + commit_interval INTERVAL, + after_message_id UUID DEFAULT uuid_nil() REFERENCES mq_msgs(id) ON DELETE SET DEFAULT +); + +-- Insert dummy message so that the 'nil' UUID can be referenced +INSERT INTO mq_msgs (id, channel_name, channel_args, after_message_id) VALUES (uuid_nil(), '', '', NULL); + +-- Internal helper function to check that a UUID is neither NULL nor NIL +CREATE FUNCTION mq_uuid_exists( + id UUID +) RETURNS BOOLEAN AS $$ + SELECT id IS NOT NULL AND id != uuid_nil() +$$ LANGUAGE SQL IMMUTABLE; + +-- Index for polling +CREATE INDEX ON mq_msgs(channel_name, channel_args, attempt_at) WHERE id != uuid_nil() AND NOT mq_uuid_exists(after_message_id); +-- Index for adding messages +CREATE INDEX ON mq_msgs(channel_name, channel_args, created_at, id) WHERE id != uuid_nil() AND after_message_id IS NOT NULL; + +-- Index for ensuring strict message order +CREATE UNIQUE INDEX ON mq_msgs(channel_name, channel_args, after_message_id); + + +-- Large, less frequently updated table of message payloads +CREATE TABLE mq_payloads( + id UUID PRIMARY KEY, + name TEXT NOT NULL, + payload_json JSONB, + payload_bytes BYTEA +); + +-- Internal helper function to return the most recently added message in a queue. +CREATE FUNCTION mq_latest_message(from_channel_name TEXT, from_channel_args TEXT) +RETURNS UUID AS $$ + SELECT COALESCE( + ( + SELECT id FROM mq_msgs + WHERE channel_name = from_channel_name + AND channel_args = from_channel_args + AND after_message_id IS NOT NULL + AND id != uuid_nil() + ORDER BY created_at DESC, id DESC + LIMIT 1 + ), + uuid_nil() + ) +$$ LANGUAGE SQL STABLE; + +-- Internal helper function to randomly select a set of channels with "ready" messages. +CREATE FUNCTION mq_active_channels(channel_names TEXT[], batch_size INT) +RETURNS TABLE(name TEXT, args TEXT) AS $$ + SELECT channel_name, channel_args + FROM mq_msgs + WHERE id != uuid_nil() + AND attempt_at <= NOW() + AND (channel_names IS NULL OR channel_name = ANY(channel_names)) + AND NOT mq_uuid_exists(after_message_id) + GROUP BY channel_name, channel_args + ORDER BY RANDOM() + LIMIT batch_size +$$ LANGUAGE SQL STABLE; + +-- Main entry-point for task runner: pulls a batch of messages from the queue. +CREATE FUNCTION mq_poll(channel_names TEXT[], batch_size INT DEFAULT 1) +RETURNS TABLE( + id UUID, + is_committed BOOLEAN, + name TEXT, + payload_json TEXT, + payload_bytes BYTEA, + retry_backoff INTERVAL, + wait_time INTERVAL +) AS $$ +BEGIN + RETURN QUERY UPDATE mq_msgs + SET + attempt_at = CASE WHEN mq_msgs.attempts = 1 THEN NULL ELSE NOW() + mq_msgs.retry_backoff END, + attempts = mq_msgs.attempts - 1, + retry_backoff = mq_msgs.retry_backoff * 2 + FROM ( + SELECT + msgs.id + FROM mq_active_channels(channel_names, batch_size) AS active_channels + INNER JOIN LATERAL ( + SELECT * FROM mq_msgs + WHERE mq_msgs.id != uuid_nil() + AND mq_msgs.attempt_at <= NOW() + AND mq_msgs.channel_name = active_channels.name + AND mq_msgs.channel_args = active_channels.args + AND NOT mq_uuid_exists(mq_msgs.after_message_id) + ORDER BY mq_msgs.attempt_at ASC + LIMIT batch_size + ) AS msgs ON TRUE + LIMIT batch_size + ) AS messages_to_update + LEFT JOIN mq_payloads ON mq_payloads.id = messages_to_update.id + WHERE mq_msgs.id = messages_to_update.id + RETURNING + mq_msgs.id, + mq_msgs.commit_interval IS NULL, + mq_payloads.name, + mq_payloads.payload_json::TEXT, + mq_payloads.payload_bytes, + mq_msgs.retry_backoff / 2, + interval '0' AS wait_time; + + IF NOT FOUND THEN + RETURN QUERY SELECT + NULL::UUID, + NULL::BOOLEAN, + NULL::TEXT, + NULL::TEXT, + NULL::BYTEA, + NULL::INTERVAL, + MIN(mq_msgs.attempt_at) - NOW() + FROM mq_msgs + WHERE mq_msgs.id != uuid_nil() + AND NOT mq_uuid_exists(mq_msgs.after_message_id) + AND (channel_names IS NULL OR mq_msgs.channel_name = ANY(channel_names)); + END IF; +END; +$$ LANGUAGE plpgsql; + +-- Creates new messages +CREATE FUNCTION mq_insert(new_messages mq_new_t[]) +RETURNS VOID AS $$ +BEGIN + PERFORM pg_notify(CONCAT('mq_', channel_name), '') + FROM unnest(new_messages) AS new_msgs + GROUP BY channel_name; + + IF FOUND THEN + PERFORM pg_notify('mq', ''); + END IF; + + INSERT INTO mq_payloads ( + id, + name, + payload_json, + payload_bytes + ) SELECT + id, + name, + payload_json::JSONB, + payload_bytes + FROM UNNEST(new_messages); + + INSERT INTO mq_msgs ( + id, + attempt_at, + attempts, + retry_backoff, + channel_name, + channel_args, + commit_interval, + after_message_id + ) + SELECT + id, + NOW() + delay + COALESCE(commit_interval, INTERVAL '0'), + retries + 1, + retry_backoff, + channel_name, + channel_args, + commit_interval, + CASE WHEN ordered + THEN + LAG(id, 1, mq_latest_message(channel_name, channel_args)) + OVER (PARTITION BY channel_name, channel_args, ordered ORDER BY id) + ELSE + NULL + END + FROM UNNEST(new_messages); +END; +$$ LANGUAGE plpgsql; + +-- Commits messages previously created with a non-NULL commit interval. +CREATE FUNCTION mq_commit(msg_ids UUID[]) +RETURNS VOID AS $$ +BEGIN + UPDATE mq_msgs + SET + attempt_at = attempt_at - commit_interval, + commit_interval = NULL + WHERE id = ANY(msg_ids) + AND commit_interval IS NOT NULL; +END; +$$ LANGUAGE plpgsql; + + +-- Deletes messages from the queue. This occurs when a message has been +-- processed, or when it expires without being processed. +CREATE FUNCTION mq_delete(msg_ids UUID[]) +RETURNS VOID AS $$ +BEGIN + PERFORM pg_notify(CONCAT('mq_', channel_name), '') + FROM mq_msgs + WHERE id = ANY(msg_ids) + AND after_message_id = uuid_nil() + GROUP BY channel_name; + + IF FOUND THEN + PERFORM pg_notify('mq', ''); + END IF; + + DELETE FROM mq_msgs WHERE id = ANY(msg_ids); + DELETE FROM mq_payloads WHERE id = ANY(msg_ids); +END; +$$ LANGUAGE plpgsql; + + +-- Can be called during the initial commit interval, or when processing +-- a message. Indicates that the caller is still active and will prevent either +-- the commit interval elapsing or the message being retried for the specified +-- interval. +CREATE FUNCTION mq_keep_alive(msg_ids UUID[], duration INTERVAL) +RETURNS VOID AS $$ + UPDATE mq_msgs + SET + attempt_at = NOW() + duration, + commit_interval = commit_interval + ((NOW() + duration) - attempt_at) + WHERE id = ANY(msg_ids) + AND attempt_at < NOW() + duration; +$$ LANGUAGE SQL; + + +-- Called during lengthy processing of a message to checkpoint the progress. +-- As well as behaving like `mq_keep_alive`, the message payload can be +-- updated. +CREATE FUNCTION mq_checkpoint( + msg_id UUID, + duration INTERVAL, + new_payload_json TEXT, + new_payload_bytes BYTEA, + extra_retries INT +) +RETURNS VOID AS $$ + UPDATE mq_msgs + SET + attempt_at = GREATEST(attempt_at, NOW() + duration), + attempts = attempts + COALESCE(extra_retries, 0) + WHERE id = msg_id; + + UPDATE mq_payloads + SET + payload_json = COALESCE(new_payload_json::JSONB, payload_json), + payload_bytes = COALESCE(new_payload_bytes, payload_bytes) + WHERE + id = msg_id; +$$ LANGUAGE SQL; diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..cba97bc --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,195 @@ +mod runner; +mod spawn; +mod utils; + +pub use runner::*; +pub use spawn::*; +pub use utils::OwnedTask; + +#[cfg(test)] +mod tests { + use super::*; + + use std::env; + use std::future::Future; + use std::ops::Deref; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::{Arc, Once}; + use std::time::Duration; + + use futures::channel::mpsc; + use futures::StreamExt; + use sqlx::{Pool, Postgres}; + use tokio::sync::{Mutex, MutexGuard}; + use tokio::task; + + struct TestGuard(MutexGuard<'static, ()>, T); + + impl Deref for TestGuard { + type Target = T; + + fn deref(&self) -> &T { + &self.1 + } + } + + async fn test_pool() -> TestGuard> { + static INIT_LOGGER: Once = Once::new(); + static TEST_MUTEX: Mutex<()> = Mutex::const_new(()); + + let guard = TEST_MUTEX.lock().await; + + let _ = dotenv::dotenv(); + + INIT_LOGGER.call_once(|| pretty_env_logger::init()); + + let pool = Pool::connect(&env::var("DATABASE_URL").unwrap()) + .await + .unwrap(); + + sqlx::query("TRUNCATE TABLE mq_payloads") + .execute(&pool) + .await + .unwrap(); + sqlx::query("DELETE FROM mq_msgs WHERE id != uuid_nil()") + .execute(&pool) + .await + .unwrap(); + + TestGuard(guard, pool) + } + + async fn test_task_runner( + pool: &Pool, + f: impl (Fn(CurrentTask) -> F) + Send + Sync + 'static, + ) -> (OwnedTask, Arc) + where + F::Output: Send + 'static, + { + let counter = Arc::new(AtomicUsize::new(0)); + let counter2 = counter.clone(); + let runner = TaskRunnerOptions::new(pool, move |task| { + counter2.fetch_add(1, Ordering::SeqCst); + task::spawn(f(task)); + }) + .run() + .await + .unwrap(); + (runner, counter) + } + + async fn pause() { + pause_ms(50).await; + } + + async fn pause_ms(ms: u64) { + tokio::time::sleep(Duration::from_millis(ms)).await; + } + + #[tokio::test] + async fn it_can_spawn_task() { + let pool = &*test_pool().await; + let (_runner, counter) = + test_task_runner(&pool, |mut task| async move { task.complete().await }).await; + + assert_eq!(counter.load(Ordering::SeqCst), 0); + TaskBuilder::new("foo").spawn(pool).await.unwrap(); + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn it_runs_tasks_in_order() { + let pool = &*test_pool().await; + let (tx, mut rx) = mpsc::unbounded(); + + let (_runner, counter) = test_task_runner(&pool, move |task| { + let tx = tx.clone(); + async move { + tx.unbounded_send(task).unwrap(); + } + }) + .await; + + assert_eq!(counter.load(Ordering::SeqCst), 0); + TaskBuilder::new("foo") + .set_ordered(true) + .spawn(pool) + .await + .unwrap(); + TaskBuilder::new("bar") + .set_ordered(true) + .spawn(pool) + .await + .unwrap(); + + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 1); + + let mut task = rx.next().await.unwrap(); + task.complete().await.unwrap(); + + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn it_runs_tasks_in_parallel() { + let pool = &*test_pool().await; + let (tx, mut rx) = mpsc::unbounded(); + + let (_runner, counter) = test_task_runner(&pool, move |task| { + let tx = tx.clone(); + async move { + tx.unbounded_send(task).unwrap(); + } + }) + .await; + + assert_eq!(counter.load(Ordering::SeqCst), 0); + TaskBuilder::new("foo").spawn(pool).await.unwrap(); + TaskBuilder::new("bar").spawn(pool).await.unwrap(); + + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 2); + + for _ in 0..2 { + let mut task = rx.next().await.unwrap(); + task.complete().await.unwrap(); + } + } + + #[tokio::test] + async fn it_retries_failed_tasks() { + let pool = &*test_pool().await; + let (_runner, counter) = test_task_runner(&pool, move |_| async {}).await; + + let backoff = 100; + + assert_eq!(counter.load(Ordering::SeqCst), 0); + TaskBuilder::new("foo") + .set_retry_backoff(Duration::from_millis(backoff)) + .set_retries(2) + .spawn(pool) + .await + .unwrap(); + + // First attempt + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 1); + + // Second attempt + pause_ms(backoff).await; + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 2); + + // Third attempt + pause_ms(backoff * 2).await; + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 3); + + // No more attempts + pause_ms(backoff * 5).await; + assert_eq!(counter.load(Ordering::SeqCst), 3); + } +} diff --git a/src/runner.rs b/src/runner.rs new file mode 100644 index 0000000..550ad95 --- /dev/null +++ b/src/runner.rs @@ -0,0 +1,356 @@ +use std::borrow::Cow; +use std::fmt::Debug; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use serde::{Deserialize, Serialize}; +use sqlx::postgres::types::PgInterval; +use sqlx::postgres::PgListener; +use sqlx::{Pool, Postgres}; +use tokio::sync::Notify; +use tokio::task; +use uuid::Uuid; + +use crate::utils::{Opaque, OwnedTask}; + +#[derive(Debug, Clone)] +pub struct TaskRunnerOptions { + min_concurrency: usize, + max_concurrency: usize, + channel_names: Option>, + runner: Opaque>, + pool: Pool, + keep_alive: bool, +} + +#[derive(Debug)] +struct TaskRunner { + options: TaskRunnerOptions, + running_tasks: AtomicUsize, + notify: Notify, +} + +#[derive(Debug, Clone)] +pub struct Checkpoint<'a> { + duration: Duration, + extra_retries: usize, + payload_json: Option>, + payload_bytes: Option<&'a [u8]>, +} + +impl<'a> Checkpoint<'a> { + pub fn new(duration: Duration) -> Self { + Self { + duration, + extra_retries: 0, + payload_json: None, + payload_bytes: None, + } + } + pub fn set_extra_retries(&mut self, extra_retries: usize) -> &mut Self { + self.extra_retries = extra_retries; + self + } + pub fn set_raw_json(&mut self, raw_json: &'a str) -> &mut Self { + self.payload_json = Some(Cow::Borrowed(raw_json)); + self + } + pub fn set_raw_bytes(&mut self, raw_bytes: &'a [u8]) -> &mut Self { + self.payload_bytes = Some(raw_bytes); + self + } + pub fn set_json(&mut self, value: &T) -> Result<&mut Self, serde_json::Error> { + let value = serde_json::to_string(value)?; + self.payload_json = Some(Cow::Owned(value)); + Ok(self) + } + async fn execute<'b, E: sqlx::Executor<'b, Database = Postgres>>( + &self, + task_id: Uuid, + executor: E, + ) -> Result<(), sqlx::Error> { + sqlx::query("SELECT mq_checkpoint($1, $2, $3, $4, $5)") + .bind(task_id) + .bind(self.duration) + .bind(self.payload_json.as_deref()) + .bind(self.payload_bytes) + .bind(self.extra_retries as i32) + .execute(executor) + .await?; + Ok(()) + } +} + +#[derive(Debug)] +pub struct CurrentTask { + id: Uuid, + name: String, + payload_json: Option, + payload_bytes: Option>, + task_runner: Arc, + keep_alive: Option, +} + +impl CurrentTask { + pub fn pool(&self) -> &Pool { + &self.task_runner.options.pool + } + async fn delete( + &self, + executor: impl sqlx::Executor<'_, Database = Postgres>, + ) -> Result<(), sqlx::Error> { + sqlx::query("SELECT mq_delete(ARRAY[$1])") + .bind(self.id) + .execute(executor) + .await?; + Ok(()) + } + pub async fn complete_with_transaction( + &mut self, + mut tx: sqlx::Transaction<'_, Postgres>, + ) -> Result<(), sqlx::Error> { + self.delete(&mut tx).await?; + tx.commit().await?; + self.keep_alive = None; + Ok(()) + } + pub async fn complete(&mut self) -> Result<(), sqlx::Error> { + self.delete(self.pool()).await?; + self.keep_alive = None; + Ok(()) + } + pub async fn checkpoint_with_transaction( + &mut self, + mut tx: sqlx::Transaction<'_, Postgres>, + checkpoint: &Checkpoint<'_>, + ) -> Result<(), sqlx::Error> { + checkpoint.execute(self.id, &mut tx).await?; + tx.commit().await?; + Ok(()) + } + pub async fn checkpoint(&mut self, checkpoint: &Checkpoint<'_>) -> Result<(), sqlx::Error> { + checkpoint.execute(self.id, self.pool()).await?; + Ok(()) + } + pub async fn keep_alive(&mut self, duration: Duration) -> Result<(), sqlx::Error> { + sqlx::query("SELECT mq_keep_alive(ARRAY[$1], $2)") + .bind(self.id) + .bind(duration) + .execute(self.pool()) + .await?; + Ok(()) + } + pub fn id(&self) -> Uuid { + self.id + } + pub fn name(&self) -> &str { + &self.name + } + pub fn json<'a, T: Deserialize<'a>>(&'a self) -> Result, serde_json::Error> { + if let Some(payload_json) = &self.payload_json { + serde_json::from_str(payload_json).map(Some) + } else { + Ok(None) + } + } + pub fn raw_json(&self) -> Option<&str> { + self.payload_json.as_deref() + } + pub fn raw_bytes(&self) -> Option<&[u8]> { + self.payload_bytes.as_deref() + } +} + +impl Drop for CurrentTask { + fn drop(&mut self) { + if self + .task_runner + .running_tasks + .fetch_sub(1, Ordering::SeqCst) + == self.task_runner.options.min_concurrency + { + self.task_runner.notify.notify_one(); + } + } +} + +impl TaskRunnerOptions { + pub fn new(pool: &Pool, f: F) -> Self { + Self { + min_concurrency: 16, + max_concurrency: 32, + channel_names: None, + keep_alive: true, + runner: Opaque(Arc::new(f)), + pool: pool.clone(), + } + } + pub async fn run(&self) -> Result { + let options = self.clone(); + let task_runner = Arc::new(TaskRunner { + options, + running_tasks: AtomicUsize::new(0), + notify: Notify::new(), + }); + let listener_task = start_listener(task_runner.clone()).await?; + Ok(OwnedTask(task::spawn(main_loop( + task_runner, + listener_task, + )))) + } +} + +async fn start_listener(task_runner: Arc) -> Result { + let mut listener = PgListener::connect_with(&task_runner.options.pool).await?; + if let Some(channels) = &task_runner.options.channel_names { + let names: Vec = channels.iter().map(|c| format!("mq_{}", c)).collect(); + listener + .listen_all(names.iter().map(|s| s.as_str())) + .await?; + } else { + listener.listen("mq").await?; + } + Ok(OwnedTask(task::spawn(async move { + while let Ok(_) = listener.recv().await { + task_runner.notify.notify_one(); + } + }))) +} + +#[derive(sqlx::FromRow)] +struct PolledMessage { + id: Option, + is_committed: Option, + name: Option, + payload_json: Option, + payload_bytes: Option>, + retry_backoff: Option, + wait_time: Option, +} + +fn to_duration(interval: PgInterval) -> Duration { + const SECONDS_PER_DAY: u64 = 24 * 60 * 60; + if interval.microseconds < 0 || interval.days < 0 || interval.months < 0 { + Duration::default() + } else { + let days = (interval.days as u64) + (interval.months as u64) * 30; + Duration::from_micros(interval.microseconds as u64) + + Duration::from_secs(days * SECONDS_PER_DAY) + } +} + +async fn poll_and_dispatch( + task_runner: &Arc, + batch_size: i32, +) -> Result { + log::info!("Polling for messages"); + + let options = &task_runner.options; + let messages = sqlx::query_as::<_, PolledMessage>("SELECT * FROM mq_poll($1, $2)") + .bind(&options.channel_names) + .bind(batch_size) + .fetch_all(&options.pool) + .await?; + + let ids_to_delete: Vec<_> = messages + .iter() + .filter(|msg| msg.is_committed == Some(false)) + .filter_map(|msg| msg.id) + .collect(); + + log::info!("Deleting {} messages", ids_to_delete.len()); + if !ids_to_delete.is_empty() { + sqlx::query("SELECT mq_delete($1)") + .bind(ids_to_delete) + .execute(&options.pool) + .await?; + } + + let wait_time = messages + .iter() + .filter_map(|msg| msg.wait_time.clone()) + .map(to_duration) + .min() + .unwrap_or(Duration::from_secs(60)); + + for msg in messages { + if let PolledMessage { + id: Some(id), + is_committed: Some(true), + name: Some(name), + payload_json, + payload_bytes, + retry_backoff: Some(retry_backoff), + .. + } = msg + { + let retry_backoff = to_duration(retry_backoff); + let keep_alive = if options.keep_alive { + Some(OwnedTask(task::spawn(keep_task_alive( + id, + options.pool.clone(), + retry_backoff, + )))) + } else { + None + }; + let current_task = CurrentTask { + id, + name, + payload_json, + payload_bytes, + task_runner: task_runner.clone(), + keep_alive, + }; + task_runner.running_tasks.fetch_add(1, Ordering::SeqCst); + (options.runner)(current_task); + } + } + + Ok(wait_time) +} + +async fn main_loop(task_runner: Arc, _listener_task: OwnedTask) { + let options = &task_runner.options; + let mut failures = 0; + loop { + let running_tasks = task_runner.running_tasks.load(Ordering::SeqCst); + let duration = if running_tasks < options.min_concurrency { + let batch_size = (options.max_concurrency - running_tasks) as i32; + + match poll_and_dispatch(&task_runner, batch_size).await { + Ok(duration) => { + failures = 0; + duration + } + Err(e) => { + failures += 1; + log::error!("Failed to poll for messages: {}", e); + Duration::from_millis(50 << failures) + } + } + } else { + Duration::from_secs(60) + }; + + // Wait for us to be notified, or for the timeout to elapse + let _ = tokio::time::timeout(duration, task_runner.notify.notified()).await; + } +} + +async fn keep_task_alive(id: Uuid, pool: Pool, mut interval: Duration) { + loop { + tokio::time::sleep(interval / 2).await; + interval *= 2; + if let Err(e) = sqlx::query("SELECT mq_keep_alive(ARRAY[$1], $2)") + .bind(id) + .bind(interval) + .execute(&pool) + .await + { + log::error!("Failed to keep task {} alive: {}", id, e); + break; + } + } +} diff --git a/src/spawn.rs b/src/spawn.rs new file mode 100644 index 0000000..da90905 --- /dev/null +++ b/src/spawn.rs @@ -0,0 +1,106 @@ +use std::borrow::Cow; +use std::fmt::Debug; +use std::time::Duration; + +use serde::Serialize; +use sqlx::Postgres; +use uuid::Uuid; + +#[derive(Debug, Clone)] +pub struct TaskBuilder<'a> { + id: Uuid, + delay: Duration, + channel_name: &'a str, + channel_args: &'a str, + retries: usize, + retry_backoff: Duration, + commit_interval: Option, + ordered: bool, + name: &'a str, + payload_json: Option>, + payload_bytes: Option<&'a [u8]>, +} + +impl<'a> TaskBuilder<'a> { + pub fn new(name: &'a str) -> Self { + Self::new_with_id(Uuid::new_v4(), name) + } + pub fn new_with_id(id: Uuid, name: &'a str) -> Self { + Self { + id, + delay: Duration::from_secs(0), + channel_name: "", + channel_args: "", + retries: 4, + retry_backoff: Duration::from_secs(1), + commit_interval: None, + ordered: false, + name, + payload_json: None, + payload_bytes: None, + } + } + pub fn set_channel_name(&mut self, channel_name: &'a str) -> &mut Self { + self.channel_name = channel_name; + self + } + pub fn set_channel_args(&mut self, channel_args: &'a str) -> &mut Self { + self.channel_args = channel_args; + self + } + pub fn set_retries(&mut self, retries: usize) -> &mut Self { + self.retries = retries; + self + } + pub fn set_retry_backoff(&mut self, retry_backoff: Duration) -> &mut Self { + self.retry_backoff = retry_backoff; + self + } + pub fn set_commit_interval(&mut self, commit_interval: Option) -> &mut Self { + self.commit_interval = commit_interval; + self + } + pub fn set_ordered(&mut self, ordered: bool) -> &mut Self { + self.ordered = ordered; + self + } + pub fn set_delay(&mut self, delay: Duration) -> &mut Self { + self.delay = delay; + self + } + pub fn set_raw_json(&mut self, raw_json: &'a str) -> &mut Self { + self.payload_json = Some(Cow::Borrowed(raw_json)); + self + } + pub fn set_raw_bytes(&mut self, raw_bytes: &'a [u8]) -> &mut Self { + self.payload_bytes = Some(raw_bytes); + self + } + pub fn set_json(&mut self, value: &T) -> Result<&mut Self, serde_json::Error> { + let value = serde_json::to_string(value)?; + self.payload_json = Some(Cow::Owned(value)); + Ok(self) + } + pub async fn spawn<'b, E: sqlx::Executor<'b, Database = Postgres>>( + &self, + executor: E, + ) -> Result { + sqlx::query( + "SELECT mq_insert(ARRAY[($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)::mq_new_t])", + ) + .bind(self.id) + .bind(self.delay) + .bind(self.retries as i32) + .bind(self.retry_backoff) + .bind(self.channel_name) + .bind(self.channel_args) + .bind(self.commit_interval) + .bind(self.ordered) + .bind(self.name) + .bind(self.payload_json.as_deref()) + .bind(self.payload_bytes) + .execute(executor) + .await?; + Ok(self.id) + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..8ab8cfc --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,37 @@ +use std::any::Any; +use std::fmt::{self, Debug}; +use std::ops::{Deref, DerefMut}; + +use tokio::task::JoinHandle; + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Opaque(pub T); + +impl Debug for Opaque { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(&self.0, f) + } +} + +impl Deref for Opaque { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Opaque { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[derive(Debug)] +pub struct OwnedTask(pub JoinHandle<()>); + +impl Drop for OwnedTask { + fn drop(&mut self) { + self.0.abort(); + } +}