diff --git a/jobs-core/src/catch_unwind.rs b/jobs-core/src/catch_unwind.rs index af4f428..69d752c 100644 --- a/jobs-core/src/catch_unwind.rs +++ b/jobs-core/src/catch_unwind.rs @@ -1,21 +1,19 @@ use std::{ future::Future, + panic::AssertUnwindSafe, pin::Pin, - sync::Mutex, task::{Context, Poll}, }; pub(crate) struct CatchUnwindFuture { - future: Mutex, + future: F, } pub(crate) fn catch_unwind(future: F) -> CatchUnwindFuture where F: Future + Unpin, { - CatchUnwindFuture { - future: Mutex::new(future), - } + CatchUnwindFuture { future } } impl Future for CatchUnwindFuture @@ -25,13 +23,12 @@ where type Output = std::thread::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let future = &self.future; + let future = &mut self.get_mut().future; let waker = cx.waker().clone(); - let res = std::panic::catch_unwind(|| { + let res = std::panic::catch_unwind(AssertUnwindSafe(|| { let mut context = Context::from_waker(&waker); - let mut guard = future.lock().unwrap(); - Pin::new(&mut *guard).poll(&mut context) - }); + Pin::new(future).poll(&mut context) + })); match res { Ok(poll) => poll.map(Ok), diff --git a/jobs-core/src/processor_map.rs b/jobs-core/src/processor_map.rs index fae76f5..432844a 100644 --- a/jobs-core/src/processor_map.rs +++ b/jobs-core/src/processor_map.rs @@ -1,6 +1,9 @@ use crate::{catch_unwind::catch_unwind, Job, JobError, JobInfo, ReturnJobInfo}; use serde_json::Value; -use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, time::Instant}; +use std::{ + collections::HashMap, future::Future, panic::AssertUnwindSafe, pin::Pin, sync::Arc, + time::Instant, +}; use tracing::Span; use tracing_futures::Instrument; use uuid::Uuid; @@ -165,13 +168,7 @@ where let start = Instant::now(); - let state_mtx = std::sync::Mutex::new(state); - let process_mtx = std::sync::Mutex::new(process_fn); - - let res = match std::panic::catch_unwind(|| { - let state = state_mtx.lock().unwrap().clone(); - (process_mtx.lock().unwrap())(args, state) - }) { + let res = match std::panic::catch_unwind(AssertUnwindSafe(|| (process_fn)(args, state))) { Ok(fut) => catch_unwind(fut).await, Err(e) => Err(e), }; diff --git a/jobs-core/src/storage.rs b/jobs-core/src/storage.rs index 0aa5271..58216b5 100644 --- a/jobs-core/src/storage.rs +++ b/jobs-core/src/storage.rs @@ -138,10 +138,11 @@ pub trait Storage: Clone + Send { /// A default, in-memory implementation of a storage mechanism pub mod memory_storage { use super::JobInfo; - use event_listener::Event; + use event_listener::{Event, EventListener}; use std::{ collections::HashMap, convert::Infallible, + future::Future, sync::Arc, sync::Mutex, time::{Duration, SystemTime}, @@ -154,7 +155,7 @@ pub mod memory_storage { /// Race a future against the clock, returning an empty tuple if the clock wins async fn timeout(&self, duration: Duration, future: F) -> Result where - F: std::future::Future + Send + Sync; + F: Future + Send + Sync; } #[derive(Clone)] @@ -186,6 +187,97 @@ pub mod memory_storage { timer, } } + + fn contains_job(&self, uuid: &Uuid) -> bool { + self.inner.lock().unwrap().jobs.contains_key(uuid) + } + + fn insert_job(&self, job: JobInfo) { + self.inner.lock().unwrap().jobs.insert(job.id(), job); + } + + fn get_job(&self, id: &Uuid) -> Option { + self.inner.lock().unwrap().jobs.get(id).cloned() + } + + fn try_deque(&self, queue: &str, now: SystemTime) -> Option { + let mut inner = self.inner.lock().unwrap(); + + let j = inner.job_queues.iter().find_map(|(k, v)| { + if v == queue { + let job = inner.jobs.get(k)?; + + if job.is_pending(now) && job.is_ready(now) && job.is_in_queue(queue) { + return Some(job.clone()); + } + } + + None + }); + + if let Some(job) = j { + inner.job_queues.remove(&job.id()); + return Some(job); + } + + None + } + + fn listener(&self, queue: &str, now: SystemTime) -> (Duration, EventListener) { + let mut inner = self.inner.lock().unwrap(); + + let duration = + inner + .job_queues + .iter() + .fold(Duration::from_secs(5), |duration, (id, v_queue)| { + if v_queue == queue { + if let Some(job) = inner.jobs.get(id) { + if let Some(ready_at) = job.next_queue() { + let job_eta = ready_at + .duration_since(now) + .unwrap_or(Duration::from_secs(0)); + + if job_eta < duration { + return job_eta; + } + } + } + } + + duration + }); + + let listener = inner.queues.entry(queue.to_string()).or_default().listen(); + + (duration, listener) + } + + fn queue_and_notify(&self, queue: &str, id: Uuid) { + let mut inner = self.inner.lock().unwrap(); + + inner.job_queues.insert(id, queue.to_owned()); + + inner.queues.entry(queue.to_string()).or_default().notify(1); + } + + fn mark_running(&self, job_id: Uuid, worker_id: Uuid) { + let mut inner = self.inner.lock().unwrap(); + + inner.worker_ids.insert(job_id, worker_id); + inner.worker_ids_inverse.insert(worker_id, job_id); + } + + fn purge_job(&self, job_id: Uuid) { + let mut inner = self.inner.lock().unwrap(); + + inner.jobs.remove(&job_id); + inner.job_queues.remove(&job_id); + + if let Some(worker_id) = inner.worker_ids.remove(&job_id) { + inner.worker_ids_inverse.remove(&worker_id); + } + } } #[async_trait::async_trait] @@ -195,7 +287,7 @@ pub mod memory_storage { async fn generate_id(&self) -> Result { let uuid = loop { let uuid = Uuid::new_v4(); - if !self.inner.lock().unwrap().jobs.contains_key(&uuid) { + if !self.contains_job(&uuid) { break uuid; } }; @@ -204,99 +296,44 @@ pub mod memory_storage { } async fn save_job(&self, job: JobInfo) -> Result<(), Self::Error> { - self.inner.lock().unwrap().jobs.insert(job.id(), job); + self.insert_job(job); Ok(()) } async fn fetch_job(&self, id: Uuid) -> Result, Self::Error> { - let j = self.inner.lock().unwrap().jobs.get(&id).cloned(); - - Ok(j) + Ok(self.get_job(&id)) } async fn fetch_job_from_queue(&self, queue: &str) -> Result { loop { - let listener = { - let mut inner = self.inner.lock().unwrap(); - let now = SystemTime::now(); + let now = SystemTime::now(); - let j = inner.job_queues.iter().find_map(|(k, v)| { - if v == queue { - let job = inner.jobs.get(k)?; + if let Some(job) = self.try_deque(queue, now) { + return Ok(job); + } - if job.is_pending(now) && job.is_ready(now) && job.is_in_queue(queue) { - return Some(job.clone()); - } - } + let (duration, listener) = self.listener(queue, now); - None - }); - - let duration = if let Some(j) = j { - if inner.job_queues.remove(&j.id()).is_some() { - return Ok(j); - } else { - continue; - } - } else { - inner.job_queues.iter().fold( - Duration::from_secs(5), - |duration, (id, v_queue)| { - if v_queue == queue { - if let Some(job) = inner.jobs.get(id) { - if let Some(ready_at) = job.next_queue() { - let job_eta = ready_at - .duration_since(now) - .unwrap_or(Duration::from_secs(0)); - - if job_eta < duration { - return job_eta; - } - } - } - } - - duration - }, - ) - }; - - self.timer.timeout( - duration, - inner.queues.entry(queue.to_string()).or_default().listen(), - ) - }; - - let _ = listener.await; + let _ = self.timer.timeout(duration, listener).await; } } async fn queue_job(&self, queue: &str, id: Uuid) -> Result<(), Self::Error> { - let mut inner = self.inner.lock().unwrap(); - - inner.job_queues.insert(id, queue.to_owned()); - - inner.queues.entry(queue.to_string()).or_default().notify(1); + self.queue_and_notify(queue, id); Ok(()) } async fn run_job(&self, id: Uuid, worker_id: Uuid) -> Result<(), Self::Error> { - let mut inner = self.inner.lock().unwrap(); + self.mark_running(id, worker_id); - inner.worker_ids.insert(id, worker_id); - inner.worker_ids_inverse.insert(worker_id, id); Ok(()) } async fn delete_job(&self, id: Uuid) -> Result<(), Self::Error> { - let mut inner = self.inner.lock().unwrap(); - inner.jobs.remove(&id); - inner.job_queues.remove(&id); - if let Some(worker_id) = inner.worker_ids.remove(&id) { - inner.worker_ids_inverse.remove(&worker_id); - } + self.purge_job(id); + Ok(()) } }