diff --git a/actix/Cargo.toml b/actix/Cargo.toml index 6124d4e..7049243 100644 --- a/actix/Cargo.toml +++ b/actix/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "http-signature-normalization-actix" description = "An HTTP Signatures library that leaves the signing to you" -version = "0.10.2" +version = "0.10.3" authors = ["asonix "] license = "AGPL-3.0" readme = "README.md" diff --git a/actix/src/digest/middleware.rs b/actix/src/digest/middleware.rs index 2ef8d01..071f628 100644 --- a/actix/src/digest/middleware.rs +++ b/actix/src/digest/middleware.rs @@ -17,7 +17,7 @@ use std::{ task::{Context, Poll}, }; use streem::IntoStreamer; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; use tracing::{debug, Span}; use tracing_error::SpanTrace; @@ -125,22 +125,30 @@ where } } +struct VerifiedReceiver { + rx: Option>, +} + impl FromRequest for DigestVerified { type Error = VerifyError; - type Future = Ready>; + type Future = LocalBoxFuture<'static, Result>; fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let res = req - .extensions() - .get::() - .copied() + .extensions_mut() + .get_mut::() + .and_then(|r| r.rx.take()) .ok_or_else(|| VerifyError::new(&Span::current(), VerifyErrorKind::Extension)); if res.is_err() { debug!("Failed to fetch DigestVerified from request"); } - ready(res) + Box::pin(async move { + res?.await + .map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped)) + .map(|()| DigestVerified) + }) } } @@ -203,12 +211,16 @@ where let spawner = self.1.clone(); let (tx, rx) = mpsc::channel(1); - let f1 = span.in_scope(|| verify_payload(spawner, vec, self.3.clone(), payload, tx)); + let (verify_tx, verify_rx) = oneshot::channel(); + let f1 = span + .in_scope(|| verify_payload(spawner, vec, self.3.clone(), payload, tx, verify_tx)); let payload: Pin> + 'static>> = Box::pin(RxStream(rx)); req.set_payload(payload.into()); - req.extensions_mut().insert(DigestVerified); + req.extensions_mut().insert(VerifiedReceiver { + rx: Some(verify_rx), + }); let f2 = self.0.call(req); @@ -238,6 +250,7 @@ async fn verify_payload( mut verify_digest: T, payload: Payload, tx: mpsc::Sender, + verify_tx: oneshot::Sender<()>, ) -> Result<(), actix_web::Error> where T: DigestVerify + Clone + Send + 'static, @@ -264,6 +277,9 @@ where .await??; if verified { + if verify_tx.send(()).is_err() { + debug!("handler dropped"); + } Ok(()) } else { Err(VerifyError::new(&Span::current(), VerifyErrorKind::Verify).into())