Replace spawned tasks with inline payload stream processing

This commit is contained in:
asonix 2023-09-10 13:20:35 -04:00
parent 6e0a6fa3a2
commit 369a1e8a96

View file

@ -1,6 +1,6 @@
//! Types for setting up Digest middleware verification //! Types for setting up Digest middleware verification
use crate::{DefaultSpawner, Spawn}; use crate::{Canceled, DefaultSpawner, Spawn};
use super::{DigestPart, DigestVerify}; use super::{DigestPart, DigestVerify};
use actix_web::{ use actix_web::{
@ -16,7 +16,7 @@ use std::{
pin::Pin, pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
use streem::IntoStreamer; use streem::{from_fn::Yielder, IntoStreamer};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::{debug, Span}; use tracing::{debug, Span};
use tracing_error::SpanTrace; use tracing_error::SpanTrace;
@ -207,30 +207,23 @@ where
))); )));
} }
}; };
let payload = req.take_payload();
let spawner = self.1.clone(); let spawner = self.1.clone();
let digest = self.3.clone();
let (tx, rx) = mpsc::channel(1);
let (verify_tx, verify_rx) = oneshot::channel(); let (verify_tx, verify_rx) = oneshot::channel();
let f1 = span
.in_scope(|| verify_payload(spawner, vec, self.3.clone(), payload, tx, verify_tx));
let payload = req.take_payload();
let payload: Pin<Box<dyn Stream<Item = Result<web::Bytes, PayloadError>> + 'static>> = let payload: Pin<Box<dyn Stream<Item = Result<web::Bytes, PayloadError>> + 'static>> =
Box::pin(RxStream(rx)); Box::pin(streem::try_from_fn(|yielder| async move {
verify_payload(yielder, spawner, vec, digest, payload, verify_tx).await
}));
req.set_payload(payload.into()); req.set_payload(payload.into());
req.extensions_mut().insert(VerifiedReceiver { req.extensions_mut().insert(VerifiedReceiver {
rx: Some(verify_rx), rx: Some(verify_rx),
}); });
let f2 = self.0.call(req); Box::pin(self.0.call(req))
Box::pin(async move {
let handle1 = actix_web::rt::spawn(f1);
let handle2 = actix_web::rt::spawn(f2);
handle1.await.expect("verify panic")?;
handle2.await.expect("inner panic")
})
} else if self.2 { } else if self.2 {
Box::pin(ready(Err(VerifyError::new( Box::pin(ready(Err(VerifyError::new(
&span, &span,
@ -243,46 +236,79 @@ where
} }
} }
#[tracing::instrument(name = "Verify Payload", skip(spawner, verify_digest, payload, tx))] fn canceled_error(error: Canceled) -> PayloadError {
PayloadError::Io(std::io::Error::new(std::io::ErrorKind::Other, error))
}
fn verified_error(error: VerifyError) -> PayloadError {
PayloadError::Io(std::io::Error::new(std::io::ErrorKind::Other, error))
}
async fn verify_payload<T, Spawner>( async fn verify_payload<T, Spawner>(
yielder: Yielder<Result<web::Bytes, PayloadError>>,
spawner: Spawner, spawner: Spawner,
vec: Vec<DigestPart>, vec: Vec<DigestPart>,
mut verify_digest: T, mut verify_digest: T,
payload: Payload, payload: Payload,
tx: mpsc::Sender<web::Bytes>,
verify_tx: oneshot::Sender<()>, verify_tx: oneshot::Sender<()>,
) -> Result<(), actix_web::Error> ) -> Result<(), PayloadError>
where where
T: DigestVerify + Clone + Send + 'static, T: DigestVerify + Clone + Send + 'static,
Spawner: Spawn, Spawner: Spawn,
{ {
let mut payload = payload.into_streamer(); let mut payload = payload.into_streamer();
while let Some(bytes) = payload.try_next().await? { let mut error = None;
let bytes2 = bytes.clone();
verify_digest = spawner
.spawn_blocking(move || {
verify_digest.update(bytes2.as_ref());
Ok(verify_digest) as Result<T, VerifyError>
})
.await??;
tx.send(bytes) while let Some(bytes) = payload.try_next().await? {
.await if error.is_none() {
.map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))?; let bytes2 = bytes.clone();
let mut verify_digest2 = verify_digest.clone();
let task = spawner.spawn_blocking(move || {
verify_digest2.update(bytes2.as_ref());
Ok(verify_digest2) as Result<T, VerifyError>
});
yielder.yield_ok(bytes).await;
match task.await {
Ok(Ok(digest)) => {
verify_digest = digest;
}
Ok(Err(e)) => {
error = Some(verified_error(e));
}
Err(e) => {
error = Some(canceled_error(e));
}
}
} else {
yielder.yield_ok(bytes).await;
}
}
if let Some(error) = error {
return Err(error);
} }
let verified = spawner let verified = spawner
.spawn_blocking(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>) .spawn_blocking(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>)
.await??; .await
.map_err(canceled_error)?
.map_err(verified_error)?;
if verified { if verified {
if verify_tx.send(()).is_err() { if verify_tx.send(()).is_err() {
debug!("handler dropped"); debug!("handler dropped");
} }
Ok(()) Ok(())
} else { } else {
Err(VerifyError::new(&Span::current(), VerifyErrorKind::Verify).into()) Err(verified_error(VerifyError::new(
&Span::current(),
VerifyErrorKind::Verify,
)))
} }
} }