1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-06-02 13:29:24 +00:00
actix-web/src/middleware/errhandlers.rs

207 lines
6.3 KiB
Rust
Raw Normal View History

2019-03-24 18:49:26 +00:00
//! Custom handlers service for responses.
use std::rc::Rc;
2019-11-21 08:52:33 +00:00
use std::task::{Context, Poll};
use actix_service::{Service, Transform};
2020-05-18 02:47:20 +00:00
use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready};
2019-12-04 12:32:18 +00:00
use fxhash::FxHashMap;
use crate::dev::{ServiceRequest, ServiceResponse};
use crate::error::{Error, Result};
use crate::http::StatusCode;
/// Error handler response
pub enum ErrorHandlerResponse<B> {
/// New http response got generated
Response(ServiceResponse<B>),
/// Result is a future that resolves to a new http response
2019-11-21 08:52:33 +00:00
Future(LocalBoxFuture<'static, Result<ServiceResponse<B>, Error>>),
}
2019-07-17 09:48:37 +00:00
type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;
/// `Middleware` for allowing custom handlers for responses.
///
/// You can use `ErrorHandlers::handler()` method to register a custom error
/// handler for specific status code. You can modify existing response or
/// create completely new one.
///
/// ## Example
///
/// ```rust
2019-03-24 18:47:23 +00:00
/// use actix_web::middleware::errhandlers::{ErrorHandlers, ErrorHandlerResponse};
/// use actix_web::{web, http, dev, App, HttpRequest, HttpResponse, Result};
///
/// fn render_500<B>(mut res: dev::ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
/// res.response_mut()
/// .headers_mut()
/// .insert(http::header::CONTENT_TYPE, http::HeaderValue::from_static("Error"));
/// Ok(ErrorHandlerResponse::Response(res))
/// }
///
2019-11-21 08:52:33 +00:00
/// # fn main() {
/// let app = App::new()
/// .wrap(
/// ErrorHandlers::new()
/// .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500),
/// )
/// .service(web::resource("/test")
/// .route(web::get().to(|| HttpResponse::Ok()))
/// .route(web::head().to(|| HttpResponse::MethodNotAllowed())
/// ));
/// # }
/// ```
pub struct ErrorHandlers<B> {
2019-12-04 12:32:18 +00:00
handlers: Rc<FxHashMap<StatusCode, Box<ErrorHandler<B>>>>,
}
impl<B> Default for ErrorHandlers<B> {
fn default() -> Self {
ErrorHandlers {
2019-12-04 12:32:18 +00:00
handlers: Rc::new(FxHashMap::default()),
}
}
}
impl<B> ErrorHandlers<B> {
/// Construct new `ErrorHandlers` instance
pub fn new() -> Self {
ErrorHandlers::default()
}
/// Register error handler for specified status code
pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
where
F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
{
Rc::get_mut(&mut self.handlers)
.unwrap()
.insert(status, Box::new(handler));
self
}
}
impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Transform = ErrorHandlersMiddleware<S, B>;
type InitError = ();
2019-11-21 08:52:33 +00:00
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(ErrorHandlersMiddleware {
service,
handlers: self.handlers.clone(),
})
}
}
2019-03-24 18:32:30 +00:00
#[doc(hidden)]
pub struct ErrorHandlersMiddleware<S, B> {
service: S,
2019-12-04 12:32:18 +00:00
handlers: Rc<FxHashMap<StatusCode, Box<ErrorHandler<B>>>>,
}
impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
2019-11-21 08:52:33 +00:00
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
2019-12-07 18:46:51 +00:00
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
2019-11-21 08:52:33 +00:00
self.service.poll_ready(cx)
}
fn call(&mut self, req: ServiceRequest) -> Self::Future {
let handlers = self.handlers.clone();
2019-11-21 08:52:33 +00:00
let fut = self.service.call(req);
async move {
let res = fut.await?;
if let Some(handler) = handlers.get(&res.status()) {
match handler(res) {
2019-11-21 08:52:33 +00:00
Ok(ErrorHandlerResponse::Response(res)) => Ok(res),
Ok(ErrorHandlerResponse::Future(fut)) => fut.await,
Err(e) => Err(e),
}
} else {
2019-11-21 08:52:33 +00:00
Ok(res)
}
2019-11-21 08:52:33 +00:00
}
.boxed_local()
}
}
#[cfg(test)]
mod tests {
2019-05-12 15:34:51 +00:00
use actix_service::IntoService;
2020-05-18 02:47:20 +00:00
use futures_util::future::ok;
use super::*;
use crate::http::{header::CONTENT_TYPE, HeaderValue, StatusCode};
2019-11-26 05:25:50 +00:00
use crate::test::{self, TestRequest};
use crate::HttpResponse;
#[allow(clippy::unnecessary_wraps)]
fn render_500<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(ErrorHandlerResponse::Response(res))
}
2019-11-26 05:25:50 +00:00
#[actix_rt::test]
async fn test_handler() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mut mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500)
.new_transform(srv.into_service())
.await
.unwrap();
let resp =
test::call_service(&mut mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
2020-12-27 20:53:19 +00:00
#[allow(clippy::unnecessary_wraps)]
fn render_500_async<B: 'static>(
mut res: ServiceResponse<B>,
) -> Result<ErrorHandlerResponse<B>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
2019-11-21 08:52:33 +00:00
Ok(ErrorHandlerResponse::Future(ok(res).boxed_local()))
}
2019-11-26 05:25:50 +00:00
#[actix_rt::test]
async fn test_handler_async() {
let srv = |req: ServiceRequest| {
ok(req.into_response(HttpResponse::InternalServerError().finish()))
};
let mut mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, render_500_async)
.new_transform(srv.into_service())
.await
.unwrap();
let resp =
test::call_service(&mut mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
}