diff --git a/examples/live_federation/objects/post.rs b/examples/live_federation/objects/post.rs index 44012b1..9a08b9d 100644 --- a/examples/live_federation/objects/post.rs +++ b/examples/live_federation/objects/post.rs @@ -1,6 +1,9 @@ use crate::{ - activities::create_post::CreatePost, database::DatabaseHandle, error::Error, - generate_object_id, objects::person::DbUser, + activities::create_post::CreatePost, + database::DatabaseHandle, + error::Error, + generate_object_id, + objects::person::DbUser, }; use activitypub_federation::{ config::Data, diff --git a/examples/local_federation/activities/accept.rs b/examples/local_federation/activities/accept.rs index 9305213..c18945f 100644 --- a/examples/local_federation/activities/accept.rs +++ b/examples/local_federation/activities/accept.rs @@ -1,6 +1,9 @@ use crate::{activities::follow::Follow, instance::DatabaseHandle, objects::person::DbUser}; use activitypub_federation::{ - config::Data, fetch::object_id::ObjectId, kinds::activity::AcceptType, traits::ActivityHandler, + config::Data, + fetch::object_id::ObjectId, + kinds::activity::AcceptType, + traits::ActivityHandler, }; use serde::{Deserialize, Serialize}; use url::Url; diff --git a/examples/local_federation/activities/follow.rs b/examples/local_federation/activities/follow.rs index 51e8ee1..865a618 100644 --- a/examples/local_federation/activities/follow.rs +++ b/examples/local_federation/activities/follow.rs @@ -1,5 +1,7 @@ use crate::{ - activities::accept::Accept, generate_object_id, instance::DatabaseHandle, + activities::accept::Accept, + generate_object_id, + instance::DatabaseHandle, objects::person::DbUser, }; use activitypub_federation::{ diff --git a/examples/local_federation/axum/http.rs b/examples/local_federation/axum/http.rs index fd3fba9..3202117 100644 --- a/examples/local_federation/axum/http.rs +++ b/examples/local_federation/axum/http.rs @@ -17,7 +17,8 @@ use axum::{ extract::{Path, Query}, response::IntoResponse, routing::{get, post}, - Json, Router, + Json, + Router, }; use axum_macros::debug_handler; use serde::Deserialize; diff --git a/src/activity_queue.rs b/src/activity_queue.rs index 2c60662..5307a7c 100644 --- a/src/activity_queue.rs +++ b/src/activity_queue.rs @@ -12,6 +12,7 @@ use crate::{ }; use anyhow::{anyhow, Context}; +use crate::rate_limit::InstanceRatelimit; use bytes::Bytes; use futures_core::Future; use http::{header::HeaderName, HeaderMap, HeaderValue}; @@ -26,6 +27,7 @@ use std::{ sync::{ atomic::{AtomicUsize, Ordering}, Arc, + Mutex, }, time::{Duration, SystemTime}, }; @@ -104,6 +106,7 @@ where &config.client, config.request_timeout, Default::default(), + activity_queue.failure_rate_limit_hourly.clone(), ) .await { @@ -144,11 +147,23 @@ struct SendActivityTask { } async fn sign_and_send( + // TODO: this should only take a single struct as param task: &SendActivityTask, client: &ClientWithMiddleware, timeout: Duration, retry_strategy: RetryStrategy, + failure_rate_limit_hourly: Arc>>, ) -> Result<(), anyhow::Error> { + // Do nothing if there have been too many errors from this domain recently + { + // TODO: handle locking inside of InstanceRateLimit? + // TODO: need wrapper url type which returns domain as String + let mut lock = failure_rate_limit_hourly.lock().unwrap(); + let check = lock.check(task.inbox.domain().unwrap()); + if !check { + return Ok(()); + } + } debug!( "Sending {} to {}, contents:\n {}", task.activity_id, @@ -177,6 +192,7 @@ async fn sign_and_send( request .try_clone() .expect("The body of the request is not cloneable"), + failure_rate_limit_hourly.clone(), ) }, retry_strategy, @@ -188,10 +204,11 @@ async fn send( task: &SendActivityTask, client: &ClientWithMiddleware, request: Request, + failure_rate_limit_hourly: Arc>>, ) -> Result<(), anyhow::Error> { let response = client.execute(request).await; - match response { + let res = match response { Ok(o) if o.status().is_success() => { debug!( "Activity {} delivered successfully to {}", @@ -224,7 +241,12 @@ async fn send( task.inbox, e )), + }; + if res.is_err() { + let mut lock = failure_rate_limit_hourly.lock().unwrap(); + lock.log(task.inbox.domain().unwrap()); } + res } pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap { @@ -258,6 +280,7 @@ pub(crate) struct ActivityQueue { sender: UnboundedSender, sender_task: JoinHandle<()>, retry_sender_task: JoinHandle<()>, + failure_rate_limit_hourly: Arc>>, } /// Simple stat counter to show where we're up to with sending messages @@ -478,6 +501,9 @@ impl ActivityQueue { sender, sender_task, retry_sender_task, + failure_rate_limit_hourly: Arc::new(Mutex::new(InstanceRatelimit::new( + Duration::from_secs(60 * 60), + ))), } } diff --git a/src/actix_web/middleware.rs b/src/actix_web/middleware.rs index 2bb562b..afa0117 100644 --- a/src/actix_web/middleware.rs +++ b/src/actix_web/middleware.rs @@ -1,7 +1,10 @@ use crate::config::{Data, FederationConfig, FederationMiddleware}; use actix_web::{ dev::{forward_ready, Payload, Service, ServiceRequest, ServiceResponse, Transform}, - Error, FromRequest, HttpMessage, HttpRequest, + Error, + FromRequest, + HttpMessage, + HttpRequest, }; use std::future::{ready, Ready}; diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 5faadcd..7a9734c 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -3,7 +3,10 @@ #![doc = include_str!("../../docs/07_fetching_data.md")] use crate::{ - config::Data, error::Error, http_signatures::sign_request, reqwest_shim::ResponseExt, + config::Data, + error::Error, + http_signatures::sign_request, + reqwest_shim::ResponseExt, FEDERATION_CONTENT_TYPE, }; use bytes::Bytes; diff --git a/src/lib.rs b/src/lib.rs index 2031c98..db405b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ pub mod error; pub mod fetch; pub mod http_signatures; pub mod protocol; +mod rate_limit; pub(crate) mod reqwest_shim; pub mod traits; diff --git a/src/rate_limit.rs b/src/rate_limit.rs new file mode 100644 index 0000000..6f06e0c --- /dev/null +++ b/src/rate_limit.rs @@ -0,0 +1,95 @@ +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; +use std::ops::Sub; + +pub struct InstanceRatelimit { + period: Duration, + data: HashMap>, +} + +impl InstanceRatelimit { + pub fn new(period: Duration) -> Self { + InstanceRatelimit { + period, + data: HashMap::new(), + } + } + + fn domain_limiter(&mut self, domain: &str) -> &mut RateLimiter { + // TODO: inefficient, we only need String when inserting new entry which is rare + let domain = domain.to_string(); + self.data.entry(domain).or_insert_with(|| RateLimiter::new(self.period)) + } + + pub fn check(&mut self, domain: &str) -> bool { + self.domain_limiter(domain).check() + } + + pub fn log(&mut self, domain: &str) { + self.domain_limiter(domain).log() + } +} + +// TODO: check lemmy rate limiting code +struct RateLimiter { + period: Duration, + /// Using limit + 1 for greater than check + /// TODO: check if this is necessary or not + readings: [Option; LIMIT + 1], +} + +impl RateLimiter { + pub fn new(period: Duration) -> RateLimiter { + RateLimiter { + period, + readings: [None; LIMIT + 1], + } + } + + /// Count amount of entries less than `period` time before now and check against limit. + /// Return true if it is less. + fn check(&self) -> bool { + let now = Instant::now(); + let count = self.readings.iter() + .filter(|r| r.is_some()) + // TODO: check if gt/lt is correct + .filter(|r| r.unwrap() < now.sub(self.period)) + .count(); + count > LIMIT + } + + pub fn log(&mut self) { + let now = Instant::now(); + // TODO: replace all items older than `period` with None, insert Some(now) + } +} + +#[cfg(test)] +pub mod test { + use std::thread::sleep; + use std::time::Duration; + use crate::rate_limit::RateLimiter; + + #[test] + fn test_limiting() { + let mut limiter = RateLimiter::<1>::new(Duration::from_secs(1)); + assert_eq!(limiter.check(), true); + limiter.log(); + assert_eq!(limiter.check(), true); + limiter.log(); + assert_eq!(limiter.check(), false); + } + + #[test] + fn test_expiration() { + let mut limiter = RateLimiter::<1>::new(Duration::from_secs(1)); + assert_eq!(limiter.check(), true); + limiter.log(); + assert_eq!(limiter.check(), false); + sleep(Duration::from_secs(1)); + assert_eq!(limiter.check(), true); + + } +} \ No newline at end of file