implement simple moka cache for all dereference() calls

This commit is contained in:
phiresky 2023-07-01 02:42:02 +02:00
parent b64f4a8f3f
commit f9548ce788
12 changed files with 115 additions and 27 deletions

View file

@ -56,6 +56,7 @@ axum = { version = "0.6.18", features = [
], default-features = false, optional = true }
tower = { version = "0.4.13", optional = true }
hyper = { version = "0.14", optional = true }
moka = { version = "0.11.2", features = ["future"] }
[features]
default = ["actix-web", "axum"]

View file

@ -11,25 +11,46 @@ It is sometimes necessary to fetch from a URL, but we don't know the exact type
# use activitypub_federation::traits::tests::DbConnection;
# use activitypub_federation::config::Data;
# use url::Url;
# use std::sync::Arc;
# use activitypub_federation::traits::tests::{Person, Note};
#[derive(Clone)]
pub enum SearchableDbObjects {
User(DbUser),
Post(DbPost)
}
#[derive(Deserialize, Serialize)]
#[derive(Deserialize, Serialize, Debug)]
#[serde(untagged)]
pub enum SearchableObjects {
Person(Person),
Note(Note)
}
#[derive(Debug, Clone)]
pub struct Error(pub(crate) Arc<anyhow::Error>);
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl<T> From<T> for Error
where
T: Into<anyhow::Error>,
{
fn from(t: T) -> Self {
Error(Arc::new(t.into()))
}
}
#[async_trait::async_trait]
impl Object for SearchableDbObjects {
type DataType = DbConnection;
type Kind = SearchableObjects;
type Error = anyhow::Error;
type Error = Error;
async fn read_from_id(
object_id: Url,
@ -62,7 +83,7 @@ impl Object for SearchableDbObjects {
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
async fn main() -> Result<(), Error> {
# let config = FederationConfig::builder().domain("example.com").app_data(DbConnection).build().await.unwrap();
# let data = config.to_request_data();
let query = "https://example.com/id/413";

View file

@ -1,8 +1,8 @@
use std::fmt::{Display, Formatter};
use std::{fmt::{Display, Formatter}, sync::Arc};
/// Necessary because of this issue: https://github.com/actix/actix-web/issues/1711
#[derive(Debug)]
pub struct Error(pub(crate) anyhow::Error);
#[derive(Debug, Clone)]
pub struct Error(pub(crate) Arc<anyhow::Error>);
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
@ -15,6 +15,6 @@ where
T: Into<anyhow::Error>,
{
fn from(t: T) -> Self {
Error(t.into())
Error(Arc::new(t.into()))
}
}

View file

@ -1,8 +1,8 @@
use std::fmt::{Display, Formatter};
use std::{fmt::{Display, Formatter}, sync::Arc};
/// Necessary because of this issue: https://github.com/actix/actix-web/issues/1711
#[derive(Debug)]
pub struct Error(pub(crate) anyhow::Error);
#[derive(Debug, Clone)]
pub struct Error(pub(crate) Arc<anyhow::Error>);
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
@ -15,6 +15,6 @@ where
T: Into<anyhow::Error>,
{
fn from(t: T) -> Self {
Error(t.into())
Error(Arc::new(t.into()))
}
}

View file

@ -22,13 +22,13 @@ pub async fn receive_activity<Activity, ActorT, Datatype>(
) -> Result<HttpResponse, <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Clone + Sync + Send + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<anyhow::Error>
+ From<Error>
+ From<<ActorT as Object>::Error>
+ From<serde_json::Error>,
<ActorT as Object>::Error: From<Error> + From<anyhow::Error>,
<ActorT as Object>::Error: From<Error> + From<anyhow::Error> + Clone + Send + Sync,
Datatype: Clone,
{
verify_body_hash(request.headers().get("Digest"), &body)?;

View file

@ -21,8 +21,8 @@ pub async fn signing_actor<A>(
data: &Data<<A as Object>::DataType>,
) -> Result<A, <A as Object>::Error>
where
A: Object + Actor,
<A as Object>::Error: From<Error> + From<anyhow::Error>,
A: Object + Actor + Sync + Send + Clone,
<A as Object>::Error: From<Error> + From<anyhow::Error> + Sync + Clone + Send,
for<'de2> <A as Object>::Kind: Deserialize<'de2>,
{
verify_body_hash(request.headers().get("Digest"), &body.unwrap_or_default())?;

View file

@ -27,13 +27,13 @@ pub async fn receive_activity<Activity, ActorT, Datatype>(
) -> Result<(), <Activity as ActivityHandler>::Error>
where
Activity: ActivityHandler<DataType = Datatype> + DeserializeOwned + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + 'static,
ActorT: Object<DataType = Datatype> + Actor + Send + Sync + Clone + 'static,
for<'de2> <ActorT as Object>::Kind: serde::Deserialize<'de2>,
<Activity as ActivityHandler>::Error: From<anyhow::Error>
+ From<Error>
+ From<<ActorT as Object>::Error>
+ From<serde_json::Error>,
<ActorT as Object>::Error: From<Error> + From<anyhow::Error>,
<ActorT as Object>::Error: From<Error> + From<anyhow::Error> + Clone + Send + Sync,
Datatype: Clone,
{
verify_body_hash(activity_data.headers.get("Digest"), &activity_data.body)?;

View file

@ -31,7 +31,6 @@ pub enum Error {
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl Error {
pub(crate) fn other<T>(error: T) -> Self
where

View file

@ -1,8 +1,11 @@
use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Object};
use anyhow::anyhow;
use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc};
use moka::future::Cache;
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::{
cell::OnceCell,
fmt::{Debug, Display, Formatter},
marker::PhantomData,
str::FromStr,
@ -60,6 +63,41 @@ where
Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>;
impl<Kind> ObjectId<Kind>
where
Kind: Object + Send + Sync + Clone + 'static,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
<Kind as Object>::Error: Clone + Send + Sync,
{
/// This creates a cache for every monomorphization of ObjectId (so for every type of object)
const CACHE: OnceCell<Cache<Url, Result<Kind, Kind::Error>>> = OnceCell::new();
/// Fetches an activitypub object, either from local database (if possible), or over http, retrieving from cache if possible
pub async fn dereference(
&self,
data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error>
where
<Kind as Object>::Error: From<Error> + From<anyhow::Error>,
{
let cache = Self::CACHE;
let cache = cache.get_or_init(|| {
Cache::builder()
.max_capacity(Kind::cache_max_capacity())
.time_to_live(Kind::cache_time_to_live())
.build()
});
// The get_with method ensures that the dereference_inner method is only called once even if the dereference method is called twice simultaneously.
// From the docs: "This method guarantees that concurrent calls on the same not-existing key are coalesced into one evaluation of the init future. Only one of the calls evaluates its future, and other calls wait for that future to resolve."
// Considerations: should an error result be stored in the cache as well? Right now: yes. Otherwise try_get_with should be used.
cache
.get_with(*self.0.clone(), async {
self.dereference_uncached(data).await
})
.await
}
}
impl<Kind> ObjectId<Kind>
where
Kind: Object + Send + 'static,
@ -85,7 +123,7 @@ where
}
/// Fetches an activitypub object, either from local database (if possible), or over http.
pub async fn dereference(
pub async fn dereference_uncached(
&self,
data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error>

View file

@ -5,7 +5,7 @@ use crate::{
traits::{Actor, Object},
FEDERATION_CONTENT_TYPE,
};
use anyhow::anyhow;
use anyhow::{anyhow, Context};
use itertools::Itertools;
use regex::Regex;
use serde::{Deserialize, Serialize};
@ -22,10 +22,10 @@ pub async fn webfinger_resolve_actor<T: Clone, Kind>(
data: &Data<T>,
) -> Result<Kind, <Kind as Object>::Error>
where
Kind: Object + Actor + Send + 'static + Object<DataType = T>,
Kind: Object + Actor + Send + 'static + Object<DataType = T> + Clone + Sync,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>,
<Kind as Object>::Error:
From<crate::error::Error> + From<anyhow::Error> + From<url::ParseError> + Send + Sync,
From<crate::error::Error> + From<anyhow::Error> + Send + Sync + Clone,
{
let (_, domain) = identifier
.splitn(2, '@')
@ -36,7 +36,7 @@ where
format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}");
debug!("Fetching webfinger url: {}", &fetch_url);
let res: Webfinger = fetch_object_http(&Url::parse(&fetch_url)?, data).await?;
let res: Webfinger = fetch_object_http(&Url::parse(&fetch_url).context("parsing url")?, data).await?;
debug_assert_eq!(res.subject, format!("acct:{identifier}"));
let links: Vec<Url> = res

View file

@ -151,8 +151,8 @@ pub(crate) async fn signing_actor<'a, A, H>(
data: &Data<<A as Object>::DataType>,
) -> Result<A, <A as Object>::Error>
where
A: Object + Actor,
<A as Object>::Error: From<Error> + From<anyhow::Error>,
A: Object + Actor + Clone + Sync,
<A as Object>::Error: From<Error> + From<anyhow::Error> + Clone + Send + Sync,
for<'de2> <A as Object>::Kind: Deserialize<'de2>,
H: IntoIterator<Item = (&'a HeaderName, &'a HeaderValue)>,
{

View file

@ -4,7 +4,7 @@ use crate::{config::Data, protocol::public_key::PublicKey};
use async_trait::async_trait;
use chrono::NaiveDateTime;
use serde::Deserialize;
use std::{fmt::Debug, ops::Deref};
use std::{ops::Deref, time::Duration};
use url::Url;
/// Helper for converting between database structs and federated protocol structs.
@ -102,6 +102,15 @@ pub trait Object: Sized {
/// Error type returned by handler methods
type Error;
/// Defines how many of this type of object should be cached
fn cache_max_capacity() -> u64 {
1000
}
/// Defines how long objects of this type should live in the in-memory cache
fn cache_time_to_live() -> Duration {
Duration::from_secs(10)
}
/// Returns the last time this object was updated.
///
/// If this returns `Some` and the value is too long ago, the object is refetched from the
@ -337,6 +346,7 @@ pub trait Collection: Sized {
#[doc(hidden)]
#[allow(clippy::unwrap_used)]
pub mod tests {
use super::*;
use crate::{
fetch::object_id::ObjectId,
@ -400,11 +410,30 @@ pub mod tests {
local: false,
});
#[derive(Debug, Clone)]
pub struct ClonableError(String);
impl std::fmt::Display for ClonableError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for ClonableError {}
impl From<anyhow::Error> for ClonableError {
fn from(value: anyhow::Error) -> Self {
ClonableError(format!("{:?}", value))
}
}
impl From<crate::error::Error> for ClonableError {
fn from(value: crate::error::Error) -> Self {
ClonableError(format!("{:?}", value))
}
}
#[async_trait]
impl Object for DbUser {
type DataType = DbConnection;
type Kind = Person;
type Error = Error;
type Error = ClonableError;
async fn read_from_id(
_object_id: Url,