mirror of
https://github.com/Diggsey/sqlxmq.git
synced 2024-11-24 17:21:00 +00:00
Add support for passing custom context via a JobRegistry
This commit is contained in:
parent
1cc2262cf8
commit
fc4e909d24
10 changed files with 263 additions and 144 deletions
6
.github/workflows/toolchain.yml
vendored
6
.github/workflows/toolchain.yml
vendored
|
@ -52,6 +52,8 @@ jobs:
|
|||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RUST_BACKTRACE: "1"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-rs/toolchain@v1
|
||||
|
@ -67,10 +69,13 @@ jobs:
|
|||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: -- --nocapture
|
||||
|
||||
test_nightly:
|
||||
name: Test (Nightly)
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
RUST_BACKTRACE: "1"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions-rs/toolchain@v1
|
||||
|
@ -86,3 +91,4 @@ jobs:
|
|||
- uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: -- --nocapture
|
||||
|
|
4
.vscode/settings.json
vendored
Normal file
4
.vscode/settings.json
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"rust-analyzer.checkOnSave.allFeatures": false,
|
||||
"rust-analyzer.cargo.allFeatures": false
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "sqlxmq"
|
||||
version = "0.1.2"
|
||||
version = "0.2.0"
|
||||
authors = ["Diggory Blake <diggsey@googlemail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
@ -16,14 +16,15 @@ members = ["sqlxmq_macros", "sqlxmq_stress"]
|
|||
|
||||
[dependencies]
|
||||
sqlx = { version = "0.5.2", features = ["postgres", "chrono", "uuid"] }
|
||||
tokio = { version = "1.4.0", features = ["full"] }
|
||||
tokio = { version = "=1.8.0", features = ["full"] }
|
||||
dotenv = "0.15.0"
|
||||
chrono = "0.4.19"
|
||||
uuid = { version = "0.8.2", features = ["v4"] }
|
||||
log = "0.4.14"
|
||||
serde_json = "1.0.64"
|
||||
serde = "1.0.124"
|
||||
sqlxmq_macros = { version = "0.1", path = "sqlxmq_macros" }
|
||||
sqlxmq_macros = { version = "0.2.0", path = "sqlxmq_macros" }
|
||||
anymap2 = "0.13.0"
|
||||
|
||||
[features]
|
||||
default = ["runtime-tokio-native-tls"]
|
||||
|
|
|
@ -127,13 +127,17 @@ use sqlxmq::{job, CurrentJob};
|
|||
// Arguments to the `#[job]` attribute allow setting default job options.
|
||||
#[job(channel_name = "foo")]
|
||||
async fn example_job(
|
||||
// The first argument should always be the current job.
|
||||
mut current_job: CurrentJob,
|
||||
// Additional arguments are optional, but can be used to access context
|
||||
// provided via `JobRegistry::set_context`.
|
||||
message: &'static str,
|
||||
) -> sqlx::Result<()> {
|
||||
// Decode a JSON payload
|
||||
let who: Option<String> = current_job.json()?;
|
||||
|
||||
// Do some work
|
||||
println!("Hello, {}!", who.as_deref().unwrap_or("world"));
|
||||
println!("{}, {}!", message, who.as_deref().unwrap_or("world"));
|
||||
|
||||
// Mark the job as complete
|
||||
current_job.complete().await?;
|
||||
|
@ -160,6 +164,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
|
|||
// Here is where you can configure the registry
|
||||
// registry.set_error_handler(...)
|
||||
|
||||
// And add context
|
||||
registry.set_context("Hello");
|
||||
|
||||
let runner = registry
|
||||
// Create a job runner using the connection pool.
|
||||
.runner(&pool)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "sqlxmq_macros"
|
||||
version = "0.1.2"
|
||||
version = "0.2.0"
|
||||
authors = ["Diggory Blake <diggsey@googlemail.com>"]
|
||||
edition = "2018"
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
|
|
@ -92,8 +92,12 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
|
|||
|
||||
/// Marks a function as being a background job.
|
||||
///
|
||||
/// The function must take a single `CurrentJob` argument, and should
|
||||
/// be async or return a future.
|
||||
/// The first argument to the function must have type `CurrentJob`.
|
||||
/// Additional arguments can be used to access context from the job
|
||||
/// registry. Context is accessed based on the type of the argument.
|
||||
/// Context arguments must be `Send + Sync + Clone + 'static`.
|
||||
///
|
||||
/// The function should be async or return a future.
|
||||
///
|
||||
/// The async result must be a `Result<(), E>` type, where `E` is convertible
|
||||
/// to a `Box<dyn Error + Send + Sync + 'static>`, which is the case for most
|
||||
|
@ -103,7 +107,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
|
|||
///
|
||||
/// # Name
|
||||
///
|
||||
/// ```
|
||||
/// ```ignore
|
||||
/// #[job("example")]
|
||||
/// #[job(name="example")]
|
||||
/// ```
|
||||
|
@ -115,7 +119,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
|
|||
///
|
||||
/// # Channel name
|
||||
///
|
||||
/// ```
|
||||
/// ```ignore
|
||||
/// #[job(channel_name="foo")]
|
||||
/// ```
|
||||
///
|
||||
|
@ -123,7 +127,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
|
|||
///
|
||||
/// # Retries
|
||||
///
|
||||
/// ```
|
||||
/// ```ignore
|
||||
/// #[job(retries = 3)]
|
||||
/// ```
|
||||
///
|
||||
|
@ -131,7 +135,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
|
|||
///
|
||||
/// # Retry backoff
|
||||
///
|
||||
/// ```
|
||||
/// ```ignore
|
||||
/// #[job(backoff_secs=1.5)]
|
||||
/// #[job(backoff_secs=2)]
|
||||
/// ```
|
||||
|
@ -140,7 +144,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
|
|||
///
|
||||
/// # Ordered
|
||||
///
|
||||
/// ```
|
||||
/// ```ignore
|
||||
/// #[job(ordered)]
|
||||
/// #[job(ordered=true)]
|
||||
/// #[job(ordered=false)]
|
||||
|
@ -150,7 +154,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
|
|||
///
|
||||
/// # Prototype
|
||||
///
|
||||
/// ```
|
||||
/// ```ignore
|
||||
/// fn my_proto<'a, 'b>(
|
||||
/// builder: &'a mut JobBuilder<'b>
|
||||
/// ) -> &'a mut JobBuilder<'b> {
|
||||
|
@ -170,7 +174,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
|
|||
/// prototype will always be applied first so that explicit options can override it.
|
||||
/// Each option can only be provided once in the attribute.
|
||||
///
|
||||
/// ```
|
||||
/// ```ignore
|
||||
/// #[job("my_job", proto(my_proto), retries=0, ordered)]
|
||||
/// ```
|
||||
///
|
||||
|
@ -223,6 +227,18 @@ pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|||
});
|
||||
}
|
||||
|
||||
let extract_ctx: Vec<_> = inner_fn
|
||||
.sig
|
||||
.inputs
|
||||
.iter()
|
||||
.skip(1)
|
||||
.map(|_| {
|
||||
quote! {
|
||||
registry.context()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let expanded = quote! {
|
||||
#(#errors)*
|
||||
#[allow(non_upper_case_globals)]
|
||||
|
@ -234,7 +250,7 @@ pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream {
|
|||
builder #(#chain)*
|
||||
}),
|
||||
sqlxmq::hidden::RunFn(|registry, current_job| {
|
||||
registry.spawn_internal(#fq_name, inner(current_job));
|
||||
registry.spawn_internal(#fq_name, inner(current_job #(, #extract_ctx)*));
|
||||
}),
|
||||
)
|
||||
};
|
||||
|
|
48
src/lib.rs
48
src/lib.rs
|
@ -126,13 +126,17 @@
|
|||
//! // Arguments to the `#[job]` attribute allow setting default job options.
|
||||
//! #[job(channel_name = "foo")]
|
||||
//! async fn example_job(
|
||||
//! // The first argument should always be the current job.
|
||||
//! mut current_job: CurrentJob,
|
||||
//! // Additional arguments are optional, but can be used to access context
|
||||
//! // provided via [`JobRegistry::set_context`].
|
||||
//! message: &'static str,
|
||||
//! ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
//! // Decode a JSON payload
|
||||
//! let who: Option<String> = current_job.json()?;
|
||||
//!
|
||||
//! // Do some work
|
||||
//! println!("Hello, {}!", who.as_deref().unwrap_or("world"));
|
||||
//! println!("{}, {}!", message, who.as_deref().unwrap_or("world"));
|
||||
//!
|
||||
//! // Mark the job as complete
|
||||
//! current_job.complete().await?;
|
||||
|
@ -172,6 +176,9 @@
|
|||
//! // Here is where you can configure the registry
|
||||
//! // registry.set_error_handler(...)
|
||||
//!
|
||||
//! // And add context
|
||||
//! registry.set_context("Hello");
|
||||
//!
|
||||
//! let runner = registry
|
||||
//! // Create a job runner using the connection pool.
|
||||
//! .runner(&pool)
|
||||
|
@ -321,12 +328,22 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[job]
|
||||
async fn example_job_with_ctx(
|
||||
mut current_job: CurrentJob,
|
||||
ctx1: i32,
|
||||
ctx2: &'static str,
|
||||
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
|
||||
assert_eq!(ctx1, 42);
|
||||
assert_eq!(ctx2, "Hello, world!");
|
||||
current_job.complete().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn named_job_runner(pool: &Pool<Postgres>) -> OwnedHandle {
|
||||
JobRegistry::new(&[example_job1, example_job2])
|
||||
.runner(pool)
|
||||
.run()
|
||||
.await
|
||||
.unwrap()
|
||||
let mut registry = JobRegistry::new(&[example_job1, example_job2, example_job_with_ctx]);
|
||||
registry.set_context(42).set_context("Hello, world!");
|
||||
registry.runner(pool).run().await.unwrap()
|
||||
}
|
||||
|
||||
async fn pause() {
|
||||
|
@ -339,6 +356,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn it_can_spawn_job() {
|
||||
{
|
||||
let pool = &*test_pool().await;
|
||||
let (_runner, counter) =
|
||||
test_job_runner(&pool, |mut job| async move { job.complete().await }).await;
|
||||
|
@ -348,9 +366,12 @@ mod tests {
|
|||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
pause().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_runs_jobs_in_order() {
|
||||
{
|
||||
let pool = &*test_pool().await;
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
|
||||
|
@ -383,9 +404,12 @@ mod tests {
|
|||
pause().await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
pause().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_runs_jobs_in_parallel() {
|
||||
{
|
||||
let pool = &*test_pool().await;
|
||||
let (tx, mut rx) = mpsc::unbounded();
|
||||
|
||||
|
@ -409,9 +433,12 @@ mod tests {
|
|||
job.complete().await.unwrap();
|
||||
}
|
||||
}
|
||||
pause().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_retries_failed_jobs() {
|
||||
{
|
||||
let pool = &*test_pool().await;
|
||||
let (_runner, counter) = test_job_runner(&pool, move |_| async {}).await;
|
||||
|
||||
|
@ -443,9 +470,12 @@ mod tests {
|
|||
pause_ms(backoff * 5).await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
pause().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_checkpoint_jobs() {
|
||||
{
|
||||
let pool = &*test_pool().await;
|
||||
let (_runner, counter) = test_job_runner(&pool, move |mut current_job| async move {
|
||||
let state: bool = current_job.json().unwrap().unwrap();
|
||||
|
@ -485,14 +515,20 @@ mod tests {
|
|||
pause_ms(backoff * 3).await;
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
pause().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn it_can_use_registry() {
|
||||
{
|
||||
let pool = &*test_pool().await;
|
||||
let _runner = named_job_runner(pool).await;
|
||||
|
||||
example_job1.builder().spawn(pool).await.unwrap();
|
||||
example_job2.builder().spawn(pool).await.unwrap();
|
||||
example_job_with_ctx.builder().spawn(pool).await.unwrap();
|
||||
pause().await;
|
||||
}
|
||||
pause().await;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
use std::any::type_name;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error;
|
||||
use std::fmt::Display;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anymap2::any::CloneAnySendSync;
|
||||
use anymap2::Map;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -16,6 +19,7 @@ use crate::{JobBuilder, JobRunnerOptions};
|
|||
pub struct JobRegistry {
|
||||
error_handler: Arc<dyn Fn(&str, Box<dyn Error + Send + 'static>) + Send + Sync>,
|
||||
job_map: HashMap<&'static str, &'static NamedJob>,
|
||||
context: Map<dyn CloneAnySendSync + Send + Sync>,
|
||||
}
|
||||
|
||||
/// Error returned when a job is received whose name is not in the registry.
|
||||
|
@ -41,6 +45,7 @@ impl JobRegistry {
|
|||
Self {
|
||||
error_handler: Arc::new(Self::default_error_handler),
|
||||
job_map,
|
||||
context: Map::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,6 +58,24 @@ impl JobRegistry {
|
|||
self
|
||||
}
|
||||
|
||||
/// Provide context for the jobs.
|
||||
pub fn set_context<C: Clone + Send + Sync + 'static>(&mut self, context: C) -> &mut Self {
|
||||
self.context.insert(context);
|
||||
self
|
||||
}
|
||||
|
||||
/// Access job context. Will panic if context with this type has not been provided.
|
||||
pub fn context<C: Clone + Send + Sync + 'static>(&self) -> C {
|
||||
if let Some(c) = self.context.get::<C>() {
|
||||
c.clone()
|
||||
} else {
|
||||
panic!(
|
||||
"No context of type `{}` has been provided.",
|
||||
type_name::<C>()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Look-up a job by name.
|
||||
pub fn resolve_job(&self, name: &str) -> Option<&'static NamedJob> {
|
||||
self.job_map.get(name).copied()
|
||||
|
|
|
@ -123,6 +123,13 @@ impl CurrentJob {
|
|||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop_keep_alive(&mut self) {
|
||||
if let Some(keep_alive) = self.keep_alive.take() {
|
||||
keep_alive.stop().await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete this job and commit the provided transaction at the same time.
|
||||
/// If the transaction cannot be committed, the job will not be completed.
|
||||
pub async fn complete_with_transaction(
|
||||
|
@ -131,13 +138,13 @@ impl CurrentJob {
|
|||
) -> Result<(), sqlx::Error> {
|
||||
self.delete(&mut tx).await?;
|
||||
tx.commit().await?;
|
||||
self.keep_alive = None;
|
||||
self.stop_keep_alive().await;
|
||||
Ok(())
|
||||
}
|
||||
/// Complete this job.
|
||||
pub async fn complete(&mut self) -> Result<(), sqlx::Error> {
|
||||
self.delete(self.pool()).await?;
|
||||
self.keep_alive = None;
|
||||
self.stop_keep_alive().await;
|
||||
Ok(())
|
||||
}
|
||||
/// Checkpoint this job and commit the provided transaction at the same time.
|
||||
|
@ -254,7 +261,7 @@ impl JobRunnerOptions {
|
|||
notify: Notify::new(),
|
||||
});
|
||||
let listener_task = start_listener(job_runner.clone()).await?;
|
||||
Ok(OwnedHandle(task::spawn(main_loop(
|
||||
Ok(OwnedHandle::new(task::spawn(main_loop(
|
||||
job_runner,
|
||||
listener_task,
|
||||
))))
|
||||
|
@ -271,7 +278,7 @@ async fn start_listener(job_runner: Arc<JobRunner>) -> Result<OwnedHandle, sqlx:
|
|||
} else {
|
||||
listener.listen("mq").await?;
|
||||
}
|
||||
Ok(OwnedHandle(task::spawn(async move {
|
||||
Ok(OwnedHandle::new(task::spawn(async move {
|
||||
let mut num_errors = 0;
|
||||
loop {
|
||||
if num_errors > 0 || listener.recv().await.is_ok() {
|
||||
|
@ -356,7 +363,7 @@ async fn poll_and_dispatch(
|
|||
{
|
||||
let retry_backoff = to_duration(retry_backoff);
|
||||
let keep_alive = if options.keep_alive {
|
||||
Some(OwnedHandle(task::spawn(keep_job_alive(
|
||||
Some(OwnedHandle::new(task::spawn(keep_job_alive(
|
||||
id,
|
||||
options.pool.clone(),
|
||||
retry_backoff,
|
||||
|
|
23
src/utils.rs
23
src/utils.rs
|
@ -31,10 +31,29 @@ impl<T: Any> DerefMut for Opaque<T> {
|
|||
/// the handle is dropped. Extract the inner join handle to prevent this
|
||||
/// behaviour.
|
||||
#[derive(Debug)]
|
||||
pub struct OwnedHandle(pub JoinHandle<()>);
|
||||
pub struct OwnedHandle(Option<JoinHandle<()>>);
|
||||
|
||||
impl OwnedHandle {
|
||||
/// Construct a new `OwnedHandle` from the provided `JoinHandle`
|
||||
pub fn new(inner: JoinHandle<()>) -> Self {
|
||||
Self(Some(inner))
|
||||
}
|
||||
/// Get back the original `JoinHandle`
|
||||
pub fn into_inner(mut self) -> JoinHandle<()> {
|
||||
self.0.take().expect("Only consumed once")
|
||||
}
|
||||
/// Stop the task and wait for it to finish.
|
||||
pub async fn stop(self) {
|
||||
let handle = self.into_inner();
|
||||
handle.abort();
|
||||
let _ = handle.await;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OwnedHandle {
|
||||
fn drop(&mut self) {
|
||||
self.0.abort();
|
||||
if let Some(handle) = self.0.take() {
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue