Initial commit

This commit is contained in:
Diggory Blake 2021-03-28 02:57:17 +01:00
commit 6681fe87de
9 changed files with 1022 additions and 0 deletions

1
.env Normal file
View file

@ -0,0 +1 @@
DATABASE_URL=postgres://postgres:password@localhost/sqlxmq

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/target
Cargo.lock

27
Cargo.toml Normal file
View file

@ -0,0 +1,27 @@
[package]
name = "sqlxmq"
version = "0.1.0"
authors = ["Diggory Blake <diggsey@googlemail.com>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
sqlx = { version = "0.5.1", features = [
"postgres",
"runtime-async-std-native-tls",
"chrono",
"uuid",
] }
tokio = { version = "1.3.0", features = ["full"] }
dotenv = "0.15.0"
futures = "0.3.13"
chrono = "0.4.19"
uuid = { version = "0.8.2", features = ["v4"] }
log = "0.4.14"
serde_json = "1.0.64"
serde = "1.0.124"
[dev-dependencies]
dotenv = "0.15.0"
pretty_env_logger = "0.4.0"

View file

@ -0,0 +1,12 @@
DROP FUNCTION mq_checkpoint;
DROP FUNCTION mq_keep_alive;
DROP FUNCTION mq_delete;
DROP FUNCTION mq_commit;
DROP FUNCTION mq_insert;
DROP FUNCTION mq_poll;
DROP FUNCTION mq_active_channels;
DROP FUNCTION mq_latest_message;
DROP TABLE mq_payloads;
DROP TABLE mq_msgs;
DROP FUNCTION mq_uuid_exists;
DROP TYPE mq_new_t;

View file

@ -0,0 +1,286 @@
-- The UDT for creating messages
CREATE TYPE mq_new_t AS (
-- Unique message ID
id UUID,
-- Delay before message is processed
delay INTERVAL,
-- Number of retries if initial processing fails
retries INT,
-- Initial backoff between retries
retry_backoff INTERVAL,
-- Name of channel
channel_name TEXT,
-- Arguments to channel
channel_args TEXT,
-- Interval for two-phase commit (or NULL to disable two-phase commit)
commit_interval INTERVAL,
-- Whether this message should be processed in order with respect to other
-- ordered messages.
ordered BOOLEAN,
-- Name of message
name TEXT,
-- JSON payload
payload_json TEXT,
-- Binary payload
payload_bytes BYTEA
);
-- Small, frequently updated table of messages
CREATE TABLE mq_msgs (
id UUID PRIMARY KEY,
created_at TIMESTAMPTZ DEFAULT NOW(),
attempt_at TIMESTAMPTZ DEFAULT NOW(),
attempts INT NOT NULL DEFAULT 5,
retry_backoff INTERVAL NOT NULL DEFAULT INTERVAL '1 second',
channel_name TEXT NOT NULL,
channel_args TEXT NOT NULL,
commit_interval INTERVAL,
after_message_id UUID DEFAULT uuid_nil() REFERENCES mq_msgs(id) ON DELETE SET DEFAULT
);
-- Insert dummy message so that the 'nil' UUID can be referenced
INSERT INTO mq_msgs (id, channel_name, channel_args, after_message_id) VALUES (uuid_nil(), '', '', NULL);
-- Internal helper function to check that a UUID is neither NULL nor NIL
CREATE FUNCTION mq_uuid_exists(
id UUID
) RETURNS BOOLEAN AS $$
SELECT id IS NOT NULL AND id != uuid_nil()
$$ LANGUAGE SQL IMMUTABLE;
-- Index for polling
CREATE INDEX ON mq_msgs(channel_name, channel_args, attempt_at) WHERE id != uuid_nil() AND NOT mq_uuid_exists(after_message_id);
-- Index for adding messages
CREATE INDEX ON mq_msgs(channel_name, channel_args, created_at, id) WHERE id != uuid_nil() AND after_message_id IS NOT NULL;
-- Index for ensuring strict message order
CREATE UNIQUE INDEX ON mq_msgs(channel_name, channel_args, after_message_id);
-- Large, less frequently updated table of message payloads
CREATE TABLE mq_payloads(
id UUID PRIMARY KEY,
name TEXT NOT NULL,
payload_json JSONB,
payload_bytes BYTEA
);
-- Internal helper function to return the most recently added message in a queue.
CREATE FUNCTION mq_latest_message(from_channel_name TEXT, from_channel_args TEXT)
RETURNS UUID AS $$
SELECT COALESCE(
(
SELECT id FROM mq_msgs
WHERE channel_name = from_channel_name
AND channel_args = from_channel_args
AND after_message_id IS NOT NULL
AND id != uuid_nil()
ORDER BY created_at DESC, id DESC
LIMIT 1
),
uuid_nil()
)
$$ LANGUAGE SQL STABLE;
-- Internal helper function to randomly select a set of channels with "ready" messages.
CREATE FUNCTION mq_active_channels(channel_names TEXT[], batch_size INT)
RETURNS TABLE(name TEXT, args TEXT) AS $$
SELECT channel_name, channel_args
FROM mq_msgs
WHERE id != uuid_nil()
AND attempt_at <= NOW()
AND (channel_names IS NULL OR channel_name = ANY(channel_names))
AND NOT mq_uuid_exists(after_message_id)
GROUP BY channel_name, channel_args
ORDER BY RANDOM()
LIMIT batch_size
$$ LANGUAGE SQL STABLE;
-- Main entry-point for task runner: pulls a batch of messages from the queue.
CREATE FUNCTION mq_poll(channel_names TEXT[], batch_size INT DEFAULT 1)
RETURNS TABLE(
id UUID,
is_committed BOOLEAN,
name TEXT,
payload_json TEXT,
payload_bytes BYTEA,
retry_backoff INTERVAL,
wait_time INTERVAL
) AS $$
BEGIN
RETURN QUERY UPDATE mq_msgs
SET
attempt_at = CASE WHEN mq_msgs.attempts = 1 THEN NULL ELSE NOW() + mq_msgs.retry_backoff END,
attempts = mq_msgs.attempts - 1,
retry_backoff = mq_msgs.retry_backoff * 2
FROM (
SELECT
msgs.id
FROM mq_active_channels(channel_names, batch_size) AS active_channels
INNER JOIN LATERAL (
SELECT * FROM mq_msgs
WHERE mq_msgs.id != uuid_nil()
AND mq_msgs.attempt_at <= NOW()
AND mq_msgs.channel_name = active_channels.name
AND mq_msgs.channel_args = active_channels.args
AND NOT mq_uuid_exists(mq_msgs.after_message_id)
ORDER BY mq_msgs.attempt_at ASC
LIMIT batch_size
) AS msgs ON TRUE
LIMIT batch_size
) AS messages_to_update
LEFT JOIN mq_payloads ON mq_payloads.id = messages_to_update.id
WHERE mq_msgs.id = messages_to_update.id
RETURNING
mq_msgs.id,
mq_msgs.commit_interval IS NULL,
mq_payloads.name,
mq_payloads.payload_json::TEXT,
mq_payloads.payload_bytes,
mq_msgs.retry_backoff / 2,
interval '0' AS wait_time;
IF NOT FOUND THEN
RETURN QUERY SELECT
NULL::UUID,
NULL::BOOLEAN,
NULL::TEXT,
NULL::TEXT,
NULL::BYTEA,
NULL::INTERVAL,
MIN(mq_msgs.attempt_at) - NOW()
FROM mq_msgs
WHERE mq_msgs.id != uuid_nil()
AND NOT mq_uuid_exists(mq_msgs.after_message_id)
AND (channel_names IS NULL OR mq_msgs.channel_name = ANY(channel_names));
END IF;
END;
$$ LANGUAGE plpgsql;
-- Creates new messages
CREATE FUNCTION mq_insert(new_messages mq_new_t[])
RETURNS VOID AS $$
BEGIN
PERFORM pg_notify(CONCAT('mq_', channel_name), '')
FROM unnest(new_messages) AS new_msgs
GROUP BY channel_name;
IF FOUND THEN
PERFORM pg_notify('mq', '');
END IF;
INSERT INTO mq_payloads (
id,
name,
payload_json,
payload_bytes
) SELECT
id,
name,
payload_json::JSONB,
payload_bytes
FROM UNNEST(new_messages);
INSERT INTO mq_msgs (
id,
attempt_at,
attempts,
retry_backoff,
channel_name,
channel_args,
commit_interval,
after_message_id
)
SELECT
id,
NOW() + delay + COALESCE(commit_interval, INTERVAL '0'),
retries + 1,
retry_backoff,
channel_name,
channel_args,
commit_interval,
CASE WHEN ordered
THEN
LAG(id, 1, mq_latest_message(channel_name, channel_args))
OVER (PARTITION BY channel_name, channel_args, ordered ORDER BY id)
ELSE
NULL
END
FROM UNNEST(new_messages);
END;
$$ LANGUAGE plpgsql;
-- Commits messages previously created with a non-NULL commit interval.
CREATE FUNCTION mq_commit(msg_ids UUID[])
RETURNS VOID AS $$
BEGIN
UPDATE mq_msgs
SET
attempt_at = attempt_at - commit_interval,
commit_interval = NULL
WHERE id = ANY(msg_ids)
AND commit_interval IS NOT NULL;
END;
$$ LANGUAGE plpgsql;
-- Deletes messages from the queue. This occurs when a message has been
-- processed, or when it expires without being processed.
CREATE FUNCTION mq_delete(msg_ids UUID[])
RETURNS VOID AS $$
BEGIN
PERFORM pg_notify(CONCAT('mq_', channel_name), '')
FROM mq_msgs
WHERE id = ANY(msg_ids)
AND after_message_id = uuid_nil()
GROUP BY channel_name;
IF FOUND THEN
PERFORM pg_notify('mq', '');
END IF;
DELETE FROM mq_msgs WHERE id = ANY(msg_ids);
DELETE FROM mq_payloads WHERE id = ANY(msg_ids);
END;
$$ LANGUAGE plpgsql;
-- Can be called during the initial commit interval, or when processing
-- a message. Indicates that the caller is still active and will prevent either
-- the commit interval elapsing or the message being retried for the specified
-- interval.
CREATE FUNCTION mq_keep_alive(msg_ids UUID[], duration INTERVAL)
RETURNS VOID AS $$
UPDATE mq_msgs
SET
attempt_at = NOW() + duration,
commit_interval = commit_interval + ((NOW() + duration) - attempt_at)
WHERE id = ANY(msg_ids)
AND attempt_at < NOW() + duration;
$$ LANGUAGE SQL;
-- Called during lengthy processing of a message to checkpoint the progress.
-- As well as behaving like `mq_keep_alive`, the message payload can be
-- updated.
CREATE FUNCTION mq_checkpoint(
msg_id UUID,
duration INTERVAL,
new_payload_json TEXT,
new_payload_bytes BYTEA,
extra_retries INT
)
RETURNS VOID AS $$
UPDATE mq_msgs
SET
attempt_at = GREATEST(attempt_at, NOW() + duration),
attempts = attempts + COALESCE(extra_retries, 0)
WHERE id = msg_id;
UPDATE mq_payloads
SET
payload_json = COALESCE(new_payload_json::JSONB, payload_json),
payload_bytes = COALESCE(new_payload_bytes, payload_bytes)
WHERE
id = msg_id;
$$ LANGUAGE SQL;

195
src/lib.rs Normal file
View file

@ -0,0 +1,195 @@
mod runner;
mod spawn;
mod utils;
pub use runner::*;
pub use spawn::*;
pub use utils::OwnedTask;
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::future::Future;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Once};
use std::time::Duration;
use futures::channel::mpsc;
use futures::StreamExt;
use sqlx::{Pool, Postgres};
use tokio::sync::{Mutex, MutexGuard};
use tokio::task;
struct TestGuard<T>(MutexGuard<'static, ()>, T);
impl<T> Deref for TestGuard<T> {
type Target = T;
fn deref(&self) -> &T {
&self.1
}
}
async fn test_pool() -> TestGuard<Pool<Postgres>> {
static INIT_LOGGER: Once = Once::new();
static TEST_MUTEX: Mutex<()> = Mutex::const_new(());
let guard = TEST_MUTEX.lock().await;
let _ = dotenv::dotenv();
INIT_LOGGER.call_once(|| pretty_env_logger::init());
let pool = Pool::connect(&env::var("DATABASE_URL").unwrap())
.await
.unwrap();
sqlx::query("TRUNCATE TABLE mq_payloads")
.execute(&pool)
.await
.unwrap();
sqlx::query("DELETE FROM mq_msgs WHERE id != uuid_nil()")
.execute(&pool)
.await
.unwrap();
TestGuard(guard, pool)
}
async fn test_task_runner<F: Future + Send + 'static>(
pool: &Pool<Postgres>,
f: impl (Fn(CurrentTask) -> F) + Send + Sync + 'static,
) -> (OwnedTask, Arc<AtomicUsize>)
where
F::Output: Send + 'static,
{
let counter = Arc::new(AtomicUsize::new(0));
let counter2 = counter.clone();
let runner = TaskRunnerOptions::new(pool, move |task| {
counter2.fetch_add(1, Ordering::SeqCst);
task::spawn(f(task));
})
.run()
.await
.unwrap();
(runner, counter)
}
async fn pause() {
pause_ms(50).await;
}
async fn pause_ms(ms: u64) {
tokio::time::sleep(Duration::from_millis(ms)).await;
}
#[tokio::test]
async fn it_can_spawn_task() {
let pool = &*test_pool().await;
let (_runner, counter) =
test_task_runner(&pool, |mut task| async move { task.complete().await }).await;
assert_eq!(counter.load(Ordering::SeqCst), 0);
TaskBuilder::new("foo").spawn(pool).await.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn it_runs_tasks_in_order() {
let pool = &*test_pool().await;
let (tx, mut rx) = mpsc::unbounded();
let (_runner, counter) = test_task_runner(&pool, move |task| {
let tx = tx.clone();
async move {
tx.unbounded_send(task).unwrap();
}
})
.await;
assert_eq!(counter.load(Ordering::SeqCst), 0);
TaskBuilder::new("foo")
.set_ordered(true)
.spawn(pool)
.await
.unwrap();
TaskBuilder::new("bar")
.set_ordered(true)
.spawn(pool)
.await
.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
let mut task = rx.next().await.unwrap();
task.complete().await.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn it_runs_tasks_in_parallel() {
let pool = &*test_pool().await;
let (tx, mut rx) = mpsc::unbounded();
let (_runner, counter) = test_task_runner(&pool, move |task| {
let tx = tx.clone();
async move {
tx.unbounded_send(task).unwrap();
}
})
.await;
assert_eq!(counter.load(Ordering::SeqCst), 0);
TaskBuilder::new("foo").spawn(pool).await.unwrap();
TaskBuilder::new("bar").spawn(pool).await.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
for _ in 0..2 {
let mut task = rx.next().await.unwrap();
task.complete().await.unwrap();
}
}
#[tokio::test]
async fn it_retries_failed_tasks() {
let pool = &*test_pool().await;
let (_runner, counter) = test_task_runner(&pool, move |_| async {}).await;
let backoff = 100;
assert_eq!(counter.load(Ordering::SeqCst), 0);
TaskBuilder::new("foo")
.set_retry_backoff(Duration::from_millis(backoff))
.set_retries(2)
.spawn(pool)
.await
.unwrap();
// First attempt
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
// Second attempt
pause_ms(backoff).await;
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
// Third attempt
pause_ms(backoff * 2).await;
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 3);
// No more attempts
pause_ms(backoff * 5).await;
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
}

356
src/runner.rs Normal file
View file

@ -0,0 +1,356 @@
use std::borrow::Cow;
use std::fmt::Debug;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use sqlx::postgres::types::PgInterval;
use sqlx::postgres::PgListener;
use sqlx::{Pool, Postgres};
use tokio::sync::Notify;
use tokio::task;
use uuid::Uuid;
use crate::utils::{Opaque, OwnedTask};
#[derive(Debug, Clone)]
pub struct TaskRunnerOptions {
min_concurrency: usize,
max_concurrency: usize,
channel_names: Option<Vec<String>>,
runner: Opaque<Arc<dyn Fn(CurrentTask) + Send + Sync + 'static>>,
pool: Pool<Postgres>,
keep_alive: bool,
}
#[derive(Debug)]
struct TaskRunner {
options: TaskRunnerOptions,
running_tasks: AtomicUsize,
notify: Notify,
}
#[derive(Debug, Clone)]
pub struct Checkpoint<'a> {
duration: Duration,
extra_retries: usize,
payload_json: Option<Cow<'a, str>>,
payload_bytes: Option<&'a [u8]>,
}
impl<'a> Checkpoint<'a> {
pub fn new(duration: Duration) -> Self {
Self {
duration,
extra_retries: 0,
payload_json: None,
payload_bytes: None,
}
}
pub fn set_extra_retries(&mut self, extra_retries: usize) -> &mut Self {
self.extra_retries = extra_retries;
self
}
pub fn set_raw_json(&mut self, raw_json: &'a str) -> &mut Self {
self.payload_json = Some(Cow::Borrowed(raw_json));
self
}
pub fn set_raw_bytes(&mut self, raw_bytes: &'a [u8]) -> &mut Self {
self.payload_bytes = Some(raw_bytes);
self
}
pub fn set_json<T: Serialize>(&mut self, value: &T) -> Result<&mut Self, serde_json::Error> {
let value = serde_json::to_string(value)?;
self.payload_json = Some(Cow::Owned(value));
Ok(self)
}
async fn execute<'b, E: sqlx::Executor<'b, Database = Postgres>>(
&self,
task_id: Uuid,
executor: E,
) -> Result<(), sqlx::Error> {
sqlx::query("SELECT mq_checkpoint($1, $2, $3, $4, $5)")
.bind(task_id)
.bind(self.duration)
.bind(self.payload_json.as_deref())
.bind(self.payload_bytes)
.bind(self.extra_retries as i32)
.execute(executor)
.await?;
Ok(())
}
}
#[derive(Debug)]
pub struct CurrentTask {
id: Uuid,
name: String,
payload_json: Option<String>,
payload_bytes: Option<Vec<u8>>,
task_runner: Arc<TaskRunner>,
keep_alive: Option<OwnedTask>,
}
impl CurrentTask {
pub fn pool(&self) -> &Pool<Postgres> {
&self.task_runner.options.pool
}
async fn delete(
&self,
executor: impl sqlx::Executor<'_, Database = Postgres>,
) -> Result<(), sqlx::Error> {
sqlx::query("SELECT mq_delete(ARRAY[$1])")
.bind(self.id)
.execute(executor)
.await?;
Ok(())
}
pub async fn complete_with_transaction(
&mut self,
mut tx: sqlx::Transaction<'_, Postgres>,
) -> Result<(), sqlx::Error> {
self.delete(&mut tx).await?;
tx.commit().await?;
self.keep_alive = None;
Ok(())
}
pub async fn complete(&mut self) -> Result<(), sqlx::Error> {
self.delete(self.pool()).await?;
self.keep_alive = None;
Ok(())
}
pub async fn checkpoint_with_transaction(
&mut self,
mut tx: sqlx::Transaction<'_, Postgres>,
checkpoint: &Checkpoint<'_>,
) -> Result<(), sqlx::Error> {
checkpoint.execute(self.id, &mut tx).await?;
tx.commit().await?;
Ok(())
}
pub async fn checkpoint(&mut self, checkpoint: &Checkpoint<'_>) -> Result<(), sqlx::Error> {
checkpoint.execute(self.id, self.pool()).await?;
Ok(())
}
pub async fn keep_alive(&mut self, duration: Duration) -> Result<(), sqlx::Error> {
sqlx::query("SELECT mq_keep_alive(ARRAY[$1], $2)")
.bind(self.id)
.bind(duration)
.execute(self.pool())
.await?;
Ok(())
}
pub fn id(&self) -> Uuid {
self.id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn json<'a, T: Deserialize<'a>>(&'a self) -> Result<Option<T>, serde_json::Error> {
if let Some(payload_json) = &self.payload_json {
serde_json::from_str(payload_json).map(Some)
} else {
Ok(None)
}
}
pub fn raw_json(&self) -> Option<&str> {
self.payload_json.as_deref()
}
pub fn raw_bytes(&self) -> Option<&[u8]> {
self.payload_bytes.as_deref()
}
}
impl Drop for CurrentTask {
fn drop(&mut self) {
if self
.task_runner
.running_tasks
.fetch_sub(1, Ordering::SeqCst)
== self.task_runner.options.min_concurrency
{
self.task_runner.notify.notify_one();
}
}
}
impl TaskRunnerOptions {
pub fn new<F: Fn(CurrentTask) + Send + Sync + 'static>(pool: &Pool<Postgres>, f: F) -> Self {
Self {
min_concurrency: 16,
max_concurrency: 32,
channel_names: None,
keep_alive: true,
runner: Opaque(Arc::new(f)),
pool: pool.clone(),
}
}
pub async fn run(&self) -> Result<OwnedTask, sqlx::Error> {
let options = self.clone();
let task_runner = Arc::new(TaskRunner {
options,
running_tasks: AtomicUsize::new(0),
notify: Notify::new(),
});
let listener_task = start_listener(task_runner.clone()).await?;
Ok(OwnedTask(task::spawn(main_loop(
task_runner,
listener_task,
))))
}
}
async fn start_listener(task_runner: Arc<TaskRunner>) -> Result<OwnedTask, sqlx::Error> {
let mut listener = PgListener::connect_with(&task_runner.options.pool).await?;
if let Some(channels) = &task_runner.options.channel_names {
let names: Vec<String> = channels.iter().map(|c| format!("mq_{}", c)).collect();
listener
.listen_all(names.iter().map(|s| s.as_str()))
.await?;
} else {
listener.listen("mq").await?;
}
Ok(OwnedTask(task::spawn(async move {
while let Ok(_) = listener.recv().await {
task_runner.notify.notify_one();
}
})))
}
#[derive(sqlx::FromRow)]
struct PolledMessage {
id: Option<Uuid>,
is_committed: Option<bool>,
name: Option<String>,
payload_json: Option<String>,
payload_bytes: Option<Vec<u8>>,
retry_backoff: Option<PgInterval>,
wait_time: Option<PgInterval>,
}
fn to_duration(interval: PgInterval) -> Duration {
const SECONDS_PER_DAY: u64 = 24 * 60 * 60;
if interval.microseconds < 0 || interval.days < 0 || interval.months < 0 {
Duration::default()
} else {
let days = (interval.days as u64) + (interval.months as u64) * 30;
Duration::from_micros(interval.microseconds as u64)
+ Duration::from_secs(days * SECONDS_PER_DAY)
}
}
async fn poll_and_dispatch(
task_runner: &Arc<TaskRunner>,
batch_size: i32,
) -> Result<Duration, sqlx::Error> {
log::info!("Polling for messages");
let options = &task_runner.options;
let messages = sqlx::query_as::<_, PolledMessage>("SELECT * FROM mq_poll($1, $2)")
.bind(&options.channel_names)
.bind(batch_size)
.fetch_all(&options.pool)
.await?;
let ids_to_delete: Vec<_> = messages
.iter()
.filter(|msg| msg.is_committed == Some(false))
.filter_map(|msg| msg.id)
.collect();
log::info!("Deleting {} messages", ids_to_delete.len());
if !ids_to_delete.is_empty() {
sqlx::query("SELECT mq_delete($1)")
.bind(ids_to_delete)
.execute(&options.pool)
.await?;
}
let wait_time = messages
.iter()
.filter_map(|msg| msg.wait_time.clone())
.map(to_duration)
.min()
.unwrap_or(Duration::from_secs(60));
for msg in messages {
if let PolledMessage {
id: Some(id),
is_committed: Some(true),
name: Some(name),
payload_json,
payload_bytes,
retry_backoff: Some(retry_backoff),
..
} = msg
{
let retry_backoff = to_duration(retry_backoff);
let keep_alive = if options.keep_alive {
Some(OwnedTask(task::spawn(keep_task_alive(
id,
options.pool.clone(),
retry_backoff,
))))
} else {
None
};
let current_task = CurrentTask {
id,
name,
payload_json,
payload_bytes,
task_runner: task_runner.clone(),
keep_alive,
};
task_runner.running_tasks.fetch_add(1, Ordering::SeqCst);
(options.runner)(current_task);
}
}
Ok(wait_time)
}
async fn main_loop(task_runner: Arc<TaskRunner>, _listener_task: OwnedTask) {
let options = &task_runner.options;
let mut failures = 0;
loop {
let running_tasks = task_runner.running_tasks.load(Ordering::SeqCst);
let duration = if running_tasks < options.min_concurrency {
let batch_size = (options.max_concurrency - running_tasks) as i32;
match poll_and_dispatch(&task_runner, batch_size).await {
Ok(duration) => {
failures = 0;
duration
}
Err(e) => {
failures += 1;
log::error!("Failed to poll for messages: {}", e);
Duration::from_millis(50 << failures)
}
}
} else {
Duration::from_secs(60)
};
// Wait for us to be notified, or for the timeout to elapse
let _ = tokio::time::timeout(duration, task_runner.notify.notified()).await;
}
}
async fn keep_task_alive(id: Uuid, pool: Pool<Postgres>, mut interval: Duration) {
loop {
tokio::time::sleep(interval / 2).await;
interval *= 2;
if let Err(e) = sqlx::query("SELECT mq_keep_alive(ARRAY[$1], $2)")
.bind(id)
.bind(interval)
.execute(&pool)
.await
{
log::error!("Failed to keep task {} alive: {}", id, e);
break;
}
}
}

106
src/spawn.rs Normal file
View file

@ -0,0 +1,106 @@
use std::borrow::Cow;
use std::fmt::Debug;
use std::time::Duration;
use serde::Serialize;
use sqlx::Postgres;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct TaskBuilder<'a> {
id: Uuid,
delay: Duration,
channel_name: &'a str,
channel_args: &'a str,
retries: usize,
retry_backoff: Duration,
commit_interval: Option<Duration>,
ordered: bool,
name: &'a str,
payload_json: Option<Cow<'a, str>>,
payload_bytes: Option<&'a [u8]>,
}
impl<'a> TaskBuilder<'a> {
pub fn new(name: &'a str) -> Self {
Self::new_with_id(Uuid::new_v4(), name)
}
pub fn new_with_id(id: Uuid, name: &'a str) -> Self {
Self {
id,
delay: Duration::from_secs(0),
channel_name: "",
channel_args: "",
retries: 4,
retry_backoff: Duration::from_secs(1),
commit_interval: None,
ordered: false,
name,
payload_json: None,
payload_bytes: None,
}
}
pub fn set_channel_name(&mut self, channel_name: &'a str) -> &mut Self {
self.channel_name = channel_name;
self
}
pub fn set_channel_args(&mut self, channel_args: &'a str) -> &mut Self {
self.channel_args = channel_args;
self
}
pub fn set_retries(&mut self, retries: usize) -> &mut Self {
self.retries = retries;
self
}
pub fn set_retry_backoff(&mut self, retry_backoff: Duration) -> &mut Self {
self.retry_backoff = retry_backoff;
self
}
pub fn set_commit_interval(&mut self, commit_interval: Option<Duration>) -> &mut Self {
self.commit_interval = commit_interval;
self
}
pub fn set_ordered(&mut self, ordered: bool) -> &mut Self {
self.ordered = ordered;
self
}
pub fn set_delay(&mut self, delay: Duration) -> &mut Self {
self.delay = delay;
self
}
pub fn set_raw_json(&mut self, raw_json: &'a str) -> &mut Self {
self.payload_json = Some(Cow::Borrowed(raw_json));
self
}
pub fn set_raw_bytes(&mut self, raw_bytes: &'a [u8]) -> &mut Self {
self.payload_bytes = Some(raw_bytes);
self
}
pub fn set_json<T: Serialize>(&mut self, value: &T) -> Result<&mut Self, serde_json::Error> {
let value = serde_json::to_string(value)?;
self.payload_json = Some(Cow::Owned(value));
Ok(self)
}
pub async fn spawn<'b, E: sqlx::Executor<'b, Database = Postgres>>(
&self,
executor: E,
) -> Result<Uuid, sqlx::Error> {
sqlx::query(
"SELECT mq_insert(ARRAY[($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)::mq_new_t])",
)
.bind(self.id)
.bind(self.delay)
.bind(self.retries as i32)
.bind(self.retry_backoff)
.bind(self.channel_name)
.bind(self.channel_args)
.bind(self.commit_interval)
.bind(self.ordered)
.bind(self.name)
.bind(self.payload_json.as_deref())
.bind(self.payload_bytes)
.execute(executor)
.await?;
Ok(self.id)
}
}

37
src/utils.rs Normal file
View file

@ -0,0 +1,37 @@
use std::any::Any;
use std::fmt::{self, Debug};
use std::ops::{Deref, DerefMut};
use tokio::task::JoinHandle;
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Opaque<T: Any>(pub T);
impl<T: Any> Debug for Opaque<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<dyn Any>::fmt(&self.0, f)
}
}
impl<T: Any> Deref for Opaque<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T: Any> DerefMut for Opaque<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Debug)]
pub struct OwnedTask(pub JoinHandle<()>);
impl Drop for OwnedTask {
fn drop(&mut self) {
self.0.abort();
}
}