From 8c431650e258555252429ebcc2eacfc4cdc40760 Mon Sep 17 00:00:00 2001 From: Astro Date: Sat, 28 Oct 2023 01:29:18 +0200 Subject: [PATCH] endpoint: refactor actor_cache into separate file --- src/actor_cache.rs | 88 ++++++++++++++++++++++++++++++++++++++++++++++ src/endpoint.rs | 88 ++-------------------------------------------- src/main.rs | 1 + src/state.rs | 2 +- 4 files changed, 92 insertions(+), 87 deletions(-) create mode 100644 src/actor_cache.rs diff --git a/src/actor_cache.rs b/src/actor_cache.rs new file mode 100644 index 0000000..efdb292 --- /dev/null +++ b/src/actor_cache.rs @@ -0,0 +1,88 @@ +use std::{ + collections::HashMap, + sync::Arc, + time::Instant, +}; + +use futures::Future; +use lru::LruCache; +use tokio::sync::{Mutex, oneshot}; + +use crate::activitypub::Actor; +use crate::error::Error; + + +#[allow(clippy::type_complexity)] +#[derive(Clone)] +pub struct ActorCache { + cache: Arc, Error>>>>, + queues: Arc, Error>>>>>>, +} + +impl Default for ActorCache { + fn default() -> Self { + ActorCache { + cache: Arc::new(Mutex::new( + LruCache::new(std::num::NonZeroUsize::new(64).unwrap()) + )), + queues: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +impl ActorCache { + pub async fn get(&self, k: &str, f: F) -> Result, Error> + where + F: (FnOnce() -> R) + Send + 'static, + R: Future> + Send, + { + let begin = Instant::now(); + + let mut lru = self.cache.lock().await; + if let Some(v) = lru.get(k) { + return v.clone(); + } + drop(lru); + + let (tx, rx) = oneshot::channel(); + let mut new = false; + let mut queues = self.queues.lock().await; + let queue = queues.entry(k.to_string()) + .or_insert_with(|| { + new = true; + Vec::with_capacity(1) + }); + queue.push(tx); + drop(queues); + + if new { + let k = k.to_string(); + let cache = self.cache.clone(); + let queues = self.queues.clone(); + tokio::spawn(async move { + let result = f().await + .map(Arc::new); + + let mut lru = cache.lock().await; + lru.put(k.clone(), result.clone()); + drop(lru); + + let mut queues = queues.lock().await; + let queue = queues.remove(&k) + .expect("queues.remove"); + let queue_len = queue.len(); + let mut notified = 0usize; + for tx in queue.into_iter() { + if let Ok(()) = tx.send(result.clone()) { + notified += 1; + } + } + + let end = Instant::now(); + tracing::info!("Notified {notified}/{queue_len} endpoint verifications for actor {k} in {:?}", end - begin); + }); + } + + rx.await.unwrap() + } +} diff --git a/src/endpoint.rs b/src/endpoint.rs index 90a72ea..348513e 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -1,8 +1,4 @@ -use std::{ - collections::HashMap, - sync::Arc, - time::Instant, -}; +use std::sync::Arc; use axum::{ async_trait, @@ -10,93 +6,13 @@ use axum::{ extract::{FromRef, FromRequest}, http::{header::CONTENT_TYPE, Request, StatusCode}, BoxError, }; - -use futures::Future; use http_digest_headers::DigestHeader; use sigh::{Signature, PublicKey, Key, PrivateKey}; -use lru::LruCache; -use tokio::sync::{Mutex, oneshot}; - use crate::fetch::authorized_fetch; use crate::activitypub::Actor; use crate::error::Error; - - -#[allow(clippy::type_complexity)] -#[derive(Clone)] -pub struct ActorCache { - cache: Arc, Error>>>>, - queues: Arc, Error>>>>>>, -} - -impl Default for ActorCache { - fn default() -> Self { - ActorCache { - cache: Arc::new(Mutex::new( - LruCache::new(std::num::NonZeroUsize::new(64).unwrap()) - )), - queues: Arc::new(Mutex::new(HashMap::new())), - } - } -} - -impl ActorCache { - pub async fn get(&self, k: &str, f: F) -> Result, Error> - where - F: (FnOnce() -> R) + Send + 'static, - R: Future> + Send, - { - let begin = Instant::now(); - - let mut lru = self.cache.lock().await; - if let Some(v) = lru.get(k) { - return v.clone(); - } - drop(lru); - - let (tx, rx) = oneshot::channel(); - let mut new = false; - let mut queues = self.queues.lock().await; - let queue = queues.entry(k.to_string()) - .or_insert_with(|| { - new = true; - Vec::with_capacity(1) - }); - queue.push(tx); - drop(queues); - - if new { - let k = k.to_string(); - let cache = self.cache.clone(); - let queues = self.queues.clone(); - tokio::spawn(async move { - let result = f().await - .map(Arc::new); - - let mut lru = cache.lock().await; - lru.put(k.clone(), result.clone()); - drop(lru); - - let mut queues = queues.lock().await; - let queue = queues.remove(&k) - .expect("queues.remove"); - let queue_len = queue.len(); - let mut notified = 0usize; - for tx in queue.into_iter() { - if let Ok(()) = tx.send(result.clone()) { - notified += 1; - } - } - - let end = Instant::now(); - tracing::info!("Notified {notified}/{queue_len} endpoint verifications for actor {k} in {:?}", end - begin); - }); - } - - rx.await.unwrap() - } -} +use crate::actor_cache::ActorCache; const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[ diff --git a/src/main.rs b/src/main.rs index f2e812b..0f1b3a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,6 +25,7 @@ mod send; mod stream; mod relay; mod activitypub; +mod actor_cache; mod endpoint; use state::State; diff --git a/src/state.rs b/src/state.rs index 577b625..f12ddfc 100644 --- a/src/state.rs +++ b/src/state.rs @@ -3,7 +3,7 @@ use axum::{ }; use sigh::{PrivateKey, PublicKey}; use std::sync::Arc; -use crate::{config::Config, db::Database, endpoint::ActorCache}; +use crate::{config::Config, db::Database, actor_cache::ActorCache}; #[derive(Clone)] pub struct State {