diff --git a/Cargo.lock b/Cargo.lock index 4b6796d..352b606 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,6 +38,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -242,6 +248,7 @@ dependencies = [ "http", "http_digest_headers", "httpdate", + "lru", "metrics", "metrics-exporter-prometheus", "metrics-util", @@ -671,6 +678,10 @@ name = "hashbrown" version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "heck" @@ -901,6 +912,15 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +[[package]] +name = "lru" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efa59af2ddfad1854ae27d75009d538d0998b4b2fd47083e743ac1a10e46c60" +dependencies = [ + "hashbrown 0.14.2", +] + [[package]] name = "mach2" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index 9f1b7df..11f9c2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,3 +32,4 @@ deunicode = "1.3" urlencoding = "2" httpdate = "1" redis = { version = "0.23", features = ["tokio-comp", "connection-manager"] } +lru = "0.12" diff --git a/src/endpoint.rs b/src/endpoint.rs index 5450bad..90a72ea 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -1,4 +1,8 @@ -use std::sync::Arc; +use std::{ + collections::HashMap, + sync::Arc, + time::Instant, +}; use axum::{ async_trait, @@ -7,15 +11,94 @@ use axum::{ http::{header::CONTENT_TYPE, Request, StatusCode}, BoxError, }; -use http_digest_headers::{DigestHeader}; - +use futures::Future; +use http_digest_headers::DigestHeader; use sigh::{Signature, PublicKey, Key, PrivateKey}; +use lru::LruCache; +use tokio::sync::{Mutex, oneshot}; use crate::fetch::authorized_fetch; use crate::activitypub::Actor; use crate::error::Error; + +#[allow(clippy::type_complexity)] +#[derive(Clone)] +pub struct ActorCache { + cache: Arc, Error>>>>, + queues: Arc, Error>>>>>>, +} + +impl Default for ActorCache { + fn default() -> Self { + ActorCache { + cache: Arc::new(Mutex::new( + LruCache::new(std::num::NonZeroUsize::new(64).unwrap()) + )), + queues: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +impl ActorCache { + pub async fn get(&self, k: &str, f: F) -> Result, Error> + where + F: (FnOnce() -> R) + Send + 'static, + R: Future> + Send, + { + let begin = Instant::now(); + + let mut lru = self.cache.lock().await; + if let Some(v) = lru.get(k) { + return v.clone(); + } + drop(lru); + + let (tx, rx) = oneshot::channel(); + let mut new = false; + let mut queues = self.queues.lock().await; + let queue = queues.entry(k.to_string()) + .or_insert_with(|| { + new = true; + Vec::with_capacity(1) + }); + queue.push(tx); + drop(queues); + + if new { + let k = k.to_string(); + let cache = self.cache.clone(); + let queues = self.queues.clone(); + tokio::spawn(async move { + let result = f().await + .map(Arc::new); + + let mut lru = cache.lock().await; + lru.put(k.clone(), result.clone()); + drop(lru); + + let mut queues = queues.lock().await; + let queue = queues.remove(&k) + .expect("queues.remove"); + let queue_len = queue.len(); + let mut notified = 0usize; + for tx in queue.into_iter() { + if let Ok(()) = tx.send(result.clone()) { + notified += 1; + } + } + + let end = Instant::now(); + tracing::info!("Notified {notified}/{queue_len} endpoint verifications for actor {k} in {:?}", end - begin); + }); + } + + rx.await.unwrap() + } +} + + const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[ "(request-target)", "host", "date", @@ -49,8 +132,8 @@ where } else { return Err((StatusCode::UNSUPPORTED_MEDIA_TYPE, "No content-type".to_string())); }; - if ! content_type.starts_with("application/json") && - ! (content_type.starts_with("application/") && content_type.ends_with("+json")) + if ! (content_type.starts_with("application/json") || + (content_type.starts_with("application/") && content_type.ends_with("+json"))) { return Err((StatusCode::UNSUPPORTED_MEDIA_TYPE, "Invalid content-type".to_string())); } @@ -105,12 +188,20 @@ impl<'a> Endpoint<'a> { pub async fn remote_actor( &self, client: &reqwest::Client, - key_id: &str, - private_key: &PrivateKey, - ) -> Result { - let remote_actor: Actor = serde_json::from_value( - authorized_fetch(client, &self.remote_actor_uri, key_id, private_key).await? - )?; + cache: &ActorCache, + key_id: String, + private_key: Arc, + ) -> Result, Error> { + let client = client.clone(); + let url = self.remote_actor_uri.clone(); + let remote_actor = cache.get(&self.remote_actor_uri, || async move { + tracing::info!("GET actor {}", url); + let actor: Actor = serde_json::from_value( + authorized_fetch(&client, &url, &key_id, &private_key).await? + )?; + Ok(actor) + }).await?; + let public_key = PublicKey::from_pem(remote_actor.public_key.pem.as_bytes())?; if ! (self.signature.verify(&public_key)?) { return Err(Error::SignatureFail); diff --git a/src/error.rs b/src/error.rs index 6980939..639d83d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,19 +1,45 @@ -#[derive(Debug, thiserror::Error)] +use std::sync::Arc; + +#[derive(Clone, Debug, thiserror::Error)] pub enum Error { #[error("HTTP Digest generation error")] Digest, #[error("JSON encoding error")] - Json(#[from] serde_json::Error), + Json(#[from] Arc), #[error("Signature error")] - Signature(#[from] sigh::Error), + Signature(#[from] Arc), #[error("Signature verification failure")] SignatureFail, #[error("HTTP request error")] - HttpReq(#[from] http::Error), + HttpReq(#[from] Arc), #[error("HTTP client error")] - Http(#[from] reqwest::Error), + Http(#[from] Arc), #[error("Invalid URI")] InvalidUri, #[error("Error response from remote")] Response(String), } + +impl From for Error { + fn from(e: serde_json::Error) -> Self { + Error::Json(Arc::new(e)) + } +} + +impl From for Error { + fn from(e: reqwest::Error) -> Self { + Error::Http(Arc::new(e)) + } +} + +impl From for Error { + fn from(e: sigh::Error) -> Self { + Error::Signature(Arc::new(e)) + } +} + +impl From for Error { + fn from(e: http::Error) -> Self { + Error::HttpReq(Arc::new(e)) + } +} diff --git a/src/main.rs b/src/main.rs index 340e7ca..f2e812b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -133,7 +133,7 @@ async fn post_relay( endpoint: endpoint::Endpoint<'_>, target: actor::Actor ) -> Response { - let remote_actor = match endpoint.remote_actor(&state.client, &target.key_id(), &state.priv_key).await { + let remote_actor = match endpoint.remote_actor(&state.client, &state.actor_cache, target.key_id(), state.priv_key.clone()).await { Ok(remote_actor) => remote_actor, Err(e) => { track_request("POST", "relay", "bad_actor"); diff --git a/src/send.rs b/src/send.rs index c94b919..2e2aeb4 100644 --- a/src/send.rs +++ b/src/send.rs @@ -16,8 +16,7 @@ pub async fn send( body: &T, ) -> Result<(), Error> { let body = Arc::new( - serde_json::to_vec(body) - .map_err(Error::Json)? + serde_json::to_vec(body)? ); send_raw(client, uri, key_id, private_key, body).await } @@ -41,8 +40,7 @@ pub async fn send_raw( .header("content-type", "application/activity+json") .header("date", httpdate::fmt_http_date(SystemTime::now())) .header("digest", digest_header) - .body(body.as_ref().clone()) - .map_err(Error::HttpReq)?; + .body(body.as_ref().clone())?; let t1 = Instant::now(); SigningConfig::new(RsaSha256, private_key, key_id) .sign(&mut req)?; diff --git a/src/state.rs b/src/state.rs index 5559725..577b625 100644 --- a/src/state.rs +++ b/src/state.rs @@ -3,13 +3,14 @@ use axum::{ }; use sigh::{PrivateKey, PublicKey}; use std::sync::Arc; -use crate::{config::Config, db::Database}; +use crate::{config::Config, db::Database, endpoint::ActorCache}; #[derive(Clone)] pub struct State { pub database: Database, pub redis: Option<(redis::aio::ConnectionManager, Arc)>, pub client: Arc, + pub actor_cache: ActorCache, pub hostname: Arc, pub priv_key: Arc, pub pub_key: Arc, @@ -30,6 +31,7 @@ impl State { database, redis: redis.map(|(connection, in_topic)| (connection, Arc::new(in_topic))), client: Arc::new(client), + actor_cache: Default::default(), hostname: Arc::new(config.hostname), priv_key, pub_key,