use std::{ future::Future, mem, pin::Pin, task::{Context, Poll}, }; use actix_http::{error::PayloadError, header, HttpMessage}; use bytes::Bytes; use futures_core::Stream; use pin_project_lite::pin_project; use super::{read_body::ReadBody, ResponseTimeout, DEFAULT_BODY_LIMIT}; use crate::ClientResponse; pin_project! { /// A `Future` that reads a body stream, resolving as [`Bytes`]. /// /// # Errors /// `Future` implementation returns error if: /// - content type is not `application/json`; /// - content length is greater than [limit](JsonBody::limit) (default: 2 MiB). pub struct ResponseBody { #[pin] body: Option>, length: Option, timeout: ResponseTimeout, err: Option, } } #[deprecated(since = "3.0.0", note = "Renamed to `ResponseBody`.")] pub type MessageBody = ResponseBody; impl ResponseBody where S: Stream>, { /// Creates a body stream reader from a response by taking its payload. pub fn new(res: &mut ClientResponse) -> ResponseBody { let length = match res.headers().get(&header::CONTENT_LENGTH) { Some(value) => { let len = value.to_str().ok().and_then(|s| s.parse::().ok()); match len { None => return Self::err(PayloadError::UnknownLength), len => len, } } None => None, }; ResponseBody { body: Some(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)), length, timeout: mem::take(&mut res.timeout), err: None, } } /// Change max size limit of payload. /// /// The default limit is 2 MiB. pub fn limit(mut self, limit: usize) -> Self { if let Some(ref mut body) = self.body { body.limit = limit; } self } fn err(err: PayloadError) -> Self { ResponseBody { body: None, length: None, timeout: ResponseTimeout::default(), err: Some(err), } } } impl Future for ResponseBody where S: Stream>, { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); if let Some(err) = this.err.take() { return Poll::Ready(Err(err)); } if let Some(len) = this.length.take() { let body = Option::as_ref(&this.body).unwrap(); if len > body.limit { return Poll::Ready(Err(PayloadError::Overflow)); } } this.timeout.poll_timeout(cx)?; this.body.as_pin_mut().unwrap().poll(cx) } } #[cfg(test)] mod tests { use static_assertions::assert_impl_all; use super::*; use crate::test::TestResponse; assert_impl_all!(ResponseBody<()>: Unpin); #[actix_rt::test] async fn read_body() { let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "xxxx")).finish(); match req.body().await.err().unwrap() { PayloadError::UnknownLength => {} _ => unreachable!("error"), } let mut req = TestResponse::with_header((header::CONTENT_LENGTH, "10000000")).finish(); match req.body().await.err().unwrap() { PayloadError::Overflow => {} _ => unreachable!("error"), } let mut req = TestResponse::default() .set_payload(Bytes::from_static(b"test")) .finish(); assert_eq!(req.body().await.ok().unwrap(), Bytes::from_static(b"test")); let mut req = TestResponse::default() .set_payload(Bytes::from_static(b"11111111111111")) .finish(); match req.body().limit(5).await.err().unwrap() { PayloadError::Overflow => {} _ => unreachable!("error"), } } }