diff --git a/src/main.rs b/src/main.rs index 36c5d27..7220dd9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,7 +22,7 @@ use self::{ data::{ActorCache, Media, State}, db::Db, jobs::{create_server, create_workers}, - middleware::RelayResolver, + middleware::{DebugPayload, RelayResolver}, routes::{actor, inbox, index, nodeinfo, nodeinfo_meta, statics}, }; @@ -132,6 +132,7 @@ async fn main() -> Result<(), anyhow::Error> { .service(web::resource("/media/{path}").route(web::get().to(routes::media))) .service( web::resource("/inbox") + .wrap(DebugPayload(config.debug())) .wrap(config.digest_middleware()) .wrap(config.signature_middleware(state.requests(), actors.clone())) .route(web::post().to(inbox)), diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 6e5b1f9..f910828 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -1,5 +1,7 @@ +mod payload; mod verifier; mod webfinger; +pub use payload::DebugPayload; pub use verifier::MyVerify; pub use webfinger::RelayResolver; diff --git a/src/middleware/payload.rs b/src/middleware/payload.rs new file mode 100644 index 0000000..3db41b4 --- /dev/null +++ b/src/middleware/payload.rs @@ -0,0 +1,98 @@ +use actix_web::{ + dev::{Payload, Service, ServiceRequest, Transform}, + http::StatusCode, + HttpMessage, HttpResponse, ResponseError, +}; +use bytes::BytesMut; +use futures::{ + future::{ok, LocalBoxFuture, Ready}, + stream::StreamExt, +}; +use log::info; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::channel; + +#[derive(Clone, Debug)] +pub struct DebugPayload(pub bool); + +#[doc(hidden)] +#[derive(Clone, Debug)] +pub struct DebugPayloadMiddleware(bool, S); + +#[derive(Clone, Debug, thiserror::Error)] +#[error("Failed to read payload")] +pub struct DebugError; + +impl ResponseError for DebugError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } + + fn error_response(&self) -> HttpResponse { + HttpResponse::new(self.status_code()) + } +} + +impl Transform for DebugPayload +where + S: Service, + S::Future: 'static, + S::Error: 'static, +{ + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type InitError = (); + type Transform = DebugPayloadMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ok(DebugPayloadMiddleware(self.0, service)) + } +} + +impl Service for DebugPayloadMiddleware +where + S: Service, + S::Future: 'static, + S::Error: 'static, +{ + type Request = S::Request; + type Response = S::Response; + type Error = S::Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.1.poll_ready(cx) + } + + fn call(&mut self, mut req: S::Request) -> Self::Future { + if self.0 { + let (mut tx, rx) = channel(1); + + let mut pl = req.take_payload(); + req.set_payload(Payload::Stream(Box::pin(rx))); + + let fut = self.1.call(req); + + return Box::pin(async move { + let mut bytes = BytesMut::new(); + + while let Some(res) = pl.next().await { + let b = res.map_err(|_| DebugError)?; + bytes.extend(b); + } + + info!("{}", String::from_utf8_lossy(bytes.as_ref())); + + tx.send(Ok(bytes.freeze())).await.map_err(|_| DebugError)?; + + fut.await + }); + } + + let fut = self.1.call(req); + + Box::pin(async move { fut.await }) + } +}