Fix poll after completion, misused compare_and_swap

This commit is contained in:
asonix 2023-08-14 21:17:57 -05:00
parent 26ca3a7195
commit 09236d731d
4 changed files with 154 additions and 73 deletions

View file

@ -278,29 +278,37 @@ where
+ Copy,
{
loop {
let (job_id, bytes) = repo.pop(queue).await?;
let fut = async {
let (job_id, bytes) = repo.pop(queue, worker_id).await?;
let span = tracing::info_span!("Running Job", worker_id = ?worker_id);
let span = tracing::info_span!("Running Job");
let guard = MetricsGuard::guard(worker_id, queue);
let guard = MetricsGuard::guard(worker_id, queue);
let res = span
.in_scope(|| {
heartbeat(
repo,
queue,
job_id,
(callback)(repo, store, config, bytes.as_ref()),
)
})
.instrument(span)
.await;
let res = span
.in_scope(|| {
heartbeat(
repo,
queue,
worker_id,
job_id,
(callback)(repo, store, config, bytes.as_ref()),
)
})
.instrument(span)
.await;
repo.complete_job(queue, job_id).await?;
repo.complete_job(queue, worker_id, job_id).await?;
res?;
res?;
guard.disarm();
guard.disarm();
Ok(()) as Result<(), Error>
};
fut.instrument(tracing::info_span!("tick", worker_id = %worker_id))
.await?;
}
}
@ -361,38 +369,52 @@ where
+ Copy,
{
loop {
let (job_id, bytes) = repo.pop(queue).await?;
let fut = async {
let (job_id, bytes) = repo.pop(queue, worker_id).await?;
let span = tracing::info_span!("Running Job", worker_id = ?worker_id);
let span = tracing::info_span!("Running Job");
let guard = MetricsGuard::guard(worker_id, queue);
let guard = MetricsGuard::guard(worker_id, queue);
let res = span
.in_scope(|| {
heartbeat(
repo,
queue,
job_id,
(callback)(repo, store, process_map, config, bytes.as_ref()),
)
})
.instrument(span)
.await;
let res = span
.in_scope(|| {
heartbeat(
repo,
queue,
worker_id,
job_id,
(callback)(repo, store, process_map, config, bytes.as_ref()),
)
})
.instrument(span)
.await;
repo.complete_job(queue, job_id).await?;
repo.complete_job(queue, worker_id, job_id).await?;
res?;
res?;
guard.disarm();
guard.disarm();
Ok(()) as Result<(), Error>
};
fut.instrument(tracing::info_span!("tick", worker_id = %worker_id))
.await?;
}
}
async fn heartbeat<R, Fut>(repo: &R, queue: &'static str, job_id: JobId, fut: Fut) -> Fut::Output
async fn heartbeat<R, Fut>(
repo: &R,
queue: &'static str,
worker_id: uuid::Uuid,
job_id: JobId,
fut: Fut,
) -> Fut::Output
where
R: QueueRepo,
Fut: std::future::Future,
{
let mut fut = std::pin::pin!(fut);
let mut fut =
std::pin::pin!(fut.instrument(tracing::info_span!("job-future", job_id = ?job_id)));
let mut interval = actix_rt::time::interval(Duration::from_secs(5));
@ -405,10 +427,12 @@ where
}
_ = interval.tick() => {
if hb.is_none() {
hb = Some(repo.heartbeat(queue, job_id));
hb = Some(repo.heartbeat(queue, worker_id, job_id));
}
}
opt = poll_opt(hb.as_mut()), if hb.is_some() => {
hb.take();
if let Some(Err(e)) = opt {
tracing::warn!("Failed heartbeat\n{}", format!("{e:?}"));
}
@ -423,6 +447,6 @@ where
{
match opt {
None => None,
Some(fut) => std::future::poll_fn(|cx| Pin::new(&mut *fut).poll(cx).map(Some)).await,
Some(fut) => Some(fut.await),
}
}

View file

@ -73,13 +73,8 @@ where
errors.push(e);
}
if !errors.is_empty() {
let span = tracing::error_span!("Error deleting files");
span.in_scope(|| {
for error in errors {
tracing::error!("{}", format!("{error}"));
}
});
for error in errors {
tracing::error!("{}", format!("{error:?}"));
}
Ok(())

View file

@ -296,11 +296,25 @@ impl JobId {
pub(crate) trait QueueRepo: BaseRepo {
async fn push(&self, queue: &'static str, job: Arc<[u8]>) -> Result<JobId, RepoError>;
async fn pop(&self, queue: &'static str) -> Result<(JobId, Arc<[u8]>), RepoError>;
async fn pop(
&self,
queue: &'static str,
worker_id: Uuid,
) -> Result<(JobId, Arc<[u8]>), RepoError>;
async fn heartbeat(&self, queue: &'static str, job_id: JobId) -> Result<(), RepoError>;
async fn heartbeat(
&self,
queue: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError>;
async fn complete_job(&self, queue: &'static str, job_id: JobId) -> Result<(), RepoError>;
async fn complete_job(
&self,
queue: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError>;
}
#[async_trait::async_trait(?Send)]
@ -312,16 +326,30 @@ where
T::push(self, queue, job).await
}
async fn pop(&self, queue: &'static str) -> Result<(JobId, Arc<[u8]>), RepoError> {
T::pop(self, queue).await
async fn pop(
&self,
queue: &'static str,
worker_id: Uuid,
) -> Result<(JobId, Arc<[u8]>), RepoError> {
T::pop(self, queue, worker_id).await
}
async fn heartbeat(&self, queue: &'static str, job_id: JobId) -> Result<(), RepoError> {
T::heartbeat(self, queue, job_id).await
async fn heartbeat(
&self,
queue: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError> {
T::heartbeat(self, queue, worker_id, job_id).await
}
async fn complete_job(&self, queue: &'static str, job_id: JobId) -> Result<(), RepoError> {
T::complete_job(self, queue, job_id).await
async fn complete_job(
&self,
queue: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError> {
T::complete_job(self, queue, worker_id, job_id).await
}
}

View file

@ -24,6 +24,7 @@ use std::{
};
use tokio::{sync::Notify, task::JoinHandle};
use url::Url;
use uuid::Uuid;
macro_rules! b {
($self:ident.$ident:ident, $expr:expr) => {{
@ -625,7 +626,7 @@ impl UploadRepo for SledRepo {
enum JobState {
Pending,
Running([u8; 8]),
Running([u8; 24]),
}
impl JobState {
@ -633,12 +634,26 @@ impl JobState {
Self::Pending
}
fn running() -> Self {
Self::Running(
time::OffsetDateTime::now_utc()
.unix_timestamp()
.to_be_bytes(),
)
fn running(worker_id: Uuid) -> Self {
let first_eight = time::OffsetDateTime::now_utc()
.unix_timestamp()
.to_be_bytes();
let next_sixteen = worker_id.into_bytes();
let mut bytes = [0u8; 24];
bytes[0..8]
.iter_mut()
.zip(&first_eight)
.for_each(|(dest, src)| *dest = *src);
bytes[8..24]
.iter_mut()
.zip(&next_sixteen)
.for_each(|(dest, src)| *dest = *src);
Self::Running(bytes)
}
fn as_bytes(&self) -> &[u8] {
@ -703,8 +718,12 @@ impl QueueRepo for SledRepo {
Ok(id)
}
#[tracing::instrument(skip(self))]
async fn pop(&self, queue_name: &'static str) -> Result<(JobId, Arc<[u8]>), RepoError> {
#[tracing::instrument(skip(self, worker_id), fields(job_id))]
async fn pop(
&self,
queue_name: &'static str,
worker_id: Uuid,
) -> Result<(JobId, Arc<[u8]>), RepoError> {
let metrics_guard = PopMetricsGuard::guard(queue_name);
let now = time::OffsetDateTime::now_utc();
@ -713,13 +732,15 @@ impl QueueRepo for SledRepo {
let queue = self.queue.clone();
let job_state = self.job_state.clone();
let span = tracing::Span::current();
let opt = actix_rt::task::spawn_blocking(move || {
let _guard = span.enter();
// Job IDs are generated with Uuid version 7 - defining their first bits as a
// timestamp. Scanning a prefix should give us jobs in the order they were queued.
for res in job_state.scan_prefix(queue_name) {
let (key, value) = res?;
if value.len() == 8 {
if value.len() > 8 {
let unix_timestamp =
i64::from_be_bytes(value[0..8].try_into().expect("Verified length"));
@ -734,13 +755,14 @@ impl QueueRepo for SledRepo {
}
}
let state = JobState::running();
let state = JobState::running(worker_id);
match job_state.compare_and_swap(&key, Some(value), Some(state.as_bytes())) {
Ok(_) => {
match job_state.compare_and_swap(&key, Some(value), Some(state.as_bytes()))? {
Ok(()) => {
// acquired job
}
Err(_) => {
tracing::debug!("Contested");
// someone else acquired job
continue;
}
@ -752,6 +774,8 @@ impl QueueRepo for SledRepo {
let job_id = JobId::from_bytes(id_bytes);
tracing::Span::current().record("job_id", &format!("{job_id:?}"));
let opt = queue
.get(&key)?
.map(|job_bytes| (job_id, Arc::from(job_bytes.to_vec())));
@ -790,18 +814,23 @@ impl QueueRepo for SledRepo {
}
}
#[tracing::instrument(skip(self))]
async fn heartbeat(&self, queue_name: &'static str, job_id: JobId) -> Result<(), RepoError> {
#[tracing::instrument(skip(self, worker_id))]
async fn heartbeat(
&self,
queue_name: &'static str,
worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError> {
let key = job_key(queue_name, job_id);
let job_state = self.job_state.clone();
actix_rt::task::spawn_blocking(move || {
if let Some(state) = job_state.get(&key)? {
let new_state = JobState::running();
let new_state = JobState::running(worker_id);
match job_state.compare_and_swap(&key, Some(state), Some(new_state.as_bytes()))? {
Ok(_) => Ok(()),
Ok(()) => Ok(()),
Err(_) => Err(SledError::Conflict),
}
} else {
@ -814,8 +843,13 @@ impl QueueRepo for SledRepo {
Ok(())
}
#[tracing::instrument(skip(self))]
async fn complete_job(&self, queue_name: &'static str, job_id: JobId) -> Result<(), RepoError> {
#[tracing::instrument(skip(self, _worker_id))]
async fn complete_job(
&self,
queue_name: &'static str,
_worker_id: Uuid,
job_id: JobId,
) -> Result<(), RepoError> {
let key = job_key(queue_name, job_id);
let queue = self.queue.clone();