Compare commits

...

18 commits

Author SHA1 Message Date
asonix cda92e7523 Update flake 2024-06-23 13:57:40 -05:00
asonix 43b03a176c Don't fail publish on clippy warnings
unfixable without ructe release
2024-06-23 13:57:28 -05:00
asonix a465d1ae5b Allow versions to be unused 2024-06-23 13:56:37 -05:00
asonix 4fa7674a35 Move cargo config to config.toml 2024-06-23 13:55:10 -05:00
asonix 8c14d613f7 Prepare v0.3.114 2024-06-23 13:45:10 -05:00
asonix aff2431681 Update dependencies (minor & point) 2024-06-23 13:42:26 -05:00
asonix 5aa97212b3 Impose limits on the size of downloaded content from foreign servers 2024-06-23 13:35:24 -05:00
asonix 97567cf598 Prepare v0.3.113 2024-05-01 15:45:53 -05:00
asonix 4c663f399e Update dependencies (minor & point) 2024-05-01 15:43:53 -05:00
asonix 8a3256f52a Avoid deadlock of iterating over tree while transacting on that tree 2024-05-01 15:43:08 -05:00
asonix 13a2653fe8 Remove prerelease flag 2024-04-23 14:00:04 -05:00
asonix 8dd9a86d22 Use match_pattern rather than path for metrics differentiation 2024-04-21 11:44:16 -05:00
asonix 5c0c0591dd Prepare 0.3.112 2024-04-14 22:47:38 -05:00
asonix 04ca4e5401 Stable async-cpupool 2024-04-14 19:53:31 -05:00
asonix 1de1d76506 prerelease 2024-04-13 13:57:12 -05:00
asonix dd9225bb89 Prepare v0.3.111 2024-04-07 11:53:24 -05:00
asonix b577730836 Fix build 2024-04-07 11:40:57 -05:00
asonix 21883c168b BROKEN! Start collecting more metrics about various sizes 2024-04-07 11:04:03 -05:00
18 changed files with 564 additions and 378 deletions

View file

@ -21,7 +21,8 @@ jobs:
- -
name: Clippy name: Clippy
run: | run: |
cargo clippy --no-default-features -- -D warnings # cargo clippy --no-default-features -- -D warnings
cargo clippy --no-default-features
tests: tests:
runs-on: docker runs-on: docker

621
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,7 @@
[package] [package]
name = "ap-relay" name = "ap-relay"
description = "A simple activitypub relay" description = "A simple activitypub relay"
version = "0.3.110" version = "0.3.114"
authors = ["asonix <asonix@asonix.dog>"] authors = ["asonix <asonix@asonix.dog>"]
license = "AGPL-3.0" license = "AGPL-3.0"
readme = "README.md" readme = "README.md"
@ -29,7 +29,7 @@ actix-webfinger = { version = "0.5.0", default-features = false }
activitystreams = "0.7.0-alpha.25" activitystreams = "0.7.0-alpha.25"
activitystreams-ext = "0.1.0-alpha.3" activitystreams-ext = "0.1.0-alpha.3"
ammonia = "4.0.0" ammonia = "4.0.0"
async-cpupool = "0.2.0" async-cpupool = "0.2.2"
bcrypt = "0.15" bcrypt = "0.15"
base64 = "0.22" base64 = "0.22"
clap = { version = "4.0.0", features = ["derive"] } clap = { version = "4.0.0", features = ["derive"] }
@ -38,7 +38,7 @@ config = { version = "0.14.0", default-features = false, features = ["toml", "js
console-subscriber = { version = "0.2", optional = true } console-subscriber = { version = "0.2", optional = true }
dashmap = "5.1.0" dashmap = "5.1.0"
dotenv = "0.15.0" dotenv = "0.15.0"
flume = "0.11.0" futures-core = "0.3.30"
lru = "0.12.0" lru = "0.12.0"
metrics = "0.22.0" metrics = "0.22.0"
metrics-exporter-prometheus = { version = "0.13.0", default-features = false, features = [ metrics-exporter-prometheus = { version = "0.13.0", default-features = false, features = [

View file

@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1711163522, "lastModified": 1719075281,
"narHash": "sha256-YN/Ciidm+A0fmJPWlHBGvVkcarYWSC+s3NTPk/P+q3c=", "narHash": "sha256-CyyxvOwFf12I91PBWz43iGT1kjsf5oi6ax7CrvaMyAo=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "44d0940ea560dee511026a53f0e2e2cde489b4d4", "rev": "a71e967ef3694799d0c418c98332f7ff4cc5f6af",
"type": "github" "type": "github"
}, },
"original": { "original": {

View file

@ -5,7 +5,7 @@
rustPlatform.buildRustPackage { rustPlatform.buildRustPackage {
pname = "relay"; pname = "relay";
version = "0.3.110"; version = "0.3.114";
src = ./.; src = ./.;
cargoLock.lockFile = ./Cargo.lock; cargoLock.lockFile = ./Cargo.lock;

View file

@ -15,6 +15,10 @@ const MINUTES: u64 = 60 * SECONDS;
const HOURS: u64 = 60 * MINUTES; const HOURS: u64 = 60 * MINUTES;
const DAYS: u64 = 24 * HOURS; const DAYS: u64 = 24 * HOURS;
pub(crate) fn recordable(len: usize) -> u32 {
((len as u64) % u64::from(u32::MAX)) as u32
}
type DistributionMap = BTreeMap<Vec<(String, String)>, Summary>; type DistributionMap = BTreeMap<Vec<(String, String)>, Summary>;
#[derive(Clone)] #[derive(Clone)]
@ -299,7 +303,14 @@ impl Inner {
for sample in samples { for sample in samples {
entry.add(*sample); entry.add(*sample);
} }
}) });
let mut total_len = 0;
for dist_map in d.values() {
total_len += dist_map.len();
}
metrics::gauge!("relay.collector.distributions.size").set(recordable(total_len));
} }
let d = self.distributions.read().unwrap().clone(); let d = self.distributions.read().unwrap().clone();
@ -358,6 +369,7 @@ impl MemoryCollector {
) { ) {
let mut d = self.inner.descriptions.write().unwrap(); let mut d = self.inner.descriptions.write().unwrap();
d.entry(key.as_str().to_owned()).or_insert(description); d.entry(key.as_str().to_owned()).or_insert(description);
metrics::gauge!("relay.collector.descriptions.size").set(recordable(d.len()));
} }
pub(crate) fn install(&self) -> Result<(), SetRecorderError<Self>> { pub(crate) fn install(&self) -> Result<(), SetRecorderError<Self>> {

View file

@ -9,10 +9,10 @@ pub(crate) struct LastOnline {
impl LastOnline { impl LastOnline {
pub(crate) fn mark_seen(&self, iri: &IriStr) { pub(crate) fn mark_seen(&self, iri: &IriStr) {
if let Some(authority) = iri.authority_str() { if let Some(authority) = iri.authority_str() {
self.domains let mut guard = self.domains.lock().unwrap();
.lock() guard.insert(authority.to_string(), OffsetDateTime::now_utc());
.unwrap() metrics::gauge!("relay.last-online.size",)
.insert(authority.to_string(), OffsetDateTime::now_utc()); .set(crate::collector::recordable(guard.len()));
} }
} }

View file

@ -73,7 +73,9 @@ impl State {
} }
pub(crate) fn cache(&self, object_id: IriString, actor_id: IriString) { pub(crate) fn cache(&self, object_id: IriString, actor_id: IriString) {
self.object_cache.write().unwrap().put(object_id, actor_id); let mut guard = self.object_cache.write().unwrap();
guard.put(object_id, actor_id);
metrics::gauge!("relay.object-cache.size").set(crate::collector::recordable(guard.len()));
} }
pub(crate) fn is_connected(&self, iri: &IriString) -> bool { pub(crate) fn is_connected(&self, iri: &IriString) -> bool {

142
src/db.rs
View file

@ -7,7 +7,7 @@ use rsa::{
pkcs8::{DecodePrivateKey, EncodePrivateKey}, pkcs8::{DecodePrivateKey, EncodePrivateKey},
RsaPrivateKey, RsaPrivateKey,
}; };
use sled::{Batch, Tree}; use sled::{transaction::TransactionError, Batch, Transactional, Tree};
use std::{ use std::{
collections::{BTreeMap, HashMap}, collections::{BTreeMap, HashMap},
sync::{ sync::{
@ -283,10 +283,15 @@ impl Db {
pub(crate) async fn check_health(&self) -> Result<(), Error> { pub(crate) async fn check_health(&self) -> Result<(), Error> {
let next = self.inner.healthz_counter.fetch_add(1, Ordering::Relaxed); let next = self.inner.healthz_counter.fetch_add(1, Ordering::Relaxed);
self.unblock(move |inner| { self.unblock(move |inner| {
inner let res = inner
.healthz .healthz
.insert("healthz", &next.to_be_bytes()[..]) .insert("healthz", &next.to_be_bytes()[..])
.map_err(Error::from) .map_err(Error::from);
metrics::gauge!("relay.db.healthz.size")
.set(crate::collector::recordable(inner.healthz.len()));
res
}) })
.await?; .await?;
self.inner.healthz.flush_async().await?; self.inner.healthz.flush_async().await?;
@ -349,6 +354,9 @@ impl Db {
.actor_id_info .actor_id_info
.insert(actor_id.as_str().as_bytes(), vec)?; .insert(actor_id.as_str().as_bytes(), vec)?;
metrics::gauge!("relay.db.actor-id-info.size")
.set(crate::collector::recordable(inner.actor_id_info.len()));
Ok(()) Ok(())
}) })
.await .await
@ -383,6 +391,9 @@ impl Db {
.actor_id_instance .actor_id_instance
.insert(actor_id.as_str().as_bytes(), vec)?; .insert(actor_id.as_str().as_bytes(), vec)?;
metrics::gauge!("relay.db.actor-id-instance.size")
.set(crate::collector::recordable(inner.actor_id_instance.len()));
Ok(()) Ok(())
}) })
.await .await
@ -417,6 +428,9 @@ impl Db {
.actor_id_contact .actor_id_contact
.insert(actor_id.as_str().as_bytes(), vec)?; .insert(actor_id.as_str().as_bytes(), vec)?;
metrics::gauge!("relay.db.actor-id-contact.size")
.set(crate::collector::recordable(inner.actor_id_contact.len()));
Ok(()) Ok(())
}) })
.await .await
@ -447,6 +461,12 @@ impl Db {
inner inner
.media_url_media_id .media_url_media_id
.insert(url.as_str().as_bytes(), id.as_bytes())?; .insert(url.as_str().as_bytes(), id.as_bytes())?;
metrics::gauge!("relay.db.media-id-media-url.size")
.set(crate::collector::recordable(inner.media_id_media_url.len()));
metrics::gauge!("relay.db.media-url-media-id.size")
.set(crate::collector::recordable(inner.media_url_media_id.len()));
Ok(()) Ok(())
}) })
.await .await
@ -538,6 +558,14 @@ impl Db {
inner inner
.actor_id_actor .actor_id_actor
.insert(actor.id.as_str().as_bytes(), vec)?; .insert(actor.id.as_str().as_bytes(), vec)?;
metrics::gauge!("relay.db.public-key-actor-id.size").set(crate::collector::recordable(
inner.public_key_id_actor_id.len(),
));
metrics::gauge!("relay.db.actor-id-actor.size").set(crate::collector::recordable(
inner.public_key_id_actor_id.len(),
));
Ok(()) Ok(())
}) })
.await .await
@ -550,6 +578,10 @@ impl Db {
.connected_actor_ids .connected_actor_ids
.remove(actor_id.as_str().as_bytes())?; .remove(actor_id.as_str().as_bytes())?;
metrics::gauge!("relay.db.connected-actor-ids.size").set(crate::collector::recordable(
inner.connected_actor_ids.len(),
));
Ok(()) Ok(())
}) })
.await .await
@ -562,6 +594,10 @@ impl Db {
.connected_actor_ids .connected_actor_ids
.insert(actor_id.as_str().as_bytes(), actor_id.as_str().as_bytes())?; .insert(actor_id.as_str().as_bytes(), actor_id.as_str().as_bytes())?;
metrics::gauge!("relay.db.connected-actor-ids.size").set(crate::collector::recordable(
inner.connected_actor_ids.len(),
));
Ok(()) Ok(())
}) })
.await .await
@ -569,30 +605,64 @@ impl Db {
pub(crate) async fn add_blocks(&self, domains: Vec<String>) -> Result<(), Error> { pub(crate) async fn add_blocks(&self, domains: Vec<String>) -> Result<(), Error> {
self.unblock(move |inner| { self.unblock(move |inner| {
for connected in inner.connected_by_domain(&domains) { let connected_by_domain = inner.connected_by_domain(&domains).collect::<Vec<_>>();
inner
.connected_actor_ids
.remove(connected.as_str().as_bytes())?;
}
for authority in &domains { let res = (
inner &inner.connected_actor_ids,
.blocked_domains &inner.blocked_domains,
.insert(domain_key(authority), authority.as_bytes())?; &inner.allowed_domains,
inner.allowed_domains.remove(domain_key(authority))?; )
} .transaction(|(connected, blocked, allowed)| {
let mut connected_batch = Batch::default();
let mut blocked_batch = Batch::default();
let mut allowed_batch = Batch::default();
Ok(()) for connected in &connected_by_domain {
connected_batch.remove(connected.as_str().as_bytes());
}
for authority in &domains {
blocked_batch
.insert(domain_key(authority).as_bytes(), authority.as_bytes());
allowed_batch.remove(domain_key(authority).as_bytes());
}
connected.apply_batch(&connected_batch)?;
blocked.apply_batch(&blocked_batch)?;
allowed.apply_batch(&allowed_batch)?;
Ok(())
});
metrics::gauge!("relay.db.connected-actor-ids.size").set(crate::collector::recordable(
inner.connected_actor_ids.len(),
));
metrics::gauge!("relay.db.blocked-domains.size")
.set(crate::collector::recordable(inner.blocked_domains.len()));
metrics::gauge!("relay.db.allowed-domains.size")
.set(crate::collector::recordable(inner.allowed_domains.len()));
match res {
Ok(()) => Ok(()),
Err(TransactionError::Abort(e) | TransactionError::Storage(e)) => Err(e.into()),
}
}) })
.await .await
} }
pub(crate) async fn remove_blocks(&self, domains: Vec<String>) -> Result<(), Error> { pub(crate) async fn remove_blocks(&self, domains: Vec<String>) -> Result<(), Error> {
self.unblock(move |inner| { self.unblock(move |inner| {
let mut blocked_batch = Batch::default();
for authority in &domains { for authority in &domains {
inner.blocked_domains.remove(domain_key(authority))?; blocked_batch.remove(domain_key(authority).as_bytes());
} }
inner.blocked_domains.apply_batch(blocked_batch)?;
metrics::gauge!("relay.db.blocked-domains.size")
.set(crate::collector::recordable(inner.blocked_domains.len()));
Ok(()) Ok(())
}) })
.await .await
@ -600,12 +670,17 @@ impl Db {
pub(crate) async fn add_allows(&self, domains: Vec<String>) -> Result<(), Error> { pub(crate) async fn add_allows(&self, domains: Vec<String>) -> Result<(), Error> {
self.unblock(move |inner| { self.unblock(move |inner| {
let mut allowed_batch = Batch::default();
for authority in &domains { for authority in &domains {
inner allowed_batch.insert(domain_key(authority).as_bytes(), authority.as_bytes());
.allowed_domains
.insert(domain_key(authority), authority.as_bytes())?;
} }
inner.allowed_domains.apply_batch(allowed_batch)?;
metrics::gauge!("relay.db.allowed-domains.size")
.set(crate::collector::recordable(inner.allowed_domains.len()));
Ok(()) Ok(())
}) })
.await .await
@ -614,17 +689,32 @@ impl Db {
pub(crate) async fn remove_allows(&self, domains: Vec<String>) -> Result<(), Error> { pub(crate) async fn remove_allows(&self, domains: Vec<String>) -> Result<(), Error> {
self.unblock(move |inner| { self.unblock(move |inner| {
if inner.restricted_mode { if inner.restricted_mode {
for connected in inner.connected_by_domain(&domains) { let connected_by_domain = inner.connected_by_domain(&domains).collect::<Vec<_>>();
inner
.connected_actor_ids let mut connected_batch = Batch::default();
.remove(connected.as_str().as_bytes())?;
for connected in &connected_by_domain {
connected_batch.remove(connected.as_str().as_bytes());
} }
inner.connected_actor_ids.apply_batch(connected_batch)?;
metrics::gauge!("relay.db.connected-actor-ids.size").set(
crate::collector::recordable(inner.connected_actor_ids.len()),
);
} }
let mut allowed_batch = Batch::default();
for authority in &domains { for authority in &domains {
inner.allowed_domains.remove(domain_key(authority))?; allowed_batch.remove(domain_key(authority).as_bytes());
} }
inner.allowed_domains.apply_batch(allowed_batch)?;
metrics::gauge!("relay.db.allowed-domains.size")
.set(crate::collector::recordable(inner.allowed_domains.len()));
Ok(()) Ok(())
}) })
.await .await
@ -665,6 +755,10 @@ impl Db {
inner inner
.settings .settings
.insert("private-key".as_bytes(), pem_pkcs8.as_bytes())?; .insert("private-key".as_bytes(), pem_pkcs8.as_bytes())?;
metrics::gauge!("relay.db.settings.size")
.set(crate::collector::recordable(inner.settings.len()));
Ok(()) Ok(())
}) })
.await .await

View file

@ -123,6 +123,9 @@ pub(crate) enum ErrorKind {
#[error("Couldn't sign request")] #[error("Couldn't sign request")]
SignRequest, SignRequest,
#[error("Response body from server exceeded limits")]
BodyTooLarge,
#[error("Couldn't make request")] #[error("Couldn't make request")]
Reqwest(#[from] reqwest::Error), Reqwest(#[from] reqwest::Error),

View file

@ -40,7 +40,12 @@ fn debug_object(activity: &serde_json::Value) -> &serde_json::Value {
object object
} }
pub(crate) fn build_storage() -> MetricsStorage<Storage<TokioTimer>> {
MetricsStorage::wrap(Storage::new(TokioTimer))
}
pub(crate) fn create_workers( pub(crate) fn create_workers(
storage: MetricsStorage<Storage<TokioTimer>>,
state: State, state: State,
actors: ActorCache, actors: ActorCache,
media: MediaCache, media: MediaCache,
@ -48,18 +53,15 @@ pub(crate) fn create_workers(
) -> std::io::Result<JobServer> { ) -> std::io::Result<JobServer> {
let deliver_concurrency = config.deliver_concurrency(); let deliver_concurrency = config.deliver_concurrency();
let queue_handle = WorkerConfig::new( let queue_handle = WorkerConfig::new(storage, move |queue_handle| {
MetricsStorage::wrap(Storage::new(TokioTimer)), JobState::new(
move |queue_handle| { state.clone(),
JobState::new( actors.clone(),
state.clone(), JobServer::new(queue_handle),
actors.clone(), media.clone(),
JobServer::new(queue_handle), config.clone(),
media.clone(), )
config.clone(), })
)
},
)
.register::<Deliver>() .register::<Deliver>()
.register::<DeliverMany>() .register::<DeliverMany>()
.register::<QueryNodeinfo>() .register::<QueryNodeinfo>()

View file

@ -156,7 +156,7 @@ struct Link {
#[serde(untagged)] #[serde(untagged)]
enum MaybeSupported<T> { enum MaybeSupported<T> {
Supported(T), Supported(T),
Unsupported(String), Unsupported(#[allow(unused)] String),
} }
impl<T> MaybeSupported<T> { impl<T> MaybeSupported<T> {
@ -165,8 +165,8 @@ impl<T> MaybeSupported<T> {
} }
} }
struct SupportedVersion(String); struct SupportedVersion(#[allow(unused)] String);
struct SupportedNodeinfo(String); struct SupportedNodeinfo(#[allow(unused)] String);
static SUPPORTED_VERSIONS: &str = "2."; static SUPPORTED_VERSIONS: &str = "2.";
static SUPPORTED_NODEINFO: &str = "http://nodeinfo.diaspora.software/ns/schema/2."; static SUPPORTED_NODEINFO: &str = "http://nodeinfo.diaspora.software/ns/schema/2.";

View file

@ -38,6 +38,7 @@ mod middleware;
mod requests; mod requests;
mod routes; mod routes;
mod spawner; mod spawner;
mod stream;
mod telegram; mod telegram;
use crate::config::UrlKind; use crate::config::UrlKind;
@ -321,10 +322,16 @@ async fn server_main(
let sign_spawner2 = sign_spawner.clone(); let sign_spawner2 = sign_spawner.clone();
let verify_spawner2 = verify_spawner.clone(); let verify_spawner2 = verify_spawner.clone();
let config2 = config.clone(); let config2 = config.clone();
let job_store = jobs::build_storage();
let server = HttpServer::new(move || { let server = HttpServer::new(move || {
let job_server = let job_server = create_workers(
create_workers(state.clone(), actors.clone(), media.clone(), config.clone()) job_store.clone(),
.expect("Failed to create job server"); state.clone(),
actors.clone(),
media.clone(),
config.clone(),
)
.expect("Failed to create job server");
let app = App::new() let app = App::new()
.app_data(web::Data::new(db.clone())) .app_data(web::Data::new(db.clone()))

View file

@ -80,7 +80,7 @@ where
fn call(&self, req: ServiceRequest) -> Self::Future { fn call(&self, req: ServiceRequest) -> Self::Future {
let log_on_drop = LogOnDrop { let log_on_drop = LogOnDrop {
begin: Instant::now(), begin: Instant::now(),
path: req.path().to_string(), path: format!("{:?}", req.match_pattern()),
method: req.method().to_string(), method: req.method().to_string(),
arm: false, arm: false,
}; };

View file

@ -2,6 +2,7 @@ use crate::{
data::LastOnline, data::LastOnline,
error::{Error, ErrorKind}, error::{Error, ErrorKind},
spawner::Spawner, spawner::Spawner,
stream::{aggregate, limit_stream},
}; };
use activitystreams::iri_string::types::IriString; use activitystreams::iri_string::types::IriString;
use actix_web::http::header::Date; use actix_web::http::header::Date;
@ -24,6 +25,9 @@ const ONE_MINUTE: u64 = 60 * ONE_SECOND;
const ONE_HOUR: u64 = 60 * ONE_MINUTE; const ONE_HOUR: u64 = 60 * ONE_MINUTE;
const ONE_DAY: u64 = 24 * ONE_HOUR; const ONE_DAY: u64 = 24 * ONE_HOUR;
// 20 KB
const JSON_SIZE_LIMIT: usize = 20 * 1024;
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum BreakerStrategy { pub(crate) enum BreakerStrategy {
// Requires a successful response // Requires a successful response
@ -262,7 +266,7 @@ impl Requests {
where where
T: serde::de::DeserializeOwned, T: serde::de::DeserializeOwned,
{ {
let body = self let stream = self
.do_deliver( .do_deliver(
url, url,
&serde_json::json!({}), &serde_json::json!({}),
@ -271,8 +275,9 @@ impl Requests {
strategy, strategy,
) )
.await? .await?
.bytes() .bytes_stream();
.await?;
let body = aggregate(limit_stream(stream, JSON_SIZE_LIMIT)).await?;
Ok(serde_json::from_slice(&body)?) Ok(serde_json::from_slice(&body)?)
} }
@ -299,11 +304,12 @@ impl Requests {
where where
T: serde::de::DeserializeOwned, T: serde::de::DeserializeOwned,
{ {
let body = self let stream = self
.do_fetch_response(url, accept, strategy) .do_fetch_response(url, accept, strategy)
.await? .await?
.bytes() .bytes_stream();
.await?;
let body = aggregate(limit_stream(stream, JSON_SIZE_LIMIT)).await?;
Ok(serde_json::from_slice(&body)?) Ok(serde_json::from_slice(&body)?)
} }

View file

@ -2,10 +2,14 @@ use crate::{
data::MediaCache, data::MediaCache,
error::Error, error::Error,
requests::{BreakerStrategy, Requests}, requests::{BreakerStrategy, Requests},
stream::limit_stream,
}; };
use actix_web::{body::BodyStream, web, HttpResponse}; use actix_web::{body::BodyStream, web, HttpResponse};
use uuid::Uuid; use uuid::Uuid;
// 16 MB
const IMAGE_SIZE_LIMIT: usize = 16 * 1024 * 1024;
#[tracing::instrument(name = "Media", skip(media, requests))] #[tracing::instrument(name = "Media", skip(media, requests))]
pub(crate) async fn route( pub(crate) async fn route(
media: web::Data<MediaCache>, media: web::Data<MediaCache>,
@ -25,7 +29,10 @@ pub(crate) async fn route(
response.insert_header((name.clone(), value.clone())); response.insert_header((name.clone(), value.clone()));
} }
return Ok(response.body(BodyStream::new(res.bytes_stream()))); return Ok(response.body(BodyStream::new(limit_stream(
res.bytes_stream(),
IMAGE_SIZE_LIMIT,
))));
} }
Ok(HttpResponse::NotFound().finish()) Ok(HttpResponse::NotFound().finish())

59
src/stream.rs Normal file
View file

@ -0,0 +1,59 @@
use crate::error::{Error, ErrorKind};
use actix_web::web::{Bytes, BytesMut};
use futures_core::Stream;
use streem::IntoStreamer;
pub(crate) fn limit_stream<'a, S>(
input: S,
limit: usize,
) -> impl Stream<Item = Result<Bytes, Error>> + Send + 'a
where
S: Stream<Item = reqwest::Result<Bytes>> + Send + 'a,
{
streem::try_from_fn(move |yielder| async move {
let stream = std::pin::pin!(input);
let mut stream = stream.into_streamer();
let mut count = 0;
while let Some(bytes) = stream.try_next().await? {
count += bytes.len();
if count > limit {
return Err(ErrorKind::BodyTooLarge.into());
}
yielder.yield_ok(bytes).await;
}
Ok(())
})
}
pub(crate) async fn aggregate<S>(input: S) -> Result<Bytes, Error>
where
S: Stream<Item = Result<Bytes, Error>>,
{
let stream = std::pin::pin!(input);
let mut streamer = stream.into_streamer();
let mut buf = Vec::new();
while let Some(bytes) = streamer.try_next().await? {
buf.push(bytes);
}
if buf.len() == 1 {
return Ok(buf.pop().expect("buf has exactly one element"));
}
let total_size: usize = buf.iter().map(|b| b.len()).sum();
let mut bytes_mut = BytesMut::with_capacity(total_size);
for bytes in &buf {
bytes_mut.extend_from_slice(&bytes);
}
Ok(bytes_mut.freeze())
}