From e8f730c94798bcdc39558b887205a6d2e85268c5 Mon Sep 17 00:00:00 2001 From: "Aode (lion)" Date: Wed, 8 Dec 2021 12:13:15 -0600 Subject: [PATCH] Remove tokio by inlining digest hashing --- http-signature-normalization-actix/Cargo.toml | 2 - .../examples/client.rs | 2 +- .../examples/server.rs | 2 +- .../src/digest/middleware.rs | 95 ++++++++----------- .../src/digest/sign.rs | 4 +- http-signature-normalization-actix/src/lib.rs | 4 +- .../src/middleware.rs | 3 +- .../src/sign.rs | 2 +- 8 files changed, 49 insertions(+), 65 deletions(-) diff --git a/http-signature-normalization-actix/Cargo.toml b/http-signature-normalization-actix/Cargo.toml index 4a6e7e0..8fb75b7 100644 --- a/http-signature-normalization-actix/Cargo.toml +++ b/http-signature-normalization-actix/Cargo.toml @@ -38,10 +38,8 @@ http-signature-normalization = { version = "0.5.1", path = ".." } sha2 = { version = "0.9", optional = true } sha3 = { version = "0.9", optional = true } thiserror = "1.0" -tokio = { version = "1", default-features = false, features = ["sync"] } tracing = "0.1" tracing-error = "0.2" -tracing-futures = "0.2" [dev-dependencies] actix-rt = "2.3.0" diff --git a/http-signature-normalization-actix/examples/client.rs b/http-signature-normalization-actix/examples/client.rs index bb6d101..1cc1f2a 100644 --- a/http-signature-normalization-actix/examples/client.rs +++ b/http-signature-normalization-actix/examples/client.rs @@ -36,7 +36,7 @@ async fn request(config: Config) -> Result<(), Box> { Ok(()) } -#[actix_rt::main] +#[actix_web::main] async fn main() -> Result<(), Box> { let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); diff --git a/http-signature-normalization-actix/examples/server.rs b/http-signature-normalization-actix/examples/server.rs index 907cc1a..025757a 100644 --- a/http-signature-normalization-actix/examples/server.rs +++ b/http-signature-normalization-actix/examples/server.rs @@ -51,7 +51,7 @@ async fn index( "Eyyyyup" } -#[actix_rt::main] +#[actix_web::main] async fn main() -> Result<(), Box> { let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); diff --git a/http-signature-normalization-actix/src/digest/middleware.rs b/http-signature-normalization-actix/src/digest/middleware.rs index ddb1a0b..9949472 100644 --- a/http-signature-normalization-actix/src/digest/middleware.rs +++ b/http-signature-normalization-actix/src/digest/middleware.rs @@ -14,10 +14,10 @@ use futures_util::{ use std::{ future::{ready, Ready}, pin::Pin, + rc::Rc, task::{Context, Poll}, }; -use tokio::sync::mpsc; -use tracing::{debug, Span}; +use tracing::{debug, Instrument, Span}; use tracing_error::SpanTrace; #[derive(Copy, Clone, Debug)] @@ -44,7 +44,7 @@ pub struct DigestVerified; pub struct VerifyDigest(bool, T); #[doc(hidden)] -pub struct VerifyMiddleware(S, bool, T); +pub struct VerifyMiddleware(Rc, bool, T); #[derive(Debug, thiserror::Error)] #[error("Error verifying digest")] @@ -76,19 +76,6 @@ enum VerifyErrorKind { #[error("Failed to verify digest")] Verify, - - #[error("Payload dropped. If this was unexpected, it could be that the payload isn't required in the route this middleware is guarding")] - Dropped, -} - -struct RxStream(mpsc::Receiver); - -impl Stream for RxStream { - type Item = T; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_recv(cx) - } } impl VerifyDigest @@ -142,7 +129,11 @@ where type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { - ready(Ok(VerifyMiddleware(service, self.0, self.1.clone()))) + ready(Ok(VerifyMiddleware( + Rc::new(service), + self.0, + self.1.clone(), + ))) } } @@ -167,53 +158,48 @@ where digest.required = tracing::field::display(&self.1), ); - if let Some(digest) = req.headers().get("Digest") { - let vec = match parse_digest(digest) { - Some(vec) => vec, - None => { - return Box::pin(ready(Err( - VerifyError::new(&span, VerifyErrorKind::Empty).into() - ))); - } - }; - let payload = req.take_payload(); + let verifier = self.2.clone(); + let svc = Rc::clone(&self.0); + let required = self.1; - let (tx, rx) = mpsc::channel(1); - let f1 = span.in_scope(|| verify_payload(vec, self.2.clone(), payload, tx)); + Box::pin(async move { + if let Some(digest) = req.headers().get("Digest") { + let vec = match parse_digest(digest) { + Some(vec) => vec, + None => { + return Err(VerifyError::new(&span, VerifyErrorKind::Empty).into()); + } + }; - let payload: Pin> + 'static>> = - Box::pin(RxStream(rx).map(Ok)); - req.set_payload(payload.into()); - req.extensions_mut().insert(DigestVerified); + let payload = req.take_payload(); - let f2 = self.0.call(req); + let payload = verify_payload(vec, verifier, payload) + .instrument(span) + .await?; - Box::pin(async move { - let (_, res) = futures_util::future::join(f1, f2).await; - res - }) - } else if self.1 { - Box::pin(ready(Err(VerifyError::new( - &span, - VerifyErrorKind::MissingDigest, - ) - .into()))) - } else { - Box::pin(self.0.call(req)) - } + req.set_payload(payload); + req.extensions_mut().insert(DigestVerified); + + svc.call(req).await + } else if required { + Err(VerifyError::new(&span, VerifyErrorKind::MissingDigest).into()) + } else { + svc.call(req).await + } + }) } } -#[tracing::instrument(name = "Verify Payload", skip(verify_digest, payload, tx))] +#[tracing::instrument(name = "Verify Payload", skip(verify_digest, payload))] async fn verify_payload( vec: Vec, mut verify_digest: T, mut payload: Payload, - tx: mpsc::Sender, -) -> Result<(), actix_web::Error> +) -> Result where T: DigestVerify + Clone + Send + 'static, { + let mut payload_vec = vec![]; while let Some(res) = payload.next().await { let bytes = res?; let bytes2 = bytes.clone(); @@ -223,16 +209,17 @@ where }) .await??; - tx.send(bytes) - .await - .map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))?; + payload_vec.push(Ok(bytes)); } let verified = web::block(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>).await??; if verified { - Ok(()) + let stream: Pin>>> = + Box::pin(futures_util::stream::iter(payload_vec)); + + Ok(Payload::Stream(stream)) } else { Err(VerifyError::new(&Span::current(), VerifyErrorKind::Verify).into()) } diff --git a/http-signature-normalization-actix/src/digest/sign.rs b/http-signature-normalization-actix/src/digest/sign.rs index c3385d3..a202b7e 100644 --- a/http-signature-normalization-actix/src/digest/sign.rs +++ b/http-signature-normalization-actix/src/digest/sign.rs @@ -30,7 +30,7 @@ impl SignExt for ClientRequest { Self: Sized, { Box::pin(async move { - let (d, v) = actix_rt::task::spawn_blocking(move || { + let (d, v) = actix_web::web::block(move || { let d = digest.compute(v.as_ref()); Ok((d, v)) as Result<(String, V), E> }) @@ -67,7 +67,7 @@ impl SignExt for ClientRequest { Self: Sized, { Box::pin(async move { - let (d, v) = actix_rt::task::spawn_blocking(move || { + let (d, v) = actix_web::web::block(move || { let d = digest.compute(v.as_ref()); Ok((d, v)) as Result<(String, V), E> }) diff --git a/http-signature-normalization-actix/src/lib.rs b/http-signature-normalization-actix/src/lib.rs index 81e8491..259445a 100644 --- a/http-signature-normalization-actix/src/lib.rs +++ b/http-signature-normalization-actix/src/lib.rs @@ -50,7 +50,7 @@ //! "Eyyyyup" //! } //! -//! #[actix_rt::main] +//! #[actix_web::main] //! async fn main() -> Result<(), Box> { //! let config = Config::default(); //! @@ -104,7 +104,7 @@ //! use http_signature_normalization_actix::prelude::*; //! use sha2::{Digest, Sha256}; //! -//! #[actix_rt::main] +//! #[actix_web::main] //! async fn main() -> Result<(), Box> { //! let config = Config::default(); //! let digest = Sha256::new(); diff --git a/http-signature-normalization-actix/src/middleware.rs b/http-signature-normalization-actix/src/middleware.rs index 5680405..601979f 100644 --- a/http-signature-normalization-actix/src/middleware.rs +++ b/http-signature-normalization-actix/src/middleware.rs @@ -12,9 +12,8 @@ use std::{ future::{ready, Ready}, task::{Context, Poll}, }; -use tracing::{debug, Span}; +use tracing::{debug, Instrument, Span}; use tracing_error::SpanTrace; -use tracing_futures::Instrument; #[derive(Clone, Debug)] /// A marker type that can be used to guard routes when the signature middleware is set to diff --git a/http-signature-normalization-actix/src/sign.rs b/http-signature-normalization-actix/src/sign.rs index f9e65f5..0ac8c70 100644 --- a/http-signature-normalization-actix/src/sign.rs +++ b/http-signature-normalization-actix/src/sign.rs @@ -92,7 +92,7 @@ where let key_id = key_id.to_string(); - let signed = actix_rt::task::spawn_blocking(move || unsigned.sign(key_id, f)).await.map_err(|_| BlockingError)??; + let signed = actix_web::web::block(move || unsigned.sign(key_id, f)).await.map_err(|_| BlockingError)??; Ok(signed) }