diff --git a/Cargo.lock b/Cargo.lock index 0b1f846..4373e80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -399,6 +399,7 @@ dependencies = [ "console-subscriber", "dashmap", "dotenv", + "futures-core", "http-signature-normalization-actix", "http-signature-normalization-reqwest", "lru", diff --git a/Cargo.toml b/Cargo.toml index 44a0f80..93fe4f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,6 +38,7 @@ config = { version = "0.14.0", default-features = false, features = ["toml", "js console-subscriber = { version = "0.2", optional = true } dashmap = "5.1.0" dotenv = "0.15.0" +futures-core = "0.3.30" lru = "0.12.0" metrics = "0.22.0" metrics-exporter-prometheus = { version = "0.13.0", default-features = false, features = [ diff --git a/src/error.rs b/src/error.rs index 05d5d1e..2a469bf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -123,6 +123,9 @@ pub(crate) enum ErrorKind { #[error("Couldn't sign request")] SignRequest, + #[error("Response body from server exceeded limits")] + BodyTooLarge, + #[error("Couldn't make request")] Reqwest(#[from] reqwest::Error), diff --git a/src/main.rs b/src/main.rs index 67c4379..a871f0a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,6 +38,7 @@ mod middleware; mod requests; mod routes; mod spawner; +mod stream; mod telegram; use crate::config::UrlKind; diff --git a/src/requests.rs b/src/requests.rs index 60cc9af..63faa3c 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -2,6 +2,7 @@ use crate::{ data::LastOnline, error::{Error, ErrorKind}, spawner::Spawner, + stream::{aggregate, limit_stream}, }; use activitystreams::iri_string::types::IriString; use actix_web::http::header::Date; @@ -24,6 +25,9 @@ const ONE_MINUTE: u64 = 60 * ONE_SECOND; const ONE_HOUR: u64 = 60 * ONE_MINUTE; const ONE_DAY: u64 = 24 * ONE_HOUR; +// 20 KB +const JSON_SIZE_LIMIT: usize = 20 * 1024; + #[derive(Debug)] pub(crate) enum BreakerStrategy { // Requires a successful response @@ -262,7 +266,7 @@ impl Requests { where T: serde::de::DeserializeOwned, { - let body = self + let stream = self .do_deliver( url, &serde_json::json!({}), @@ -271,8 +275,9 @@ impl Requests { strategy, ) .await? - .bytes() - .await?; + .bytes_stream(); + + let body = aggregate(limit_stream(stream, JSON_SIZE_LIMIT)).await?; Ok(serde_json::from_slice(&body)?) } @@ -299,11 +304,12 @@ impl Requests { where T: serde::de::DeserializeOwned, { - let body = self + let stream = self .do_fetch_response(url, accept, strategy) .await? - .bytes() - .await?; + .bytes_stream(); + + let body = aggregate(limit_stream(stream, JSON_SIZE_LIMIT)).await?; Ok(serde_json::from_slice(&body)?) } diff --git a/src/routes/media.rs b/src/routes/media.rs index 4c9b260..686ab22 100644 --- a/src/routes/media.rs +++ b/src/routes/media.rs @@ -2,10 +2,14 @@ use crate::{ data::MediaCache, error::Error, requests::{BreakerStrategy, Requests}, + stream::limit_stream, }; use actix_web::{body::BodyStream, web, HttpResponse}; use uuid::Uuid; +// 16 MB +const IMAGE_SIZE_LIMIT: usize = 16 * 1024 * 1024; + #[tracing::instrument(name = "Media", skip(media, requests))] pub(crate) async fn route( media: web::Data, @@ -25,7 +29,10 @@ pub(crate) async fn route( response.insert_header((name.clone(), value.clone())); } - return Ok(response.body(BodyStream::new(res.bytes_stream()))); + return Ok(response.body(BodyStream::new(limit_stream( + res.bytes_stream(), + IMAGE_SIZE_LIMIT, + )))); } Ok(HttpResponse::NotFound().finish()) diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..530b06e --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,59 @@ +use crate::error::{Error, ErrorKind}; +use actix_web::web::{Bytes, BytesMut}; +use futures_core::Stream; +use streem::IntoStreamer; + +pub(crate) fn limit_stream<'a, S>( + input: S, + limit: usize, +) -> impl Stream> + Send + 'a +where + S: Stream> + Send + 'a, +{ + streem::try_from_fn(move |yielder| async move { + let stream = std::pin::pin!(input); + let mut stream = stream.into_streamer(); + + let mut count = 0; + + while let Some(bytes) = stream.try_next().await? { + count += bytes.len(); + + if count > limit { + return Err(ErrorKind::BodyTooLarge.into()); + } + + yielder.yield_ok(bytes).await; + } + + Ok(()) + }) +} + +pub(crate) async fn aggregate(input: S) -> Result +where + S: Stream>, +{ + let stream = std::pin::pin!(input); + let mut streamer = stream.into_streamer(); + + let mut buf = Vec::new(); + + while let Some(bytes) = streamer.try_next().await? { + buf.push(bytes); + } + + if buf.len() == 1 { + return Ok(buf.pop().expect("buf has exactly one element")); + } + + let total_size: usize = buf.iter().map(|b| b.len()).sum(); + + let mut bytes_mut = BytesMut::with_capacity(total_size); + + for bytes in &buf { + bytes_mut.extend_from_slice(&bytes); + } + + Ok(bytes_mut.freeze()) +}