//! Process and extract typed data from a multipart stream. use std::{ any::Any, collections::HashMap, future::{ready, Future}, sync::Arc, }; use actix_web::{dev, error::PayloadError, web, Error, FromRequest, HttpRequest}; use derive_more::{Deref, DerefMut}; use futures_core::future::LocalBoxFuture; use futures_util::{TryFutureExt as _, TryStreamExt as _}; use crate::{Field, Multipart, MultipartError}; pub mod bytes; pub mod json; #[cfg(feature = "tempfile")] pub mod tempfile; pub mod text; #[cfg(feature = "derive")] pub use actix_multipart_derive::MultipartForm; type FieldErrorHandler = Option Error + Send + Sync>>; /// Trait that data types to be used in a multipart form struct should implement. /// /// It represents an asynchronous handler that processes a multipart field to produce `Self`. pub trait FieldReader<'t>: Sized + Any { /// Future that resolves to a `Self`. type Future: Future>; /// The form will call this function to handle the field. fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future; } /// Used to accumulate the state of the loaded fields. #[doc(hidden)] #[derive(Default, Deref, DerefMut)] pub struct State(pub HashMap>); /// Trait that the field collection types implement, i.e. `Vec`, `Option`, or `T` itself. #[doc(hidden)] pub trait FieldGroupReader<'t>: Sized + Any { type Future: Future>; /// The form will call this function for each matching field. fn handle_field( req: &'t HttpRequest, field: Field, limits: &'t mut Limits, state: &'t mut State, duplicate_field: DuplicateField, ) -> Self::Future; /// Construct `Self` from the group of processed fields. fn from_state(name: &str, state: &'t mut State) -> Result; } impl<'t, T> FieldGroupReader<'t> for Option where T: FieldReader<'t>, { type Future = LocalBoxFuture<'t, Result<(), MultipartError>>; fn handle_field( req: &'t HttpRequest, field: Field, limits: &'t mut Limits, state: &'t mut State, duplicate_field: DuplicateField, ) -> Self::Future { if state.contains_key(field.name()) { match duplicate_field { DuplicateField::Ignore => return Box::pin(ready(Ok(()))), DuplicateField::Deny => { return Box::pin(ready(Err(MultipartError::DuplicateField( field.name().to_owned(), )))) } DuplicateField::Replace => {} } } Box::pin(async move { let field_name = field.name().to_owned(); let t = T::read_field(req, field, limits).await?; state.insert(field_name, Box::new(t)); Ok(()) }) } fn from_state(name: &str, state: &'t mut State) -> Result { Ok(state.remove(name).map(|m| *m.downcast::().unwrap())) } } impl<'t, T> FieldGroupReader<'t> for Vec where T: FieldReader<'t>, { type Future = LocalBoxFuture<'t, Result<(), MultipartError>>; fn handle_field( req: &'t HttpRequest, field: Field, limits: &'t mut Limits, state: &'t mut State, _duplicate_field: DuplicateField, ) -> Self::Future { Box::pin(async move { // Note: Vec GroupReader always allows duplicates let field_name = field.name().to_owned(); let vec = state .entry(field_name) .or_insert_with(|| Box::>::default()) .downcast_mut::>() .unwrap(); let item = T::read_field(req, field, limits).await?; vec.push(item); Ok(()) }) } fn from_state(name: &str, state: &'t mut State) -> Result { Ok(state .remove(name) .map(|m| *m.downcast::>().unwrap()) .unwrap_or_default()) } } impl<'t, T> FieldGroupReader<'t> for T where T: FieldReader<'t>, { type Future = LocalBoxFuture<'t, Result<(), MultipartError>>; fn handle_field( req: &'t HttpRequest, field: Field, limits: &'t mut Limits, state: &'t mut State, duplicate_field: DuplicateField, ) -> Self::Future { if state.contains_key(field.name()) { match duplicate_field { DuplicateField::Ignore => return Box::pin(ready(Ok(()))), DuplicateField::Deny => { return Box::pin(ready(Err(MultipartError::DuplicateField( field.name().to_owned(), )))) } DuplicateField::Replace => {} } } Box::pin(async move { let field_name = field.name().to_owned(); let t = T::read_field(req, field, limits).await?; state.insert(field_name, Box::new(t)); Ok(()) }) } fn from_state(name: &str, state: &'t mut State) -> Result { state .remove(name) .map(|m| *m.downcast::().unwrap()) .ok_or_else(|| MultipartError::MissingField(name.to_owned())) } } /// Trait that allows a type to be used in the [`struct@MultipartForm`] extractor. /// /// You should use the [`macro@MultipartForm`] macro to derive this for your struct. pub trait MultipartCollect: Sized { /// An optional limit in bytes to be applied a given field name. Note this limit will be shared /// across all fields sharing the same name. fn limit(field_name: &str) -> Option; /// The extractor will call this function for each incoming field, the state can be updated /// with the processed field data. fn handle_field<'t>( req: &'t HttpRequest, field: Field, limits: &'t mut Limits, state: &'t mut State, ) -> LocalBoxFuture<'t, Result<(), MultipartError>>; /// Once all the fields have been processed and stored in the state, this is called /// to convert into the struct representation. fn from_state(state: State) -> Result; } #[doc(hidden)] pub enum DuplicateField { /// Additional fields are not processed. Ignore, /// An error will be raised. Deny, /// All fields will be processed, the last one will replace all previous. Replace, } /// Used to keep track of the remaining limits for the form and current field. pub struct Limits { pub total_limit_remaining: usize, pub memory_limit_remaining: usize, pub field_limit_remaining: Option, } impl Limits { pub fn new(total_limit: usize, memory_limit: usize) -> Self { Self { total_limit_remaining: total_limit, memory_limit_remaining: memory_limit, field_limit_remaining: None, } } /// This function should be called within a [`FieldReader`] when reading each chunk of a field /// to ensure that the form limits are not exceeded. /// /// # Arguments /// /// * `bytes` - The number of bytes being read from this chunk /// * `in_memory` - Whether to consume from the memory limits pub fn try_consume_limits( &mut self, bytes: usize, in_memory: bool, ) -> Result<(), MultipartError> { self.total_limit_remaining = self .total_limit_remaining .checked_sub(bytes) .ok_or(MultipartError::Payload(PayloadError::Overflow))?; if in_memory { self.memory_limit_remaining = self .memory_limit_remaining .checked_sub(bytes) .ok_or(MultipartError::Payload(PayloadError::Overflow))?; } if let Some(field_limit) = self.field_limit_remaining { self.field_limit_remaining = Some( field_limit .checked_sub(bytes) .ok_or(MultipartError::Payload(PayloadError::Overflow))?, ); } Ok(()) } } /// Typed `multipart/form-data` extractor. /// /// To extract typed data from a multipart stream, the inner type `T` must implement the /// [`MultipartCollect`] trait. You should use the [`macro@MultipartForm`] macro to derive this /// for your struct. /// /// Add a [`MultipartFormConfig`] to your app data to configure extraction. #[derive(Deref, DerefMut)] pub struct MultipartForm(pub T); impl MultipartForm { /// Unwrap into inner `T` value. pub fn into_inner(self) -> T { self.0 } } impl FromRequest for MultipartForm where T: MultipartCollect, { type Error = Error; type Future = LocalBoxFuture<'static, Result>; #[inline] fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { let mut payload = Multipart::new(req.headers(), payload.take()); let config = MultipartFormConfig::from_req(req); let mut limits = Limits::new(config.total_limit, config.memory_limit); let req = req.clone(); let req2 = req.clone(); let err_handler = config.err_handler.clone(); Box::pin( async move { let mut state = State::default(); // We need to ensure field limits are shared for all instances of this field name let mut field_limits = HashMap::>::new(); while let Some(field) = payload.try_next().await? { // Retrieve the limit for this field let entry = field_limits .entry(field.name().to_owned()) .or_insert_with(|| T::limit(field.name())); limits.field_limit_remaining = entry.to_owned(); T::handle_field(&req, field, &mut limits, &mut state).await?; // Update the stored limit *entry = limits.field_limit_remaining; } let inner = T::from_state(state)?; Ok(MultipartForm(inner)) } .map_err(move |err| { if let Some(handler) = err_handler { (*handler)(err, &req2) } else { err.into() } }), ) } } type MultipartFormErrorHandler = Option Error + Send + Sync>>; /// [`struct@MultipartForm`] extractor configuration. /// /// Add to your app data to have it picked up by [`struct@MultipartForm`] extractors. #[derive(Clone)] pub struct MultipartFormConfig { total_limit: usize, memory_limit: usize, err_handler: MultipartFormErrorHandler, } impl MultipartFormConfig { /// Sets maximum accepted payload size for the entire form. By default this limit is 50MiB. pub fn total_limit(mut self, total_limit: usize) -> Self { self.total_limit = total_limit; self } /// Sets maximum accepted data that will be read into memory. By default this limit is 2MiB. pub fn memory_limit(mut self, memory_limit: usize) -> Self { self.memory_limit = memory_limit; self } /// Sets custom error handler. pub fn error_handler(mut self, f: F) -> Self where F: Fn(MultipartError, &HttpRequest) -> Error + Send + Sync + 'static, { self.err_handler = Some(Arc::new(f)); self } /// Extracts payload config from app data. Check both `T` and `Data`, in that order, and fall /// back to the default payload config. fn from_req(req: &HttpRequest) -> &Self { req.app_data::() .or_else(|| req.app_data::>().map(|d| d.as_ref())) .unwrap_or(&DEFAULT_CONFIG) } } const DEFAULT_CONFIG: MultipartFormConfig = MultipartFormConfig { total_limit: 52_428_800, // 50 MiB memory_limit: 2_097_152, // 2 MiB err_handler: None, }; impl Default for MultipartFormConfig { fn default() -> Self { DEFAULT_CONFIG } } #[cfg(test)] mod tests { use actix_http::encoding::Decoder; use actix_multipart_rfc7578::client::multipart; use actix_test::TestServer; use actix_web::{dev::Payload, http::StatusCode, web, App, HttpResponse, Responder}; use awc::{Client, ClientResponse}; use super::MultipartForm; use crate::form::{bytes::Bytes, tempfile::TempFile, text::Text, MultipartFormConfig}; pub async fn send_form( srv: &TestServer, form: multipart::Form<'static>, uri: &'static str, ) -> ClientResponse> { Client::default() .post(srv.url(uri)) .content_type(form.content_type()) .send_body(multipart::Body::from(form)) .await .unwrap() } /// Test `Option` fields. #[derive(MultipartForm)] struct TestOptions { field1: Option>, field2: Option>, } async fn test_options_route(form: MultipartForm) -> impl Responder { assert!(form.field1.is_some()); assert!(form.field2.is_none()); HttpResponse::Ok().finish() } #[actix_rt::test] async fn test_options() { let srv = actix_test::start(|| App::new().route("/", web::post().to(test_options_route))); let mut form = multipart::Form::default(); form.add_text("field1", "value"); let response = send_form(&srv, form, "/").await; assert_eq!(response.status(), StatusCode::OK); } /// Test `Vec` fields. #[derive(MultipartForm)] struct TestVec { list1: Vec>, list2: Vec>, } async fn test_vec_route(form: MultipartForm) -> impl Responder { let form = form.into_inner(); let strings = form .list1 .into_iter() .map(|s| s.into_inner()) .collect::>(); assert_eq!(strings, vec!["value1", "value2", "value3"]); assert_eq!(form.list2.len(), 0); HttpResponse::Ok().finish() } #[actix_rt::test] async fn test_vec() { let srv = actix_test::start(|| App::new().route("/", web::post().to(test_vec_route))); let mut form = multipart::Form::default(); form.add_text("list1", "value1"); form.add_text("list1", "value2"); form.add_text("list1", "value3"); let response = send_form(&srv, form, "/").await; assert_eq!(response.status(), StatusCode::OK); } /// Test the `rename` field attribute. #[derive(MultipartForm)] struct TestFieldRenaming { #[multipart(rename = "renamed")] field1: Text, #[multipart(rename = "field1")] field2: Text, field3: Text, } async fn test_field_renaming_route(form: MultipartForm) -> impl Responder { assert_eq!(&*form.field1, "renamed"); assert_eq!(&*form.field2, "field1"); assert_eq!(&*form.field3, "field3"); HttpResponse::Ok().finish() } #[actix_rt::test] async fn test_field_renaming() { let srv = actix_test::start(|| App::new().route("/", web::post().to(test_field_renaming_route))); let mut form = multipart::Form::default(); form.add_text("renamed", "renamed"); form.add_text("field1", "field1"); form.add_text("field3", "field3"); let response = send_form(&srv, form, "/").await; assert_eq!(response.status(), StatusCode::OK); } /// Test the `deny_unknown_fields` struct attribute. #[derive(MultipartForm)] #[multipart(deny_unknown_fields)] struct TestDenyUnknown {} #[derive(MultipartForm)] struct TestAllowUnknown {} async fn test_deny_unknown_route(_: MultipartForm) -> impl Responder { HttpResponse::Ok().finish() } async fn test_allow_unknown_route(_: MultipartForm) -> impl Responder { HttpResponse::Ok().finish() } #[actix_rt::test] async fn test_deny_unknown() { let srv = actix_test::start(|| { App::new() .route("/deny", web::post().to(test_deny_unknown_route)) .route("/allow", web::post().to(test_allow_unknown_route)) }); let mut form = multipart::Form::default(); form.add_text("unknown", "value"); let response = send_form(&srv, form, "/deny").await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); let mut form = multipart::Form::default(); form.add_text("unknown", "value"); let response = send_form(&srv, form, "/allow").await; assert_eq!(response.status(), StatusCode::OK); } /// Test the `duplicate_field` struct attribute. #[derive(MultipartForm)] #[multipart(duplicate_field = "deny")] struct TestDuplicateDeny { _field: Text, } #[derive(MultipartForm)] #[multipart(duplicate_field = "replace")] struct TestDuplicateReplace { field: Text, } #[derive(MultipartForm)] #[multipart(duplicate_field = "ignore")] struct TestDuplicateIgnore { field: Text, } async fn test_duplicate_deny_route(_: MultipartForm) -> impl Responder { HttpResponse::Ok().finish() } async fn test_duplicate_replace_route( form: MultipartForm, ) -> impl Responder { assert_eq!(&*form.field, "second_value"); HttpResponse::Ok().finish() } async fn test_duplicate_ignore_route( form: MultipartForm, ) -> impl Responder { assert_eq!(&*form.field, "first_value"); HttpResponse::Ok().finish() } #[actix_rt::test] async fn test_duplicate_field() { let srv = actix_test::start(|| { App::new() .route("/deny", web::post().to(test_duplicate_deny_route)) .route("/replace", web::post().to(test_duplicate_replace_route)) .route("/ignore", web::post().to(test_duplicate_ignore_route)) }); let mut form = multipart::Form::default(); form.add_text("_field", "first_value"); form.add_text("_field", "second_value"); let response = send_form(&srv, form, "/deny").await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); let mut form = multipart::Form::default(); form.add_text("field", "first_value"); form.add_text("field", "second_value"); let response = send_form(&srv, form, "/replace").await; assert_eq!(response.status(), StatusCode::OK); let mut form = multipart::Form::default(); form.add_text("field", "first_value"); form.add_text("field", "second_value"); let response = send_form(&srv, form, "/ignore").await; assert_eq!(response.status(), StatusCode::OK); } /// Test the Limits. #[derive(MultipartForm)] struct TestMemoryUploadLimits { field: Bytes, } #[derive(MultipartForm)] struct TestFileUploadLimits { field: TempFile, } async fn test_upload_limits_memory( form: MultipartForm, ) -> impl Responder { assert!(!form.field.data.is_empty()); HttpResponse::Ok().finish() } async fn test_upload_limits_file(form: MultipartForm) -> impl Responder { assert!(form.field.size > 0); HttpResponse::Ok().finish() } #[actix_rt::test] async fn test_memory_limits() { let srv = actix_test::start(|| { App::new() .route("/text", web::post().to(test_upload_limits_memory)) .route("/file", web::post().to(test_upload_limits_file)) .app_data( MultipartFormConfig::default() .memory_limit(20) .total_limit(usize::MAX), ) }); // Exceeds the 20 byte memory limit let mut form = multipart::Form::default(); form.add_text("field", "this string is 28 bytes long"); let response = send_form(&srv, form, "/text").await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); // Memory limit should not apply when the data is being streamed to disk let mut form = multipart::Form::default(); form.add_text("field", "this string is 28 bytes long"); let response = send_form(&srv, form, "/file").await; assert_eq!(response.status(), StatusCode::OK); } #[actix_rt::test] async fn test_total_limit() { let srv = actix_test::start(|| { App::new() .route("/text", web::post().to(test_upload_limits_memory)) .route("/file", web::post().to(test_upload_limits_file)) .app_data( MultipartFormConfig::default() .memory_limit(usize::MAX) .total_limit(20), ) }); // Within the 20 byte limit let mut form = multipart::Form::default(); form.add_text("field", "7 bytes"); let response = send_form(&srv, form, "/text").await; assert_eq!(response.status(), StatusCode::OK); // Exceeds the 20 byte overall limit let mut form = multipart::Form::default(); form.add_text("field", "this string is 28 bytes long"); let response = send_form(&srv, form, "/text").await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); // Exceeds the 20 byte overall limit let mut form = multipart::Form::default(); form.add_text("field", "this string is 28 bytes long"); let response = send_form(&srv, form, "/file").await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); } #[derive(MultipartForm)] struct TestFieldLevelLimits { #[multipart(limit = "30B")] field: Vec, } async fn test_field_level_limits_route( form: MultipartForm, ) -> impl Responder { assert!(!form.field.is_empty()); HttpResponse::Ok().finish() } #[actix_rt::test] async fn test_field_level_limits() { let srv = actix_test::start(|| { App::new() .route("/", web::post().to(test_field_level_limits_route)) .app_data( MultipartFormConfig::default() .memory_limit(usize::MAX) .total_limit(usize::MAX), ) }); // Within the 30 byte limit let mut form = multipart::Form::default(); form.add_text("field", "this string is 28 bytes long"); let response = send_form(&srv, form, "/").await; assert_eq!(response.status(), StatusCode::OK); // Exceeds the the 30 byte limit let mut form = multipart::Form::default(); form.add_text("field", "this string is more than 30 bytes long"); let response = send_form(&srv, form, "/").await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); // Total of values (14 bytes) is within 30 byte limit for "field" let mut form = multipart::Form::default(); form.add_text("field", "7 bytes"); form.add_text("field", "7 bytes"); let response = send_form(&srv, form, "/").await; assert_eq!(response.status(), StatusCode::OK); // Total of values exceeds 30 byte limit for "field" let mut form = multipart::Form::default(); form.add_text("field", "this string is 28 bytes long"); form.add_text("field", "this string is 28 bytes long"); let response = send_form(&srv, form, "/").await; assert_eq!(response.status(), StatusCode::BAD_REQUEST); } }