//! Route match predicates #![allow(non_snake_case)] use http; use http::{header, HttpTryFrom}; use httpmessage::HttpMessage; use httprequest::HttpRequest; use std::marker::PhantomData; /// Trait defines resource route predicate. /// Predicate can modify request object. It is also possible to /// to store extra attributes on request by using `Extensions` container, /// Extensions container available via `HttpRequest::extensions()` method. pub trait Predicate { /// Check if request matches predicate fn check(&self, &mut HttpRequest) -> bool; } /// Return predicate that matches if any of supplied predicate matches. /// /// ```rust /// # extern crate actix_web; /// use actix_web::{pred, App, HttpResponse}; /// /// fn main() { /// App::new().resource("/index.html", |r| { /// r.route() /// .filter(pred::Any(pred::Get()).or(pred::Post())) /// .f(|r| HttpResponse::MethodNotAllowed()) /// }); /// } /// ``` pub fn Any + 'static>(pred: P) -> AnyPredicate { AnyPredicate(vec![Box::new(pred)]) } /// Matches if any of supplied predicate matches. pub struct AnyPredicate(Vec>>); impl AnyPredicate { /// Add new predicate to list of predicates to check pub fn or + 'static>(mut self, pred: P) -> Self { self.0.push(Box::new(pred)); self } } impl Predicate for AnyPredicate { fn check(&self, req: &mut HttpRequest) -> bool { for p in &self.0 { if p.check(req) { return true; } } false } } /// Return predicate that matches if all of supplied predicate matches. /// /// ```rust /// # extern crate actix_web; /// use actix_web::{pred, App, HttpResponse}; /// /// fn main() { /// App::new().resource("/index.html", |r| { /// r.route() /// .filter( /// pred::All(pred::Get()) /// .and(pred::Header("content-type", "plain/text")), /// ) /// .f(|_| HttpResponse::MethodNotAllowed()) /// }); /// } /// ``` pub fn All + 'static>(pred: P) -> AllPredicate { AllPredicate(vec![Box::new(pred)]) } /// Matches if all of supplied predicate matches. pub struct AllPredicate(Vec>>); impl AllPredicate { /// Add new predicate to list of predicates to check pub fn and + 'static>(mut self, pred: P) -> Self { self.0.push(Box::new(pred)); self } } impl Predicate for AllPredicate { fn check(&self, req: &mut HttpRequest) -> bool { for p in &self.0 { if !p.check(req) { return false; } } true } } /// Return predicate that matches if supplied predicate does not match. pub fn Not + 'static>(pred: P) -> NotPredicate { NotPredicate(Box::new(pred)) } #[doc(hidden)] pub struct NotPredicate(Box>); impl Predicate for NotPredicate { fn check(&self, req: &mut HttpRequest) -> bool { !self.0.check(req) } } /// Http method predicate #[doc(hidden)] pub struct MethodPredicate(http::Method, PhantomData); impl Predicate for MethodPredicate { fn check(&self, req: &mut HttpRequest) -> bool { *req.method() == self.0 } } /// Predicate to match *GET* http method pub fn Get() -> MethodPredicate { MethodPredicate(http::Method::GET, PhantomData) } /// Predicate to match *POST* http method pub fn Post() -> MethodPredicate { MethodPredicate(http::Method::POST, PhantomData) } /// Predicate to match *PUT* http method pub fn Put() -> MethodPredicate { MethodPredicate(http::Method::PUT, PhantomData) } /// Predicate to match *DELETE* http method pub fn Delete() -> MethodPredicate { MethodPredicate(http::Method::DELETE, PhantomData) } /// Predicate to match *HEAD* http method pub fn Head() -> MethodPredicate { MethodPredicate(http::Method::HEAD, PhantomData) } /// Predicate to match *OPTIONS* http method pub fn Options() -> MethodPredicate { MethodPredicate(http::Method::OPTIONS, PhantomData) } /// Predicate to match *CONNECT* http method pub fn Connect() -> MethodPredicate { MethodPredicate(http::Method::CONNECT, PhantomData) } /// Predicate to match *PATCH* http method pub fn Patch() -> MethodPredicate { MethodPredicate(http::Method::PATCH, PhantomData) } /// Predicate to match *TRACE* http method pub fn Trace() -> MethodPredicate { MethodPredicate(http::Method::TRACE, PhantomData) } /// Predicate to match specified http method pub fn Method(method: http::Method) -> MethodPredicate { MethodPredicate(method, PhantomData) } /// Return predicate that matches if request contains specified header and /// value. pub fn Header( name: &'static str, value: &'static str, ) -> HeaderPredicate { HeaderPredicate( header::HeaderName::try_from(name).unwrap(), header::HeaderValue::from_static(value), PhantomData, ) } #[doc(hidden)] pub struct HeaderPredicate(header::HeaderName, header::HeaderValue, PhantomData); impl Predicate for HeaderPredicate { fn check(&self, req: &mut HttpRequest) -> bool { if let Some(val) = req.headers().get(&self.0) { return val == self.1; } false } } /// Return predicate that matches if request contains specified Host name. /// /// ```rust /// # extern crate actix_web; /// use actix_web::{pred, App, HttpResponse}; /// /// fn main() { /// App::new().resource("/index.html", |r| { /// r.route() /// .filter(pred::Host("www.rust-lang.org")) /// .f(|_| HttpResponse::MethodNotAllowed()) /// }); /// } /// ``` pub fn Host>(host: H) -> HostPredicate { HostPredicate(host.as_ref().to_string(), None, PhantomData) } #[doc(hidden)] pub struct HostPredicate(String, Option, PhantomData); impl HostPredicate { /// Set reuest scheme to match pub fn scheme>(&mut self, scheme: H) { self.1 = Some(scheme.as_ref().to_string()) } } impl Predicate for HostPredicate { fn check(&self, req: &mut HttpRequest) -> bool { let info = req.connection_info(); if let Some(ref scheme) = self.1 { self.0 == info.host() && scheme == info.scheme() } else { self.0 == info.host() } } } #[cfg(test)] mod tests { use super::*; use http::header::{self, HeaderMap}; use http::{Method, Uri, Version}; use std::str::FromStr; #[test] fn test_header() { let mut headers = HeaderMap::new(); headers.insert( header::TRANSFER_ENCODING, header::HeaderValue::from_static("chunked"), ); let mut req = HttpRequest::new( Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None, ); let pred = Header("transfer-encoding", "chunked"); assert!(pred.check(&mut req)); let pred = Header("transfer-encoding", "other"); assert!(!pred.check(&mut req)); let pred = Header("content-type", "other"); assert!(!pred.check(&mut req)); } #[test] fn test_host() { let mut headers = HeaderMap::new(); headers.insert( header::HOST, header::HeaderValue::from_static("www.rust-lang.org"), ); let mut req = HttpRequest::new( Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, headers, None, ); let pred = Host("www.rust-lang.org"); assert!(pred.check(&mut req)); let pred = Host("localhost"); assert!(!pred.check(&mut req)); } #[test] fn test_methods() { let mut req = HttpRequest::new( Method::GET, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); let mut req2 = HttpRequest::new( Method::POST, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Get().check(&mut req)); assert!(!Get().check(&mut req2)); assert!(Post().check(&mut req2)); assert!(!Post().check(&mut req)); let mut r = HttpRequest::new( Method::PUT, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Put().check(&mut r)); assert!(!Put().check(&mut req)); let mut r = HttpRequest::new( Method::DELETE, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Delete().check(&mut r)); assert!(!Delete().check(&mut req)); let mut r = HttpRequest::new( Method::HEAD, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Head().check(&mut r)); assert!(!Head().check(&mut req)); let mut r = HttpRequest::new( Method::OPTIONS, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Options().check(&mut r)); assert!(!Options().check(&mut req)); let mut r = HttpRequest::new( Method::CONNECT, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Connect().check(&mut r)); assert!(!Connect().check(&mut req)); let mut r = HttpRequest::new( Method::PATCH, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Patch().check(&mut r)); assert!(!Patch().check(&mut req)); let mut r = HttpRequest::new( Method::TRACE, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Trace().check(&mut r)); assert!(!Trace().check(&mut req)); } #[test] fn test_preds() { let mut r = HttpRequest::new( Method::TRACE, Uri::from_str("/").unwrap(), Version::HTTP_11, HeaderMap::new(), None, ); assert!(Not(Get()).check(&mut r)); assert!(!Not(Trace()).check(&mut r)); assert!(All(Trace()).and(Trace()).check(&mut r)); assert!(!All(Get()).and(Trace()).check(&mut r)); assert!(Any(Get()).or(Trace()).check(&mut r)); assert!(!Any(Get()).or(Get()).check(&mut r)); } }