From 6ac4ac66b96aa73f694d8de7fde2e0a2cffa1af8 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 21 Nov 2019 10:54:07 +0600 Subject: [PATCH] migrate actix-cors --- Cargo.toml | 2 +- actix-cors/Cargo.toml | 4 +- actix-cors/src/lib.rs | 722 ++++++++++++++++++++++-------------------- 3 files changed, 383 insertions(+), 345 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b80cf3e6e..6827a6196 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ members = [ ".", "awc", "actix-http", - #"actix-cors", + "actix-cors", #"actix-files", #"actix-framed", #"actix-session", diff --git a/actix-cors/Cargo.toml b/actix-cors/Cargo.toml index 56b6fabd9..57aa5833a 100644 --- a/actix-cors/Cargo.toml +++ b/actix-cors/Cargo.toml @@ -17,7 +17,7 @@ name = "actix_cors" path = "src/lib.rs" [dependencies] -actix-web = "1.0.9" -actix-service = "0.4.0" +actix-web = "2.0.0-alpha.1" +actix-service = "1.0.0-alpha.1" derive_more = "0.15.0" futures = "0.3.1" diff --git a/actix-cors/src/lib.rs b/actix-cors/src/lib.rs index c76bae925..40f9fdf99 100644 --- a/actix-cors/src/lib.rs +++ b/actix-cors/src/lib.rs @@ -23,7 +23,8 @@ //! .allowed_methods(vec!["GET", "POST"]) //! .allowed_headers(vec![http::header::AUTHORIZATION, http::header::ACCEPT]) //! .allowed_header(http::header::CONTENT_TYPE) -//! .max_age(3600)) +//! .max_age(3600) +//! .finish()) //! .service( //! web::resource("/index.html") //! .route(web::get().to(index)) @@ -41,16 +42,16 @@ use std::collections::HashSet; use std::iter::FromIterator; use std::rc::Rc; +use std::task::{Context, Poll}; -use actix_service::{IntoTransform, Service, Transform}; +use actix_service::{Service, Transform}; use actix_web::dev::{RequestHead, ServiceRequest, ServiceResponse}; use actix_web::error::{Error, ResponseError, Result}; use actix_web::http::header::{self, HeaderName, HeaderValue}; use actix_web::http::{self, HttpTryFrom, Method, StatusCode, Uri}; use actix_web::HttpResponse; use derive_more::Display; -use futures::future::{ok, Either, Future, FutureResult}; -use futures::Poll; +use futures::future::{ok, Either, FutureExt, LocalBoxFuture, Ready}; /// A set of errors that can occur during processing CORS #[derive(Debug, Display)] @@ -456,25 +457,9 @@ impl Cors { } self } -} -fn cors<'a>( - parts: &'a mut Option, - err: &Option, -) -> Option<&'a mut Inner> { - if err.is_some() { - return None; - } - parts.as_mut() -} - -impl IntoTransform for Cors -where - S: Service, Error = Error>, - S::Future: 'static, - B: 'static, -{ - fn into_transform(self) -> CorsFactory { + /// Construct cors middleware + pub fn finish(self) -> CorsFactory { let mut slf = if !self.methods { self.allowed_methods(vec![ Method::GET, @@ -521,6 +506,16 @@ where } } +fn cors<'a>( + parts: &'a mut Option, + err: &Option, +) -> Option<&'a mut Inner> { + if err.is_some() { + return None; + } + parts.as_mut() +} + /// `Middleware` for Cross-origin resource sharing support /// /// The Cors struct contains the settings for CORS requests to be validated and @@ -540,7 +535,7 @@ where type Error = Error; type InitError = (); type Transform = CorsMiddleware; - type Future = FutureResult; + type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(CorsMiddleware { @@ -682,12 +677,12 @@ where type Response = ServiceResponse; type Error = Error; type Future = Either< - FutureResult, - Either>>, + Ready>, + LocalBoxFuture<'static, Result>, >; - fn poll_ready(&mut self) -> Poll<(), Self::Error> { - self.service.poll_ready() + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.service.poll_ready(cx) } fn call(&mut self, req: ServiceRequest) -> Self::Future { @@ -698,7 +693,7 @@ where .and_then(|_| self.inner.validate_allowed_method(req.head())) .and_then(|_| self.inner.validate_allowed_headers(req.head())) { - return Either::A(ok(req.error_response(e))); + return Either::Left(ok(req.error_response(e))); } // allowed headers @@ -751,39 +746,50 @@ where .finish() .into_body(); - Either::A(ok(req.into_response(res))) - } else if req.headers().contains_key(&header::ORIGIN) { - // Only check requests with a origin header. - if let Err(e) = self.inner.validate_origin(req.head()) { - return Either::A(ok(req.error_response(e))); + Either::Left(ok(req.into_response(res))) + } else { + if req.headers().contains_key(&header::ORIGIN) { + // Only check requests with a origin header. + if let Err(e) = self.inner.validate_origin(req.head()) { + return Either::Left(ok(req.error_response(e))); + } } let inner = self.inner.clone(); + let has_origin = req.headers().contains_key(&header::ORIGIN); + let fut = self.service.call(req); - Either::B(Either::B(Box::new(self.service.call(req).and_then( - move |mut res| { - if let Some(origin) = - inner.access_control_allow_origin(res.request().head()) - { - res.headers_mut() - .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()); - }; + Either::Right( + async move { + let res = fut.await; - if let Some(ref expose) = inner.expose_hdrs { - res.headers_mut().insert( - header::ACCESS_CONTROL_EXPOSE_HEADERS, - HeaderValue::try_from(expose.as_str()).unwrap(), - ); - } - if inner.supports_credentials { - res.headers_mut().insert( - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - HeaderValue::from_static("true"), - ); - } - if inner.vary_header { - let value = - if let Some(hdr) = res.headers_mut().get(&header::VARY) { + if has_origin { + let mut res = res?; + if let Some(origin) = + inner.access_control_allow_origin(res.request().head()) + { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + origin.clone(), + ); + }; + + if let Some(ref expose) = inner.expose_hdrs { + res.headers_mut().insert( + header::ACCESS_CONTROL_EXPOSE_HEADERS, + HeaderValue::try_from(expose.as_str()).unwrap(), + ); + } + if inner.supports_credentials { + res.headers_mut().insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + } + if inner.vary_header { + let value = if let Some(hdr) = + res.headers_mut().get(&header::VARY) + { let mut val: Vec = Vec::with_capacity(hdr.as_bytes().len() + 8); val.extend(hdr.as_bytes()); @@ -792,159 +798,153 @@ where } else { HeaderValue::from_static("Origin") }; - res.headers_mut().insert(header::VARY, value); + res.headers_mut().insert(header::VARY, value); + } + Ok(res) + } else { + res } - Ok(res) - }, - )))) - } else { - Either::B(Either::A(self.service.call(req))) + } + .boxed_local(), + ) } } } #[cfg(test)] mod tests { - use actix_service::{IntoService, Transform}; + use actix_service::{service_fn2, Transform}; use actix_web::test::{self, block_on, TestRequest}; use super::*; - impl Cors { - fn finish(self, srv: F) -> CorsMiddleware - where - F: IntoService, - S: Service< - Request = ServiceRequest, - Response = ServiceResponse, - Error = Error, - > + 'static, - S::Future: 'static, - B: 'static, - { - block_on( - IntoTransform::::into_transform(self) - .new_transform(srv.into_service()), - ) - .unwrap() - } - } - #[test] #[should_panic(expected = "Credentials are allowed, but the Origin is set to")] fn cors_validates_illegal_allow_credentials() { - let _cors = Cors::new() - .supports_credentials() - .send_wildcard() - .finish(test::ok_service()); + let _cors = Cors::new().supports_credentials().send_wildcard().finish(); } #[test] fn validate_origin_allows_all_origins() { - let mut cors = Cors::new().finish(test::ok_service()); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .to_srv_request(); + block_on(async { + let mut cors = Cors::new() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::OK); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn default() { - let mut cors = - block_on(Cors::default().new_transform(test::ok_service())).unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .to_srv_request(); + block_on(async { + let mut cors = Cors::default() + .new_transform(test::ok_service()) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::OK); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_preflight() { - let mut cors = Cors::new() - .send_wildcard() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) - .allowed_header(header::CONTENT_TYPE) - .finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .send_wildcard() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT]) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed") + .to_srv_request(); - assert!(cors.inner.validate_allowed_method(req.head()).is_err()); - assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_err()); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "put") + .method(Method::OPTIONS) + .to_srv_request(); - assert!(cors.inner.validate_allowed_method(req.head()).is_err()); - assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); + assert!(cors.inner.validate_allowed_method(req.head()).is_err()); + assert!(cors.inner.validate_allowed_headers(req.head()).is_ok()); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header( - header::ACCESS_CONTROL_REQUEST_HEADERS, - "AUTHORIZATION,ACCEPT", - ) - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"*"[..], - resp.headers() - .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(&header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"3600"[..], + resp.headers() + .get(&header::ACCESS_CONTROL_MAX_AGE) + .unwrap() + .as_bytes() + ); + let hdr = resp + .headers() + .get(&header::ACCESS_CONTROL_ALLOW_HEADERS) .unwrap() - .as_bytes() - ); - assert_eq!( - &b"3600"[..], - resp.headers() - .get(&header::ACCESS_CONTROL_MAX_AGE) + .to_str() + .unwrap(); + assert!(hdr.contains("authorization")); + assert!(hdr.contains("accept")); + assert!(hdr.contains("content-type")); + + let methods = resp + .headers() + .get(header::ACCESS_CONTROL_ALLOW_METHODS) .unwrap() - .as_bytes() - ); - let hdr = resp - .headers() - .get(&header::ACCESS_CONTROL_ALLOW_HEADERS) - .unwrap() - .to_str() - .unwrap(); - assert!(hdr.contains("authorization")); - assert!(hdr.contains("accept")); - assert!(hdr.contains("content-type")); + .to_str() + .unwrap(); + assert!(methods.contains("POST")); + assert!(methods.contains("GET")); + assert!(methods.contains("OPTIONS")); - let methods = resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_METHODS) - .unwrap() - .to_str() - .unwrap(); - assert!(methods.contains("POST")); - assert!(methods.contains("GET")); - assert!(methods.contains("OPTIONS")); + Rc::get_mut(&mut cors.inner).unwrap().preflight = false; - Rc::get_mut(&mut cors.inner).unwrap().preflight = false; + let req = TestRequest::with_header("Origin", "https://www.example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .header( + header::ACCESS_CONTROL_REQUEST_HEADERS, + "AUTHORIZATION,ACCEPT", + ) + .method(Method::OPTIONS) + .to_srv_request(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .header( - header::ACCESS_CONTROL_REQUEST_HEADERS, - "AUTHORIZATION,ACCEPT", - ) - .method(Method::OPTIONS) - .to_srv_request(); - - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::OK); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } // #[test] @@ -960,216 +960,254 @@ mod tests { #[test] #[should_panic(expected = "OriginNotAllowed")] fn test_validate_not_allowed_origin() { - let cors = Cors::new() - .allowed_origin("https://www.example.com") - .finish(test::ok_service()); + block_on(async { + let cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.unknown.com") - .method(Method::GET) - .to_srv_request(); - cors.inner.validate_origin(req.head()).unwrap(); - cors.inner.validate_allowed_method(req.head()).unwrap(); - cors.inner.validate_allowed_headers(req.head()).unwrap(); + let req = TestRequest::with_header("Origin", "https://www.unknown.com") + .method(Method::GET) + .to_srv_request(); + cors.inner.validate_origin(req.head()).unwrap(); + cors.inner.validate_allowed_method(req.head()).unwrap(); + cors.inner.validate_allowed_headers(req.head()).unwrap(); + }) } #[test] fn test_validate_origin() { - let mut cors = Cors::new() - .allowed_origin("https://www.example.com") - .finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .allowed_origin("https://www.example.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::GET) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::GET) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!(resp.status(), StatusCode::OK); + let resp = test::call_service(&mut cors, req).await; + assert_eq!(resp.status(), StatusCode::OK); + }) } #[test] fn test_no_origin_response() { - let mut cors = Cors::new().disable_preflight().finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .disable_preflight() + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::default().method(Method::GET).to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert!(resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .is_none()); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://www.example.com"[..], - resp.headers() + let req = TestRequest::default().method(Method::GET).to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert!(resp + .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + .is_none()); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://www.example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + }) } #[test] fn test_response() { - let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() - .send_wildcard() - .disable_preflight() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(exposed_headers.clone()) - .expose_headers(exposed_headers.clone()) - .allowed_header(header::CONTENT_TYPE) - .finish(test::ok_service()); + block_on(async { + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"*"[..], - resp.headers() + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"*"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + assert_eq!( + &b"Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + { + let headers = resp + .headers() + .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) + .unwrap() + .to_str() + .unwrap() + .split(',') + .map(|s| s.trim()) + .collect::>(); + + for h in exposed_headers { + assert!(headers.contains(&h.as_str())); + } + } + + let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; + let mut cors = Cors::new() + .send_wildcard() + .disable_preflight() + .max_age(3600) + .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) + .allowed_headers(exposed_headers.clone()) + .expose_headers(exposed_headers.clone()) + .allowed_header(header::CONTENT_TYPE) + .finish() + .new_transform(service_fn2(|req: ServiceRequest| { + ok(req.into_response( + HttpResponse::Ok().header(header::VARY, "Accept").finish(), + )) + })) + .await + .unwrap(); + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"Accept, Origin"[..], + resp.headers().get(header::VARY).unwrap().as_bytes() + ); + + let mut cors = Cors::new() + .disable_vary_header() + .allowed_origin("https://www.example.com") + .allowed_origin("https://www.google.com") + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); + + let req = TestRequest::with_header("Origin", "https://www.example.com") + .method(Method::OPTIONS) + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") + .to_srv_request(); + let resp = test::call_service(&mut cors, req).await; + + let origins_str = resp + .headers() .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) .unwrap() - .as_bytes() - ); - assert_eq!( - &b"Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() - ); - - { - let headers = resp - .headers() - .get(header::ACCESS_CONTROL_EXPOSE_HEADERS) - .unwrap() .to_str() - .unwrap() - .split(',') - .map(|s| s.trim()) - .collect::>(); + .unwrap(); - for h in exposed_headers { - assert!(headers.contains(&h.as_str())); - } - } - - let exposed_headers = vec![header::AUTHORIZATION, header::ACCEPT]; - let mut cors = Cors::new() - .send_wildcard() - .disable_preflight() - .max_age(3600) - .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST]) - .allowed_headers(exposed_headers.clone()) - .expose_headers(exposed_headers.clone()) - .allowed_header(header::CONTENT_TYPE) - .finish(|req: ServiceRequest| { - req.into_response( - HttpResponse::Ok().header(header::VARY, "Accept").finish(), - ) - }); - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"Accept, Origin"[..], - resp.headers().get(header::VARY).unwrap().as_bytes() - ); - - let mut cors = Cors::new() - .disable_vary_header() - .allowed_origin("https://www.example.com") - .allowed_origin("https://www.google.com") - .finish(test::ok_service()); - - let req = TestRequest::with_header("Origin", "https://www.example.com") - .method(Method::OPTIONS) - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST") - .to_srv_request(); - let resp = test::call_service(&mut cors, req); - - let origins_str = resp - .headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .to_str() - .unwrap(); - - assert_eq!("https://www.example.com", origins_str); + assert_eq!("https://www.example.com", origins_str); + }) } #[test] fn test_multiple_origins() { - let mut cors = Cors::new() - .allowed_origin("https://example.com") - .allowed_origin("https://example.org") - .allowed_methods(vec![Method::GET]) - .finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://example.com") - .method(Method::GET) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://example.com") + .method(Method::GET) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); - let req = TestRequest::with_header("Origin", "https://example.org") - .method(Method::GET) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://example.org") + .method(Method::GET) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://example.org"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + }) } #[test] fn test_multiple_origins_preflight() { - let mut cors = Cors::new() - .allowed_origin("https://example.com") - .allowed_origin("https://example.org") - .allowed_methods(vec![Method::GET]) - .finish(test::ok_service()); + block_on(async { + let mut cors = Cors::new() + .allowed_origin("https://example.com") + .allowed_origin("https://example.org") + .allowed_methods(vec![Method::GET]) + .finish() + .new_transform(test::ok_service()) + .await + .unwrap(); - let req = TestRequest::with_header("Origin", "https://example.com") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://example.com") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://example.com"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.com"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); - let req = TestRequest::with_header("Origin", "https://example.org") - .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") - .method(Method::OPTIONS) - .to_srv_request(); + let req = TestRequest::with_header("Origin", "https://example.org") + .header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET") + .method(Method::OPTIONS) + .to_srv_request(); - let resp = test::call_service(&mut cors, req); - assert_eq!( - &b"https://example.org"[..], - resp.headers() - .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) - .unwrap() - .as_bytes() - ); + let resp = test::call_service(&mut cors, req).await; + assert_eq!( + &b"https://example.org"[..], + resp.headers() + .get(header::ACCESS_CONTROL_ALLOW_ORIGIN) + .unwrap() + .as_bytes() + ); + }) } }