diff --git a/actix-multipart/src/form/mod.rs b/actix-multipart/src/form/mod.rs index 451b103fd..6fbdfa1a1 100644 --- a/actix-multipart/src/form/mod.rs +++ b/actix-multipart/src/form/mod.rs @@ -33,6 +33,14 @@ pub trait FieldReader<'t>: Sized + Any { type Future: Future>; /// The form will call this function to handle the field. + /// + /// # Panics + /// + /// When reading the `field` payload using its `Stream` implementation, polling (manually or via + /// `next()`/`try_next()`) may panic after the payload is exhausted. If this is a problem for + /// your implementation of this method, you should [`fuse()`] the `Field` first. + /// + /// [`fuse()`]: https://docs.rs/futures-util/0.3/futures_util/stream/trait.StreamExt.html#method.fuse fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future; } @@ -396,11 +404,20 @@ mod tests { use actix_http::encoding::Decoder; use actix_multipart_rfc7578::client::multipart; use actix_test::TestServer; - use actix_web::{dev::Payload, http::StatusCode, web, App, HttpResponse, Responder}; + use actix_web::{ + dev::Payload, http::StatusCode, web, App, HttpRequest, HttpResponse, Resource, Responder, + }; use awc::{Client, ClientResponse}; + use futures_core::future::LocalBoxFuture; + use futures_util::TryStreamExt as _; use super::MultipartForm; - use crate::form::{bytes::Bytes, tempfile::TempFile, text::Text, MultipartFormConfig}; + use crate::{ + form::{ + bytes::Bytes, tempfile::TempFile, text::Text, FieldReader, Limits, MultipartFormConfig, + }, + Field, MultipartError, + }; pub async fn send_form( srv: &TestServer, @@ -734,4 +751,49 @@ mod tests { let response = send_form(&srv, form, "/").await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); } + + #[should_panic(expected = "called `Result::unwrap()` on an `Err` value: Connect(Disconnected)")] + #[actix_web::test] + async fn field_try_next_panic() { + #[derive(Debug)] + struct NullSink; + + impl<'t> FieldReader<'t> for NullSink { + type Future = LocalBoxFuture<'t, Result>; + + fn read_field( + _: &'t HttpRequest, + mut field: Field, + _limits: &'t mut Limits, + ) -> Self::Future { + Box::pin(async move { + // exhaust field stream + while let Some(_chunk) = field.try_next().await? {} + + // poll again, crash + let _post = field.try_next().await; + + Ok(Self) + }) + } + } + + #[allow(dead_code)] + #[derive(MultipartForm)] + struct NullSinkForm { + foo: NullSink, + } + + async fn null_sink(_form: MultipartForm) -> impl Responder { + "unreachable" + } + + let srv = actix_test::start(|| App::new().service(Resource::new("/").post(null_sink))); + + let mut form = multipart::Form::default(); + form.add_text("foo", "data is not important to this test"); + + // panics with Err(Connect(Disconnected)) due to form NullSink panic + let _res = send_form(&srv, form, "/").await; + } } diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index d0f833318..0256aa7bf 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -465,7 +465,12 @@ impl Stream for Field { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); let mut inner = this.inner.borrow_mut(); - if let Some(mut buffer) = inner.payload.as_ref().unwrap().get_mut(&this.safety) { + if let Some(mut buffer) = inner + .payload + .as_ref() + .expect("Field should not be polled after completion") + .get_mut(&this.safety) + { // check safety and poll read payload to buffer. buffer.poll_stream(cx)?; } else if !this.safety.is_clean() { @@ -496,6 +501,7 @@ impl fmt::Debug for Field { } struct InnerField { + /// Payload is initialized as Some and is `take`n when the field stream finishes. payload: Option, boundary: String, eof: bool, @@ -643,7 +649,12 @@ impl InnerField { return Poll::Ready(None); } - let result = if let Some(mut payload) = self.payload.as_ref().unwrap().get_mut(s) { + let result = if let Some(mut payload) = self + .payload + .as_ref() + .expect("Field should not be polled after completion") + .get_mut(s) + { if !self.eof { let res = if let Some(ref mut len) = self.length { InnerField::read_len(&mut payload, len) @@ -674,8 +685,10 @@ impl InnerField { }; if let Poll::Ready(None) = result { - self.payload.take(); + // drop payload buffer and make future un-poll-able + let _ = self.payload.take(); } + result } }