1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-11-28 20:41:48 +00:00

refactor codegen route guards

This commit is contained in:
Rob Ede 2023-02-06 17:06:47 +00:00
parent 65c0545a7a
commit 359d5d5c80
No known key found for this signature in database
GPG key ID: 97C636207D3EF933
3 changed files with 168 additions and 122 deletions

View file

@ -6,11 +6,11 @@ use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens, TokenStreamExt}; use quote::{quote, ToTokens, TokenStreamExt};
use syn::{parse_macro_input, AttributeArgs, Ident, LitStr, Meta, NestedMeta, Path}; use syn::{parse_macro_input, AttributeArgs, Ident, LitStr, Meta, NestedMeta, Path};
macro_rules! method_type { macro_rules! standard_method_type {
( (
$($variant:ident, $upper:ident, $lower:ident,)+ $($variant:ident, $upper:ident, $lower:ident,)+
) => { ) => {
#[derive(Debug, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MethodType { pub enum MethodType {
$( $(
$variant, $variant,
@ -27,13 +27,7 @@ macro_rules! method_type {
fn parse(method: &str) -> Result<Self, String> { fn parse(method: &str) -> Result<Self, String> {
match method { match method {
$(stringify!($upper) => Ok(Self::$variant),)+ $(stringify!($upper) => Ok(Self::$variant),)+
_ => { _ => Err(format!("HTTP method must be uppercase: `{}`", method)),
if method.chars().all(|c| c.is_ascii_uppercase()) {
Ok(Self::Method)
} else {
Err(format!("HTTP method must be uppercase: `{}`", method))
}
},
} }
} }
@ -47,13 +41,7 @@ macro_rules! method_type {
}; };
} }
#[derive(Eq, Hash, PartialEq)] standard_method_type! {
struct MethodTypeExt {
method: MethodType,
custom_method: Option<LitStr>,
}
method_type! {
Get, GET, get, Get, GET, get,
Post, POST, post, Post, POST, post,
Put, PUT, put, Put, PUT, put,
@ -63,7 +51,15 @@ method_type! {
Options, OPTIONS, options, Options, OPTIONS, options,
Trace, TRACE, trace, Trace, TRACE, trace,
Patch, PATCH, patch, Patch, PATCH, patch,
Method, METHOD, method, }
impl TryFrom<&syn::LitStr> for MethodType {
type Error = syn::Error;
fn try_from(value: &syn::LitStr) -> Result<Self, Self::Error> {
Self::parse(value.value().as_str())
.map_err(|message| syn::Error::new_spanned(value, message))
}
} }
impl ToTokens for MethodType { impl ToTokens for MethodType {
@ -73,27 +69,107 @@ impl ToTokens for MethodType {
} }
} }
impl ToTokens for MethodTypeExt { #[derive(Debug, Clone, PartialEq, Eq, Hash)]
fn to_tokens(&self, stream: &mut TokenStream2) { enum MethodTypeExt {
match self.method { Standard(MethodType),
MethodType::Method => { Custom(LitStr),
let ident = Ident::new( }
self.custom_method.as_ref().unwrap().value().as_str(),
Span::call_site(), impl MethodTypeExt {
); /// Returns a single method guard token stream.
stream.append(ident); fn to_tokens_single_guard(&self) -> TokenStream2 {
match self {
MethodTypeExt::Standard(method) => {
quote! {
.guard(::actix_web::guard::#method())
}
}
MethodTypeExt::Custom(lit) => {
quote! {
.guard(::actix_web::guard::Method(
::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap()
))
}
}
}
}
/// Returns a multi-method guard chain token stream.
fn to_tokens_multi_guard(&self, or_chain: Vec<impl ToTokens>) -> TokenStream2 {
debug_assert!(
or_chain.len() > 0,
"empty or_chain passed to multi-guard constructor"
);
match self {
MethodTypeExt::Standard(method) => {
quote! {
.guard(
::actix_web::guard::Any(::actix_web::guard::#method())
#(#or_chain)*
)
}
}
MethodTypeExt::Custom(lit) => {
quote! {
.guard(
::actix_web::guard::Any(
::actix_web::guard::Method(
::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap()
)
)
#(#or_chain)*
)
}
}
}
}
/// Returns a token stream containing the `.or` chain to be passed in to
/// [`MethodTypeExt::to_tokens_multi_guard()`].
fn to_tokens_multi_guard_or_chain(&self) -> TokenStream2 {
match self {
MethodTypeExt::Standard(method_type) => {
quote! {
.or(::actix_web::guard::#method_type())
}
}
MethodTypeExt::Custom(lit) => {
quote! {
.or(
::actix_web::guard::Method(
::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap()
)
)
}
} }
_ => self.method.to_tokens(stream),
} }
} }
} }
impl TryFrom<&syn::LitStr> for MethodType { impl ToTokens for MethodTypeExt {
fn to_tokens(&self, stream: &mut TokenStream2) {
match self {
MethodTypeExt::Custom(lit_str) => {
let ident = Ident::new(lit_str.value().as_str(), Span::call_site());
stream.append(ident);
}
MethodTypeExt::Standard(method) => method.to_tokens(stream),
}
}
}
impl TryFrom<&syn::LitStr> for MethodTypeExt {
type Error = syn::Error; type Error = syn::Error;
fn try_from(value: &syn::LitStr) -> Result<Self, Self::Error> { fn try_from(value: &syn::LitStr) -> Result<Self, Self::Error> {
Self::parse(value.value().as_str()) match MethodType::try_from(value) {
.map_err(|message| syn::Error::new_spanned(value, message)) Ok(method) => Ok(MethodTypeExt::Standard(method)),
Err(_) if value.value().chars().all(|c| c.is_ascii_uppercase()) => {
Ok(MethodTypeExt::Custom(value.clone()))
}
Err(err) => Err(err),
}
} }
} }
@ -127,12 +203,7 @@ impl Args {
let is_route_macro = method.is_none(); let is_route_macro = method.is_none();
if let Some(method) = method { if let Some(method) = method {
methods.insert({ methods.insert(MethodTypeExt::Standard(method));
MethodTypeExt {
method,
custom_method: None,
}
});
} }
for arg in args { for arg in args {
@ -149,6 +220,7 @@ impl Args {
)); ));
} }
}, },
NestedMeta::Meta(syn::Meta::NameValue(nv)) => { NestedMeta::Meta(syn::Meta::NameValue(nv)) => {
if nv.path.is_ident("name") { if nv.path.is_ident("name") {
if let syn::Lit::Str(lit) = nv.lit { if let syn::Lit::Str(lit) = nv.lit {
@ -184,23 +256,10 @@ impl Args {
"HTTP method forbidden here. To handle multiple methods, use `route` instead", "HTTP method forbidden here. To handle multiple methods, use `route` instead",
)); ));
} else if let syn::Lit::Str(ref lit) = nv.lit { } else if let syn::Lit::Str(ref lit) = nv.lit {
let method = MethodType::try_from(lit)?; if !methods.insert(MethodTypeExt::try_from(lit)?) {
if !methods.insert({
if method == MethodType::Method {
MethodTypeExt {
method,
custom_method: Some(lit.clone()),
}
} else {
MethodTypeExt {
method,
custom_method: None,
}
}
}) {
return Err(syn::Error::new_spanned( return Err(syn::Error::new_spanned(
&nv.lit, &nv.lit,
&format!( format!(
"HTTP method defined more than once: `{}`", "HTTP method defined more than once: `{}`",
lit.value() lit.value()
), ),
@ -219,11 +278,13 @@ impl Args {
)); ));
} }
} }
arg => { arg => {
return Err(syn::Error::new_spanned(arg, "Unknown attribute.")); return Err(syn::Error::new_spanned(arg, "Unknown attribute."));
} }
} }
} }
Ok(Args { Ok(Args {
path: path.unwrap(), path: path.unwrap(),
resource_name, resource_name,
@ -344,60 +405,22 @@ impl ToTokens for Route {
.map_or_else(|| name.to_string(), LitStr::value); .map_or_else(|| name.to_string(), LitStr::value);
let method_guards = { let method_guards = {
debug_assert!(methods.len() > 0, "Args::methods should not be empty");
let mut others = methods.iter(); let mut others = methods.iter();
let first = others.next().unwrap(); let first = others.next().unwrap();
let first_method = &first.method;
if methods.len() > 1 { if methods.len() > 1 {
let mut mult_method_guards: Vec<TokenStream2> = Vec::new(); let other_method_guards = others
for method_ext in methods { .map(|method_ext| method_ext.to_tokens_multi_guard_or_chain())
let method_type = &method_ext.method; .collect();
let custom_method = &method_ext.custom_method;
match custom_method { first.to_tokens_multi_guard(other_method_guards)
Some(lit) => {
mult_method_guards.push(quote! {
.or(::actix_web::guard::#method_type(::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap()))
});
}
None => {
mult_method_guards.push(quote! {
.or(::actix_web::guard::#method_type())
});
}
}
}
match &first.custom_method {
Some(lit) => {
quote! {
.guard(
::actix_web::guard::Any(::actix_web::guard::#first_method(::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap()))
#(#mult_method_guards)*
)
}
}
None => {
quote! {
.guard(
::actix_web::guard::Any(::actix_web::guard::#first_method())
#(#mult_method_guards)*
)
}
}
}
} else { } else {
match &first.custom_method { first.to_tokens_single_guard()
Some(lit) => {
quote! {
.guard(::actix_web::guard::#first_method(::actix_web::http::Method::from_bytes(#lit.as_bytes()).unwrap()))
}
}
None => {
quote! {
.guard(::actix_web::guard::#first_method())
}
}
}
} }
}; };
quote! { quote! {
let __resource = ::actix_web::Resource::new(#path) let __resource = ::actix_web::Resource::new(#path)
.name(#resource_name) .name(#resource_name)

View file

@ -86,6 +86,11 @@ async fn get_param_test(_: web::Path<String>) -> impl Responder {
HttpResponse::Ok() HttpResponse::Ok()
} }
#[route("/hello", method = "HELLO")]
async fn custom_route_test() -> impl Responder {
HttpResponse::Ok()
}
#[route( #[route(
"/multi", "/multi",
method = "GET", method = "GET",

View file

@ -1,19 +1,37 @@
use actix_web_codegen::*;
use actix_web::http::Method;
use std::str::FromStr; use std::str::FromStr;
#[route("/", method="UNEXPECTED")] use actix_web::http::Method;
use actix_web_codegen::route;
#[route("/single", method = "CUSTOM")]
async fn index() -> String { async fn index() -> String {
"Hello World!".to_owned() "Hello Single!".to_owned()
}
#[route("/multi", method = "GET", method = "CUSTOM")]
async fn custom() -> String {
"Hello Multi!".to_owned()
} }
#[actix_web::main] #[actix_web::main]
async fn main() { async fn main() {
use actix_web::App; use actix_web::App;
let srv = actix_test::start(|| App::new().service(index)); let srv = actix_test::start(|| App::new().service(index).service(custom));
let request = srv.request(Method::from_str("UNEXPECTED").unwrap(), srv.url("/")); let request = srv.request(Method::GET, srv.url("/"));
let response = request.send().await.unwrap();
assert!(response.status().is_client_error());
let request = srv.request(Method::from_str("CUSTOM").unwrap(), srv.url("/single"));
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(Method::GET, srv.url("/multi"));
let response = request.send().await.unwrap();
assert!(response.status().is_success());
let request = srv.request(Method::from_str("CUSTOM").unwrap(), srv.url("/multi"));
let response = request.send().await.unwrap(); let response = request.send().await.unwrap();
assert!(response.status().is_success()); assert!(response.status().is_success());
} }