diff --git a/crates/utils/src/rate_limit/mod.rs b/crates/utils/src/rate_limit/mod.rs index e2a155eb5..6027520f0 100644 --- a/crates/utils/src/rate_limit/mod.rs +++ b/crates/utils/src/rate_limit/mod.rs @@ -1,10 +1,14 @@ -use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr, LemmyError}; -use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; +use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr}; +use actix_web::{ + dev::{Service, ServiceRequest, ServiceResponse, Transform}, + HttpResponse, +}; use futures::future::{ok, Ready}; use rate_limiter::{RateLimitType, RateLimiter}; use std::{ future::Future, pin::Pin, + rc::Rc, sync::Arc, task::{Context, Poll}, }; @@ -29,7 +33,7 @@ pub struct RateLimited { pub struct RateLimitedMiddleware { rate_limited: RateLimited, - service: S, + service: Rc, } impl RateLimit { @@ -63,78 +67,28 @@ impl RateLimit { } impl RateLimited { - pub async fn wrap( - self, - ip_addr: IpAddr, - fut: impl Future>, - ) -> Result - where - E: From, - { + /// Returns true if the request passed the rate limit, false if it failed and should be rejected. + pub async fn check(self, ip_addr: IpAddr) -> bool { // Does not need to be blocking because the RwLock in settings never held across await points, // and the operation here locks only long enough to clone let rate_limit = self.rate_limit_config; - // before - { - let mut limiter = self.rate_limiter.lock().await; + let mut limiter = self.rate_limiter.lock().await; - match self.type_ { - RateLimitType::Message => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.message, - rate_limit.message_per_second, - )?; - - drop(limiter); - return fut.await; - } - RateLimitType::Post => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.post, - rate_limit.post_per_second, - )?; - } - RateLimitType::Register => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.register, - rate_limit.register_per_second, - )?; - } - RateLimitType::Image => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.image, - rate_limit.image_per_second, - )?; - } - RateLimitType::Comment => { - limiter.check_rate_limit_full( - self.type_, - &ip_addr, - rate_limit.comment, - rate_limit.comment_per_second, - )?; - } - }; - } - - let res = fut.await; - - res + let (kind, interval) = match self.type_ { + RateLimitType::Message => (rate_limit.message, rate_limit.message_per_second), + RateLimitType::Post => (rate_limit.post, rate_limit.post_per_second), + RateLimitType::Register => (rate_limit.register, rate_limit.register_per_second), + RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second), + RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second), + }; + limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval) } } impl Transform for RateLimited where - S: Service, + S: Service + 'static, S::Future: 'static, { type Response = S::Response; @@ -146,7 +100,7 @@ where fn new_transform(&self, service: S) -> Self::Future { ok(RateLimitedMiddleware { rate_limited: self.clone(), - service, + service: Rc::new(service), }) } } @@ -155,7 +109,7 @@ type FutResult = dyn Future>; impl Service for RateLimitedMiddleware where - S: Service, + S: Service + 'static, S::Future: 'static, { type Response = S::Response; @@ -169,11 +123,20 @@ where fn call(&self, req: ServiceRequest) -> Self::Future { let ip_addr = get_ip(&req.connection_info()); - let fut = self - .rate_limited - .clone() - .wrap(ip_addr, self.service.call(req)); + let rate_limited = self.rate_limited.clone(); + let service = self.service.clone(); - Box::pin(async move { fut.await.map_err(actix_web::Error::from) }) + Box::pin(async move { + if rate_limited.check(ip_addr).await { + service.call(req).await + } else { + let (http_req, _) = req.into_parts(); + // if rate limit was hit, respond with http 400 + Ok(ServiceResponse::new( + http_req, + HttpResponse::BadRequest().finish(), + )) + } + }) } } diff --git a/crates/utils/src/rate_limit/rate_limiter.rs b/crates/utils/src/rate_limit/rate_limiter.rs index ccc483ed7..31d91036e 100644 --- a/crates/utils/src/rate_limit/rate_limiter.rs +++ b/crates/utils/src/rate_limit/rate_limiter.rs @@ -1,11 +1,11 @@ -use crate::{IpAddr, LemmyError}; -use std::{collections::HashMap, time::SystemTime}; +use crate::IpAddr; +use std::{collections::HashMap, time::Instant}; use strum::IntoEnumIterator; use tracing::debug; #[derive(Debug, Clone)] struct RateLimitBucket { - last_checked: SystemTime, + last_checked: Instant, allowance: f64, } @@ -36,7 +36,7 @@ impl RateLimiter { bucket.insert( ip.clone(), RateLimitBucket { - last_checked: SystemTime::now(), + last_checked: Instant::now(), allowance: -2f64, }, ); @@ -46,6 +46,8 @@ impl RateLimiter { } /// Rate limiting Algorithm described here: https://stackoverflow.com/a/668327/1655478 + /// + /// Returns true if the request passed the rate limit, false if it failed and should be rejected. #[allow(clippy::float_cmp)] pub(super) fn check_rate_limit_full( &mut self, @@ -53,12 +55,12 @@ impl RateLimiter { ip: &IpAddr, rate: i32, per: i32, - ) -> Result<(), LemmyError> { + ) -> bool { self.insert_ip(ip); if let Some(bucket) = self.buckets.get_mut(&type_) { if let Some(rate_limit) = bucket.get_mut(ip) { - let current = SystemTime::now(); - let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64; + let current = Instant::now(); + let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64; // The initial value if rate_limit.allowance == -2f64 { @@ -79,25 +81,16 @@ impl RateLimiter { time_passed, rate_limit.allowance ); - Err(LemmyError::from_error_message( - anyhow::anyhow!( - "Too many requests. type: {}, IP: {}, {} per {} seconds", - type_.as_ref(), - ip, - rate, - per - ), - "too_many_requests", - )) + false } else { rate_limit.allowance -= 1.0; - Ok(()) + true } } else { - Ok(()) + true } } else { - Ok(()) + true } } } diff --git a/crates/websocket/src/chat_server.rs b/crates/websocket/src/chat_server.rs index 53274e386..1e95344b0 100644 --- a/crates/websocket/src/chat_server.rs +++ b/crates/websocket/src/chat_server.rs @@ -478,22 +478,33 @@ impl ChatServer { .as_str() .ok_or_else(|| LemmyError::from_message("missing op"))?; - if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) { - let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data); - match user_operation_crud { - UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await, - UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await, - UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await, - UserOperationCrud::CreateComment => rate_limiter.comment().wrap(ip, fut).await, - _ => rate_limiter.message().wrap(ip, fut).await, - } + // check if api call passes the rate limit, and generate future for later execution + let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) { + let passed = match user_operation_crud { + UserOperationCrud::Register => rate_limiter.register().check(ip).await, + UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await, + UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await, + UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await, + _ => rate_limiter.message().check(ip).await, + }; + let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data); + (passed, fut) } else { let user_operation = UserOperation::from_str(op)?; - let fut = (message_handler)(context, msg.id, user_operation.clone(), data); - match user_operation { - UserOperation::GetCaptcha => rate_limiter.post().wrap(ip, fut).await, - _ => rate_limiter.message().wrap(ip, fut).await, - } + let passed = match user_operation { + UserOperation::GetCaptcha => rate_limiter.post().check(ip).await, + _ => rate_limiter.message().check(ip).await, + }; + let fut = (message_handler)(context, msg.id, user_operation, data); + (passed, fut) + }; + + // if rate limit passed, execute api call future + if passed { + fut.await + } else { + // if rate limit was hit, respond with empty message + Ok("".to_string()) } } }