diff --git a/awc/src/response.rs b/awc/src/response.rs index 40de3dc17..994ddb761 100644 --- a/awc/src/response.rs +++ b/awc/src/response.rs @@ -228,12 +228,13 @@ impl fmt::Debug for ClientResponse { } } +const DEFAULT_BODY_LIMIT: usize = 2 * 1024 * 1024; + /// Future that resolves to a complete HTTP message body. pub struct MessageBody { length: Option, - err: Option, timeout: ResponseTimeout, - fut: Option>, + body: Result, Option>, } impl MessageBody @@ -242,41 +243,38 @@ where { /// Create `MessageBody` for request. pub fn new(res: &mut ClientResponse) -> MessageBody { - let mut len = None; - if let Some(l) = res.headers().get(&header::CONTENT_LENGTH) { - if let Ok(s) = l.to_str() { - if let Ok(l) = s.parse::() { - len = Some(l) - } else { - return Self::err(PayloadError::UnknownLength); + let length = match res.headers().get(&header::CONTENT_LENGTH) { + Some(value) => { + let len = value.to_str().ok().and_then(|s| s.parse::().ok()); + + match len { + None => return Self::err(PayloadError::UnknownLength), + len => len, } - } else { - return Self::err(PayloadError::UnknownLength); } - } + None => None, + }; MessageBody { - length: len, - err: None, + length, timeout: std::mem::take(&mut res.timeout), - fut: Some(ReadBody::new(res.take_payload(), 262_144)), + body: Ok(ReadBody::new(res.take_payload(), DEFAULT_BODY_LIMIT)), } } - /// Change max size of payload. By default max size is 256kB + /// Change max size of payload. By default max size is 2048kB pub fn limit(mut self, limit: usize) -> Self { - if let Some(ref mut fut) = self.fut { - fut.limit = limit; + if let Ok(ref mut body) = self.body { + body.limit = limit; } self } fn err(e: PayloadError) -> Self { MessageBody { - fut: None, - err: Some(e), length: None, timeout: ResponseTimeout::default(), + body: Err(Some(e)), } } } @@ -290,19 +288,20 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - if let Some(err) = this.err.take() { - return Poll::Ready(Err(err)); - } + match this.body { + Err(ref mut err) => Poll::Ready(Err(err.take().unwrap())), + Ok(ref mut body) => { + if let Some(len) = this.length.take() { + if len > body.limit { + return Poll::Ready(Err(PayloadError::Overflow)); + } + } - if let Some(len) = this.length.take() { - if len > this.fut.as_ref().unwrap().limit { - return Poll::Ready(Err(PayloadError::Overflow)); + this.timeout.poll_timeout(cx)?; + + Pin::new(body).poll(cx) } } - - this.timeout.poll_timeout(cx)?; - - Pin::new(&mut this.fut.as_mut().unwrap()).poll(cx) } } @@ -415,7 +414,7 @@ impl ReadBody { fn new(stream: Payload, limit: usize) -> Self { Self { stream, - buf: BytesMut::with_capacity(std::cmp::min(limit, 32768)), + buf: BytesMut::new(), limit, } } @@ -430,20 +429,14 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); - loop { - return match Pin::new(&mut this.stream).poll_next(cx)? { - Poll::Ready(Some(chunk)) => { - if (this.buf.len() + chunk.len()) > this.limit { - Poll::Ready(Err(PayloadError::Overflow)) - } else { - this.buf.extend_from_slice(&chunk); - continue; - } - } - Poll::Ready(None) => Poll::Ready(Ok(this.buf.split().freeze())), - Poll::Pending => Poll::Pending, - }; + while let Some(chunk) = ready!(Pin::new(&mut this.stream).poll_next(cx)?) { + if (this.buf.len() + chunk.len()) > this.limit { + return Poll::Ready(Err(PayloadError::Overflow)); + } + this.buf.extend_from_slice(&chunk); } + + Poll::Ready(Ok(this.buf.split().freeze())) } } @@ -462,7 +455,7 @@ mod tests { _ => unreachable!("error"), } - let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "1000000").finish(); + let mut req = TestResponse::with_header(header::CONTENT_LENGTH, "10000000").finish(); match req.body().await.err().unwrap() { PayloadError::Overflow => {} _ => unreachable!("error"),