//! For middleware documentation, see [`ErrorHandlers`]. use std::{ future::Future, pin::Pin, rc::Rc, task::{Context, Poll}, }; use actix_service::{Service, Transform}; use ahash::AHashMap; use futures_core::{future::LocalBoxFuture, ready}; use pin_project_lite::pin_project; use crate::{ dev::{ServiceRequest, ServiceResponse}, http::StatusCode, Error, Result, }; /// Return type for [`ErrorHandlers`] custom handlers. pub enum ErrorHandlerResponse { /// Immediate HTTP response. Response(ServiceResponse), /// A future that resolves to an HTTP response. Future(LocalBoxFuture<'static, Result, Error>>), } type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; /// Middleware for registering custom status code based error handlers. /// /// Register handlers with the `ErrorHandlers::handler()` method to register a custom error handler /// for a given status code. Handlers can modify existing responses or create completely new ones. /// /// # Examples /// ``` /// use actix_web::middleware::{ErrorHandlers, ErrorHandlerResponse}; /// use actix_web::{web, http, dev, App, HttpRequest, HttpResponse, Result}; /// /// fn render_500(mut res: dev::ServiceResponse) -> Result> { /// res.response_mut() /// .headers_mut() /// .insert(http::header::CONTENT_TYPE, http::HeaderValue::from_static("Error")); /// Ok(ErrorHandlerResponse::Response(res)) /// } /// /// let app = App::new() /// .wrap( /// ErrorHandlers::new() /// .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500), /// ) /// .service(web::resource("/test") /// .route(web::get().to(|| HttpResponse::Ok())) /// .route(web::head().to(|| HttpResponse::MethodNotAllowed()) /// )); /// ``` pub struct ErrorHandlers { handlers: Handlers, } type Handlers = Rc>>>; impl Default for ErrorHandlers { fn default() -> Self { ErrorHandlers { handlers: Rc::new(AHashMap::default()), } } } impl ErrorHandlers { /// Construct new `ErrorHandlers` instance. pub fn new() -> Self { ErrorHandlers::default() } /// Register error handler for specified status code. pub fn handler(mut self, status: StatusCode, handler: F) -> Self where F: Fn(ServiceResponse) -> Result> + 'static, { Rc::get_mut(&mut self.handlers) .unwrap() .insert(status, Box::new(handler)); self } } impl Transform for ErrorHandlers where S: Service, Error = Error> + 'static, S::Future: 'static, B: 'static, { type Response = ServiceResponse; type Error = Error; type Transform = ErrorHandlersMiddleware; type InitError = (); type Future = LocalBoxFuture<'static, Result>; fn new_transform(&self, service: S) -> Self::Future { let handlers = self.handlers.clone(); Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) }) } } #[doc(hidden)] pub struct ErrorHandlersMiddleware { service: S, handlers: Handlers, } impl Service for ErrorHandlersMiddleware where S: Service, Error = Error>, S::Future: 'static, B: 'static, { type Response = ServiceResponse; type Error = Error; type Future = ErrorHandlersFuture; actix_service::forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { let handlers = self.handlers.clone(); let fut = self.service.call(req); ErrorHandlersFuture::ServiceFuture { fut, handlers } } } pin_project! { #[project = ErrorHandlersProj] pub enum ErrorHandlersFuture where Fut: Future, { ServiceFuture { #[pin] fut: Fut, handlers: Handlers, }, HandlerFuture { fut: LocalBoxFuture<'static, Fut::Output>, }, } } impl Future for ErrorHandlersFuture where Fut: Future, Error>>, { type Output = Fut::Output; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().project() { ErrorHandlersProj::ServiceFuture { fut, handlers } => { let res = ready!(fut.poll(cx))?; match handlers.get(&res.status()) { Some(handler) => match handler(res)? { ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)), ErrorHandlerResponse::Future(fut) => { self.as_mut() .set(ErrorHandlersFuture::HandlerFuture { fut }); self.poll(cx) } }, None => Poll::Ready(Ok(res)), } } ErrorHandlersProj::HandlerFuture { fut } => fut.as_mut().poll(cx), } } } #[cfg(test)] mod tests { use actix_service::IntoService; use actix_utils::future::ok; use futures_util::future::FutureExt as _; use super::*; use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode}; use crate::test::{self, TestRequest}; use crate::HttpResponse; #[allow(clippy::unnecessary_wraps)] fn render_500(mut res: ServiceResponse) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); Ok(ErrorHandlerResponse::Response(res)) } #[actix_rt::test] async fn test_handler() { let srv = |req: ServiceRequest| { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; let mw = ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500) .new_transform(srv.into_service()) .await .unwrap(); let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } #[allow(clippy::unnecessary_wraps)] fn render_500_async( mut res: ServiceResponse, ) -> Result> { res.response_mut() .headers_mut() .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); Ok(ErrorHandlerResponse::Future(ok(res).boxed_local())) } #[actix_rt::test] async fn test_handler_async() { let srv = |req: ServiceRequest| { ok(req.into_response(HttpResponse::InternalServerError().finish())) }; let mw = ErrorHandlers::new() .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async) .new_transform(srv.into_service()) .await .unwrap(); let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await; assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); } }