diff --git a/Cargo.lock b/Cargo.lock index 285e113..6b7231a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -330,6 +330,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d2e7343e7fc9de883d1b0341e0b13970f764c14101234857d2ddafa1cb1cac2" +[[package]] +name = "ahash" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f33b5018f120946c1dcf279194f238a9f146725593ead1c08fa47ff22b0b5d3" +dependencies = [ + "const-random", +] + [[package]] name = "aho-corasick" version = "0.7.10" @@ -358,11 +367,13 @@ dependencies = [ "dotenv", "futures", "log", + "lru", "pretty_env_logger", "serde", "serde_json", "thiserror", "tokio", + "ttl_cache", ] [[package]] @@ -393,6 +404,12 @@ dependencies = [ "winapi 0.3.8", ] +[[package]] +name = "autocfg" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2" + [[package]] name = "autocfg" version = "1.0.0" @@ -593,6 +610,26 @@ dependencies = [ "bitflags", ] +[[package]] +name = "const-random" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f1af9ac737b2dd2d577701e59fd09ba34822f6f2ebdb30a7647405d9e55e16a" +dependencies = [ + "const-random-macro", + "proc-macro-hack", +] + +[[package]] +name = "const-random-macro" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25e4c606eb459dd29f7c57b2e0879f2b6f14ee130918c2b78ccb58a9624e6c7a" +dependencies = [ + "getrandom", + "proc-macro-hack", +] + [[package]] name = "copyless" version = "0.1.4" @@ -624,7 +661,7 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8" dependencies = [ - "autocfg", + "autocfg 1.0.0", "cfg-if", "lazy_static", ] @@ -948,6 +985,16 @@ dependencies = [ "tokio-util 0.2.0", ] +[[package]] +name = "hashbrown" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e6073d0ca812575946eb5f35ff68dbe519907b25c42530389ff946dc84c6ead" +dependencies = [ + "ahash", + "autocfg 0.1.7", +] + [[package]] name = "heck" version = "0.3.1" @@ -1030,7 +1077,7 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "076f042c5b7b98f31d205f1249267e12a6518c1481e9dae9764af19b707d2292" dependencies = [ - "autocfg", + "autocfg 1.0.0", ] [[package]] @@ -1112,6 +1159,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "lru" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0609345ddee5badacf857d4f547e0e5a2e987db77085c24cd887f73573a04237" +dependencies = [ + "hashbrown", +] + [[package]] name = "lru-cache" version = "0.1.2" @@ -1225,7 +1281,7 @@ version = "0.1.42" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f6ea62e9d81a77cd3ee9a2a5b9b609447857f3d358704331e4ef39eb247fcba" dependencies = [ - "autocfg", + "autocfg 1.0.0", "num-traits", ] @@ -1235,7 +1291,7 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c62be47e61d1842b9170f0fdeec8eba98e60e90e5446449a0545e5152acd7096" dependencies = [ - "autocfg", + "autocfg 1.0.0", ] [[package]] @@ -1274,7 +1330,7 @@ version = "0.9.54" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1024c0a59774200a555087a6da3f253a9095a5f344e353b212ac4c8b8e450986" dependencies = [ - "autocfg", + "autocfg 1.0.0", "cc", "libc", "pkg-config", @@ -2005,6 +2061,15 @@ dependencies = [ "trust-dns-proto", ] +[[package]] +name = "ttl_cache" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4189890526f0168710b6ee65ceaedf1460c48a14318ceec933cb26baa492096a" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "typenum" version = "1.11.2" diff --git a/Cargo.toml b/Cargo.toml index f78b92e..2864410 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,10 @@ bb8-postgres = "0.4.0" dotenv = "0.15.0" futures = "0.3.4" log = "0.4" +lru = "0.4.3" pretty_env_logger = "0.4.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = "1.0" tokio = { version = "0.2.13", features = ["sync"] } +ttl_cache = "0.5.1" diff --git a/src/cache.rs b/src/cache.rs deleted file mode 100644 index 8e93b15..0000000 --- a/src/cache.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::collections::{BTreeMap, HashMap, LinkedList}; - -pub struct WeightedCache -where - K: std::hash::Hash + Eq + Clone, -{ - size: usize, - capacity: usize, - forward: HashMap, - backward: BTreeMap>, -} - -#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub struct Count(usize); - -impl WeightedCache -where - K: std::hash::Hash + Eq + Clone, -{ - /// Create a new Weighted Cache - /// - /// panics if capacity is 0 - pub fn new(capacity: usize) -> Self { - if capacity == 0 { - panic!("Cache Capacity must be > 0"); - } - - WeightedCache { - size: 0, - capacity, - forward: HashMap::new(), - backward: BTreeMap::new(), - } - } - - /// Gets a value from the weighted cache - pub fn get(&mut self, key: K) -> Option<&V> { - let (value, count) = self.forward.get_mut(&key)?; - - if let Some(v) = self.backward.get_mut(count) { - v.drain_filter(|item| item == &key); - } - - count.0 += 1; - - let entry = self.backward.entry(*count).or_insert(LinkedList::new()); - entry.push_back(key); - - Some(&*value) - } - - /// set a value in the weighted cache - pub fn insert(&mut self, key: K, value: V) -> Option { - if self.forward.contains_key(&key) { - return None; - } - - let ret = if self.size >= self.capacity { - self.remove_least() - } else { - None - }; - - let count = Count(1); - self.forward.insert(key.clone(), (value, count)); - let entry = self.backward.entry(count).or_insert(LinkedList::new()); - - entry.push_back(key); - self.size += 1; - - ret - } - - fn remove_least(&mut self) -> Option { - let items = self.backward.values_mut().next()?; - - let oldest = items.pop_front()?; - let length = items.len(); - drop(items); - - let (item, count) = self.forward.remove(&oldest)?; - - if length == 0 { - self.backward.remove(&count); - self.backward = self - .backward - .clone() - .into_iter() - .map(|(mut k, v)| { - k.0 -= count.0; - (k, v) - }) - .collect(); - } - - self.size -= 1; - - Some(item) - } -} diff --git a/src/inbox.rs b/src/inbox.rs index 46774a7..8b3a234 100644 --- a/src/inbox.rs +++ b/src/inbox.rs @@ -2,7 +2,6 @@ use activitystreams::primitives::XsdAnyUri; use actix::Addr; use actix_web::{client::Client, web, Responder}; use log::info; -use std::sync::Arc; use crate::{ apub::{AcceptedActors, AcceptedObjects, ValidTypes}, @@ -20,11 +19,10 @@ pub async fn inbox( client: web::Data, input: web::Json, ) -> Result { - let _state = state.into_inner(); let input = input.into_inner(); info!("Relaying {} for {}", input.object.id(), input.actor); - let actor = fetch_actor(client.into_inner(), &input.actor).await?; + let actor = fetch_actor(state, client, &input.actor).await?; info!("Actor, {:#?}", actor); match input.kind { @@ -38,8 +36,16 @@ pub async fn inbox( Ok("{}") } -async fn fetch_actor(client: Arc, actor_id: &XsdAnyUri) -> Result { - client +async fn fetch_actor( + state: web::Data, + client: web::Data, + actor_id: &XsdAnyUri, +) -> Result { + if let Some(actor) = state.get_actor(actor_id).await { + return Ok(actor); + } + + let actor: AcceptedActors = client .get(actor_id.as_ref()) .header("Accept", "application/activity+json") .send() @@ -47,7 +53,11 @@ async fn fetch_actor(client: Arc, actor_id: &XsdAnyUri) -> Result() { + let id = self.0.fetch_add(1, Ordering::SeqCst); + actix::Arbiter::set_item(ArbiterLabel(id)); + } } } diff --git a/src/main.rs b/src/main.rs index 88f1258..96cb81e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,6 @@ use actix_web::{client::Client, web, App, HttpServer, Responder}; use bb8_postgres::tokio_postgres; mod apub; -mod cache; mod db_actor; mod inbox; mod label; diff --git a/src/state.rs b/src/state.rs index 0ddc6dd..0b716a5 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2,32 +2,49 @@ use activitystreams::primitives::XsdAnyUri; use anyhow::Error; use bb8_postgres::tokio_postgres::{row::Row, Client}; use futures::try_join; +use lru::LruCache; use std::{collections::HashSet, sync::Arc}; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::RwLock; +use ttl_cache::TtlCache; -use crate::{cache::WeightedCache, db_actor::Pool}; +use crate::{apub::AcceptedActors, db_actor::Pool}; #[derive(Clone)] pub struct State { - cache: Arc>>, + actor_cache: Arc>>, + actor_id_cache: Arc>>, blocks: Arc>>, whitelists: Arc>>, listeners: Arc>>, } impl State { - pub async fn is_cached(&self, object_id: XsdAnyUri) -> bool { - let cache = self.cache.clone(); + pub async fn get_actor(&self, actor_id: &XsdAnyUri) -> Option { + let cache = self.actor_cache.clone(); - let mut lock = cache.lock().await; - lock.get(object_id).is_some() + let read_guard = cache.read().await; + read_guard.get(actor_id).cloned() + } + + pub async fn cache_actor(&self, actor_id: XsdAnyUri, actor: AcceptedActors) { + let cache = self.actor_cache.clone(); + + let mut write_guard = cache.write().await; + write_guard.insert(actor_id, actor, std::time::Duration::from_secs(3600)); + } + + pub async fn is_cached(&self, object_id: &XsdAnyUri) -> bool { + let cache = self.actor_id_cache.clone(); + + let read_guard = cache.read().await; + read_guard.contains(object_id) } pub async fn cache(&self, object_id: XsdAnyUri, actor_id: XsdAnyUri) { - let cache = self.cache.clone(); + let cache = self.actor_id_cache.clone(); - let mut lock = cache.lock().await; - lock.insert(object_id, actor_id); + let mut write_guard = cache.write().await; + write_guard.put(object_id, actor_id); } pub async fn add_block(&self, client: &Client, block: XsdAnyUri) -> Result<(), Error> { @@ -103,7 +120,8 @@ impl State { let (blocks, whitelists, listeners) = try_join!(f1, f2, f3)?; Ok(State { - cache: Arc::new(Mutex::new(WeightedCache::new(1024 * 8))), + actor_cache: Arc::new(RwLock::new(TtlCache::new(1024 * 8))), + actor_id_cache: Arc::new(RwLock::new(LruCache::new(1024 * 8))), blocks: Arc::new(RwLock::new(blocks)), whitelists: Arc::new(RwLock::new(whitelists)), listeners: Arc::new(RwLock::new(listeners)),