mirror of
https://github.com/Diggsey/sqlxmq.git
synced 2025-01-02 20:38:45 +00:00
Initial commit
This commit is contained in:
commit
6681fe87de
9 changed files with 1022 additions and 0 deletions
1
.env
Normal file
1
.env
Normal file
|
@ -0,0 +1 @@
|
|||
DATABASE_URL=postgres://postgres:password@localhost/sqlxmq
|
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
/target
|
||||
Cargo.lock
|
27
Cargo.toml
Normal file
27
Cargo.toml
Normal file
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "sqlxmq"
|
||||
version = "0.1.0"
|
||||
authors = ["Diggory Blake <diggsey@googlemail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
sqlx = { version = "0.5.1", features = [
|
||||
"postgres",
|
||||
"runtime-async-std-native-tls",
|
||||
"chrono",
|
||||
"uuid",
|
||||
] }
|
||||
tokio = { version = "1.3.0", features = ["full"] }
|
||||
dotenv = "0.15.0"
|
||||
futures = "0.3.13"
|
||||
chrono = "0.4.19"
|
||||
uuid = { version = "0.8.2", features = ["v4"] }
|
||||
log = "0.4.14"
|
||||
serde_json = "1.0.64"
|
||||
serde = "1.0.124"
|
||||
|
||||
[dev-dependencies]
|
||||
dotenv = "0.15.0"
|
||||
pretty_env_logger = "0.4.0"
|
12
migrations/20210316025847_setup.down.sql
Normal file
12
migrations/20210316025847_setup.down.sql
Normal file
|
@ -0,0 +1,12 @@
|
|||
DROP FUNCTION mq_checkpoint;
|
||||
DROP FUNCTION mq_keep_alive;
|
||||
DROP FUNCTION mq_delete;
|
||||
DROP FUNCTION mq_commit;
|
||||
DROP FUNCTION mq_insert;
|
||||
DROP FUNCTION mq_poll;
|
||||
DROP FUNCTION mq_active_channels;
|
||||
DROP FUNCTION mq_latest_message;
|
||||
DROP TABLE mq_payloads;
|
||||
DROP TABLE mq_msgs;
|
||||
DROP FUNCTION mq_uuid_exists;
|
||||
DROP TYPE mq_new_t;
|
286
migrations/20210316025847_setup.up.sql
Normal file
286
migrations/20210316025847_setup.up.sql
Normal file
|
@ -0,0 +1,286 @@
|
|||
-- The UDT for creating messages
|
||||
CREATE TYPE mq_new_t AS (
|
||||
-- Unique message ID
|
||||
id UUID,
|
||||
-- Delay before message is processed
|
||||
delay INTERVAL,
|
||||
-- Number of retries if initial processing fails
|
||||
retries INT,
|
||||
-- Initial backoff between retries
|
||||
retry_backoff INTERVAL,
|
||||
-- Name of channel
|
||||
channel_name TEXT,
|
||||
-- Arguments to channel
|
||||
channel_args TEXT,
|
||||
-- Interval for two-phase commit (or NULL to disable two-phase commit)
|
||||
commit_interval INTERVAL,
|
||||
-- Whether this message should be processed in order with respect to other
|
||||
-- ordered messages.
|
||||
ordered BOOLEAN,
|
||||
-- Name of message
|
||||
name TEXT,
|
||||
-- JSON payload
|
||||
payload_json TEXT,
|
||||
-- Binary payload
|
||||
payload_bytes BYTEA
|
||||
);
|
||||
|
||||
-- Small, frequently updated table of messages
|
||||
CREATE TABLE mq_msgs (
|
||||
id UUID PRIMARY KEY,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
attempt_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
attempts INT NOT NULL DEFAULT 5,
|
||||
retry_backoff INTERVAL NOT NULL DEFAULT INTERVAL '1 second',
|
||||
channel_name TEXT NOT NULL,
|
||||
channel_args TEXT NOT NULL,
|
||||
commit_interval INTERVAL,
|
||||
after_message_id UUID DEFAULT uuid_nil() REFERENCES mq_msgs(id) ON DELETE SET DEFAULT
|
||||
);
|
||||
|
||||
-- Insert dummy message so that the 'nil' UUID can be referenced
|
||||
INSERT INTO mq_msgs (id, channel_name, channel_args, after_message_id) VALUES (uuid_nil(), '', '', NULL);
|
||||
|
||||
-- Internal helper function to check that a UUID is neither NULL nor NIL
|
||||
CREATE FUNCTION mq_uuid_exists(
|
||||
id UUID
|
||||
) RETURNS BOOLEAN AS $$
|
||||
SELECT id IS NOT NULL AND id != uuid_nil()
|
||||
$$ LANGUAGE SQL IMMUTABLE;
|
||||
|
||||
-- Index for polling
|
||||
CREATE INDEX ON mq_msgs(channel_name, channel_args, attempt_at) WHERE id != uuid_nil() AND NOT mq_uuid_exists(after_message_id);
|
||||
-- Index for adding messages
|
||||
CREATE INDEX ON mq_msgs(channel_name, channel_args, created_at, id) WHERE id != uuid_nil() AND after_message_id IS NOT NULL;
|
||||
|
||||
-- Index for ensuring strict message order
|
||||
CREATE UNIQUE INDEX ON mq_msgs(channel_name, channel_args, after_message_id);
|
||||
|
||||
|
||||
-- Large, less frequently updated table of message payloads
|
||||
CREATE TABLE mq_payloads(
|
||||
id UUID PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
payload_json JSONB,
|
||||
payload_bytes BYTEA
|
||||
);
|
||||
|
||||
-- Internal helper function to return the most recently added message in a queue.
|
||||
CREATE FUNCTION mq_latest_message(from_channel_name TEXT, from_channel_args TEXT)
|
||||
RETURNS UUID AS $$
|
||||
SELECT COALESCE(
|
||||
(
|
||||
SELECT id FROM mq_msgs
|
||||
WHERE channel_name = from_channel_name
|
||||
AND channel_args = from_channel_args
|
||||
AND after_message_id IS NOT NULL
|
||||
AND id != uuid_nil()
|
||||
ORDER BY created_at DESC, id DESC
|
||||
LIMIT 1
|
||||
),
|
||||
uuid_nil()
|
||||
)
|
||||
$$ LANGUAGE SQL STABLE;
|
||||
|
||||
-- Internal helper function to randomly select a set of channels with "ready" messages.
|
||||
CREATE FUNCTION mq_active_channels(channel_names TEXT[], batch_size INT)
|
||||
RETURNS TABLE(name TEXT, args TEXT) AS $$
|
||||
SELECT channel_name, channel_args
|
||||
FROM mq_msgs
|
||||
WHERE id != uuid_nil()
|
||||
AND attempt_at <= NOW()
|
||||
AND (channel_names IS NULL OR channel_name = ANY(channel_names))
|
||||
AND NOT mq_uuid_exists(after_message_id)
|
||||
GROUP BY channel_name, channel_args
|
||||
ORDER BY RANDOM()
|
||||
LIMIT batch_size
|
||||
$$ LANGUAGE SQL STABLE;
|
||||
|
||||
-- Main entry-point for task runner: pulls a batch of messages from the queue.
|
||||
CREATE FUNCTION mq_poll(channel_names TEXT[], batch_size INT DEFAULT 1)
|
||||
RETURNS TABLE(
|
||||
id UUID,
|
||||
is_committed BOOLEAN,
|
||||
name TEXT,
|
||||
payload_json TEXT,
|
||||
payload_bytes BYTEA,
|
||||
retry_backoff INTERVAL,
|
||||
wait_time INTERVAL
|
||||
) AS $$
|
||||
BEGIN
|
||||
RETURN QUERY UPDATE mq_msgs
|
||||
SET
|
||||
attempt_at = CASE WHEN mq_msgs.attempts = 1 THEN NULL ELSE NOW() + mq_msgs.retry_backoff END,
|
||||
attempts = mq_msgs.attempts - 1,
|
||||
retry_backoff = mq_msgs.retry_backoff * 2
|
||||
FROM (
|
||||
SELECT
|
||||
msgs.id
|
||||
FROM mq_active_channels(channel_names, batch_size) AS active_channels
|
||||
INNER JOIN LATERAL (
|
||||
SELECT * FROM mq_msgs
|
||||
WHERE mq_msgs.id != uuid_nil()
|
||||
AND mq_msgs.attempt_at <= NOW()
|
||||
AND mq_msgs.channel_name = active_channels.name
|
||||
AND mq_msgs.channel_args = active_channels.args
|
||||
AND NOT mq_uuid_exists(mq_msgs.after_message_id)
|
||||
ORDER BY mq_msgs.attempt_at ASC
|
||||
LIMIT batch_size
|
||||
) AS msgs ON TRUE
|
||||
LIMIT batch_size
|
||||
) AS messages_to_update
|
||||
LEFT JOIN mq_payloads ON mq_payloads.id = messages_to_update.id
|
||||
WHERE mq_msgs.id = messages_to_update.id
|
||||
RETURNING
|
||||
mq_msgs.id,
|
||||
mq_msgs.commit_interval IS NULL,
|
||||
mq_payloads.name,
|
||||
mq_payloads.payload_json::TEXT,
|
||||
mq_payloads.payload_bytes,
|
||||
mq_msgs.retry_backoff / 2,
|
||||
interval '0' AS wait_time;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
RETURN QUERY SELECT
|
||||
NULL::UUID,
|
||||
NULL::BOOLEAN,
|
||||
NULL::TEXT,
|
||||
NULL::TEXT,
|
||||
NULL::BYTEA,
|
||||
NULL::INTERVAL,
|
||||
MIN(mq_msgs.attempt_at) - NOW()
|
||||
FROM mq_msgs
|
||||
WHERE mq_msgs.id != uuid_nil()
|
||||
AND NOT mq_uuid_exists(mq_msgs.after_message_id)
|
||||
AND (channel_names IS NULL OR mq_msgs.channel_name = ANY(channel_names));
|
||||
END IF;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Creates new messages
|
||||
CREATE FUNCTION mq_insert(new_messages mq_new_t[])
|
||||
RETURNS VOID AS $$
|
||||
BEGIN
|
||||
PERFORM pg_notify(CONCAT('mq_', channel_name), '')
|
||||
FROM unnest(new_messages) AS new_msgs
|
||||
GROUP BY channel_name;
|
||||
|
||||
IF FOUND THEN
|
||||
PERFORM pg_notify('mq', '');
|
||||
END IF;
|
||||
|
||||
INSERT INTO mq_payloads (
|
||||
id,
|
||||
name,
|
||||
payload_json,
|
||||
payload_bytes
|
||||
) SELECT
|
||||
id,
|
||||
name,
|
||||
payload_json::JSONB,
|
||||
payload_bytes
|
||||
FROM UNNEST(new_messages);
|
||||
|
||||
INSERT INTO mq_msgs (
|
||||
id,
|
||||
attempt_at,
|
||||
attempts,
|
||||
retry_backoff,
|
||||
channel_name,
|
||||
channel_args,
|
||||
commit_interval,
|
||||
after_message_id
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
NOW() + delay + COALESCE(commit_interval, INTERVAL '0'),
|
||||
retries + 1,
|
||||
retry_backoff,
|
||||
channel_name,
|
||||
channel_args,
|
||||
commit_interval,
|
||||
CASE WHEN ordered
|
||||
THEN
|
||||
LAG(id, 1, mq_latest_message(channel_name, channel_args))
|
||||
OVER (PARTITION BY channel_name, channel_args, ordered ORDER BY id)
|
||||
ELSE
|
||||
NULL
|
||||
END
|
||||
FROM UNNEST(new_messages);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Commits messages previously created with a non-NULL commit interval.
|
||||
CREATE FUNCTION mq_commit(msg_ids UUID[])
|
||||
RETURNS VOID AS $$
|
||||
BEGIN
|
||||
UPDATE mq_msgs
|
||||
SET
|
||||
attempt_at = attempt_at - commit_interval,
|
||||
commit_interval = NULL
|
||||
WHERE id = ANY(msg_ids)
|
||||
AND commit_interval IS NOT NULL;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
|
||||
-- Deletes messages from the queue. This occurs when a message has been
|
||||
-- processed, or when it expires without being processed.
|
||||
CREATE FUNCTION mq_delete(msg_ids UUID[])
|
||||
RETURNS VOID AS $$
|
||||
BEGIN
|
||||
PERFORM pg_notify(CONCAT('mq_', channel_name), '')
|
||||
FROM mq_msgs
|
||||
WHERE id = ANY(msg_ids)
|
||||
AND after_message_id = uuid_nil()
|
||||
GROUP BY channel_name;
|
||||
|
||||
IF FOUND THEN
|
||||
PERFORM pg_notify('mq', '');
|
||||
END IF;
|
||||
|
||||
DELETE FROM mq_msgs WHERE id = ANY(msg_ids);
|
||||
DELETE FROM mq_payloads WHERE id = ANY(msg_ids);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
|
||||
-- Can be called during the initial commit interval, or when processing
|
||||
-- a message. Indicates that the caller is still active and will prevent either
|
||||
-- the commit interval elapsing or the message being retried for the specified
|
||||
-- interval.
|
||||
CREATE FUNCTION mq_keep_alive(msg_ids UUID[], duration INTERVAL)
|
||||
RETURNS VOID AS $$
|
||||
UPDATE mq_msgs
|
||||
SET
|
||||
attempt_at = NOW() + duration,
|
||||
commit_interval = commit_interval + ((NOW() + duration) - attempt_at)
|
||||
WHERE id = ANY(msg_ids)
|
||||
AND attempt_at < NOW() + duration;
|
||||
$$ LANGUAGE SQL;
|
||||
|
||||
|
||||
-- Called during lengthy processing of a message to checkpoint the progress.
|
||||
-- As well as behaving like `mq_keep_alive`, the message payload can be
|
||||
-- updated.
|
||||
CREATE FUNCTION mq_checkpoint(
|
||||
msg_id UUID,
|
||||
duration INTERVAL,
|
||||
new_payload_json TEXT,
|
||||
new_payload_bytes BYTEA,
|
||||
extra_retries INT
|
||||
)
|
||||
RETURNS VOID AS $$
|
||||
UPDATE mq_msgs
|
||||
SET
|
||||
attempt_at = GREATEST(attempt_at, NOW() + duration),
|
||||
attempts = attempts + COALESCE(extra_retries, 0)
|
||||
WHERE id = msg_id;
|
||||
|
||||
UPDATE mq_payloads
|
||||
SET
|
||||
payload_json = COALESCE(new_payload_json::JSONB, payload_json),
|
||||
payload_bytes = COALESCE(new_payload_bytes, payload_bytes)
|
||||
WHERE
|
||||
id = msg_id;
|
||||
$$ LANGUAGE SQL;
|
195
src/lib.rs
Normal file
195
src/lib.rs
Normal file
|
@ -0,0 +1,195 @@
|
|||
mod runner;
|
||||
mod spawn;
|
||||
mod utils;
|
||||
|
||||
pub use runner::*;
|
||||
pub use spawn::*;
|
||||
pub use utils::OwnedTask;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use std::env;
|
||||
use std::future::Future;
|
||||
use std::ops::Deref;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Once};
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::channel::mpsc;
|
||||
use futures::StreamExt;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tokio::sync::{Mutex, MutexGuard};
|
||||
use tokio::task;
|
||||
|
||||
struct TestGuard<T>(MutexGuard<'static, ()>, T);
|
||||
|
||||
impl<T> Deref for TestGuard<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &T {
|
||||
&self.1
|
||||
}
|
||||
}
|
||||
|
||||
async fn test_pool() -> TestGuard<Pool<Postgres>> {
|
||||
static INIT_LOGGER: Once = Once::new();
|
||||
static TEST_MUTEX: Mutex<()> = Mutex::const_new(());
|
||||
|
||||
let guard = TEST_MUTEX.lock().await;
|
||||
|
||||
let _ = dotenv::dotenv();
|
||||
|
||||
INIT_LOGGER.call_once(|| pretty_env_logger::init());
|
||||
|
||||
let pool = Pool::connect(&env::var("DATABASE_URL").unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
sqlx::query("TRUNCATE TABLE mq_payloads")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
sqlx::query("DELETE FROM mq_msgs WHERE id != uuid_nil()")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
TestGuard(guard, pool)
|
||||
}
|
||||
|
||||
async fn test_task_runner<F: Future + Send + 'static>(
|
||||
pool: &Pool<Postgres>,
|
||||
f: impl (Fn(CurrentTask) -> F) + Send + Sync + 'static,
|
||||
) -> (OwnedTask, Arc<AtomicUsize>)
|
||||
where
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let counter2 = counter.clone();
|
||||
let runner = TaskRunnerOptions::new(pool, move |task| {
|
||||
counter2.fetch_add(1, Ordering::SeqCst);
|
||||
task::spawn(f(task));
|
||||
})
|
||||
.run()
|
||||
.await
|
||||
.unwrap();
|
||||
(runner, counter)
|
||||
}
|
||||
|
||||
async fn pause() {
|
||||
pause_ms(50).await;
|
||||
}
|
||||
|
||||
async fn pause_ms(ms: u64) {
|
||||
tokio::time::sleep(Duration::from_millis(ms)).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_spawn_task() {
|
||||
let pool = &*test_pool().await;
|
||||
let (_runner, counter) =
|
||||
test_task_runner(&pool, |mut task| async move { task.complete().await }).await;
|
||||
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 0);
|
||||
TaskBuilder::new("foo").spawn(pool).await.unwrap();
|
||||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_runs_tasks_in_order() {
|
||||
let pool = &*test_pool().await;
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
|
||||
let (_runner, counter) = test_task_runner(&pool, move |task| {
|
||||
let tx = tx.clone();
|
||||
async move {
|
||||
tx.unbounded_send(task).unwrap();
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 0);
|
||||
TaskBuilder::new("foo")
|
||||
.set_ordered(true)
|
||||
.spawn(pool)
|
||||
.await
|
||||
.unwrap();
|
||||
TaskBuilder::new("bar")
|
||||
.set_ordered(true)
|
||||
.spawn(pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
|
||||
let mut task = rx.next().await.unwrap();
|
||||
task.complete().await.unwrap();
|
||||
|
||||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_runs_tasks_in_parallel() {
|
||||
let pool = &*test_pool().await;
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
|
||||
let (_runner, counter) = test_task_runner(&pool, move |task| {
|
||||
let tx = tx.clone();
|
||||
async move {
|
||||
tx.unbounded_send(task).unwrap();
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 0);
|
||||
TaskBuilder::new("foo").spawn(pool).await.unwrap();
|
||||
TaskBuilder::new("bar").spawn(pool).await.unwrap();
|
||||
|
||||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 2);
|
||||
|
||||
for _ in 0..2 {
|
||||
let mut task = rx.next().await.unwrap();
|
||||
task.complete().await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_retries_failed_tasks() {
|
||||
let pool = &*test_pool().await;
|
||||
let (_runner, counter) = test_task_runner(&pool, move |_| async {}).await;
|
||||
|
||||
let backoff = 100;
|
||||
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 0);
|
||||
TaskBuilder::new("foo")
|
||||
.set_retry_backoff(Duration::from_millis(backoff))
|
||||
.set_retries(2)
|
||||
.spawn(pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// First attempt
|
||||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
|
||||
// Second attempt
|
||||
pause_ms(backoff).await;
|
||||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 2);
|
||||
|
||||
// Third attempt
|
||||
pause_ms(backoff * 2).await;
|
||||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 3);
|
||||
|
||||
// No more attempts
|
||||
pause_ms(backoff * 5).await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
}
|
356
src/runner.rs
Normal file
356
src/runner.rs
Normal file
|
@ -0,0 +1,356 @@
|
|||
use std::borrow::Cow;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::postgres::types::PgInterval;
|
||||
use sqlx::postgres::PgListener;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tokio::sync::Notify;
|
||||
use tokio::task;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::utils::{Opaque, OwnedTask};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskRunnerOptions {
|
||||
min_concurrency: usize,
|
||||
max_concurrency: usize,
|
||||
channel_names: Option<Vec<String>>,
|
||||
runner: Opaque<Arc<dyn Fn(CurrentTask) + Send + Sync + 'static>>,
|
||||
pool: Pool<Postgres>,
|
||||
keep_alive: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TaskRunner {
|
||||
options: TaskRunnerOptions,
|
||||
running_tasks: AtomicUsize,
|
||||
notify: Notify,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Checkpoint<'a> {
|
||||
duration: Duration,
|
||||
extra_retries: usize,
|
||||
payload_json: Option<Cow<'a, str>>,
|
||||
payload_bytes: Option<&'a [u8]>,
|
||||
}
|
||||
|
||||
impl<'a> Checkpoint<'a> {
|
||||
pub fn new(duration: Duration) -> Self {
|
||||
Self {
|
||||
duration,
|
||||
extra_retries: 0,
|
||||
payload_json: None,
|
||||
payload_bytes: None,
|
||||
}
|
||||
}
|
||||
pub fn set_extra_retries(&mut self, extra_retries: usize) -> &mut Self {
|
||||
self.extra_retries = extra_retries;
|
||||
self
|
||||
}
|
||||
pub fn set_raw_json(&mut self, raw_json: &'a str) -> &mut Self {
|
||||
self.payload_json = Some(Cow::Borrowed(raw_json));
|
||||
self
|
||||
}
|
||||
pub fn set_raw_bytes(&mut self, raw_bytes: &'a [u8]) -> &mut Self {
|
||||
self.payload_bytes = Some(raw_bytes);
|
||||
self
|
||||
}
|
||||
pub fn set_json<T: Serialize>(&mut self, value: &T) -> Result<&mut Self, serde_json::Error> {
|
||||
let value = serde_json::to_string(value)?;
|
||||
self.payload_json = Some(Cow::Owned(value));
|
||||
Ok(self)
|
||||
}
|
||||
async fn execute<'b, E: sqlx::Executor<'b, Database = Postgres>>(
|
||||
&self,
|
||||
task_id: Uuid,
|
||||
executor: E,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
sqlx::query("SELECT mq_checkpoint($1, $2, $3, $4, $5)")
|
||||
.bind(task_id)
|
||||
.bind(self.duration)
|
||||
.bind(self.payload_json.as_deref())
|
||||
.bind(self.payload_bytes)
|
||||
.bind(self.extra_retries as i32)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CurrentTask {
|
||||
id: Uuid,
|
||||
name: String,
|
||||
payload_json: Option<String>,
|
||||
payload_bytes: Option<Vec<u8>>,
|
||||
task_runner: Arc<TaskRunner>,
|
||||
keep_alive: Option<OwnedTask>,
|
||||
}
|
||||
|
||||
impl CurrentTask {
|
||||
pub fn pool(&self) -> &Pool<Postgres> {
|
||||
&self.task_runner.options.pool
|
||||
}
|
||||
async fn delete(
|
||||
&self,
|
||||
executor: impl sqlx::Executor<'_, Database = Postgres>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
sqlx::query("SELECT mq_delete(ARRAY[$1])")
|
||||
.bind(self.id)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn complete_with_transaction(
|
||||
&mut self,
|
||||
mut tx: sqlx::Transaction<'_, Postgres>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
self.delete(&mut tx).await?;
|
||||
tx.commit().await?;
|
||||
self.keep_alive = None;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn complete(&mut self) -> Result<(), sqlx::Error> {
|
||||
self.delete(self.pool()).await?;
|
||||
self.keep_alive = None;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn checkpoint_with_transaction(
|
||||
&mut self,
|
||||
mut tx: sqlx::Transaction<'_, Postgres>,
|
||||
checkpoint: &Checkpoint<'_>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
checkpoint.execute(self.id, &mut tx).await?;
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn checkpoint(&mut self, checkpoint: &Checkpoint<'_>) -> Result<(), sqlx::Error> {
|
||||
checkpoint.execute(self.id, self.pool()).await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn keep_alive(&mut self, duration: Duration) -> Result<(), sqlx::Error> {
|
||||
sqlx::query("SELECT mq_keep_alive(ARRAY[$1], $2)")
|
||||
.bind(self.id)
|
||||
.bind(duration)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn id(&self) -> Uuid {
|
||||
self.id
|
||||
}
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
pub fn json<'a, T: Deserialize<'a>>(&'a self) -> Result<Option<T>, serde_json::Error> {
|
||||
if let Some(payload_json) = &self.payload_json {
|
||||
serde_json::from_str(payload_json).map(Some)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
pub fn raw_json(&self) -> Option<&str> {
|
||||
self.payload_json.as_deref()
|
||||
}
|
||||
pub fn raw_bytes(&self) -> Option<&[u8]> {
|
||||
self.payload_bytes.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for CurrentTask {
|
||||
fn drop(&mut self) {
|
||||
if self
|
||||
.task_runner
|
||||
.running_tasks
|
||||
.fetch_sub(1, Ordering::SeqCst)
|
||||
== self.task_runner.options.min_concurrency
|
||||
{
|
||||
self.task_runner.notify.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TaskRunnerOptions {
|
||||
pub fn new<F: Fn(CurrentTask) + Send + Sync + 'static>(pool: &Pool<Postgres>, f: F) -> Self {
|
||||
Self {
|
||||
min_concurrency: 16,
|
||||
max_concurrency: 32,
|
||||
channel_names: None,
|
||||
keep_alive: true,
|
||||
runner: Opaque(Arc::new(f)),
|
||||
pool: pool.clone(),
|
||||
}
|
||||
}
|
||||
pub async fn run(&self) -> Result<OwnedTask, sqlx::Error> {
|
||||
let options = self.clone();
|
||||
let task_runner = Arc::new(TaskRunner {
|
||||
options,
|
||||
running_tasks: AtomicUsize::new(0),
|
||||
notify: Notify::new(),
|
||||
});
|
||||
let listener_task = start_listener(task_runner.clone()).await?;
|
||||
Ok(OwnedTask(task::spawn(main_loop(
|
||||
task_runner,
|
||||
listener_task,
|
||||
))))
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_listener(task_runner: Arc<TaskRunner>) -> Result<OwnedTask, sqlx::Error> {
|
||||
let mut listener = PgListener::connect_with(&task_runner.options.pool).await?;
|
||||
if let Some(channels) = &task_runner.options.channel_names {
|
||||
let names: Vec<String> = channels.iter().map(|c| format!("mq_{}", c)).collect();
|
||||
listener
|
||||
.listen_all(names.iter().map(|s| s.as_str()))
|
||||
.await?;
|
||||
} else {
|
||||
listener.listen("mq").await?;
|
||||
}
|
||||
Ok(OwnedTask(task::spawn(async move {
|
||||
while let Ok(_) = listener.recv().await {
|
||||
task_runner.notify.notify_one();
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct PolledMessage {
|
||||
id: Option<Uuid>,
|
||||
is_committed: Option<bool>,
|
||||
name: Option<String>,
|
||||
payload_json: Option<String>,
|
||||
payload_bytes: Option<Vec<u8>>,
|
||||
retry_backoff: Option<PgInterval>,
|
||||
wait_time: Option<PgInterval>,
|
||||
}
|
||||
|
||||
fn to_duration(interval: PgInterval) -> Duration {
|
||||
const SECONDS_PER_DAY: u64 = 24 * 60 * 60;
|
||||
if interval.microseconds < 0 || interval.days < 0 || interval.months < 0 {
|
||||
Duration::default()
|
||||
} else {
|
||||
let days = (interval.days as u64) + (interval.months as u64) * 30;
|
||||
Duration::from_micros(interval.microseconds as u64)
|
||||
+ Duration::from_secs(days * SECONDS_PER_DAY)
|
||||
}
|
||||
}
|
||||
|
||||
async fn poll_and_dispatch(
|
||||
task_runner: &Arc<TaskRunner>,
|
||||
batch_size: i32,
|
||||
) -> Result<Duration, sqlx::Error> {
|
||||
log::info!("Polling for messages");
|
||||
|
||||
let options = &task_runner.options;
|
||||
let messages = sqlx::query_as::<_, PolledMessage>("SELECT * FROM mq_poll($1, $2)")
|
||||
.bind(&options.channel_names)
|
||||
.bind(batch_size)
|
||||
.fetch_all(&options.pool)
|
||||
.await?;
|
||||
|
||||
let ids_to_delete: Vec<_> = messages
|
||||
.iter()
|
||||
.filter(|msg| msg.is_committed == Some(false))
|
||||
.filter_map(|msg| msg.id)
|
||||
.collect();
|
||||
|
||||
log::info!("Deleting {} messages", ids_to_delete.len());
|
||||
if !ids_to_delete.is_empty() {
|
||||
sqlx::query("SELECT mq_delete($1)")
|
||||
.bind(ids_to_delete)
|
||||
.execute(&options.pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let wait_time = messages
|
||||
.iter()
|
||||
.filter_map(|msg| msg.wait_time.clone())
|
||||
.map(to_duration)
|
||||
.min()
|
||||
.unwrap_or(Duration::from_secs(60));
|
||||
|
||||
for msg in messages {
|
||||
if let PolledMessage {
|
||||
id: Some(id),
|
||||
is_committed: Some(true),
|
||||
name: Some(name),
|
||||
payload_json,
|
||||
payload_bytes,
|
||||
retry_backoff: Some(retry_backoff),
|
||||
..
|
||||
} = msg
|
||||
{
|
||||
let retry_backoff = to_duration(retry_backoff);
|
||||
let keep_alive = if options.keep_alive {
|
||||
Some(OwnedTask(task::spawn(keep_task_alive(
|
||||
id,
|
||||
options.pool.clone(),
|
||||
retry_backoff,
|
||||
))))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let current_task = CurrentTask {
|
||||
id,
|
||||
name,
|
||||
payload_json,
|
||||
payload_bytes,
|
||||
task_runner: task_runner.clone(),
|
||||
keep_alive,
|
||||
};
|
||||
task_runner.running_tasks.fetch_add(1, Ordering::SeqCst);
|
||||
(options.runner)(current_task);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(wait_time)
|
||||
}
|
||||
|
||||
async fn main_loop(task_runner: Arc<TaskRunner>, _listener_task: OwnedTask) {
|
||||
let options = &task_runner.options;
|
||||
let mut failures = 0;
|
||||
loop {
|
||||
let running_tasks = task_runner.running_tasks.load(Ordering::SeqCst);
|
||||
let duration = if running_tasks < options.min_concurrency {
|
||||
let batch_size = (options.max_concurrency - running_tasks) as i32;
|
||||
|
||||
match poll_and_dispatch(&task_runner, batch_size).await {
|
||||
Ok(duration) => {
|
||||
failures = 0;
|
||||
duration
|
||||
}
|
||||
Err(e) => {
|
||||
failures += 1;
|
||||
log::error!("Failed to poll for messages: {}", e);
|
||||
Duration::from_millis(50 << failures)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Duration::from_secs(60)
|
||||
};
|
||||
|
||||
// Wait for us to be notified, or for the timeout to elapse
|
||||
let _ = tokio::time::timeout(duration, task_runner.notify.notified()).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn keep_task_alive(id: Uuid, pool: Pool<Postgres>, mut interval: Duration) {
|
||||
loop {
|
||||
tokio::time::sleep(interval / 2).await;
|
||||
interval *= 2;
|
||||
if let Err(e) = sqlx::query("SELECT mq_keep_alive(ARRAY[$1], $2)")
|
||||
.bind(id)
|
||||
.bind(interval)
|
||||
.execute(&pool)
|
||||
.await
|
||||
{
|
||||
log::error!("Failed to keep task {} alive: {}", id, e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
106
src/spawn.rs
Normal file
106
src/spawn.rs
Normal file
|
@ -0,0 +1,106 @@
|
|||
use std::borrow::Cow;
|
||||
use std::fmt::Debug;
|
||||
use std::time::Duration;
|
||||
|
||||
use serde::Serialize;
|
||||
use sqlx::Postgres;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskBuilder<'a> {
|
||||
id: Uuid,
|
||||
delay: Duration,
|
||||
channel_name: &'a str,
|
||||
channel_args: &'a str,
|
||||
retries: usize,
|
||||
retry_backoff: Duration,
|
||||
commit_interval: Option<Duration>,
|
||||
ordered: bool,
|
||||
name: &'a str,
|
||||
payload_json: Option<Cow<'a, str>>,
|
||||
payload_bytes: Option<&'a [u8]>,
|
||||
}
|
||||
|
||||
impl<'a> TaskBuilder<'a> {
|
||||
pub fn new(name: &'a str) -> Self {
|
||||
Self::new_with_id(Uuid::new_v4(), name)
|
||||
}
|
||||
pub fn new_with_id(id: Uuid, name: &'a str) -> Self {
|
||||
Self {
|
||||
id,
|
||||
delay: Duration::from_secs(0),
|
||||
channel_name: "",
|
||||
channel_args: "",
|
||||
retries: 4,
|
||||
retry_backoff: Duration::from_secs(1),
|
||||
commit_interval: None,
|
||||
ordered: false,
|
||||
name,
|
||||
payload_json: None,
|
||||
payload_bytes: None,
|
||||
}
|
||||
}
|
||||
pub fn set_channel_name(&mut self, channel_name: &'a str) -> &mut Self {
|
||||
self.channel_name = channel_name;
|
||||
self
|
||||
}
|
||||
pub fn set_channel_args(&mut self, channel_args: &'a str) -> &mut Self {
|
||||
self.channel_args = channel_args;
|
||||
self
|
||||
}
|
||||
pub fn set_retries(&mut self, retries: usize) -> &mut Self {
|
||||
self.retries = retries;
|
||||
self
|
||||
}
|
||||
pub fn set_retry_backoff(&mut self, retry_backoff: Duration) -> &mut Self {
|
||||
self.retry_backoff = retry_backoff;
|
||||
self
|
||||
}
|
||||
pub fn set_commit_interval(&mut self, commit_interval: Option<Duration>) -> &mut Self {
|
||||
self.commit_interval = commit_interval;
|
||||
self
|
||||
}
|
||||
pub fn set_ordered(&mut self, ordered: bool) -> &mut Self {
|
||||
self.ordered = ordered;
|
||||
self
|
||||
}
|
||||
pub fn set_delay(&mut self, delay: Duration) -> &mut Self {
|
||||
self.delay = delay;
|
||||
self
|
||||
}
|
||||
pub fn set_raw_json(&mut self, raw_json: &'a str) -> &mut Self {
|
||||
self.payload_json = Some(Cow::Borrowed(raw_json));
|
||||
self
|
||||
}
|
||||
pub fn set_raw_bytes(&mut self, raw_bytes: &'a [u8]) -> &mut Self {
|
||||
self.payload_bytes = Some(raw_bytes);
|
||||
self
|
||||
}
|
||||
pub fn set_json<T: Serialize>(&mut self, value: &T) -> Result<&mut Self, serde_json::Error> {
|
||||
let value = serde_json::to_string(value)?;
|
||||
self.payload_json = Some(Cow::Owned(value));
|
||||
Ok(self)
|
||||
}
|
||||
pub async fn spawn<'b, E: sqlx::Executor<'b, Database = Postgres>>(
|
||||
&self,
|
||||
executor: E,
|
||||
) -> Result<Uuid, sqlx::Error> {
|
||||
sqlx::query(
|
||||
"SELECT mq_insert(ARRAY[($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)::mq_new_t])",
|
||||
)
|
||||
.bind(self.id)
|
||||
.bind(self.delay)
|
||||
.bind(self.retries as i32)
|
||||
.bind(self.retry_backoff)
|
||||
.bind(self.channel_name)
|
||||
.bind(self.channel_args)
|
||||
.bind(self.commit_interval)
|
||||
.bind(self.ordered)
|
||||
.bind(self.name)
|
||||
.bind(self.payload_json.as_deref())
|
||||
.bind(self.payload_bytes)
|
||||
.execute(executor)
|
||||
.await?;
|
||||
Ok(self.id)
|
||||
}
|
||||
}
|
37
src/utils.rs
Normal file
37
src/utils.rs
Normal file
|
@ -0,0 +1,37 @@
|
|||
use std::any::Any;
|
||||
use std::fmt::{self, Debug};
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct Opaque<T: Any>(pub T);
|
||||
|
||||
impl<T: Any> Debug for Opaque<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
<dyn Any>::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Any> Deref for Opaque<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Any> DerefMut for Opaque<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OwnedTask(pub JoinHandle<()>);
|
||||
|
||||
impl Drop for OwnedTask {
|
||||
fn drop(&mut self) {
|
||||
self.0.abort();
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue