//! For middleware documentation, see [`Condition`]. use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use futures_core::{future::LocalBoxFuture, ready}; use futures_util::FutureExt as _; use pin_project_lite::pin_project; use crate::{ body::EitherBody, dev::{Service, ServiceResponse, Transform}, }; /// Middleware for conditionally enabling other middleware. /// /// # Examples /// ``` /// use actix_web::middleware::{Condition, NormalizePath}; /// use actix_web::App; /// /// let enable_normalize = std::env::var("NORMALIZE_PATH").is_ok(); /// let app = App::new() /// .wrap(Condition::new(enable_normalize, NormalizePath::default())); /// ``` pub struct Condition { transformer: T, enable: bool, } impl Condition { pub fn new(enable: bool, transformer: T) -> Self { Self { transformer, enable, } } } impl Transform for Condition where S: Service, Error = Err> + 'static, T: Transform, Error = Err>, T::Future: 'static, T::InitError: 'static, T::Transform: 'static, { type Response = ServiceResponse>; type Error = Err; type Transform = ConditionMiddleware; type InitError = T::InitError; type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { if self.enable { let fut = self.transformer.new_transform(service); async move { let wrapped_svc = fut.await?; Ok(ConditionMiddleware::Enable(wrapped_svc)) } .boxed_local() } else { async move { Ok(ConditionMiddleware::Disable(service)) }.boxed_local() } } } pub enum ConditionMiddleware { Enable(E), Disable(D), } impl Service for ConditionMiddleware where E: Service, Error = Err>, D: Service, Error = Err>, { type Response = ServiceResponse>; type Error = Err; type Future = ConditionMiddlewareFuture; fn poll_ready(&self, cx: &mut Context<'_>) -> Poll> { match self { ConditionMiddleware::Enable(service) => service.poll_ready(cx), ConditionMiddleware::Disable(service) => service.poll_ready(cx), } } fn call(&self, req: Req) -> Self::Future { match self { ConditionMiddleware::Enable(service) => ConditionMiddlewareFuture::Enabled { fut: service.call(req), }, ConditionMiddleware::Disable(service) => ConditionMiddlewareFuture::Disabled { fut: service.call(req), }, } } } pin_project! { #[doc(hidden)] #[project = ConditionProj] pub enum ConditionMiddlewareFuture { Enabled { #[pin] fut: E, }, Disabled { #[pin] fut: D, }, } } impl Future for ConditionMiddlewareFuture where E: Future, Err>>, D: Future, Err>>, { type Output = Result>, Err>; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let res = match self.project() { ConditionProj::Enabled { fut } => ready!(fut.poll(cx))?.map_into_left_body(), ConditionProj::Disabled { fut } => ready!(fut.poll(cx))?.map_into_right_body(), }; Poll::Ready(Ok(res)) } } #[cfg(test)] mod tests { use actix_service::IntoService as _; use super::*; use crate::{ body::BoxBody, dev::ServiceRequest, error::Result, http::{ header::{HeaderValue, CONTENT_TYPE}, StatusCode, }, middleware::{self, ErrorHandlerResponse, ErrorHandlers}, test::{self, TestRequest}, web::Bytes, HttpResponse, }; #[allow(clippy::unnecessary_wraps)] fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) } #[test] fn compat_with_builtin_middleware() { let _ = Condition::new(true, middleware::Compat::noop()); let _ = Condition::new(true, middleware::Logger::default()); let _ = Condition::new(true, middleware::Compress::default()); let _ = Condition::new(true, middleware::NormalizePath::trim()); let _ = Condition::new(true, middleware::DefaultHeaders::new()); let _ = Condition::new(true, middleware::ErrorHandlers::::new()); let _ = Condition::new(true, middleware::ErrorHandlers::::new()); } #[actix_rt::test] async fn test_handler_enabled() { let srv = |req: ServiceRequest| async move { let resp = HttpResponse::InternalServerError().message_body(String::new())?; Ok(req.into_response(resp)) }; let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Condition::new(true, mw) .new_transform(srv.into_service()) .await .unwrap(); let resp: ServiceResponse, String>> = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } #[actix_rt::test] async fn test_handler_disabled() { let srv = |req: ServiceRequest| async move { let resp = HttpResponse::InternalServerError().message_body(String::new())?; Ok(req.into_response(resp)) }; let mw = ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500); let mw = Condition::new(false, mw) .new_transform(srv.into_service()) .await .unwrap(); let resp: ServiceResponse, String>> = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE), None); } }