diff --git a/actix-web-codegen/src/lib.rs b/actix-web-codegen/src/lib.rs index b724eb797..39a8a6464 100644 --- a/actix-web-codegen/src/lib.rs +++ b/actix-web-codegen/src/lib.rs @@ -21,6 +21,7 @@ //! //! - `"path"` - Raw literal string with path for which to register handle. Mandatory. //! - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard` +//! - `wrap="Middleware"` - Registers a resource middleware. //! //! ## Notes //! @@ -54,6 +55,7 @@ use proc_macro::TokenStream; /// /// - `"path"` - Raw literal string with path for which to register handler. Mandatory. /// - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard` +/// - `wrap="Middleware"` - Registers a resource middleware. #[proc_macro_attribute] pub fn get(args: TokenStream, input: TokenStream) -> TokenStream { route::generate(args, input, route::GuardType::Get) diff --git a/actix-web-codegen/src/route.rs b/actix-web-codegen/src/route.rs index 3e6f9c979..7e3d43f1d 100644 --- a/actix-web-codegen/src/route.rs +++ b/actix-web-codegen/src/route.rs @@ -56,12 +56,14 @@ impl ToTokens for GuardType { struct Args { path: syn::LitStr, guards: Vec, + wrappers: Vec, } impl Args { fn new(args: AttributeArgs) -> syn::Result { let mut path = None; let mut guards = Vec::new(); + let mut wrappers = Vec::new(); for arg in args { match arg { NestedMeta::Lit(syn::Lit::Str(lit)) => match path { @@ -85,10 +87,19 @@ impl Args { "Attribute guard expects literal string!", )); } + } else if nv.path.is_ident("wrap") { + if let syn::Lit::Str(lit) = nv.lit { + wrappers.push(lit.parse()?); + } else { + return Err(syn::Error::new_spanned( + nv.lit, + "Attribute wrap expects type", + )); + } } else { return Err(syn::Error::new_spanned( nv.path, - "Unknown attribute key is specified. Allowed: guard.", + "Unknown attribute key is specified. Allowed: guard and wrap", )); } } @@ -100,6 +111,7 @@ impl Args { Ok(Args { path: path.unwrap(), guards, + wrappers, }) } } @@ -184,7 +196,7 @@ impl ToTokens for Route { name, guard, ast, - args: Args { path, guards }, + args: Args { path, guards, wrappers }, resource_type, } = self; let resource_name = name.to_string(); @@ -199,6 +211,7 @@ impl ToTokens for Route { .name(#resource_name) .guard(actix_web::guard::#guard()) #(.guard(actix_web::guard::fn_guard(#guards)))* + #(.wrap(#wrappers))* .#resource_type(#name); actix_web::dev::HttpServiceFactory::register(__resource, __config) diff --git a/actix-web-codegen/tests/test_macro.rs b/actix-web-codegen/tests/test_macro.rs index ffb50c11e..8264a7fd7 100644 --- a/actix-web-codegen/tests/test_macro.rs +++ b/actix-web-codegen/tests/test_macro.rs @@ -1,6 +1,11 @@ -use actix_web::{http, test, web::Path, App, HttpResponse, Responder}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use actix_web::{http, test, web::Path, App, HttpResponse, Responder, Error}; +use actix_web::dev::{Service, Transform, ServiceRequest, ServiceResponse}; use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace}; use futures::{future, Future}; +use actix_web::http::header::{HeaderName, HeaderValue}; // Make sure that we can name function as 'config' #[get("/config")] @@ -73,6 +78,65 @@ async fn get_param_test(_: Path) -> impl Responder { HttpResponse::Ok() } +pub struct ChangeStatusCode; + +impl Transform for ChangeStatusCode +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = ChangeStatusCodeMiddleware; + type Future = future::Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + future::ok(ChangeStatusCodeMiddleware { service }) + } +} + +pub struct ChangeStatusCodeMiddleware { + service: S, +} + +impl Service for ChangeStatusCodeMiddleware +where + S: Service, Error = Error>, + S::Future: 'static, + B: 'static, +{ + type Request = ServiceRequest; + type Response = ServiceResponse; + type Error = Error; + type Future = Pin>>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: ServiceRequest) -> Self::Future { + + let fut = self.service.call(req); + + Box::pin(async move { + let mut res = fut.await?; + let headers = res.headers_mut(); + let header_name = HeaderName::from_lowercase(b"custom-header").unwrap(); + let header_value = HeaderValue::from_str("hello").unwrap(); + headers.insert(header_name, header_value); + Ok(res) + }) + } +} + +#[get("/test/wrap", wrap = "ChangeStatusCode")] +async fn get_wrap(_: Path) -> impl Responder { + HttpResponse::Ok() +} + #[actix_rt::test] async fn test_params() { let srv = test::start(|| { @@ -155,3 +219,15 @@ async fn test_auto_async() { let response = request.send().await.unwrap(); assert!(response.status().is_success()); } + +#[actix_rt::test] +async fn test_wrap() { + let srv = test::start(|| { + App::new() + .service(get_wrap) + }); + + let request = srv.request(http::Method::GET, srv.url("/test/wrap")); + let response = request.send().await.unwrap(); + assert!(response.headers().contains_key("custom-header")); +}