Compare commits

...

2 commits

Author SHA1 Message Date
98f0951d33
Update example 2023-03-22 10:51:33 +01:00
3285647117
Allow to schedule tasks directly with pg connection 2023-03-22 10:50:52 +01:00
5 changed files with 142 additions and 51 deletions

View file

@ -1,20 +1,33 @@
use async_trait::async_trait; use async_trait::async_trait;
use backie::{BackgroundTask, CurrentTask}; use backie::{BackgroundTask, CurrentTask, QueueConfig, RetentionMode};
use backie::{PgTaskStore, Queue, WorkerPool}; use backie::{PgTaskStore, Queue, WorkerPool};
use diesel_async::pg::AsyncPgConnection; use diesel_async::pg::AsyncPgConnection;
use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager}; use diesel_async::pooled_connection::{bb8::Pool, AsyncDieselConnectionManager};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::time::Duration; use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio::task::JoinSet;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct MyApplicationContext { pub struct MyApplicationContext {
app_name: String, app_name: String,
notify_finished: Arc<Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
} }
impl MyApplicationContext { impl MyApplicationContext {
pub fn new(app_name: &str) -> Self { pub fn new(app_name: &str, notify_finished: tokio::sync::oneshot::Sender<()>) -> Self {
Self { Self {
app_name: app_name.to_string(), app_name: app_name.to_string(),
notify_finished: Arc::new(Mutex::new(Some(notify_finished))),
}
}
pub async fn notify_finished(&self) {
let mut lock = self.notify_finished.lock().await;
if let Some(sender) = lock.take() {
sender.send(()).unwrap();
} }
} }
} }
@ -37,12 +50,6 @@ impl BackgroundTask for MyTask {
type Error = anyhow::Error; type Error = anyhow::Error;
async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> { async fn run(&self, task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> {
// let new_task = MyTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
log::info!( log::info!(
"[{}] Hello from {}! the current number is {}", "[{}] Hello from {}! the current number is {}",
task.id(), task.id(),
@ -74,17 +81,6 @@ impl BackgroundTask for MyFailingTask {
type Error = anyhow::Error; type Error = anyhow::Error;
async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> { async fn run(&self, task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> {
// let new_task = MyFailingTask::new(self.number + 1);
// queue
// .insert_task(&new_task)
// .await
// .unwrap();
// task.id();
// task.keep_alive().await?;
// task.previous_error();
// task.retry_count();
log::info!("[{}] the current number is {}", task.id(), self.number); log::info!("[{}] the current number is {}", task.id(), self.number);
tokio::time::sleep(Duration::from_secs(3)).await; tokio::time::sleep(Duration::from_secs(3)).await;
@ -93,58 +89,124 @@ impl BackgroundTask for MyFailingTask {
} }
} }
#[derive(Serialize, Deserialize)]
struct EmptyTask {
pub idx: u64,
}
#[async_trait]
impl BackgroundTask for EmptyTask {
const TASK_NAME: &'static str = "empty_task";
const QUEUE: &'static str = "loaded_queue";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, _task: CurrentTask, _ctx: Self::AppData) -> Result<(), Self::Error> {
Ok(())
}
}
#[derive(Serialize, Deserialize)]
struct FinalTask;
#[async_trait]
impl BackgroundTask for FinalTask {
const TASK_NAME: &'static str = "final_task";
const QUEUE: &'static str = "loaded_queue";
type AppData = MyApplicationContext;
type Error = anyhow::Error;
async fn run(&self, _task: CurrentTask, ctx: Self::AppData) -> Result<(), Self::Error> {
ctx.notify_finished().await;
Ok(())
}
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() -> anyhow::Result<()> {
env_logger::init(); env_logger::init();
let connection_url = "postgres://postgres:password@localhost/backie"; let connection_url = "postgres://postgres:password@localhost/backie";
log::info!("Starting..."); log::info!("Starting...");
let max_pool_size: u32 = 3;
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(connection_url); let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(connection_url);
let pool = Pool::builder() let pool = Pool::builder()
.max_size(max_pool_size) .max_size(300)
.min_idle(Some(1)) .min_idle(Some(1))
.build(manager) .build(manager)
.await .await
.unwrap(); .unwrap();
log::info!("Pool created ..."); log::info!("Pool created ...");
let task_store = PgTaskStore::new(pool);
let (tx, mut rx) = tokio::sync::watch::channel(false); let (tx, mut rx) = tokio::sync::watch::channel(false);
let (notify_finished, wait_done) = tokio::sync::oneshot::channel();
// Some global application context I want to pass to my background tasks // Some global application context I want to pass to my background tasks
let my_app_context = MyApplicationContext::new("Backie Example App"); let my_app_context = MyApplicationContext::new("Backie Example App", notify_finished);
// queue.enqueue(task1).await.unwrap();
// queue.enqueue(task2).await.unwrap();
// queue.enqueue(task3).await.unwrap();
// Store all task to join them later
let mut tasks = JoinSet::new();
for i in 0..1_000 {
tasks.spawn({
let pool = pool.clone();
async move {
let mut connection = pool.get().await.unwrap();
let task = EmptyTask { idx: i };
task.enqueue(&mut connection).await.unwrap();
}
});
}
while let Some(result) = tasks.join_next().await {
let _ = result?;
}
(FinalTask {})
.enqueue(&mut pool.get().await.unwrap())
.await
.unwrap();
log::info!("Tasks created ...");
let started = Instant::now();
// Register the task types I want to use and start the worker pool // Register the task types I want to use and start the worker pool
let join_handle = let join_handle = WorkerPool::new(PgTaskStore::new(pool.clone()), move || my_app_context.clone())
WorkerPool::new(task_store.clone(), move || my_app_context.clone()) .register_task_type::<MyTask>()
.register_task_type::<MyTask>() .register_task_type::<MyFailingTask>()
.register_task_type::<MyFailingTask>() .register_task_type::<EmptyTask>()
.configure_queue("default".into()) .register_task_type::<FinalTask>()
.start(async move { .configure_queue("default".into())
let _ = rx.changed().await; .configure_queue(
}) QueueConfig::new("loaded_queue")
.await .pull_interval(Duration::from_millis(100))
.unwrap(); .retention_mode(RetentionMode::RemoveDone)
.num_workers(300),
)
.start(async move {
let _ = rx.changed().await;
})
.await
.unwrap();
log::info!("Workers started ..."); log::info!("Workers started ...");
let task1 = MyTask::new(0); wait_done.await.unwrap();
let task2 = MyTask::new(20_000); let elapsed = started.elapsed();
let task3 = MyFailingTask::new(50_000); println!("Ran 50k jobs in {} seconds", elapsed.as_secs());
let queue = Queue::new(task_store);
queue.enqueue(task1).await.unwrap();
queue.enqueue(task2).await.unwrap();
queue.enqueue(task3).await.unwrap();
log::info!("Tasks created ...");
// Wait for Ctrl+C // Wait for Ctrl+C
let _ = tokio::signal::ctrl_c().await; // let _ = tokio::signal::ctrl_c().await;
log::info!("Stopping ..."); log::info!("Stopping ...");
tx.send(true).unwrap(); tx.send(true).unwrap();
join_handle.await.unwrap(); join_handle.await.unwrap();
log::info!("Workers Stopped!"); log::info!("Workers Stopped!");
Ok(())
} }

View file

@ -7,7 +7,10 @@ use chrono::Duration;
use chrono::Utc; use chrono::Utc;
use diesel::prelude::*; use diesel::prelude::*;
use diesel::ExpressionMethods; use diesel::ExpressionMethods;
use diesel_async::{pg::AsyncPgConnection, RunQueryDsl}; use diesel::query_builder::{Query, QueryFragment, QueryId};
use diesel::sql_types::{HasSqlType, SingleValue};
use diesel_async::return_futures::GetResult;
use diesel_async::{pg::AsyncPgConnection, AsyncConnection, RunQueryDsl};
impl Task { impl Task {
pub(crate) async fn remove( pub(crate) async fn remove(

View file

@ -25,7 +25,7 @@ where
{ {
// TODO: Add option to specify the timeout of a task // TODO: Add option to specify the timeout of a task
self.task_store self.task_store
.create_task(NewTask::new(background_task, Duration::from_secs(10))?) .create_task(NewTask::with_timeout(background_task, Duration::from_secs(10))?)
.await?; .await?;
Ok(()) Ok(())
} }

View file

@ -1,5 +1,7 @@
use crate::errors::AsyncQueueError; use crate::errors::AsyncQueueError;
use crate::task::{NewTask, Task, TaskId, TaskState}; use crate::task::{NewTask, Task, TaskId, TaskState};
use crate::BackgroundTask;
use async_trait::async_trait;
use diesel::result::Error::QueryBuilderError; use diesel::result::Error::QueryBuilderError;
use diesel_async::scoped_futures::ScopedFutureExt; use diesel_async::scoped_futures::ScopedFutureExt;
use diesel_async::AsyncConnection; use diesel_async::AsyncConnection;
@ -17,6 +19,23 @@ impl PgTaskStore {
} }
} }
/// A trait that is used to enqueue tasks for the PostgreSQL backend.
#[async_trait::async_trait]
pub trait PgQueueTask {
async fn enqueue(self, connection: &mut AsyncPgConnection) -> Result<(), AsyncQueueError>;
}
impl<T> PgQueueTask for T
where
T: BackgroundTask,
{
async fn enqueue(self, connection: &mut AsyncPgConnection) -> Result<(), AsyncQueueError> {
let new_task = NewTask::new::<T>(self)?;
Task::insert(connection, new_task).await?;
Ok(())
}
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl TaskStore for PgTaskStore { impl TaskStore for PgTaskStore {
async fn pull_next_task( async fn pull_next_task(

View file

@ -117,9 +117,9 @@ pub struct NewTask {
} }
impl NewTask { impl NewTask {
pub(crate) fn new<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error> pub(crate) fn with_timeout<T>(background_task: T, timeout: Duration) -> Result<Self, serde_json::Error>
where where
T: BackgroundTask, T: BackgroundTask,
{ {
let max_retries = background_task.max_retries(); let max_retries = background_task.max_retries();
let uniq_hash = background_task.uniq(); let uniq_hash = background_task.uniq();
@ -134,6 +134,13 @@ impl NewTask {
max_retries, max_retries,
}) })
} }
pub(crate) fn new<T>(background_task: T) -> Result<Self, serde_json::Error>
where
T: BackgroundTask,
{
Self::with_timeout(background_task, Duration::from_secs(120))
}
} }
#[cfg(test)] #[cfg(test)]