Add support for passing custom context via a JobRegistry

This commit is contained in:
Diggory Blake 2021-07-17 18:45:06 +01:00
parent 1cc2262cf8
commit fc4e909d24
No known key found for this signature in database
GPG key ID: E6BDFA83146ABD40
10 changed files with 263 additions and 144 deletions

View file

@ -52,6 +52,8 @@ jobs:
test:
name: Test
runs-on: ubuntu-latest
env:
RUST_BACKTRACE: "1"
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
@ -67,10 +69,13 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: test
args: -- --nocapture
test_nightly:
name: Test (Nightly)
runs-on: ubuntu-latest
env:
RUST_BACKTRACE: "1"
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
@ -86,3 +91,4 @@ jobs:
- uses: actions-rs/cargo@v1
with:
command: test
args: -- --nocapture

4
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,4 @@
{
"rust-analyzer.checkOnSave.allFeatures": false,
"rust-analyzer.cargo.allFeatures": false
}

View file

@ -1,6 +1,6 @@
[package]
name = "sqlxmq"
version = "0.1.2"
version = "0.2.0"
authors = ["Diggory Blake <diggsey@googlemail.com>"]
edition = "2018"
license = "MIT OR Apache-2.0"
@ -16,14 +16,15 @@ members = ["sqlxmq_macros", "sqlxmq_stress"]
[dependencies]
sqlx = { version = "0.5.2", features = ["postgres", "chrono", "uuid"] }
tokio = { version = "1.4.0", features = ["full"] }
tokio = { version = "=1.8.0", features = ["full"] }
dotenv = "0.15.0"
chrono = "0.4.19"
uuid = { version = "0.8.2", features = ["v4"] }
log = "0.4.14"
serde_json = "1.0.64"
serde = "1.0.124"
sqlxmq_macros = { version = "0.1", path = "sqlxmq_macros" }
sqlxmq_macros = { version = "0.2.0", path = "sqlxmq_macros" }
anymap2 = "0.13.0"
[features]
default = ["runtime-tokio-native-tls"]

View file

@ -127,13 +127,17 @@ use sqlxmq::{job, CurrentJob};
// Arguments to the `#[job]` attribute allow setting default job options.
#[job(channel_name = "foo")]
async fn example_job(
// The first argument should always be the current job.
mut current_job: CurrentJob,
// Additional arguments are optional, but can be used to access context
// provided via `JobRegistry::set_context`.
message: &'static str,
) -> sqlx::Result<()> {
// Decode a JSON payload
let who: Option<String> = current_job.json()?;
// Do some work
println!("Hello, {}!", who.as_deref().unwrap_or("world"));
println!("{}, {}!", message, who.as_deref().unwrap_or("world"));
// Mark the job as complete
current_job.complete().await?;
@ -160,6 +164,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Here is where you can configure the registry
// registry.set_error_handler(...)
// And add context
registry.set_context("Hello");
let runner = registry
// Create a job runner using the connection pool.
.runner(&pool)

View file

@ -1,6 +1,6 @@
[package]
name = "sqlxmq_macros"
version = "0.1.2"
version = "0.2.0"
authors = ["Diggory Blake <diggsey@googlemail.com>"]
edition = "2018"
license = "MIT OR Apache-2.0"

View file

@ -92,8 +92,12 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// Marks a function as being a background job.
///
/// The function must take a single `CurrentJob` argument, and should
/// be async or return a future.
/// The first argument to the function must have type `CurrentJob`.
/// Additional arguments can be used to access context from the job
/// registry. Context is accessed based on the type of the argument.
/// Context arguments must be `Send + Sync + Clone + 'static`.
///
/// The function should be async or return a future.
///
/// The async result must be a `Result<(), E>` type, where `E` is convertible
/// to a `Box<dyn Error + Send + Sync + 'static>`, which is the case for most
@ -103,7 +107,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
///
/// # Name
///
/// ```
/// ```ignore
/// #[job("example")]
/// #[job(name="example")]
/// ```
@ -115,7 +119,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
///
/// # Channel name
///
/// ```
/// ```ignore
/// #[job(channel_name="foo")]
/// ```
///
@ -123,7 +127,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
///
/// # Retries
///
/// ```
/// ```ignore
/// #[job(retries = 3)]
/// ```
///
@ -131,7 +135,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
///
/// # Retry backoff
///
/// ```
/// ```ignore
/// #[job(backoff_secs=1.5)]
/// #[job(backoff_secs=2)]
/// ```
@ -140,7 +144,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
///
/// # Ordered
///
/// ```
/// ```ignore
/// #[job(ordered)]
/// #[job(ordered=true)]
/// #[job(ordered=false)]
@ -150,7 +154,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
///
/// # Prototype
///
/// ```
/// ```ignore
/// fn my_proto<'a, 'b>(
/// builder: &'a mut JobBuilder<'b>
/// ) -> &'a mut JobBuilder<'b> {
@ -170,7 +174,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// prototype will always be applied first so that explicit options can override it.
/// Each option can only be provided once in the attribute.
///
/// ```
/// ```ignore
/// #[job("my_job", proto(my_proto), retries=0, ordered)]
/// ```
///
@ -223,6 +227,18 @@ pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream {
});
}
let extract_ctx: Vec<_> = inner_fn
.sig
.inputs
.iter()
.skip(1)
.map(|_| {
quote! {
registry.context()
}
})
.collect();
let expanded = quote! {
#(#errors)*
#[allow(non_upper_case_globals)]
@ -234,7 +250,7 @@ pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream {
builder #(#chain)*
}),
sqlxmq::hidden::RunFn(|registry, current_job| {
registry.spawn_internal(#fq_name, inner(current_job));
registry.spawn_internal(#fq_name, inner(current_job #(, #extract_ctx)*));
}),
)
};

View file

@ -126,13 +126,17 @@
//! // Arguments to the `#[job]` attribute allow setting default job options.
//! #[job(channel_name = "foo")]
//! async fn example_job(
//! // The first argument should always be the current job.
//! mut current_job: CurrentJob,
//! // Additional arguments are optional, but can be used to access context
//! // provided via [`JobRegistry::set_context`].
//! message: &'static str,
//! ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
//! // Decode a JSON payload
//! let who: Option<String> = current_job.json()?;
//!
//! // Do some work
//! println!("Hello, {}!", who.as_deref().unwrap_or("world"));
//! println!("{}, {}!", message, who.as_deref().unwrap_or("world"));
//!
//! // Mark the job as complete
//! current_job.complete().await?;
@ -172,6 +176,9 @@
//! // Here is where you can configure the registry
//! // registry.set_error_handler(...)
//!
//! // And add context
//! registry.set_context("Hello");
//!
//! let runner = registry
//! // Create a job runner using the connection pool.
//! .runner(&pool)
@ -321,12 +328,22 @@ mod tests {
Ok(())
}
#[job]
async fn example_job_with_ctx(
mut current_job: CurrentJob,
ctx1: i32,
ctx2: &'static str,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
assert_eq!(ctx1, 42);
assert_eq!(ctx2, "Hello, world!");
current_job.complete().await?;
Ok(())
}
async fn named_job_runner(pool: &Pool<Postgres>) -> OwnedHandle {
JobRegistry::new(&[example_job1, example_job2])
.runner(pool)
.run()
.await
.unwrap()
let mut registry = JobRegistry::new(&[example_job1, example_job2, example_job_with_ctx]);
registry.set_context(42).set_context("Hello, world!");
registry.runner(pool).run().await.unwrap()
}
async fn pause() {
@ -339,6 +356,7 @@ mod tests {
#[tokio::test]
async fn it_can_spawn_job() {
{
let pool = &*test_pool().await;
let (_runner, counter) =
test_job_runner(&pool, |mut job| async move { job.complete().await }).await;
@ -348,9 +366,12 @@ mod tests {
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
pause().await;
}
#[tokio::test]
async fn it_runs_jobs_in_order() {
{
let pool = &*test_pool().await;
let (tx, mut rx) = mpsc::unbounded();
@ -383,9 +404,12 @@ mod tests {
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
pause().await;
}
#[tokio::test]
async fn it_runs_jobs_in_parallel() {
{
let pool = &*test_pool().await;
let (tx, mut rx) = mpsc::unbounded();
@ -409,9 +433,12 @@ mod tests {
job.complete().await.unwrap();
}
}
pause().await;
}
#[tokio::test]
async fn it_retries_failed_jobs() {
{
let pool = &*test_pool().await;
let (_runner, counter) = test_job_runner(&pool, move |_| async {}).await;
@ -443,9 +470,12 @@ mod tests {
pause_ms(backoff * 5).await;
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
pause().await;
}
#[tokio::test]
async fn it_can_checkpoint_jobs() {
{
let pool = &*test_pool().await;
let (_runner, counter) = test_job_runner(&pool, move |mut current_job| async move {
let state: bool = current_job.json().unwrap().unwrap();
@ -485,14 +515,20 @@ mod tests {
pause_ms(backoff * 3).await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
pause().await;
}
#[tokio::test]
async fn it_can_use_registry() {
{
let pool = &*test_pool().await;
let _runner = named_job_runner(pool).await;
example_job1.builder().spawn(pool).await.unwrap();
example_job2.builder().spawn(pool).await.unwrap();
example_job_with_ctx.builder().spawn(pool).await.unwrap();
pause().await;
}
pause().await;
}
}

View file

@ -1,9 +1,12 @@
use std::any::type_name;
use std::collections::HashMap;
use std::error::Error;
use std::fmt::Display;
use std::future::Future;
use std::sync::Arc;
use anymap2::any::CloneAnySendSync;
use anymap2::Map;
use sqlx::{Pool, Postgres};
use uuid::Uuid;
@ -16,6 +19,7 @@ use crate::{JobBuilder, JobRunnerOptions};
pub struct JobRegistry {
error_handler: Arc<dyn Fn(&str, Box<dyn Error + Send + 'static>) + Send + Sync>,
job_map: HashMap<&'static str, &'static NamedJob>,
context: Map<dyn CloneAnySendSync + Send + Sync>,
}
/// Error returned when a job is received whose name is not in the registry.
@ -41,6 +45,7 @@ impl JobRegistry {
Self {
error_handler: Arc::new(Self::default_error_handler),
job_map,
context: Map::new(),
}
}
@ -53,6 +58,24 @@ impl JobRegistry {
self
}
/// Provide context for the jobs.
pub fn set_context<C: Clone + Send + Sync + 'static>(&mut self, context: C) -> &mut Self {
self.context.insert(context);
self
}
/// Access job context. Will panic if context with this type has not been provided.
pub fn context<C: Clone + Send + Sync + 'static>(&self) -> C {
if let Some(c) = self.context.get::<C>() {
c.clone()
} else {
panic!(
"No context of type `{}` has been provided.",
type_name::<C>()
);
}
}
/// Look-up a job by name.
pub fn resolve_job(&self, name: &str) -> Option<&'static NamedJob> {
self.job_map.get(name).copied()

View file

@ -123,6 +123,13 @@ impl CurrentJob {
.await?;
Ok(())
}
async fn stop_keep_alive(&mut self) {
if let Some(keep_alive) = self.keep_alive.take() {
keep_alive.stop().await;
}
}
/// Complete this job and commit the provided transaction at the same time.
/// If the transaction cannot be committed, the job will not be completed.
pub async fn complete_with_transaction(
@ -131,13 +138,13 @@ impl CurrentJob {
) -> Result<(), sqlx::Error> {
self.delete(&mut tx).await?;
tx.commit().await?;
self.keep_alive = None;
self.stop_keep_alive().await;
Ok(())
}
/// Complete this job.
pub async fn complete(&mut self) -> Result<(), sqlx::Error> {
self.delete(self.pool()).await?;
self.keep_alive = None;
self.stop_keep_alive().await;
Ok(())
}
/// Checkpoint this job and commit the provided transaction at the same time.
@ -254,7 +261,7 @@ impl JobRunnerOptions {
notify: Notify::new(),
});
let listener_task = start_listener(job_runner.clone()).await?;
Ok(OwnedHandle(task::spawn(main_loop(
Ok(OwnedHandle::new(task::spawn(main_loop(
job_runner,
listener_task,
))))
@ -271,7 +278,7 @@ async fn start_listener(job_runner: Arc<JobRunner>) -> Result<OwnedHandle, sqlx:
} else {
listener.listen("mq").await?;
}
Ok(OwnedHandle(task::spawn(async move {
Ok(OwnedHandle::new(task::spawn(async move {
let mut num_errors = 0;
loop {
if num_errors > 0 || listener.recv().await.is_ok() {
@ -356,7 +363,7 @@ async fn poll_and_dispatch(
{
let retry_backoff = to_duration(retry_backoff);
let keep_alive = if options.keep_alive {
Some(OwnedHandle(task::spawn(keep_job_alive(
Some(OwnedHandle::new(task::spawn(keep_job_alive(
id,
options.pool.clone(),
retry_backoff,

View file

@ -31,10 +31,29 @@ impl<T: Any> DerefMut for Opaque<T> {
/// the handle is dropped. Extract the inner join handle to prevent this
/// behaviour.
#[derive(Debug)]
pub struct OwnedHandle(pub JoinHandle<()>);
pub struct OwnedHandle(Option<JoinHandle<()>>);
impl OwnedHandle {
/// Construct a new `OwnedHandle` from the provided `JoinHandle`
pub fn new(inner: JoinHandle<()>) -> Self {
Self(Some(inner))
}
/// Get back the original `JoinHandle`
pub fn into_inner(mut self) -> JoinHandle<()> {
self.0.take().expect("Only consumed once")
}
/// Stop the task and wait for it to finish.
pub async fn stop(self) {
let handle = self.into_inner();
handle.abort();
let _ = handle.await;
}
}
impl Drop for OwnedHandle {
fn drop(&mut self) {
self.0.abort();
if let Some(handle) = self.0.take() {
handle.abort();
}
}
}