Remove actix-rt and replace with tokio tasks (#42)

* Remove `actix-rt` and replace with tokio tasks

* Include activity queue test

* Use older `Arc` method

* Refactor to not re-process PEM data on each request

* Add retry queue and spawn tokio tasks directly

* Fix doc error

* Remove semaphore and use join set for backpressure

* Fix debug issue with multiple mailboxes
This commit is contained in:
cetra3 2023-06-20 19:24:14 +09:30 committed by GitHub
parent 6ac6e2d90e
commit c356265cf4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 601 additions and 153 deletions

2
.gitignore vendored
View file

@ -1,3 +1,5 @@
/target
/.idea
/Cargo.lock
perf.data*
flamegraph.svg

View file

@ -23,26 +23,37 @@ openssl = "0.10.54"
once_cell = "1.18.0"
http = "0.2.9"
sha2 = "0.10.6"
background-jobs = "0.13.0"
thiserror = "1.0.40"
derive_builder = "0.12.0"
itertools = "0.10.5"
dyn-clone = "1.0.11"
enum_delegate = "0.2.0"
httpdate = "1.0.2"
http-signature-normalization-reqwest = { version = "0.8.0", default-features = false, features = ["sha-2", "middleware"] }
http-signature-normalization-reqwest = { version = "0.8.0", default-features = false, features = [
"sha-2",
"middleware",
] }
http-signature-normalization = "0.7.0"
bytes = "1.4.0"
futures-core = { version = "0.3.28", default-features = false }
pin-project-lite = "0.2.9"
activitystreams-kinds = "0.3.0"
regex = { version = "1.8.4", default-features = false, features = ["std"] }
tokio = { version = "1.21.2", features = [
"sync",
"rt",
"rt-multi-thread",
"time",
] }
# Actix-web
actix-web = { version = "4.3.1", default-features = false, optional = true }
# Axum
axum = { version = "0.6.18", features = ["json", "headers"], default-features = false, optional = true }
axum = { version = "0.6.18", features = [
"json",
"headers",
], default-features = false, optional = true }
tower = { version = "0.4.13", optional = true }
hyper = { version = "0.14", optional = true }
displaydoc = "0.2.4"
@ -56,9 +67,13 @@ axum = ["dep:axum", "dep:tower", "dep:hyper"]
rand = "0.8.5"
env_logger = "0.10.0"
tower-http = { version = "0.4.0", features = ["map-request-body", "util"] }
axum = { version = "0.6.18", features = ["http1", "tokio", "query"], default-features = false }
axum = { version = "0.6.18", features = [
"http1",
"tokio",
"query",
], default-features = false }
axum-macros = "0.3.7"
actix-rt = "2.8.0"
tokio = { version = "1.21.2", features = ["full"] }
[profile.dev]
strip = "symbols"

View file

@ -5,12 +5,13 @@ Next we need to do some configuration. Most importantly we need to specify the d
```
# use activitypub_federation::config::FederationConfig;
# let db_connection = ();
# let _ = actix_rt::System::new();
# tokio::runtime::Runtime::new().unwrap().block_on(async {
let config = FederationConfig::builder()
.domain("example.com")
.app_data(db_connection)
.build()?;
.build().await?;
# Ok::<(), anyhow::Error>(())
# }).unwrap()
```
`debug` is necessary to test federation with http and localhost URLs, but it should never be used in production. The `worker_count` value can be adjusted depending on the instance size. A lower value saves resources on a small instance, while a higher value is necessary on larger instances to keep up with send jobs. `url_verifier` can be used to implement a domain blacklist.

View file

@ -22,12 +22,12 @@ The next step is to allow other servers to fetch our actors and objects. For thi
# use http::HeaderMap;
# async fn generate_user_html(_: String, _: Data<DbConnection>) -> axum::response::Response { todo!() }
#[actix_rt::main]
#[tokio::main]
async fn main() -> Result<(), Error> {
let data = FederationConfig::builder()
.domain("example.com")
.app_data(DbConnection)
.build()?;
.build().await?;
let app = axum::Router::new()
.route("/user/:name", get(http_get_user))

View file

@ -7,18 +7,17 @@ After setting up our structs, implementing traits and initializing configuration
# use activitypub_federation::traits::tests::DbUser;
# use activitypub_federation::config::FederationConfig;
# let db_connection = activitypub_federation::traits::tests::DbConnection;
# let _ = actix_rt::System::new();
# actix_rt::Runtime::new().unwrap().block_on(async {
# tokio::runtime::Runtime::new().unwrap().block_on(async {
let config = FederationConfig::builder()
.domain("example.com")
.app_data(db_connection)
.build()?;
.build().await?;
let user_id = ObjectId::<DbUser>::parse("https://mastodon.social/@LemmyDev")?;
let data = config.to_request_data();
let user = user_id.dereference(&data).await;
assert!(user.is_ok());
# Ok::<(), anyhow::Error>(())
}).unwrap()
# }).unwrap()
```
`dereference` retrieves the object JSON at the given URL, and uses serde to convert it to `Person`. It then calls your method `Object::from_json` which inserts it in the database and returns a `DbUser` struct. `request_data` contains the federation config as well as a counter of outgoing HTTP requests. If this counter exceeds the configured maximum, further requests are aborted in order to avoid recursive fetching which could allow for a denial of service attack.
@ -32,9 +31,8 @@ We can similarly dereference a user over webfinger with the following method. It
# use activitypub_federation::fetch::webfinger::webfinger_resolve_actor;
# use activitypub_federation::traits::tests::DbUser;
# let db_connection = DbConnection;
# let _ = actix_rt::System::new();
# actix_rt::Runtime::new().unwrap().block_on(async {
# let config = FederationConfig::builder().domain("example.com").app_data(db_connection).build()?;
# tokio::runtime::Runtime::new().unwrap().block_on(async {
# let config = FederationConfig::builder().domain("example.com").app_data(db_connection).build().await?;
# let data = config.to_request_data();
let user: DbUser = webfinger_resolve_actor("nutomic@lemmy.ml", &data).await?;
# Ok::<(), anyhow::Error>(())

View file

@ -9,13 +9,12 @@ To send an activity we need to initialize our previously defined struct, and pic
# use activitypub_federation::traits::Actor;
# use activitypub_federation::fetch::object_id::ObjectId;
# use activitypub_federation::traits::tests::{DB_USER, DbConnection, Follow};
# let _ = actix_rt::System::new();
# actix_rt::Runtime::new().unwrap().block_on(async {
# tokio::runtime::Runtime::new().unwrap().block_on(async {
# let db_connection = DbConnection;
# let config = FederationConfig::builder()
# .domain("example.com")
# .app_data(db_connection)
# .build()?;
# .build().await?;
# let data = config.to_request_data();
# let sender = DB_USER.clone();
# let recipient = DB_USER.clone();

View file

@ -61,9 +61,9 @@ impl Object for SearchableDbObjects {
}
}
#[actix_rt::main]
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
# let config = FederationConfig::builder().domain("example.com").app_data(DbConnection).build().unwrap();
# let config = FederationConfig::builder().domain("example.com").app_data(DbConnection).build().await.unwrap();
# let data = config.to_request_data();
let query = "https://example.com/id/413";
let query_result = ObjectId::<SearchableDbObjects>::parse(query)?

View file

@ -28,7 +28,7 @@ const DOMAIN: &str = "example.com";
const LOCAL_USER_NAME: &str = "alison";
const BIND_ADDRESS: &str = "localhost:8003";
#[actix_rt::main]
#[tokio::main]
async fn main() -> Result<(), Error> {
env_logger::builder()
.filter_level(LevelFilter::Warn)
@ -47,7 +47,8 @@ async fn main() -> Result<(), Error> {
let config = FederationConfig::builder()
.domain(DOMAIN)
.app_data(database)
.build()?;
.build()
.await?;
info!("Listen with HTTP server on {BIND_ADDRESS}");
let config = config.clone();

View file

@ -30,7 +30,7 @@ pub fn listen(config: &FederationConfig<DatabaseHandle>) -> Result<(), Error> {
})
.bind(hostname)?
.run();
actix_rt::spawn(server);
tokio::spawn(server);
Ok(())
}

View file

@ -41,7 +41,7 @@ pub fn listen(config: &FederationConfig<DatabaseHandle>) -> Result<(), Error> {
.expect("Failed to lookup domain name");
let server = axum::Server::bind(&addr).serve(app.into_make_service());
actix_rt::spawn(server);
tokio::spawn(server);
Ok(())
}

View file

@ -11,7 +11,7 @@ use std::{
};
use url::Url;
pub fn new_instance(
pub async fn new_instance(
hostname: &str,
name: String,
) -> Result<FederationConfig<DatabaseHandle>, Error> {
@ -29,7 +29,8 @@ pub fn new_instance(
.signed_fetch_actor(&system_user)
.app_data(database)
.debug(true)
.build()?;
.build()
.await?;
Ok(config)
}

View file

@ -17,7 +17,7 @@ mod instance;
mod objects;
mod utils;
#[actix_rt::main]
#[tokio::main]
async fn main() -> Result<(), Error> {
env_logger::builder()
.filter_level(LevelFilter::Warn)
@ -32,8 +32,8 @@ async fn main() -> Result<(), Error> {
.map(|arg| Webserver::from_str(&arg).unwrap())
.unwrap_or(Webserver::Axum);
let alpha = new_instance("localhost:8001", "alpha".to_string())?;
let beta = new_instance("localhost:8002", "beta".to_string())?;
let alpha = new_instance("localhost:8001", "alpha".to_string()).await?;
let beta = new_instance("localhost:8002", "beta".to_string()).await?;
listen(&alpha, &webserver)?;
listen(&beta, &webserver)?;
info!("Local instances started");

View file

@ -11,25 +11,28 @@ use crate::{
FEDERATION_CONTENT_TYPE,
};
use anyhow::anyhow;
use background_jobs::{
memory_storage::{ActixTimer, Storage},
ActixJob,
Backoff,
Manager,
MaxRetries,
WorkerConfig,
};
use bytes::Bytes;
use futures_core::Future;
use http::{header::HeaderName, HeaderMap, HeaderValue};
use httpdate::fmt_http_date;
use itertools::Itertools;
use openssl::pkey::{PKey, Private};
use reqwest::Request;
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
use serde::Serialize;
use std::{
fmt::Debug,
future::Future,
pin::Pin,
fmt::{Debug, Display},
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::{Duration, SystemTime},
};
use tokio::{
sync::mpsc::{unbounded_channel, UnboundedSender},
task::{JoinHandle, JoinSet},
};
use tracing::{debug, info, warn};
use url::Url;
@ -56,10 +59,19 @@ where
let config = &data.config;
let actor_id = activity.actor();
let activity_id = activity.id();
let activity_serialized = serde_json::to_string_pretty(&activity)?;
let private_key = actor
let activity_serialized: Bytes = serde_json::to_vec(&activity)?.into();
let private_key_pem = actor
.private_key_pem()
.expect("Actor for sending activity has private key");
.ok_or_else(|| anyhow!("Actor {actor_id} does not contain a private key for signing"))?;
// This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening
let private_key = tokio::task::spawn_blocking(move || {
PKey::private_key_from_pem(private_key_pem.as_bytes())
.map_err(|err| anyhow!("Could not create private key from PEM data:{err}"))
})
.await
.map_err(|err| anyhow!("Error joining:{err}"))??;
let inboxes: Vec<Url> = inboxes
.into_iter()
.unique()
@ -84,25 +96,32 @@ where
private_key: private_key.clone(),
http_signature_compat: config.http_signature_compat,
};
// Don't use the activity queue if this is in debug mode, send and wait directly
if config.debug {
let res = do_send(message, &config.client, config.request_timeout).await;
// Don't fail on error, as we intentionally do some invalid actions in tests, to verify that
// they are rejected on the receiving side. These errors shouldn't bubble up to make the API
// call fail. This matches the behaviour in production.
if let Err(e) = res {
warn!("{}", e);
if let Err(err) = sign_and_send(
&message,
&config.client,
config.request_timeout,
Default::default(),
)
.await
{
warn!("{err}");
}
} else {
activity_queue.queue(message).await?;
let stats = activity_queue.get_stats().await?;
let stats = activity_queue.get_stats();
let running = stats.running.load(Ordering::Relaxed);
let stats_fmt = format!(
"Activity queue stats: pending: {}, running: {}, dead (this hour): {}, complete (this hour): {}",
stats.pending,
stats.running,
stats.dead.this_hour(),
stats.complete.this_hour()
"Activity queue stats: pending: {}, running: {}, retries: {}, dead: {}, complete: {}",
stats.pending.load(Ordering::Relaxed),
running,
stats.retries.load(Ordering::Relaxed),
stats.dead_last_hour.load(Ordering::Relaxed),
stats.completed_last_hour.load(Ordering::Relaxed),
);
if stats.running as u64 == config.worker_count {
if running == config.worker_count {
warn!("Reached max number of send activity workers ({}). Consider increasing worker count to avoid federation delays", config.worker_count);
warn!(stats_fmt);
} else {
@ -114,56 +133,66 @@ where
Ok(())
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug)]
struct SendActivityTask {
actor_id: Url,
activity_id: Url,
activity: String,
activity: Bytes,
inbox: Url,
private_key: String,
private_key: PKey<Private>,
http_signature_compat: bool,
}
impl ActixJob for SendActivityTask {
type State = QueueState;
type Future = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>>>>;
const NAME: &'static str = "SendActivityTask";
const MAX_RETRIES: MaxRetries = MaxRetries::Count(3);
/// This gives the following retry intervals:
/// - 60s (one minute, for service restart)
/// - 60min (one hour, for instance maintenance)
/// - 60h (2.5 days, for major incident with rebuild from backup)
const BACKOFF: Backoff = Backoff::Exponential(60);
fn run(self, state: Self::State) -> Self::Future {
Box::pin(async move { do_send(self, &state.client, state.timeout).await })
}
}
async fn do_send(
task: SendActivityTask,
async fn sign_and_send(
task: &SendActivityTask,
client: &ClientWithMiddleware,
timeout: Duration,
retry_strategy: RetryStrategy,
) -> Result<(), anyhow::Error> {
debug!("Sending {} to {}", task.activity_id, task.inbox);
debug!(
"Sending {} to {}, contents:\n {}",
task.activity_id,
task.inbox,
serde_json::from_slice::<serde_json::Value>(&task.activity)?
);
let request_builder = client
.post(task.inbox.to_string())
.timeout(timeout)
.headers(generate_request_headers(&task.inbox));
let request = sign_request(
request_builder,
task.actor_id,
task.activity,
task.private_key,
&task.actor_id,
task.activity.clone(),
task.private_key.clone(),
task.http_signature_compat,
)
.await?;
retry(
|| {
send(
task,
client,
request
.try_clone()
.expect("The body of the request is not cloneable"),
)
},
retry_strategy,
)
.await
}
async fn send(
task: &SendActivityTask,
client: &ClientWithMiddleware,
request: Request,
) -> Result<(), anyhow::Error> {
let response = client.execute(request).await;
match response {
Ok(o) if o.status().is_success() => {
info!(
debug!(
"Activity {} delivered successfully to {}",
task.activity_id, task.inbox
);
@ -171,7 +200,7 @@ async fn do_send(
}
Ok(o) if o.status().is_client_error() => {
let text = o.text_limited().await.map_err(Error::other)?;
info!(
debug!(
"Activity {} was rejected by {}, aborting: {}",
task.activity_id, task.inbox, text,
);
@ -189,7 +218,7 @@ async fn do_send(
))
}
Err(e) => {
info!(
debug!(
"Unable to connect to {}, aborting task {}: {}",
task.inbox, task.activity_id, e
);
@ -220,27 +249,401 @@ pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap {
headers
}
pub(crate) fn create_activity_queue(
client: ClientWithMiddleware,
worker_count: u64,
request_timeout: Duration,
debug: bool,
) -> Manager {
// queue is not used in debug mod, so dont create any workers to avoid log spam
let worker_count = if debug { 0 } else { worker_count };
// Configure and start our workers
WorkerConfig::new_managed(Storage::new(ActixTimer), move |_| QueueState {
client: client.clone(),
timeout: request_timeout,
})
.register::<SendActivityTask>()
.set_worker_count("default", worker_count)
.start()
/// A simple activity queue which spawns tokio workers to send out requests
/// When creating a queue, it will spawn a task per worker thread
/// Uses an unbounded mpsc queue for communication (i.e, all messages are in memory)
pub(crate) struct ActivityQueue {
// Stats shared between the queue and workers
stats: Arc<Stats>,
sender: UnboundedSender<SendActivityTask>,
sender_task: JoinHandle<()>,
retry_sender_task: JoinHandle<()>,
}
#[derive(Clone)]
struct QueueState {
/// Simple stat counter to show where we're up to with sending messages
/// This is a lock-free way to share things between tasks
/// When reading these values it's possible (but extremely unlikely) to get stale data if a worker task is in the middle of transitioning
#[derive(Default)]
struct Stats {
pending: AtomicUsize,
running: AtomicUsize,
retries: AtomicUsize,
dead_last_hour: AtomicUsize,
completed_last_hour: AtomicUsize,
}
#[derive(Clone, Copy, Default)]
struct RetryStrategy {
/// Amount of time in seconds to back off
backoff: usize,
/// Amount of times to retry
retries: usize,
/// If this particular request has already been retried, you can add an offset here to increment the count to start
offset: usize,
/// Number of seconds to sleep before trying
initial_sleep: usize,
}
/// A tokio spawned worker which is responsible for submitting requests to federated servers
/// This will retry up to one time with the same signature, and if it fails, will move it to the retry queue.
/// We need to retry activity sending in case the target instances is temporarily unreachable.
/// In this case, the task is stored and resent when the instance is hopefully back up. This
/// list shows the retry intervals, and which events of the target instance can be covered:
/// - 60s (one minute, service restart) -- happens in the worker w/ same signature
/// - 60min (one hour, instance maintenance) --- happens in the retry worker
/// - 60h (2.5 days, major incident with rebuild from backup) --- happens in the retry worker
async fn worker(
client: ClientWithMiddleware,
timeout: Duration,
message: SendActivityTask,
retry_queue: UnboundedSender<SendActivityTask>,
stats: Arc<Stats>,
strategy: RetryStrategy,
) {
stats.pending.fetch_sub(1, Ordering::Relaxed);
stats.running.fetch_add(1, Ordering::Relaxed);
let outcome = sign_and_send(&message, &client, timeout, strategy).await;
// "Running" has finished, check the outcome
stats.running.fetch_sub(1, Ordering::Relaxed);
match outcome {
Ok(_) => {
stats.completed_last_hour.fetch_add(1, Ordering::Relaxed);
}
Err(_err) => {
stats.retries.fetch_add(1, Ordering::Relaxed);
warn!(
"Sending activity {} to {} to the retry queue to be tried again later",
message.activity_id, message.inbox
);
// Send to the retry queue. Ignoring whether it succeeds or not
retry_queue.send(message).ok();
}
}
}
async fn retry_worker(
client: ClientWithMiddleware,
timeout: Duration,
message: SendActivityTask,
stats: Arc<Stats>,
strategy: RetryStrategy,
) {
// Because the times are pretty extravagant between retries, we have to re-sign each time
let outcome = retry(
|| {
sign_and_send(
&message,
&client,
timeout,
RetryStrategy {
backoff: 0,
retries: 0,
offset: 0,
initial_sleep: 0,
},
)
},
strategy,
)
.await;
stats.retries.fetch_sub(1, Ordering::Relaxed);
match outcome {
Ok(_) => {
stats.completed_last_hour.fetch_add(1, Ordering::Relaxed);
}
Err(_err) => {
stats.dead_last_hour.fetch_add(1, Ordering::Relaxed);
}
}
}
impl ActivityQueue {
fn new(
client: ClientWithMiddleware,
worker_count: usize,
retry_count: usize,
timeout: Duration,
backoff: usize, // This should be 60 seconds by default or 1 second in tests
) -> Self {
let stats: Arc<Stats> = Default::default();
// This task clears the dead/completed stats every hour
let hour_stats = stats.clone();
tokio::spawn(async move {
let duration = Duration::from_secs(3600);
loop {
tokio::time::sleep(duration).await;
hour_stats.completed_last_hour.store(0, Ordering::Relaxed);
hour_stats.dead_last_hour.store(0, Ordering::Relaxed);
}
});
let (retry_sender, mut retry_receiver) = unbounded_channel();
let retry_stats = stats.clone();
let retry_client = client.clone();
// The "fast path" retry
// The backoff should be < 5 mins for this to work otherwise signatures may expire
// This strategy is the one that is used with the *same* signature
let strategy = RetryStrategy {
backoff,
retries: 1,
offset: 0,
initial_sleep: 0,
};
// The "retry path" strategy
// After the fast path fails, a task will sleep up to backoff ^ 2 and then retry again
let retry_strategy = RetryStrategy {
backoff,
retries: 3,
offset: 2,
initial_sleep: backoff.pow(2), // wait 60 mins before even trying
};
let retry_sender_task = tokio::spawn(async move {
let mut join_set = JoinSet::new();
while let Some(message) = retry_receiver.recv().await {
// If we're over the limit of retries, wait for them to finish before spawning
while join_set.len() >= retry_count {
join_set.join_next().await;
}
join_set.spawn(retry_worker(
retry_client.clone(),
timeout,
message,
retry_stats.clone(),
retry_strategy,
));
}
while !join_set.is_empty() {
join_set.join_next().await;
}
});
let (sender, mut receiver) = unbounded_channel();
let sender_stats = stats.clone();
let sender_task = tokio::spawn(async move {
let mut join_set = JoinSet::new();
while let Some(message) = receiver.recv().await {
// If we're over the limit of workers, wait for them to finish before spawning
while join_set.len() >= worker_count {
join_set.join_next().await;
}
join_set.spawn(worker(
client.clone(),
timeout,
message,
retry_sender.clone(),
sender_stats.clone(),
strategy,
));
}
drop(retry_sender);
while !join_set.is_empty() {
join_set.join_next().await;
}
});
Self {
stats,
sender,
sender_task,
retry_sender_task,
}
}
async fn queue(&self, message: SendActivityTask) -> Result<(), anyhow::Error> {
self.stats.pending.fetch_add(1, Ordering::Relaxed);
self.sender.send(message)?;
Ok(())
}
fn get_stats(&self) -> &Stats {
&self.stats
}
#[allow(unused)]
// Drops all the senders and shuts down the workers
async fn shutdown(self, wait_for_retries: bool) -> Result<Arc<Stats>, anyhow::Error> {
drop(self.sender);
self.sender_task.await?;
if wait_for_retries {
self.retry_sender_task.await?;
}
Ok(self.stats)
}
}
/// Creates an activity queue using tokio spawned tasks
/// Note: requires a tokio runtime
pub(crate) fn create_activity_queue(
client: ClientWithMiddleware,
worker_count: usize,
retry_count: usize,
request_timeout: Duration,
) -> ActivityQueue {
assert!(
worker_count > 0,
"worker count needs to be greater than zero"
);
ActivityQueue::new(client, worker_count, retry_count, request_timeout, 60)
}
/// Retries a future action factory function up to `amount` times with an exponential backoff timer between tries
async fn retry<T, E: Display + Debug, F: Future<Output = Result<T, E>>, A: FnMut() -> F>(
mut action: A,
strategy: RetryStrategy,
) -> Result<T, E> {
let mut count = strategy.offset;
// Do an initial sleep if it's called for
if strategy.initial_sleep > 0 {
let sleep_dur = Duration::from_secs(strategy.initial_sleep as u64);
tokio::time::sleep(sleep_dur).await;
}
loop {
match action().await {
Ok(val) => return Ok(val),
Err(err) => {
if count < strategy.retries {
count += 1;
let sleep_amt = strategy.backoff.pow(count as u32) as u64;
let sleep_dur = Duration::from_secs(sleep_amt);
warn!("{err:?}. Sleeping for {sleep_dur:?} and trying again");
tokio::time::sleep(sleep_dur).await;
continue;
} else {
return Err(err);
}
}
}
}
}
#[cfg(test)]
mod tests {
use axum::extract::State;
use bytes::Bytes;
use http::StatusCode;
use std::time::Instant;
use crate::http_signatures::generate_actor_keypair;
use super::*;
#[allow(unused)]
// This will periodically send back internal errors to test the retry
async fn dodgy_handler(
State(state): State<Arc<AtomicUsize>>,
headers: HeaderMap,
body: Bytes,
) -> Result<(), StatusCode> {
debug!("Headers:{:?}", headers);
debug!("Body len:{}", body.len());
if state.fetch_add(1, Ordering::Relaxed) % 20 == 0 {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
Ok(())
}
async fn test_server() {
use axum::{routing::post, Router};
// We should break every now and then ;)
let state = Arc::new(AtomicUsize::new(0));
let app = Router::new()
.route("/", post(dodgy_handler))
.with_state(state);
axum::Server::bind(&"0.0.0.0:8001".parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
// Queues 100 messages and then asserts that the worker runs them
async fn test_activity_queue_workers() {
let num_workers = 64;
let num_messages: usize = 100;
tokio::spawn(test_server());
/*
// uncomment for debug logs & stats
use tracing::log::LevelFilter;
env_logger::builder()
.filter_level(LevelFilter::Warn)
.filter_module("activitypub_federation", LevelFilter::Info)
.format_timestamp(None)
.init();
*/
let activity_queue = ActivityQueue::new(
reqwest::Client::default().into(),
num_workers,
num_workers,
Duration::from_secs(10),
1,
);
let keypair = generate_actor_keypair().unwrap();
let message = SendActivityTask {
actor_id: "http://localhost:8001".parse().unwrap(),
activity_id: "http://localhost:8001/activity".parse().unwrap(),
activity: "{}".into(),
inbox: "http://localhost:8001".parse().unwrap(),
private_key: keypair.private_key().unwrap(),
http_signature_compat: true,
};
let start = Instant::now();
for _ in 0..num_messages {
activity_queue.queue(message.clone()).await.unwrap();
}
info!("Queue Sent: {:?}", start.elapsed());
let stats = activity_queue.shutdown(true).await.unwrap();
info!(
"Queue Finished. Num msgs: {}, Time {:?}, msg/s: {:0.0}",
num_messages,
start.elapsed(),
num_messages as f64 / start.elapsed().as_secs_f64()
);
assert_eq!(
stats.completed_last_hour.load(Ordering::Relaxed),
num_messages
);
}
}

View file

@ -65,19 +65,19 @@ mod test {
use reqwest_middleware::ClientWithMiddleware;
use url::Url;
#[actix_rt::test]
#[tokio::test]
async fn test_receive_activity() {
let (body, incoming_request, config) = setup_receive_test().await;
receive_activity::<Follow, DbUser, DbConnection>(
incoming_request.to_http_request(),
body.into(),
body,
&config.to_request_data(),
)
.await
.unwrap();
}
#[actix_rt::test]
#[tokio::test]
async fn test_receive_activity_invalid_body_signature() {
let (_, incoming_request, config) = setup_receive_test().await;
let err = receive_activity::<Follow, DbUser, DbConnection>(
@ -93,13 +93,13 @@ mod test {
assert_eq!(e, &Error::ActivityBodyDigestInvalid)
}
#[actix_rt::test]
#[tokio::test]
async fn test_receive_activity_invalid_path() {
let (body, incoming_request, config) = setup_receive_test().await;
let incoming_request = incoming_request.uri("/wrong");
let err = receive_activity::<Follow, DbUser, DbConnection>(
incoming_request.to_http_request(),
body.into(),
body,
&config.to_request_data(),
)
.await
@ -110,7 +110,7 @@ mod test {
assert_eq!(e, &Error::ActivitySignatureInvalid)
}
async fn setup_receive_test() -> (String, TestRequest, FederationConfig<DbConnection>) {
async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) {
let inbox = "https://example.com/inbox";
let headers = generate_request_headers(&Url::parse(inbox).unwrap());
let request_builder = ClientWithMiddleware::from(Client::default())
@ -122,12 +122,12 @@ mod test {
kind: Default::default(),
id: "http://localhost:123/1".try_into().unwrap(),
};
let body = serde_json::to_string(&activity).unwrap();
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
let outgoing_request = sign_request(
request_builder,
activity.actor.into_inner(),
body.to_string(),
DB_USER_KEYPAIR.private_key.clone(),
&activity.actor.into_inner(),
body.clone(),
DB_USER_KEYPAIR.private_key().unwrap(),
false,
)
.await
@ -142,6 +142,7 @@ mod test {
.app_data(DbConnection)
.debug(true)
.build()
.await
.unwrap();
(body, incoming_request, config)
}

View file

@ -4,26 +4,27 @@
//!
//! ```
//! # use activitypub_federation::config::FederationConfig;
//! # let _ = actix_rt::System::new();
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
//! let settings = FederationConfig::builder()
//! .domain("example.com")
//! .app_data(())
//! .http_fetch_limit(50)
//! .worker_count(16)
//! .build()?;
//! .build().await?;
//! # Ok::<(), anyhow::Error>(())
//! # }).unwrap()
//! ```
use crate::{
activity_queue::create_activity_queue,
activity_queue::{create_activity_queue, ActivityQueue},
error::Error,
protocol::verification::verify_domains_match,
traits::{ActivityHandler, Actor},
};
use async_trait::async_trait;
use background_jobs::Manager;
use derive_builder::Builder;
use dyn_clone::{clone_trait_object, DynClone};
use openssl::pkey::{PKey, Private};
use reqwest_middleware::ClientWithMiddleware;
use serde::de::DeserializeOwned;
use std::{
@ -54,9 +55,14 @@ pub struct FederationConfig<T: Clone> {
/// HTTP client used for all outgoing requests. Middleware can be used to add functionality
/// like log tracing or retry of failed requests.
pub(crate) client: ClientWithMiddleware,
/// Number of worker threads for sending outgoing activities
#[builder(default = "64")]
pub(crate) worker_count: u64,
/// Number of tasks that can be in-flight concurrently.
/// Tasks are retried once after a minute, then put into the retry queue
#[builder(default = "1024")]
pub(crate) worker_count: usize,
/// Number of concurrent tasks that are being retried in-flight concurrently.
/// Tasks are retried after an hour, then again in 60 hours.
#[builder(default = "128")]
pub(crate) retry_count: usize,
/// Run library in debug mode. This allows usage of http and localhost urls. It also sends
/// outgoing activities synchronously, not in background thread. This helps to make tests
/// more consistent. Do not use for production.
@ -79,11 +85,11 @@ pub struct FederationConfig<T: Clone> {
/// This can be used to implement secure mode federation.
/// <https://docs.joinmastodon.org/spec/activitypub/#secure-mode>
#[builder(default = "None", setter(custom))]
pub(crate) signed_fetch_actor: Option<Arc<(Url, String)>>,
pub(crate) signed_fetch_actor: Option<Arc<(Url, PKey<Private>)>>,
/// Queue for sending outgoing activities. Only optional to make builder work, its always
/// present once constructed.
#[builder(setter(skip))]
pub(crate) activity_queue: Option<Arc<Manager>>,
pub(crate) activity_queue: Option<Arc<ActivityQueue>>,
}
impl<T: Clone> FederationConfig<T> {
@ -180,7 +186,10 @@ impl<T: Clone> FederationConfigBuilder<T> {
let private_key_pem = actor
.private_key_pem()
.expect("actor does not have a private key to sign with");
self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key_pem))));
let private_key = PKey::private_key_from_pem(private_key_pem.as_bytes())
.expect("Could not decode PEM data");
self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key))));
self
}
@ -188,13 +197,14 @@ impl<T: Clone> FederationConfigBuilder<T> {
///
/// Values which are not explicitly specified use the defaults. Also initializes the
/// queue for outgoing activities, which is stored internally in the config struct.
pub fn build(&mut self) -> Result<FederationConfig<T>, FederationConfigBuilderError> {
/// Requires a tokio runtime for the background queue.
pub async fn build(&mut self) -> Result<FederationConfig<T>, FederationConfigBuilderError> {
let mut config = self.partial_build()?;
let queue = create_activity_queue(
config.client.clone(),
config.worker_count,
config.retry_count,
config.request_timeout,
config.debug,
);
config.activity_queue = Some(Arc::new(queue));
Ok(config)

View file

@ -30,7 +30,7 @@ where
/// Fetches collection over HTTP
///
/// Unlike [ObjectId::fetch](crate::fetch::object_id::ObjectId::fetch) this method doesn't do
/// Unlike [ObjectId::dereference](crate::fetch::object_id::ObjectId::dereference) this method doesn't do
/// any caching.
pub async fn dereference(
&self,

View file

@ -9,6 +9,7 @@ use crate::{
reqwest_shim::ResponseExt,
FEDERATION_CONTENT_TYPE,
};
use bytes::Bytes;
use http::StatusCode;
use serde::de::DeserializeOwned;
use std::sync::atomic::Ordering;
@ -56,8 +57,8 @@ pub async fn fetch_object_http<T: Clone, Kind: DeserializeOwned>(
let res = if let Some((actor_id, private_key_pem)) = config.signed_fetch_actor.as_deref() {
let req = sign_request(
req,
actor_id.clone(),
String::new(),
actor_id,
Bytes::new(),
private_key_pem.clone(),
data.config.http_signature_compat,
)

View file

@ -36,13 +36,12 @@ where
/// # use activitypub_federation::config::FederationConfig;
/// # use activitypub_federation::error::Error::NotFound;
/// # use activitypub_federation::traits::tests::{DbConnection, DbUser};
/// # let _ = actix_rt::System::new();
/// # actix_rt::Runtime::new().unwrap().block_on(async {
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// # let db_connection = DbConnection;
/// let config = FederationConfig::builder()
/// .domain("example.com")
/// .app_data(db_connection)
/// .build()?;
/// .build().await?;
/// let request_data = config.to_request_data();
/// let object_id = ObjectId::<DbUser>::parse("https://lemmy.ml/u/nutomic")?;
/// // Attempt to fetch object from local database or fall back to remote server

View file

@ -199,12 +199,13 @@ mod tests {
traits::tests::{DbConnection, DbUser},
};
#[actix_rt::test]
#[tokio::test]
async fn test_webfinger() {
let config = FederationConfig::builder()
.domain("example.com")
.app_data(DbConnection)
.build()
.await
.unwrap();
let data = config.to_request_data();
let res =

View file

@ -13,12 +13,13 @@ use crate::{
traits::{Actor, Object},
};
use base64::{engine::general_purpose::STANDARD as Base64, Engine};
use bytes::Bytes;
use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri};
use http_signature_normalization_reqwest::prelude::{Config, SignExt};
use once_cell::sync::Lazy;
use openssl::{
hash::MessageDigest,
pkey::PKey,
pkey::{PKey, Private},
rsa::Rsa,
sign::{Signer, Verifier},
};
@ -26,7 +27,7 @@ use reqwest::Request;
use reqwest_middleware::RequestBuilder;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind};
use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind, time::Duration};
use tracing::debug;
use url::Url;
@ -39,6 +40,14 @@ pub struct Keypair {
pub public_key: String,
}
impl Keypair {
/// Helper method to turn this into an openssl private key
#[cfg(test)]
pub(crate) fn private_key(&self) -> Result<PKey<Private>, anyhow::Error> {
Ok(PKey::private_key_from_pem(self.private_key.as_bytes())?)
}
}
/// Generate a random asymmetric keypair for ActivityPub HTTP signatures.
pub fn generate_actor_keypair() -> Result<Keypair, std::io::Error> {
let rsa = Rsa::generate(2048)?;
@ -58,19 +67,27 @@ pub fn generate_actor_keypair() -> Result<Keypair, std::io::Error> {
})
}
/// Sets the amount of time that a signed request is valid. Currenlty 5 minutes
/// Mastodon & friends have ~1 hour expiry from creation if it's not set in the header
pub(crate) const EXPIRES_AFTER: Duration = Duration::from_secs(300);
/// Creates an HTTP post request to `inbox_url`, with the given `client` and `headers`, and
/// `activity` as request body. The request is signed with `private_key` and then sent.
pub(crate) async fn sign_request(
request_builder: RequestBuilder,
actor_id: Url,
activity: String,
private_key: String,
actor_id: &Url,
activity: Bytes,
private_key: PKey<Private>,
http_signature_compat: bool,
) -> Result<Request, anyhow::Error> {
static CONFIG: Lazy<Config> = Lazy::new(Config::new);
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| Config::new().mastodon_compat());
static CONFIG: Lazy<Config> = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER));
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| {
Config::new()
.mastodon_compat()
.set_expiration(EXPIRES_AFTER)
});
let key_id = main_key_id(&actor_id);
let key_id = main_key_id(actor_id);
let sig_conf = match http_signature_compat {
false => CONFIG.clone(),
true => CONFIG_COMPAT.clone(),
@ -82,7 +99,6 @@ pub(crate) async fn sign_request(
Sha256::new(),
activity,
move |signing_string| {
let private_key = PKey::private_key_from_pem(private_key.as_bytes())?;
let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?;
signer.update(signing_string.as_bytes())?;
@ -259,7 +275,7 @@ pub mod test {
static INBOX_URL: Lazy<Url> =
Lazy::new(|| Url::parse("https://example.com/u/alice/inbox").unwrap());
#[actix_rt::test]
#[tokio::test]
async fn test_sign() {
let mut headers = generate_request_headers(&INBOX_URL);
// use hardcoded date in order to test against hardcoded signature
@ -273,9 +289,9 @@ pub mod test {
.headers(headers);
let request = sign_request(
request_builder,
ACTOR_ID.clone(),
"my activity".to_string(),
test_keypair().private_key,
&ACTOR_ID,
"my activity".into(),
PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(),
// set this to prevent created/expires headers to be generated and inserted
// automatically from current time
true,
@ -301,7 +317,7 @@ pub mod test {
assert_eq!(signature, expected_signature);
}
#[actix_rt::test]
#[tokio::test]
async fn test_verify() {
let headers = generate_request_headers(&INBOX_URL);
let request_builder = ClientWithMiddleware::from(Client::new())
@ -309,9 +325,9 @@ pub mod test {
.headers(headers);
let request = sign_request(
request_builder,
ACTOR_ID.clone(),
"my activity".to_string(),
test_keypair().private_key,
&ACTOR_ID,
"my activity".to_string().into(),
PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(),
false,
)
.await