diff --git a/server/src/main.rs b/server/src/main.rs index c92770f2a..4e773ee57 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -34,7 +34,9 @@ async fn main() -> io::Result<()> { embedded_migrations::run(&conn).unwrap(); // Set up the rate limiter - let rate_limiter = RateLimit(Arc::new(Mutex::new(RateLimiter::default()))); + let rate_limiter = RateLimit { + rate_limiter: Arc::new(Mutex::new(RateLimiter::default())), + }; // Set up websocket server let server = ChatServer::startup(pool.clone(), rate_limiter.clone()).start(); diff --git a/server/src/rate_limit/mod.rs b/server/src/rate_limit/mod.rs index 9aeb11718..bb77db29c 100644 --- a/server/src/rate_limit/mod.rs +++ b/server/src/rate_limit/mod.rs @@ -18,12 +18,20 @@ use strum::IntoEnumIterator; use tokio::sync::Mutex; #[derive(Debug, Clone)] -pub struct RateLimit(pub Arc>); +pub struct RateLimit { + pub rate_limiter: Arc>, +} #[derive(Debug, Clone)] -pub struct RateLimited(Arc>, RateLimitType); +pub struct RateLimited { + rate_limiter: Arc>, + type_: RateLimitType, +} -pub struct RateLimitedMiddleware(RateLimited, S); +pub struct RateLimitedMiddleware { + rate_limited: RateLimited, + service: S, +} impl RateLimit { pub fn message(&self) -> RateLimited { @@ -39,7 +47,10 @@ impl RateLimit { } fn kind(&self, type_: RateLimitType) -> RateLimited { - RateLimited(self.0.clone(), type_) + RateLimited { + rate_limiter: self.rate_limiter.clone(), + type_, + } } } @@ -64,12 +75,12 @@ impl RateLimited { // before { - let mut limiter = self.0.lock().await; + let mut limiter = self.rate_limiter.lock().await; - match self.1 { + match self.type_ { RateLimitType::Message => { limiter.check_rate_limit_full( - self.1, + self.type_, &ip_addr, rate_limit.message, rate_limit.message_per_second, @@ -80,7 +91,7 @@ impl RateLimited { } RateLimitType::Post => { limiter.check_rate_limit_full( - self.1.clone(), + self.type_.clone(), &ip_addr, rate_limit.post, rate_limit.post_per_second, @@ -89,7 +100,7 @@ impl RateLimited { } RateLimitType::Register => { limiter.check_rate_limit_full( - self.1, + self.type_, &ip_addr, rate_limit.register, rate_limit.register_per_second, @@ -103,12 +114,12 @@ impl RateLimited { // after { - let mut limiter = self.0.lock().await; + let mut limiter = self.rate_limiter.lock().await; if res.is_ok() { - match self.1 { + match self.type_ { RateLimitType::Post => { limiter.check_rate_limit_full( - self.1, + self.type_, &ip_addr, rate_limit.post, rate_limit.post_per_second, @@ -117,7 +128,7 @@ impl RateLimited { } RateLimitType::Register => { limiter.check_rate_limit_full( - self.1, + self.type_, &ip_addr, rate_limit.register, rate_limit.register_per_second, @@ -146,7 +157,10 @@ where type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ok(RateLimitedMiddleware(self.clone(), service)) + ok(RateLimitedMiddleware { + rate_limited: self.clone(), + service, + }) } } @@ -163,7 +177,7 @@ where type Future = Pin>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.1.poll_ready(cx) + self.service.poll_ready(cx) } fn call(&mut self, req: S::Request) -> Self::Future { @@ -176,7 +190,10 @@ where .unwrap_or("127.0.0.1") .to_string(); - let fut = self.0.clone().wrap(ip_addr, self.1.call(req)); + let fut = self + .rate_limited + .clone() + .wrap(ip_addr, self.service.call(req)); Box::pin(async move { fut.await.map_err(actix_web::Error::from) }) }