use crate::{ data::LastOnline, error::{Error, ErrorKind}, }; use activitystreams::iri_string::types::IriString; use actix_web::http::header::Date; use awc::{error::SendRequestError, Client, ClientResponse, Connector}; use base64::{engine::general_purpose::STANDARD, Engine}; use dashmap::DashMap; use http_signature_normalization_actix::{prelude::*, Canceled, Spawn}; use rand::thread_rng; use rsa::{ pkcs1v15::SigningKey, sha2::{Digest, Sha256}, signature::{RandomizedSigner, SignatureEncoding}, RsaPrivateKey, }; use std::{ panic::AssertUnwindSafe, sync::{ atomic::{AtomicBool, Ordering}, Arc, }, thread::JoinHandle, time::{Duration, SystemTime}, }; use tracing_awc::Tracing; const ONE_SECOND: u64 = 1; const ONE_MINUTE: u64 = 60 * ONE_SECOND; const ONE_HOUR: u64 = 60 * ONE_MINUTE; const ONE_DAY: u64 = 24 * ONE_HOUR; #[derive(Clone)] pub(crate) struct Breakers { inner: Arc>, } impl std::fmt::Debug for Breakers { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Breakers").finish() } } impl Breakers { fn should_try(&self, url: &IriString) -> bool { if let Some(authority) = url.authority_str() { if let Some(breaker) = self.inner.get(authority) { breaker.should_try() } else { true } } else { false } } fn fail(&self, url: &IriString) { if let Some(authority) = url.authority_str() { let should_write = { if let Some(mut breaker) = self.inner.get_mut(authority) { breaker.fail(); if !breaker.should_try() { tracing::warn!("Failed breaker for {authority}"); } false } else { true } }; if should_write { let mut breaker = self.inner.entry(authority.to_owned()).or_default(); breaker.fail(); } } } fn succeed(&self, url: &IriString) { if let Some(authority) = url.authority_str() { let should_write = { if let Some(mut breaker) = self.inner.get_mut(authority) { breaker.succeed(); false } else { true } }; if should_write { let mut breaker = self.inner.entry(authority.to_owned()).or_default(); breaker.succeed(); } } } } impl Default for Breakers { fn default() -> Self { Breakers { inner: Arc::new(DashMap::new()), } } } #[derive(Debug)] struct Breaker { failures: usize, last_attempt: SystemTime, last_success: SystemTime, } impl Breaker { const FAILURE_WAIT: Duration = Duration::from_secs(ONE_DAY); const FAILURE_THRESHOLD: usize = 10; fn should_try(&self) -> bool { self.failures < Self::FAILURE_THRESHOLD || self.last_attempt + Self::FAILURE_WAIT < SystemTime::now() } fn fail(&mut self) { self.failures += 1; self.last_attempt = SystemTime::now(); } fn succeed(&mut self) { self.failures = 0; self.last_attempt = SystemTime::now(); self.last_success = SystemTime::now(); } } impl Default for Breaker { fn default() -> Self { let now = SystemTime::now(); Breaker { failures: 0, last_attempt: now, last_success: now, } } } #[derive(Clone)] pub(crate) struct Requests { pool_size: usize, client: Client, key_id: String, user_agent: String, private_key: RsaPrivateKey, config: Config, breakers: Breakers, last_online: Arc, } impl std::fmt::Debug for Requests { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Requests") .field("pool_size", &self.pool_size) .field("key_id", &self.key_id) .field("user_agent", &self.user_agent) .field("config", &self.config) .field("breakers", &self.breakers) .finish() } } thread_local! { static CLIENT: std::cell::OnceCell = std::cell::OnceCell::new(); } pub(crate) fn build_client(user_agent: &str, pool_size: usize, timeout_seconds: u64) -> Client { CLIENT.with(|client| { client .get_or_init(|| { let connector = Connector::new().limit(pool_size); Client::builder() .connector(connector) .wrap(Tracing) .add_default_header(("User-Agent", user_agent.to_string())) .timeout(Duration::from_secs(timeout_seconds)) .finish() }) .clone() }) } impl Requests { pub(crate) fn new( key_id: String, private_key: RsaPrivateKey, user_agent: String, breakers: Breakers, last_online: Arc, pool_size: usize, timeout_seconds: u64, spawner: Spawner, ) -> Self { Requests { pool_size, client: build_client(&user_agent, pool_size, timeout_seconds), key_id, user_agent, private_key, config: Config::new().mastodon_compat().spawner(spawner), breakers, last_online, } } pub(crate) fn reset_breaker(&self, iri: &IriString) { self.breakers.succeed(iri); } async fn check_response( &self, parsed_url: &IriString, res: Result, ) -> Result { if res.is_err() { self.breakers.fail(&parsed_url); } let mut res = res.map_err(|e| ErrorKind::SendRequest(parsed_url.to_string(), e.to_string()))?; if res.status().is_server_error() { self.breakers.fail(&parsed_url); if let Ok(bytes) = res.body().await { if let Ok(s) = String::from_utf8(bytes.as_ref().to_vec()) { if !s.is_empty() { tracing::warn!("Response from {parsed_url}, {s}"); } } } return Err(ErrorKind::Status(parsed_url.to_string(), res.status()).into()); } self.last_online.mark_seen(&parsed_url); self.breakers.succeed(&parsed_url); Ok(res) } #[tracing::instrument(name = "Fetch Json", skip(self), fields(signing_string))] pub(crate) async fn fetch_json(&self, url: &IriString) -> Result where T: serde::de::DeserializeOwned, { self.do_fetch(url, "application/json").await } #[tracing::instrument(name = "Fetch Json", skip(self), fields(signing_string))] pub(crate) async fn fetch_json_msky(&self, url: &IriString) -> Result where T: serde::de::DeserializeOwned, { let mut res = self .do_deliver( url, &serde_json::json!({}), "application/json", "application/json", ) .await?; let body = res .body() .await .map_err(|e| ErrorKind::ReceiveResponse(url.to_string(), e.to_string()))?; Ok(serde_json::from_slice(body.as_ref())?) } #[tracing::instrument(name = "Fetch Activity+Json", skip(self), fields(signing_string))] pub(crate) async fn fetch(&self, url: &IriString) -> Result where T: serde::de::DeserializeOwned, { self.do_fetch(url, "application/activity+json").await } async fn do_fetch(&self, url: &IriString, accept: &str) -> Result where T: serde::de::DeserializeOwned, { let mut res = self.do_fetch_response(url, accept).await?; let body = res .body() .await .map_err(|e| ErrorKind::ReceiveResponse(url.to_string(), e.to_string()))?; Ok(serde_json::from_slice(body.as_ref())?) } #[tracing::instrument(name = "Fetch response", skip(self), fields(signing_string))] pub(crate) async fn fetch_response(&self, url: &IriString) -> Result { self.do_fetch_response(url, "*/*").await } pub(crate) async fn do_fetch_response( &self, url: &IriString, accept: &str, ) -> Result { if !self.breakers.should_try(url) { return Err(ErrorKind::Breaker.into()); } let signer = self.signer(); let span = tracing::Span::current(); let res = self .client .get(url.as_str()) .insert_header(("Accept", accept)) .insert_header(Date(SystemTime::now().into())) .no_decompress() .signature( self.config.clone(), self.key_id.clone(), move |signing_string| { span.record("signing_string", signing_string); span.in_scope(|| signer.sign(signing_string)) }, ) .await? .send() .await; let res = self.check_response(url, res).await?; Ok(res) } #[tracing::instrument( "Deliver to Inbox", skip_all, fields(inbox = inbox.to_string().as_str(), signing_string) )] pub(crate) async fn deliver(&self, inbox: &IriString, item: &T) -> Result<(), Error> where T: serde::ser::Serialize + std::fmt::Debug, { self.do_deliver( inbox, item, "application/activity+json", "application/activity+json", ) .await?; Ok(()) } async fn do_deliver( &self, inbox: &IriString, item: &T, content_type: &str, accept: &str, ) -> Result where T: serde::ser::Serialize + std::fmt::Debug, { if !self.breakers.should_try(&inbox) { return Err(ErrorKind::Breaker.into()); } let signer = self.signer(); let span = tracing::Span::current(); let item_string = serde_json::to_string(item)?; let (req, body) = self .client .post(inbox.as_str()) .insert_header(("Accept", accept)) .insert_header(("Content-Type", content_type)) .insert_header(Date(SystemTime::now().into())) .signature_with_digest( self.config.clone(), self.key_id.clone(), Sha256::new(), item_string, move |signing_string| { span.record("signing_string", signing_string); span.in_scope(|| signer.sign(signing_string)) }, ) .await? .split(); let res = req.send_body(body).await; let res = self.check_response(inbox, res).await?; Ok(res) } fn signer(&self) -> Signer { Signer { private_key: self.private_key.clone(), } } } struct Signer { private_key: RsaPrivateKey, } impl Signer { fn sign(&self, signing_string: &str) -> Result { let signing_key = SigningKey::::new(self.private_key.clone()); let signature = signing_key.try_sign_with_rng(&mut thread_rng(), signing_string.as_bytes())?; Ok(STANDARD.encode(signature.to_bytes().as_ref())) } } fn signature_thread( receiver: flume::Receiver>, shutdown: flume::Receiver<()>, ) { let stopping = AtomicBool::new(false); while !stopping.load(Ordering::Acquire) { flume::Selector::new() .recv(&receiver, |res| match res { Ok(f) => { let res = std::panic::catch_unwind(AssertUnwindSafe(move || { (f)(); })); if let Err(e) = res { tracing::warn!("Signature fn panicked: {e:?}"); } } Err(_) => { tracing::warn!("Receive error, stopping"); stopping.store(true, Ordering::Release) } }) .recv(&shutdown, |_| { tracing::warn!("Stopping"); stopping.store(true, Ordering::Release) }) .wait(); } } #[derive(Clone, Debug)] pub(crate) struct Spawner { sender: flume::Sender>, threads: Option>>>, shutdown: flume::Sender<()>, } impl Spawner { pub(crate) fn build() -> std::io::Result { let threads = std::thread::available_parallelism() .map(usize::from) .unwrap_or(1); let (sender, receiver) = flume::bounded(8); let (shutdown, shutdown_rx) = flume::bounded(threads); let threads = (0..threads) .map(|i| { let receiver = receiver.clone(); let shutdown_rx = shutdown_rx.clone(); std::thread::Builder::new() .name(format!("signature-thread-{i}")) .spawn(move || { signature_thread(receiver, shutdown_rx); }) }) .collect::, _>>()?; Ok(Spawner { sender, threads: Some(Arc::new(threads)), shutdown, }) } } impl Drop for Spawner { fn drop(&mut self) { if let Some(threads) = self.threads.take().and_then(Arc::into_inner) { for _ in &threads { let _ = self.shutdown.send(()); } for thread in threads { let _ = thread.join(); } } } } impl Spawn for Spawner { type Future = std::pin::Pin>>>; fn spawn_blocking(&self, func: Func) -> Self::Future where Func: FnOnce() -> Out + Send + 'static, Out: Send + 'static, { let sender = self.sender.clone(); Box::pin(async move { let (tx, rx) = flume::bounded(1); let _ = sender .send_async(Box::new(move || { if tx.send((func)()).is_err() { tracing::warn!("Requestor hung up"); } })) .await; rx.recv_async().await.map_err(|_| Canceled) }) } }