1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-11-20 08:31:09 +00:00

remove boxed futures on Json extract type (#1832)

This commit is contained in:
fakeshadow 2020-12-17 07:34:33 +08:00 committed by GitHub
parent 1a361273e7
commit 97f615c245
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,14 +1,16 @@
//! Json extractor/responder
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{fmt, ops};
use bytes::BytesMut;
use futures_util::future::{err, ok, FutureExt, LocalBoxFuture, Ready};
use futures_util::StreamExt;
use futures_util::future::{ready, Ready};
use futures_util::ready;
use futures_util::stream::Stream;
use serde::de::DeserializeOwned;
use serde::Serialize;
@ -127,12 +129,12 @@ impl<T: Serialize> Responder for Json<T> {
fn respond_to(self, _: &HttpRequest) -> Self::Future {
let body = match serde_json::to_string(&self.0) {
Ok(body) => body,
Err(e) => return err(e.into()),
Err(e) => return ready(Err(e.into())),
};
ok(Response::build(StatusCode::OK)
ready(Ok(Response::build(StatusCode::OK)
.content_type("application/json")
.body(body))
.body(body)))
}
}
@ -173,37 +175,64 @@ where
T: DeserializeOwned + 'static,
{
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self, Error>>;
type Future = JsonExtractFut<T>;
type Config = JsonConfig;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
let req2 = req.clone();
let config = JsonConfig::from_req(req);
let limit = config.limit;
let ctype = config.content_type.clone();
let err_handler = config.err_handler.clone();
JsonBody::new(req, payload, ctype)
.limit(limit)
.map(move |res| match res {
Err(e) => {
log::debug!(
"Failed to deserialize Json from payload. \
Request path: {}",
req2.path()
);
JsonExtractFut {
req: Some(req.clone()),
fut: JsonBody::new(req, payload, ctype).limit(limit),
err_handler,
}
}
}
if let Some(err) = err_handler {
Err((*err)(e, &req2))
} else {
Err(e.into())
}
type JsonErrorHandler =
Option<Arc<dyn Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync>>;
pub struct JsonExtractFut<T> {
req: Option<HttpRequest>,
fut: JsonBody<T>,
err_handler: JsonErrorHandler,
}
impl<T> Future for JsonExtractFut<T>
where
T: DeserializeOwned + 'static,
{
type Output = Result<Json<T>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let res = ready!(Pin::new(&mut this.fut).poll(cx));
let res = match res {
Err(e) => {
let req = this.req.take().unwrap();
log::debug!(
"Failed to deserialize Json from payload. \
Request path: {}",
req.path()
);
if let Some(err) = this.err_handler.as_ref() {
Err((*err)(e, &req))
} else {
Err(e.into())
}
Ok(data) => Ok(Json(data)),
})
.boxed_local()
}
Ok(data) => Ok(Json(data)),
};
Poll::Ready(res)
}
}
@ -248,8 +277,7 @@ where
#[derive(Clone)]
pub struct JsonConfig {
limit: usize,
err_handler:
Option<Arc<dyn Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync>>,
err_handler: JsonErrorHandler,
content_type: Option<Arc<dyn Fn(mime::Mime) -> bool + Send + Sync>>,
}
@ -308,17 +336,22 @@ impl Default for JsonConfig {
/// * content type is not `application/json`
/// (unless specified in [`JsonConfig`])
/// * content length is greater than 256k
pub struct JsonBody<U> {
limit: usize,
length: Option<usize>,
#[cfg(feature = "compress")]
stream: Option<Decompress<Payload>>,
#[cfg(not(feature = "compress"))]
stream: Option<Payload>,
err: Option<JsonPayloadError>,
fut: Option<LocalBoxFuture<'static, Result<U, JsonPayloadError>>>,
pub enum JsonBody<U> {
Error(Option<JsonPayloadError>),
Body {
limit: usize,
length: Option<usize>,
#[cfg(feature = "compress")]
payload: Decompress<Payload>,
#[cfg(not(feature = "compress"))]
payload: Payload,
buf: BytesMut,
_res: PhantomData<U>,
},
}
impl<U> Unpin for JsonBody<U> {}
impl<U> JsonBody<U>
where
U: DeserializeOwned + 'static,
@ -340,39 +373,58 @@ where
};
if !json {
return JsonBody {
limit: 262_144,
length: None,
stream: None,
fut: None,
err: Some(JsonPayloadError::ContentType),
};
return JsonBody::Error(Some(JsonPayloadError::ContentType));
}
let len = req
let length = req
.headers()
.get(&CONTENT_LENGTH)
.and_then(|l| l.to_str().ok())
.and_then(|s| s.parse::<usize>().ok());
// Notice the content_length is not checked against limit of json config here.
// As the internal usage always call JsonBody::limit after JsonBody::new.
// And limit check to return an error variant of JsonBody happens there.
#[cfg(feature = "compress")]
let payload = Decompress::from_headers(payload.take(), req.headers());
#[cfg(not(feature = "compress"))]
let payload = payload.take();
JsonBody {
JsonBody::Body {
limit: 262_144,
length: len,
stream: Some(payload),
fut: None,
err: None,
length,
payload,
buf: BytesMut::with_capacity(8192),
_res: PhantomData,
}
}
/// Change max size of payload. By default max size is 256Kb
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
pub fn limit(self, limit: usize) -> Self {
match self {
JsonBody::Body {
length,
payload,
buf,
..
} => {
if let Some(len) = length {
if len > limit {
return JsonBody::Error(Some(JsonPayloadError::Overflow));
}
}
JsonBody::Body {
limit,
length,
payload,
buf,
_res: PhantomData,
}
}
JsonBody::Error(e) => JsonBody::Error(e),
}
}
}
@ -382,41 +434,34 @@ where
{
type Output = Result<U, JsonPayloadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(ref mut fut) = self.fut {
return Pin::new(fut).poll(cx);
}
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if let Some(err) = self.err.take() {
return Poll::Ready(Err(err));
}
let limit = self.limit;
if let Some(len) = self.length.take() {
if len > limit {
return Poll::Ready(Err(JsonPayloadError::Overflow));
}
}
let mut stream = self.stream.take().unwrap();
self.fut = Some(
async move {
let mut body = BytesMut::with_capacity(8192);
while let Some(item) = stream.next().await {
let chunk = item?;
if (body.len() + chunk.len()) > limit {
return Err(JsonPayloadError::Overflow);
} else {
body.extend_from_slice(&chunk);
match this {
JsonBody::Body {
limit,
buf,
payload,
..
} => loop {
let res = ready!(Pin::new(&mut *payload).poll_next(cx));
match res {
Some(chunk) => {
let chunk = chunk?;
if (buf.len() + chunk.len()) > *limit {
return Poll::Ready(Err(JsonPayloadError::Overflow));
} else {
buf.extend_from_slice(&chunk);
}
}
None => {
let json = serde_json::from_slice::<U>(&buf)?;
return Poll::Ready(Ok(json));
}
}
Ok(serde_json::from_slice::<U>(&body)?)
}
.boxed_local(),
);
self.poll(cx)
},
JsonBody::Error(e) => Poll::Ready(Err(e.take().unwrap())),
}
}
}