1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-11-26 03:21:08 +00:00

refactor: multipart tweaks

This commit is contained in:
Rob Ede 2024-07-04 04:53:10 +01:00
parent 00c185f617
commit 210c9a5eb3
No known key found for this signature in database
GPG key ID: 97C636207D3EF933
7 changed files with 169 additions and 135 deletions

View file

@ -10,7 +10,7 @@ use derive_more::{Display, Error, From};
/// A set of errors that can occur during parsing multipart streams. /// A set of errors that can occur during parsing multipart streams.
#[derive(Debug, Display, From, Error)] #[derive(Debug, Display, From, Error)]
#[non_exhaustive] #[non_exhaustive]
pub enum MultipartError { pub enum Error {
/// Could not find Content-Type header. /// Could not find Content-Type header.
#[display(fmt = "Could not find Content-Type header")] #[display(fmt = "Could not find Content-Type header")]
ContentTypeMissing, ContentTypeMissing,
@ -95,11 +95,11 @@ pub enum MultipartError {
} }
/// Return `BadRequest` for `MultipartError`. /// Return `BadRequest` for `MultipartError`.
impl ResponseError for MultipartError { impl ResponseError for Error {
fn status_code(&self) -> StatusCode { fn status_code(&self) -> StatusCode {
match &self { match &self {
MultipartError::Field { source, .. } => source.as_response_error().status_code(), Error::Field { source, .. } => source.as_response_error().status_code(),
MultipartError::ContentTypeIncompatible => StatusCode::UNSUPPORTED_MEDIA_TYPE, Error::ContentTypeIncompatible => StatusCode::UNSUPPORTED_MEDIA_TYPE,
_ => StatusCode::BAD_REQUEST, _ => StatusCode::BAD_REQUEST,
} }
} }
@ -111,7 +111,7 @@ mod tests {
#[test] #[test]
fn test_multipart_error() { fn test_multipart_error() {
let resp = MultipartError::BoundaryMissing.error_response(); let resp = Error::BoundaryMissing.error_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
} }
} }

View file

@ -12,11 +12,11 @@ use crate::server::Multipart;
/// # Examples /// # Examples
/// ///
/// ``` /// ```
/// use actix_web::{web, HttpResponse, Error}; /// use actix_web::{web, HttpResponse};
/// use actix_multipart::Multipart; /// use actix_multipart::Multipart;
/// use futures_util::StreamExt as _; /// use futures_util::StreamExt as _;
/// ///
/// async fn index(mut payload: Multipart) -> Result<HttpResponse, Error> { /// async fn index(mut payload: Multipart) -> actix_web::Result<HttpResponse> {
/// // iterate over multipart stream /// // iterate over multipart stream
/// while let Some(item) = payload.next().await { /// while let Some(item) = payload.next().await {
/// let mut field = item?; /// let mut field = item?;
@ -27,7 +27,7 @@ use crate::server::Multipart;
/// } /// }
/// } /// }
/// ///
/// Ok(HttpResponse::Ok().into()) /// Ok(HttpResponse::Ok().finish())
/// } /// }
/// ``` /// ```
impl FromRequest for Multipart { impl FromRequest for Multipart {

View file

@ -15,7 +15,7 @@ use futures_core::stream::Stream;
use mime::Mime; use mime::Mime;
use crate::{ use crate::{
error::MultipartError, error::Error,
payload::{PayloadBuffer, PayloadRef}, payload::{PayloadBuffer, PayloadRef},
safety::Safety, safety::Safety,
}; };
@ -106,7 +106,7 @@ impl Field {
} }
impl Stream for Field { impl Stream for Field {
type Item = Result<Bytes, MultipartError>; type Item = Result<Bytes, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut(); let this = self.get_mut();
@ -122,7 +122,7 @@ impl Stream for Field {
buffer.poll_stream(cx)?; buffer.poll_stream(cx)?;
} else if !this.safety.is_clean() { } else if !this.safety.is_clean() {
// safety violation // safety violation
return Poll::Ready(Some(Err(MultipartError::NotConsumed))); return Poll::Ready(Some(Err(Error::NotConsumed)));
} else { } else {
return Poll::Pending; return Poll::Pending;
} }
@ -192,7 +192,7 @@ impl InnerField {
pub(crate) fn read_len( pub(crate) fn read_len(
payload: &mut PayloadBuffer, payload: &mut PayloadBuffer,
size: &mut u64, size: &mut u64,
) -> Poll<Option<Result<Bytes, MultipartError>>> { ) -> Poll<Option<Result<Bytes, Error>>> {
if *size == 0 { if *size == 0 {
Poll::Ready(None) Poll::Ready(None)
} else { } else {
@ -208,7 +208,7 @@ impl InnerField {
} }
None => { None => {
if payload.eof && (*size != 0) { if payload.eof && (*size != 0) {
Poll::Ready(Some(Err(MultipartError::Incomplete))) Poll::Ready(Some(Err(Error::Incomplete)))
} else { } else {
Poll::Pending Poll::Pending
} }
@ -223,13 +223,13 @@ impl InnerField {
pub(crate) fn read_stream( pub(crate) fn read_stream(
payload: &mut PayloadBuffer, payload: &mut PayloadBuffer,
boundary: &str, boundary: &str,
) -> Poll<Option<Result<Bytes, MultipartError>>> { ) -> Poll<Option<Result<Bytes, Error>>> {
let mut pos = 0; let mut pos = 0;
let len = payload.buf.len(); let len = payload.buf.len();
if len == 0 { if len == 0 {
return if payload.eof { return if payload.eof {
Poll::Ready(Some(Err(MultipartError::Incomplete))) Poll::Ready(Some(Err(Error::Incomplete)))
} else { } else {
Poll::Pending Poll::Pending
}; };
@ -293,7 +293,7 @@ impl InnerField {
} }
} }
pub(crate) fn poll(&mut self, safety: &Safety) -> Poll<Option<Result<Bytes, MultipartError>>> { pub(crate) fn poll(&mut self, safety: &Safety) -> Poll<Option<Result<Bytes, Error>>> {
if self.payload.is_none() { if self.payload.is_none() {
return Poll::Ready(None); return Poll::Ready(None);
} }

View file

@ -63,4 +63,4 @@ pub(crate) mod safety;
mod server; mod server;
pub mod test; pub mod test;
pub use self::{error::MultipartError, field::Field, server::Multipart}; pub use self::{error::Error as MultipartError, field::Field, server::Multipart};

View file

@ -1,6 +1,6 @@
use std::{ use std::{
cell::{RefCell, RefMut}, cell::{RefCell, RefMut},
cmp, cmp, mem,
pin::Pin, pin::Pin,
rc::Rc, rc::Rc,
task::{Context, Poll}, task::{Context, Poll},
@ -12,7 +12,7 @@ use actix_web::{
}; };
use futures_core::stream::{LocalBoxStream, Stream}; use futures_core::stream::{LocalBoxStream, Stream};
use crate::{error::MultipartError, safety::Safety}; use crate::{error::Error, safety::Safety};
pub(crate) struct PayloadRef { pub(crate) struct PayloadRef {
payload: Rc<RefCell<PayloadBuffer>>, payload: Rc<RefCell<PayloadBuffer>>,
@ -21,7 +21,7 @@ pub(crate) struct PayloadRef {
impl PayloadRef { impl PayloadRef {
pub(crate) fn new(payload: PayloadBuffer) -> PayloadRef { pub(crate) fn new(payload: PayloadBuffer) -> PayloadRef {
PayloadRef { PayloadRef {
payload: Rc::new(payload.into()), payload: Rc::new(RefCell::new(payload)),
} }
} }
@ -44,28 +44,33 @@ impl Clone for PayloadRef {
/// Payload buffer. /// Payload buffer.
pub(crate) struct PayloadBuffer { pub(crate) struct PayloadBuffer {
pub(crate) eof: bool,
pub(crate) buf: BytesMut,
pub(crate) stream: LocalBoxStream<'static, Result<Bytes, PayloadError>>, pub(crate) stream: LocalBoxStream<'static, Result<Bytes, PayloadError>>,
pub(crate) buf: BytesMut,
/// EOF flag. If true, no more payload reads will be attempted.
pub(crate) eof: bool,
} }
impl PayloadBuffer { impl PayloadBuffer {
/// Constructs new `PayloadBuffer` instance. /// Constructs new payload buffer.
pub(crate) fn new<S>(stream: S) -> Self pub(crate) fn new<S>(stream: S) -> Self
where where
S: Stream<Item = Result<Bytes, PayloadError>> + 'static, S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{ {
PayloadBuffer { PayloadBuffer {
eof: false,
buf: BytesMut::new(),
stream: Box::pin(stream), stream: Box::pin(stream),
buf: BytesMut::with_capacity(1_024), // pre-allocate 1KiB
eof: false,
} }
} }
pub(crate) fn poll_stream(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> { pub(crate) fn poll_stream(&mut self, cx: &mut Context<'_>) -> Result<(), PayloadError> {
loop { loop {
match Pin::new(&mut self.stream).poll_next(cx) { match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data), Poll::Ready(Some(Ok(data))) => {
self.buf.extend_from_slice(&data);
// try to read more data
continue;
}
Poll::Ready(Some(Err(err))) => return Err(err), Poll::Ready(Some(Err(err))) => return Err(err),
Poll::Ready(None) => { Poll::Ready(None) => {
self.eof = true; self.eof = true;
@ -76,7 +81,7 @@ impl PayloadBuffer {
} }
} }
/// Read exact number of bytes. /// Reads exact number of bytes.
#[cfg(test)] #[cfg(test)]
pub(crate) fn read_exact(&mut self, size: usize) -> Option<Bytes> { pub(crate) fn read_exact(&mut self, size: usize) -> Option<Bytes> {
if size <= self.buf.len() { if size <= self.buf.len() {
@ -86,46 +91,57 @@ impl PayloadBuffer {
} }
} }
pub(crate) fn read_max(&mut self, size: u64) -> Result<Option<Bytes>, MultipartError> { pub(crate) fn read_max(&mut self, size: u64) -> Result<Option<Bytes>, Error> {
if !self.buf.is_empty() { if !self.buf.is_empty() {
let size = cmp::min(self.buf.len() as u64, size) as usize; let size = cmp::min(self.buf.len() as u64, size) as usize;
Ok(Some(self.buf.split_to(size).freeze())) Ok(Some(self.buf.split_to(size).freeze()))
} else if self.eof { } else if self.eof {
Err(MultipartError::Incomplete) Err(Error::Incomplete)
} else { } else {
Ok(None) Ok(None)
} }
} }
/// Read until specified ending. /// Reads until specified ending.
pub(crate) fn read_until(&mut self, line: &[u8]) -> Result<Option<Bytes>, MultipartError> { ///
let res = memchr::memmem::find(&self.buf, line) /// Returns:
.map(|idx| self.buf.split_to(idx + line.len()).freeze()); ///
/// - `Ok(Some(chunk))` - `needle` is found, with chunk ending after needle
/// - `Err(Incomplete)` - `needle` is not found and we're at EOF
/// - `Ok(None)` - `needle` is not found otherwise
pub(crate) fn read_until(&mut self, needle: &[u8]) -> Result<Option<Bytes>, Error> {
match memchr::memmem::find(&self.buf, needle) {
// buffer exhausted and EOF without finding needle
None if self.eof => Err(Error::Incomplete),
if res.is_none() && self.eof { // needle not yet found
Err(MultipartError::Incomplete) None => Ok(None),
} else {
Ok(res) // needle found, split chunk out of buf
Some(idx) => Ok(Some(self.buf.split_to(idx + needle.len()).freeze())),
} }
} }
/// Read bytes until new line delimiter. /// Reads bytes until new line delimiter.
pub(crate) fn readline(&mut self) -> Result<Option<Bytes>, MultipartError> { #[inline]
pub(crate) fn readline(&mut self) -> Result<Option<Bytes>, Error> {
self.read_until(b"\n") self.read_until(b"\n")
} }
/// Read bytes until new line delimiter or EOF. /// Reads bytes until new line delimiter or until EOF.
pub(crate) fn readline_or_eof(&mut self) -> Result<Option<Bytes>, MultipartError> { #[inline]
pub(crate) fn readline_or_eof(&mut self) -> Result<Option<Bytes>, Error> {
match self.readline() { match self.readline() {
Err(MultipartError::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())), Err(Error::Incomplete) if self.eof => Ok(Some(self.buf.split().freeze())),
line => line, line => line,
} }
} }
/// Put unprocessed data back to the buffer. /// Puts unprocessed data back to the buffer.
pub(crate) fn unprocessed(&mut self, data: Bytes) { pub(crate) fn unprocessed(&mut self, data: Bytes) {
let buf = BytesMut::from(data.as_ref()); // TODO: use BytesMut::from when it's released, see https://github.com/tokio-rs/bytes/pull/710
let buf = std::mem::replace(&mut self.buf, buf); let buf = BytesMut::from(&data[..]);
let buf = mem::replace(&mut self.buf, buf);
self.buf.extend_from_slice(&buf); self.buf.extend_from_slice(&buf);
} }
} }

View file

@ -18,7 +18,7 @@ use futures_core::stream::Stream;
use mime::Mime; use mime::Mime;
use crate::{ use crate::{
error::MultipartError, error::Error,
field::InnerField, field::InnerField,
payload::{PayloadBuffer, PayloadRef}, payload::{PayloadBuffer, PayloadRef},
safety::Safety, safety::Safety,
@ -33,9 +33,15 @@ const MAX_HEADERS: usize = 32;
/// implementation. `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart` is /// implementation. `MultipartItem::Field` contains multipart field. `MultipartItem::Multipart` is
/// used for nested multipart streams. /// used for nested multipart streams.
pub struct Multipart { pub struct Multipart {
flow: Flow,
safety: Safety, safety: Safety,
inner: Option<Inner>, }
error: Option<MultipartError>,
enum Flow {
InFlight(Inner),
/// Error container is Some until an error is returned out of the flow.
Error(Option<Error>),
} }
impl Multipart { impl Multipart {
@ -59,24 +65,22 @@ impl Multipart {
} }
/// Extract Content-Type and boundary info from headers. /// Extract Content-Type and boundary info from headers.
pub(crate) fn find_ct_and_boundary( pub(crate) fn find_ct_and_boundary(headers: &HeaderMap) -> Result<(Mime, String), Error> {
headers: &HeaderMap,
) -> Result<(Mime, String), MultipartError> {
let content_type = headers let content_type = headers
.get(&header::CONTENT_TYPE) .get(&header::CONTENT_TYPE)
.ok_or(MultipartError::ContentTypeMissing)? .ok_or(Error::ContentTypeMissing)?
.to_str() .to_str()
.ok() .ok()
.and_then(|content_type| content_type.parse::<Mime>().ok()) .and_then(|content_type| content_type.parse::<Mime>().ok())
.ok_or(MultipartError::ContentTypeParse)?; .ok_or(Error::ContentTypeParse)?;
if content_type.type_() != mime::MULTIPART { if content_type.type_() != mime::MULTIPART {
return Err(MultipartError::ContentTypeIncompatible); return Err(Error::ContentTypeIncompatible);
} }
let boundary = content_type let boundary = content_type
.get_param(mime::BOUNDARY) .get_param(mime::BOUNDARY)
.ok_or(MultipartError::BoundaryMissing)? .ok_or(Error::BoundaryMissing)?
.as_str() .as_str()
.to_owned(); .to_owned();
@ -90,64 +94,57 @@ impl Multipart {
{ {
Multipart { Multipart {
safety: Safety::new(), safety: Safety::new(),
inner: Some(Inner { flow: Flow::InFlight(Inner {
payload: PayloadRef::new(PayloadBuffer::new(stream)), payload: PayloadRef::new(PayloadBuffer::new(stream)),
content_type: ct, content_type: ct,
boundary, boundary,
state: State::FirstBoundary, state: State::FirstBoundary,
item: Item::None, item: Item::None,
}), }),
error: None,
} }
} }
/// Constructs a new multipart reader from given `MultipartError`. /// Constructs a new multipart reader from given `MultipartError`.
pub(crate) fn from_error(err: MultipartError) -> Multipart { pub(crate) fn from_error(err: Error) -> Multipart {
Multipart { Multipart {
error: Some(err), flow: Flow::Error(Some(err)),
safety: Safety::new(), safety: Safety::new(),
inner: None,
} }
} }
/// Return requests parsed Content-Type or raise the stored error. /// Return requests parsed Content-Type or raise the stored error.
pub(crate) fn content_type_or_bail(&mut self) -> Result<mime::Mime, MultipartError> { pub(crate) fn content_type_or_bail(&mut self) -> Result<mime::Mime, Error> {
if let Some(err) = self.error.take() { match self.flow {
return Err(err); Flow::InFlight(ref inner) => Ok(inner.content_type.clone()),
Flow::Error(ref mut err) => Err(err
.take()
.expect("error should not be taken after it was returned")),
} }
Ok(self
.inner
.as_ref()
// TODO: look into using enum instead of two options
.expect("multipart requests should have state")
.content_type
.clone())
} }
} }
impl Stream for Multipart { impl Stream for Multipart {
type Item = Result<Field, MultipartError>; type Item = Result<Field, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut(); let this = self.get_mut();
match this.inner.as_mut() { match this.flow {
Some(inner) => { Flow::InFlight(ref mut inner) => {
if let Some(mut buffer) = inner.payload.get_mut(&this.safety) { if let Some(mut buffer) = inner.payload.get_mut(&this.safety) {
// check safety and poll read payload to buffer. // check safety and poll read payload to buffer.
buffer.poll_stream(cx)?; buffer.poll_stream(cx)?;
} else if !this.safety.is_clean() { } else if !this.safety.is_clean() {
// safety violation // safety violation
return Poll::Ready(Some(Err(MultipartError::NotConsumed))); return Poll::Ready(Some(Err(Error::NotConsumed)));
} else { } else {
return Poll::Pending; return Poll::Pending;
} }
inner.poll(&this.safety, cx) inner.poll(&this.safety, cx)
} }
None => Poll::Ready(Some(Err(this
.error Flow::Error(ref mut err) => Poll::Ready(Some(Err(err
.take() .take()
.expect("Multipart polled after finish")))), .expect("Multipart polled after finish")))),
} }
@ -191,22 +188,21 @@ struct Inner {
} }
impl Inner { impl Inner {
fn read_field_headers( fn read_field_headers(payload: &mut PayloadBuffer) -> Result<Option<HeaderMap>, Error> {
payload: &mut PayloadBuffer,
) -> Result<Option<HeaderMap>, MultipartError> {
match payload.read_until(b"\r\n\r\n")? { match payload.read_until(b"\r\n\r\n")? {
None => { None => {
if payload.eof { if payload.eof {
Err(MultipartError::Incomplete) Err(Error::Incomplete)
} else { } else {
Ok(None) Ok(None)
} }
} }
Some(bytes) => { Some(bytes) => {
let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS]; let mut hdrs = [httparse::EMPTY_HEADER; MAX_HEADERS];
match httparse::parse_headers(&bytes, &mut hdrs) { match httparse::parse_headers(&bytes, &mut hdrs).map_err(ParseError::from)? {
Ok(httparse::Status::Complete((_, hdrs))) => { httparse::Status::Complete((_, hdrs)) => {
// convert headers // convert headers
let mut headers = HeaderMap::with_capacity(hdrs.len()); let mut headers = HeaderMap::with_capacity(hdrs.len());
@ -220,57 +216,84 @@ impl Inner {
Ok(Some(headers)) Ok(Some(headers))
} }
Ok(httparse::Status::Partial) => Err(ParseError::Header.into()),
Err(err) => Err(ParseError::from(err).into()), httparse::Status::Partial => Err(ParseError::Header.into()),
} }
} }
} }
} }
fn read_boundary( /// Reads a field boundary from the payload buffer (and discards it).
payload: &mut PayloadBuffer, ///
boundary: &str, /// Reads "in-between" and "final" boundaries. E.g. for boundary = "foo":
) -> Result<Option<bool>, MultipartError> { ///
/// ```plain
/// --foo <-- in-between fields
/// --foo-- <-- end of request body, should be followed by EOF
/// ```
///
/// Returns:
///
/// - `Ok(Some(true))` - final field boundary read (EOF)
/// - `Ok(Some(false))` - field boundary read
/// - `Ok(None)` - boundary not found, more data needs reading
/// - `Err(BoundaryMissing)` - multipart boundary is missing
fn read_boundary(payload: &mut PayloadBuffer, boundary: &str) -> Result<Option<bool>, Error> {
// TODO: need to read epilogue // TODO: need to read epilogue
match payload.readline_or_eof()? { let chunk = match payload.readline_or_eof()? {
None => { // TODO: this might be okay as a let Some() else return Ok(None)
if payload.eof { None => return Ok(payload.eof.then_some(true)),
Ok(Some(true)) Some(chunk) => chunk,
} else { };
Ok(None)
} const BOUNDARY_MARKER: &[u8] = b"--";
} const LINE_BREAK: &[u8] = b"\r\n";
Some(chunk) => {
if chunk.len() < boundary.len() + 4 let boundary_len = boundary.len();
|| &chunk[..2] != b"--"
|| &chunk[2..boundary.len() + 2] != boundary.as_bytes() if chunk.len() < boundary_len + 2 + 2
|| !chunk.starts_with(BOUNDARY_MARKER)
|| &chunk[2..boundary_len + 2] != boundary.as_bytes()
{ {
Err(MultipartError::BoundaryMissing) return Err(Error::BoundaryMissing);
} else if &chunk[boundary.len() + 2..] == b"\r\n" { }
Ok(Some(false))
} else if &chunk[boundary.len() + 2..boundary.len() + 4] == b"--" // chunk facts:
&& (chunk.len() == boundary.len() + 4 // - long enough to contain boundary + 2 markers or 1 marker and line-break
|| &chunk[boundary.len() + 4..] == b"\r\n") // - starts with boundary marker
// - chunk contains correct boundary
if &chunk[boundary_len + 2..] == LINE_BREAK {
// boundary is followed by line-break, indicating more fields to come
return Ok(Some(false));
}
// boundary is followed by marker
if &chunk[boundary_len + 2..boundary_len + 4] == BOUNDARY_MARKER
&& (
// chunk is exactly boundary len + 2 markers
chunk.len() == boundary_len + 2 + 2
// final boundary is allowed to end with a line-break
|| &chunk[boundary_len + 4..] == LINE_BREAK
)
{ {
Ok(Some(true)) return Ok(Some(true));
} else {
Err(MultipartError::BoundaryMissing)
}
}
} }
Err(Error::BoundaryMissing)
} }
fn skip_until_boundary( fn skip_until_boundary(
payload: &mut PayloadBuffer, payload: &mut PayloadBuffer,
boundary: &str, boundary: &str,
) -> Result<Option<bool>, MultipartError> { ) -> Result<Option<bool>, Error> {
let mut eof = false; let mut eof = false;
loop { loop {
match payload.readline()? { match payload.readline()? {
Some(chunk) => { Some(chunk) => {
if chunk.is_empty() { if chunk.is_empty() {
return Err(MultipartError::BoundaryMissing); return Err(Error::BoundaryMissing);
} }
if chunk.len() < boundary.len() { if chunk.len() < boundary.len() {
continue; continue;
@ -292,7 +315,7 @@ impl Inner {
} }
None => { None => {
return if payload.eof { return if payload.eof {
Err(MultipartError::Incomplete) Err(Error::Incomplete)
} else { } else {
Ok(None) Ok(None)
}; };
@ -302,11 +325,7 @@ impl Inner {
Ok(Some(eof)) Ok(Some(eof))
} }
fn poll( fn poll(&mut self, safety: &Safety, cx: &Context<'_>) -> Poll<Option<Result<Field, Error>>> {
&mut self,
safety: &Safety,
cx: &Context<'_>,
) -> Poll<Option<Result<Field, MultipartError>>> {
if self.state == State::Eof { if self.state == State::Eof {
Poll::Ready(None) Poll::Ready(None)
} else { } else {
@ -338,6 +357,7 @@ impl Inner {
// read until first boundary // read until first boundary
State::FirstBoundary => { State::FirstBoundary => {
match Inner::skip_until_boundary(&mut payload, &self.boundary)? { match Inner::skip_until_boundary(&mut payload, &self.boundary)? {
None => return Poll::Pending,
Some(eof) => { Some(eof) => {
if eof { if eof {
self.state = State::Eof; self.state = State::Eof;
@ -346,7 +366,6 @@ impl Inner {
self.state = State::Headers; self.state = State::Headers;
} }
} }
None => return Poll::Pending,
} }
} }
@ -398,11 +417,11 @@ impl Inner {
// type must be set as "form-data", and it must have a name parameter. // type must be set as "form-data", and it must have a name parameter.
let Some(cd) = &field_content_disposition else { let Some(cd) = &field_content_disposition else {
return Poll::Ready(Some(Err(MultipartError::ContentDispositionMissing))); return Poll::Ready(Some(Err(Error::ContentDispositionMissing)));
}; };
let Some(field_name) = cd.get_name() else { let Some(field_name) = cd.get_name() else {
return Poll::Ready(Some(Err(MultipartError::ContentDispositionNameMissing))); return Poll::Ready(Some(Err(Error::ContentDispositionNameMissing)));
}; };
Some(field_name.to_owned()) Some(field_name.to_owned())
@ -422,7 +441,7 @@ impl Inner {
// nested multipart stream is not supported // nested multipart stream is not supported
if let Some(mime) = &field_content_type { if let Some(mime) = &field_content_type {
if mime.type_() == mime::MULTIPART { if mime.type_() == mime::MULTIPART {
return Poll::Ready(Some(Err(MultipartError::Nested))); return Poll::Ready(Some(Err(Error::Nested)));
} }
} }
@ -475,7 +494,7 @@ mod tests {
async fn test_boundary() { async fn test_boundary() {
let headers = HeaderMap::new(); let headers = HeaderMap::new();
match Multipart::find_ct_and_boundary(&headers) { match Multipart::find_ct_and_boundary(&headers) {
Err(MultipartError::ContentTypeMissing) => {} Err(Error::ContentTypeMissing) => {}
_ => unreachable!("should not happen"), _ => unreachable!("should not happen"),
} }
@ -486,7 +505,7 @@ mod tests {
); );
match Multipart::find_ct_and_boundary(&headers) { match Multipart::find_ct_and_boundary(&headers) {
Err(MultipartError::ContentTypeParse) => {} Err(Error::ContentTypeParse) => {}
_ => unreachable!("should not happen"), _ => unreachable!("should not happen"),
} }
@ -496,7 +515,7 @@ mod tests {
header::HeaderValue::from_static("multipart/mixed"), header::HeaderValue::from_static("multipart/mixed"),
); );
match Multipart::find_ct_and_boundary(&headers) { match Multipart::find_ct_and_boundary(&headers) {
Err(MultipartError::BoundaryMissing) => {} Err(Error::BoundaryMissing) => {}
_ => unreachable!("should not happen"), _ => unreachable!("should not happen"),
} }
@ -831,7 +850,7 @@ mod tests {
#[actix_rt::test] #[actix_rt::test]
async fn test_multipart_from_error() { async fn test_multipart_from_error() {
let err = MultipartError::ContentTypeMissing; let err = Error::ContentTypeMissing;
let mut multipart = Multipart::from_error(err); let mut multipart = Multipart::from_error(err);
assert!(multipart.next().await.unwrap().is_err()) assert!(multipart.next().await.unwrap().is_err())
} }
@ -888,7 +907,7 @@ mod tests {
res.expect_err( res.expect_err(
"according to RFC 7578, form-data fields require a content-disposition header" "according to RFC 7578, form-data fields require a content-disposition header"
), ),
MultipartError::ContentDispositionMissing Error::ContentDispositionMissing
); );
} }
@ -942,7 +961,7 @@ mod tests {
let res = multipart.next().await.unwrap(); let res = multipart.next().await.unwrap();
assert_matches!( assert_matches!(
res.expect_err("according to RFC 7578, form-data fields require a name attribute"), res.expect_err("according to RFC 7578, form-data fields require a name attribute"),
MultipartError::ContentDispositionNameMissing Error::ContentDispositionNameMissing
); );
} }
@ -960,7 +979,7 @@ mod tests {
// should fail immediately // should fail immediately
match field.next().await { match field.next().await {
Some(Err(MultipartError::NotConsumed)) => {} Some(Err(Error::NotConsumed)) => {}
_ => panic!(), _ => panic!(),
}; };
} }

View file

@ -25,8 +25,7 @@ const BOUNDARY_PREFIX: &str = "------------------------";
/// ///
/// ``` /// ```
/// use actix_multipart::test::create_form_data_payload_and_headers; /// use actix_multipart::test::create_form_data_payload_and_headers;
/// use actix_web::test::TestRequest; /// use actix_web::{test::TestRequest, web::Bytes};
/// use bytes::Bytes;
/// use memchr::memmem::find; /// use memchr::memmem::find;
/// ///
/// let (body, headers) = create_form_data_payload_and_headers( /// let (body, headers) = create_form_data_payload_and_headers(