Remove tokio by inlining digest hashing

This commit is contained in:
Aode (lion) 2021-12-08 12:13:15 -06:00
parent 8667701c74
commit e8f730c947
8 changed files with 49 additions and 65 deletions

View file

@ -38,10 +38,8 @@ http-signature-normalization = { version = "0.5.1", path = ".." }
sha2 = { version = "0.9", optional = true }
sha3 = { version = "0.9", optional = true }
thiserror = "1.0"
tokio = { version = "1", default-features = false, features = ["sync"] }
tracing = "0.1"
tracing-error = "0.2"
tracing-futures = "0.2"
[dev-dependencies]
actix-rt = "2.3.0"

View file

@ -36,7 +36,7 @@ async fn request(config: Config) -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
#[actix_rt::main]
#[actix_web::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));

View file

@ -51,7 +51,7 @@ async fn index(
"Eyyyyup"
}
#[actix_rt::main]
#[actix_web::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));

View file

@ -14,10 +14,10 @@ use futures_util::{
use std::{
future::{ready, Ready},
pin::Pin,
rc::Rc,
task::{Context, Poll},
};
use tokio::sync::mpsc;
use tracing::{debug, Span};
use tracing::{debug, Instrument, Span};
use tracing_error::SpanTrace;
#[derive(Copy, Clone, Debug)]
@ -44,7 +44,7 @@ pub struct DigestVerified;
pub struct VerifyDigest<T>(bool, T);
#[doc(hidden)]
pub struct VerifyMiddleware<T, S>(S, bool, T);
pub struct VerifyMiddleware<T, S>(Rc<S>, bool, T);
#[derive(Debug, thiserror::Error)]
#[error("Error verifying digest")]
@ -76,19 +76,6 @@ enum VerifyErrorKind {
#[error("Failed to verify digest")]
Verify,
#[error("Payload dropped. If this was unexpected, it could be that the payload isn't required in the route this middleware is guarding")]
Dropped,
}
struct RxStream<T>(mpsc::Receiver<T>);
impl<T> Stream for RxStream<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.0).poll_recv(cx)
}
}
impl<T> VerifyDigest<T>
@ -142,7 +129,11 @@ where
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(VerifyMiddleware(service, self.0, self.1.clone())))
ready(Ok(VerifyMiddleware(
Rc::new(service),
self.0,
self.1.clone(),
)))
}
}
@ -167,53 +158,48 @@ where
digest.required = tracing::field::display(&self.1),
);
if let Some(digest) = req.headers().get("Digest") {
let vec = match parse_digest(digest) {
Some(vec) => vec,
None => {
return Box::pin(ready(Err(
VerifyError::new(&span, VerifyErrorKind::Empty).into()
)));
}
};
let payload = req.take_payload();
let verifier = self.2.clone();
let svc = Rc::clone(&self.0);
let required = self.1;
let (tx, rx) = mpsc::channel(1);
let f1 = span.in_scope(|| verify_payload(vec, self.2.clone(), payload, tx));
Box::pin(async move {
if let Some(digest) = req.headers().get("Digest") {
let vec = match parse_digest(digest) {
Some(vec) => vec,
None => {
return Err(VerifyError::new(&span, VerifyErrorKind::Empty).into());
}
};
let payload: Pin<Box<dyn Stream<Item = Result<web::Bytes, PayloadError>> + 'static>> =
Box::pin(RxStream(rx).map(Ok));
req.set_payload(payload.into());
req.extensions_mut().insert(DigestVerified);
let payload = req.take_payload();
let f2 = self.0.call(req);
let payload = verify_payload(vec, verifier, payload)
.instrument(span)
.await?;
Box::pin(async move {
let (_, res) = futures_util::future::join(f1, f2).await;
res
})
} else if self.1 {
Box::pin(ready(Err(VerifyError::new(
&span,
VerifyErrorKind::MissingDigest,
)
.into())))
} else {
Box::pin(self.0.call(req))
}
req.set_payload(payload);
req.extensions_mut().insert(DigestVerified);
svc.call(req).await
} else if required {
Err(VerifyError::new(&span, VerifyErrorKind::MissingDigest).into())
} else {
svc.call(req).await
}
})
}
}
#[tracing::instrument(name = "Verify Payload", skip(verify_digest, payload, tx))]
#[tracing::instrument(name = "Verify Payload", skip(verify_digest, payload))]
async fn verify_payload<T>(
vec: Vec<DigestPart>,
mut verify_digest: T,
mut payload: Payload,
tx: mpsc::Sender<web::Bytes>,
) -> Result<(), actix_web::Error>
) -> Result<Payload, actix_web::Error>
where
T: DigestVerify + Clone + Send + 'static,
{
let mut payload_vec = vec![];
while let Some(res) = payload.next().await {
let bytes = res?;
let bytes2 = bytes.clone();
@ -223,16 +209,17 @@ where
})
.await??;
tx.send(bytes)
.await
.map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))?;
payload_vec.push(Ok(bytes));
}
let verified =
web::block(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>).await??;
if verified {
Ok(())
let stream: Pin<Box<dyn Stream<Item = Result<web::Bytes, PayloadError>>>> =
Box::pin(futures_util::stream::iter(payload_vec));
Ok(Payload::Stream(stream))
} else {
Err(VerifyError::new(&Span::current(), VerifyErrorKind::Verify).into())
}

View file

@ -30,7 +30,7 @@ impl SignExt for ClientRequest {
Self: Sized,
{
Box::pin(async move {
let (d, v) = actix_rt::task::spawn_blocking(move || {
let (d, v) = actix_web::web::block(move || {
let d = digest.compute(v.as_ref());
Ok((d, v)) as Result<(String, V), E>
})
@ -67,7 +67,7 @@ impl SignExt for ClientRequest {
Self: Sized,
{
Box::pin(async move {
let (d, v) = actix_rt::task::spawn_blocking(move || {
let (d, v) = actix_web::web::block(move || {
let d = digest.compute(v.as_ref());
Ok((d, v)) as Result<(String, V), E>
})

View file

@ -50,7 +50,7 @@
//! "Eyyyyup"
//! }
//!
//! #[actix_rt::main]
//! #[actix_web::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let config = Config::default();
//!
@ -104,7 +104,7 @@
//! use http_signature_normalization_actix::prelude::*;
//! use sha2::{Digest, Sha256};
//!
//! #[actix_rt::main]
//! #[actix_web::main]
//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let config = Config::default();
//! let digest = Sha256::new();

View file

@ -12,9 +12,8 @@ use std::{
future::{ready, Ready},
task::{Context, Poll},
};
use tracing::{debug, Span};
use tracing::{debug, Instrument, Span};
use tracing_error::SpanTrace;
use tracing_futures::Instrument;
#[derive(Clone, Debug)]
/// A marker type that can be used to guard routes when the signature middleware is set to

View file

@ -92,7 +92,7 @@ where
let key_id = key_id.to_string();
let signed = actix_rt::task::spawn_blocking(move || unsigned.sign(key_id, f)).await.map_err(|_| BlockingError)??;
let signed = actix_web::web::block(move || unsigned.sign(key_id, f)).await.map_err(|_| BlockingError)??;
Ok(signed)
}