1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-03 05:48:45 +00:00

add resource middleware on actix-web-codegen (#1467)

Co-authored-by: Yuki Okushi <huyuumi.dev@gmail.com>
This commit is contained in:
Quentin de Quelen 2020-05-07 11:31:12 +02:00 committed by GitHub
parent b521e9b221
commit 9164ed1f0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 3 deletions

View file

@ -21,6 +21,7 @@
//!
//! - `"path"` - Raw literal string with path for which to register handle. Mandatory.
//! - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard`
//! - `wrap="Middleware"` - Registers a resource middleware.
//!
//! ## Notes
//!
@ -54,6 +55,7 @@ use proc_macro::TokenStream;
///
/// - `"path"` - Raw literal string with path for which to register handler. Mandatory.
/// - `guard="function_name"` - Registers function as guard using `actix_web::guard::fn_guard`
/// - `wrap="Middleware"` - Registers a resource middleware.
#[proc_macro_attribute]
pub fn get(args: TokenStream, input: TokenStream) -> TokenStream {
route::generate(args, input, route::GuardType::Get)

View file

@ -56,12 +56,14 @@ impl ToTokens for GuardType {
struct Args {
path: syn::LitStr,
guards: Vec<Ident>,
wrappers: Vec<syn::Type>,
}
impl Args {
fn new(args: AttributeArgs) -> syn::Result<Self> {
let mut path = None;
let mut guards = Vec::new();
let mut wrappers = Vec::new();
for arg in args {
match arg {
NestedMeta::Lit(syn::Lit::Str(lit)) => match path {
@ -85,10 +87,19 @@ impl Args {
"Attribute guard expects literal string!",
));
}
} else if nv.path.is_ident("wrap") {
if let syn::Lit::Str(lit) = nv.lit {
wrappers.push(lit.parse()?);
} else {
return Err(syn::Error::new_spanned(
nv.lit,
"Attribute wrap expects type",
));
}
} else {
return Err(syn::Error::new_spanned(
nv.path,
"Unknown attribute key is specified. Allowed: guard.",
"Unknown attribute key is specified. Allowed: guard and wrap",
));
}
}
@ -100,6 +111,7 @@ impl Args {
Ok(Args {
path: path.unwrap(),
guards,
wrappers,
})
}
}
@ -184,7 +196,7 @@ impl ToTokens for Route {
name,
guard,
ast,
args: Args { path, guards },
args: Args { path, guards, wrappers },
resource_type,
} = self;
let resource_name = name.to_string();
@ -199,6 +211,7 @@ impl ToTokens for Route {
.name(#resource_name)
.guard(actix_web::guard::#guard())
#(.guard(actix_web::guard::fn_guard(#guards)))*
#(.wrap(#wrappers))*
.#resource_type(#name);
actix_web::dev::HttpServiceFactory::register(__resource, __config)

View file

@ -1,6 +1,11 @@
use actix_web::{http, test, web::Path, App, HttpResponse, Responder};
use std::pin::Pin;
use std::task::{Context, Poll};
use actix_web::{http, test, web::Path, App, HttpResponse, Responder, Error};
use actix_web::dev::{Service, Transform, ServiceRequest, ServiceResponse};
use actix_web_codegen::{connect, delete, get, head, options, patch, post, put, trace};
use futures::{future, Future};
use actix_web::http::header::{HeaderName, HeaderValue};
// Make sure that we can name function as 'config'
#[get("/config")]
@ -73,6 +78,65 @@ async fn get_param_test(_: Path<String>) -> impl Responder {
HttpResponse::Ok()
}
pub struct ChangeStatusCode;
impl<S, B> Transform<S> for ChangeStatusCode
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = ChangeStatusCodeMiddleware<S>;
type Future = future::Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
future::ok(ChangeStatusCodeMiddleware { service })
}
}
pub struct ChangeStatusCodeMiddleware<S> {
service: S,
}
impl<S, B> Service for ChangeStatusCodeMiddleware<S>
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: ServiceRequest) -> Self::Future {
let fut = self.service.call(req);
Box::pin(async move {
let mut res = fut.await?;
let headers = res.headers_mut();
let header_name = HeaderName::from_lowercase(b"custom-header").unwrap();
let header_value = HeaderValue::from_str("hello").unwrap();
headers.insert(header_name, header_value);
Ok(res)
})
}
}
#[get("/test/wrap", wrap = "ChangeStatusCode")]
async fn get_wrap(_: Path<String>) -> impl Responder {
HttpResponse::Ok()
}
#[actix_rt::test]
async fn test_params() {
let srv = test::start(|| {
@ -155,3 +219,15 @@ async fn test_auto_async() {
let response = request.send().await.unwrap();
assert!(response.status().is_success());
}
#[actix_rt::test]
async fn test_wrap() {
let srv = test::start(|| {
App::new()
.service(get_wrap)
});
let request = srv.request(http::Method::GET, srv.url("/test/wrap"));
let response = request.send().await.unwrap();
assert!(response.headers().contains_key("custom-header"));
}