Refactor to not re-process PEM data on each request

This commit is contained in:
cetra3 2023-06-16 19:55:36 +09:30
parent 4126c77394
commit 76c9cc2683
6 changed files with 150 additions and 64 deletions

View file

@ -39,7 +39,12 @@ futures-core = { version = "0.3.28", default-features = false }
pin-project-lite = "0.2.9"
activitystreams-kinds = "0.3.0"
regex = { version = "1.8.4", default-features = false, features = ["std"] }
tokio = { version = "1.21.2", features = ["sync", "rt", "time"] }
tokio = { version = "1.21.2", features = [
"sync",
"rt",
"rt-multi-thread",
"time",
] }
# Actix-web
actix-web = { version = "4.3.1", default-features = false, optional = true }

View file

@ -12,12 +12,15 @@ use crate::{
};
use anyhow::anyhow;
use bytes::Bytes;
use futures_core::Future;
use http::{header::HeaderName, HeaderMap, HeaderValue};
use httpdate::fmt_http_date;
use itertools::Itertools;
use openssl::pkey::{PKey, Private};
use reqwest::Request;
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
use serde::Serialize;
use std::{
fmt::{Debug, Display},
sync::{
@ -56,10 +59,17 @@ where
let config = &data.config;
let actor_id = activity.actor();
let activity_id = activity.id();
let activity_serialized = serde_json::to_string_pretty(&activity)?;
let private_key = actor
let activity_serialized: Bytes = serde_json::to_vec(&activity)?.into();
let private_key_pem = actor
.private_key_pem()
.expect("Actor for sending activity has private key");
.ok_or_else(|| anyhow!("Actor {actor_id} does not contain a private key for signing"))?;
// This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening
let private_key = tokio::task::block_in_place(|| {
PKey::private_key_from_pem(private_key_pem.as_bytes())
.map_err(|err| anyhow!("Could not create private key from PEM data:{err}"))
})?;
let inboxes: Vec<Url> = inboxes
.into_iter()
.unique()
@ -106,17 +116,17 @@ where
Ok(())
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[derive(Clone, Debug)]
struct SendActivityTask {
actor_id: Url,
activity_id: Url,
activity: String,
activity: Bytes,
inbox: Url,
private_key: String,
private_key: PKey<Private>,
http_signature_compat: bool,
}
async fn do_send(
async fn sign_and_send(
task: &SendActivityTask,
client: &ClientWithMiddleware,
timeout: Duration,
@ -134,6 +144,15 @@ async fn do_send(
task.http_signature_compat,
)
.await?;
send(task, client, request).await
}
async fn send(
task: &SendActivityTask,
client: &ClientWithMiddleware,
request: Request,
) -> Result<(), anyhow::Error> {
let response = client.execute(request).await;
match response {
@ -220,14 +239,13 @@ struct Stats {
completed_last_hour: AtomicUsize,
}
/// We need to retry activity sending in case the target instances is temporarily unreachable.
/// In this case, the task is stored and resent when the instance is hopefully back up. This
/// list shows the retry intervals, and which events of the target instance can be covered:
/// - 60s (one minute, service restart)
/// - 60min (one hour, instance maintenance)
/// - 60h (2.5 days, major incident with rebuild from backup)
const MAX_RETRIES: usize = 3;
const BACKOFF: usize = 60;
#[derive(Clone, Copy)]
struct RetryStrategy {
/// Amount of time in seconds to back off
backoff: usize,
/// Amount of times to retry
retries: usize,
}
/// A tokio spawned worker which is responsible for submitting requests to federated servers
async fn worker(
@ -235,15 +253,13 @@ async fn worker(
timeout: Duration,
mut receiver: UnboundedReceiver<SendActivityTask>,
stats: Arc<Stats>,
strategy: RetryStrategy,
) {
while let Some(message) = receiver.recv().await {
// Update our counters as we're now "running" and not "pending"
stats.pending.fetch_sub(1, Ordering::Relaxed);
stats.running.fetch_add(1, Ordering::Relaxed);
// This will use the retry helper method below, with an exponential backoff
// If the task is sleeping, tokio will use work-stealing to keep it busy with something else
let outcome = retry(|| do_send(&message, &client, timeout), MAX_RETRIES, BACKOFF).await;
let outcome = retry(|| sign_and_send(&message, &client, timeout), strategy).await;
// "Running" has finished, check the outcome
stats.running.fetch_sub(1, Ordering::Relaxed);
@ -252,7 +268,6 @@ async fn worker(
Ok(_) => {
stats.completed_last_hour.fetch_add(1, Ordering::Relaxed);
}
// We might want to do something here
Err(_err) => {
stats.dead_last_hour.fetch_add(1, Ordering::Relaxed);
}
@ -261,7 +276,12 @@ async fn worker(
}
impl ActivityQueue {
fn new(client: ClientWithMiddleware, timeout: Duration, worker_count: usize) -> Self {
fn new(
client: ClientWithMiddleware,
worker_count: usize,
timeout: Duration,
strategy: RetryStrategy,
) -> Self {
// Keep a vec of senders to send our messages to
let mut senders = Vec::with_capacity(worker_count);
let mut handles = Vec::with_capacity(worker_count);
@ -287,6 +307,7 @@ impl ActivityQueue {
timeout,
receiver,
stats.clone(),
strategy,
)));
senders.push(sender);
}
@ -344,15 +365,30 @@ pub(crate) fn create_activity_queue(
worker_count > 0,
"worker count needs to be greater than zero"
);
/// We need to retry activity sending in case the target instances is temporarily unreachable.
/// In this case, the task is stored and resent when the instance is hopefully back up. This
/// list shows the retry intervals, and which events of the target instance can be covered:
/// - 60s (one minute, service restart)
/// - 60min (one hour, instance maintenance)
/// - 60h (2.5 days, major incident with rebuild from backup)
const MAX_RETRIES: usize = 3;
const BACKOFF: usize = 60;
ActivityQueue::new(client, request_timeout, worker_count)
ActivityQueue::new(
client,
worker_count,
request_timeout,
RetryStrategy {
backoff: BACKOFF,
retries: MAX_RETRIES,
},
)
}
/// Retries a future action factory function up to `amount` times with an exponential backoff timer between tries
async fn retry<T, E: Display, F: Future<Output = Result<T, E>>, A: FnMut() -> F>(
mut action: A,
amount: usize,
sleep_seconds: usize,
strategy: RetryStrategy,
) -> Result<T, E> {
let mut count = 0;
@ -360,12 +396,13 @@ async fn retry<T, E: Display, F: Future<Output = Result<T, E>>, A: FnMut() -> F>
match action().await {
Ok(val) => return Ok(val),
Err(err) => {
if count < amount {
if count < strategy.retries {
count += 1;
warn!("{err}");
let sleep_amt = sleep_seconds.pow(count as u32) as u64;
tokio::time::sleep(Duration::from_secs(sleep_amt)).await;
let sleep_amt = strategy.backoff.pow(count as u32) as u64;
let sleep_dur = Duration::from_secs(sleep_amt);
warn!("{err}. Sleeping for {sleep_dur:?} and trying again");
tokio::time::sleep(sleep_dur).await;
continue;
} else {
return Err(err);
@ -377,23 +414,41 @@ async fn retry<T, E: Display, F: Future<Output = Result<T, E>>, A: FnMut() -> F>
#[cfg(test)]
mod tests {
use axum::extract::State;
use bytes::Bytes;
use std::{thread::available_parallelism, time::Instant};
use http::StatusCode;
use std::time::Instant;
use crate::http_signatures::generate_actor_keypair;
use super::*;
#[allow(unused)]
// This will periodically send back internal errors to test the retry
async fn dodgy_handler(
State(state): State<Arc<AtomicUsize>>,
headers: HeaderMap,
body: Bytes,
) -> Result<(), StatusCode> {
debug!("Headers:{:?}", headers);
debug!("Body len:{}", body.len());
if state.fetch_add(1, Ordering::Relaxed) % 20 == 0 {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
Ok(())
}
async fn test_server() {
use axum::{routing::post, Router};
let app = Router::new().route(
"/",
post(|headers: HeaderMap, body: Bytes| async move {
debug!("Headers:{:?}", headers);
debug!("Body len:{}", body.len());
}),
);
// run it with hyper on localhost:3000
// We should break every now and then ;)
let state = Arc::new(AtomicUsize::new(0));
let app = Router::new()
.route("/", post(dodgy_handler))
.with_state(state);
axum::Server::bind(&"0.0.0.0:8001".parse().unwrap())
.serve(app.into_make_service())
.await
@ -403,8 +458,8 @@ mod tests {
#[tokio::test(flavor = "multi_thread")]
// Queues 10_000 messages and then asserts that the worker runs them
async fn test_activity_queue_workers() {
let num_workers = available_parallelism().unwrap().get();
let num_messages: usize = 10_000;
let num_workers = 64;
let num_messages: usize = 100;
tokio::spawn(test_server());
@ -419,10 +474,14 @@ mod tests {
.init();
*/
let activity_queue = create_activity_queue(
let activity_queue = ActivityQueue::new(
reqwest::Client::default().into(),
num_workers,
Duration::from_secs(10),
RetryStrategy {
backoff: 1,
retries: 3,
},
);
let keypair = generate_actor_keypair().unwrap();
@ -432,7 +491,7 @@ mod tests {
activity_id: "http://localhost:8001/activity".parse().unwrap(),
activity: "{}".into(),
inbox: "http://localhost:8001".parse().unwrap(),
private_key: keypair.private_key.clone(),
private_key: keypair.private_key().unwrap(),
http_signature_compat: true,
};

View file

@ -70,7 +70,7 @@ mod test {
let (body, incoming_request, config) = setup_receive_test().await;
receive_activity::<Follow, DbUser, DbConnection>(
incoming_request.to_http_request(),
body.into(),
body,
&config.to_request_data(),
)
.await
@ -99,7 +99,7 @@ mod test {
let incoming_request = incoming_request.uri("/wrong");
let err = receive_activity::<Follow, DbUser, DbConnection>(
incoming_request.to_http_request(),
body.into(),
body,
&config.to_request_data(),
)
.await
@ -110,7 +110,7 @@ mod test {
assert_eq!(e, &Error::ActivitySignatureInvalid)
}
async fn setup_receive_test() -> (String, TestRequest, FederationConfig<DbConnection>) {
async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) {
let inbox = "https://example.com/inbox";
let headers = generate_request_headers(&Url::parse(inbox).unwrap());
let request_builder = ClientWithMiddleware::from(Client::default())
@ -122,12 +122,12 @@ mod test {
kind: Default::default(),
id: "http://localhost:123/1".try_into().unwrap(),
};
let body = serde_json::to_string(&activity).unwrap();
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
let outgoing_request = sign_request(
request_builder,
&activity.actor.into_inner(),
body.to_string(),
DB_USER_KEYPAIR.private_key.clone(),
body.clone(),
DB_USER_KEYPAIR.private_key().unwrap(),
false,
)
.await

View file

@ -24,6 +24,7 @@ use crate::{
use async_trait::async_trait;
use derive_builder::Builder;
use dyn_clone::{clone_trait_object, DynClone};
use openssl::pkey::{PKey, Private};
use reqwest_middleware::ClientWithMiddleware;
use serde::de::DeserializeOwned;
use std::{
@ -54,7 +55,8 @@ pub struct FederationConfig<T: Clone> {
/// HTTP client used for all outgoing requests. Middleware can be used to add functionality
/// like log tracing or retry of failed requests.
pub(crate) client: ClientWithMiddleware,
/// Number of worker threads for sending outgoing activities
/// Number of worker tasks for sending outgoing activities
/// Should be a multiplier of the number of threads
#[builder(default = "64")]
pub(crate) worker_count: usize,
/// Run library in debug mode. This allows usage of http and localhost urls. It also sends
@ -79,7 +81,7 @@ pub struct FederationConfig<T: Clone> {
/// This can be used to implement secure mode federation.
/// <https://docs.joinmastodon.org/spec/activitypub/#secure-mode>
#[builder(default = "None", setter(custom))]
pub(crate) signed_fetch_actor: Option<Arc<(Url, String)>>,
pub(crate) signed_fetch_actor: Option<Arc<(Url, PKey<Private>)>>,
/// Queue for sending outgoing activities. Only optional to make builder work, its always
/// present once constructed.
#[builder(setter(skip))]
@ -180,7 +182,10 @@ impl<T: Clone> FederationConfigBuilder<T> {
let private_key_pem = actor
.private_key_pem()
.expect("actor does not have a private key to sign with");
self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key_pem))));
let private_key = PKey::private_key_from_pem(private_key_pem.as_bytes())
.expect("Could not decode PEM data");
self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key))));
self
}

View file

@ -9,6 +9,7 @@ use crate::{
reqwest_shim::ResponseExt,
FEDERATION_CONTENT_TYPE,
};
use bytes::Bytes;
use http::StatusCode;
use serde::de::DeserializeOwned;
use std::sync::atomic::Ordering;
@ -57,7 +58,7 @@ pub async fn fetch_object_http<T: Clone, Kind: DeserializeOwned>(
let req = sign_request(
req,
actor_id,
String::new(),
Bytes::new(),
private_key_pem.clone(),
data.config.http_signature_compat,
)

View file

@ -13,12 +13,13 @@ use crate::{
traits::{Actor, Object},
};
use base64::{engine::general_purpose::STANDARD as Base64, Engine};
use bytes::Bytes;
use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri};
use http_signature_normalization_reqwest::prelude::{Config, SignExt};
use once_cell::sync::Lazy;
use openssl::{
hash::MessageDigest,
pkey::PKey,
pkey::{PKey, Private},
rsa::Rsa,
sign::{Signer, Verifier},
};
@ -26,7 +27,7 @@ use reqwest::Request;
use reqwest_middleware::RequestBuilder;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind};
use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind, time::Duration};
use tracing::debug;
use url::Url;
@ -39,6 +40,14 @@ pub struct Keypair {
pub public_key: String,
}
impl Keypair {
/// Helper method to turn this into an openssl private key
#[cfg(test)]
pub(crate) fn private_key(&self) -> Result<PKey<Private>, anyhow::Error> {
Ok(PKey::private_key_from_pem(self.private_key.as_bytes())?)
}
}
/// Generate a random asymmetric keypair for ActivityPub HTTP signatures.
pub fn generate_actor_keypair() -> Result<Keypair, std::io::Error> {
let rsa = Rsa::generate(2048)?;
@ -58,17 +67,25 @@ pub fn generate_actor_keypair() -> Result<Keypair, std::io::Error> {
})
}
/// Sets the amount of time that a signed request is valid. Currenlty 5 minutes
/// Mastodon & friends have ~1 hour expiry from creation if it's not set in the header
pub(crate) const EXPIRES_AFTER: Duration = Duration::from_secs(300);
/// Creates an HTTP post request to `inbox_url`, with the given `client` and `headers`, and
/// `activity` as request body. The request is signed with `private_key` and then sent.
pub(crate) async fn sign_request(
request_builder: RequestBuilder,
actor_id: &Url,
activity: String,
private_key: String,
activity: Bytes,
private_key: PKey<Private>,
http_signature_compat: bool,
) -> Result<Request, anyhow::Error> {
static CONFIG: Lazy<Config> = Lazy::new(Config::new);
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| Config::new().mastodon_compat());
static CONFIG: Lazy<Config> = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER));
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| {
Config::new()
.mastodon_compat()
.set_expiration(EXPIRES_AFTER)
});
let key_id = main_key_id(actor_id);
let sig_conf = match http_signature_compat {
@ -80,9 +97,8 @@ pub(crate) async fn sign_request(
sig_conf.clone(),
key_id,
Sha256::new(),
activity.clone(),
activity,
move |signing_string| {
let private_key = PKey::private_key_from_pem(private_key.as_bytes())?;
let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?;
signer.update(signing_string.as_bytes())?;
@ -275,7 +291,7 @@ pub mod test {
request_builder,
&ACTOR_ID,
"my activity".into(),
test_keypair().private_key,
PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(),
// set this to prevent created/expires headers to be generated and inserted
// automatically from current time
true,
@ -310,8 +326,8 @@ pub mod test {
let request = sign_request(
request_builder,
&ACTOR_ID,
"my activity".to_string(),
test_keypair().private_key,
"my activity".to_string().into(),
PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(),
false,
)
.await