1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-18 05:05:43 +00:00

add resource middleware on actix-web-codegen (#1467)

Co-authored-by: Yuki Okushi <huyuumi.dev@gmail.com>
This commit is contained in:
Quentin de Quelen 2020-05-07 11:31:12 +02:00 committed by GitHub
parent b521e9b221
commit 9164ed1f0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 3 deletions

View file

@ -21,6 +21,7 @@
//! //!
//! - `"path"` - Raw literal string with path for which to register handle. Mandatory. //! - `"path"` - Raw literal string with path for which to register handle. Mandatory.
//! - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard` //! - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard`
//! - `wrap="Middleware"` - Registers a resource middleware.
//! //!
//! ## Notes //! ## Notes
//! //!
@ -54,6 +55,7 @@ use proc_macro::TokenStream;
/// ///
/// - `"path"` - Raw literal string with path for which to register handler. Mandatory. /// - `"path"` - Raw literal string with path for which to register handler. Mandatory.
/// - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard` /// - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard`
/// - `wrap="Middleware"` - Registers a resource middleware.
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn get(args: TokenStream, input: TokenStream) -> TokenStream { pub fn get(args: TokenStream, input: TokenStream) -> TokenStream {
route::generate(args, input, route::GuardType::Get) route::generate(args, input, route::GuardType::Get)

View file

@ -56,12 +56,14 @@ impl ToTokens for GuardType {
struct Args { struct Args {
path: syn::LitStr, path: syn::LitStr,
guards: Vec<Ident>, guards: Vec<Ident>,
wrappers: Vec<syn::Type>,
} }
impl Args { impl Args {
fn new(args: AttributeArgs) -> syn::Result<Self> { fn new(args: AttributeArgs) -> syn::Result<Self> {
let mut path = None; let mut path = None;
let mut guards = Vec::new(); let mut guards = Vec::new();
let mut wrappers = Vec::new();
for arg in args { for arg in args {
match arg { match arg {
NestedMeta::Lit(syn::Lit::Str(lit)) => match path { NestedMeta::Lit(syn::Lit::Str(lit)) => match path {
@ -85,10 +87,19 @@ impl Args {
"Attribute guard expects literal string!", "Attribute guard expects literal string!",
)); ));
} }
} else if nv.path.is_ident("wrap") {
if let syn::Lit::Str(lit) = nv.lit {
wrappers.push(lit.parse()?);
} else {
return Err(syn::Error::new_spanned(
nv.lit,
"Attribute wrap expects type",
));
}
} else { } else {
return Err(syn::Error::new_spanned( return Err(syn::Error::new_spanned(
nv.path, nv.path,
"Unknown attribute key is specified. Allowed: guard.", "Unknown attribute key is specified. Allowed: guard and wrap",
)); ));
} }
} }
@ -100,6 +111,7 @@ impl Args {
Ok(Args { Ok(Args {
path: path.unwrap(), path: path.unwrap(),
guards, guards,
wrappers,
}) })
} }
} }
@ -184,7 +196,7 @@ impl ToTokens for Route {
name, name,
guard, guard,
ast, ast,
args: Args { path, guards }, args: Args { path, guards, wrappers },
resource_type, resource_type,
} = self; } = self;
let resource_name = name.to_string(); let resource_name = name.to_string();
@ -199,6 +211,7 @@ impl ToTokens for Route {
.name(#resource_name) .name(#resource_name)
.guard(actix_web::guard::#guard()) .guard(actix_web::guard::#guard())
#(.guard(actix_web::guard::fn_guard(#guards)))* #(.guard(actix_web::guard::fn_guard(#guards)))*
#(.wrap(#wrappers))*
.#resource_type(#name); .#resource_type(#name);
actix_web::dev::HttpServiceFactory::register(__resource, __config) actix_web::dev::HttpServiceFactory::register(__resource, __config)

View file

@ -1,6 +1,11 @@
use actix_web::{http, test, web::Path, App, HttpResponse, Responder}; use std::pin::Pin;
use std::task::{Context, Poll};
use actix_web::{http, test, web::Path, App, HttpResponse, Responder, Error};
use actix_web::dev::{Service, Transform, ServiceRequest, ServiceResponse};
use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace}; use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace};
use futures::{future, Future}; use futures::{future, Future};
use actix_web::http::header::{HeaderName, HeaderValue};
// Make sure that we can name function as 'config' // Make sure that we can name function as 'config'
#[get("/config")] #[get("/config")]
@ -73,6 +78,65 @@ async fn get_param_test(_: Path<String>) -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
pub struct ChangeStatusCode;
impl<S, B> Transform<S> for ChangeStatusCode
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = ChangeStatusCodeMiddleware<S>;
type Future = future::Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
future::ok(ChangeStatusCodeMiddleware { service })
}
}
pub struct ChangeStatusCodeMiddleware<S> {
service: S,
}
impl<S, B> Service for ChangeStatusCodeMiddleware<S>
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: ServiceRequest) -> Self::Future {
let fut = self.service.call(req);
Box::pin(async move {
let mut res = fut.await?;
let headers = res.headers_mut();
let header_name = HeaderName::from_lowercase(b"custom-header").unwrap();
let header_value = HeaderValue::from_str("hello").unwrap();
headers.insert(header_name, header_value);
Ok(res)
})
}
}
#[get("/test/wrap", wrap = "ChangeStatusCode")]
async fn get_wrap(_: Path<String>) -> impl Responder {
HttpResponse::Ok()
}
#[actix_rt::test] #[actix_rt::test]
async fn test_params() { async fn test_params() {
let srv = test::start(|| { let srv = test::start(|| {
@ -155,3 +219,15 @@ async fn test_auto_async() {
let response = request.send().await.unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }
#[actix_rt::test]
async fn test_wrap() {
let srv = test::start(|| {
App::new()
.service(get_wrap)
});
let request = srv.request(http::Method::GET, srv.url("/test/wrap"));
let response = request.send().await.unwrap();
assert!(response.headers().contains_key("custom-header"));
}