1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2025-01-04 22:38:44 +00:00

add PayloadConfig

This commit is contained in:
Nikolay Kim 2018-04-04 21:13:48 -07:00
parent 7be4b1f399
commit 800f711cc1
2 changed files with 90 additions and 9 deletions

View file

@ -1,6 +1,7 @@
use std::str; use std::str;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use mime::Mime;
use bytes::Bytes; use bytes::Bytes;
use serde_urlencoded; use serde_urlencoded;
use serde::de::{self, DeserializeOwned}; use serde::de::{self, DeserializeOwned};
@ -301,12 +302,20 @@ impl Default for FormConfig {
/// ``` /// ```
impl<S: 'static> FromRequest<S> for Bytes impl<S: 'static> FromRequest<S> for Bytes
{ {
type Config = (); type Config = PayloadConfig;
type Result = Box<Future<Item=Self, Error=Error>>; type Result = Either<FutureResult<Self, Error>,
Box<Future<Item=Self, Error=Error>>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest<S>, _: &Self::Config) -> Self::Result { fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
Box::new(MessageBody::new(req.clone()).from_err()) // check content-type
if let Err(e) = cfg.check_mimetype(req) {
return Either::A(result(Err(e)));
}
Either::B(Box::new(MessageBody::new(req.clone())
.limit(cfg.limit)
.from_err()))
} }
} }
@ -328,12 +337,18 @@ impl<S: 'static> FromRequest<S> for Bytes
/// ``` /// ```
impl<S: 'static> FromRequest<S> for String impl<S: 'static> FromRequest<S> for String
{ {
type Config = (); type Config = PayloadConfig;
type Result = Either<FutureResult<String, Error>, type Result = Either<FutureResult<String, Error>,
Box<Future<Item=String, Error=Error>>>; Box<Future<Item=String, Error=Error>>>;
#[inline] #[inline]
fn from_request(req: &HttpRequest<S>, _: &Self::Config) -> Self::Result { fn from_request(req: &HttpRequest<S>, cfg: &Self::Config) -> Self::Result {
// check content-type
if let Err(e) = cfg.check_mimetype(req) {
return Either::A(result(Err(e)));
}
// check charset
let encoding = match req.encoding() { let encoding = match req.encoding() {
Err(_) => return Either::A( Err(_) => return Either::A(
result(Err(ErrorBadRequest("Unknown request charset")))), result(Err(ErrorBadRequest("Unknown request charset")))),
@ -342,6 +357,7 @@ impl<S: 'static> FromRequest<S> for String
Either::B(Box::new( Either::B(Box::new(
MessageBody::new(req.clone()) MessageBody::new(req.clone())
.limit(cfg.limit)
.from_err() .from_err()
.and_then(move |body| { .and_then(move |body| {
let enc: *const Encoding = encoding as *const Encoding; let enc: *const Encoding = encoding as *const Encoding;
@ -357,9 +373,57 @@ impl<S: 'static> FromRequest<S> for String
} }
} }
/// Payload configuration for request's payload.
pub struct PayloadConfig {
limit: usize,
mimetype: Option<Mime>,
}
impl PayloadConfig {
/// Change max size of payload. By default max size is 256Kb
pub fn limit(&mut self, limit: usize) -> &mut Self {
self.limit = limit;
self
}
/// Set required mime-type of the request. By default mime type is not enforced.
pub fn mimetype(&mut self, mt: Mime) -> &mut Self {
self.mimetype = Some(mt);
self
}
fn check_mimetype<S>(&self, req: &HttpRequest<S>) -> Result<(), Error> {
// check content-type
if let Some(ref mt) = self.mimetype {
match req.mime_type() {
Ok(Some(ref req_mt)) => {
if mt != req_mt {
return Err(ErrorBadRequest("Unexpected Content-Type"));
}
},
Ok(None) => {
return Err(ErrorBadRequest("Content-Type is expected"));
},
Err(err) => {
return Err(err.into());
},
}
}
Ok(())
}
}
impl Default for PayloadConfig {
fn default() -> Self {
PayloadConfig{limit: 262_144, mimetype: None}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use mime;
use bytes::Bytes; use bytes::Bytes;
use futures::{Async, Future}; use futures::{Async, Future};
use http::header; use http::header;
@ -375,10 +439,11 @@ mod tests {
#[test] #[test]
fn test_bytes() { fn test_bytes() {
let cfg = PayloadConfig::default();
let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish(); let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
match Bytes::from_request(&req, &()).poll().unwrap() { match Bytes::from_request(&req, &cfg).poll().unwrap() {
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s, Bytes::from_static(b"hello=world")); assert_eq!(s, Bytes::from_static(b"hello=world"));
}, },
@ -388,10 +453,11 @@ mod tests {
#[test] #[test]
fn test_string() { fn test_string() {
let cfg = PayloadConfig::default();
let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish(); let mut req = TestRequest::with_header(header::CONTENT_LENGTH, "11").finish();
req.payload_mut().unread_data(Bytes::from_static(b"hello=world")); req.payload_mut().unread_data(Bytes::from_static(b"hello=world"));
match String::from_request(&req, &()).poll().unwrap() { match String::from_request(&req, &cfg).poll().unwrap() {
Async::Ready(s) => { Async::Ready(s) => {
assert_eq!(s, "hello=world"); assert_eq!(s, "hello=world");
}, },
@ -417,6 +483,21 @@ mod tests {
} }
} }
#[test]
fn test_payload_config() {
let req = HttpRequest::default();
let mut cfg = PayloadConfig::default();
cfg.mimetype(mime::APPLICATION_JSON);
assert!(cfg.check_mimetype(&req).is_err());
let req = TestRequest::with_header(
header::CONTENT_TYPE, "application/x-www-form-urlencoded").finish();
assert!(cfg.check_mimetype(&req).is_err());
let req = TestRequest::with_header(header::CONTENT_TYPE, "application/json").finish();
assert!(cfg.check_mimetype(&req).is_ok());
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct MyStruct { struct MyStruct {
key: String, key: String,

View file

@ -180,7 +180,7 @@ pub mod dev {
pub use json::{JsonBody, JsonConfig}; pub use json::{JsonBody, JsonConfig};
pub use info::ConnectionInfo; pub use info::ConnectionInfo;
pub use handler::{Handler, Reply}; pub use handler::{Handler, Reply};
pub use extractor::{FormConfig}; pub use extractor::{FormConfig, PayloadConfig};
pub use route::Route; pub use route::Route;
pub use router::{Router, Resource, ResourceType}; pub use router::{Router, Resource, ResourceType};
pub use resource::ResourceHandler; pub use resource::ResourceHandler;