diff --git a/Cargo.lock b/Cargo.lock index 29110c9..aaf9f4f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -308,6 +308,7 @@ dependencies = [ "mime", "opentelemetry", "opentelemetry-otlp", + "pin-project-lite", "quanta", "rand", "rsa", diff --git a/Cargo.toml b/Cargo.toml index ac178a4..75a6552 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ metrics-util = "0.14.0" mime = "0.3.16" opentelemetry = { version = "0.18", features = ["rt-tokio"] } opentelemetry-otlp = "0.11" +pin-project-lite = "0.2.9" quanta = "0.10.1" rand = "0.8" rsa = "0.7" diff --git a/src/middleware/timings.rs b/src/middleware/timings.rs index 705e74e..a311a75 100644 --- a/src/middleware/timings.rs +++ b/src/middleware/timings.rs @@ -1,10 +1,10 @@ use actix_web::{ + body::MessageBody, dev::{Service, ServiceRequest, ServiceResponse, Transform}, http::StatusCode, }; -use futures_util::future::LocalBoxFuture; use std::{ - future::{ready, Ready}, + future::{ready, Future, Ready}, time::Instant, }; @@ -15,12 +15,30 @@ struct LogOnDrop { begin: Instant, path: String, method: String, - disarm: bool, + arm: bool, +} + +pin_project_lite::pin_project! { + pub(crate) struct TimingsFuture { + #[pin] + future: F, + + log_on_drop: Option, + } +} + +pin_project_lite::pin_project! { + pub(crate) struct TimingsBody { + #[pin] + body: B, + + log_on_drop: LogOnDrop, + } } impl Drop for LogOnDrop { fn drop(&mut self) { - if !self.disarm { + if self.arm { let duration = self.begin.elapsed(); metrics::histogram!("relay.request.complete", duration, "path" => self.path.clone(), "method" => self.method.clone()); } @@ -32,7 +50,7 @@ where S: Service, Error = actix_web::Error>, S::Future: 'static, { - type Response = S::Response; + type Response = ServiceResponse>; type Error = S::Error; type InitError = (); type Transform = TimingsMiddleware; @@ -48,9 +66,9 @@ where S: Service, Error = actix_web::Error>, S::Future: 'static, { - type Response = S::Response; + type Response = ServiceResponse>; type Error = S::Error; - type Future = LocalBoxFuture<'static, Result>; + type Future = TimingsFuture; fn poll_ready( &self, @@ -60,29 +78,66 @@ where } fn call(&self, req: ServiceRequest) -> Self::Future { - let mut logger = LogOnDrop { + let log_on_drop = LogOnDrop { begin: Instant::now(), path: req.path().to_string(), method: req.method().to_string(), - disarm: false, + arm: false, }; - let fut = self.0.call(req); - Box::pin(async move { - let res = fut.await; + let future = self.0.call(req); - let status = match &res { - Ok(res) => res.status(), - Err(e) => e.as_response_error().status_code(), - }; - if status == StatusCode::NOT_FOUND || status == StatusCode::METHOD_NOT_ALLOWED { - logger.disarm = true; - } - - // TODO: Drop after body write - drop(logger); - - res - }) + TimingsFuture { + future, + log_on_drop: Some(log_on_drop), + } + } +} + +impl Future for TimingsFuture +where + F: Future, actix_web::Error>>, +{ + type Output = Result>, actix_web::Error>; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + + let res = std::task::ready!(this.future.poll(cx)); + + let mut log_on_drop = this + .log_on_drop + .take() + .expect("TimingsFuture polled after completion"); + + let status = match &res { + Ok(res) => res.status(), + Err(e) => e.as_response_error().status_code(), + }; + + log_on_drop.arm = + status != StatusCode::NOT_FOUND && status != StatusCode::METHOD_NOT_ALLOWED; + + let res = res.map(|r| r.map_body(|_, body| TimingsBody { body, log_on_drop })); + + std::task::Poll::Ready(res) + } +} + +impl MessageBody for TimingsBody { + type Error = B::Error; + + fn size(&self) -> actix_web::body::BodySize { + self.body.size() + } + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>> { + self.project().body.poll_next(cx) } }