//! For middleware documentation, see [`DefaultHeaders`]. use std::{ future::Future, marker::PhantomData, pin::Pin, rc::Rc, task::{Context, Poll}, }; use actix_http::error::HttpError; use actix_utils::future::{ready, Ready}; use futures_core::ready; use pin_project_lite::pin_project; use crate::{ dev::{Service, Transform}, http::header::{HeaderMap, HeaderName, HeaderValue, TryIntoHeaderPair, CONTENT_TYPE}, service::{ServiceRequest, ServiceResponse}, Error, }; /// Middleware for setting default response headers. /// /// Headers with the same key that are already set in a response will *not* be overwritten. /// /// # Examples /// ``` /// use actix_web::{web, http, middleware, App, HttpResponse}; /// /// let app = App::new() /// .wrap(middleware::DefaultHeaders::new().add(("X-Version", "0.2"))) /// .service( /// web::resource("/test") /// .route(web::get().to(|| HttpResponse::Ok())) /// .route(web::method(http::Method::HEAD).to(|| HttpResponse::MethodNotAllowed())) /// ); /// ``` #[derive(Debug, Clone, Default)] pub struct DefaultHeaders { inner: Rc, } #[derive(Debug, Default)] struct Inner { headers: HeaderMap, } impl DefaultHeaders { /// Constructs an empty `DefaultHeaders` middleware. #[inline] pub fn new() -> DefaultHeaders { DefaultHeaders::default() } /// Adds a header to the default set. /// /// # Panics /// Panics when resolved header name or value is invalid. #[allow(clippy::should_implement_trait)] pub fn add(mut self, header: impl TryIntoHeaderPair) -> Self { // standard header terminology `insert` or `append` for this method would make the behavior // of this middleware less obvious since it only adds the headers if they are not present match header.try_into_pair() { Ok((key, value)) => Rc::get_mut(&mut self.inner) .expect("All default headers must be added before cloning.") .headers .append(key, value), Err(err) => panic!("Invalid header: {}", err.into()), } self } #[doc(hidden)] #[deprecated( since = "4.0.0", note = "Prefer `.add((key, value))`. Will be removed in v5." )] pub fn header(self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: Into, HeaderValue: TryFrom, >::Error: Into, { self.add(( HeaderName::try_from(key) .map_err(Into::into) .expect("Invalid header name"), HeaderValue::try_from(value) .map_err(Into::into) .expect("Invalid header value"), )) } /// Adds a default *Content-Type* header if response does not contain one. /// /// Default is `application/octet-stream`. pub fn add_content_type(self) -> Self { #[allow(clippy::declare_interior_mutable_const)] const HV_MIME: HeaderValue = HeaderValue::from_static("application/octet-stream"); self.add((CONTENT_TYPE, HV_MIME)) } } impl Transform for DefaultHeaders where S: Service, Error = Error>, S::Future: 'static, { type Response = ServiceResponse; type Error = Error; type Transform = DefaultHeadersMiddleware; type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(DefaultHeadersMiddleware { service, inner: Rc::clone(&self.inner), })) } } pub struct DefaultHeadersMiddleware { service: S, inner: Rc, } impl Service for DefaultHeadersMiddleware where S: Service, Error = Error>, S::Future: 'static, { type Response = ServiceResponse; type Error = Error; type Future = DefaultHeaderFuture; actix_service::forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { let inner = self.inner.clone(); let fut = self.service.call(req); DefaultHeaderFuture { fut, inner, _body: PhantomData, } } } pin_project! { pub struct DefaultHeaderFuture, B> { #[pin] fut: S::Future, inner: Rc, _body: PhantomData, } } impl Future for DefaultHeaderFuture where S: Service, Error = Error>, { type Output = ::Output; #[allow(clippy::borrow_interior_mutable_const)] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); let mut res = ready!(this.fut.poll(cx))?; // set response headers for (key, value) in this.inner.headers.iter() { if !res.headers().contains_key(key) { res.headers_mut().insert(key.clone(), value.clone()); } } Poll::Ready(Ok(res)) } } #[cfg(test)] mod tests { use actix_service::IntoService; use actix_utils::future::ok; use super::*; use crate::{ test::{self, TestRequest}, HttpResponse, }; #[actix_rt::test] async fn adding_default_headers() { let mw = DefaultHeaders::new() .add(("X-TEST", "0001")) .add(("X-TEST-TWO", HeaderValue::from_static("123"))) .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default().to_srv_request(); let res = mw.call(req).await.unwrap(); assert_eq!(res.headers().get("x-test").unwrap(), "0001"); assert_eq!(res.headers().get("x-test-two").unwrap(), "123"); } #[actix_rt::test] async fn no_override_existing() { let req = TestRequest::default().to_srv_request(); let srv = |req: ServiceRequest| { ok(req.into_response( HttpResponse::Ok() .insert_header((CONTENT_TYPE, "0002")) .finish(), )) }; let mw = DefaultHeaders::new() .add((CONTENT_TYPE, "0001")) .new_transform(srv.into_service()) .await .unwrap(); let resp = mw.call(req).await.unwrap(); assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); } #[actix_rt::test] async fn adding_content_type() { let mw = DefaultHeaders::new() .add_content_type() .new_transform(test::ok_service()) .await .unwrap(); let req = TestRequest::default().to_srv_request(); let resp = mw.call(req).await.unwrap(); assert_eq!( resp.headers().get(CONTENT_TYPE).unwrap(), "application/octet-stream" ); } #[test] #[should_panic] fn invalid_header_name() { DefaultHeaders::new().add((":", "hello")); } #[test] #[should_panic] fn invalid_header_value() { DefaultHeaders::new().add(("x-test", "\n")); } }