backie/src/worker.rs

338 lines
10 KiB
Rust
Raw Permalink Normal View History

2023-03-12 17:33:00 +00:00
use crate::catch_unwind::CatchUnwindFuture;
use crate::errors::{AsyncQueueError, BackieError};
use crate::runnable::BackgroundTask;
2023-03-11 15:38:32 +00:00
use crate::store::TaskStore;
use crate::task::{CurrentTask, Task, TaskState};
2023-03-11 16:49:23 +00:00
use crate::RetentionMode;
2023-03-09 15:59:45 +00:00
use futures::future::FutureExt;
use futures::select;
use std::collections::BTreeMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
pub type ExecuteTaskFn<AppData> = Arc<
dyn Fn(
CurrentTask,
serde_json::Value,
AppData,
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
+ Send
+ Sync,
>;
pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
2023-03-12 17:33:00 +00:00
#[derive(Debug, thiserror::Error)]
pub enum TaskExecError {
2023-03-12 17:33:00 +00:00
#[error("Task deserialization failed: {0}")]
TaskDeserializationFailed(#[from] serde_json::Error),
#[error("Task execution failed: {0}")]
ExecutionFailed(#[from] anyhow::Error),
2023-03-12 17:33:00 +00:00
#[error("Task panicked with: {0}")]
Panicked(String),
}
pub(crate) fn runnable<BT>(
task_info: CurrentTask,
payload: serde_json::Value,
app_context: BT::AppData,
) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
where
BT: BackgroundTask,
{
Box::pin(async move {
let background_task: BT = serde_json::from_value(payload)?;
background_task.run(task_info, app_context).await?;
Ok(())
})
}
2023-03-04 18:07:17 +00:00
/// Worker that executes tasks.
2023-03-11 16:49:23 +00:00
pub struct Worker<AppData, S>
2023-03-04 18:07:17 +00:00
where
AppData: Clone + Send + 'static,
2023-03-11 16:49:23 +00:00
S: TaskStore,
2023-03-04 18:07:17 +00:00
{
2023-03-11 16:49:23 +00:00
store: Arc<S>,
queue_name: String,
retention_mode: RetentionMode,
pull_interval: Duration,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
app_data_fn: StateFn<AppData>,
2023-03-09 15:59:45 +00:00
/// Notification for the worker to stop.
shutdown: Option<tokio::sync::watch::Receiver<()>>,
2023-03-04 18:07:17 +00:00
}
2023-03-11 16:49:23 +00:00
impl<AppData, S> Worker<AppData, S>
2023-03-04 18:07:17 +00:00
where
AppData: Clone + Send + 'static,
2023-03-11 16:49:23 +00:00
S: TaskStore,
2023-03-04 18:07:17 +00:00
{
pub(crate) fn new(
2023-03-11 16:49:23 +00:00
store: Arc<S>,
queue_name: String,
retention_mode: RetentionMode,
pull_interval: Duration,
task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
app_data_fn: StateFn<AppData>,
shutdown: Option<tokio::sync::watch::Receiver<()>>,
) -> Self {
Self {
store,
queue_name,
retention_mode,
pull_interval,
task_registry,
app_data_fn,
shutdown,
}
}
2023-03-09 18:12:50 +00:00
pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> {
2023-03-11 21:33:25 +00:00
let registered_task_names = self.task_registry.keys().cloned().collect::<Vec<_>>();
2023-03-09 18:12:50 +00:00
loop {
// Check if has to stop before pulling next task
if let Some(ref shutdown) = self.shutdown {
if shutdown.has_changed()? {
return Ok(());
}
};
2023-03-11 21:22:25 +00:00
match self
.store
.pull_next_task(&self.queue_name, &registered_task_names)
.await?
{
2023-03-09 18:12:50 +00:00
Some(task) => {
self.run(task).await?;
2023-03-09 18:12:50 +00:00
}
None => {
// Listen to watchable future
// All that until a max timeout
match &mut self.shutdown {
Some(recv) => {
// Listen to watchable future
// All that until a max timeout
select! {
_ = recv.changed().fuse() => {
log::info!("Shutting down worker");
return Ok(());
}
_ = tokio::time::sleep(self.pull_interval).fuse() => {}
2023-03-09 18:12:50 +00:00
}
}
None => {
tokio::time::sleep(self.pull_interval).await;
2023-03-09 18:12:50 +00:00
}
};
}
};
}
}
async fn run(&self, task: Task) -> Result<(), BackieError> {
let task_info = CurrentTask::new(&task);
let runnable_task_caller = self
.task_registry
.get(&task.task_name)
.ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
2023-03-09 18:12:50 +00:00
2023-03-12 17:33:00 +00:00
// catch panics
let result: Result<(), TaskExecError> = CatchUnwindFuture::create({
let task_payload = task.payload.clone();
let app_data = (self.app_data_fn)();
let runnable_task_caller = runnable_task_caller.clone();
async move { runnable_task_caller(task_info, task_payload, app_data).await }
})
.await
.and_then(|result| {
result?;
Ok(())
});
match &result {
Ok(_) => self.finalize_task(task, result).await?,
Err(error) => {
if task.retries < task.max_retries {
2023-03-11 15:38:32 +00:00
let backoff_seconds = 5; // TODO: runnable_task.backoff(task.retries as u32);
2023-03-04 18:07:17 +00:00
log::debug!(
"Task {} failed to run and will be retried in {} seconds",
task.id,
backoff_seconds
);
2023-03-12 17:33:00 +00:00
let error_message = format!("{}", error);
2023-03-12 17:33:00 +00:00
self.store
.schedule_task_retry(task.id, backoff_seconds, &error_message)
2023-03-04 18:07:17 +00:00
.await?;
} else {
log::debug!("Task {} failed and reached the maximum retries", task.id);
self.finalize_task(task, result).await?;
2023-03-04 18:07:17 +00:00
}
}
}
Ok(())
}
async fn finalize_task(
&self,
2023-03-04 18:07:17 +00:00
task: Task,
result: Result<(), TaskExecError>,
) -> Result<(), BackieError> {
2023-03-04 18:07:17 +00:00
match self.retention_mode {
RetentionMode::KeepAll => match result {
Ok(_) => {
self.store.set_task_state(task.id, TaskState::Done).await?;
log::debug!("Task {} done and kept in the database", task.id);
2023-03-04 18:07:17 +00:00
}
Err(error) => {
log::debug!("Task {} failed and kept in the database", task.id);
self.store
.set_task_state(task.id, TaskState::Failed(format!("{}", error)))
2023-03-07 16:52:26 +00:00
.await?;
2023-03-04 18:07:17 +00:00
}
},
RetentionMode::RemoveAll => {
log::debug!("Task {} finalized and deleted from the database", task.id);
self.store.remove_task(task.id).await?;
2023-03-04 18:07:17 +00:00
}
RetentionMode::RemoveDone => match result {
2023-03-04 18:07:17 +00:00
Ok(_) => {
log::debug!("Task {} done and deleted from the database", task.id);
self.store.remove_task(task.id).await?;
2023-03-04 18:07:17 +00:00
}
Err(error) => {
log::debug!("Task {} failed and kept in the database", task.id);
self.store
.set_task_state(task.id, TaskState::Failed(format!("{}", error)))
2023-03-07 16:52:26 +00:00
.await?;
2023-03-04 18:07:17 +00:00
}
},
};
Ok(())
}
}
#[cfg(test)]
mod async_worker_tests {
use super::*;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
2023-03-09 15:59:45 +00:00
#[derive(thiserror::Error, Debug)]
2023-03-09 15:59:45 +00:00
enum TaskError {
#[error("Something went wrong")]
2023-03-09 15:59:45 +00:00
SomethingWrong,
#[error("{0}")]
Custom(String),
2023-03-09 15:59:45 +00:00
}
2023-03-04 18:07:17 +00:00
#[derive(Serialize, Deserialize)]
struct WorkerAsyncTask {
pub number: u16,
}
#[async_trait]
impl BackgroundTask for WorkerAsyncTask {
const TASK_NAME: &'static str = "WorkerAsyncTask";
type AppData = ();
async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> {
2023-03-04 18:07:17 +00:00
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct WorkerAsyncTaskSchedule {
pub number: u16,
}
#[async_trait]
impl BackgroundTask for WorkerAsyncTaskSchedule {
const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
type AppData = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
2023-03-04 18:07:17 +00:00
Ok(())
}
// fn cron(&self) -> Option<Scheduled> {
// Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
// }
2023-03-04 18:07:17 +00:00
}
#[derive(Serialize, Deserialize)]
struct AsyncFailedTask {
pub number: u16,
}
#[async_trait]
impl BackgroundTask for AsyncFailedTask {
const TASK_NAME: &'static str = "AsyncFailedTask";
type AppData = ();
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
2023-03-04 18:07:17 +00:00
let message = format!("number {} is wrong :(", self.number);
Err(TaskError::Custom(message).into())
2023-03-04 18:07:17 +00:00
}
fn max_retries(&self) -> i32 {
0
}
}
#[derive(Serialize, Deserialize, Clone)]
struct AsyncRetryTask {}
#[async_trait]
impl BackgroundTask for AsyncRetryTask {
const TASK_NAME: &'static str = "AsyncRetryTask";
type AppData = ();
2023-03-04 18:07:17 +00:00
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Err(TaskError::SomethingWrong.into())
2023-03-04 18:07:17 +00:00
}
}
#[derive(Serialize, Deserialize)]
struct AsyncTaskType1 {}
#[async_trait]
impl BackgroundTask for AsyncTaskType1 {
const TASK_NAME: &'static str = "AsyncTaskType1";
type AppData = ();
2023-03-04 18:07:17 +00:00
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
2023-03-04 18:07:17 +00:00
}
}
#[derive(Serialize, Deserialize)]
struct AsyncTaskType2 {}
#[async_trait]
impl BackgroundTask for AsyncTaskType2 {
const TASK_NAME: &'static str = "AsyncTaskType2";
type AppData = ();
2023-03-04 19:46:09 +00:00
async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> {
Ok(())
2023-03-09 15:59:45 +00:00
}
}
2023-03-04 18:07:17 +00:00
}