diff --git a/crates/federate/Cargo.toml b/crates/federate/Cargo.toml index d456854c4..0c3baf348 100644 --- a/crates/federate/Cargo.toml +++ b/crates/federate/Cargo.toml @@ -19,6 +19,7 @@ lemmy_api_common.workspace = true lemmy_apub.workspace = true lemmy_db_schema = { workspace = true, features = ["full"] } lemmy_db_views_actor.workspace = true +lemmy_utils.workspace = true activitypub_federation.workspace = true anyhow.workspace = true @@ -33,6 +34,3 @@ tokio = { workspace = true, features = ["full"] } tracing.workspace = true moka.workspace = true tokio-util = "0.7.10" - -[dev-dependencies] -lemmy_utils.workspace = true \ No newline at end of file diff --git a/crates/federate/src/lib.rs b/crates/federate/src/lib.rs index d95551e34..091b819a5 100644 --- a/crates/federate/src/lib.rs +++ b/crates/federate/src/lib.rs @@ -7,9 +7,11 @@ use lemmy_db_schema::{ source::{federation_queue_state::FederationQueueState, instance::Instance}, utils::{ActualDbPool, DbPool}, }; +use lemmy_utils::error::LemmyResult; use std::{collections::HashMap, time::Duration}; use tokio::{ - sync::mpsc::{unbounded_channel, UnboundedReceiver}, + sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + task::JoinHandle, time::sleep, }; use tokio_util::sync::CancellationToken; @@ -32,112 +34,116 @@ pub struct Opts { pub process_index: i32, } -async fn start_stop_federation_workers( +pub struct SendManager { opts: Opts, - pool: ActualDbPool, - federation_config: FederationConfig, - cancel: CancellationToken, -) -> anyhow::Result<()> { - let mut workers = HashMap::::new(); + workers: HashMap, + context: FederationConfig, + stats_sender: UnboundedSender<(String, FederationQueueState)>, + exit_print: JoinHandle<()>, +} - let (stats_sender, stats_receiver) = unbounded_channel(); - let exit_print = tokio::spawn(receive_print_stats(pool.clone(), stats_receiver)); - let pool2 = &mut DbPool::Pool(&pool); - let process_index = opts.process_index - 1; - let local_domain = federation_config.settings().get_hostname_without_port()?; - info!( - "Starting federation workers for process count {} and index {}", - opts.process_count, process_index - ); - loop { - let mut total_count = 0; - let mut dead_count = 0; - let mut disallowed_count = 0; - for (instance, allowed, is_dead) in - Instance::read_federated_with_blocked_and_dead(pool2).await? - { - if instance.domain == local_domain { - continue; - } - if instance.id.inner() % opts.process_count != process_index { - continue; - } - total_count += 1; - if !allowed { - disallowed_count += 1; - } - if is_dead { - dead_count += 1; - } - let should_federate = allowed && !is_dead; - if should_federate { - if workers.contains_key(&instance.id) { - // worker already running +impl SendManager { + pub fn new(opts: Opts, context: FederationConfig) -> Self { + let (stats_sender, stats_receiver) = unbounded_channel(); + Self { + opts, + workers: HashMap::new(), + stats_sender, + exit_print: tokio::spawn(receive_print_stats( + context.inner_pool().clone(), + stats_receiver, + )), + context, + } + } + + pub fn run(mut self) -> CancellableTask { + let task = CancellableTask::spawn(WORKER_EXIT_TIMEOUT, move |cancel| async move { + self.do_loop(cancel).await.unwrap(); + self.cancel().await.unwrap(); + }); + task + } + + async fn do_loop(&mut self, cancel: CancellationToken) -> LemmyResult<()> { + let process_index = self.opts.process_index - 1; + info!( + "Starting federation workers for process count {} and index {}", + self.opts.process_count, process_index + ); + let local_domain = self.context.settings().get_hostname_without_port()?; + let mut pool = self.context.pool(); + loop { + let mut total_count = 0; + let mut dead_count = 0; + let mut disallowed_count = 0; + for (instance, allowed, is_dead) in + Instance::read_federated_with_blocked_and_dead(&mut pool).await? + { + if instance.domain == local_domain { continue; } - // create new worker - let config = federation_config.clone(); - let stats_sender = stats_sender.clone(); - let pool = pool.clone(); - workers.insert( - instance.id, - CancellableTask::spawn(WORKER_EXIT_TIMEOUT, move |stop| { - let instance = instance.clone(); - let req_data = config.clone().to_request_data(); - let stats_sender = stats_sender.clone(); - let pool = pool.clone(); - async move { - InstanceWorker::init_and_loop( - instance, - req_data, - &mut DbPool::Pool(&pool), - stop, - stats_sender, - ) - .await + if instance.id.inner() % self.opts.process_count != process_index { + continue; + } + total_count += 1; + if !allowed { + disallowed_count += 1; + } + if is_dead { + dead_count += 1; + } + let should_federate = allowed && !is_dead; + if should_federate { + if self.workers.contains_key(&instance.id) { + // worker already running + continue; + } + // create new worker + let instance = instance.clone(); + let req_data = self.context.to_request_data(); + let stats_sender = self.stats_sender.clone(); + self.workers.insert( + instance.id, + CancellableTask::spawn(WORKER_EXIT_TIMEOUT, move |stop| async move { + InstanceWorker::init_and_loop(instance, req_data, stop, stats_sender).await + }), + ); + } else if !should_federate { + if let Some(worker) = self.workers.remove(&instance.id) { + if let Err(e) = worker.cancel().await { + tracing::error!("error stopping worker: {e}"); } - }), - ); - } else if !should_federate { - if let Some(worker) = workers.remove(&instance.id) { - if let Err(e) = worker.cancel().await { - tracing::error!("error stopping worker: {e}"); } } } - } - let worker_count = workers.len(); - tracing::info!("Federating to {worker_count}/{total_count} instances ({dead_count} dead, {disallowed_count} disallowed)"); - tokio::select! { - () = sleep(INSTANCES_RECHECK_DELAY) => {}, - _ = cancel.cancelled() => { break; } + let worker_count = self.workers.len(); + tracing::info!("Federating to {worker_count}/{total_count} instances ({dead_count} dead, {disallowed_count} disallowed)"); + tokio::select! { + () = sleep(INSTANCES_RECHECK_DELAY) => {}, + _ = cancel.cancelled() => { return Ok(()) } + } } } - drop(stats_sender); - tracing::warn!( - "Waiting for {} workers ({:.2?} max)", - workers.len(), - WORKER_EXIT_TIMEOUT - ); - // the cancel futures need to be awaited concurrently for the shutdown processes to be triggered concurrently - futures::future::join_all(workers.into_values().map(util::CancellableTask::cancel)).await; - exit_print.await?; - Ok(()) -} -/// starts and stops federation workers depending on which instances are on db -/// await the returned future to stop/cancel all workers gracefully -pub fn start_stop_federation_workers_cancellable( - opts: Opts, - pool: ActualDbPool, - config: FederationConfig, -) -> CancellableTask { - CancellableTask::spawn(WORKER_EXIT_TIMEOUT, move |stop| { - let opts = opts.clone(); - let pool = pool.clone(); - let config = config.clone(); - async move { start_stop_federation_workers(opts, pool, config, stop).await } - }) + pub async fn cancel(self) -> LemmyResult<()> { + drop(self.stats_sender); + tracing::warn!( + "Waiting for {} workers ({:.2?} max)", + self.workers.len(), + WORKER_EXIT_TIMEOUT + ); + // the cancel futures need to be awaited concurrently for the shutdown processes to be triggered concurrently + futures::future::join_all( + self + .workers + .into_values() + .map(util::CancellableTask::cancel), + ) + .await; + self.exit_print.await?; + Ok(()) + } } /// every 60s, print the state for every instance. exits if the receiver is done (all senders dropped) @@ -208,15 +214,12 @@ async fn print_stats(pool: &mut DbPool<'_>, stats: &HashMap LemmyResult<()> { // initialization let context = LemmyContext::init_test_context().await; let pool = &mut context.pool(); - let actual_pool = build_db_pool_for_tests().await; let opts = Opts { process_count: 1, process_index: 1, @@ -233,7 +236,15 @@ mod test { Instance::read_or_create(pool, "gamma.com".to_string()).await?, ]; - let task = start_stop_federation_workers_cancellable(opts, actual_pool, federation_config); + // start it and wait a moment + let task = SendManager::new(opts, federation_config); + task.run(); + sleep(Duration::from_secs(1)); + + // check that correct number of instance workers was started + // TODO: need to wrap in Arc or something similar + // TODO: test with different `opts`, dead/blocked instances etc + assert_eq!(3, task.workers.len()); // cleanup for i in instances { diff --git a/crates/federate/src/util.rs b/crates/federate/src/util.rs index b057ab0fe..c7f020e1d 100644 --- a/crates/federate/src/util.rs +++ b/crates/federate/src/util.rs @@ -56,23 +56,16 @@ impl CancellableTask { /// spawn a task but with graceful shutdown pub fn spawn( timeout: Duration, - task: impl Fn(CancellationToken) -> F + Send + 'static, + task: impl FnOnce(CancellationToken) -> F + Send + 'static, ) -> CancellableTask where F: Future + Send + 'static, + R: Send + 'static, { let stop = CancellationToken::new(); let stop2 = stop.clone(); - let task: JoinHandle<()> = tokio::spawn(async move { - loop { - let res = task(stop2.clone()).await; - if stop2.is_cancelled() { - return; - } else { - tracing::warn!("task exited, restarting: {res:?}"); - } - } - }); + // TODO: need to print error + let task: JoinHandle = tokio::spawn(task(stop2.clone())); let abort = task.abort_handle(); CancellableTask { f: Box::pin(async move { diff --git a/crates/federate/src/worker.rs b/crates/federate/src/worker.rs index f6701a8d1..2cc8d087e 100644 --- a/crates/federate/src/worker.rs +++ b/crates/federate/src/worker.rs @@ -22,7 +22,7 @@ use lemmy_db_schema::{ instance::{Instance, InstanceForm}, site::Site, }, - utils::{naive_now, DbPool}, + utils::naive_now, }; use lemmy_db_views_actor::structs::CommunityFollowerView; use once_cell::sync::Lazy; @@ -80,11 +80,11 @@ impl InstanceWorker { pub(crate) async fn init_and_loop( instance: Instance, context: Data, - pool: &mut DbPool<'_>, // in theory there's a ref to the pool in context, but i couldn't get that to work wrt lifetimes stop: CancellationToken, stats_sender: UnboundedSender<(String, FederationQueueState)>, ) -> Result<(), anyhow::Error> { - let state = FederationQueueState::load(pool, instance.id).await?; + let mut pool = context.pool(); + let state = FederationQueueState::load(&mut pool, instance.id).await?; let mut worker = InstanceWorker { instance, site_loaded: false, @@ -98,31 +98,28 @@ impl InstanceWorker { state, last_state_insert: Utc.timestamp_nanos(0), }; - worker.loop_until_stopped(pool).await + worker.loop_until_stopped().await } /// loop fetch new activities from db and send them to the inboxes of the given instances /// this worker only returns if (a) there is an internal error or (b) the cancellation token is cancelled (graceful exit) - pub(crate) async fn loop_until_stopped( - &mut self, - pool: &mut DbPool<'_>, - ) -> Result<(), anyhow::Error> { + pub(crate) async fn loop_until_stopped(&mut self) -> Result<(), anyhow::Error> { debug!("Starting federation worker for {}", self.instance.domain); let save_state_every = chrono::Duration::from_std(SAVE_STATE_EVERY_TIME).expect("not negative"); - self.update_communities(pool).await?; + self.update_communities().await?; self.initial_fail_sleep().await?; while !self.stop.is_cancelled() { - self.loop_batch(pool).await?; + self.loop_batch().await?; if self.stop.is_cancelled() { break; } if (Utc::now() - self.last_state_insert) > save_state_every { - self.save_and_send_state(pool).await?; + self.save_and_send_state().await?; } - self.update_communities(pool).await?; + self.update_communities().await?; } // final update of state in db - self.save_and_send_state(pool).await?; + self.save_and_send_state().await?; Ok(()) } @@ -147,8 +144,8 @@ impl InstanceWorker { Ok(()) } /// send out a batch of CHECK_SAVE_STATE_EVERY_IT activities - async fn loop_batch(&mut self, pool: &mut DbPool<'_>) -> Result<()> { - let latest_id = get_latest_activity_id(pool).await?; + async fn loop_batch(&mut self) -> Result<()> { + let latest_id = get_latest_activity_id(&mut self.context.pool()).await?; let mut id = if let Some(id) = self.state.last_successful_id { id } else { @@ -156,7 +153,7 @@ impl InstanceWorker { // skip all past activities: self.state.last_successful_id = Some(latest_id); // save here to ensure it's not read as 0 again later if no activities have happened - self.save_and_send_state(pool).await?; + self.save_and_send_state().await?; latest_id }; if id >= latest_id { @@ -174,7 +171,7 @@ impl InstanceWorker { { id = ActivityId(id.0 + 1); processed_activities += 1; - let Some(ele) = get_activity_cached(pool, id) + let Some(ele) = get_activity_cached(&mut self.context.pool(), id) .await .context("failed reading activity from db")? else { @@ -182,7 +179,7 @@ impl InstanceWorker { self.state.last_successful_id = Some(id); continue; }; - if let Err(e) = self.send_retry_loop(pool, &ele.0, &ele.1).await { + if let Err(e) = self.send_retry_loop(&ele.0, &ele.1).await { warn!( "sending {} errored internally, skipping activity: {:?}", ele.0.ap_id, e @@ -203,12 +200,11 @@ impl InstanceWorker { // and will return an error if an internal error occurred (send errors cause an infinite loop) async fn send_retry_loop( &mut self, - pool: &mut DbPool<'_>, activity: &SentActivity, object: &SharedInboxActivities, ) -> Result<()> { let inbox_urls = self - .get_inbox_urls(pool, activity) + .get_inbox_urls(activity) .await .context("failed figuring out inbox urls")?; if inbox_urls.is_empty() { @@ -220,7 +216,7 @@ impl InstanceWorker { let Some(actor_apub_id) = &activity.actor_apub_id else { return Ok(()); // activity was inserted before persistent queue was activated }; - let actor = get_actor_cached(pool, activity.actor_type, actor_apub_id) + let actor = get_actor_cached(&mut self.context.pool(), activity.actor_type, actor_apub_id) .await .context("failed getting actor instance (was it marked deleted / removed?)")?; @@ -239,7 +235,7 @@ impl InstanceWorker { "{}: retrying {:?} attempt {} with delay {retry_delay:.2?}. ({e})", self.instance.domain, activity.id, self.state.fail_count ); - self.save_and_send_state(pool).await?; + self.save_and_send_state().await?; tokio::select! { () = sleep(retry_delay) => {}, () = self.stop.cancelled() => { @@ -258,7 +254,7 @@ impl InstanceWorker { .domain(self.instance.domain.clone()) .updated(Some(naive_now())) .build(); - Instance::update(pool, self.instance.id, form).await?; + Instance::update(&mut self.context.pool(), self.instance.id, form).await?; } } Ok(()) @@ -268,16 +264,12 @@ impl InstanceWorker { /// most often this will return 0 values (if instance doesn't care about the activity) /// or 1 value (the shared inbox) /// > 1 values only happens for non-lemmy software - async fn get_inbox_urls( - &mut self, - pool: &mut DbPool<'_>, - activity: &SentActivity, - ) -> Result> { + async fn get_inbox_urls(&mut self, activity: &SentActivity) -> Result> { let mut inbox_urls: HashSet = HashSet::new(); if activity.send_all_instances { if !self.site_loaded { - self.site = Site::read_from_instance_id(pool, self.instance.id).await?; + self.site = Site::read_from_instance_id(&mut self.context.pool(), self.instance.id).await?; self.site_loaded = true; } if let Some(site) = &self.site { @@ -301,22 +293,18 @@ impl InstanceWorker { Ok(inbox_urls) } - async fn update_communities(&mut self, pool: &mut DbPool<'_>) -> Result<()> { + async fn update_communities(&mut self) -> Result<()> { if (Utc::now() - self.last_full_communities_fetch) > *FOLLOW_REMOVALS_RECHECK_DELAY { // process removals every hour (self.followed_communities, self.last_full_communities_fetch) = self - .get_communities(pool, self.instance.id, Utc.timestamp_nanos(0)) + .get_communities(self.instance.id, Utc.timestamp_nanos(0)) .await?; self.last_incremental_communities_fetch = self.last_full_communities_fetch; } if (Utc::now() - self.last_incremental_communities_fetch) > *FOLLOW_ADDITIONS_RECHECK_DELAY { // process additions every minute let (news, time) = self - .get_communities( - pool, - self.instance.id, - self.last_incremental_communities_fetch, - ) + .get_communities(self.instance.id, self.last_incremental_communities_fetch) .await?; self.followed_communities.extend(news); self.last_incremental_communities_fetch = time; @@ -327,26 +315,29 @@ impl InstanceWorker { /// get a list of local communities with the remote inboxes on the given instance that cares about them async fn get_communities( &mut self, - pool: &mut DbPool<'_>, instance_id: InstanceId, last_fetch: DateTime, ) -> Result<(HashMap>, DateTime)> { let new_last_fetch = Utc::now() - chrono::TimeDelta::try_seconds(10).expect("TimeDelta out of bounds"); // update to time before fetch to ensure overlap. subtract 10s to ensure overlap even if published date is not exact Ok(( - CommunityFollowerView::get_instance_followed_community_inboxes(pool, instance_id, last_fetch) - .await? - .into_iter() - .fold(HashMap::new(), |mut map, (c, u)| { - map.entry(c).or_default().insert(u.into()); - map - }), + CommunityFollowerView::get_instance_followed_community_inboxes( + &mut self.context.pool(), + instance_id, + last_fetch, + ) + .await? + .into_iter() + .fold(HashMap::new(), |mut map, (c, u)| { + map.entry(c).or_default().insert(u.into()); + map + }), new_last_fetch, )) } - async fn save_and_send_state(&mut self, pool: &mut DbPool<'_>) -> Result<()> { + async fn save_and_send_state(&mut self) -> Result<()> { self.last_state_insert = Utc::now(); - FederationQueueState::upsert(pool, &self.state).await?; + FederationQueueState::upsert(&mut self.context.pool(), &self.state).await?; self .stats_sender .send((self.instance.domain.clone(), self.state.clone()))?; diff --git a/src/lib.rs b/src/lib.rs index 61e8abd13..c3be5a191 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ use lemmy_apub::{ FEDERATION_HTTP_FETCH_LIMIT, }; use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool}; -use lemmy_federate::{start_stop_federation_workers_cancellable, Opts}; +use lemmy_federate::{Opts, SendManager}; use lemmy_routes::{feeds, images, nodeinfo, webfinger}; use lemmy_utils::{ error::LemmyResult, @@ -206,14 +206,14 @@ pub async fn start_lemmy_server(args: CmdArgs) -> LemmyResult<()> { None }; let federate = (!args.disable_activity_sending).then(|| { - start_stop_federation_workers_cancellable( + let task = SendManager::new( Opts { process_index: args.federate_process_index, process_count: args.federate_process_count, }, - pool.clone(), - federation_config.clone(), - ) + federation_config, + ); + task.run() }); let mut interrupt = tokio::signal::unix::signal(SignalKind::interrupt())?; let mut terminate = tokio::signal::unix::signal(SignalKind::terminate())?;