use crate::catch_unwind::CatchUnwindFuture; use crate::errors::{AsyncQueueError, BackieError}; use crate::runnable::BackgroundTask; use crate::store::TaskStore; use crate::task::{CurrentTask, Task, TaskState}; use crate::RetentionMode; use futures::future::FutureExt; use futures::select; use std::collections::BTreeMap; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; pub type ExecuteTaskFn = Arc< dyn Fn( CurrentTask, serde_json::Value, AppData, ) -> Pin> + Send>> + Send + Sync, >; pub type StateFn = Arc AppData + Send + Sync>; #[derive(Debug, thiserror::Error)] pub enum TaskExecError { #[error("Task deserialization failed: {0}")] TaskDeserializationFailed(#[from] serde_json::Error), #[error("Task execution failed: {0}")] ExecutionFailed(#[from] anyhow::Error), #[error("Task panicked with: {0}")] Panicked(String), } pub(crate) fn runnable( task_info: CurrentTask, payload: serde_json::Value, app_context: BT::AppData, ) -> Pin> + Send>> where BT: BackgroundTask, { Box::pin(async move { let background_task: BT = serde_json::from_value(payload)?; background_task.run(task_info, app_context).await?; Ok(()) }) } /// Worker that executes tasks. pub struct Worker where AppData: Clone + Send + 'static, S: TaskStore, { store: Arc, queue_name: String, retention_mode: RetentionMode, pull_interval: Duration, task_registry: BTreeMap>, app_data_fn: StateFn, /// Notification for the worker to stop. shutdown: Option>, } impl Worker where AppData: Clone + Send + 'static, S: TaskStore, { pub(crate) fn new( store: Arc, queue_name: String, retention_mode: RetentionMode, pull_interval: Duration, task_registry: BTreeMap>, app_data_fn: StateFn, shutdown: Option>, ) -> Self { Self { store, queue_name, retention_mode, pull_interval, task_registry, app_data_fn, shutdown, } } pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> { let registered_task_names = self.task_registry.keys().cloned().collect::>(); loop { // Check if has to stop before pulling next task if let Some(ref shutdown) = self.shutdown { if shutdown.has_changed()? { return Ok(()); } }; match self .store .pull_next_task(&self.queue_name, ®istered_task_names) .await? { Some(task) => { self.run(task).await?; } None => { // Listen to watchable future // All that until a max timeout match &mut self.shutdown { Some(recv) => { // Listen to watchable future // All that until a max timeout select! { _ = recv.changed().fuse() => { log::info!("Shutting down worker"); return Ok(()); } _ = tokio::time::sleep(self.pull_interval).fuse() => {} } } None => { tokio::time::sleep(self.pull_interval).await; } }; } }; } } async fn run(&self, task: Task) -> Result<(), BackieError> { let task_info = CurrentTask::new(&task); let runnable_task_caller = self .task_registry .get(&task.task_name) .ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?; // catch panics let result: Result<(), TaskExecError> = CatchUnwindFuture::create({ let task_payload = task.payload.clone(); let app_data = (self.app_data_fn)(); let runnable_task_caller = runnable_task_caller.clone(); async move { runnable_task_caller(task_info, task_payload, app_data).await } }) .await .and_then(|result| { result?; Ok(()) }); match &result { Ok(_) => self.finalize_task(task, result).await?, Err(error) => { if task.retries < task.max_retries { let backoff_seconds = 5; // TODO: runnable_task.backoff(task.retries as u32); log::debug!( "Task {} failed to run and will be retried in {} seconds", task.id, backoff_seconds ); let error_message = format!("{}", error); self.store .schedule_task_retry(task.id, backoff_seconds, &error_message) .await?; } else { log::debug!("Task {} failed and reached the maximum retries", task.id); self.finalize_task(task, result).await?; } } } Ok(()) } async fn finalize_task( &self, task: Task, result: Result<(), TaskExecError>, ) -> Result<(), BackieError> { match self.retention_mode { RetentionMode::KeepAll => match result { Ok(_) => { self.store.set_task_state(task.id, TaskState::Done).await?; log::debug!("Task {} done and kept in the database", task.id); } Err(error) => { log::debug!("Task {} failed and kept in the database", task.id); self.store .set_task_state(task.id, TaskState::Failed(format!("{}", error))) .await?; } }, RetentionMode::RemoveAll => { log::debug!("Task {} finalized and deleted from the database", task.id); self.store.remove_task(task.id).await?; } RetentionMode::RemoveDone => match result { Ok(_) => { log::debug!("Task {} done and deleted from the database", task.id); self.store.remove_task(task.id).await?; } Err(error) => { log::debug!("Task {} failed and kept in the database", task.id); self.store .set_task_state(task.id, TaskState::Failed(format!("{}", error))) .await?; } }, }; Ok(()) } } #[cfg(test)] mod async_worker_tests { use super::*; use async_trait::async_trait; use serde::{Deserialize, Serialize}; #[derive(thiserror::Error, Debug)] enum TaskError { #[error("Something went wrong")] SomethingWrong, #[error("{0}")] Custom(String), } #[derive(Serialize, Deserialize)] struct WorkerAsyncTask { pub number: u16, } #[async_trait] impl BackgroundTask for WorkerAsyncTask { const TASK_NAME: &'static str = "WorkerAsyncTask"; type AppData = (); async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } } #[derive(Serialize, Deserialize)] struct WorkerAsyncTaskSchedule { pub number: u16, } #[async_trait] impl BackgroundTask for WorkerAsyncTaskSchedule { const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule"; type AppData = (); async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } // fn cron(&self) -> Option { // Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1))) // } } #[derive(Serialize, Deserialize)] struct AsyncFailedTask { pub number: u16, } #[async_trait] impl BackgroundTask for AsyncFailedTask { const TASK_NAME: &'static str = "AsyncFailedTask"; type AppData = (); async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { let message = format!("number {} is wrong :(", self.number); Err(TaskError::Custom(message).into()) } fn max_retries(&self) -> i32 { 0 } } #[derive(Serialize, Deserialize, Clone)] struct AsyncRetryTask {} #[async_trait] impl BackgroundTask for AsyncRetryTask { const TASK_NAME: &'static str = "AsyncRetryTask"; type AppData = (); async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { Err(TaskError::SomethingWrong.into()) } } #[derive(Serialize, Deserialize)] struct AsyncTaskType1 {} #[async_trait] impl BackgroundTask for AsyncTaskType1 { const TASK_NAME: &'static str = "AsyncTaskType1"; type AppData = (); async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } } #[derive(Serialize, Deserialize)] struct AsyncTaskType2 {} #[async_trait] impl BackgroundTask for AsyncTaskType2 { const TASK_NAME: &'static str = "AsyncTaskType2"; type AppData = (); async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), anyhow::Error> { Ok(()) } } }