diff --git a/.gitignore b/.gitignore index 181f3ef..7785759 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ /target /.idea /Cargo.lock +perf.data* +flamegraph.svg \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 9f4d2e1..37471ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,26 +23,37 @@ openssl = "0.10.54" once_cell = "1.18.0" http = "0.2.9" sha2 = "0.10.6" -background-jobs = "0.13.0" thiserror = "1.0.40" derive_builder = "0.12.0" itertools = "0.10.5" dyn-clone = "1.0.11" enum_delegate = "0.2.0" httpdate = "1.0.2" -http-signature-normalization-reqwest = { version = "0.8.0", default-features = false, features = ["sha-2", "middleware"] } +http-signature-normalization-reqwest = { version = "0.8.0", default-features = false, features = [ + "sha-2", + "middleware", +] } http-signature-normalization = "0.7.0" bytes = "1.4.0" futures-core = { version = "0.3.28", default-features = false } pin-project-lite = "0.2.9" activitystreams-kinds = "0.3.0" regex = { version = "1.8.4", default-features = false, features = ["std"] } +tokio = { version = "1.21.2", features = [ + "sync", + "rt", + "rt-multi-thread", + "time", +] } # Actix-web actix-web = { version = "4.3.1", default-features = false, optional = true } # Axum -axum = { version = "0.6.18", features = ["json", "headers"], default-features = false, optional = true } +axum = { version = "0.6.18", features = [ + "json", + "headers", +], default-features = false, optional = true } tower = { version = "0.4.13", optional = true } hyper = { version = "0.14", optional = true } displaydoc = "0.2.4" @@ -56,9 +67,13 @@ axum = ["dep:axum", "dep:tower", "dep:hyper"] rand = "0.8.5" env_logger = "0.10.0" tower-http = { version = "0.4.0", features = ["map-request-body", "util"] } -axum = { version = "0.6.18", features = ["http1", "tokio", "query"], default-features = false } +axum = { version = "0.6.18", features = [ + "http1", + "tokio", + "query", +], default-features = false } axum-macros = "0.3.7" -actix-rt = "2.8.0" +tokio = { version = "1.21.2", features = ["full"] } [profile.dev] strip = "symbols" diff --git a/docs/05_configuration.md b/docs/05_configuration.md index 4c67cc6..315bc66 100644 --- a/docs/05_configuration.md +++ b/docs/05_configuration.md @@ -5,12 +5,13 @@ Next we need to do some configuration. Most importantly we need to specify the d ``` # use activitypub_federation::config::FederationConfig; # let db_connection = (); -# let _ = actix_rt::System::new(); +# tokio::runtime::Runtime::new().unwrap().block_on(async { let config = FederationConfig::builder() .domain("example.com") .app_data(db_connection) - .build()?; + .build().await?; # Ok::<(), anyhow::Error>(()) +# }).unwrap() ``` `debug` is necessary to test federation with http and localhost URLs, but it should never be used in production. The `worker_count` value can be adjusted depending on the instance size. A lower value saves resources on a small instance, while a higher value is necessary on larger instances to keep up with send jobs. `url_verifier` can be used to implement a domain blacklist. \ No newline at end of file diff --git a/docs/06_http_endpoints_axum.md b/docs/06_http_endpoints_axum.md index d2eea49..8ebbcc8 100644 --- a/docs/06_http_endpoints_axum.md +++ b/docs/06_http_endpoints_axum.md @@ -22,12 +22,12 @@ The next step is to allow other servers to fetch our actors and objects. For thi # use http::HeaderMap; # async fn generate_user_html(_: String, _: Data) -> axum::response::Response { todo!() } -#[actix_rt::main] +#[tokio::main] async fn main() -> Result<(), Error> { let data = FederationConfig::builder() .domain("example.com") .app_data(DbConnection) - .build()?; + .build().await?; let app = axum::Router::new() .route("/user/:name", get(http_get_user)) diff --git a/docs/07_fetching_data.md b/docs/07_fetching_data.md index caab27f..05eb7ed 100644 --- a/docs/07_fetching_data.md +++ b/docs/07_fetching_data.md @@ -7,18 +7,17 @@ After setting up our structs, implementing traits and initializing configuration # use activitypub_federation::traits::tests::DbUser; # use activitypub_federation::config::FederationConfig; # let db_connection = activitypub_federation::traits::tests::DbConnection; -# let _ = actix_rt::System::new(); -# actix_rt::Runtime::new().unwrap().block_on(async { +# tokio::runtime::Runtime::new().unwrap().block_on(async { let config = FederationConfig::builder() .domain("example.com") .app_data(db_connection) - .build()?; + .build().await?; let user_id = ObjectId::::parse("https://mastodon.social/@LemmyDev")?; let data = config.to_request_data(); let user = user_id.dereference(&data).await; assert!(user.is_ok()); # Ok::<(), anyhow::Error>(()) -}).unwrap() +# }).unwrap() ``` `dereference` retrieves the object JSON at the given URL, and uses serde to convert it to `Person`. It then calls your method `Object::from_json` which inserts it in the database and returns a `DbUser` struct. `request_data` contains the federation config as well as a counter of outgoing HTTP requests. If this counter exceeds the configured maximum, further requests are aborted in order to avoid recursive fetching which could allow for a denial of service attack. @@ -32,9 +31,8 @@ We can similarly dereference a user over webfinger with the following method. It # use activitypub_federation::fetch::webfinger::webfinger_resolve_actor; # use activitypub_federation::traits::tests::DbUser; # let db_connection = DbConnection; -# let _ = actix_rt::System::new(); -# actix_rt::Runtime::new().unwrap().block_on(async { -# let config = FederationConfig::builder().domain("example.com").app_data(db_connection).build()?; +# tokio::runtime::Runtime::new().unwrap().block_on(async { +# let config = FederationConfig::builder().domain("example.com").app_data(db_connection).build().await?; # let data = config.to_request_data(); let user: DbUser = webfinger_resolve_actor("nutomic@lemmy.ml", &data).await?; # Ok::<(), anyhow::Error>(()) diff --git a/docs/09_sending_activities.md b/docs/09_sending_activities.md index a7262ea..649ec17 100644 --- a/docs/09_sending_activities.md +++ b/docs/09_sending_activities.md @@ -9,13 +9,12 @@ To send an activity we need to initialize our previously defined struct, and pic # use activitypub_federation::traits::Actor; # use activitypub_federation::fetch::object_id::ObjectId; # use activitypub_federation::traits::tests::{DB_USER, DbConnection, Follow}; -# let _ = actix_rt::System::new(); -# actix_rt::Runtime::new().unwrap().block_on(async { +# tokio::runtime::Runtime::new().unwrap().block_on(async { # let db_connection = DbConnection; # let config = FederationConfig::builder() # .domain("example.com") # .app_data(db_connection) -# .build()?; +# .build().await?; # let data = config.to_request_data(); # let sender = DB_USER.clone(); # let recipient = DB_USER.clone(); diff --git a/docs/10_fetching_objects_with_unknown_type.md b/docs/10_fetching_objects_with_unknown_type.md index ffabb7a..596c381 100644 --- a/docs/10_fetching_objects_with_unknown_type.md +++ b/docs/10_fetching_objects_with_unknown_type.md @@ -61,9 +61,9 @@ impl Object for SearchableDbObjects { } } -#[actix_rt::main] +#[tokio::main] async fn main() -> Result<(), anyhow::Error> { - # let config = FederationConfig::builder().domain("example.com").app_data(DbConnection).build().unwrap(); + # let config = FederationConfig::builder().domain("example.com").app_data(DbConnection).build().await.unwrap(); # let data = config.to_request_data(); let query = "https://example.com/id/413"; let query_result = ObjectId::::parse(query)? diff --git a/examples/live_federation/main.rs b/examples/live_federation/main.rs index 134c93b..4326226 100644 --- a/examples/live_federation/main.rs +++ b/examples/live_federation/main.rs @@ -28,7 +28,7 @@ const DOMAIN: &str = "example.com"; const LOCAL_USER_NAME: &str = "alison"; const BIND_ADDRESS: &str = "localhost:8003"; -#[actix_rt::main] +#[tokio::main] async fn main() -> Result<(), Error> { env_logger::builder() .filter_level(LevelFilter::Warn) @@ -47,7 +47,8 @@ async fn main() -> Result<(), Error> { let config = FederationConfig::builder() .domain(DOMAIN) .app_data(database) - .build()?; + .build() + .await?; info!("Listen with HTTP server on {BIND_ADDRESS}"); let config = config.clone(); diff --git a/examples/local_federation/actix_web/http.rs b/examples/local_federation/actix_web/http.rs index c8a3e04..12a750f 100644 --- a/examples/local_federation/actix_web/http.rs +++ b/examples/local_federation/actix_web/http.rs @@ -30,7 +30,7 @@ pub fn listen(config: &FederationConfig) -> Result<(), Error> { }) .bind(hostname)? .run(); - actix_rt::spawn(server); + tokio::spawn(server); Ok(()) } diff --git a/examples/local_federation/axum/http.rs b/examples/local_federation/axum/http.rs index c67df08..3202117 100644 --- a/examples/local_federation/axum/http.rs +++ b/examples/local_federation/axum/http.rs @@ -41,7 +41,7 @@ pub fn listen(config: &FederationConfig) -> Result<(), Error> { .expect("Failed to lookup domain name"); let server = axum::Server::bind(&addr).serve(app.into_make_service()); - actix_rt::spawn(server); + tokio::spawn(server); Ok(()) } diff --git a/examples/local_federation/instance.rs b/examples/local_federation/instance.rs index 2c64f5f..51311e3 100644 --- a/examples/local_federation/instance.rs +++ b/examples/local_federation/instance.rs @@ -11,7 +11,7 @@ use std::{ }; use url::Url; -pub fn new_instance( +pub async fn new_instance( hostname: &str, name: String, ) -> Result, Error> { @@ -29,7 +29,8 @@ pub fn new_instance( .signed_fetch_actor(&system_user) .app_data(database) .debug(true) - .build()?; + .build() + .await?; Ok(config) } diff --git a/examples/local_federation/main.rs b/examples/local_federation/main.rs index 68f7324..1597668 100644 --- a/examples/local_federation/main.rs +++ b/examples/local_federation/main.rs @@ -17,7 +17,7 @@ mod instance; mod objects; mod utils; -#[actix_rt::main] +#[tokio::main] async fn main() -> Result<(), Error> { env_logger::builder() .filter_level(LevelFilter::Warn) @@ -32,8 +32,8 @@ async fn main() -> Result<(), Error> { .map(|arg| Webserver::from_str(&arg).unwrap()) .unwrap_or(Webserver::Axum); - let alpha = new_instance("localhost:8001", "alpha".to_string())?; - let beta = new_instance("localhost:8002", "beta".to_string())?; + let alpha = new_instance("localhost:8001", "alpha".to_string()).await?; + let beta = new_instance("localhost:8002", "beta".to_string()).await?; listen(&alpha, &webserver)?; listen(&beta, &webserver)?; info!("Local instances started"); diff --git a/src/activity_queue.rs b/src/activity_queue.rs index abf5f4c..f470ed9 100644 --- a/src/activity_queue.rs +++ b/src/activity_queue.rs @@ -11,25 +11,28 @@ use crate::{ FEDERATION_CONTENT_TYPE, }; use anyhow::anyhow; -use background_jobs::{ - memory_storage::{ActixTimer, Storage}, - ActixJob, - Backoff, - Manager, - MaxRetries, - WorkerConfig, -}; + +use bytes::Bytes; +use futures_core::Future; use http::{header::HeaderName, HeaderMap, HeaderValue}; use httpdate::fmt_http_date; use itertools::Itertools; +use openssl::pkey::{PKey, Private}; +use reqwest::Request; use reqwest_middleware::ClientWithMiddleware; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use std::{ - fmt::Debug, - future::Future, - pin::Pin, + fmt::{Debug, Display}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, time::{Duration, SystemTime}, }; +use tokio::{ + sync::mpsc::{unbounded_channel, UnboundedSender}, + task::{JoinHandle, JoinSet}, +}; use tracing::{debug, info, warn}; use url::Url; @@ -56,10 +59,19 @@ where let config = &data.config; let actor_id = activity.actor(); let activity_id = activity.id(); - let activity_serialized = serde_json::to_string_pretty(&activity)?; - let private_key = actor + let activity_serialized: Bytes = serde_json::to_vec(&activity)?.into(); + let private_key_pem = actor .private_key_pem() - .expect("Actor for sending activity has private key"); + .ok_or_else(|| anyhow!("Actor {actor_id} does not contain a private key for signing"))?; + + // This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening + let private_key = tokio::task::spawn_blocking(move || { + PKey::private_key_from_pem(private_key_pem.as_bytes()) + .map_err(|err| anyhow!("Could not create private key from PEM data:{err}")) + }) + .await + .map_err(|err| anyhow!("Error joining:{err}"))??; + let inboxes: Vec = inboxes .into_iter() .unique() @@ -84,25 +96,32 @@ where private_key: private_key.clone(), http_signature_compat: config.http_signature_compat, }; + + // Don't use the activity queue if this is in debug mode, send and wait directly if config.debug { - let res = do_send(message, &config.client, config.request_timeout).await; - // Don't fail on error, as we intentionally do some invalid actions in tests, to verify that - // they are rejected on the receiving side. These errors shouldn't bubble up to make the API - // call fail. This matches the behaviour in production. - if let Err(e) = res { - warn!("{}", e); + if let Err(err) = sign_and_send( + &message, + &config.client, + config.request_timeout, + Default::default(), + ) + .await + { + warn!("{err}"); } } else { activity_queue.queue(message).await?; - let stats = activity_queue.get_stats().await?; + let stats = activity_queue.get_stats(); + let running = stats.running.load(Ordering::Relaxed); let stats_fmt = format!( - "Activity queue stats: pending: {}, running: {}, dead (this hour): {}, complete (this hour): {}", - stats.pending, - stats.running, - stats.dead.this_hour(), - stats.complete.this_hour() - ); - if stats.running as u64 == config.worker_count { + "Activity queue stats: pending: {}, running: {}, retries: {}, dead: {}, complete: {}", + stats.pending.load(Ordering::Relaxed), + running, + stats.retries.load(Ordering::Relaxed), + stats.dead_last_hour.load(Ordering::Relaxed), + stats.completed_last_hour.load(Ordering::Relaxed), + ); + if running == config.worker_count { warn!("Reached max number of send activity workers ({}). Consider increasing worker count to avoid federation delays", config.worker_count); warn!(stats_fmt); } else { @@ -114,56 +133,66 @@ where Ok(()) } -#[derive(Clone, Debug, Deserialize, Serialize)] +#[derive(Clone, Debug)] struct SendActivityTask { actor_id: Url, activity_id: Url, - activity: String, + activity: Bytes, inbox: Url, - private_key: String, + private_key: PKey, http_signature_compat: bool, } -impl ActixJob for SendActivityTask { - type State = QueueState; - type Future = Pin>>>; - const NAME: &'static str = "SendActivityTask"; - - const MAX_RETRIES: MaxRetries = MaxRetries::Count(3); - /// This gives the following retry intervals: - /// - 60s (one minute, for service restart) - /// - 60min (one hour, for instance maintenance) - /// - 60h (2.5 days, for major incident with rebuild from backup) - const BACKOFF: Backoff = Backoff::Exponential(60); - - fn run(self, state: Self::State) -> Self::Future { - Box::pin(async move { do_send(self, &state.client, state.timeout).await }) - } -} - -async fn do_send( - task: SendActivityTask, +async fn sign_and_send( + task: &SendActivityTask, client: &ClientWithMiddleware, timeout: Duration, + retry_strategy: RetryStrategy, ) -> Result<(), anyhow::Error> { - debug!("Sending {} to {}", task.activity_id, task.inbox); + debug!( + "Sending {} to {}, contents:\n {}", + task.activity_id, + task.inbox, + serde_json::from_slice::(&task.activity)? + ); let request_builder = client .post(task.inbox.to_string()) .timeout(timeout) .headers(generate_request_headers(&task.inbox)); let request = sign_request( request_builder, - task.actor_id, - task.activity, - task.private_key, + &task.actor_id, + task.activity.clone(), + task.private_key.clone(), task.http_signature_compat, ) .await?; + + retry( + || { + send( + task, + client, + request + .try_clone() + .expect("The body of the request is not cloneable"), + ) + }, + retry_strategy, + ) + .await +} + +async fn send( + task: &SendActivityTask, + client: &ClientWithMiddleware, + request: Request, +) -> Result<(), anyhow::Error> { let response = client.execute(request).await; match response { Ok(o) if o.status().is_success() => { - info!( + debug!( "Activity {} delivered successfully to {}", task.activity_id, task.inbox ); @@ -171,7 +200,7 @@ async fn do_send( } Ok(o) if o.status().is_client_error() => { let text = o.text_limited().await.map_err(Error::other)?; - info!( + debug!( "Activity {} was rejected by {}, aborting: {}", task.activity_id, task.inbox, text, ); @@ -189,7 +218,7 @@ async fn do_send( )) } Err(e) => { - info!( + debug!( "Unable to connect to {}, aborting task {}: {}", task.inbox, task.activity_id, e ); @@ -220,27 +249,401 @@ pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap { headers } -pub(crate) fn create_activity_queue( - client: ClientWithMiddleware, - worker_count: u64, - request_timeout: Duration, - debug: bool, -) -> Manager { - // queue is not used in debug mod, so dont create any workers to avoid log spam - let worker_count = if debug { 0 } else { worker_count }; - - // Configure and start our workers - WorkerConfig::new_managed(Storage::new(ActixTimer), move |_| QueueState { - client: client.clone(), - timeout: request_timeout, - }) - .register::() - .set_worker_count("default", worker_count) - .start() +/// A simple activity queue which spawns tokio workers to send out requests +/// When creating a queue, it will spawn a task per worker thread +/// Uses an unbounded mpsc queue for communication (i.e, all messages are in memory) +pub(crate) struct ActivityQueue { + // Stats shared between the queue and workers + stats: Arc, + sender: UnboundedSender, + sender_task: JoinHandle<()>, + retry_sender_task: JoinHandle<()>, } -#[derive(Clone)] -struct QueueState { +/// Simple stat counter to show where we're up to with sending messages +/// This is a lock-free way to share things between tasks +/// When reading these values it's possible (but extremely unlikely) to get stale data if a worker task is in the middle of transitioning +#[derive(Default)] +struct Stats { + pending: AtomicUsize, + running: AtomicUsize, + retries: AtomicUsize, + dead_last_hour: AtomicUsize, + completed_last_hour: AtomicUsize, +} + +#[derive(Clone, Copy, Default)] +struct RetryStrategy { + /// Amount of time in seconds to back off + backoff: usize, + /// Amount of times to retry + retries: usize, + /// If this particular request has already been retried, you can add an offset here to increment the count to start + offset: usize, + /// Number of seconds to sleep before trying + initial_sleep: usize, +} + +/// A tokio spawned worker which is responsible for submitting requests to federated servers +/// This will retry up to one time with the same signature, and if it fails, will move it to the retry queue. +/// We need to retry activity sending in case the target instances is temporarily unreachable. +/// In this case, the task is stored and resent when the instance is hopefully back up. This +/// list shows the retry intervals, and which events of the target instance can be covered: +/// - 60s (one minute, service restart) -- happens in the worker w/ same signature +/// - 60min (one hour, instance maintenance) --- happens in the retry worker +/// - 60h (2.5 days, major incident with rebuild from backup) --- happens in the retry worker +async fn worker( client: ClientWithMiddleware, timeout: Duration, + message: SendActivityTask, + retry_queue: UnboundedSender, + stats: Arc, + strategy: RetryStrategy, +) { + stats.pending.fetch_sub(1, Ordering::Relaxed); + stats.running.fetch_add(1, Ordering::Relaxed); + + let outcome = sign_and_send(&message, &client, timeout, strategy).await; + + // "Running" has finished, check the outcome + stats.running.fetch_sub(1, Ordering::Relaxed); + + match outcome { + Ok(_) => { + stats.completed_last_hour.fetch_add(1, Ordering::Relaxed); + } + Err(_err) => { + stats.retries.fetch_add(1, Ordering::Relaxed); + warn!( + "Sending activity {} to {} to the retry queue to be tried again later", + message.activity_id, message.inbox + ); + // Send to the retry queue. Ignoring whether it succeeds or not + retry_queue.send(message).ok(); + } + } +} + +async fn retry_worker( + client: ClientWithMiddleware, + timeout: Duration, + message: SendActivityTask, + stats: Arc, + strategy: RetryStrategy, +) { + // Because the times are pretty extravagant between retries, we have to re-sign each time + let outcome = retry( + || { + sign_and_send( + &message, + &client, + timeout, + RetryStrategy { + backoff: 0, + retries: 0, + offset: 0, + initial_sleep: 0, + }, + ) + }, + strategy, + ) + .await; + + stats.retries.fetch_sub(1, Ordering::Relaxed); + + match outcome { + Ok(_) => { + stats.completed_last_hour.fetch_add(1, Ordering::Relaxed); + } + Err(_err) => { + stats.dead_last_hour.fetch_add(1, Ordering::Relaxed); + } + } +} + +impl ActivityQueue { + fn new( + client: ClientWithMiddleware, + worker_count: usize, + retry_count: usize, + timeout: Duration, + backoff: usize, // This should be 60 seconds by default or 1 second in tests + ) -> Self { + let stats: Arc = Default::default(); + + // This task clears the dead/completed stats every hour + let hour_stats = stats.clone(); + tokio::spawn(async move { + let duration = Duration::from_secs(3600); + loop { + tokio::time::sleep(duration).await; + hour_stats.completed_last_hour.store(0, Ordering::Relaxed); + hour_stats.dead_last_hour.store(0, Ordering::Relaxed); + } + }); + + let (retry_sender, mut retry_receiver) = unbounded_channel(); + let retry_stats = stats.clone(); + let retry_client = client.clone(); + + // The "fast path" retry + // The backoff should be < 5 mins for this to work otherwise signatures may expire + // This strategy is the one that is used with the *same* signature + let strategy = RetryStrategy { + backoff, + retries: 1, + offset: 0, + initial_sleep: 0, + }; + + // The "retry path" strategy + // After the fast path fails, a task will sleep up to backoff ^ 2 and then retry again + let retry_strategy = RetryStrategy { + backoff, + retries: 3, + offset: 2, + initial_sleep: backoff.pow(2), // wait 60 mins before even trying + }; + + let retry_sender_task = tokio::spawn(async move { + let mut join_set = JoinSet::new(); + + while let Some(message) = retry_receiver.recv().await { + // If we're over the limit of retries, wait for them to finish before spawning + while join_set.len() >= retry_count { + join_set.join_next().await; + } + + join_set.spawn(retry_worker( + retry_client.clone(), + timeout, + message, + retry_stats.clone(), + retry_strategy, + )); + } + + while !join_set.is_empty() { + join_set.join_next().await; + } + }); + + let (sender, mut receiver) = unbounded_channel(); + + let sender_stats = stats.clone(); + + let sender_task = tokio::spawn(async move { + let mut join_set = JoinSet::new(); + + while let Some(message) = receiver.recv().await { + // If we're over the limit of workers, wait for them to finish before spawning + while join_set.len() >= worker_count { + join_set.join_next().await; + } + + join_set.spawn(worker( + client.clone(), + timeout, + message, + retry_sender.clone(), + sender_stats.clone(), + strategy, + )); + } + + drop(retry_sender); + + while !join_set.is_empty() { + join_set.join_next().await; + } + }); + + Self { + stats, + sender, + sender_task, + retry_sender_task, + } + } + + async fn queue(&self, message: SendActivityTask) -> Result<(), anyhow::Error> { + self.stats.pending.fetch_add(1, Ordering::Relaxed); + self.sender.send(message)?; + + Ok(()) + } + + fn get_stats(&self) -> &Stats { + &self.stats + } + + #[allow(unused)] + // Drops all the senders and shuts down the workers + async fn shutdown(self, wait_for_retries: bool) -> Result, anyhow::Error> { + drop(self.sender); + + self.sender_task.await?; + + if wait_for_retries { + self.retry_sender_task.await?; + } + + Ok(self.stats) + } +} + +/// Creates an activity queue using tokio spawned tasks +/// Note: requires a tokio runtime +pub(crate) fn create_activity_queue( + client: ClientWithMiddleware, + worker_count: usize, + retry_count: usize, + request_timeout: Duration, +) -> ActivityQueue { + assert!( + worker_count > 0, + "worker count needs to be greater than zero" + ); + + ActivityQueue::new(client, worker_count, retry_count, request_timeout, 60) +} + +/// Retries a future action factory function up to `amount` times with an exponential backoff timer between tries +async fn retry>, A: FnMut() -> F>( + mut action: A, + strategy: RetryStrategy, +) -> Result { + let mut count = strategy.offset; + + // Do an initial sleep if it's called for + if strategy.initial_sleep > 0 { + let sleep_dur = Duration::from_secs(strategy.initial_sleep as u64); + tokio::time::sleep(sleep_dur).await; + } + + loop { + match action().await { + Ok(val) => return Ok(val), + Err(err) => { + if count < strategy.retries { + count += 1; + + let sleep_amt = strategy.backoff.pow(count as u32) as u64; + let sleep_dur = Duration::from_secs(sleep_amt); + warn!("{err:?}. Sleeping for {sleep_dur:?} and trying again"); + tokio::time::sleep(sleep_dur).await; + continue; + } else { + return Err(err); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use axum::extract::State; + use bytes::Bytes; + use http::StatusCode; + use std::time::Instant; + + use crate::http_signatures::generate_actor_keypair; + + use super::*; + + #[allow(unused)] + // This will periodically send back internal errors to test the retry + async fn dodgy_handler( + State(state): State>, + headers: HeaderMap, + body: Bytes, + ) -> Result<(), StatusCode> { + debug!("Headers:{:?}", headers); + debug!("Body len:{}", body.len()); + + if state.fetch_add(1, Ordering::Relaxed) % 20 == 0 { + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + Ok(()) + } + + async fn test_server() { + use axum::{routing::post, Router}; + + // We should break every now and then ;) + let state = Arc::new(AtomicUsize::new(0)); + + let app = Router::new() + .route("/", post(dodgy_handler)) + .with_state(state); + + axum::Server::bind(&"0.0.0.0:8001".parse().unwrap()) + .serve(app.into_make_service()) + .await + .unwrap(); + } + + #[tokio::test(flavor = "multi_thread")] + // Queues 100 messages and then asserts that the worker runs them + async fn test_activity_queue_workers() { + let num_workers = 64; + let num_messages: usize = 100; + + tokio::spawn(test_server()); + + /* + // uncomment for debug logs & stats + use tracing::log::LevelFilter; + + env_logger::builder() + .filter_level(LevelFilter::Warn) + .filter_module("activitypub_federation", LevelFilter::Info) + .format_timestamp(None) + .init(); + + */ + + let activity_queue = ActivityQueue::new( + reqwest::Client::default().into(), + num_workers, + num_workers, + Duration::from_secs(10), + 1, + ); + + let keypair = generate_actor_keypair().unwrap(); + + let message = SendActivityTask { + actor_id: "http://localhost:8001".parse().unwrap(), + activity_id: "http://localhost:8001/activity".parse().unwrap(), + activity: "{}".into(), + inbox: "http://localhost:8001".parse().unwrap(), + private_key: keypair.private_key().unwrap(), + http_signature_compat: true, + }; + + let start = Instant::now(); + + for _ in 0..num_messages { + activity_queue.queue(message.clone()).await.unwrap(); + } + + info!("Queue Sent: {:?}", start.elapsed()); + + let stats = activity_queue.shutdown(true).await.unwrap(); + + info!( + "Queue Finished. Num msgs: {}, Time {:?}, msg/s: {:0.0}", + num_messages, + start.elapsed(), + num_messages as f64 / start.elapsed().as_secs_f64() + ); + + assert_eq!( + stats.completed_last_hour.load(Ordering::Relaxed), + num_messages + ); + } } diff --git a/src/actix_web/inbox.rs b/src/actix_web/inbox.rs index bb21c46..e4d83e9 100644 --- a/src/actix_web/inbox.rs +++ b/src/actix_web/inbox.rs @@ -65,19 +65,19 @@ mod test { use reqwest_middleware::ClientWithMiddleware; use url::Url; - #[actix_rt::test] + #[tokio::test] async fn test_receive_activity() { let (body, incoming_request, config) = setup_receive_test().await; receive_activity::( incoming_request.to_http_request(), - body.into(), + body, &config.to_request_data(), ) .await .unwrap(); } - #[actix_rt::test] + #[tokio::test] async fn test_receive_activity_invalid_body_signature() { let (_, incoming_request, config) = setup_receive_test().await; let err = receive_activity::( @@ -93,13 +93,13 @@ mod test { assert_eq!(e, &Error::ActivityBodyDigestInvalid) } - #[actix_rt::test] + #[tokio::test] async fn test_receive_activity_invalid_path() { let (body, incoming_request, config) = setup_receive_test().await; let incoming_request = incoming_request.uri("/wrong"); let err = receive_activity::( incoming_request.to_http_request(), - body.into(), + body, &config.to_request_data(), ) .await @@ -110,7 +110,7 @@ mod test { assert_eq!(e, &Error::ActivitySignatureInvalid) } - async fn setup_receive_test() -> (String, TestRequest, FederationConfig) { + async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig) { let inbox = "https://example.com/inbox"; let headers = generate_request_headers(&Url::parse(inbox).unwrap()); let request_builder = ClientWithMiddleware::from(Client::default()) @@ -122,12 +122,12 @@ mod test { kind: Default::default(), id: "http://localhost:123/1".try_into().unwrap(), }; - let body = serde_json::to_string(&activity).unwrap(); + let body: Bytes = serde_json::to_vec(&activity).unwrap().into(); let outgoing_request = sign_request( request_builder, - activity.actor.into_inner(), - body.to_string(), - DB_USER_KEYPAIR.private_key.clone(), + &activity.actor.into_inner(), + body.clone(), + DB_USER_KEYPAIR.private_key().unwrap(), false, ) .await @@ -142,6 +142,7 @@ mod test { .app_data(DbConnection) .debug(true) .build() + .await .unwrap(); (body, incoming_request, config) } diff --git a/src/config.rs b/src/config.rs index 0982194..3cd0ce8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,26 +4,27 @@ //! //! ``` //! # use activitypub_federation::config::FederationConfig; -//! # let _ = actix_rt::System::new(); +//! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! let settings = FederationConfig::builder() //! .domain("example.com") //! .app_data(()) //! .http_fetch_limit(50) //! .worker_count(16) -//! .build()?; +//! .build().await?; //! # Ok::<(), anyhow::Error>(()) +//! # }).unwrap() //! ``` use crate::{ - activity_queue::create_activity_queue, + activity_queue::{create_activity_queue, ActivityQueue}, error::Error, protocol::verification::verify_domains_match, traits::{ActivityHandler, Actor}, }; use async_trait::async_trait; -use background_jobs::Manager; use derive_builder::Builder; use dyn_clone::{clone_trait_object, DynClone}; +use openssl::pkey::{PKey, Private}; use reqwest_middleware::ClientWithMiddleware; use serde::de::DeserializeOwned; use std::{ @@ -54,9 +55,14 @@ pub struct FederationConfig { /// HTTP client used for all outgoing requests. Middleware can be used to add functionality /// like log tracing or retry of failed requests. pub(crate) client: ClientWithMiddleware, - /// Number of worker threads for sending outgoing activities - #[builder(default = "64")] - pub(crate) worker_count: u64, + /// Number of tasks that can be in-flight concurrently. + /// Tasks are retried once after a minute, then put into the retry queue + #[builder(default = "1024")] + pub(crate) worker_count: usize, + /// Number of concurrent tasks that are being retried in-flight concurrently. + /// Tasks are retried after an hour, then again in 60 hours. + #[builder(default = "128")] + pub(crate) retry_count: usize, /// Run library in debug mode. This allows usage of http and localhost urls. It also sends /// outgoing activities synchronously, not in background thread. This helps to make tests /// more consistent. Do not use for production. @@ -79,11 +85,11 @@ pub struct FederationConfig { /// This can be used to implement secure mode federation. /// #[builder(default = "None", setter(custom))] - pub(crate) signed_fetch_actor: Option>, + pub(crate) signed_fetch_actor: Option)>>, /// Queue for sending outgoing activities. Only optional to make builder work, its always /// present once constructed. #[builder(setter(skip))] - pub(crate) activity_queue: Option>, + pub(crate) activity_queue: Option>, } impl FederationConfig { @@ -180,7 +186,10 @@ impl FederationConfigBuilder { let private_key_pem = actor .private_key_pem() .expect("actor does not have a private key to sign with"); - self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key_pem)))); + + let private_key = PKey::private_key_from_pem(private_key_pem.as_bytes()) + .expect("Could not decode PEM data"); + self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key)))); self } @@ -188,13 +197,14 @@ impl FederationConfigBuilder { /// /// Values which are not explicitly specified use the defaults. Also initializes the /// queue for outgoing activities, which is stored internally in the config struct. - pub fn build(&mut self) -> Result, FederationConfigBuilderError> { + /// Requires a tokio runtime for the background queue. + pub async fn build(&mut self) -> Result, FederationConfigBuilderError> { let mut config = self.partial_build()?; let queue = create_activity_queue( config.client.clone(), config.worker_count, + config.retry_count, config.request_timeout, - config.debug, ); config.activity_queue = Some(Arc::new(queue)); Ok(config) diff --git a/src/fetch/collection_id.rs b/src/fetch/collection_id.rs index 48e3867..a45419c 100644 --- a/src/fetch/collection_id.rs +++ b/src/fetch/collection_id.rs @@ -30,7 +30,7 @@ where /// Fetches collection over HTTP /// - /// Unlike [ObjectId::fetch](crate::fetch::object_id::ObjectId::fetch) this method doesn't do + /// Unlike [ObjectId::dereference](crate::fetch::object_id::ObjectId::dereference) this method doesn't do /// any caching. pub async fn dereference( &self, diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 769f302..7a9734c 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -9,6 +9,7 @@ use crate::{ reqwest_shim::ResponseExt, FEDERATION_CONTENT_TYPE, }; +use bytes::Bytes; use http::StatusCode; use serde::de::DeserializeOwned; use std::sync::atomic::Ordering; @@ -56,8 +57,8 @@ pub async fn fetch_object_http( let res = if let Some((actor_id, private_key_pem)) = config.signed_fetch_actor.as_deref() { let req = sign_request( req, - actor_id.clone(), - String::new(), + actor_id, + Bytes::new(), private_key_pem.clone(), data.config.http_signature_compat, ) diff --git a/src/fetch/object_id.rs b/src/fetch/object_id.rs index 1155493..b210f11 100644 --- a/src/fetch/object_id.rs +++ b/src/fetch/object_id.rs @@ -36,13 +36,12 @@ where /// # use activitypub_federation::config::FederationConfig; /// # use activitypub_federation::error::Error::NotFound; /// # use activitypub_federation::traits::tests::{DbConnection, DbUser}; -/// # let _ = actix_rt::System::new(); -/// # actix_rt::Runtime::new().unwrap().block_on(async { +/// # tokio::runtime::Runtime::new().unwrap().block_on(async { /// # let db_connection = DbConnection; /// let config = FederationConfig::builder() /// .domain("example.com") /// .app_data(db_connection) -/// .build()?; +/// .build().await?; /// let request_data = config.to_request_data(); /// let object_id = ObjectId::::parse("https://lemmy.ml/u/nutomic")?; /// // Attempt to fetch object from local database or fall back to remote server diff --git a/src/fetch/webfinger.rs b/src/fetch/webfinger.rs index 936485a..183bd46 100644 --- a/src/fetch/webfinger.rs +++ b/src/fetch/webfinger.rs @@ -199,12 +199,13 @@ mod tests { traits::tests::{DbConnection, DbUser}, }; - #[actix_rt::test] + #[tokio::test] async fn test_webfinger() { let config = FederationConfig::builder() .domain("example.com") .app_data(DbConnection) .build() + .await .unwrap(); let data = config.to_request_data(); let res = diff --git a/src/http_signatures.rs b/src/http_signatures.rs index 988533f..dc2bc23 100644 --- a/src/http_signatures.rs +++ b/src/http_signatures.rs @@ -13,12 +13,13 @@ use crate::{ traits::{Actor, Object}, }; use base64::{engine::general_purpose::STANDARD as Base64, Engine}; +use bytes::Bytes; use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri}; use http_signature_normalization_reqwest::prelude::{Config, SignExt}; use once_cell::sync::Lazy; use openssl::{ hash::MessageDigest, - pkey::PKey, + pkey::{PKey, Private}, rsa::Rsa, sign::{Signer, Verifier}, }; @@ -26,7 +27,7 @@ use reqwest::Request; use reqwest_middleware::RequestBuilder; use serde::Deserialize; use sha2::{Digest, Sha256}; -use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind}; +use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind, time::Duration}; use tracing::debug; use url::Url; @@ -39,6 +40,14 @@ pub struct Keypair { pub public_key: String, } +impl Keypair { + /// Helper method to turn this into an openssl private key + #[cfg(test)] + pub(crate) fn private_key(&self) -> Result, anyhow::Error> { + Ok(PKey::private_key_from_pem(self.private_key.as_bytes())?) + } +} + /// Generate a random asymmetric keypair for ActivityPub HTTP signatures. pub fn generate_actor_keypair() -> Result { let rsa = Rsa::generate(2048)?; @@ -58,19 +67,27 @@ pub fn generate_actor_keypair() -> Result { }) } +/// Sets the amount of time that a signed request is valid. Currenlty 5 minutes +/// Mastodon & friends have ~1 hour expiry from creation if it's not set in the header +pub(crate) const EXPIRES_AFTER: Duration = Duration::from_secs(300); + /// Creates an HTTP post request to `inbox_url`, with the given `client` and `headers`, and /// `activity` as request body. The request is signed with `private_key` and then sent. pub(crate) async fn sign_request( request_builder: RequestBuilder, - actor_id: Url, - activity: String, - private_key: String, + actor_id: &Url, + activity: Bytes, + private_key: PKey, http_signature_compat: bool, ) -> Result { - static CONFIG: Lazy = Lazy::new(Config::new); - static CONFIG_COMPAT: Lazy = Lazy::new(|| Config::new().mastodon_compat()); + static CONFIG: Lazy = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER)); + static CONFIG_COMPAT: Lazy = Lazy::new(|| { + Config::new() + .mastodon_compat() + .set_expiration(EXPIRES_AFTER) + }); - let key_id = main_key_id(&actor_id); + let key_id = main_key_id(actor_id); let sig_conf = match http_signature_compat { false => CONFIG.clone(), true => CONFIG_COMPAT.clone(), @@ -82,7 +99,6 @@ pub(crate) async fn sign_request( Sha256::new(), activity, move |signing_string| { - let private_key = PKey::private_key_from_pem(private_key.as_bytes())?; let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?; signer.update(signing_string.as_bytes())?; @@ -259,7 +275,7 @@ pub mod test { static INBOX_URL: Lazy = Lazy::new(|| Url::parse("https://example.com/u/alice/inbox").unwrap()); - #[actix_rt::test] + #[tokio::test] async fn test_sign() { let mut headers = generate_request_headers(&INBOX_URL); // use hardcoded date in order to test against hardcoded signature @@ -273,9 +289,9 @@ pub mod test { .headers(headers); let request = sign_request( request_builder, - ACTOR_ID.clone(), - "my activity".to_string(), - test_keypair().private_key, + &ACTOR_ID, + "my activity".into(), + PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(), // set this to prevent created/expires headers to be generated and inserted // automatically from current time true, @@ -301,7 +317,7 @@ pub mod test { assert_eq!(signature, expected_signature); } - #[actix_rt::test] + #[tokio::test] async fn test_verify() { let headers = generate_request_headers(&INBOX_URL); let request_builder = ClientWithMiddleware::from(Client::new()) @@ -309,9 +325,9 @@ pub mod test { .headers(headers); let request = sign_request( request_builder, - ACTOR_ID.clone(), - "my activity".to_string(), - test_keypair().private_key, + &ACTOR_ID, + "my activity".to_string().into(), + PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(), false, ) .await