//! Types for setting up Digest middleware verification use crate::{DefaultSpawner, Spawn}; use super::{DigestPart, DigestVerify}; use actix_web::{ body::MessageBody, dev::{Payload, Service, ServiceRequest, ServiceResponse, Transform}, error::PayloadError, http::{header::HeaderValue, StatusCode}, web, FromRequest, HttpMessage, HttpRequest, HttpResponse, ResponseError, }; use futures_core::{future::LocalBoxFuture, Stream}; use std::{ future::{ready, Ready}, pin::Pin, task::{Context, Poll}, }; use streem::IntoStreamer; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, Span}; use tracing_error::SpanTrace; #[derive(Copy, Clone, Debug)] /// A type implementing FromRequest that can be used in route handler to guard for verified /// digests /// /// This is only required when the [`VerifyDigest`] middleware is set to optional pub struct DigestVerified; #[derive(Clone, Debug)] /// The VerifyDigest middleware /// /// ```rust,ignore /// let middleware = VerifyDigest::new(MyVerify::new()) /// .optional(); /// /// HttpServer::new(move || { /// App::new() /// .wrap(middleware.clone()) /// .route("/protected", web::post().to(|_: DigestVerified| "Verified Digest Header")) /// .route("/unprotected", web::post().to(|| "No verification required")) /// }) /// ``` pub struct VerifyDigest(Spawner, bool, T); #[doc(hidden)] pub struct VerifyMiddleware(S, Spawner, bool, T); #[derive(Debug, thiserror::Error)] #[error("Error verifying digest")] #[doc(hidden)] pub struct VerifyError { context: String, kind: VerifyErrorKind, } impl VerifyError { fn new(span: &Span, kind: VerifyErrorKind) -> Self { span.in_scope(|| VerifyError { context: SpanTrace::capture().to_string(), kind, }) } } #[derive(Debug, thiserror::Error)] enum VerifyErrorKind { #[error("Missing request extension")] Extension, #[error("Digest header missing")] MissingDigest, #[error("Digest header is empty")] Empty, #[error("Failed to verify digest")] Verify, #[error("Payload dropped. If this was unexpected, it could be that the payload isn't required in the route this middleware is guarding")] Dropped, } struct RxStream(mpsc::Receiver); impl Stream for RxStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_recv(cx).map(|opt| opt.map(Ok)) } } impl VerifyDigest where T: DigestVerify + Clone, { /// Produce a new VerifyDigest with a user-provided [`Digestverify`] type pub fn new(verify_digest: T) -> Self { VerifyDigest(DefaultSpawner, true, verify_digest) } } impl VerifyDigest where T: DigestVerify + Clone, { /// Set the spawner used for verifying bytes in the request /// /// By default this value uses `actix_web::web::block` pub fn spawner(self, spawner: NewSpawner) -> VerifyDigest where NewSpawner: Spawn, { VerifyDigest(spawner, self.1, self.2) } /// Mark verifying the Digest as optional /// /// If a digest is present in the request, it will be verified, but it is not required to be /// present pub fn optional(self) -> Self { VerifyDigest(self.0, false, self.2) } } struct VerifiedReceiver { rx: Option>, } impl FromRequest for DigestVerified { type Error = VerifyError; type Future = LocalBoxFuture<'static, Result>; fn from_request(req: &HttpRequest, _: &mut Payload) -> Self::Future { let res = req .extensions_mut() .get_mut::() .and_then(|r| r.rx.take()) .ok_or_else(|| VerifyError::new(&Span::current(), VerifyErrorKind::Extension)); if res.is_err() { debug!("Failed to fetch DigestVerified from request"); } Box::pin(async move { res?.await .map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped)) .map(|()| DigestVerified) }) } } impl Transform for VerifyDigest where T: DigestVerify + Clone + Send + 'static, S: Service, Error = actix_web::Error> + 'static, S::Error: 'static, B: MessageBody + 'static, Spawner: Spawn + Clone + 'static, { type Response = ServiceResponse; type Error = actix_web::Error; type Transform = VerifyMiddleware; type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(VerifyMiddleware( service, self.0.clone(), self.1, self.2.clone(), ))) } } impl Service for VerifyMiddleware where T: DigestVerify + Clone + Send + 'static, S: Service, Error = actix_web::Error> + 'static, S::Error: 'static, B: MessageBody + 'static, Spawner: Spawn + Clone + 'static, { type Response = ServiceResponse; type Error = actix_web::Error; type Future = LocalBoxFuture<'static, Result>; fn poll_ready(&self, cx: &mut Context) -> Poll> { self.0.poll_ready(cx) } fn call(&self, mut req: ServiceRequest) -> Self::Future { let span = tracing::info_span!( "Verify digest", digest.required = tracing::field::display(&self.2), ); if let Some(digest) = req.headers().get("Digest") { let vec = match parse_digest(digest) { Some(vec) => vec, None => { return Box::pin(ready(Err( VerifyError::new(&span, VerifyErrorKind::Empty).into() ))); } }; let payload = req.take_payload(); let spawner = self.1.clone(); let (tx, rx) = mpsc::channel(1); let (verify_tx, verify_rx) = oneshot::channel(); let f1 = span .in_scope(|| verify_payload(spawner, vec, self.3.clone(), payload, tx, verify_tx)); let payload: Pin> + 'static>> = Box::pin(RxStream(rx)); req.set_payload(payload.into()); req.extensions_mut().insert(VerifiedReceiver { rx: Some(verify_rx), }); let f2 = self.0.call(req); Box::pin(async move { let handle1 = actix_web::rt::spawn(f1); let handle2 = actix_web::rt::spawn(f2); handle1.await.expect("verify panic")?; handle2.await.expect("inner panic") }) } else if self.2 { Box::pin(ready(Err(VerifyError::new( &span, VerifyErrorKind::MissingDigest, ) .into()))) } else { Box::pin(self.0.call(req)) } } } #[tracing::instrument(name = "Verify Payload", skip(spawner, verify_digest, payload, tx))] async fn verify_payload( spawner: Spawner, vec: Vec, mut verify_digest: T, payload: Payload, tx: mpsc::Sender, verify_tx: oneshot::Sender<()>, ) -> Result<(), actix_web::Error> where T: DigestVerify + Clone + Send + 'static, Spawner: Spawn, { let mut payload = payload.into_streamer(); while let Some(bytes) = payload.try_next().await? { let bytes2 = bytes.clone(); verify_digest = spawner .spawn_blocking(move || { verify_digest.update(bytes2.as_ref()); Ok(verify_digest) as Result }) .await??; tx.send(bytes) .await .map_err(|_| VerifyError::new(&Span::current(), VerifyErrorKind::Dropped))?; } let verified = spawner .spawn_blocking(move || Ok(verify_digest.verify(&vec)) as Result<_, VerifyError>) .await??; if verified { if verify_tx.send(()).is_err() { debug!("handler dropped"); } Ok(()) } else { Err(VerifyError::new(&Span::current(), VerifyErrorKind::Verify).into()) } } fn parse_digest(h: &HeaderValue) -> Option> { let h = h.to_str().ok()?.split(';').next()?; let v: Vec<_> = h .split(',') .filter_map(|p| { let mut iter = p.splitn(2, '='); iter.next() .and_then(|alg| iter.next().map(|value| (alg, value))) }) .map(|(alg, value)| DigestPart { algorithm: alg.to_owned(), digest: value.to_owned(), }) .collect(); if v.is_empty() { None } else { Some(v) } } impl ResponseError for VerifyError { fn status_code(&self) -> StatusCode { StatusCode::BAD_REQUEST } fn error_response(&self) -> HttpResponse { HttpResponse::new(self.status_code()) } }