From be358db422055a7245495e887540735f38414196 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 9 Apr 2018 14:20:12 -0700 Subject: [PATCH] CorsBuilder::finish() panics on any configuration error --- src/middleware/cors.rs | 110 +++++++++++++++++++++-------------------- 1 file changed, 56 insertions(+), 54 deletions(-) diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index 28c5c7898..65f39d7b4 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -34,7 +34,7 @@ //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_header(http::header::CONTENT_TYPE) //! .max_age(3600) -//! .finish().expect("Can not create CORS middleware") +//! .finish() //! .register(r); // <- Register CORS middleware //! r.method(http::Method::GET).f(|_| HttpResponse::Ok()); //! r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); @@ -47,6 +47,7 @@ //! Cors middleware automatically handle *OPTIONS* preflight request. use std::collections::HashSet; use std::iter::FromIterator; +use std::rc::Rc; use http::{self, Method, HttpTryFrom, Uri, StatusCode}; use http::header::{self, HeaderName, HeaderValue}; @@ -91,19 +92,6 @@ pub enum CorsError { HeadersNotAllowed, } -/// A set of errors that can occur during building CORS middleware -#[derive(Debug, Fail)] -pub enum CorsBuilderError { - #[fail(display="Parse error: {}", _0)] - ParseError(http::Error), - /// Credentials are allowed, but the Origin is set to "*". This is not allowed by W3C - /// - /// This is a misconfiguration. Check the documentation for `Cors`. - #[fail(display="Credentials are allowed, but the Origin is set to \"*\"")] - CredentialsWithWildcardOrigin, -} - - impl ResponseError for CorsError { fn error_response(&self) -> HttpResponse { @@ -155,7 +143,12 @@ impl AllOrSome { /// /// The Cors struct contains the settings for CORS requests to be validated and /// for responses to be generated. +#[derive(Clone)] pub struct Cors { + inner: Rc, +} + +struct Inner { methods: HashSet, origins: AllOrSome>, origins_str: Option, @@ -170,7 +163,7 @@ pub struct Cors { impl Default for Cors { fn default() -> Cors { - Cors { + let inner = Inner { origins: AllOrSome::default(), origins_str: None, methods: HashSet::from_iter( @@ -184,14 +177,15 @@ impl Default for Cors { send_wildcard: false, supports_credentials: false, vary_header: true, - } + }; + Cors{inner: Rc::new(inner)} } } impl Cors { pub fn build() -> CorsBuilder { CorsBuilder { - cors: Some(Cors { + cors: Some(Inner { origins: AllOrSome::All, origins_str: None, methods: HashSet::new(), @@ -223,7 +217,7 @@ impl Cors { fn validate_origin(&self, req: &mut HttpRequest) -> Result<(), CorsError> { if let Some(hdr) = req.headers().get(header::ORIGIN) { if let Ok(origin) = hdr.to_str() { - return match self.origins { + return match self.inner.origins { AllOrSome::All => Ok(()), AllOrSome::Some(ref allowed_origins) => { allowed_origins @@ -235,7 +229,7 @@ impl Cors { } Err(CorsError::BadOrigin) } else { - return match self.origins { + return match self.inner.origins { AllOrSome::All => Ok(()), _ => Err(CorsError::MissingOrigin) } @@ -246,7 +240,7 @@ impl Cors { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_METHOD) { if let Ok(meth) = hdr.to_str() { if let Ok(method) = Method::try_from(meth) { - return self.methods.get(&method) + return self.inner.methods.get(&method) .and_then(|_| Some(())) .ok_or_else(|| CorsError::MethodNotAllowed); } @@ -258,7 +252,7 @@ impl Cors { } fn validate_allowed_headers(&self, req: &mut HttpRequest) -> Result<(), CorsError> { - match self.headers { + match self.inner.headers { AllOrSome::All => Ok(()), AllOrSome::Some(ref allowed_headers) => { if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { @@ -288,13 +282,13 @@ impl Cors { impl Middleware for Cors { fn start(&self, req: &mut HttpRequest) -> Result { - if self.preflight && Method::OPTIONS == *req.method() { + if self.inner.preflight && Method::OPTIONS == *req.method() { self.validate_origin(req)?; self.validate_allowed_method(req)?; self.validate_allowed_headers(req)?; // allowed headers - let headers = if let Some(headers) = self.headers.as_ref() { + let headers = if let Some(headers) = self.inner.headers.as_ref() { Some(HeaderValue::try_from(&headers.iter().fold( String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]).unwrap()) } else if let Some(hdr) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) { @@ -305,13 +299,13 @@ impl Middleware for Cors { Ok(Started::Response( HttpResponse::Ok() - .if_some(self.max_age.as_ref(), |max_age, resp| { + .if_some(self.inner.max_age.as_ref(), |max_age, resp| { let _ = resp.header( header::ACCESS_CONTROL_MAX_AGE, format!("{}", max_age).as_str());}) .if_some(headers, |headers, resp| { let _ = resp.header(header::ACCESS_CONTROL_ALLOW_HEADERS, headers); }) - .if_true(self.origins.is_all(), |resp| { - if self.send_wildcard { + .if_true(self.inner.origins.is_all(), |resp| { + if self.inner.send_wildcard { resp.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"); } else { let origin = req.headers().get(header::ORIGIN).unwrap(); @@ -319,17 +313,17 @@ impl Middleware for Cors { header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); } }) - .if_true(self.origins.is_some(), |resp| { + .if_true(self.inner.origins.is_some(), |resp| { resp.header( header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.origins_str.as_ref().unwrap().clone()); + self.inner.origins_str.as_ref().unwrap().clone()); }) - .if_true(self.supports_credentials, |resp| { + .if_true(self.inner.supports_credentials, |resp| { resp.header(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); }) .header( header::ACCESS_CONTROL_ALLOW_METHODS, - &self.methods.iter().fold( + &self.inner.methods.iter().fold( String::new(), |s, v| s + "," + v.as_str()).as_str()[1..]) .finish())) } else { @@ -340,9 +334,9 @@ impl Middleware for Cors { } fn response(&self, req: &mut HttpRequest, mut resp: HttpResponse) -> Result { - match self.origins { + match self.inner.origins { AllOrSome::All => { - if self.send_wildcard { + if self.inner.send_wildcard { resp.headers_mut().insert( header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); } else if let Some(origin) = req.headers().get(header::ORIGIN) { @@ -353,20 +347,20 @@ impl Middleware for Cors { AllOrSome::Some(_) => { resp.headers_mut().insert( header::ACCESS_CONTROL_ALLOW_ORIGIN, - self.origins_str.as_ref().unwrap().clone()); + self.inner.origins_str.as_ref().unwrap().clone()); } } - if let Some(ref expose) = self.expose_hdrs { + if let Some(ref expose) = self.inner.expose_hdrs { resp.headers_mut().insert( header::ACCESS_CONTROL_EXPOSE_HEADERS, HeaderValue::try_from(expose.as_str()).unwrap()); } - if self.supports_credentials { + if self.inner.supports_credentials { resp.headers_mut().insert( header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true")); } - if self.vary_header { + if self.inner.vary_header { let value = if let Some(hdr) = resp.headers_mut().get(header::VARY) { let mut val: Vec = Vec::with_capacity(hdr.as_bytes().len() + 8); val.extend(hdr.as_bytes()); @@ -404,17 +398,19 @@ impl Middleware for Cors { /// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) /// .allowed_header(header::CONTENT_TYPE) /// .max_age(3600) -/// .finish().unwrap(); +/// .finish(); /// # } /// ``` pub struct CorsBuilder { - cors: Option, + cors: Option, methods: bool, error: Option, expose_hdrs: HashSet, } -fn cors<'a>(parts: &'a mut Option, err: &Option) -> Option<&'a mut Cors> { +fn cors<'a>(parts: &'a mut Option, err: &Option) + -> Option<&'a mut Inner> +{ if err.is_some() { return None } @@ -437,6 +433,8 @@ impl CorsBuilder { /// [Resource Processing Model](https://www.w3.org/TR/cors/#resource-processing-model). /// /// Defaults to `All`. + /// + /// Builder panics if supplied origin is not valid uri. pub fn allowed_origin(&mut self, origin: &str) -> &mut CorsBuilder { if let Some(cors) = cors(&mut self.cors, &self.error) { match Uri::try_from(origin) { @@ -602,6 +600,9 @@ impl CorsBuilder { /// and `send_wildcards` set to `true`. /// /// Defaults to `false`. + /// + /// Builder panics if credentials are allowed, but the Origin is set to "*". + /// This is not allowed by W3C pub fn supports_credentials(&mut self) -> &mut CorsBuilder { if let Some(cors) = cors(&mut self.cors, &self.error) { cors.supports_credentials = true @@ -641,7 +642,9 @@ impl CorsBuilder { } /// Finishes building and returns the built `Cors` instance. - pub fn finish(&mut self) -> Result { + /// + /// This method panics in case of any configuration error. + pub fn finish(&mut self) -> Cors { if !self.methods { self.allowed_methods(vec![Method::GET, Method::HEAD, Method::POST, Method::OPTIONS, Method::PUT, @@ -649,13 +652,13 @@ impl CorsBuilder { } if let Some(e) = self.error.take() { - return Err(CorsBuilderError::ParseError(e)) + panic!("{}", e); } let mut cors = self.cors.take().expect("cannot reuse CorsBuilder"); if cors.supports_credentials && cors.send_wildcard && cors.origins.is_all() { - return Err(CorsBuilderError::CredentialsWithWildcardOrigin) + panic!("Credentials are allowed, but the Origin is set to \"*\""); } if let AllOrSome::Some(ref origins) = cors.origins { @@ -668,7 +671,7 @@ impl CorsBuilder { self.expose_hdrs.iter().fold( String::new(), |s, v| s + v.as_str())[1..].to_owned()); } - Ok(cors) + Cors{inner: Rc::new(cors)} } } @@ -702,13 +705,12 @@ mod tests { } #[test] - #[should_panic(expected = "CredentialsWithWildcardOrigin")] + #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] fn cors_validates_illegal_allow_credentials() { Cors::build() .supports_credentials() .send_wildcard() - .finish() - .unwrap(); + .finish(); } #[test] @@ -728,7 +730,7 @@ mod tests { .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_header(header::CONTENT_TYPE) - .finish().unwrap(); + .finish(); let mut req = TestRequest::with_header( "Origin", "https://www.example.com") @@ -764,7 +766,7 @@ mod tests { // &b"POST,GET,OPTIONS"[..], // resp.headers().get(header::ACCESS_CONTROL_ALLOW_METHODS).unwrap().as_bytes()); - cors.preflight = false; + Rc::get_mut(&mut cors.inner).unwrap().preflight = false; assert!(cors.start(&mut req).unwrap().is_done()); } @@ -772,7 +774,7 @@ mod tests { #[should_panic(expected = "MissingOrigin")] fn test_validate_missing_origin() { let cors = Cors::build() - .allowed_origin("https://www.example.com").finish().unwrap(); + .allowed_origin("https://www.example.com").finish(); let mut req = HttpRequest::default(); cors.start(&mut req).unwrap(); @@ -782,7 +784,7 @@ mod tests { #[should_panic(expected = "OriginNotAllowed")] fn test_validate_not_allowed_origin() { let cors = Cors::build() - .allowed_origin("https://www.example.com").finish().unwrap(); + .allowed_origin("https://www.example.com").finish(); let mut req = TestRequest::with_header("Origin", "https://www.unknown.com") .method(Method::GET) @@ -793,7 +795,7 @@ mod tests { #[test] fn test_validate_origin() { let cors = Cors::build() - .allowed_origin("https://www.example.com").finish().unwrap(); + .allowed_origin("https://www.example.com").finish(); let mut req = TestRequest::with_header("Origin", "https://www.example.com") .method(Method::GET) @@ -804,7 +806,7 @@ mod tests { #[test] fn test_no_origin_response() { - let cors = Cors::build().finish().unwrap(); + let cors = Cors::build().finish(); let mut req = TestRequest::default().method(Method::GET).finish(); let resp: HttpResponse = HttpResponse::Ok().into(); @@ -830,7 +832,7 @@ mod tests { .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) .allowed_header(header::CONTENT_TYPE) - .finish().unwrap(); + .finish(); let mut req = TestRequest::with_header( "Origin", "https://www.example.com") @@ -857,7 +859,7 @@ mod tests { let cors = Cors::build() .disable_vary_header() .allowed_origin("https://www.example.com") - .finish().unwrap(); + .finish(); let resp: HttpResponse = HttpResponse::Ok().into(); let resp = cors.response(&mut req, resp).unwrap().response(); assert_eq!(