From 210c9a5eb3be0995e9df097a1079e43f87909168 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Thu, 4 Jul 2024 04:53:10 +0100 Subject: [PATCH] refactor: multipart tweaks --- actix-multipart/src/error.rs | 10 +- actix-multipart/src/extractor.rs | 6 +- actix-multipart/src/field.rs | 16 +-- actix-multipart/src/lib.rs | 2 +- actix-multipart/src/payload.rs | 72 +++++++----- actix-multipart/src/server.rs | 195 +++++++++++++++++-------------- actix-multipart/src/test.rs | 3 +- 7 files changed, 169 insertions(+), 135 deletions(-) diff --git a/actix-multipart/src/error.rs b/actix-multipart/src/error.rs index 30ef63c1a..cdb608738 100644 --- a/actix-multipart/src/error.rs +++ b/actix-multipart/src/error.rs @@ -10,7 +10,7 @@ use derive_more::{Display, Error, From}; /// A set of errors that can occur during parsing multipart streams. #[derive(Debug, Display, From, Error)] #[non_exhaustive] -pub enum MultipartError { +pub enum Error { /// Could not find Content-Type header. #[display(fmt = "Could not find Content-Type header")] ContentTypeMissing, @@ -95,11 +95,11 @@ pub enum MultipartError { } /// Return `BadRequest` for `MultipartError`. -impl ResponseError for MultipartError { +impl ResponseError for Error { fn status_code(&self) -> StatusCode { match &self { - MultipartError::Field { source, .. } => source.as_response_error().status_code(), - MultipartError::ContentTypeIncompatible => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Error::Field { source, .. } => source.as_response_error().status_code(), + Error::ContentTypeIncompatible => StatusCode::UNSUPPORTED_MEDIA_TYPE, _ => StatusCode::BAD_REQUEST, } } @@ -111,7 +111,7 @@ mod tests { #[test] fn test_multipart_error() { - let resp = MultipartError::BoundaryMissing.error_response(); + let resp = Error::BoundaryMissing.error_response(); assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } } diff --git a/actix-multipart/src/extractor.rs b/actix-multipart/src/extractor.rs index ab98f887e..f7777100e 100644 --- a/actix-multipart/src/extractor.rs +++ b/actix-multipart/src/extractor.rs @@ -12,11 +12,11 @@ use crate::server::Multipart; /// # Examples /// /// ``` -/// use actix_web::{web, HttpResponse, Error}; +/// use actix_web::{web, HttpResponse}; /// use actix_multipart::Multipart; /// use futures_util::StreamExt as _; /// -/// async fn index(mut payload: Multipart) -> Result { +/// async fn index(mut payload: Multipart) -> actix_web::Result { /// // iterate over multipart stream /// while let Some(item) = payload.next().await { /// let mut field = item?; @@ -27,7 +27,7 @@ use crate::server::Multipart; /// } /// } /// -/// Ok(HttpResponse::Ok().into()) +/// Ok(HttpResponse::Ok().finish()) /// } /// ``` impl FromRequest for Multipart { diff --git a/actix-multipart/src/field.rs b/actix-multipart/src/field.rs index 86fbc8b2d..50660b5d3 100644 --- a/actix-multipart/src/field.rs +++ b/actix-multipart/src/field.rs @@ -15,7 +15,7 @@ use futures_core::stream::Stream; use mime::Mime; use crate::{ - error::MultipartError, + error::Error, payload::{PayloadBuffer, PayloadRef}, safety::Safety, }; @@ -106,7 +106,7 @@ impl Field { } impl Stream for Field { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -122,7 +122,7 @@ impl Stream for Field { buffer.poll_stream(cx)?; } else if !this.safety.is_clean() { // safety violation - return Poll::Ready(Some(Err(MultipartError::NotConsumed))); + return Poll::Ready(Some(Err(Error::NotConsumed))); } else { return Poll::Pending; } @@ -192,7 +192,7 @@ impl InnerField { pub(crate) fn read_len( payload: &mut PayloadBuffer, size: &mut u64, - ) -> Poll>> { + ) -> Poll>> { if *size == 0 { Poll::Ready(None) } else { @@ -208,7 +208,7 @@ impl InnerField { } None => { if payload.eof && (*size != 0) { - Poll::Ready(Some(Err(MultipartError::Incomplete))) + Poll::Ready(Some(Err(Error::Incomplete))) } else { Poll::Pending } @@ -223,13 +223,13 @@ impl InnerField { pub(crate) fn read_stream( payload: &mut PayloadBuffer, boundary: &str, - ) -> Poll>> { + ) -> Poll>> { let mut pos = 0; let len = payload.buf.len(); if len == 0 { return if payload.eof { - Poll::Ready(Some(Err(MultipartError::Incomplete))) + Poll::Ready(Some(Err(Error::Incomplete))) } else { Poll::Pending }; @@ -293,7 +293,7 @@ impl InnerField { } } - pub(crate) fn poll(&mut self, safety: &Safety) -> Poll>> { + pub(crate) fn poll(&mut self, safety: &Safety) -> Poll>> { if self.payload.is_none() { return Poll::Ready(None); } diff --git a/actix-multipart/src/lib.rs b/actix-multipart/src/lib.rs index d33f17097..744c27088 100644 --- a/actix-multipart/src/lib.rs +++ b/actix-multipart/src/lib.rs @@ -63,4 +63,4 @@ pub(crate) mod safety; mod server; pub mod test; -pub use self::{error::MultipartError, field::Field, server::Multipart}; +pub use self::{error::Error as MultipartError, field::Field, server::Multipart}; diff --git a/actix-multipart/src/payload.rs b/actix-multipart/src/payload.rs index a798f2c1a..ed5477997 100644 --- a/actix-multipart/src/payload.rs +++ b/actix-multipart/src/payload.rs @@ -1,6 +1,6 @@ use std::{ cell::{RefCell, RefMut}, - cmp, + cmp, mem, pin::Pin, rc::Rc, task::{Context, Poll}, @@ -12,7 +12,7 @@ use actix_web::{ }; use futures_core::stream::{LocalBoxStream, Stream}; -use crate::{error::MultipartError, safety::Safety}; +use crate::{error::Error, safety::Safety}; pub(crate) struct PayloadRef { payload: Rc>, @@ -21,7 +21,7 @@ pub(crate) struct PayloadRef { impl PayloadRef { pub(crate) fn new(payload: PayloadBuffer) -> PayloadRef { PayloadRef { - payload: Rc::new(payload.into()), + payload: Rc::new(RefCell::new(payload)), } } @@ -44,28 +44,33 @@ impl Clone for PayloadRef { /// Payload buffer. pub(crate) struct PayloadBuffer { - pub(crate) eof: bool, - pub(crate) buf: BytesMut, pub(crate) stream: LocalBoxStream<'static, Result>, + pub(crate) buf: BytesMut, + /// EOF flag. If true, no more payload reads will be attempted. + pub(crate) eof: bool, } impl PayloadBuffer { - /// Constructs new `PayloadBuffer` instance. + /// Constructs new payload buffer. pub(crate) fn new(stream: S) -> Self where S: Stream> + 'static, { PayloadBuffer { - eof: false, - buf: BytesMut::new(), stream: Box::pin(stream), + buf: BytesMut::with_capacity(1_024), // pre-allocate 1KiB + eof: false, } } pub(crate) fn poll_stream(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { loop { match Pin::new(&mut self.stream).poll_next(cx) { - Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data), + Poll::Ready(Some(Ok(data))) => { + self.buf.extend_from_slice(&data); + // try to read more data + continue; + } Poll::Ready(Some(Err(err))) => return Err(err), Poll::Ready(None) => { self.eof = true; @@ -76,7 +81,7 @@ impl PayloadBuffer { } } - /// Read exact number of bytes. + /// Reads exact number of bytes. #[cfg(test)] pub(crate) fn read_exact(&mut self, size: usize) -> Option { if size <= self.buf.len() { @@ -86,46 +91,57 @@ impl PayloadBuffer { } } - pub(crate) fn read_max(&mut self, size: u64) -> Result, MultipartError> { + pub(crate) fn read_max(&mut self, size: u64) -> Result, Error> { if !self.buf.is_empty() { let size = cmp::min(self.buf.len() as u64, size) as usize; Ok(Some(self.buf.split_to(size).freeze())) } else if self.eof { - Err(MultipartError::Incomplete) + Err(Error::Incomplete) } else { Ok(None) } } - /// Read until specified ending. - pub(crate) fn read_until(&mut self, line: &[u8]) -> Result, MultipartError> { - let res = memchr::memmem::find(&self.buf, line) - .map(|idx| self.buf.split_to(idx + line.len()).freeze()); + /// Reads until specified ending. + /// + /// Returns: + /// + /// - `Ok(Some(chunk))` - `needle` is found, with chunk ending after needle + /// - `Err(Incomplete)` - `needle` is not found and we're at EOF + /// - `Ok(None)` - `needle` is not found otherwise + pub(crate) fn read_until(&mut self, needle: &[u8]) -> Result, Error> { + match memchr::memmem::find(&self.buf, needle) { + // buffer exhausted and EOF without finding needle + None if self.eof => Err(Error::Incomplete), - if res.is_none() && self.eof { - Err(MultipartError::Incomplete) - } else { - Ok(res) + // needle not yet found + None => Ok(None), + + // needle found, split chunk out of buf + Some(idx) => Ok(Some(self.buf.split_to(idx + needle.len()).freeze())), } } - /// Read bytes until new line delimiter. - pub(crate) fn readline(&mut self) -> Result, MultipartError> { + /// Reads bytes until new line delimiter. + #[inline] + pub(crate) fn readline(&mut self) -> Result, Error> { self.read_until(b"\n") } - /// Read bytes until new line delimiter or EOF. - pub(crate) fn readline_or_eof(&mut self) -> Result, MultipartError> { + /// Reads bytes until new line delimiter or until EOF. + #[inline] + pub(crate) fn readline_or_eof(&mut self) -> Result, Error> { match self.readline() { - Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())), + Err(Error::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())), line => line, } } - /// Put unprocessed data back to the buffer. + /// Puts unprocessed data back to the buffer. pub(crate) fn unprocessed(&mut self, data: Bytes) { - let buf = BytesMut::from(data.as_ref()); - let buf = std::mem::replace(&mut self.buf, buf); + // TODO: use BytesMut::from when it's released, see https://github.com/tokio-rs/bytes/pull/710 + let buf = BytesMut::from(&data[..]); + let buf = mem::replace(&mut self.buf, buf); self.buf.extend_from_slice(&buf); } } diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index d0ed5be59..dc6a9ecb7 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -18,7 +18,7 @@ use futures_core::stream::Stream; use mime::Mime; use crate::{ - error::MultipartError, + error::Error, field::InnerField, payload::{PayloadBuffer, PayloadRef}, safety::Safety, @@ -33,9 +33,15 @@ const MAX_HEADERS: usize = 32; /// implementation. `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart` is /// used for nested multipart streams. pub struct Multipart { + flow: Flow, safety: Safety, - inner: Option, - error: Option, +} + +enum Flow { + InFlight(Inner), + + /// Error container is Some until an error is returned out of the flow. + Error(Option), } impl Multipart { @@ -59,24 +65,22 @@ impl Multipart { } /// Extract Content-Type and boundary info from headers. - pub(crate) fn find_ct_and_boundary( - headers: &HeaderMap, - ) -> Result<(Mime, String), MultipartError> { + pub(crate) fn find_ct_and_boundary(headers: &HeaderMap) -> Result<(Mime, String), Error> { let content_type = headers .get(&header::CONTENT_TYPE) - .ok_or(MultipartError::ContentTypeMissing)? + .ok_or(Error::ContentTypeMissing)? .to_str() .ok() .and_then(|content_type| content_type.parse::().ok()) - .ok_or(MultipartError::ContentTypeParse)?; + .ok_or(Error::ContentTypeParse)?; if content_type.type_() != mime::MULTIPART { - return Err(MultipartError::ContentTypeIncompatible); + return Err(Error::ContentTypeIncompatible); } let boundary = content_type .get_param(mime::BOUNDARY) - .ok_or(MultipartError::BoundaryMissing)? + .ok_or(Error::BoundaryMissing)? .as_str() .to_owned(); @@ -90,64 +94,57 @@ impl Multipart { { Multipart { safety: Safety::new(), - inner: Some(Inner { + flow: Flow::InFlight(Inner { payload: PayloadRef::new(PayloadBuffer::new(stream)), content_type: ct, boundary, state: State::FirstBoundary, item: Item::None, }), - error: None, } } /// Constructs a new multipart reader from given `MultipartError`. - pub(crate) fn from_error(err: MultipartError) -> Multipart { + pub(crate) fn from_error(err: Error) -> Multipart { Multipart { - error: Some(err), + flow: Flow::Error(Some(err)), safety: Safety::new(), - inner: None, } } /// Return requests parsed Content-Type or raise the stored error. - pub(crate) fn content_type_or_bail(&mut self) -> Result { - if let Some(err) = self.error.take() { - return Err(err); + pub(crate) fn content_type_or_bail(&mut self) -> Result { + match self.flow { + Flow::InFlight(ref inner) => Ok(inner.content_type.clone()), + Flow::Error(ref mut err) => Err(err + .take() + .expect("error should not be taken after it was returned")), } - - Ok(self - .inner - .as_ref() - // TODO: look into using enum instead of two options - .expect("multipart requests should have state") - .content_type - .clone()) } } impl Stream for Multipart { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - match this.inner.as_mut() { - Some(inner) => { + match this.flow { + Flow::InFlight(ref mut inner) => { if let Some(mut buffer) = inner.payload.get_mut(&this.safety) { // check safety and poll read payload to buffer. buffer.poll_stream(cx)?; } else if !this.safety.is_clean() { // safety violation - return Poll::Ready(Some(Err(MultipartError::NotConsumed))); + return Poll::Ready(Some(Err(Error::NotConsumed))); } else { return Poll::Pending; } inner.poll(&this.safety, cx) } - None => Poll::Ready(Some(Err(this - .error + + Flow::Error(ref mut err) => Poll::Ready(Some(Err(err .take() .expect("Multipart polled after finish")))), } @@ -191,22 +188,21 @@ struct Inner { } impl Inner { - fn read_field_headers( - payload: &mut PayloadBuffer, - ) -> Result, MultipartError> { + fn read_field_headers(payload: &mut PayloadBuffer) -> Result, Error> { match payload.read_until(b"\r\n\r\n")? { None => { if payload.eof { - Err(MultipartError::Incomplete) + Err(Error::Incomplete) } else { Ok(None) } } + Some(bytes) => { let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS]; - match httparse::parse_headers(&bytes, &mut hdrs) { - Ok(httparse::Status::Complete((_, hdrs))) => { + match httparse::parse_headers(&bytes, &mut hdrs).map_err(ParseError::from)? { + httparse::Status::Complete((_, hdrs)) => { // convert headers let mut headers = HeaderMap::with_capacity(hdrs.len()); @@ -220,57 +216,84 @@ impl Inner { Ok(Some(headers)) } - Ok(httparse::Status::Partial) => Err(ParseError::Header.into()), - Err(err) => Err(ParseError::from(err).into()), + + httparse::Status::Partial => Err(ParseError::Header.into()), } } } } - fn read_boundary( - payload: &mut PayloadBuffer, - boundary: &str, - ) -> Result, MultipartError> { + /// Reads a field boundary from the payload buffer (and discards it). + /// + /// Reads "in-between" and "final" boundaries. E.g. for boundary = "foo": + /// + /// ```plain + /// --foo <-- in-between fields + /// --foo-- <-- end of request body, should be followed by EOF + /// ``` + /// + /// Returns: + /// + /// - `Ok(Some(true))` - final field boundary read (EOF) + /// - `Ok(Some(false))` - field boundary read + /// - `Ok(None)` - boundary not found, more data needs reading + /// - `Err(BoundaryMissing)` - multipart boundary is missing + fn read_boundary(payload: &mut PayloadBuffer, boundary: &str) -> Result, Error> { // TODO: need to read epilogue - match payload.readline_or_eof()? { - None => { - if payload.eof { - Ok(Some(true)) - } else { - Ok(None) - } - } - Some(chunk) => { - if chunk.len() < boundary.len() + 4 - || &chunk[..2] != b"--" - || &chunk[2..boundary.len() + 2] != boundary.as_bytes() - { - Err(MultipartError::BoundaryMissing) - } else if &chunk[boundary.len() + 2..] == b"\r\n" { - Ok(Some(false)) - } else if &chunk[boundary.len() + 2..boundary.len() + 4] == b"--" - && (chunk.len() == boundary.len() + 4 - || &chunk[boundary.len() + 4..] == b"\r\n") - { - Ok(Some(true)) - } else { - Err(MultipartError::BoundaryMissing) - } - } + let chunk = match payload.readline_or_eof()? { + // TODO: this might be okay as a let Some() else return Ok(None) + None => return Ok(payload.eof.then_some(true)), + Some(chunk) => chunk, + }; + + const BOUNDARY_MARKER: &[u8] = b"--"; + const LINE_BREAK: &[u8] = b"\r\n"; + + let boundary_len = boundary.len(); + + if chunk.len() < boundary_len + 2 + 2 + || !chunk.starts_with(BOUNDARY_MARKER) + || &chunk[2..boundary_len + 2] != boundary.as_bytes() + { + return Err(Error::BoundaryMissing); } + + // chunk facts: + // - long enough to contain boundary + 2 markers or 1 marker and line-break + // - starts with boundary marker + // - chunk contains correct boundary + + if &chunk[boundary_len + 2..] == LINE_BREAK { + // boundary is followed by line-break, indicating more fields to come + return Ok(Some(false)); + } + + // boundary is followed by marker + if &chunk[boundary_len + 2..boundary_len + 4] == BOUNDARY_MARKER + && ( + // chunk is exactly boundary len + 2 markers + chunk.len() == boundary_len + 2 + 2 + // final boundary is allowed to end with a line-break + || &chunk[boundary_len + 4..] == LINE_BREAK + ) + { + return Ok(Some(true)); + } + + Err(Error::BoundaryMissing) } fn skip_until_boundary( payload: &mut PayloadBuffer, boundary: &str, - ) -> Result, MultipartError> { + ) -> Result, Error> { let mut eof = false; loop { match payload.readline()? { Some(chunk) => { if chunk.is_empty() { - return Err(MultipartError::BoundaryMissing); + return Err(Error::BoundaryMissing); } if chunk.len() < boundary.len() { continue; @@ -292,7 +315,7 @@ impl Inner { } None => { return if payload.eof { - Err(MultipartError::Incomplete) + Err(Error::Incomplete) } else { Ok(None) }; @@ -302,11 +325,7 @@ impl Inner { Ok(Some(eof)) } - fn poll( - &mut self, - safety: &Safety, - cx: &Context<'_>, - ) -> Poll>> { + fn poll(&mut self, safety: &Safety, cx: &Context<'_>) -> Poll>> { if self.state == State::Eof { Poll::Ready(None) } else { @@ -338,6 +357,7 @@ impl Inner { // read until first boundary State::FirstBoundary => { match Inner::skip_until_boundary(&mut payload, &self.boundary)? { + None => return Poll::Pending, Some(eof) => { if eof { self.state = State::Eof; @@ -346,7 +366,6 @@ impl Inner { self.state = State::Headers; } } - None => return Poll::Pending, } } @@ -398,11 +417,11 @@ impl Inner { // type must be set as "form-data", and it must have a name parameter. let Some(cd) = &field_content_disposition else { - return Poll::Ready(Some(Err(MultipartError::ContentDispositionMissing))); + return Poll::Ready(Some(Err(Error::ContentDispositionMissing))); }; let Some(field_name) = cd.get_name() else { - return Poll::Ready(Some(Err(MultipartError::ContentDispositionNameMissing))); + return Poll::Ready(Some(Err(Error::ContentDispositionNameMissing))); }; Some(field_name.to_owned()) @@ -422,7 +441,7 @@ impl Inner { // nested multipart stream is not supported if let Some(mime) = &field_content_type { if mime.type_() == mime::MULTIPART { - return Poll::Ready(Some(Err(MultipartError::Nested))); + return Poll::Ready(Some(Err(Error::Nested))); } } @@ -475,7 +494,7 @@ mod tests { async fn test_boundary() { let headers = HeaderMap::new(); match Multipart::find_ct_and_boundary(&headers) { - Err(MultipartError::ContentTypeMissing) => {} + Err(Error::ContentTypeMissing) => {} _ => unreachable!("should not happen"), } @@ -486,7 +505,7 @@ mod tests { ); match Multipart::find_ct_and_boundary(&headers) { - Err(MultipartError::ContentTypeParse) => {} + Err(Error::ContentTypeParse) => {} _ => unreachable!("should not happen"), } @@ -496,7 +515,7 @@ mod tests { header::HeaderValue::from_static("multipart/mixed"), ); match Multipart::find_ct_and_boundary(&headers) { - Err(MultipartError::BoundaryMissing) => {} + Err(Error::BoundaryMissing) => {} _ => unreachable!("should not happen"), } @@ -831,7 +850,7 @@ mod tests { #[actix_rt::test] async fn test_multipart_from_error() { - let err = MultipartError::ContentTypeMissing; + let err = Error::ContentTypeMissing; let mut multipart = Multipart::from_error(err); assert!(multipart.next().await.unwrap().is_err()) } @@ -888,7 +907,7 @@ mod tests { res.expect_err( "according to RFC 7578, form-data fields require a content-disposition header" ), - MultipartError::ContentDispositionMissing + Error::ContentDispositionMissing ); } @@ -942,7 +961,7 @@ mod tests { let res = multipart.next().await.unwrap(); assert_matches!( res.expect_err("according to RFC 7578, form-data fields require a name attribute"), - MultipartError::ContentDispositionNameMissing + Error::ContentDispositionNameMissing ); } @@ -960,7 +979,7 @@ mod tests { // should fail immediately match field.next().await { - Some(Err(MultipartError::NotConsumed)) => {} + Some(Err(Error::NotConsumed)) => {} _ => panic!(), }; } diff --git a/actix-multipart/src/test.rs b/actix-multipart/src/test.rs index 956595355..7dec85f8e 100644 --- a/actix-multipart/src/test.rs +++ b/actix-multipart/src/test.rs @@ -25,8 +25,7 @@ const BOUNDARY_PREFIX: &str = "------------------------"; /// /// ``` /// use actix_multipart::test::create_form_data_payload_and_headers; -/// use actix_web::test::TestRequest; -/// use bytes::Bytes; +/// use actix_web::{test::TestRequest, web::Bytes}; /// use memchr::memmem::find; /// /// let (body, headers) = create_form_data_payload_and_headers(