From fc4e909d241bd6b5f588472ae91a7f20e809d2ec Mon Sep 17 00:00:00 2001 From: Diggory Blake Date: Sat, 17 Jul 2021 18:45:06 +0100 Subject: [PATCH] Add support for passing custom context via a JobRegistry --- .github/workflows/toolchain.yml | 6 + .vscode/settings.json | 4 + Cargo.toml | 7 +- README.md | 9 +- sqlxmq_macros/Cargo.toml | 2 +- sqlxmq_macros/src/lib.rs | 36 ++-- src/lib.rs | 280 ++++++++++++++++++-------------- src/registry.rs | 23 +++ src/runner.rs | 17 +- src/utils.rs | 23 ++- 10 files changed, 263 insertions(+), 144 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.github/workflows/toolchain.yml b/.github/workflows/toolchain.yml index 6c5c7ff..f198e7f 100644 --- a/.github/workflows/toolchain.yml +++ b/.github/workflows/toolchain.yml @@ -52,6 +52,8 @@ jobs: test: name: Test runs-on: ubuntu-latest + env: + RUST_BACKTRACE: "1" steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 @@ -67,10 +69,13 @@ jobs: - uses: actions-rs/cargo@v1 with: command: test + args: -- --nocapture test_nightly: name: Test (Nightly) runs-on: ubuntu-latest + env: + RUST_BACKTRACE: "1" steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 @@ -86,3 +91,4 @@ jobs: - uses: actions-rs/cargo@v1 with: command: test + args: -- --nocapture diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..50ddaa6 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "rust-analyzer.checkOnSave.allFeatures": false, + "rust-analyzer.cargo.allFeatures": false +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index abe6b22..bedf1ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlxmq" -version = "0.1.2" +version = "0.2.0" authors = ["Diggory Blake "] edition = "2018" license = "MIT OR Apache-2.0" @@ -16,14 +16,15 @@ members = ["sqlxmq_macros", "sqlxmq_stress"] [dependencies] sqlx = { version = "0.5.2", features = ["postgres", "chrono", "uuid"] } -tokio = { version = "1.4.0", features = ["full"] } +tokio = { version = "=1.8.0", features = ["full"] } dotenv = "0.15.0" chrono = "0.4.19" uuid = { version = "0.8.2", features = ["v4"] } log = "0.4.14" serde_json = "1.0.64" serde = "1.0.124" -sqlxmq_macros = { version = "0.1", path = "sqlxmq_macros" } +sqlxmq_macros = { version = "0.2.0", path = "sqlxmq_macros" } +anymap2 = "0.13.0" [features] default = ["runtime-tokio-native-tls"] diff --git a/README.md b/README.md index cf7d70d..2bc7d55 100644 --- a/README.md +++ b/README.md @@ -127,13 +127,17 @@ use sqlxmq::{job, CurrentJob}; // Arguments to the `#[job]` attribute allow setting default job options. #[job(channel_name = "foo")] async fn example_job( + // The first argument should always be the current job. mut current_job: CurrentJob, + // Additional arguments are optional, but can be used to access context + // provided via `JobRegistry::set_context`. + message: &'static str, ) -> sqlx::Result<()> { // Decode a JSON payload let who: Option = current_job.json()?; // Do some work - println!("Hello, {}!", who.as_deref().unwrap_or("world")); + println!("{}, {}!", message, who.as_deref().unwrap_or("world")); // Mark the job as complete current_job.complete().await?; @@ -160,6 +164,9 @@ async fn main() -> Result<(), Box> { // Here is where you can configure the registry // registry.set_error_handler(...) + // And add context + registry.set_context("Hello"); + let runner = registry // Create a job runner using the connection pool. .runner(&pool) diff --git a/sqlxmq_macros/Cargo.toml b/sqlxmq_macros/Cargo.toml index 5a079d0..6db7c48 100644 --- a/sqlxmq_macros/Cargo.toml +++ b/sqlxmq_macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlxmq_macros" -version = "0.1.2" +version = "0.2.0" authors = ["Diggory Blake "] edition = "2018" license = "MIT OR Apache-2.0" diff --git a/sqlxmq_macros/src/lib.rs b/sqlxmq_macros/src/lib.rs index 5f0af87..6a55e4c 100644 --- a/sqlxmq_macros/src/lib.rs +++ b/sqlxmq_macros/src/lib.rs @@ -92,8 +92,12 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> { /// Marks a function as being a background job. /// -/// The function must take a single `CurrentJob` argument, and should -/// be async or return a future. +/// The first argument to the function must have type `CurrentJob`. +/// Additional arguments can be used to access context from the job +/// registry. Context is accessed based on the type of the argument. +/// Context arguments must be `Send + Sync + Clone + 'static`. +/// +/// The function should be async or return a future. /// /// The async result must be a `Result<(), E>` type, where `E` is convertible /// to a `Box`, which is the case for most @@ -103,7 +107,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> { /// /// # Name /// -/// ``` +/// ```ignore /// #[job("example")] /// #[job(name="example")] /// ``` @@ -115,7 +119,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> { /// /// # Channel name /// -/// ``` +/// ```ignore /// #[job(channel_name="foo")] /// ``` /// @@ -123,7 +127,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> { /// /// # Retries /// -/// ``` +/// ```ignore /// #[job(retries = 3)] /// ``` /// @@ -131,7 +135,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> { /// /// # Retry backoff /// -/// ``` +/// ```ignore /// #[job(backoff_secs=1.5)] /// #[job(backoff_secs=2)] /// ``` @@ -140,7 +144,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> { /// /// # Ordered /// -/// ``` +/// ```ignore /// #[job(ordered)] /// #[job(ordered=true)] /// #[job(ordered=false)] @@ -150,7 +154,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> { /// /// # Prototype /// -/// ``` +/// ```ignore /// fn my_proto<'a, 'b>( /// builder: &'a mut JobBuilder<'b> /// ) -> &'a mut JobBuilder<'b> { @@ -170,7 +174,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> { /// prototype will always be applied first so that explicit options can override it. /// Each option can only be provided once in the attribute. /// -/// ``` +/// ```ignore /// #[job("my_job", proto(my_proto), retries=0, ordered)] /// ``` /// @@ -223,6 +227,18 @@ pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream { }); } + let extract_ctx: Vec<_> = inner_fn + .sig + .inputs + .iter() + .skip(1) + .map(|_| { + quote! { + registry.context() + } + }) + .collect(); + let expanded = quote! { #(#errors)* #[allow(non_upper_case_globals)] @@ -234,7 +250,7 @@ pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream { builder #(#chain)* }), sqlxmq::hidden::RunFn(|registry, current_job| { - registry.spawn_internal(#fq_name, inner(current_job)); + registry.spawn_internal(#fq_name, inner(current_job #(, #extract_ctx)*)); }), ) }; diff --git a/src/lib.rs b/src/lib.rs index 6309b9f..3867955 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,13 +126,17 @@ //! // Arguments to the `#[job]` attribute allow setting default job options. //! #[job(channel_name = "foo")] //! async fn example_job( +//! // The first argument should always be the current job. //! mut current_job: CurrentJob, +//! // Additional arguments are optional, but can be used to access context +//! // provided via [`JobRegistry::set_context`]. +//! message: &'static str, //! ) -> Result<(), Box> { //! // Decode a JSON payload //! let who: Option = current_job.json()?; //! //! // Do some work -//! println!("Hello, {}!", who.as_deref().unwrap_or("world")); +//! println!("{}, {}!", message, who.as_deref().unwrap_or("world")); //! //! // Mark the job as complete //! current_job.complete().await?; @@ -172,6 +176,9 @@ //! // Here is where you can configure the registry //! // registry.set_error_handler(...) //! +//! // And add context +//! registry.set_context("Hello"); +//! //! let runner = registry //! // Create a job runner using the connection pool. //! .runner(&pool) @@ -321,12 +328,22 @@ mod tests { Ok(()) } + #[job] + async fn example_job_with_ctx( + mut current_job: CurrentJob, + ctx1: i32, + ctx2: &'static str, + ) -> Result<(), Box> { + assert_eq!(ctx1, 42); + assert_eq!(ctx2, "Hello, world!"); + current_job.complete().await?; + Ok(()) + } + async fn named_job_runner(pool: &Pool) -> OwnedHandle { - JobRegistry::new(&[example_job1, example_job2]) - .runner(pool) - .run() - .await - .unwrap() + let mut registry = JobRegistry::new(&[example_job1, example_job2, example_job_with_ctx]); + registry.set_context(42).set_context("Hello, world!"); + registry.runner(pool).run().await.unwrap() } async fn pause() { @@ -339,160 +356,179 @@ mod tests { #[tokio::test] async fn it_can_spawn_job() { - let pool = &*test_pool().await; - let (_runner, counter) = - test_job_runner(&pool, |mut job| async move { job.complete().await }).await; + { + let pool = &*test_pool().await; + let (_runner, counter) = + test_job_runner(&pool, |mut job| async move { job.complete().await }).await; - assert_eq!(counter.load(Ordering::SeqCst), 0); - JobBuilder::new("foo").spawn(pool).await.unwrap(); + assert_eq!(counter.load(Ordering::SeqCst), 0); + JobBuilder::new("foo").spawn(pool).await.unwrap(); + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 1); + } pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 1); } #[tokio::test] async fn it_runs_jobs_in_order() { - let pool = &*test_pool().await; - let (tx, mut rx) = mpsc::unbounded(); + { + let pool = &*test_pool().await; + let (tx, mut rx) = mpsc::unbounded(); - let (_runner, counter) = test_job_runner(&pool, move |job| { - let tx = tx.clone(); - async move { - tx.unbounded_send(job).unwrap(); - } - }) - .await; + let (_runner, counter) = test_job_runner(&pool, move |job| { + let tx = tx.clone(); + async move { + tx.unbounded_send(job).unwrap(); + } + }) + .await; - assert_eq!(counter.load(Ordering::SeqCst), 0); - JobBuilder::new("foo") - .set_ordered(true) - .spawn(pool) - .await - .unwrap(); - JobBuilder::new("bar") - .set_ordered(true) - .spawn(pool) - .await - .unwrap(); + assert_eq!(counter.load(Ordering::SeqCst), 0); + JobBuilder::new("foo") + .set_ordered(true) + .spawn(pool) + .await + .unwrap(); + JobBuilder::new("bar") + .set_ordered(true) + .spawn(pool) + .await + .unwrap(); + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 1); + + let mut job = rx.next().await.unwrap(); + job.complete().await.unwrap(); + + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 2); + } pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 1); - - let mut job = rx.next().await.unwrap(); - job.complete().await.unwrap(); - - pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 2); } #[tokio::test] async fn it_runs_jobs_in_parallel() { - let pool = &*test_pool().await; - let (tx, mut rx) = mpsc::unbounded(); + { + let pool = &*test_pool().await; + let (tx, mut rx) = mpsc::unbounded(); - let (_runner, counter) = test_job_runner(&pool, move |job| { - let tx = tx.clone(); - async move { - tx.unbounded_send(job).unwrap(); + let (_runner, counter) = test_job_runner(&pool, move |job| { + let tx = tx.clone(); + async move { + tx.unbounded_send(job).unwrap(); + } + }) + .await; + + assert_eq!(counter.load(Ordering::SeqCst), 0); + JobBuilder::new("foo").spawn(pool).await.unwrap(); + JobBuilder::new("bar").spawn(pool).await.unwrap(); + + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 2); + + for _ in 0..2 { + let mut job = rx.next().await.unwrap(); + job.complete().await.unwrap(); } - }) - .await; - - assert_eq!(counter.load(Ordering::SeqCst), 0); - JobBuilder::new("foo").spawn(pool).await.unwrap(); - JobBuilder::new("bar").spawn(pool).await.unwrap(); - - pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 2); - - for _ in 0..2 { - let mut job = rx.next().await.unwrap(); - job.complete().await.unwrap(); } + pause().await; } #[tokio::test] async fn it_retries_failed_jobs() { - let pool = &*test_pool().await; - let (_runner, counter) = test_job_runner(&pool, move |_| async {}).await; + { + let pool = &*test_pool().await; + let (_runner, counter) = test_job_runner(&pool, move |_| async {}).await; - let backoff = 500; + let backoff = 500; - assert_eq!(counter.load(Ordering::SeqCst), 0); - JobBuilder::new("foo") - .set_retry_backoff(Duration::from_millis(backoff)) - .set_retries(2) - .spawn(pool) - .await - .unwrap(); + assert_eq!(counter.load(Ordering::SeqCst), 0); + JobBuilder::new("foo") + .set_retry_backoff(Duration::from_millis(backoff)) + .set_retries(2) + .spawn(pool) + .await + .unwrap(); - // First attempt + // First attempt + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 1); + + // Second attempt + pause_ms(backoff).await; + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 2); + + // Third attempt + pause_ms(backoff * 2).await; + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 3); + + // No more attempts + pause_ms(backoff * 5).await; + assert_eq!(counter.load(Ordering::SeqCst), 3); + } pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 1); - - // Second attempt - pause_ms(backoff).await; - pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 2); - - // Third attempt - pause_ms(backoff * 2).await; - pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 3); - - // No more attempts - pause_ms(backoff * 5).await; - assert_eq!(counter.load(Ordering::SeqCst), 3); } #[tokio::test] async fn it_can_checkpoint_jobs() { - let pool = &*test_pool().await; - let (_runner, counter) = test_job_runner(&pool, move |mut current_job| async move { - let state: bool = current_job.json().unwrap().unwrap(); - if state { - current_job.complete().await.unwrap(); - } else { - current_job - .checkpoint(Checkpoint::new().set_json(&true).unwrap()) - .await - .unwrap(); - } - }) - .await; + { + let pool = &*test_pool().await; + let (_runner, counter) = test_job_runner(&pool, move |mut current_job| async move { + let state: bool = current_job.json().unwrap().unwrap(); + if state { + current_job.complete().await.unwrap(); + } else { + current_job + .checkpoint(Checkpoint::new().set_json(&true).unwrap()) + .await + .unwrap(); + } + }) + .await; - let backoff = 200; + let backoff = 200; - assert_eq!(counter.load(Ordering::SeqCst), 0); - JobBuilder::new("foo") - .set_retry_backoff(Duration::from_millis(backoff)) - .set_retries(5) - .set_json(&false) - .unwrap() - .spawn(pool) - .await - .unwrap(); + assert_eq!(counter.load(Ordering::SeqCst), 0); + JobBuilder::new("foo") + .set_retry_backoff(Duration::from_millis(backoff)) + .set_retries(5) + .set_json(&false) + .unwrap() + .spawn(pool) + .await + .unwrap(); - // First attempt + // First attempt + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 1); + + // Second attempt + pause_ms(backoff).await; + pause().await; + assert_eq!(counter.load(Ordering::SeqCst), 2); + + // No more attempts + pause_ms(backoff * 3).await; + assert_eq!(counter.load(Ordering::SeqCst), 2); + } pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 1); - - // Second attempt - pause_ms(backoff).await; - pause().await; - assert_eq!(counter.load(Ordering::SeqCst), 2); - - // No more attempts - pause_ms(backoff * 3).await; - assert_eq!(counter.load(Ordering::SeqCst), 2); } #[tokio::test] async fn it_can_use_registry() { - let pool = &*test_pool().await; - let _runner = named_job_runner(pool).await; + { + let pool = &*test_pool().await; + let _runner = named_job_runner(pool).await; - example_job1.builder().spawn(pool).await.unwrap(); - example_job2.builder().spawn(pool).await.unwrap(); + example_job1.builder().spawn(pool).await.unwrap(); + example_job2.builder().spawn(pool).await.unwrap(); + example_job_with_ctx.builder().spawn(pool).await.unwrap(); + pause().await; + } pause().await; } } diff --git a/src/registry.rs b/src/registry.rs index 01d9aab..3fe6948 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -1,9 +1,12 @@ +use std::any::type_name; use std::collections::HashMap; use std::error::Error; use std::fmt::Display; use std::future::Future; use std::sync::Arc; +use anymap2::any::CloneAnySendSync; +use anymap2::Map; use sqlx::{Pool, Postgres}; use uuid::Uuid; @@ -16,6 +19,7 @@ use crate::{JobBuilder, JobRunnerOptions}; pub struct JobRegistry { error_handler: Arc) + Send + Sync>, job_map: HashMap<&'static str, &'static NamedJob>, + context: Map, } /// Error returned when a job is received whose name is not in the registry. @@ -41,6 +45,7 @@ impl JobRegistry { Self { error_handler: Arc::new(Self::default_error_handler), job_map, + context: Map::new(), } } @@ -53,6 +58,24 @@ impl JobRegistry { self } + /// Provide context for the jobs. + pub fn set_context(&mut self, context: C) -> &mut Self { + self.context.insert(context); + self + } + + /// Access job context. Will panic if context with this type has not been provided. + pub fn context(&self) -> C { + if let Some(c) = self.context.get::() { + c.clone() + } else { + panic!( + "No context of type `{}` has been provided.", + type_name::() + ); + } + } + /// Look-up a job by name. pub fn resolve_job(&self, name: &str) -> Option<&'static NamedJob> { self.job_map.get(name).copied() diff --git a/src/runner.rs b/src/runner.rs index ea00740..1442b84 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -123,6 +123,13 @@ impl CurrentJob { .await?; Ok(()) } + + async fn stop_keep_alive(&mut self) { + if let Some(keep_alive) = self.keep_alive.take() { + keep_alive.stop().await; + } + } + /// Complete this job and commit the provided transaction at the same time. /// If the transaction cannot be committed, the job will not be completed. pub async fn complete_with_transaction( @@ -131,13 +138,13 @@ impl CurrentJob { ) -> Result<(), sqlx::Error> { self.delete(&mut tx).await?; tx.commit().await?; - self.keep_alive = None; + self.stop_keep_alive().await; Ok(()) } /// Complete this job. pub async fn complete(&mut self) -> Result<(), sqlx::Error> { self.delete(self.pool()).await?; - self.keep_alive = None; + self.stop_keep_alive().await; Ok(()) } /// Checkpoint this job and commit the provided transaction at the same time. @@ -254,7 +261,7 @@ impl JobRunnerOptions { notify: Notify::new(), }); let listener_task = start_listener(job_runner.clone()).await?; - Ok(OwnedHandle(task::spawn(main_loop( + Ok(OwnedHandle::new(task::spawn(main_loop( job_runner, listener_task, )))) @@ -271,7 +278,7 @@ async fn start_listener(job_runner: Arc) -> Result 0 || listener.recv().await.is_ok() { @@ -356,7 +363,7 @@ async fn poll_and_dispatch( { let retry_backoff = to_duration(retry_backoff); let keep_alive = if options.keep_alive { - Some(OwnedHandle(task::spawn(keep_job_alive( + Some(OwnedHandle::new(task::spawn(keep_job_alive( id, options.pool.clone(), retry_backoff, diff --git a/src/utils.rs b/src/utils.rs index c656777..b23d659 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -31,10 +31,29 @@ impl DerefMut for Opaque { /// the handle is dropped. Extract the inner join handle to prevent this /// behaviour. #[derive(Debug)] -pub struct OwnedHandle(pub JoinHandle<()>); +pub struct OwnedHandle(Option>); + +impl OwnedHandle { + /// Construct a new `OwnedHandle` from the provided `JoinHandle` + pub fn new(inner: JoinHandle<()>) -> Self { + Self(Some(inner)) + } + /// Get back the original `JoinHandle` + pub fn into_inner(mut self) -> JoinHandle<()> { + self.0.take().expect("Only consumed once") + } + /// Stop the task and wait for it to finish. + pub async fn stop(self) { + let handle = self.into_inner(); + handle.abort(); + let _ = handle.await; + } +} impl Drop for OwnedHandle { fn drop(&mut self) { - self.0.abort(); + if let Some(handle) = self.0.take() { + handle.abort(); + } } }