diff --git a/Cargo.toml b/Cargo.toml index 26b5b91b2..65e3c6ae8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "actix-http-test", "actix-http", "actix-multipart", + "actix-multipart-derive", "actix-router", "actix-test", "actix-web-actors", @@ -27,6 +28,7 @@ actix-files = { path = "actix-files" } actix-http = { path = "actix-http" } actix-http-test = { path = "actix-http-test" } actix-multipart = { path = "actix-multipart" } +actix-multipart-derive = { path = "actix-multipart-derive" } actix-router = { path = "actix-router" } actix-test = { path = "actix-test" } actix-web = { path = "actix-web" } diff --git a/actix-multipart-derive/Cargo.toml b/actix-multipart-derive/Cargo.toml new file mode 100644 index 000000000..4a30898b4 --- /dev/null +++ b/actix-multipart-derive/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "actix-multipart-derive" +version = "0.5.0" +authors = ["Jacob Halsey "] +description = "Multipart form derive macro for Actix Web" +keywords = ["http", "web", "framework", "async", "futures"] +homepage = "https://actix.rs" +repository = "https://github.com/actix/actix-web.git" +license = "MIT OR Apache-2.0" +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +darling = "0.14" +parse-size = "1" +proc-macro2 = "1" +quote = "1" +syn = "1" + +[dev-dependencies] +actix-multipart = "0.5" +actix-web = "4" +rustversion = "1" +trybuild = "1" diff --git a/actix-multipart-derive/LICENSE-APACHE b/actix-multipart-derive/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/actix-multipart-derive/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/actix-multipart-derive/LICENSE-MIT b/actix-multipart-derive/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/actix-multipart-derive/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/actix-multipart-derive/README.md b/actix-multipart-derive/README.md new file mode 100644 index 000000000..95f80bc79 --- /dev/null +++ b/actix-multipart-derive/README.md @@ -0,0 +1,3 @@ +# actix-multipart-derive + +> The derive macro implementation for actix-multipart. diff --git a/actix-multipart-derive/src/lib.rs b/actix-multipart-derive/src/lib.rs new file mode 100644 index 000000000..9b6ecbae6 --- /dev/null +++ b/actix-multipart-derive/src/lib.rs @@ -0,0 +1,315 @@ +//! Multipart form derive macro for Actix Web. +//! +//! See [`macro@MultipartForm`] for usage examples. + +#![deny(rust_2018_idioms, nonstandard_style)] +#![warn(future_incompatible)] +#![doc(html_logo_url = "https://actix.rs/img/logo.png")] +#![doc(html_favicon_url = "https://actix.rs/favicon.ico")] +#![cfg_attr(docsrs, feature(doc_cfg))] + +use std::{collections::HashSet, convert::TryFrom as _}; + +use darling::{FromDeriveInput, FromField, FromMeta}; +use parse_size::parse_size; +use proc_macro::TokenStream; +use proc_macro2::Ident; +use quote::quote; +use syn::{parse_macro_input, Type}; + +#[derive(FromMeta)] +enum DuplicateField { + Ignore, + Deny, + Replace, +} + +impl Default for DuplicateField { + fn default() -> Self { + Self::Ignore + } +} + +#[derive(FromDeriveInput, Default)] +#[darling(attributes(multipart), default)] +struct MultipartFormAttrs { + deny_unknown_fields: bool, + duplicate_field: DuplicateField, +} + +#[derive(FromField, Default)] +#[darling(attributes(multipart), default)] +struct FieldAttrs { + rename: Option, + limit: Option, +} + +struct ParsedField<'t> { + serialization_name: String, + rust_name: &'t Ident, + limit: Option, + ty: &'t Type, +} + +/// Implements `MultipartCollect` for a struct so that it can be used with the `MultipartForm` +/// extractor. +/// +/// # Basic Use +/// +/// Each field type should implement the `FieldReader` trait: +/// +/// ``` +/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm}; +/// +/// #[derive(MultipartForm)] +/// struct ImageUpload { +/// description: Text, +/// timestamp: Text, +/// image: TempFile, +/// } +/// ``` +/// +/// # Optional and List Fields +/// +/// You can also use `Vec` and `Option` provided that `T: FieldReader`. +/// +/// A [`Vec`] field corresponds to an upload with multiple parts under the [same field +/// name](https://www.rfc-editor.org/rfc/rfc7578#section-4.3). +/// +/// ``` +/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm}; +/// +/// #[derive(MultipartForm)] +/// struct Form { +/// category: Option>, +/// files: Vec, +/// } +/// ``` +/// +/// # Field Renaming +/// +/// You can use the `#[multipart(rename = "foo")]` attribute to receive a field by a different name. +/// +/// ``` +/// use actix_multipart::form::{tempfile::TempFile, MultipartForm}; +/// +/// #[derive(MultipartForm)] +/// struct Form { +/// #[multipart(rename = "files[]")] +/// files: Vec, +/// } +/// ``` +/// +/// # Field Limits +/// +/// You can use the `#[multipart(limit = "")]` attribute to set field level limits. The limit +/// string is parsed using [parse_size]. +/// +/// Note: the form is also subject to the global limits configured using `MultipartFormConfig`. +/// +/// ``` +/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm}; +/// +/// #[derive(MultipartForm)] +/// struct Form { +/// #[multipart(limit = "2 KiB")] +/// description: Text, +/// +/// #[multipart(limit = "512 MiB")] +/// files: Vec, +/// } +/// ``` +/// +/// # Unknown Fields +/// +/// By default fields with an unknown name are ignored. They can be rejected using the +/// `#[multipart(deny_unknown_fields)]` attribute: +/// +/// ``` +/// # use actix_multipart::form::MultipartForm; +/// #[derive(MultipartForm)] +/// #[multipart(deny_unknown_fields)] +/// struct Form { } +/// ``` +/// +/// # Duplicate Fields +/// +/// The behaviour for when multiple fields with the same name are received can be changed using the +/// `#[multipart(duplicate_field = "")]` attribute: +/// +/// - "ignore": (default) Extra fields are ignored. I.e., the first one is persisted. +/// - "deny": A `MultipartError::UnsupportedField` error response is returned. +/// - "replace": Each field is processed, but only the last one is persisted. +/// +/// Note that `Vec` fields will ignore this option. +/// +/// ``` +/// # use actix_multipart::form::MultipartForm; +/// #[derive(MultipartForm)] +/// #[multipart(duplicate_field = "deny")] +/// struct Form { } +/// ``` +/// +/// [parse_size]: https://docs.rs/parse-size/1/parse_size +#[proc_macro_derive(MultipartForm, attributes(multipart))] +pub fn impl_multipart_form(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input: syn::DeriveInput = parse_macro_input!(input); + + let name = &input.ident; + + let data_struct = match &input.data { + syn::Data::Struct(data_struct) => data_struct, + _ => { + return compile_err(syn::Error::new( + input.ident.span(), + "`MultipartForm` can only be derived for structs", + )) + } + }; + + let fields = match &data_struct.fields { + syn::Fields::Named(fields_named) => fields_named, + _ => { + return compile_err(syn::Error::new( + input.ident.span(), + "`MultipartForm` can only be derived for a struct with named fields", + )) + } + }; + + let attrs = match MultipartFormAttrs::from_derive_input(&input) { + Ok(attrs) => attrs, + Err(err) => return err.write_errors().into(), + }; + + // Parse the field attributes + let parsed = match fields + .named + .iter() + .map(|field| { + let rust_name = field.ident.as_ref().unwrap(); + let attrs = FieldAttrs::from_field(field).map_err(|err| err.write_errors())?; + let serialization_name = attrs.rename.unwrap_or_else(|| rust_name.to_string()); + + let limit = match attrs.limit.map(|limit| match parse_size(&limit) { + Ok(size) => Ok(usize::try_from(size).unwrap()), + Err(err) => Err(syn::Error::new( + field.ident.as_ref().unwrap().span(), + format!("Could not parse size limit `{}`: {}", limit, err), + )), + }) { + Some(Err(err)) => return Err(compile_err(err)), + limit => limit.map(Result::unwrap), + }; + + Ok(ParsedField { + serialization_name, + rust_name, + limit, + ty: &field.ty, + }) + }) + .collect::, TokenStream>>() + { + Ok(attrs) => attrs, + Err(err) => return err, + }; + + // Check that field names are unique + let mut set = HashSet::new(); + for field in &parsed { + if !set.insert(field.serialization_name.clone()) { + return compile_err(syn::Error::new( + field.rust_name.span(), + format!("Multiple fields named: `{}`", field.serialization_name), + )); + } + } + + // Return value when a field name is not supported by the form + let unknown_field_result = if attrs.deny_unknown_fields { + quote!(::std::result::Result::Err( + ::actix_multipart::MultipartError::UnsupportedField(field.name().to_string()) + )) + } else { + quote!(::std::result::Result::Ok(())) + }; + + // Value for duplicate action + let duplicate_field = match attrs.duplicate_field { + DuplicateField::Ignore => quote!(::actix_multipart::form::DuplicateField::Ignore), + DuplicateField::Deny => quote!(::actix_multipart::form::DuplicateField::Deny), + DuplicateField::Replace => quote!(::actix_multipart::form::DuplicateField::Replace), + }; + + // limit() implementation + let mut limit_impl = quote!(); + for field in &parsed { + let name = &field.serialization_name; + if let Some(value) = field.limit { + limit_impl.extend(quote!( + #name => ::std::option::Option::Some(#value), + )); + } + } + + // handle_field() implementation + let mut handle_field_impl = quote!(); + for field in &parsed { + let name = &field.serialization_name; + let ty = &field.ty; + + handle_field_impl.extend(quote!( + #name => ::std::boxed::Box::pin( + <#ty as ::actix_multipart::form::FieldGroupReader>::handle_field(req, field, limits, state, #duplicate_field) + ), + )); + } + + // from_state() implementation + let mut from_state_impl = quote!(); + for field in &parsed { + let name = &field.serialization_name; + let rust_name = &field.rust_name; + let ty = &field.ty; + from_state_impl.extend(quote!( + #rust_name: <#ty as ::actix_multipart::form::FieldGroupReader>::from_state(#name, &mut state)?, + )); + } + + let gen = quote! { + impl ::actix_multipart::form::MultipartCollect for #name { + fn limit(field_name: &str) -> ::std::option::Option { + match field_name { + #limit_impl + _ => None, + } + } + + fn handle_field<'t>( + req: &'t ::actix_web::HttpRequest, + field: ::actix_multipart::Field, + limits: &'t mut ::actix_multipart::form::Limits, + state: &'t mut ::actix_multipart::form::State, + ) -> ::std::pin::Pin<::std::boxed::Box> + 't>> { + match field.name() { + #handle_field_impl + _ => return ::std::boxed::Box::pin(::std::future::ready(#unknown_field_result)), + } + } + + fn from_state(mut state: ::actix_multipart::form::State) -> ::std::result::Result { + Ok(Self { + #from_state_impl + }) + } + + } + }; + gen.into() +} + +/// Transform a syn error into a token stream for returning. +fn compile_err(err: syn::Error) -> TokenStream { + TokenStream::from(err.to_compile_error()) +} diff --git a/actix-multipart-derive/tests/trybuild.rs b/actix-multipart-derive/tests/trybuild.rs new file mode 100644 index 000000000..7b9f14ed7 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild.rs @@ -0,0 +1,16 @@ +#[rustversion::stable(1.59)] // MSRV +#[test] +fn compile_macros() { + let t = trybuild::TestCases::new(); + + t.pass("tests/trybuild/all-required.rs"); + t.pass("tests/trybuild/optional-and-list.rs"); + t.pass("tests/trybuild/rename.rs"); + t.pass("tests/trybuild/deny-unknown.rs"); + + t.pass("tests/trybuild/deny-duplicates.rs"); + t.compile_fail("tests/trybuild/deny-parse-fail.rs"); + + t.pass("tests/trybuild/size-limits.rs"); + t.compile_fail("tests/trybuild/size-limit-parse-fail.rs"); +} diff --git a/actix-multipart-derive/tests/trybuild/all-required.rs b/actix-multipart-derive/tests/trybuild/all-required.rs new file mode 100644 index 000000000..1b4a824d9 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/all-required.rs @@ -0,0 +1,19 @@ +use actix_web::{web, App, Responder}; + +use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm}; + +#[derive(Debug, MultipartForm)] +struct ImageUpload { + description: Text, + timestamp: Text, + image: TempFile, +} + +async fn handler(_form: MultipartForm) -> impl Responder { + "Hello World!" +} + +#[actix_web::main] +async fn main() { + App::new().default_service(web::to(handler)); +} diff --git a/actix-multipart-derive/tests/trybuild/deny-duplicates.rs b/actix-multipart-derive/tests/trybuild/deny-duplicates.rs new file mode 100644 index 000000000..9fcc1506c --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/deny-duplicates.rs @@ -0,0 +1,16 @@ +use actix_web::{web, App, Responder}; + +use actix_multipart::form::MultipartForm; + +#[derive(MultipartForm)] +#[multipart(duplicate_field = "deny")] +struct Form {} + +async fn handler(_form: MultipartForm
) -> impl Responder { + "Hello World!" +} + +#[actix_web::main] +async fn main() { + App::new().default_service(web::to(handler)); +} diff --git a/actix-multipart-derive/tests/trybuild/deny-parse-fail.rs b/actix-multipart-derive/tests/trybuild/deny-parse-fail.rs new file mode 100644 index 000000000..5ea566fb0 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/deny-parse-fail.rs @@ -0,0 +1,7 @@ +use actix_multipart::form::MultipartForm; + +#[derive(MultipartForm)] +#[multipart(duplicate_field = "no")] +struct Form {} + +fn main() {} diff --git a/actix-multipart-derive/tests/trybuild/deny-parse-fail.stderr b/actix-multipart-derive/tests/trybuild/deny-parse-fail.stderr new file mode 100644 index 000000000..d25e43525 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/deny-parse-fail.stderr @@ -0,0 +1,5 @@ +error: Unknown literal value `no` + --> tests/trybuild/deny-parse-fail.rs:4:31 + | +4 | #[multipart(duplicate_field = "no")] + | ^^^^ diff --git a/actix-multipart-derive/tests/trybuild/deny-unknown.rs b/actix-multipart-derive/tests/trybuild/deny-unknown.rs new file mode 100644 index 000000000..e03460624 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/deny-unknown.rs @@ -0,0 +1,16 @@ +use actix_web::{web, App, Responder}; + +use actix_multipart::form::MultipartForm; + +#[derive(MultipartForm)] +#[multipart(deny_unknown_fields)] +struct Form {} + +async fn handler(_form: MultipartForm) -> impl Responder { + "Hello World!" +} + +#[actix_web::main] +async fn main() { + App::new().default_service(web::to(handler)); +} diff --git a/actix-multipart-derive/tests/trybuild/optional-and-list.rs b/actix-multipart-derive/tests/trybuild/optional-and-list.rs new file mode 100644 index 000000000..deef3de59 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/optional-and-list.rs @@ -0,0 +1,18 @@ +use actix_web::{web, App, Responder}; + +use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm}; + +#[derive(MultipartForm)] +struct Form { + category: Option>, + files: Vec, +} + +async fn handler(_form: MultipartForm) -> impl Responder { + "Hello World!" +} + +#[actix_web::main] +async fn main() { + App::new().default_service(web::to(handler)); +} diff --git a/actix-multipart-derive/tests/trybuild/rename.rs b/actix-multipart-derive/tests/trybuild/rename.rs new file mode 100644 index 000000000..1f66bbb43 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/rename.rs @@ -0,0 +1,18 @@ +use actix_web::{web, App, Responder}; + +use actix_multipart::form::{tempfile::TempFile, MultipartForm}; + +#[derive(MultipartForm)] +struct Form { + #[multipart(rename = "files[]")] + files: Vec, +} + +async fn handler(_form: MultipartForm) -> impl Responder { + "Hello World!" +} + +#[actix_web::main] +async fn main() { + App::new().default_service(web::to(handler)); +} diff --git a/actix-multipart-derive/tests/trybuild/size-limit-parse-fail.rs b/actix-multipart-derive/tests/trybuild/size-limit-parse-fail.rs new file mode 100644 index 000000000..c3d495317 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/size-limit-parse-fail.rs @@ -0,0 +1,21 @@ +use actix_multipart::form::{text::Text, MultipartForm}; + +#[derive(MultipartForm)] +struct Form { + #[multipart(limit = "2 bytes")] + description: Text, +} + +#[derive(MultipartForm)] +struct Form2 { + #[multipart(limit = "2 megabytes")] + description: Text, +} + +#[derive(MultipartForm)] +struct Form3 { + #[multipart(limit = "four meters")] + description: Text, +} + +fn main() {} diff --git a/actix-multipart-derive/tests/trybuild/size-limit-parse-fail.stderr b/actix-multipart-derive/tests/trybuild/size-limit-parse-fail.stderr new file mode 100644 index 000000000..fc02a78c4 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/size-limit-parse-fail.stderr @@ -0,0 +1,17 @@ +error: Could not parse size limit `2 bytes`: invalid digit found in string + --> tests/trybuild/size-limit-parse-fail.rs:6:5 + | +6 | description: Text, + | ^^^^^^^^^^^ + +error: Could not parse size limit `2 megabytes`: invalid digit found in string + --> tests/trybuild/size-limit-parse-fail.rs:12:5 + | +12 | description: Text, + | ^^^^^^^^^^^ + +error: Could not parse size limit `four meters`: invalid digit found in string + --> tests/trybuild/size-limit-parse-fail.rs:18:5 + | +18 | description: Text, + | ^^^^^^^^^^^ diff --git a/actix-multipart-derive/tests/trybuild/size-limits.rs b/actix-multipart-derive/tests/trybuild/size-limits.rs new file mode 100644 index 000000000..92c3d0db5 --- /dev/null +++ b/actix-multipart-derive/tests/trybuild/size-limits.rs @@ -0,0 +1,21 @@ +use actix_web::{web, App, Responder}; + +use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm}; + +#[derive(MultipartForm)] +struct Form { + #[multipart(limit = "2 KiB")] + description: Text, + + #[multipart(limit = "512 MiB")] + files: Vec, +} + +async fn handler(_form: MultipartForm) -> impl Responder { + "Hello World!" +} + +#[actix_web::main] +async fn main() { + App::new().default_service(web::to(handler)); +} diff --git a/actix-multipart/CHANGES.md b/actix-multipart/CHANGES.md index ed117d7d3..53134db4f 100644 --- a/actix-multipart/CHANGES.md +++ b/actix-multipart/CHANGES.md @@ -1,11 +1,14 @@ # Changes ## Unreleased - 2022-xx-xx +- Added `MultipartForm` typed data extractor. [#2883] + +[#2883]: https://github.com/actix/actix-web/pull/2883 ## 0.5.0 - 2023-01-21 +- `Field::content_type()` now returns `Option<&mime::Mime>`. [#2885] - Minimum supported Rust version (MSRV) is now 1.59 due to transitive `time` dependency. -- `Field::content_type()` now returns `Option<&mime::Mime>` [#2885] [#2885]: https://github.com/actix/actix-web/pull/2885 diff --git a/actix-multipart/Cargo.toml b/actix-multipart/Cargo.toml index 8f768563c..2a14be007 100644 --- a/actix-multipart/Cargo.toml +++ b/actix-multipart/Cargo.toml @@ -1,7 +1,10 @@ [package] name = "actix-multipart" version = "0.5.0" -authors = ["Nikolay Kim "] +authors = [ + "Nikolay Kim ", + "Jacob Halsey ", +] description = "Multipart form support for Actix Web" keywords = ["http", "web", "framework", "async", "futures"] homepage = "https://actix.rs" @@ -9,26 +12,46 @@ repository = "https://github.com/actix/actix-web.git" license = "MIT OR Apache-2.0" edition = "2018" +[package.metadata.docs.rs] +rustdoc-args = ["--cfg", "docsrs"] +all-features = true + +[features] +default = ["tempfile", "derive"] +derive = ["actix-multipart-derive"] +tempfile = ["tempfile-dep", "tokio/fs"] + [lib] name = "actix_multipart" path = "src/lib.rs" [dependencies] +actix-multipart-derive = { version = "=0.5.0", optional = true } actix-utils = "3" actix-web = { version = "4", default-features = false } bytes = "1" derive_more = "0.99.5" futures-core = { version = "0.3.17", default-features = false, features = ["alloc"] } +futures-util = { version = "0.3.17", default-features = false, features = ["alloc"] } httparse = "1.3" local-waker = "0.1" log = "0.4" -mime = "0.3" memchr = "2.5" +mime = "0.3" +serde = "1" +serde_json = "1" +serde_plain = "1" +# TODO(MSRV 1.60): replace with dep: prefix +tempfile-dep = { package = "tempfile", version = "3.4", optional = true } +tokio = { version = "1.18.5", features = ["sync"] } [dev-dependencies] -actix-rt = "2.2" actix-http = "3" +actix-multipart-rfc7578 = "0.10" +actix-rt = "2.2" +actix-test = "0.1" +awc = "3" futures-util = { version = "0.3.17", default-features = false, features = ["alloc"] } tokio = { version = "1.18.5", features = ["sync"] } tokio-stream = "0.1" diff --git a/actix-multipart/src/error.rs b/actix-multipart/src/error.rs index 7d0da35e0..77b5a559f 100644 --- a/actix-multipart/src/error.rs +++ b/actix-multipart/src/error.rs @@ -1,12 +1,15 @@ //! Error and Result module -use actix_web::error::{ParseError, PayloadError}; -use actix_web::http::StatusCode; -use actix_web::ResponseError; + +use actix_web::{ + error::{ParseError, PayloadError}, + http::StatusCode, + ResponseError, +}; use derive_more::{Display, Error, From}; -/// A set of errors that can occur during parsing multipart streams -#[non_exhaustive] +/// A set of errors that can occur during parsing multipart streams. #[derive(Debug, Display, From, Error)] +#[non_exhaustive] pub enum MultipartError { /// Content-Disposition header is not found or is not equal to "form-data". /// @@ -46,12 +49,41 @@ pub enum MultipartError { /// Not consumed #[display(fmt = "Multipart stream is not consumed")] NotConsumed, + + /// An error from a field handler in a form + #[display( + fmt = "An error occurred processing field `{}`: {}", + field_name, + source + )] + Field { + field_name: String, + source: actix_web::Error, + }, + + /// Duplicate field + #[display(fmt = "Duplicate field found for: `{}`", _0)] + #[from(ignore)] + DuplicateField(#[error(not(source))] String), + + /// Missing field + #[display(fmt = "Field with name `{}` is required", _0)] + #[from(ignore)] + MissingField(#[error(not(source))] String), + + /// Unknown field + #[display(fmt = "Unsupported field `{}`", _0)] + #[from(ignore)] + UnsupportedField(#[error(not(source))] String), } /// Return `BadRequest` for `MultipartError` impl ResponseError for MultipartError { fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST + match &self { + MultipartError::Field { source, .. } => source.as_response_error().status_code(), + _ => StatusCode::BAD_REQUEST, + } } } diff --git a/actix-multipart/src/extractor.rs b/actix-multipart/src/extractor.rs index d45c4869c..56ed69ae4 100644 --- a/actix-multipart/src/extractor.rs +++ b/actix-multipart/src/extractor.rs @@ -9,8 +9,7 @@ use crate::server::Multipart; /// /// Content-type: multipart/form-data; /// -/// ## Server example -/// +/// # Examples /// ``` /// use actix_web::{web, HttpResponse, Error}; /// use actix_multipart::Multipart; diff --git a/actix-multipart/src/form/bytes.rs b/actix-multipart/src/form/bytes.rs new file mode 100644 index 000000000..7d64fffce --- /dev/null +++ b/actix-multipart/src/form/bytes.rs @@ -0,0 +1,53 @@ +//! Reads a field into memory. + +use actix_web::HttpRequest; +use bytes::BytesMut; +use futures_core::future::LocalBoxFuture; +use futures_util::TryStreamExt as _; +use mime::Mime; + +use crate::{ + form::{FieldReader, Limits}, + Field, MultipartError, +}; + +/// Read the field into memory. +#[derive(Debug)] +pub struct Bytes { + /// The data. + pub data: bytes::Bytes, + + /// The value of the `Content-Type` header. + pub content_type: Option, + + /// The `filename` value in the `Content-Disposition` header. + pub file_name: Option, +} + +impl<'t> FieldReader<'t> for Bytes { + type Future = LocalBoxFuture<'t, Result>; + + fn read_field( + _: &'t HttpRequest, + mut field: Field, + limits: &'t mut Limits, + ) -> Self::Future { + Box::pin(async move { + let mut buf = BytesMut::with_capacity(131_072); + + while let Some(chunk) = field.try_next().await? { + limits.try_consume_limits(chunk.len(), true)?; + buf.extend(chunk); + } + + Ok(Bytes { + data: buf.freeze(), + content_type: field.content_type().map(ToOwned::to_owned), + file_name: field + .content_disposition() + .get_filename() + .map(str::to_owned), + }) + }) + } +} diff --git a/actix-multipart/src/form/json.rs b/actix-multipart/src/form/json.rs new file mode 100644 index 000000000..9951eaaaf --- /dev/null +++ b/actix-multipart/src/form/json.rs @@ -0,0 +1,195 @@ +//! Deserializes a field as JSON. + +use std::sync::Arc; + +use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError}; +use derive_more::{Deref, DerefMut, Display, Error}; +use futures_core::future::LocalBoxFuture; +use serde::de::DeserializeOwned; + +use crate::{ + form::{bytes::Bytes, FieldReader, Limits}, + Field, MultipartError, +}; + +use super::FieldErrorHandler; + +/// Deserialize from JSON. +#[derive(Debug, Deref, DerefMut)] +pub struct Json(pub T); + +impl Json { + pub fn into_inner(self) -> T { + self.0 + } +} + +impl<'t, T> FieldReader<'t> for Json +where + T: DeserializeOwned + 'static, +{ + type Future = LocalBoxFuture<'t, Result>; + + fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future { + Box::pin(async move { + let config = JsonConfig::from_req(req); + let field_name = field.name().to_owned(); + + if config.validate_content_type { + let valid = if let Some(mime) = field.content_type() { + mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON) + } else { + false + }; + + if !valid { + return Err(MultipartError::Field { + field_name, + source: config.map_error(req, JsonFieldError::ContentType), + }); + } + } + + let bytes = Bytes::read_field(req, field, limits).await?; + + Ok(Json(serde_json::from_slice(bytes.data.as_ref()).map_err( + |err| MultipartError::Field { + field_name, + source: config.map_error(req, JsonFieldError::Deserialize(err)), + }, + )?)) + }) + } +} + +#[derive(Debug, Display, Error)] +#[non_exhaustive] +pub enum JsonFieldError { + /// Deserialize error. + #[display(fmt = "Json deserialize error: {}", _0)] + Deserialize(serde_json::Error), + + /// Content type error. + #[display(fmt = "Content type error")] + ContentType, +} + +impl ResponseError for JsonFieldError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// Configuration for the [`Json`] field reader. +#[derive(Clone)] +pub struct JsonConfig { + err_handler: FieldErrorHandler, + validate_content_type: bool, +} + +const DEFAULT_CONFIG: JsonConfig = JsonConfig { + err_handler: None, + validate_content_type: true, +}; + +impl JsonConfig { + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(JsonFieldError, &HttpRequest) -> Error + Send + Sync + 'static, + { + self.err_handler = Some(Arc::new(f)); + self + } + + /// Extract payload config from app data. Check both `T` and `Data`, in that order, and fall + /// back to the default payload config. + fn from_req(req: &HttpRequest) -> &Self { + req.app_data::() + .or_else(|| req.app_data::>().map(|d| d.as_ref())) + .unwrap_or(&DEFAULT_CONFIG) + } + + fn map_error(&self, req: &HttpRequest, err: JsonFieldError) -> Error { + if let Some(err_handler) = self.err_handler.as_ref() { + (*err_handler)(err, req) + } else { + err.into() + } + } + + /// Sets whether or not the field must have a valid `Content-Type` header to be parsed. + pub fn validate_content_type(mut self, validate_content_type: bool) -> Self { + self.validate_content_type = validate_content_type; + self + } +} + +impl Default for JsonConfig { + fn default() -> Self { + DEFAULT_CONFIG + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, io::Cursor}; + + use actix_multipart_rfc7578::client::multipart; + use actix_web::{http::StatusCode, web, App, HttpResponse, Responder}; + + use crate::form::{ + json::{Json, JsonConfig}, + tests::send_form, + MultipartForm, + }; + + #[derive(MultipartForm)] + struct JsonForm { + json: Json>, + } + + async fn test_json_route(form: MultipartForm) -> impl Responder { + let mut expected = HashMap::new(); + expected.insert("key1".to_owned(), "value1".to_owned()); + expected.insert("key2".to_owned(), "value2".to_owned()); + assert_eq!(&*form.json, &expected); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_json_without_content_type() { + let srv = actix_test::start(|| { + App::new() + .route("/", web::post().to(test_json_route)) + .app_data(JsonConfig::default().validate_content_type(false)) + }); + + let mut form = multipart::Form::default(); + form.add_text("json", "{\"key1\": \"value1\", \"key2\": \"value2\"}"); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_content_type_validation() { + let srv = actix_test::start(|| { + App::new() + .route("/", web::post().to(test_json_route)) + .app_data(JsonConfig::default().validate_content_type(true)) + }); + + // Deny because wrong content type + let bytes = Cursor::new("{\"key1\": \"value1\", \"key2\": \"value2\"}"); + let mut form = multipart::Form::default(); + form.add_reader_file_with_mime("json", bytes, "", mime::APPLICATION_OCTET_STREAM); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // Allow because correct content type + let bytes = Cursor::new("{\"key1\": \"value1\", \"key2\": \"value2\"}"); + let mut form = multipart::Form::default(); + form.add_reader_file_with_mime("json", bytes, "", mime::APPLICATION_JSON); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + } +} diff --git a/actix-multipart/src/form/mod.rs b/actix-multipart/src/form/mod.rs new file mode 100644 index 000000000..b0285d97e --- /dev/null +++ b/actix-multipart/src/form/mod.rs @@ -0,0 +1,744 @@ +//! Process and extract typed data from a multipart stream. + +use std::{ + any::Any, + collections::HashMap, + future::{ready, Future}, + sync::Arc, +}; + +use actix_web::{dev, error::PayloadError, web, Error, FromRequest, HttpRequest}; +use derive_more::{Deref, DerefMut}; +use futures_core::future::LocalBoxFuture; +use futures_util::{TryFutureExt as _, TryStreamExt as _}; + +use crate::{Field, Multipart, MultipartError}; + +pub mod bytes; +pub mod json; +#[cfg_attr(docsrs, doc(cfg(feature = "tempfile")))] +#[cfg(feature = "tempfile")] +pub mod tempfile; +pub mod text; + +#[cfg_attr(docsrs, doc(cfg(feature = "derive")))] +#[cfg(feature = "derive")] +pub use actix_multipart_derive::MultipartForm; + +type FieldErrorHandler = Option Error + Send + Sync>>; + +/// Trait that data types to be used in a multipart form struct should implement. +/// +/// It represents an asynchronous handler that processes a multipart field to produce `Self`. +pub trait FieldReader<'t>: Sized + Any { + /// Future that resolves to a `Self`. + type Future: Future>; + + /// The form will call this function to handle the field. + fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future; +} + +/// Used to accumulate the state of the loaded fields. +#[doc(hidden)] +#[derive(Default, Deref, DerefMut)] +pub struct State(pub HashMap>); + +/// Trait that the field collection types implement, i.e. `Vec`, `Option`, or `T` itself. +#[doc(hidden)] +pub trait FieldGroupReader<'t>: Sized + Any { + type Future: Future>; + + /// The form will call this function for each matching field. + fn handle_field( + req: &'t HttpRequest, + field: Field, + limits: &'t mut Limits, + state: &'t mut State, + duplicate_field: DuplicateField, + ) -> Self::Future; + + /// Construct `Self` from the group of processed fields. + fn from_state(name: &str, state: &'t mut State) -> Result; +} + +impl<'t, T> FieldGroupReader<'t> for Option +where + T: FieldReader<'t>, +{ + type Future = LocalBoxFuture<'t, Result<(), MultipartError>>; + + fn handle_field( + req: &'t HttpRequest, + field: Field, + limits: &'t mut Limits, + state: &'t mut State, + duplicate_field: DuplicateField, + ) -> Self::Future { + if state.contains_key(field.name()) { + match duplicate_field { + DuplicateField::Ignore => return Box::pin(ready(Ok(()))), + + DuplicateField::Deny => { + return Box::pin(ready(Err(MultipartError::DuplicateField( + field.name().to_owned(), + )))) + } + + DuplicateField::Replace => {} + } + } + + Box::pin(async move { + let field_name = field.name().to_owned(); + let t = T::read_field(req, field, limits).await?; + state.insert(field_name, Box::new(t)); + Ok(()) + }) + } + + fn from_state(name: &str, state: &'t mut State) -> Result { + Ok(state.remove(name).map(|m| *m.downcast::().unwrap())) + } +} + +impl<'t, T> FieldGroupReader<'t> for Vec +where + T: FieldReader<'t>, +{ + type Future = LocalBoxFuture<'t, Result<(), MultipartError>>; + + fn handle_field( + req: &'t HttpRequest, + field: Field, + limits: &'t mut Limits, + state: &'t mut State, + _duplicate_field: DuplicateField, + ) -> Self::Future { + Box::pin(async move { + // Note: Vec GroupReader always allows duplicates + + let field_name = field.name().to_owned(); + + let vec = state + .entry(field_name) + .or_insert_with(|| Box::>::default()) + .downcast_mut::>() + .unwrap(); + + let item = T::read_field(req, field, limits).await?; + vec.push(item); + + Ok(()) + }) + } + + fn from_state(name: &str, state: &'t mut State) -> Result { + Ok(state + .remove(name) + .map(|m| *m.downcast::>().unwrap()) + .unwrap_or_default()) + } +} + +impl<'t, T> FieldGroupReader<'t> for T +where + T: FieldReader<'t>, +{ + type Future = LocalBoxFuture<'t, Result<(), MultipartError>>; + + fn handle_field( + req: &'t HttpRequest, + field: Field, + limits: &'t mut Limits, + state: &'t mut State, + duplicate_field: DuplicateField, + ) -> Self::Future { + if state.contains_key(field.name()) { + match duplicate_field { + DuplicateField::Ignore => return Box::pin(ready(Ok(()))), + + DuplicateField::Deny => { + return Box::pin(ready(Err(MultipartError::DuplicateField( + field.name().to_owned(), + )))) + } + + DuplicateField::Replace => {} + } + } + + Box::pin(async move { + let field_name = field.name().to_owned(); + let t = T::read_field(req, field, limits).await?; + state.insert(field_name, Box::new(t)); + Ok(()) + }) + } + + fn from_state(name: &str, state: &'t mut State) -> Result { + state + .remove(name) + .map(|m| *m.downcast::().unwrap()) + .ok_or_else(|| MultipartError::MissingField(name.to_owned())) + } +} + +/// Trait that allows a type to be used in the [`struct@MultipartForm`] extractor. +/// +/// You should use the [`macro@MultipartForm`] macro to derive this for your struct. +pub trait MultipartCollect: Sized { + /// An optional limit in bytes to be applied a given field name. Note this limit will be shared + /// across all fields sharing the same name. + fn limit(field_name: &str) -> Option; + + /// The extractor will call this function for each incoming field, the state can be updated + /// with the processed field data. + fn handle_field<'t>( + req: &'t HttpRequest, + field: Field, + limits: &'t mut Limits, + state: &'t mut State, + ) -> LocalBoxFuture<'t, Result<(), MultipartError>>; + + /// Once all the fields have been processed and stored in the state, this is called + /// to convert into the struct representation. + fn from_state(state: State) -> Result; +} + +#[doc(hidden)] +pub enum DuplicateField { + /// Additional fields are not processed. + Ignore, + + /// An error will be raised. + Deny, + + /// All fields will be processed, the last one will replace all previous. + Replace, +} + +/// Used to keep track of the remaining limits for the form and current field. +pub struct Limits { + pub total_limit_remaining: usize, + pub memory_limit_remaining: usize, + pub field_limit_remaining: Option, +} + +impl Limits { + pub fn new(total_limit: usize, memory_limit: usize) -> Self { + Self { + total_limit_remaining: total_limit, + memory_limit_remaining: memory_limit, + field_limit_remaining: None, + } + } + + /// This function should be called within a [`FieldReader`] when reading each chunk of a field + /// to ensure that the form limits are not exceeded. + /// + /// # Arguments + /// + /// * `bytes` - The number of bytes being read from this chunk + /// * `in_memory` - Whether to consume from the memory limits + pub fn try_consume_limits( + &mut self, + bytes: usize, + in_memory: bool, + ) -> Result<(), MultipartError> { + self.total_limit_remaining = self + .total_limit_remaining + .checked_sub(bytes) + .ok_or(MultipartError::Payload(PayloadError::Overflow))?; + + if in_memory { + self.memory_limit_remaining = self + .memory_limit_remaining + .checked_sub(bytes) + .ok_or(MultipartError::Payload(PayloadError::Overflow))?; + } + + if let Some(field_limit) = self.field_limit_remaining { + self.field_limit_remaining = Some( + field_limit + .checked_sub(bytes) + .ok_or(MultipartError::Payload(PayloadError::Overflow))?, + ); + } + + Ok(()) + } +} + +/// Typed `multipart/form-data` extractor. +/// +/// To extract typed data from a multipart stream, the inner type `T` must implement the +/// [`MultipartCollect`] trait. You should use the [`macro@MultipartForm`] macro to derive this +/// for your struct. +/// +/// Add a [`MultipartFormConfig`] to your app data to configure extraction. +#[derive(Deref, DerefMut)] +pub struct MultipartForm(pub T); + +impl MultipartForm { + /// Unwrap into inner `T` value. + pub fn into_inner(self) -> T { + self.0 + } +} + +impl FromRequest for MultipartForm +where + T: MultipartCollect, +{ + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + #[inline] + fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future { + let mut payload = Multipart::new(req.headers(), payload.take()); + + let config = MultipartFormConfig::from_req(req); + let mut limits = Limits::new(config.total_limit, config.memory_limit); + + let req = req.clone(); + let req2 = req.clone(); + let err_handler = config.err_handler.clone(); + + Box::pin( + async move { + let mut state = State::default(); + // We need to ensure field limits are shared for all instances of this field name + let mut field_limits = HashMap::>::new(); + + while let Some(field) = payload.try_next().await? { + // Retrieve the limit for this field + let entry = field_limits + .entry(field.name().to_owned()) + .or_insert_with(|| T::limit(field.name())); + limits.field_limit_remaining = entry.to_owned(); + + T::handle_field(&req, field, &mut limits, &mut state).await?; + + // Update the stored limit + *entry = limits.field_limit_remaining; + } + let inner = T::from_state(state)?; + Ok(MultipartForm(inner)) + } + .map_err(move |err| { + if let Some(handler) = err_handler { + (*handler)(err, &req2) + } else { + err.into() + } + }), + ) + } +} + +type MultipartFormErrorHandler = + Option Error + Send + Sync>>; + +/// [`struct@MultipartForm`] extractor configuration. +/// +/// Add to your app data to have it picked up by [`struct@MultipartForm`] extractors. +#[derive(Clone)] +pub struct MultipartFormConfig { + total_limit: usize, + memory_limit: usize, + err_handler: MultipartFormErrorHandler, +} + +impl MultipartFormConfig { + /// Sets maximum accepted payload size for the entire form. By default this limit is 50MiB. + pub fn total_limit(mut self, total_limit: usize) -> Self { + self.total_limit = total_limit; + self + } + + /// Sets maximum accepted data that will be read into memory. By default this limit is 2MiB. + pub fn memory_limit(mut self, memory_limit: usize) -> Self { + self.memory_limit = memory_limit; + self + } + + /// Sets custom error handler. + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(MultipartError, &HttpRequest) -> Error + Send + Sync + 'static, + { + self.err_handler = Some(Arc::new(f)); + self + } + + /// Extracts payload config from app data. Check both `T` and `Data`, in that order, and fall + /// back to the default payload config. + fn from_req(req: &HttpRequest) -> &Self { + req.app_data::() + .or_else(|| req.app_data::>().map(|d| d.as_ref())) + .unwrap_or(&DEFAULT_CONFIG) + } +} + +const DEFAULT_CONFIG: MultipartFormConfig = MultipartFormConfig { + total_limit: 52_428_800, // 50 MiB + memory_limit: 2_097_152, // 2 MiB + err_handler: None, +}; + +impl Default for MultipartFormConfig { + fn default() -> Self { + DEFAULT_CONFIG + } +} + +#[cfg(test)] +mod tests { + use actix_http::encoding::Decoder; + use actix_multipart_rfc7578::client::multipart; + use actix_test::TestServer; + use actix_web::{dev::Payload, http::StatusCode, web, App, HttpResponse, Responder}; + use awc::{Client, ClientResponse}; + + use super::MultipartForm; + use crate::form::{bytes::Bytes, tempfile::TempFile, text::Text, MultipartFormConfig}; + + pub async fn send_form( + srv: &TestServer, + form: multipart::Form<'static>, + uri: &'static str, + ) -> ClientResponse> { + Client::default() + .post(srv.url(uri)) + .content_type(form.content_type()) + .send_body(multipart::Body::from(form)) + .await + .unwrap() + } + + /// Test `Option` fields. + #[derive(MultipartForm)] + struct TestOptions { + field1: Option>, + field2: Option>, + } + + async fn test_options_route(form: MultipartForm) -> impl Responder { + assert!(form.field1.is_some()); + assert!(form.field2.is_none()); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_options() { + let srv = + actix_test::start(|| App::new().route("/", web::post().to(test_options_route))); + + let mut form = multipart::Form::default(); + form.add_text("field1", "value"); + + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + } + + /// Test `Vec` fields. + #[derive(MultipartForm)] + struct TestVec { + list1: Vec>, + list2: Vec>, + } + + async fn test_vec_route(form: MultipartForm) -> impl Responder { + let form = form.into_inner(); + let strings = form + .list1 + .into_iter() + .map(|s| s.into_inner()) + .collect::>(); + assert_eq!(strings, vec!["value1", "value2", "value3"]); + assert_eq!(form.list2.len(), 0); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_vec() { + let srv = actix_test::start(|| App::new().route("/", web::post().to(test_vec_route))); + + let mut form = multipart::Form::default(); + form.add_text("list1", "value1"); + form.add_text("list1", "value2"); + form.add_text("list1", "value3"); + + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + } + + /// Test the `rename` field attribute. + #[derive(MultipartForm)] + struct TestFieldRenaming { + #[multipart(rename = "renamed")] + field1: Text, + #[multipart(rename = "field1")] + field2: Text, + field3: Text, + } + + async fn test_field_renaming_route( + form: MultipartForm, + ) -> impl Responder { + assert_eq!(&*form.field1, "renamed"); + assert_eq!(&*form.field2, "field1"); + assert_eq!(&*form.field3, "field3"); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_field_renaming() { + let srv = actix_test::start(|| { + App::new().route("/", web::post().to(test_field_renaming_route)) + }); + + let mut form = multipart::Form::default(); + form.add_text("renamed", "renamed"); + form.add_text("field1", "field1"); + form.add_text("field3", "field3"); + + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + } + + /// Test the `deny_unknown_fields` struct attribute. + #[derive(MultipartForm)] + #[multipart(deny_unknown_fields)] + struct TestDenyUnknown {} + + #[derive(MultipartForm)] + struct TestAllowUnknown {} + + async fn test_deny_unknown_route(_: MultipartForm) -> impl Responder { + HttpResponse::Ok().finish() + } + + async fn test_allow_unknown_route(_: MultipartForm) -> impl Responder { + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_deny_unknown() { + let srv = actix_test::start(|| { + App::new() + .route("/deny", web::post().to(test_deny_unknown_route)) + .route("/allow", web::post().to(test_allow_unknown_route)) + }); + + let mut form = multipart::Form::default(); + form.add_text("unknown", "value"); + let response = send_form(&srv, form, "/deny").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let mut form = multipart::Form::default(); + form.add_text("unknown", "value"); + let response = send_form(&srv, form, "/allow").await; + assert_eq!(response.status(), StatusCode::OK); + } + + /// Test the `duplicate_field` struct attribute. + #[derive(MultipartForm)] + #[multipart(duplicate_field = "deny")] + struct TestDuplicateDeny { + _field: Text, + } + + #[derive(MultipartForm)] + #[multipart(duplicate_field = "replace")] + struct TestDuplicateReplace { + field: Text, + } + + #[derive(MultipartForm)] + #[multipart(duplicate_field = "ignore")] + struct TestDuplicateIgnore { + field: Text, + } + + async fn test_duplicate_deny_route(_: MultipartForm) -> impl Responder { + HttpResponse::Ok().finish() + } + + async fn test_duplicate_replace_route( + form: MultipartForm, + ) -> impl Responder { + assert_eq!(&*form.field, "second_value"); + HttpResponse::Ok().finish() + } + + async fn test_duplicate_ignore_route( + form: MultipartForm, + ) -> impl Responder { + assert_eq!(&*form.field, "first_value"); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_duplicate_field() { + let srv = actix_test::start(|| { + App::new() + .route("/deny", web::post().to(test_duplicate_deny_route)) + .route("/replace", web::post().to(test_duplicate_replace_route)) + .route("/ignore", web::post().to(test_duplicate_ignore_route)) + }); + + let mut form = multipart::Form::default(); + form.add_text("_field", "first_value"); + form.add_text("_field", "second_value"); + let response = send_form(&srv, form, "/deny").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let mut form = multipart::Form::default(); + form.add_text("field", "first_value"); + form.add_text("field", "second_value"); + let response = send_form(&srv, form, "/replace").await; + assert_eq!(response.status(), StatusCode::OK); + + let mut form = multipart::Form::default(); + form.add_text("field", "first_value"); + form.add_text("field", "second_value"); + let response = send_form(&srv, form, "/ignore").await; + assert_eq!(response.status(), StatusCode::OK); + } + + /// Test the Limits. + #[derive(MultipartForm)] + struct TestMemoryUploadLimits { + field: Bytes, + } + + #[derive(MultipartForm)] + struct TestFileUploadLimits { + field: TempFile, + } + + async fn test_upload_limits_memory( + form: MultipartForm, + ) -> impl Responder { + assert!(!form.field.data.is_empty()); + HttpResponse::Ok().finish() + } + + async fn test_upload_limits_file( + form: MultipartForm, + ) -> impl Responder { + assert!(form.field.size > 0); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_memory_limits() { + let srv = actix_test::start(|| { + App::new() + .route("/text", web::post().to(test_upload_limits_memory)) + .route("/file", web::post().to(test_upload_limits_file)) + .app_data( + MultipartFormConfig::default() + .memory_limit(20) + .total_limit(usize::MAX), + ) + }); + + // Exceeds the 20 byte memory limit + let mut form = multipart::Form::default(); + form.add_text("field", "this string is 28 bytes long"); + let response = send_form(&srv, form, "/text").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // Memory limit should not apply when the data is being streamed to disk + let mut form = multipart::Form::default(); + form.add_text("field", "this string is 28 bytes long"); + let response = send_form(&srv, form, "/file").await; + assert_eq!(response.status(), StatusCode::OK); + } + + #[actix_rt::test] + async fn test_total_limit() { + let srv = actix_test::start(|| { + App::new() + .route("/text", web::post().to(test_upload_limits_memory)) + .route("/file", web::post().to(test_upload_limits_file)) + .app_data( + MultipartFormConfig::default() + .memory_limit(usize::MAX) + .total_limit(20), + ) + }); + + // Within the 20 byte limit + let mut form = multipart::Form::default(); + form.add_text("field", "7 bytes"); + let response = send_form(&srv, form, "/text").await; + assert_eq!(response.status(), StatusCode::OK); + + // Exceeds the 20 byte overall limit + let mut form = multipart::Form::default(); + form.add_text("field", "this string is 28 bytes long"); + let response = send_form(&srv, form, "/text").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // Exceeds the 20 byte overall limit + let mut form = multipart::Form::default(); + form.add_text("field", "this string is 28 bytes long"); + let response = send_form(&srv, form, "/file").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + + #[derive(MultipartForm)] + struct TestFieldLevelLimits { + #[multipart(limit = "30B")] + field: Vec, + } + + async fn test_field_level_limits_route( + form: MultipartForm, + ) -> impl Responder { + assert!(!form.field.is_empty()); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_field_level_limits() { + let srv = actix_test::start(|| { + App::new() + .route("/", web::post().to(test_field_level_limits_route)) + .app_data( + MultipartFormConfig::default() + .memory_limit(usize::MAX) + .total_limit(usize::MAX), + ) + }); + + // Within the 30 byte limit + let mut form = multipart::Form::default(); + form.add_text("field", "this string is 28 bytes long"); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + + // Exceeds the the 30 byte limit + let mut form = multipart::Form::default(); + form.add_text("field", "this string is more than 30 bytes long"); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // Total of values (14 bytes) is within 30 byte limit for "field" + let mut form = multipart::Form::default(); + form.add_text("field", "7 bytes"); + form.add_text("field", "7 bytes"); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + + // Total of values exceeds 30 byte limit for "field" + let mut form = multipart::Form::default(); + form.add_text("field", "this string is 28 bytes long"); + form.add_text("field", "this string is 28 bytes long"); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } +} diff --git a/actix-multipart/src/form/tempfile.rs b/actix-multipart/src/form/tempfile.rs new file mode 100644 index 000000000..3c637e717 --- /dev/null +++ b/actix-multipart/src/form/tempfile.rs @@ -0,0 +1,206 @@ +//! Writes a field to a temporary file on disk. + +use std::{ + io, + path::{Path, PathBuf}, + sync::Arc, +}; + +use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError}; +use derive_more::{Display, Error}; +use futures_core::future::LocalBoxFuture; +use futures_util::TryStreamExt as _; +use mime::Mime; +use tempfile_dep::NamedTempFile; +use tokio::io::AsyncWriteExt; + +use super::FieldErrorHandler; +use crate::{ + form::{FieldReader, Limits}, + Field, MultipartError, +}; + +/// Write the field to a temporary file on disk. +#[derive(Debug)] +pub struct TempFile { + /// The temporary file on disk. + pub file: NamedTempFile, + + /// The value of the `content-type` header. + pub content_type: Option, + + /// The `filename` value in the `content-disposition` header. + pub file_name: Option, + + /// The size in bytes of the file. + pub size: usize, +} + +impl<'t> FieldReader<'t> for TempFile { + type Future = LocalBoxFuture<'t, Result>; + + fn read_field( + req: &'t HttpRequest, + mut field: Field, + limits: &'t mut Limits, + ) -> Self::Future { + Box::pin(async move { + let config = TempFileConfig::from_req(req); + let field_name = field.name().to_owned(); + let mut size = 0; + + let file = config.create_tempfile().map_err(|err| { + config.map_error(req, &field_name, TempFileError::FileIo(err)) + })?; + + let mut file_async = tokio::fs::File::from_std(file.reopen().map_err(|err| { + config.map_error(req, &field_name, TempFileError::FileIo(err)) + })?); + + while let Some(chunk) = field.try_next().await? { + limits.try_consume_limits(chunk.len(), false)?; + size += chunk.len(); + file_async.write_all(chunk.as_ref()).await.map_err(|err| { + config.map_error(req, &field_name, TempFileError::FileIo(err)) + })?; + } + + file_async.flush().await.map_err(|err| { + config.map_error(req, &field_name, TempFileError::FileIo(err)) + })?; + + Ok(TempFile { + file, + content_type: field.content_type().map(ToOwned::to_owned), + file_name: field + .content_disposition() + .get_filename() + .map(str::to_owned), + size, + }) + }) + } +} + +#[derive(Debug, Display, Error)] +#[non_exhaustive] +pub enum TempFileError { + /// File I/O Error + #[display(fmt = "File I/O error: {}", _0)] + FileIo(std::io::Error), +} + +impl ResponseError for TempFileError { + fn status_code(&self) -> StatusCode { + StatusCode::INTERNAL_SERVER_ERROR + } +} + +/// Configuration for the [`TempFile`] field reader. +#[derive(Clone)] +pub struct TempFileConfig { + err_handler: FieldErrorHandler, + directory: Option, +} + +impl TempFileConfig { + fn create_tempfile(&self) -> io::Result { + if let Some(ref dir) = self.directory { + NamedTempFile::new_in(dir) + } else { + NamedTempFile::new() + } + } +} + +impl TempFileConfig { + /// Sets custom error handler. + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(TempFileError, &HttpRequest) -> Error + Send + Sync + 'static, + { + self.err_handler = Some(Arc::new(f)); + self + } + + /// Extracts payload config from app data. Check both `T` and `Data`, in that order, and fall + /// back to the default payload config. + fn from_req(req: &HttpRequest) -> &Self { + req.app_data::() + .or_else(|| req.app_data::>().map(|d| d.as_ref())) + .unwrap_or(&DEFAULT_CONFIG) + } + + fn map_error( + &self, + req: &HttpRequest, + field_name: &str, + err: TempFileError, + ) -> MultipartError { + let source = if let Some(ref err_handler) = self.err_handler { + (err_handler)(err, req) + } else { + err.into() + }; + + MultipartError::Field { + field_name: field_name.to_owned(), + source, + } + } + + /// Sets the directory that temp files will be created in. + /// + /// The default temporary file location is platform dependent. + pub fn directory(mut self, dir: impl AsRef) -> Self { + self.directory = Some(dir.as_ref().to_owned()); + self + } +} + +const DEFAULT_CONFIG: TempFileConfig = TempFileConfig { + err_handler: None, + directory: None, +}; + +impl Default for TempFileConfig { + fn default() -> Self { + DEFAULT_CONFIG + } +} + +#[cfg(test)] +mod tests { + use std::io::{Cursor, Read}; + + use actix_multipart_rfc7578::client::multipart; + use actix_web::{http::StatusCode, web, App, HttpResponse, Responder}; + + use crate::form::{tempfile::TempFile, tests::send_form, MultipartForm}; + + #[derive(MultipartForm)] + struct FileForm { + file: TempFile, + } + + async fn test_file_route(form: MultipartForm) -> impl Responder { + let mut form = form.into_inner(); + let mut contents = String::new(); + form.file.file.read_to_string(&mut contents).unwrap(); + assert_eq!(contents, "Hello, world!"); + assert_eq!(form.file.file_name.unwrap(), "testfile.txt"); + assert_eq!(form.file.content_type.unwrap(), mime::TEXT_PLAIN); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_file_upload() { + let srv = actix_test::start(|| App::new().route("/", web::post().to(test_file_route))); + + let mut form = multipart::Form::default(); + let bytes = Cursor::new("Hello, world!"); + form.add_reader_file_with_mime("file", bytes, "testfile.txt", mime::TEXT_PLAIN); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + } +} diff --git a/actix-multipart/src/form/text.rs b/actix-multipart/src/form/text.rs new file mode 100644 index 000000000..83e211524 --- /dev/null +++ b/actix-multipart/src/form/text.rs @@ -0,0 +1,196 @@ +//! Deserializes a field from plain text. + +use std::{str, sync::Arc}; + +use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError}; +use derive_more::{Deref, DerefMut, Display, Error}; +use futures_core::future::LocalBoxFuture; +use serde::de::DeserializeOwned; + +use super::FieldErrorHandler; +use crate::{ + form::{bytes::Bytes, FieldReader, Limits}, + Field, MultipartError, +}; + +/// Deserialize from plain text. +/// +/// Internally this uses [`serde_plain`] for deserialization, which supports primitive types +/// including strings, numbers, and simple enums. +#[derive(Debug, Deref, DerefMut)] +pub struct Text(pub T); + +impl Text { + /// Unwraps into inner value. + pub fn into_inner(self) -> T { + self.0 + } +} + +impl<'t, T> FieldReader<'t> for Text +where + T: DeserializeOwned + 'static, +{ + type Future = LocalBoxFuture<'t, Result>; + + fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future { + Box::pin(async move { + let config = TextConfig::from_req(req); + let field_name = field.name().to_owned(); + + if config.validate_content_type { + let valid = if let Some(mime) = field.content_type() { + mime.subtype() == mime::PLAIN || mime.suffix() == Some(mime::PLAIN) + } else { + // https://datatracker.ietf.org/doc/html/rfc7578#section-4.4 + // content type defaults to text/plain, so None should be considered valid + true + }; + + if !valid { + return Err(MultipartError::Field { + field_name, + source: config.map_error(req, TextError::ContentType), + }); + } + } + + let bytes = Bytes::read_field(req, field, limits).await?; + + let text = str::from_utf8(&bytes.data).map_err(|err| MultipartError::Field { + field_name: field_name.clone(), + source: config.map_error(req, TextError::Utf8Error(err)), + })?; + + Ok(Text(serde_plain::from_str(text).map_err(|err| { + MultipartError::Field { + field_name, + source: config.map_error(req, TextError::Deserialize(err)), + } + })?)) + }) + } +} + +#[derive(Debug, Display, Error)] +#[non_exhaustive] +pub enum TextError { + /// UTF-8 decoding error. + #[display(fmt = "UTF-8 decoding error: {}", _0)] + Utf8Error(str::Utf8Error), + + /// Deserialize error. + #[display(fmt = "Plain text deserialize error: {}", _0)] + Deserialize(serde_plain::Error), + + /// Content type error. + #[display(fmt = "Content type error")] + ContentType, +} + +impl ResponseError for TextError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + +/// Configuration for the [`Text`] field reader. +#[derive(Clone)] +pub struct TextConfig { + err_handler: FieldErrorHandler, + validate_content_type: bool, +} + +impl TextConfig { + /// Sets custom error handler. + pub fn error_handler(mut self, f: F) -> Self + where + F: Fn(TextError, &HttpRequest) -> Error + Send + Sync + 'static, + { + self.err_handler = Some(Arc::new(f)); + self + } + + /// Extracts payload config from app data. Check both `T` and `Data`, in that order, and fall + /// back to the default payload config. + fn from_req(req: &HttpRequest) -> &Self { + req.app_data::() + .or_else(|| req.app_data::>().map(|d| d.as_ref())) + .unwrap_or(&DEFAULT_CONFIG) + } + + fn map_error(&self, req: &HttpRequest, err: TextError) -> Error { + if let Some(ref err_handler) = self.err_handler { + (err_handler)(err, req) + } else { + err.into() + } + } + + /// Sets whether or not the field must have a valid `Content-Type` header to be parsed. + /// + /// Note that an empty `Content-Type` is also accepted, as the multipart specification defines + /// `text/plain` as the default for text fields. + pub fn validate_content_type(mut self, validate_content_type: bool) -> Self { + self.validate_content_type = validate_content_type; + self + } +} + +const DEFAULT_CONFIG: TextConfig = TextConfig { + err_handler: None, + validate_content_type: true, +}; + +impl Default for TextConfig { + fn default() -> Self { + DEFAULT_CONFIG + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use actix_multipart_rfc7578::client::multipart; + use actix_web::{http::StatusCode, web, App, HttpResponse, Responder}; + + use crate::form::{ + tests::send_form, + text::{Text, TextConfig}, + MultipartForm, + }; + + #[derive(MultipartForm)] + struct TextForm { + number: Text, + } + + async fn test_text_route(form: MultipartForm) -> impl Responder { + assert_eq!(*form.number, 1025); + HttpResponse::Ok().finish() + } + + #[actix_rt::test] + async fn test_content_type_validation() { + let srv = actix_test::start(|| { + App::new() + .route("/", web::post().to(test_text_route)) + .app_data(TextConfig::default().validate_content_type(true)) + }); + + // Deny because wrong content type + let bytes = Cursor::new("1025"); + let mut form = multipart::Form::default(); + form.add_reader_file_with_mime("number", bytes, "", mime::APPLICATION_OCTET_STREAM); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + // Allow because correct content type + let bytes = Cursor::new("1025"); + let mut form = multipart::Form::default(); + form.add_reader_file_with_mime("number", bytes, "", mime::TEXT_PLAIN); + let response = send_form(&srv, form, "/").await; + assert_eq!(response.status(), StatusCode::OK); + } +} diff --git a/actix-multipart/src/lib.rs b/actix-multipart/src/lib.rs index 37d03db49..c8fba77d0 100644 --- a/actix-multipart/src/lib.rs +++ b/actix-multipart/src/lib.rs @@ -3,10 +3,17 @@ #![deny(rust_2018_idioms, nonstandard_style)] #![warn(future_incompatible)] #![allow(clippy::borrow_interior_mutable_const, clippy::uninlined_format_args)] +#![cfg_attr(docsrs, feature(doc_cfg))] + +// This allows us to use the actix_multipart_derive within this crate's tests +#[cfg(test)] +extern crate self as actix_multipart; mod error; mod extractor; mod server; +pub mod form; + pub use self::error::MultipartError; pub use self::server::{Field, Multipart}; diff --git a/actix-multipart/src/server.rs b/actix-multipart/src/server.rs index 9e0becd5c..6726bc9d3 100644 --- a/actix-multipart/src/server.rs +++ b/actix-multipart/src/server.rs @@ -270,7 +270,9 @@ impl InnerMultipart { match field.borrow_mut().poll(safety) { Poll::Pending => return Poll::Pending, Poll::Ready(Some(Ok(_))) => continue, - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Err(err))) => { + return Poll::Ready(Some(Err(err))) + } Poll::Ready(None) => true, } } @@ -658,7 +660,7 @@ impl InnerField { match res { Poll::Pending => return Poll::Pending, Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))), - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))), Poll::Ready(None) => self.eof = true, } } @@ -673,7 +675,7 @@ impl InnerField { } Poll::Ready(None) } - Err(e) => Poll::Ready(Some(Err(e))), + Err(err) => Poll::Ready(Some(Err(err))), } } else { Poll::Pending @@ -794,7 +796,7 @@ impl PayloadBuffer { loop { match Pin::new(&mut self.stream).poll_next(cx) { Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data), - Poll::Ready(Some(Err(e))) => return Err(e), + Poll::Ready(Some(Err(err))) => return Err(err), Poll::Ready(None) => { self.eof = true; return Ok(()); @@ -860,19 +862,22 @@ impl PayloadBuffer { #[cfg(test)] mod tests { - use super::*; + use std::time::Duration; - use actix_http::h1::Payload; - use actix_web::http::header::{DispositionParam, DispositionType}; - use actix_web::rt; - use actix_web::test::TestRequest; - use actix_web::FromRequest; + use actix_http::h1; + use actix_web::{ + http::header::{DispositionParam, DispositionType}, + rt, + test::TestRequest, + FromRequest, + }; use bytes::Bytes; use futures_util::{future::lazy, StreamExt as _}; - use std::time::Duration; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; + use super::*; + #[actix_rt::test] async fn test_boundary() { let headers = HeaderMap::new(); @@ -1119,7 +1124,7 @@ mod tests { #[actix_rt::test] async fn test_basic() { - let (_, payload) = Payload::create(false); + let (_, payload) = h1::Payload::create(false); let mut payload = PayloadBuffer::new(payload); assert_eq!(payload.buf.len(), 0); @@ -1129,7 +1134,7 @@ mod tests { #[actix_rt::test] async fn test_eof() { - let (mut sender, payload) = Payload::create(false); + let (mut sender, payload) = h1::Payload::create(false); let mut payload = PayloadBuffer::new(payload); assert_eq!(None, payload.read_max(4).unwrap()); @@ -1145,7 +1150,7 @@ mod tests { #[actix_rt::test] async fn test_err() { - let (mut sender, payload) = Payload::create(false); + let (mut sender, payload) = h1::Payload::create(false); let mut payload = PayloadBuffer::new(payload); assert_eq!(None, payload.read_max(1).unwrap()); sender.set_error(PayloadError::Incomplete(None)); @@ -1154,7 +1159,7 @@ mod tests { #[actix_rt::test] async fn test_readmax() { - let (mut sender, payload) = Payload::create(false); + let (mut sender, payload) = h1::Payload::create(false); let mut payload = PayloadBuffer::new(payload); sender.feed_data(Bytes::from("line1")); @@ -1171,7 +1176,7 @@ mod tests { #[actix_rt::test] async fn test_readexactly() { - let (mut sender, payload) = Payload::create(false); + let (mut sender, payload) = h1::Payload::create(false); let mut payload = PayloadBuffer::new(payload); assert_eq!(None, payload.read_exact(2)); @@ -1189,7 +1194,7 @@ mod tests { #[actix_rt::test] async fn test_readuntil() { - let (mut sender, payload) = Payload::create(false); + let (mut sender, payload) = h1::Payload::create(false); let mut payload = PayloadBuffer::new(payload); assert_eq!(None, payload.read_until(b"ne").unwrap()); @@ -1230,7 +1235,7 @@ mod tests { #[actix_rt::test] async fn test_multipart_payload_consumption() { // with sample payload and HttpRequest with no headers - let (_, inner_payload) = Payload::create(false); + let (_, inner_payload) = h1::Payload::create(false); let mut payload = actix_web::dev::Payload::from(inner_payload); let req = TestRequest::default().to_http_request(); diff --git a/actix-web/Cargo.toml b/actix-web/Cargo.toml index 44755035c..6cb86bbdd 100644 --- a/actix-web/Cargo.toml +++ b/actix-web/Cargo.toml @@ -103,7 +103,7 @@ actix-test = { version = "0.1", features = ["openssl", "rustls"] } awc = { version = "3", features = ["openssl"] } brotli = "3.3.3" -const-str = "0.4" +const-str = "0.3" criterion = { version = "0.4", features = ["html_reports"] } env_logger = "0.9" flate2 = "1.0.13" diff --git a/awc/Cargo.toml b/awc/Cargo.toml index a69a07d67..00c3c87c5 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -99,7 +99,7 @@ actix-utils = "3" actix-web = { version = "4", features = ["openssl"] } brotli = "3.3.3" -const-str = "0.4" +const-str = "0.3" env_logger = "0.9" flate2 = "1.0.13" futures-util = { version = "0.3.17", default-features = false }