Add support for passing custom context via a JobRegistry

This commit is contained in:
Diggory Blake 2021-07-17 18:45:06 +01:00
parent 1cc2262cf8
commit fc4e909d24
No known key found for this signature in database
GPG key ID: E6BDFA83146ABD40
10 changed files with 263 additions and 144 deletions

View file

@ -52,6 +52,8 @@ jobs:
test: test:
name: Test name: Test
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
RUST_BACKTRACE: "1"
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1 - uses: actions-rs/toolchain@v1
@ -67,10 +69,13 @@ jobs:
- uses: actions-rs/cargo@v1 - uses: actions-rs/cargo@v1
with: with:
command: test command: test
args: -- --nocapture
test_nightly: test_nightly:
name: Test (Nightly) name: Test (Nightly)
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
RUST_BACKTRACE: "1"
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1 - uses: actions-rs/toolchain@v1
@ -86,3 +91,4 @@ jobs:
- uses: actions-rs/cargo@v1 - uses: actions-rs/cargo@v1
with: with:
command: test command: test
args: -- --nocapture

4
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,4 @@
{
"rust-analyzer.checkOnSave.allFeatures": false,
"rust-analyzer.cargo.allFeatures": false
}

View file

@ -1,6 +1,6 @@
[package] [package]
name = "sqlxmq" name = "sqlxmq"
version = "0.1.2" version = "0.2.0"
authors = ["Diggory Blake <diggsey@googlemail.com>"] authors = ["Diggory Blake <diggsey@googlemail.com>"]
edition = "2018" edition = "2018"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
@ -16,14 +16,15 @@ members = ["sqlxmq_macros", "sqlxmq_stress"]
[dependencies] [dependencies]
sqlx = { version = "0.5.2", features = ["postgres", "chrono", "uuid"] } sqlx = { version = "0.5.2", features = ["postgres", "chrono", "uuid"] }
tokio = { version = "1.4.0", features = ["full"] } tokio = { version = "=1.8.0", features = ["full"] }
dotenv = "0.15.0" dotenv = "0.15.0"
chrono = "0.4.19" chrono = "0.4.19"
uuid = { version = "0.8.2", features = ["v4"] } uuid = { version = "0.8.2", features = ["v4"] }
log = "0.4.14" log = "0.4.14"
serde_json = "1.0.64" serde_json = "1.0.64"
serde = "1.0.124" serde = "1.0.124"
sqlxmq_macros = { version = "0.1", path = "sqlxmq_macros" } sqlxmq_macros = { version = "0.2.0", path = "sqlxmq_macros" }
anymap2 = "0.13.0"
[features] [features]
default = ["runtime-tokio-native-tls"] default = ["runtime-tokio-native-tls"]

View file

@ -127,13 +127,17 @@ use sqlxmq::{job, CurrentJob};
// Arguments to the `#[job]` attribute allow setting default job options. // Arguments to the `#[job]` attribute allow setting default job options.
#[job(channel_name = "foo")] #[job(channel_name = "foo")]
async fn example_job( async fn example_job(
// The first argument should always be the current job.
mut current_job: CurrentJob, mut current_job: CurrentJob,
// Additional arguments are optional, but can be used to access context
// provided via `JobRegistry::set_context`.
message: &'static str,
) -> sqlx::Result<()> { ) -> sqlx::Result<()> {
// Decode a JSON payload // Decode a JSON payload
let who: Option<String> = current_job.json()?; let who: Option<String> = current_job.json()?;
// Do some work // Do some work
println!("Hello, {}!", who.as_deref().unwrap_or("world")); println!("{}, {}!", message, who.as_deref().unwrap_or("world"));
// Mark the job as complete // Mark the job as complete
current_job.complete().await?; current_job.complete().await?;
@ -160,6 +164,9 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Here is where you can configure the registry // Here is where you can configure the registry
// registry.set_error_handler(...) // registry.set_error_handler(...)
// And add context
registry.set_context("Hello");
let runner = registry let runner = registry
// Create a job runner using the connection pool. // Create a job runner using the connection pool.
.runner(&pool) .runner(&pool)

View file

@ -1,6 +1,6 @@
[package] [package]
name = "sqlxmq_macros" name = "sqlxmq_macros"
version = "0.1.2" version = "0.2.0"
authors = ["Diggory Blake <diggsey@googlemail.com>"] authors = ["Diggory Blake <diggsey@googlemail.com>"]
edition = "2018" edition = "2018"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"

View file

@ -92,8 +92,12 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// Marks a function as being a background job. /// Marks a function as being a background job.
/// ///
/// The function must take a single `CurrentJob` argument, and should /// The first argument to the function must have type `CurrentJob`.
/// be async or return a future. /// Additional arguments can be used to access context from the job
/// registry. Context is accessed based on the type of the argument.
/// Context arguments must be `Send + Sync + Clone + 'static`.
///
/// The function should be async or return a future.
/// ///
/// The async result must be a `Result<(), E>` type, where `E` is convertible /// The async result must be a `Result<(), E>` type, where `E` is convertible
/// to a `Box<dyn Error + Send + Sync + 'static>`, which is the case for most /// to a `Box<dyn Error + Send + Sync + 'static>`, which is the case for most
@ -103,7 +107,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// ///
/// # Name /// # Name
/// ///
/// ``` /// ```ignore
/// #[job("example")] /// #[job("example")]
/// #[job(name="example")] /// #[job(name="example")]
/// ``` /// ```
@ -115,7 +119,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// ///
/// # Channel name /// # Channel name
/// ///
/// ``` /// ```ignore
/// #[job(channel_name="foo")] /// #[job(channel_name="foo")]
/// ``` /// ```
/// ///
@ -123,7 +127,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// ///
/// # Retries /// # Retries
/// ///
/// ``` /// ```ignore
/// #[job(retries = 3)] /// #[job(retries = 3)]
/// ``` /// ```
/// ///
@ -131,7 +135,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// ///
/// # Retry backoff /// # Retry backoff
/// ///
/// ``` /// ```ignore
/// #[job(backoff_secs=1.5)] /// #[job(backoff_secs=1.5)]
/// #[job(backoff_secs=2)] /// #[job(backoff_secs=2)]
/// ``` /// ```
@ -140,7 +144,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// ///
/// # Ordered /// # Ordered
/// ///
/// ``` /// ```ignore
/// #[job(ordered)] /// #[job(ordered)]
/// #[job(ordered=true)] /// #[job(ordered=true)]
/// #[job(ordered=false)] /// #[job(ordered=false)]
@ -150,7 +154,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// ///
/// # Prototype /// # Prototype
/// ///
/// ``` /// ```ignore
/// fn my_proto<'a, 'b>( /// fn my_proto<'a, 'b>(
/// builder: &'a mut JobBuilder<'b> /// builder: &'a mut JobBuilder<'b>
/// ) -> &'a mut JobBuilder<'b> { /// ) -> &'a mut JobBuilder<'b> {
@ -170,7 +174,7 @@ fn interpret_job_arg(options: &mut JobOptions, arg: NestedMeta) -> Result<()> {
/// prototype will always be applied first so that explicit options can override it. /// prototype will always be applied first so that explicit options can override it.
/// Each option can only be provided once in the attribute. /// Each option can only be provided once in the attribute.
/// ///
/// ``` /// ```ignore
/// #[job("my_job", proto(my_proto), retries=0, ordered)] /// #[job("my_job", proto(my_proto), retries=0, ordered)]
/// ``` /// ```
/// ///
@ -223,6 +227,18 @@ pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream {
}); });
} }
let extract_ctx: Vec<_> = inner_fn
.sig
.inputs
.iter()
.skip(1)
.map(|_| {
quote! {
registry.context()
}
})
.collect();
let expanded = quote! { let expanded = quote! {
#(#errors)* #(#errors)*
#[allow(non_upper_case_globals)] #[allow(non_upper_case_globals)]
@ -234,7 +250,7 @@ pub fn job(attr: TokenStream, item: TokenStream) -> TokenStream {
builder #(#chain)* builder #(#chain)*
}), }),
sqlxmq::hidden::RunFn(|registry, current_job| { sqlxmq::hidden::RunFn(|registry, current_job| {
registry.spawn_internal(#fq_name, inner(current_job)); registry.spawn_internal(#fq_name, inner(current_job #(, #extract_ctx)*));
}), }),
) )
}; };

View file

@ -126,13 +126,17 @@
//! // Arguments to the `#[job]` attribute allow setting default job options. //! // Arguments to the `#[job]` attribute allow setting default job options.
//! #[job(channel_name = "foo")] //! #[job(channel_name = "foo")]
//! async fn example_job( //! async fn example_job(
//! // The first argument should always be the current job.
//! mut current_job: CurrentJob, //! mut current_job: CurrentJob,
//! // Additional arguments are optional, but can be used to access context
//! // provided via [`JobRegistry::set_context`].
//! message: &'static str,
//! ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> { //! ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
//! // Decode a JSON payload //! // Decode a JSON payload
//! let who: Option<String> = current_job.json()?; //! let who: Option<String> = current_job.json()?;
//! //!
//! // Do some work //! // Do some work
//! println!("Hello, {}!", who.as_deref().unwrap_or("world")); //! println!("{}, {}!", message, who.as_deref().unwrap_or("world"));
//! //!
//! // Mark the job as complete //! // Mark the job as complete
//! current_job.complete().await?; //! current_job.complete().await?;
@ -172,6 +176,9 @@
//! // Here is where you can configure the registry //! // Here is where you can configure the registry
//! // registry.set_error_handler(...) //! // registry.set_error_handler(...)
//! //!
//! // And add context
//! registry.set_context("Hello");
//!
//! let runner = registry //! let runner = registry
//! // Create a job runner using the connection pool. //! // Create a job runner using the connection pool.
//! .runner(&pool) //! .runner(&pool)
@ -321,12 +328,22 @@ mod tests {
Ok(()) Ok(())
} }
#[job]
async fn example_job_with_ctx(
mut current_job: CurrentJob,
ctx1: i32,
ctx2: &'static str,
) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
assert_eq!(ctx1, 42);
assert_eq!(ctx2, "Hello, world!");
current_job.complete().await?;
Ok(())
}
async fn named_job_runner(pool: &Pool<Postgres>) -> OwnedHandle { async fn named_job_runner(pool: &Pool<Postgres>) -> OwnedHandle {
JobRegistry::new(&[example_job1, example_job2]) let mut registry = JobRegistry::new(&[example_job1, example_job2, example_job_with_ctx]);
.runner(pool) registry.set_context(42).set_context("Hello, world!");
.run() registry.runner(pool).run().await.unwrap()
.await
.unwrap()
} }
async fn pause() { async fn pause() {
@ -339,160 +356,179 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn it_can_spawn_job() { async fn it_can_spawn_job() {
let pool = &*test_pool().await; {
let (_runner, counter) = let pool = &*test_pool().await;
test_job_runner(&pool, |mut job| async move { job.complete().await }).await; let (_runner, counter) =
test_job_runner(&pool, |mut job| async move { job.complete().await }).await;
assert_eq!(counter.load(Ordering::SeqCst), 0); assert_eq!(counter.load(Ordering::SeqCst), 0);
JobBuilder::new("foo").spawn(pool).await.unwrap(); JobBuilder::new("foo").spawn(pool).await.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
pause().await; pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
} }
#[tokio::test] #[tokio::test]
async fn it_runs_jobs_in_order() { async fn it_runs_jobs_in_order() {
let pool = &*test_pool().await; {
let (tx, mut rx) = mpsc::unbounded(); let pool = &*test_pool().await;
let (tx, mut rx) = mpsc::unbounded();
let (_runner, counter) = test_job_runner(&pool, move |job| { let (_runner, counter) = test_job_runner(&pool, move |job| {
let tx = tx.clone(); let tx = tx.clone();
async move { async move {
tx.unbounded_send(job).unwrap(); tx.unbounded_send(job).unwrap();
} }
}) })
.await; .await;
assert_eq!(counter.load(Ordering::SeqCst), 0); assert_eq!(counter.load(Ordering::SeqCst), 0);
JobBuilder::new("foo") JobBuilder::new("foo")
.set_ordered(true) .set_ordered(true)
.spawn(pool) .spawn(pool)
.await .await
.unwrap(); .unwrap();
JobBuilder::new("bar") JobBuilder::new("bar")
.set_ordered(true) .set_ordered(true)
.spawn(pool) .spawn(pool)
.await .await
.unwrap(); .unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
let mut job = rx.next().await.unwrap();
job.complete().await.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
pause().await; pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
let mut job = rx.next().await.unwrap();
job.complete().await.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
} }
#[tokio::test] #[tokio::test]
async fn it_runs_jobs_in_parallel() { async fn it_runs_jobs_in_parallel() {
let pool = &*test_pool().await; {
let (tx, mut rx) = mpsc::unbounded(); let pool = &*test_pool().await;
let (tx, mut rx) = mpsc::unbounded();
let (_runner, counter) = test_job_runner(&pool, move |job| { let (_runner, counter) = test_job_runner(&pool, move |job| {
let tx = tx.clone(); let tx = tx.clone();
async move { async move {
tx.unbounded_send(job).unwrap(); tx.unbounded_send(job).unwrap();
}
})
.await;
assert_eq!(counter.load(Ordering::SeqCst), 0);
JobBuilder::new("foo").spawn(pool).await.unwrap();
JobBuilder::new("bar").spawn(pool).await.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
for _ in 0..2 {
let mut job = rx.next().await.unwrap();
job.complete().await.unwrap();
} }
})
.await;
assert_eq!(counter.load(Ordering::SeqCst), 0);
JobBuilder::new("foo").spawn(pool).await.unwrap();
JobBuilder::new("bar").spawn(pool).await.unwrap();
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
for _ in 0..2 {
let mut job = rx.next().await.unwrap();
job.complete().await.unwrap();
} }
pause().await;
} }
#[tokio::test] #[tokio::test]
async fn it_retries_failed_jobs() { async fn it_retries_failed_jobs() {
let pool = &*test_pool().await; {
let (_runner, counter) = test_job_runner(&pool, move |_| async {}).await; let pool = &*test_pool().await;
let (_runner, counter) = test_job_runner(&pool, move |_| async {}).await;
let backoff = 500; let backoff = 500;
assert_eq!(counter.load(Ordering::SeqCst), 0); assert_eq!(counter.load(Ordering::SeqCst), 0);
JobBuilder::new("foo") JobBuilder::new("foo")
.set_retry_backoff(Duration::from_millis(backoff)) .set_retry_backoff(Duration::from_millis(backoff))
.set_retries(2) .set_retries(2)
.spawn(pool) .spawn(pool)
.await .await
.unwrap(); .unwrap();
// First attempt // First attempt
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
// Second attempt
pause_ms(backoff).await;
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
// Third attempt
pause_ms(backoff * 2).await;
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 3);
// No more attempts
pause_ms(backoff * 5).await;
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
pause().await; pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
// Second attempt
pause_ms(backoff).await;
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
// Third attempt
pause_ms(backoff * 2).await;
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 3);
// No more attempts
pause_ms(backoff * 5).await;
assert_eq!(counter.load(Ordering::SeqCst), 3);
} }
#[tokio::test] #[tokio::test]
async fn it_can_checkpoint_jobs() { async fn it_can_checkpoint_jobs() {
let pool = &*test_pool().await; {
let (_runner, counter) = test_job_runner(&pool, move |mut current_job| async move { let pool = &*test_pool().await;
let state: bool = current_job.json().unwrap().unwrap(); let (_runner, counter) = test_job_runner(&pool, move |mut current_job| async move {
if state { let state: bool = current_job.json().unwrap().unwrap();
current_job.complete().await.unwrap(); if state {
} else { current_job.complete().await.unwrap();
current_job } else {
.checkpoint(Checkpoint::new().set_json(&true).unwrap()) current_job
.await .checkpoint(Checkpoint::new().set_json(&true).unwrap())
.unwrap(); .await
} .unwrap();
}) }
.await; })
.await;
let backoff = 200; let backoff = 200;
assert_eq!(counter.load(Ordering::SeqCst), 0); assert_eq!(counter.load(Ordering::SeqCst), 0);
JobBuilder::new("foo") JobBuilder::new("foo")
.set_retry_backoff(Duration::from_millis(backoff)) .set_retry_backoff(Duration::from_millis(backoff))
.set_retries(5) .set_retries(5)
.set_json(&false) .set_json(&false)
.unwrap() .unwrap()
.spawn(pool) .spawn(pool)
.await .await
.unwrap(); .unwrap();
// First attempt // First attempt
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
// Second attempt
pause_ms(backoff).await;
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
// No more attempts
pause_ms(backoff * 3).await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
pause().await; pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
// Second attempt
pause_ms(backoff).await;
pause().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
// No more attempts
pause_ms(backoff * 3).await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
} }
#[tokio::test] #[tokio::test]
async fn it_can_use_registry() { async fn it_can_use_registry() {
let pool = &*test_pool().await; {
let _runner = named_job_runner(pool).await; let pool = &*test_pool().await;
let _runner = named_job_runner(pool).await;
example_job1.builder().spawn(pool).await.unwrap(); example_job1.builder().spawn(pool).await.unwrap();
example_job2.builder().spawn(pool).await.unwrap(); example_job2.builder().spawn(pool).await.unwrap();
example_job_with_ctx.builder().spawn(pool).await.unwrap();
pause().await;
}
pause().await; pause().await;
} }
} }

View file

@ -1,9 +1,12 @@
use std::any::type_name;
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::fmt::Display; use std::fmt::Display;
use std::future::Future; use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use anymap2::any::CloneAnySendSync;
use anymap2::Map;
use sqlx::{Pool, Postgres}; use sqlx::{Pool, Postgres};
use uuid::Uuid; use uuid::Uuid;
@ -16,6 +19,7 @@ use crate::{JobBuilder, JobRunnerOptions};
pub struct JobRegistry { pub struct JobRegistry {
error_handler: Arc<dyn Fn(&str, Box<dyn Error + Send + 'static>) + Send + Sync>, error_handler: Arc<dyn Fn(&str, Box<dyn Error + Send + 'static>) + Send + Sync>,
job_map: HashMap<&'static str, &'static NamedJob>, job_map: HashMap<&'static str, &'static NamedJob>,
context: Map<dyn CloneAnySendSync + Send + Sync>,
} }
/// Error returned when a job is received whose name is not in the registry. /// Error returned when a job is received whose name is not in the registry.
@ -41,6 +45,7 @@ impl JobRegistry {
Self { Self {
error_handler: Arc::new(Self::default_error_handler), error_handler: Arc::new(Self::default_error_handler),
job_map, job_map,
context: Map::new(),
} }
} }
@ -53,6 +58,24 @@ impl JobRegistry {
self self
} }
/// Provide context for the jobs.
pub fn set_context<C: Clone + Send + Sync + 'static>(&mut self, context: C) -> &mut Self {
self.context.insert(context);
self
}
/// Access job context. Will panic if context with this type has not been provided.
pub fn context<C: Clone + Send + Sync + 'static>(&self) -> C {
if let Some(c) = self.context.get::<C>() {
c.clone()
} else {
panic!(
"No context of type `{}` has been provided.",
type_name::<C>()
);
}
}
/// Look-up a job by name. /// Look-up a job by name.
pub fn resolve_job(&self, name: &str) -> Option<&'static NamedJob> { pub fn resolve_job(&self, name: &str) -> Option<&'static NamedJob> {
self.job_map.get(name).copied() self.job_map.get(name).copied()

View file

@ -123,6 +123,13 @@ impl CurrentJob {
.await?; .await?;
Ok(()) Ok(())
} }
async fn stop_keep_alive(&mut self) {
if let Some(keep_alive) = self.keep_alive.take() {
keep_alive.stop().await;
}
}
/// Complete this job and commit the provided transaction at the same time. /// Complete this job and commit the provided transaction at the same time.
/// If the transaction cannot be committed, the job will not be completed. /// If the transaction cannot be committed, the job will not be completed.
pub async fn complete_with_transaction( pub async fn complete_with_transaction(
@ -131,13 +138,13 @@ impl CurrentJob {
) -> Result<(), sqlx::Error> { ) -> Result<(), sqlx::Error> {
self.delete(&mut tx).await?; self.delete(&mut tx).await?;
tx.commit().await?; tx.commit().await?;
self.keep_alive = None; self.stop_keep_alive().await;
Ok(()) Ok(())
} }
/// Complete this job. /// Complete this job.
pub async fn complete(&mut self) -> Result<(), sqlx::Error> { pub async fn complete(&mut self) -> Result<(), sqlx::Error> {
self.delete(self.pool()).await?; self.delete(self.pool()).await?;
self.keep_alive = None; self.stop_keep_alive().await;
Ok(()) Ok(())
} }
/// Checkpoint this job and commit the provided transaction at the same time. /// Checkpoint this job and commit the provided transaction at the same time.
@ -254,7 +261,7 @@ impl JobRunnerOptions {
notify: Notify::new(), notify: Notify::new(),
}); });
let listener_task = start_listener(job_runner.clone()).await?; let listener_task = start_listener(job_runner.clone()).await?;
Ok(OwnedHandle(task::spawn(main_loop( Ok(OwnedHandle::new(task::spawn(main_loop(
job_runner, job_runner,
listener_task, listener_task,
)))) ))))
@ -271,7 +278,7 @@ async fn start_listener(job_runner: Arc<JobRunner>) -> Result<OwnedHandle, sqlx:
} else { } else {
listener.listen("mq").await?; listener.listen("mq").await?;
} }
Ok(OwnedHandle(task::spawn(async move { Ok(OwnedHandle::new(task::spawn(async move {
let mut num_errors = 0; let mut num_errors = 0;
loop { loop {
if num_errors > 0 || listener.recv().await.is_ok() { if num_errors > 0 || listener.recv().await.is_ok() {
@ -356,7 +363,7 @@ async fn poll_and_dispatch(
{ {
let retry_backoff = to_duration(retry_backoff); let retry_backoff = to_duration(retry_backoff);
let keep_alive = if options.keep_alive { let keep_alive = if options.keep_alive {
Some(OwnedHandle(task::spawn(keep_job_alive( Some(OwnedHandle::new(task::spawn(keep_job_alive(
id, id,
options.pool.clone(), options.pool.clone(),
retry_backoff, retry_backoff,

View file

@ -31,10 +31,29 @@ impl<T: Any> DerefMut for Opaque<T> {
/// the handle is dropped. Extract the inner join handle to prevent this /// the handle is dropped. Extract the inner join handle to prevent this
/// behaviour. /// behaviour.
#[derive(Debug)] #[derive(Debug)]
pub struct OwnedHandle(pub JoinHandle<()>); pub struct OwnedHandle(Option<JoinHandle<()>>);
impl OwnedHandle {
/// Construct a new `OwnedHandle` from the provided `JoinHandle`
pub fn new(inner: JoinHandle<()>) -> Self {
Self(Some(inner))
}
/// Get back the original `JoinHandle`
pub fn into_inner(mut self) -> JoinHandle<()> {
self.0.take().expect("Only consumed once")
}
/// Stop the task and wait for it to finish.
pub async fn stop(self) {
let handle = self.into_inner();
handle.abort();
let _ = handle.await;
}
}
impl Drop for OwnedHandle { impl Drop for OwnedHandle {
fn drop(&mut self) { fn drop(&mut self) {
self.0.abort(); if let Some(handle) = self.0.take() {
handle.abort();
}
} }
} }