diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs index a7b0110f8..0db17cec4 100644 --- a/src/middleware/cors.rs +++ b/src/middleware/cors.rs @@ -19,16 +19,16 @@ //! //! ```rust //! # extern crate actix_web; -//! use actix_web::{http, App, HttpRequest, HttpResponse}; //! use actix_web::middleware::cors::Cors; +//! use actix_web::{http, App, HttpRequest, HttpResponse}; //! //! fn index(mut req: HttpRequest) -> &'static str { -//! "Hello world" +//! "Hello world" //! } //! //! fn main() { -//! let app = App::new() -//! .configure(|app| Cors::for_app(app) // <- Construct CORS middleware builder +//! let app = App::new().configure(|app| { +//! Cors::for_app(app) // <- Construct CORS middleware builder //! .allowed_origin("https://www.rust-lang.org/") //! .allowed_methods(vec!["GET", "POST"]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) @@ -38,7 +38,8 @@ //! r.method(http::Method::GET).f(|_| HttpResponse::Ok()); //! r.method(http::Method::HEAD).f(|_| HttpResponse::MethodNotAllowed()); //! }) -//! .register()); +//! .register() +//! }); //! } //! ``` //! In this example custom *CORS* middleware get registered for "/index.html" @@ -232,18 +233,20 @@ impl Cors { /// /// ```rust /// # extern crate actix_web; - /// use actix_web::{http, App, HttpResponse}; /// use actix_web::middleware::cors::Cors; + /// use actix_web::{http, App, HttpResponse}; /// /// fn main() { - /// let app = App::new() - /// .configure(|app| Cors::for_app(app) // <- Construct CORS builder + /// let app = App::new().configure( + /// |app| { + /// Cors::for_app(app) // <- Construct CORS builder /// .allowed_origin("https://www.rust-lang.org/") /// .resource("/resource", |r| { // register resource /// r.method(http::Method::GET).f(|_| HttpResponse::Ok()); /// }) - /// .register() // construct CORS and return application instance - /// ); + /// .register() + /// }, // construct CORS and return application instance + /// ); /// } /// ``` pub fn for_app(app: App) -> CorsBuilder { @@ -420,7 +423,10 @@ impl Middleware for Cors { .finish(), )) } else { - self.validate_origin(req)?; + // Only check requests with a origin header. + if req.headers().contains_key(header::ORIGIN) { + self.validate_origin(req)?; + } Ok(Started::Done) } @@ -491,8 +497,8 @@ impl Middleware for Cors { /// ```rust /// # extern crate http; /// # extern crate actix_web; -/// use http::header; /// use actix_web::middleware::cors; +/// use http::header; /// /// # fn main() { /// let cors = cors::Cors::build() @@ -764,12 +770,13 @@ impl CorsBuilder { /// /// ```rust /// # extern crate actix_web; - /// use actix_web::{http, App, HttpResponse}; /// use actix_web::middleware::cors::Cors; + /// use actix_web::{http, App, HttpResponse}; /// /// fn main() { - /// let app = App::new() - /// .configure(|app| Cors::for_app(app) // <- Construct CORS builder + /// let app = App::new().configure( + /// |app| { + /// Cors::for_app(app) // <- Construct CORS builder /// .allowed_origin("https://www.rust-lang.org/") /// .allowed_methods(vec!["GET", "POST"]) /// .allowed_header(http::header::CONTENT_TYPE) @@ -781,8 +788,9 @@ impl CorsBuilder { /// r.method(http::Method::HEAD) /// .f(|_| HttpResponse::MethodNotAllowed()); /// }) - /// .register() // construct CORS and return application instance - /// ); + /// .register() + /// }, // construct CORS and return application instance + /// ); /// } /// ``` pub fn resource(&mut self, path: &str, f: F) -> &mut CorsBuilder @@ -1001,16 +1009,15 @@ mod tests { assert!(cors.start(&mut req).unwrap().is_done()); } - #[test] - #[should_panic(expected = "MissingOrigin")] - fn test_validate_missing_origin() { - let cors = Cors::build() - .allowed_origin("https://www.example.com") - .finish(); - - let mut req = HttpRequest::default(); - cors.start(&mut req).unwrap(); - } + // #[test] + // #[should_panic(expected = "MissingOrigin")] + // fn test_validate_missing_origin() { + // let mut cors = Cors::build() + // .allowed_origin("https://www.example.com") + // .finish(); + // let mut req = HttpRequest::default(); + // cors.start(&mut req).unwrap(); + // } #[test] #[should_panic(expected = "OriginNotAllowed")] @@ -1127,10 +1134,19 @@ mod tests { }) }); - let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let request = srv + .get() + .uri(srv.url("/test")) + .header("ORIGIN", "https://www.example2.com") + .finish() + .unwrap(); let response = srv.execute(request.send()).unwrap(); assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let request = srv.get().uri(srv.url("/test")).finish().unwrap(); + let response = srv.execute(request.send()).unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let request = srv .get() .uri(srv.url("/test"))