1
0
Fork 0
mirror of https://github.com/actix/actix-web.git synced 2024-11-24 18:41:05 +00:00

actix-multipart: Feature: Add typed multipart form extractor (#2883)

Co-authored-by: Rob Ede <robjtede@icloud.com>
This commit is contained in:
Jacob Halsey 2023-02-26 03:26:06 +00:00 committed by GitHub
parent 358c1cf85b
commit d4b833ccf0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 2017 additions and 32 deletions

View file

@ -5,6 +5,7 @@ members = [
"actix-http-test",
"actix-http",
"actix-multipart",
"actix-multipart-derive",
"actix-router",
"actix-test",
"actix-web-actors",
@ -27,6 +28,7 @@ actix-files = { path = "actix-files" }
actix-http = { path = "actix-http" }
actix-http-test = { path = "actix-http-test" }
actix-multipart = { path = "actix-multipart" }
actix-multipart-derive = { path = "actix-multipart-derive" }
actix-router = { path = "actix-router" }
actix-test = { path = "actix-test" }
actix-web = { path = "actix-web" }

View file

@ -0,0 +1,26 @@
[package]
name = "actix-multipart-derive"
version = "0.5.0"
authors = ["Jacob Halsey <jacob@jhalsey.com>"]
description = "Multipart form derive macro for Actix Web"
keywords = ["http", "web", "framework", "async", "futures"]
homepage = "https://actix.rs"
repository = "https://github.com/actix/actix-web.git"
license = "MIT OR Apache-2.0"
edition = "2018"
[lib]
proc-macro = true
[dependencies]
darling = "0.14"
parse-size = "1"
proc-macro2 = "1"
quote = "1"
syn = "1"
[dev-dependencies]
actix-multipart = "0.5"
actix-web = "4"
rustversion = "1"
trybuild = "1"

View file

@ -0,0 +1 @@
../LICENSE-APACHE

View file

@ -0,0 +1 @@
../LICENSE-MIT

View file

@ -0,0 +1,3 @@
# actix-multipart-derive
> The derive macro implementation for actix-multipart.

View file

@ -0,0 +1,315 @@
//! Multipart form derive macro for Actix Web.
//!
//! See [`macro@MultipartForm`] for usage examples.
#![deny(rust_2018_idioms, nonstandard_style)]
#![warn(future_incompatible)]
#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
#![cfg_attr(docsrs, feature(doc_cfg))]
use std::{collections::HashSet, convert::TryFrom as _};
use darling::{FromDeriveInput, FromField, FromMeta};
use parse_size::parse_size;
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use syn::{parse_macro_input, Type};
#[derive(FromMeta)]
enum DuplicateField {
Ignore,
Deny,
Replace,
}
impl Default for DuplicateField {
fn default() -> Self {
Self::Ignore
}
}
#[derive(FromDeriveInput, Default)]
#[darling(attributes(multipart), default)]
struct MultipartFormAttrs {
deny_unknown_fields: bool,
duplicate_field: DuplicateField,
}
#[derive(FromField, Default)]
#[darling(attributes(multipart), default)]
struct FieldAttrs {
rename: Option<String>,
limit: Option<String>,
}
struct ParsedField<'t> {
serialization_name: String,
rust_name: &'t Ident,
limit: Option<usize>,
ty: &'t Type,
}
/// Implements `MultipartCollect` for a struct so that it can be used with the `MultipartForm`
/// extractor.
///
/// # Basic Use
///
/// Each field type should implement the `FieldReader` trait:
///
/// ```
/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
///
/// #[derive(MultipartForm)]
/// struct ImageUpload {
/// description: Text<String>,
/// timestamp: Text<i64>,
/// image: TempFile,
/// }
/// ```
///
/// # Optional and List Fields
///
/// You can also use `Vec<T>` and `Option<T>` provided that `T: FieldReader`.
///
/// A [`Vec`] field corresponds to an upload with multiple parts under the [same field
/// name](https://www.rfc-editor.org/rfc/rfc7578#section-4.3).
///
/// ```
/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
///
/// #[derive(MultipartForm)]
/// struct Form {
/// category: Option<Text<String>>,
/// files: Vec<TempFile>,
/// }
/// ```
///
/// # Field Renaming
///
/// You can use the `#[multipart(rename = "foo")]` attribute to receive a field by a different name.
///
/// ```
/// use actix_multipart::form::{tempfile::TempFile, MultipartForm};
///
/// #[derive(MultipartForm)]
/// struct Form {
/// #[multipart(rename = "files[]")]
/// files: Vec<TempFile>,
/// }
/// ```
///
/// # Field Limits
///
/// You can use the `#[multipart(limit = "<size>")]` attribute to set field level limits. The limit
/// string is parsed using [parse_size].
///
/// Note: the form is also subject to the global limits configured using `MultipartFormConfig`.
///
/// ```
/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
///
/// #[derive(MultipartForm)]
/// struct Form {
/// #[multipart(limit = "2 KiB")]
/// description: Text<String>,
///
/// #[multipart(limit = "512 MiB")]
/// files: Vec<TempFile>,
/// }
/// ```
///
/// # Unknown Fields
///
/// By default fields with an unknown name are ignored. They can be rejected using the
/// `#[multipart(deny_unknown_fields)]` attribute:
///
/// ```
/// # use actix_multipart::form::MultipartForm;
/// #[derive(MultipartForm)]
/// #[multipart(deny_unknown_fields)]
/// struct Form { }
/// ```
///
/// # Duplicate Fields
///
/// The behaviour for when multiple fields with the same name are received can be changed using the
/// `#[multipart(duplicate_field = "<behavior>")]` attribute:
///
/// - "ignore": (default) Extra fields are ignored. I.e., the first one is persisted.
/// - "deny": A `MultipartError::UnsupportedField` error response is returned.
/// - "replace": Each field is processed, but only the last one is persisted.
///
/// Note that `Vec` fields will ignore this option.
///
/// ```
/// # use actix_multipart::form::MultipartForm;
/// #[derive(MultipartForm)]
/// #[multipart(duplicate_field = "deny")]
/// struct Form { }
/// ```
///
/// [parse_size]: https://docs.rs/parse-size/1/parse_size
#[proc_macro_derive(MultipartForm, attributes(multipart))]
pub fn impl_multipart_form(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input: syn::DeriveInput = parse_macro_input!(input);
let name = &input.ident;
let data_struct = match &input.data {
syn::Data::Struct(data_struct) => data_struct,
_ => {
return compile_err(syn::Error::new(
input.ident.span(),
"`MultipartForm` can only be derived for structs",
))
}
};
let fields = match &data_struct.fields {
syn::Fields::Named(fields_named) => fields_named,
_ => {
return compile_err(syn::Error::new(
input.ident.span(),
"`MultipartForm` can only be derived for a struct with named fields",
))
}
};
let attrs = match MultipartFormAttrs::from_derive_input(&input) {
Ok(attrs) => attrs,
Err(err) => return err.write_errors().into(),
};
// Parse the field attributes
let parsed = match fields
.named
.iter()
.map(|field| {
let rust_name = field.ident.as_ref().unwrap();
let attrs = FieldAttrs::from_field(field).map_err(|err| err.write_errors())?;
let serialization_name = attrs.rename.unwrap_or_else(|| rust_name.to_string());
let limit = match attrs.limit.map(|limit| match parse_size(&limit) {
Ok(size) => Ok(usize::try_from(size).unwrap()),
Err(err) => Err(syn::Error::new(
field.ident.as_ref().unwrap().span(),
format!("Could not parse size limit `{}`: {}", limit, err),
)),
}) {
Some(Err(err)) => return Err(compile_err(err)),
limit => limit.map(Result::unwrap),
};
Ok(ParsedField {
serialization_name,
rust_name,
limit,
ty: &field.ty,
})
})
.collect::<Result<Vec<_>, TokenStream>>()
{
Ok(attrs) => attrs,
Err(err) => return err,
};
// Check that field names are unique
let mut set = HashSet::new();
for field in &parsed {
if !set.insert(field.serialization_name.clone()) {
return compile_err(syn::Error::new(
field.rust_name.span(),
format!("Multiple fields named: `{}`", field.serialization_name),
));
}
}
// Return value when a field name is not supported by the form
let unknown_field_result = if attrs.deny_unknown_fields {
quote!(::std::result::Result::Err(
::actix_multipart::MultipartError::UnsupportedField(field.name().to_string())
))
} else {
quote!(::std::result::Result::Ok(()))
};
// Value for duplicate action
let duplicate_field = match attrs.duplicate_field {
DuplicateField::Ignore => quote!(::actix_multipart::form::DuplicateField::Ignore),
DuplicateField::Deny => quote!(::actix_multipart::form::DuplicateField::Deny),
DuplicateField::Replace => quote!(::actix_multipart::form::DuplicateField::Replace),
};
// limit() implementation
let mut limit_impl = quote!();
for field in &parsed {
let name = &field.serialization_name;
if let Some(value) = field.limit {
limit_impl.extend(quote!(
#name => ::std::option::Option::Some(#value),
));
}
}
// handle_field() implementation
let mut handle_field_impl = quote!();
for field in &parsed {
let name = &field.serialization_name;
let ty = &field.ty;
handle_field_impl.extend(quote!(
#name => ::std::boxed::Box::pin(
<#ty as ::actix_multipart::form::FieldGroupReader>::handle_field(req, field, limits, state, #duplicate_field)
),
));
}
// from_state() implementation
let mut from_state_impl = quote!();
for field in &parsed {
let name = &field.serialization_name;
let rust_name = &field.rust_name;
let ty = &field.ty;
from_state_impl.extend(quote!(
#rust_name: <#ty as ::actix_multipart::form::FieldGroupReader>::from_state(#name, &mut state)?,
));
}
let gen = quote! {
impl ::actix_multipart::form::MultipartCollect for #name {
fn limit(field_name: &str) -> ::std::option::Option<usize> {
match field_name {
#limit_impl
_ => None,
}
}
fn handle_field<'t>(
req: &'t ::actix_web::HttpRequest,
field: ::actix_multipart::Field,
limits: &'t mut ::actix_multipart::form::Limits,
state: &'t mut ::actix_multipart::form::State,
) -> ::std::pin::Pin<::std::boxed::Box<dyn ::std::future::Future<Output = ::std::result::Result<(), ::actix_multipart::MultipartError>> + 't>> {
match field.name() {
#handle_field_impl
_ => return ::std::boxed::Box::pin(::std::future::ready(#unknown_field_result)),
}
}
fn from_state(mut state: ::actix_multipart::form::State) -> ::std::result::Result<Self, ::actix_multipart::MultipartError> {
Ok(Self {
#from_state_impl
})
}
}
};
gen.into()
}
/// Transform a syn error into a token stream for returning.
fn compile_err(err: syn::Error) -> TokenStream {
TokenStream::from(err.to_compile_error())
}

View file

@ -0,0 +1,16 @@
#[rustversion::stable(1.59)] // MSRV
#[test]
fn compile_macros() {
let t = trybuild::TestCases::new();
t.pass("tests/trybuild/all-required.rs");
t.pass("tests/trybuild/optional-and-list.rs");
t.pass("tests/trybuild/rename.rs");
t.pass("tests/trybuild/deny-unknown.rs");
t.pass("tests/trybuild/deny-duplicates.rs");
t.compile_fail("tests/trybuild/deny-parse-fail.rs");
t.pass("tests/trybuild/size-limits.rs");
t.compile_fail("tests/trybuild/size-limit-parse-fail.rs");
}

View file

@ -0,0 +1,19 @@
use actix_web::{web, App, Responder};
use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
#[derive(Debug, MultipartForm)]
struct ImageUpload {
description: Text<String>,
timestamp: Text<i64>,
image: TempFile,
}
async fn handler(_form: MultipartForm<ImageUpload>) -> impl Responder {
"Hello World!"
}
#[actix_web::main]
async fn main() {
App::new().default_service(web::to(handler));
}

View file

@ -0,0 +1,16 @@
use actix_web::{web, App, Responder};
use actix_multipart::form::MultipartForm;
#[derive(MultipartForm)]
#[multipart(duplicate_field = "deny")]
struct Form {}
async fn handler(_form: MultipartForm<Form>) -> impl Responder {
"Hello World!"
}
#[actix_web::main]
async fn main() {
App::new().default_service(web::to(handler));
}

View file

@ -0,0 +1,7 @@
use actix_multipart::form::MultipartForm;
#[derive(MultipartForm)]
#[multipart(duplicate_field = "no")]
struct Form {}
fn main() {}

View file

@ -0,0 +1,5 @@
error: Unknown literal value `no`
--> tests/trybuild/deny-parse-fail.rs:4:31
|
4 | #[multipart(duplicate_field = "no")]
| ^^^^

View file

@ -0,0 +1,16 @@
use actix_web::{web, App, Responder};
use actix_multipart::form::MultipartForm;
#[derive(MultipartForm)]
#[multipart(deny_unknown_fields)]
struct Form {}
async fn handler(_form: MultipartForm<Form>) -> impl Responder {
"Hello World!"
}
#[actix_web::main]
async fn main() {
App::new().default_service(web::to(handler));
}

View file

@ -0,0 +1,18 @@
use actix_web::{web, App, Responder};
use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
#[derive(MultipartForm)]
struct Form {
category: Option<Text<String>>,
files: Vec<TempFile>,
}
async fn handler(_form: MultipartForm<Form>) -> impl Responder {
"Hello World!"
}
#[actix_web::main]
async fn main() {
App::new().default_service(web::to(handler));
}

View file

@ -0,0 +1,18 @@
use actix_web::{web, App, Responder};
use actix_multipart::form::{tempfile::TempFile, MultipartForm};
#[derive(MultipartForm)]
struct Form {
#[multipart(rename = "files[]")]
files: Vec<TempFile>,
}
async fn handler(_form: MultipartForm<Form>) -> impl Responder {
"Hello World!"
}
#[actix_web::main]
async fn main() {
App::new().default_service(web::to(handler));
}

View file

@ -0,0 +1,21 @@
use actix_multipart::form::{text::Text, MultipartForm};
#[derive(MultipartForm)]
struct Form {
#[multipart(limit = "2 bytes")]
description: Text<String>,
}
#[derive(MultipartForm)]
struct Form2 {
#[multipart(limit = "2 megabytes")]
description: Text<String>,
}
#[derive(MultipartForm)]
struct Form3 {
#[multipart(limit = "four meters")]
description: Text<String>,
}
fn main() {}

View file

@ -0,0 +1,17 @@
error: Could not parse size limit `2 bytes`: invalid digit found in string
--> tests/trybuild/size-limit-parse-fail.rs:6:5
|
6 | description: Text<String>,
| ^^^^^^^^^^^
error: Could not parse size limit `2 megabytes`: invalid digit found in string
--> tests/trybuild/size-limit-parse-fail.rs:12:5
|
12 | description: Text<String>,
| ^^^^^^^^^^^
error: Could not parse size limit `four meters`: invalid digit found in string
--> tests/trybuild/size-limit-parse-fail.rs:18:5
|
18 | description: Text<String>,
| ^^^^^^^^^^^

View file

@ -0,0 +1,21 @@
use actix_web::{web, App, Responder};
use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
#[derive(MultipartForm)]
struct Form {
#[multipart(limit = "2 KiB")]
description: Text<String>,
#[multipart(limit = "512 MiB")]
files: Vec<TempFile>,
}
async fn handler(_form: MultipartForm<Form>) -> impl Responder {
"Hello World!"
}
#[actix_web::main]
async fn main() {
App::new().default_service(web::to(handler));
}

View file

@ -1,11 +1,14 @@
# Changes
## Unreleased - 2022-xx-xx
- Added `MultipartForm` typed data extractor. [#2883]
[#2883]: https://github.com/actix/actix-web/pull/2883
## 0.5.0 - 2023-01-21
- `Field::content_type()` now returns `Option<&mime::Mime>`. [#2885]
- Minimum supported Rust version (MSRV) is now 1.59 due to transitive `time` dependency.
- `Field::content_type()` now returns `Option<&mime::Mime>` [#2885]
[#2885]: https://github.com/actix/actix-web/pull/2885

View file

@ -1,7 +1,10 @@
[package]
name = "actix-multipart"
version = "0.5.0"
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
authors = [
"Nikolay Kim <fafhrd91@gmail.com>",
"Jacob Halsey <jacob@jhalsey.com>",
]
description = "Multipart form support for Actix Web"
keywords = ["http", "web", "framework", "async", "futures"]
homepage = "https://actix.rs"
@ -9,26 +12,46 @@ repository = "https://github.com/actix/actix-web.git"
license = "MIT OR Apache-2.0"
edition = "2018"
[package.metadata.docs.rs]
rustdoc-args = ["--cfg", "docsrs"]
all-features = true
[features]
default = ["tempfile", "derive"]
derive = ["actix-multipart-derive"]
tempfile = ["tempfile-dep", "tokio/fs"]
[lib]
name = "actix_multipart"
path = "src/lib.rs"
[dependencies]
actix-multipart-derive = { version = "=0.5.0", optional = true }
actix-utils = "3"
actix-web = { version = "4", default-features = false }
bytes = "1"
derive_more = "0.99.5"
futures-core = { version = "0.3.17", default-features = false, features = ["alloc"] }
futures-util = { version = "0.3.17", default-features = false, features = ["alloc"] }
httparse = "1.3"
local-waker = "0.1"
log = "0.4"
mime = "0.3"
memchr = "2.5"
mime = "0.3"
serde = "1"
serde_json = "1"
serde_plain = "1"
# TODO(MSRV 1.60): replace with dep: prefix
tempfile-dep = { package = "tempfile", version = "3.4", optional = true }
tokio = { version = "1.18.5", features = ["sync"] }
[dev-dependencies]
actix-rt = "2.2"
actix-http = "3"
actix-multipart-rfc7578 = "0.10"
actix-rt = "2.2"
actix-test = "0.1"
awc = "3"
futures-util = { version = "0.3.17", default-features = false, features = ["alloc"] }
tokio = { version = "1.18.5", features = ["sync"] }
tokio-stream = "0.1"

View file

@ -1,12 +1,15 @@
//! Error and Result module
use actix_web::error::{ParseError, PayloadError};
use actix_web::http::StatusCode;
use actix_web::ResponseError;
use actix_web::{
error::{ParseError, PayloadError},
http::StatusCode,
ResponseError,
};
use derive_more::{Display, Error, From};
/// A set of errors that can occur during parsing multipart streams
#[non_exhaustive]
/// A set of errors that can occur during parsing multipart streams.
#[derive(Debug, Display, From, Error)]
#[non_exhaustive]
pub enum MultipartError {
/// Content-Disposition header is not found or is not equal to "form-data".
///
@ -46,12 +49,41 @@ pub enum MultipartError {
/// Not consumed
#[display(fmt = "Multipart stream is not consumed")]
NotConsumed,
/// An error from a field handler in a form
#[display(
fmt = "An error occurred processing field `{}`: {}",
field_name,
source
)]
Field {
field_name: String,
source: actix_web::Error,
},
/// Duplicate field
#[display(fmt = "Duplicate field found for: `{}`", _0)]
#[from(ignore)]
DuplicateField(#[error(not(source))] String),
/// Missing field
#[display(fmt = "Field with name `{}` is required", _0)]
#[from(ignore)]
MissingField(#[error(not(source))] String),
/// Unknown field
#[display(fmt = "Unsupported field `{}`", _0)]
#[from(ignore)]
UnsupportedField(#[error(not(source))] String),
}
/// Return `BadRequest` for `MultipartError`
impl ResponseError for MultipartError {
fn status_code(&self) -> StatusCode {
StatusCode::BAD_REQUEST
match &self {
MultipartError::Field { source, .. } => source.as_response_error().status_code(),
_ => StatusCode::BAD_REQUEST,
}
}
}

View file

@ -9,8 +9,7 @@ use crate::server::Multipart;
///
/// Content-type: multipart/form-data;
///
/// ## Server example
///
/// # Examples
/// ```
/// use actix_web::{web, HttpResponse, Error};
/// use actix_multipart::Multipart;

View file

@ -0,0 +1,53 @@
//! Reads a field into memory.
use actix_web::HttpRequest;
use bytes::BytesMut;
use futures_core::future::LocalBoxFuture;
use futures_util::TryStreamExt as _;
use mime::Mime;
use crate::{
form::{FieldReader, Limits},
Field, MultipartError,
};
/// Read the field into memory.
#[derive(Debug)]
pub struct Bytes {
/// The data.
pub data: bytes::Bytes,
/// The value of the `Content-Type` header.
pub content_type: Option<Mime>,
/// The `filename` value in the `Content-Disposition` header.
pub file_name: Option<String>,
}
impl<'t> FieldReader<'t> for Bytes {
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
fn read_field(
_: &'t HttpRequest,
mut field: Field,
limits: &'t mut Limits,
) -> Self::Future {
Box::pin(async move {
let mut buf = BytesMut::with_capacity(131_072);
while let Some(chunk) = field.try_next().await? {
limits.try_consume_limits(chunk.len(), true)?;
buf.extend(chunk);
}
Ok(Bytes {
data: buf.freeze(),
content_type: field.content_type().map(ToOwned::to_owned),
file_name: field
.content_disposition()
.get_filename()
.map(str::to_owned),
})
})
}
}

View file

@ -0,0 +1,195 @@
//! Deserializes a field as JSON.
use std::sync::Arc;
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
use derive_more::{Deref, DerefMut, Display, Error};
use futures_core::future::LocalBoxFuture;
use serde::de::DeserializeOwned;
use crate::{
form::{bytes::Bytes, FieldReader, Limits},
Field, MultipartError,
};
use super::FieldErrorHandler;
/// Deserialize from JSON.
#[derive(Debug, Deref, DerefMut)]
pub struct Json<T: DeserializeOwned>(pub T);
impl<T: DeserializeOwned> Json<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<'t, T> FieldReader<'t> for Json<T>
where
T: DeserializeOwned + 'static,
{
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future {
Box::pin(async move {
let config = JsonConfig::from_req(req);
let field_name = field.name().to_owned();
if config.validate_content_type {
let valid = if let Some(mime) = field.content_type() {
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
} else {
false
};
if !valid {
return Err(MultipartError::Field {
field_name,
source: config.map_error(req, JsonFieldError::ContentType),
});
}
}
let bytes = Bytes::read_field(req, field, limits).await?;
Ok(Json(serde_json::from_slice(bytes.data.as_ref()).map_err(
|err| MultipartError::Field {
field_name,
source: config.map_error(req, JsonFieldError::Deserialize(err)),
},
)?))
})
}
}
#[derive(Debug, Display, Error)]
#[non_exhaustive]
pub enum JsonFieldError {
/// Deserialize error.
#[display(fmt = "Json deserialize error: {}", _0)]
Deserialize(serde_json::Error),
/// Content type error.
#[display(fmt = "Content type error")]
ContentType,
}
impl ResponseError for JsonFieldError {
fn status_code(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
}
/// Configuration for the [`Json`] field reader.
#[derive(Clone)]
pub struct JsonConfig {
err_handler: FieldErrorHandler<JsonFieldError>,
validate_content_type: bool,
}
const DEFAULT_CONFIG: JsonConfig = JsonConfig {
err_handler: None,
validate_content_type: true,
};
impl JsonConfig {
pub fn error_handler<F>(mut self, f: F) -> Self
where
F: Fn(JsonFieldError, &HttpRequest) -> Error + Send + Sync + 'static,
{
self.err_handler = Some(Arc::new(f));
self
}
/// Extract payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
/// back to the default payload config.
fn from_req(req: &HttpRequest) -> &Self {
req.app_data::<Self>()
.or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
.unwrap_or(&DEFAULT_CONFIG)
}
fn map_error(&self, req: &HttpRequest, err: JsonFieldError) -> Error {
if let Some(err_handler) = self.err_handler.as_ref() {
(*err_handler)(err, req)
} else {
err.into()
}
}
/// Sets whether or not the field must have a valid `Content-Type` header to be parsed.
pub fn validate_content_type(mut self, validate_content_type: bool) -> Self {
self.validate_content_type = validate_content_type;
self
}
}
impl Default for JsonConfig {
fn default() -> Self {
DEFAULT_CONFIG
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, io::Cursor};
use actix_multipart_rfc7578::client::multipart;
use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
use crate::form::{
json::{Json, JsonConfig},
tests::send_form,
MultipartForm,
};
#[derive(MultipartForm)]
struct JsonForm {
json: Json<HashMap<String, String>>,
}
async fn test_json_route(form: MultipartForm<JsonForm>) -> impl Responder {
let mut expected = HashMap::new();
expected.insert("key1".to_owned(), "value1".to_owned());
expected.insert("key2".to_owned(), "value2".to_owned());
assert_eq!(&*form.json, &expected);
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_json_without_content_type() {
let srv = actix_test::start(|| {
App::new()
.route("/", web::post().to(test_json_route))
.app_data(JsonConfig::default().validate_content_type(false))
});
let mut form = multipart::Form::default();
form.add_text("json", "{\"key1\": \"value1\", \"key2\": \"value2\"}");
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
}
#[actix_rt::test]
async fn test_content_type_validation() {
let srv = actix_test::start(|| {
App::new()
.route("/", web::post().to(test_json_route))
.app_data(JsonConfig::default().validate_content_type(true))
});
// Deny because wrong content type
let bytes = Cursor::new("{\"key1\": \"value1\", \"key2\": \"value2\"}");
let mut form = multipart::Form::default();
form.add_reader_file_with_mime("json", bytes, "", mime::APPLICATION_OCTET_STREAM);
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Allow because correct content type
let bytes = Cursor::new("{\"key1\": \"value1\", \"key2\": \"value2\"}");
let mut form = multipart::Form::default();
form.add_reader_file_with_mime("json", bytes, "", mime::APPLICATION_JSON);
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
}
}

View file

@ -0,0 +1,744 @@
//! Process and extract typed data from a multipart stream.
use std::{
any::Any,
collections::HashMap,
future::{ready, Future},
sync::Arc,
};
use actix_web::{dev, error::PayloadError, web, Error, FromRequest, HttpRequest};
use derive_more::{Deref, DerefMut};
use futures_core::future::LocalBoxFuture;
use futures_util::{TryFutureExt as _, TryStreamExt as _};
use crate::{Field, Multipart, MultipartError};
pub mod bytes;
pub mod json;
#[cfg_attr(docsrs, doc(cfg(feature = "tempfile")))]
#[cfg(feature = "tempfile")]
pub mod tempfile;
pub mod text;
#[cfg_attr(docsrs, doc(cfg(feature = "derive")))]
#[cfg(feature = "derive")]
pub use actix_multipart_derive::MultipartForm;
type FieldErrorHandler<T> = Option<Arc<dyn Fn(T, &HttpRequest) -> Error + Send + Sync>>;
/// Trait that data types to be used in a multipart form struct should implement.
///
/// It represents an asynchronous handler that processes a multipart field to produce `Self`.
pub trait FieldReader<'t>: Sized + Any {
/// Future that resolves to a `Self`.
type Future: Future<Output = Result<Self, MultipartError>>;
/// The form will call this function to handle the field.
fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future;
}
/// Used to accumulate the state of the loaded fields.
#[doc(hidden)]
#[derive(Default, Deref, DerefMut)]
pub struct State(pub HashMap<String, Box<dyn Any>>);
/// Trait that the field collection types implement, i.e. `Vec<T>`, `Option<T>`, or `T` itself.
#[doc(hidden)]
pub trait FieldGroupReader<'t>: Sized + Any {
type Future: Future<Output = Result<(), MultipartError>>;
/// The form will call this function for each matching field.
fn handle_field(
req: &'t HttpRequest,
field: Field,
limits: &'t mut Limits,
state: &'t mut State,
duplicate_field: DuplicateField,
) -> Self::Future;
/// Construct `Self` from the group of processed fields.
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError>;
}
impl<'t, T> FieldGroupReader<'t> for Option<T>
where
T: FieldReader<'t>,
{
type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
fn handle_field(
req: &'t HttpRequest,
field: Field,
limits: &'t mut Limits,
state: &'t mut State,
duplicate_field: DuplicateField,
) -> Self::Future {
if state.contains_key(field.name()) {
match duplicate_field {
DuplicateField::Ignore => return Box::pin(ready(Ok(()))),
DuplicateField::Deny => {
return Box::pin(ready(Err(MultipartError::DuplicateField(
field.name().to_owned(),
))))
}
DuplicateField::Replace => {}
}
}
Box::pin(async move {
let field_name = field.name().to_owned();
let t = T::read_field(req, field, limits).await?;
state.insert(field_name, Box::new(t));
Ok(())
})
}
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
Ok(state.remove(name).map(|m| *m.downcast::<T>().unwrap()))
}
}
impl<'t, T> FieldGroupReader<'t> for Vec<T>
where
T: FieldReader<'t>,
{
type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
fn handle_field(
req: &'t HttpRequest,
field: Field,
limits: &'t mut Limits,
state: &'t mut State,
_duplicate_field: DuplicateField,
) -> Self::Future {
Box::pin(async move {
// Note: Vec GroupReader always allows duplicates
let field_name = field.name().to_owned();
let vec = state
.entry(field_name)
.or_insert_with(|| Box::<Vec<T>>::default())
.downcast_mut::<Vec<T>>()
.unwrap();
let item = T::read_field(req, field, limits).await?;
vec.push(item);
Ok(())
})
}
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
Ok(state
.remove(name)
.map(|m| *m.downcast::<Vec<T>>().unwrap())
.unwrap_or_default())
}
}
impl<'t, T> FieldGroupReader<'t> for T
where
T: FieldReader<'t>,
{
type Future = LocalBoxFuture<'t, Result<(), MultipartError>>;
fn handle_field(
req: &'t HttpRequest,
field: Field,
limits: &'t mut Limits,
state: &'t mut State,
duplicate_field: DuplicateField,
) -> Self::Future {
if state.contains_key(field.name()) {
match duplicate_field {
DuplicateField::Ignore => return Box::pin(ready(Ok(()))),
DuplicateField::Deny => {
return Box::pin(ready(Err(MultipartError::DuplicateField(
field.name().to_owned(),
))))
}
DuplicateField::Replace => {}
}
}
Box::pin(async move {
let field_name = field.name().to_owned();
let t = T::read_field(req, field, limits).await?;
state.insert(field_name, Box::new(t));
Ok(())
})
}
fn from_state(name: &str, state: &'t mut State) -> Result<Self, MultipartError> {
state
.remove(name)
.map(|m| *m.downcast::<T>().unwrap())
.ok_or_else(|| MultipartError::MissingField(name.to_owned()))
}
}
/// Trait that allows a type to be used in the [`struct@MultipartForm`] extractor.
///
/// You should use the [`macro@MultipartForm`] macro to derive this for your struct.
pub trait MultipartCollect: Sized {
/// An optional limit in bytes to be applied a given field name. Note this limit will be shared
/// across all fields sharing the same name.
fn limit(field_name: &str) -> Option<usize>;
/// The extractor will call this function for each incoming field, the state can be updated
/// with the processed field data.
fn handle_field<'t>(
req: &'t HttpRequest,
field: Field,
limits: &'t mut Limits,
state: &'t mut State,
) -> LocalBoxFuture<'t, Result<(), MultipartError>>;
/// Once all the fields have been processed and stored in the state, this is called
/// to convert into the struct representation.
fn from_state(state: State) -> Result<Self, MultipartError>;
}
#[doc(hidden)]
pub enum DuplicateField {
/// Additional fields are not processed.
Ignore,
/// An error will be raised.
Deny,
/// All fields will be processed, the last one will replace all previous.
Replace,
}
/// Used to keep track of the remaining limits for the form and current field.
pub struct Limits {
pub total_limit_remaining: usize,
pub memory_limit_remaining: usize,
pub field_limit_remaining: Option<usize>,
}
impl Limits {
pub fn new(total_limit: usize, memory_limit: usize) -> Self {
Self {
total_limit_remaining: total_limit,
memory_limit_remaining: memory_limit,
field_limit_remaining: None,
}
}
/// This function should be called within a [`FieldReader`] when reading each chunk of a field
/// to ensure that the form limits are not exceeded.
///
/// # Arguments
///
/// * `bytes` - The number of bytes being read from this chunk
/// * `in_memory` - Whether to consume from the memory limits
pub fn try_consume_limits(
&mut self,
bytes: usize,
in_memory: bool,
) -> Result<(), MultipartError> {
self.total_limit_remaining = self
.total_limit_remaining
.checked_sub(bytes)
.ok_or(MultipartError::Payload(PayloadError::Overflow))?;
if in_memory {
self.memory_limit_remaining = self
.memory_limit_remaining
.checked_sub(bytes)
.ok_or(MultipartError::Payload(PayloadError::Overflow))?;
}
if let Some(field_limit) = self.field_limit_remaining {
self.field_limit_remaining = Some(
field_limit
.checked_sub(bytes)
.ok_or(MultipartError::Payload(PayloadError::Overflow))?,
);
}
Ok(())
}
}
/// Typed `multipart/form-data` extractor.
///
/// To extract typed data from a multipart stream, the inner type `T` must implement the
/// [`MultipartCollect`] trait. You should use the [`macro@MultipartForm`] macro to derive this
/// for your struct.
///
/// Add a [`MultipartFormConfig`] to your app data to configure extraction.
#[derive(Deref, DerefMut)]
pub struct MultipartForm<T: MultipartCollect>(pub T);
impl<T: MultipartCollect> MultipartForm<T> {
/// Unwrap into inner `T` value.
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> FromRequest for MultipartForm<T>
where
T: MultipartCollect,
{
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
let mut payload = Multipart::new(req.headers(), payload.take());
let config = MultipartFormConfig::from_req(req);
let mut limits = Limits::new(config.total_limit, config.memory_limit);
let req = req.clone();
let req2 = req.clone();
let err_handler = config.err_handler.clone();
Box::pin(
async move {
let mut state = State::default();
// We need to ensure field limits are shared for all instances of this field name
let mut field_limits = HashMap::<String, Option<usize>>::new();
while let Some(field) = payload.try_next().await? {
// Retrieve the limit for this field
let entry = field_limits
.entry(field.name().to_owned())
.or_insert_with(|| T::limit(field.name()));
limits.field_limit_remaining = entry.to_owned();
T::handle_field(&req, field, &mut limits, &mut state).await?;
// Update the stored limit
*entry = limits.field_limit_remaining;
}
let inner = T::from_state(state)?;
Ok(MultipartForm(inner))
}
.map_err(move |err| {
if let Some(handler) = err_handler {
(*handler)(err, &req2)
} else {
err.into()
}
}),
)
}
}
type MultipartFormErrorHandler =
Option<Arc<dyn Fn(MultipartError, &HttpRequest) -> Error + Send + Sync>>;
/// [`struct@MultipartForm`] extractor configuration.
///
/// Add to your app data to have it picked up by [`struct@MultipartForm`] extractors.
#[derive(Clone)]
pub struct MultipartFormConfig {
total_limit: usize,
memory_limit: usize,
err_handler: MultipartFormErrorHandler,
}
impl MultipartFormConfig {
/// Sets maximum accepted payload size for the entire form. By default this limit is 50MiB.
pub fn total_limit(mut self, total_limit: usize) -> Self {
self.total_limit = total_limit;
self
}
/// Sets maximum accepted data that will be read into memory. By default this limit is 2MiB.
pub fn memory_limit(mut self, memory_limit: usize) -> Self {
self.memory_limit = memory_limit;
self
}
/// Sets custom error handler.
pub fn error_handler<F>(mut self, f: F) -> Self
where
F: Fn(MultipartError, &HttpRequest) -> Error + Send + Sync + 'static,
{
self.err_handler = Some(Arc::new(f));
self
}
/// Extracts payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
/// back to the default payload config.
fn from_req(req: &HttpRequest) -> &Self {
req.app_data::<Self>()
.or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
.unwrap_or(&DEFAULT_CONFIG)
}
}
const DEFAULT_CONFIG: MultipartFormConfig = MultipartFormConfig {
total_limit: 52_428_800, // 50 MiB
memory_limit: 2_097_152, // 2 MiB
err_handler: None,
};
impl Default for MultipartFormConfig {
fn default() -> Self {
DEFAULT_CONFIG
}
}
#[cfg(test)]
mod tests {
use actix_http::encoding::Decoder;
use actix_multipart_rfc7578::client::multipart;
use actix_test::TestServer;
use actix_web::{dev::Payload, http::StatusCode, web, App, HttpResponse, Responder};
use awc::{Client, ClientResponse};
use super::MultipartForm;
use crate::form::{bytes::Bytes, tempfile::TempFile, text::Text, MultipartFormConfig};
pub async fn send_form(
srv: &TestServer,
form: multipart::Form<'static>,
uri: &'static str,
) -> ClientResponse<Decoder<Payload>> {
Client::default()
.post(srv.url(uri))
.content_type(form.content_type())
.send_body(multipart::Body::from(form))
.await
.unwrap()
}
/// Test `Option` fields.
#[derive(MultipartForm)]
struct TestOptions {
field1: Option<Text<String>>,
field2: Option<Text<String>>,
}
async fn test_options_route(form: MultipartForm<TestOptions>) -> impl Responder {
assert!(form.field1.is_some());
assert!(form.field2.is_none());
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_options() {
let srv =
actix_test::start(|| App::new().route("/", web::post().to(test_options_route)));
let mut form = multipart::Form::default();
form.add_text("field1", "value");
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
}
/// Test `Vec` fields.
#[derive(MultipartForm)]
struct TestVec {
list1: Vec<Text<String>>,
list2: Vec<Text<String>>,
}
async fn test_vec_route(form: MultipartForm<TestVec>) -> impl Responder {
let form = form.into_inner();
let strings = form
.list1
.into_iter()
.map(|s| s.into_inner())
.collect::<Vec<_>>();
assert_eq!(strings, vec!["value1", "value2", "value3"]);
assert_eq!(form.list2.len(), 0);
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_vec() {
let srv = actix_test::start(|| App::new().route("/", web::post().to(test_vec_route)));
let mut form = multipart::Form::default();
form.add_text("list1", "value1");
form.add_text("list1", "value2");
form.add_text("list1", "value3");
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
}
/// Test the `rename` field attribute.
#[derive(MultipartForm)]
struct TestFieldRenaming {
#[multipart(rename = "renamed")]
field1: Text<String>,
#[multipart(rename = "field1")]
field2: Text<String>,
field3: Text<String>,
}
async fn test_field_renaming_route(
form: MultipartForm<TestFieldRenaming>,
) -> impl Responder {
assert_eq!(&*form.field1, "renamed");
assert_eq!(&*form.field2, "field1");
assert_eq!(&*form.field3, "field3");
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_field_renaming() {
let srv = actix_test::start(|| {
App::new().route("/", web::post().to(test_field_renaming_route))
});
let mut form = multipart::Form::default();
form.add_text("renamed", "renamed");
form.add_text("field1", "field1");
form.add_text("field3", "field3");
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
}
/// Test the `deny_unknown_fields` struct attribute.
#[derive(MultipartForm)]
#[multipart(deny_unknown_fields)]
struct TestDenyUnknown {}
#[derive(MultipartForm)]
struct TestAllowUnknown {}
async fn test_deny_unknown_route(_: MultipartForm<TestDenyUnknown>) -> impl Responder {
HttpResponse::Ok().finish()
}
async fn test_allow_unknown_route(_: MultipartForm<TestAllowUnknown>) -> impl Responder {
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_deny_unknown() {
let srv = actix_test::start(|| {
App::new()
.route("/deny", web::post().to(test_deny_unknown_route))
.route("/allow", web::post().to(test_allow_unknown_route))
});
let mut form = multipart::Form::default();
form.add_text("unknown", "value");
let response = send_form(&srv, form, "/deny").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let mut form = multipart::Form::default();
form.add_text("unknown", "value");
let response = send_form(&srv, form, "/allow").await;
assert_eq!(response.status(), StatusCode::OK);
}
/// Test the `duplicate_field` struct attribute.
#[derive(MultipartForm)]
#[multipart(duplicate_field = "deny")]
struct TestDuplicateDeny {
_field: Text<String>,
}
#[derive(MultipartForm)]
#[multipart(duplicate_field = "replace")]
struct TestDuplicateReplace {
field: Text<String>,
}
#[derive(MultipartForm)]
#[multipart(duplicate_field = "ignore")]
struct TestDuplicateIgnore {
field: Text<String>,
}
async fn test_duplicate_deny_route(_: MultipartForm<TestDuplicateDeny>) -> impl Responder {
HttpResponse::Ok().finish()
}
async fn test_duplicate_replace_route(
form: MultipartForm<TestDuplicateReplace>,
) -> impl Responder {
assert_eq!(&*form.field, "second_value");
HttpResponse::Ok().finish()
}
async fn test_duplicate_ignore_route(
form: MultipartForm<TestDuplicateIgnore>,
) -> impl Responder {
assert_eq!(&*form.field, "first_value");
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_duplicate_field() {
let srv = actix_test::start(|| {
App::new()
.route("/deny", web::post().to(test_duplicate_deny_route))
.route("/replace", web::post().to(test_duplicate_replace_route))
.route("/ignore", web::post().to(test_duplicate_ignore_route))
});
let mut form = multipart::Form::default();
form.add_text("_field", "first_value");
form.add_text("_field", "second_value");
let response = send_form(&srv, form, "/deny").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let mut form = multipart::Form::default();
form.add_text("field", "first_value");
form.add_text("field", "second_value");
let response = send_form(&srv, form, "/replace").await;
assert_eq!(response.status(), StatusCode::OK);
let mut form = multipart::Form::default();
form.add_text("field", "first_value");
form.add_text("field", "second_value");
let response = send_form(&srv, form, "/ignore").await;
assert_eq!(response.status(), StatusCode::OK);
}
/// Test the Limits.
#[derive(MultipartForm)]
struct TestMemoryUploadLimits {
field: Bytes,
}
#[derive(MultipartForm)]
struct TestFileUploadLimits {
field: TempFile,
}
async fn test_upload_limits_memory(
form: MultipartForm<TestMemoryUploadLimits>,
) -> impl Responder {
assert!(!form.field.data.is_empty());
HttpResponse::Ok().finish()
}
async fn test_upload_limits_file(
form: MultipartForm<TestFileUploadLimits>,
) -> impl Responder {
assert!(form.field.size > 0);
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_memory_limits() {
let srv = actix_test::start(|| {
App::new()
.route("/text", web::post().to(test_upload_limits_memory))
.route("/file", web::post().to(test_upload_limits_file))
.app_data(
MultipartFormConfig::default()
.memory_limit(20)
.total_limit(usize::MAX),
)
});
// Exceeds the 20 byte memory limit
let mut form = multipart::Form::default();
form.add_text("field", "this string is 28 bytes long");
let response = send_form(&srv, form, "/text").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Memory limit should not apply when the data is being streamed to disk
let mut form = multipart::Form::default();
form.add_text("field", "this string is 28 bytes long");
let response = send_form(&srv, form, "/file").await;
assert_eq!(response.status(), StatusCode::OK);
}
#[actix_rt::test]
async fn test_total_limit() {
let srv = actix_test::start(|| {
App::new()
.route("/text", web::post().to(test_upload_limits_memory))
.route("/file", web::post().to(test_upload_limits_file))
.app_data(
MultipartFormConfig::default()
.memory_limit(usize::MAX)
.total_limit(20),
)
});
// Within the 20 byte limit
let mut form = multipart::Form::default();
form.add_text("field", "7 bytes");
let response = send_form(&srv, form, "/text").await;
assert_eq!(response.status(), StatusCode::OK);
// Exceeds the 20 byte overall limit
let mut form = multipart::Form::default();
form.add_text("field", "this string is 28 bytes long");
let response = send_form(&srv, form, "/text").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Exceeds the 20 byte overall limit
let mut form = multipart::Form::default();
form.add_text("field", "this string is 28 bytes long");
let response = send_form(&srv, form, "/file").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[derive(MultipartForm)]
struct TestFieldLevelLimits {
#[multipart(limit = "30B")]
field: Vec<Bytes>,
}
async fn test_field_level_limits_route(
form: MultipartForm<TestFieldLevelLimits>,
) -> impl Responder {
assert!(!form.field.is_empty());
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_field_level_limits() {
let srv = actix_test::start(|| {
App::new()
.route("/", web::post().to(test_field_level_limits_route))
.app_data(
MultipartFormConfig::default()
.memory_limit(usize::MAX)
.total_limit(usize::MAX),
)
});
// Within the 30 byte limit
let mut form = multipart::Form::default();
form.add_text("field", "this string is 28 bytes long");
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
// Exceeds the the 30 byte limit
let mut form = multipart::Form::default();
form.add_text("field", "this string is more than 30 bytes long");
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Total of values (14 bytes) is within 30 byte limit for "field"
let mut form = multipart::Form::default();
form.add_text("field", "7 bytes");
form.add_text("field", "7 bytes");
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
// Total of values exceeds 30 byte limit for "field"
let mut form = multipart::Form::default();
form.add_text("field", "this string is 28 bytes long");
form.add_text("field", "this string is 28 bytes long");
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
}

View file

@ -0,0 +1,206 @@
//! Writes a field to a temporary file on disk.
use std::{
io,
path::{Path, PathBuf},
sync::Arc,
};
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
use derive_more::{Display, Error};
use futures_core::future::LocalBoxFuture;
use futures_util::TryStreamExt as _;
use mime::Mime;
use tempfile_dep::NamedTempFile;
use tokio::io::AsyncWriteExt;
use super::FieldErrorHandler;
use crate::{
form::{FieldReader, Limits},
Field, MultipartError,
};
/// Write the field to a temporary file on disk.
#[derive(Debug)]
pub struct TempFile {
/// The temporary file on disk.
pub file: NamedTempFile,
/// The value of the `content-type` header.
pub content_type: Option<Mime>,
/// The `filename` value in the `content-disposition` header.
pub file_name: Option<String>,
/// The size in bytes of the file.
pub size: usize,
}
impl<'t> FieldReader<'t> for TempFile {
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
fn read_field(
req: &'t HttpRequest,
mut field: Field,
limits: &'t mut Limits,
) -> Self::Future {
Box::pin(async move {
let config = TempFileConfig::from_req(req);
let field_name = field.name().to_owned();
let mut size = 0;
let file = config.create_tempfile().map_err(|err| {
config.map_error(req, &field_name, TempFileError::FileIo(err))
})?;
let mut file_async = tokio::fs::File::from_std(file.reopen().map_err(|err| {
config.map_error(req, &field_name, TempFileError::FileIo(err))
})?);
while let Some(chunk) = field.try_next().await? {
limits.try_consume_limits(chunk.len(), false)?;
size += chunk.len();
file_async.write_all(chunk.as_ref()).await.map_err(|err| {
config.map_error(req, &field_name, TempFileError::FileIo(err))
})?;
}
file_async.flush().await.map_err(|err| {
config.map_error(req, &field_name, TempFileError::FileIo(err))
})?;
Ok(TempFile {
file,
content_type: field.content_type().map(ToOwned::to_owned),
file_name: field
.content_disposition()
.get_filename()
.map(str::to_owned),
size,
})
})
}
}
#[derive(Debug, Display, Error)]
#[non_exhaustive]
pub enum TempFileError {
/// File I/O Error
#[display(fmt = "File I/O error: {}", _0)]
FileIo(std::io::Error),
}
impl ResponseError for TempFileError {
fn status_code(&self) -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}
}
/// Configuration for the [`TempFile`] field reader.
#[derive(Clone)]
pub struct TempFileConfig {
err_handler: FieldErrorHandler<TempFileError>,
directory: Option<PathBuf>,
}
impl TempFileConfig {
fn create_tempfile(&self) -> io::Result<NamedTempFile> {
if let Some(ref dir) = self.directory {
NamedTempFile::new_in(dir)
} else {
NamedTempFile::new()
}
}
}
impl TempFileConfig {
/// Sets custom error handler.
pub fn error_handler<F>(mut self, f: F) -> Self
where
F: Fn(TempFileError, &HttpRequest) -> Error + Send + Sync + 'static,
{
self.err_handler = Some(Arc::new(f));
self
}
/// Extracts payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
/// back to the default payload config.
fn from_req(req: &HttpRequest) -> &Self {
req.app_data::<Self>()
.or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
.unwrap_or(&DEFAULT_CONFIG)
}
fn map_error(
&self,
req: &HttpRequest,
field_name: &str,
err: TempFileError,
) -> MultipartError {
let source = if let Some(ref err_handler) = self.err_handler {
(err_handler)(err, req)
} else {
err.into()
};
MultipartError::Field {
field_name: field_name.to_owned(),
source,
}
}
/// Sets the directory that temp files will be created in.
///
/// The default temporary file location is platform dependent.
pub fn directory(mut self, dir: impl AsRef<Path>) -> Self {
self.directory = Some(dir.as_ref().to_owned());
self
}
}
const DEFAULT_CONFIG: TempFileConfig = TempFileConfig {
err_handler: None,
directory: None,
};
impl Default for TempFileConfig {
fn default() -> Self {
DEFAULT_CONFIG
}
}
#[cfg(test)]
mod tests {
use std::io::{Cursor, Read};
use actix_multipart_rfc7578::client::multipart;
use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
use crate::form::{tempfile::TempFile, tests::send_form, MultipartForm};
#[derive(MultipartForm)]
struct FileForm {
file: TempFile,
}
async fn test_file_route(form: MultipartForm<FileForm>) -> impl Responder {
let mut form = form.into_inner();
let mut contents = String::new();
form.file.file.read_to_string(&mut contents).unwrap();
assert_eq!(contents, "Hello, world!");
assert_eq!(form.file.file_name.unwrap(), "testfile.txt");
assert_eq!(form.file.content_type.unwrap(), mime::TEXT_PLAIN);
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_file_upload() {
let srv = actix_test::start(|| App::new().route("/", web::post().to(test_file_route)));
let mut form = multipart::Form::default();
let bytes = Cursor::new("Hello, world!");
form.add_reader_file_with_mime("file", bytes, "testfile.txt", mime::TEXT_PLAIN);
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
}
}

View file

@ -0,0 +1,196 @@
//! Deserializes a field from plain text.
use std::{str, sync::Arc};
use actix_web::{http::StatusCode, web, Error, HttpRequest, ResponseError};
use derive_more::{Deref, DerefMut, Display, Error};
use futures_core::future::LocalBoxFuture;
use serde::de::DeserializeOwned;
use super::FieldErrorHandler;
use crate::{
form::{bytes::Bytes, FieldReader, Limits},
Field, MultipartError,
};
/// Deserialize from plain text.
///
/// Internally this uses [`serde_plain`] for deserialization, which supports primitive types
/// including strings, numbers, and simple enums.
#[derive(Debug, Deref, DerefMut)]
pub struct Text<T: DeserializeOwned>(pub T);
impl<T: DeserializeOwned> Text<T> {
/// Unwraps into inner value.
pub fn into_inner(self) -> T {
self.0
}
}
impl<'t, T> FieldReader<'t> for Text<T>
where
T: DeserializeOwned + 'static,
{
type Future = LocalBoxFuture<'t, Result<Self, MultipartError>>;
fn read_field(req: &'t HttpRequest, field: Field, limits: &'t mut Limits) -> Self::Future {
Box::pin(async move {
let config = TextConfig::from_req(req);
let field_name = field.name().to_owned();
if config.validate_content_type {
let valid = if let Some(mime) = field.content_type() {
mime.subtype() == mime::PLAIN || mime.suffix() == Some(mime::PLAIN)
} else {
// https://datatracker.ietf.org/doc/html/rfc7578#section-4.4
// content type defaults to text/plain, so None should be considered valid
true
};
if !valid {
return Err(MultipartError::Field {
field_name,
source: config.map_error(req, TextError::ContentType),
});
}
}
let bytes = Bytes::read_field(req, field, limits).await?;
let text = str::from_utf8(&bytes.data).map_err(|err| MultipartError::Field {
field_name: field_name.clone(),
source: config.map_error(req, TextError::Utf8Error(err)),
})?;
Ok(Text(serde_plain::from_str(text).map_err(|err| {
MultipartError::Field {
field_name,
source: config.map_error(req, TextError::Deserialize(err)),
}
})?))
})
}
}
#[derive(Debug, Display, Error)]
#[non_exhaustive]
pub enum TextError {
/// UTF-8 decoding error.
#[display(fmt = "UTF-8 decoding error: {}", _0)]
Utf8Error(str::Utf8Error),
/// Deserialize error.
#[display(fmt = "Plain text deserialize error: {}", _0)]
Deserialize(serde_plain::Error),
/// Content type error.
#[display(fmt = "Content type error")]
ContentType,
}
impl ResponseError for TextError {
fn status_code(&self) -> StatusCode {
StatusCode::BAD_REQUEST
}
}
/// Configuration for the [`Text`] field reader.
#[derive(Clone)]
pub struct TextConfig {
err_handler: FieldErrorHandler<TextError>,
validate_content_type: bool,
}
impl TextConfig {
/// Sets custom error handler.
pub fn error_handler<F>(mut self, f: F) -> Self
where
F: Fn(TextError, &HttpRequest) -> Error + Send + Sync + 'static,
{
self.err_handler = Some(Arc::new(f));
self
}
/// Extracts payload config from app data. Check both `T` and `Data<T>`, in that order, and fall
/// back to the default payload config.
fn from_req(req: &HttpRequest) -> &Self {
req.app_data::<Self>()
.or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
.unwrap_or(&DEFAULT_CONFIG)
}
fn map_error(&self, req: &HttpRequest, err: TextError) -> Error {
if let Some(ref err_handler) = self.err_handler {
(err_handler)(err, req)
} else {
err.into()
}
}
/// Sets whether or not the field must have a valid `Content-Type` header to be parsed.
///
/// Note that an empty `Content-Type` is also accepted, as the multipart specification defines
/// `text/plain` as the default for text fields.
pub fn validate_content_type(mut self, validate_content_type: bool) -> Self {
self.validate_content_type = validate_content_type;
self
}
}
const DEFAULT_CONFIG: TextConfig = TextConfig {
err_handler: None,
validate_content_type: true,
};
impl Default for TextConfig {
fn default() -> Self {
DEFAULT_CONFIG
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use actix_multipart_rfc7578::client::multipart;
use actix_web::{http::StatusCode, web, App, HttpResponse, Responder};
use crate::form::{
tests::send_form,
text::{Text, TextConfig},
MultipartForm,
};
#[derive(MultipartForm)]
struct TextForm {
number: Text<i32>,
}
async fn test_text_route(form: MultipartForm<TextForm>) -> impl Responder {
assert_eq!(*form.number, 1025);
HttpResponse::Ok().finish()
}
#[actix_rt::test]
async fn test_content_type_validation() {
let srv = actix_test::start(|| {
App::new()
.route("/", web::post().to(test_text_route))
.app_data(TextConfig::default().validate_content_type(true))
});
// Deny because wrong content type
let bytes = Cursor::new("1025");
let mut form = multipart::Form::default();
form.add_reader_file_with_mime("number", bytes, "", mime::APPLICATION_OCTET_STREAM);
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
// Allow because correct content type
let bytes = Cursor::new("1025");
let mut form = multipart::Form::default();
form.add_reader_file_with_mime("number", bytes, "", mime::TEXT_PLAIN);
let response = send_form(&srv, form, "/").await;
assert_eq!(response.status(), StatusCode::OK);
}
}

View file

@ -3,10 +3,17 @@
#![deny(rust_2018_idioms, nonstandard_style)]
#![warn(future_incompatible)]
#![allow(clippy::borrow_interior_mutable_const, clippy::uninlined_format_args)]
#![cfg_attr(docsrs, feature(doc_cfg))]
// This allows us to use the actix_multipart_derive within this crate's tests
#[cfg(test)]
extern crate self as actix_multipart;
mod error;
mod extractor;
mod server;
pub mod form;
pub use self::error::MultipartError;
pub use self::server::{Field, Multipart};

View file

@ -270,7 +270,9 @@ impl InnerMultipart {
match field.borrow_mut().poll(safety) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(_))) => continue,
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Some(Err(err)))
}
Poll::Ready(None) => true,
}
}
@ -658,7 +660,7 @@ impl InnerField {
match res {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(bytes))) => return Poll::Ready(Some(Ok(bytes))),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
Poll::Ready(None) => self.eof = true,
}
}
@ -673,7 +675,7 @@ impl InnerField {
}
Poll::Ready(None)
}
Err(e) => Poll::Ready(Some(Err(e))),
Err(err) => Poll::Ready(Some(Err(err))),
}
} else {
Poll::Pending
@ -794,7 +796,7 @@ impl PayloadBuffer {
loop {
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(Ok(data))) => self.buf.extend_from_slice(&data),
Poll::Ready(Some(Err(e))) => return Err(e),
Poll::Ready(Some(Err(err))) => return Err(err),
Poll::Ready(None) => {
self.eof = true;
return Ok(());
@ -860,19 +862,22 @@ impl PayloadBuffer {
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use actix_http::h1::Payload;
use actix_web::http::header::{DispositionParam, DispositionType};
use actix_web::rt;
use actix_web::test::TestRequest;
use actix_web::FromRequest;
use actix_http::h1;
use actix_web::{
http::header::{DispositionParam, DispositionType},
rt,
test::TestRequest,
FromRequest,
};
use bytes::Bytes;
use futures_util::{future::lazy, StreamExt as _};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use super::*;
#[actix_rt::test]
async fn test_boundary() {
let headers = HeaderMap::new();
@ -1119,7 +1124,7 @@ mod tests {
#[actix_rt::test]
async fn test_basic() {
let (_, payload) = Payload::create(false);
let (_, payload) = h1::Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
assert_eq!(payload.buf.len(), 0);
@ -1129,7 +1134,7 @@ mod tests {
#[actix_rt::test]
async fn test_eof() {
let (mut sender, payload) = Payload::create(false);
let (mut sender, payload) = h1::Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_max(4).unwrap());
@ -1145,7 +1150,7 @@ mod tests {
#[actix_rt::test]
async fn test_err() {
let (mut sender, payload) = Payload::create(false);
let (mut sender, payload) = h1::Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_max(1).unwrap());
sender.set_error(PayloadError::Incomplete(None));
@ -1154,7 +1159,7 @@ mod tests {
#[actix_rt::test]
async fn test_readmax() {
let (mut sender, payload) = Payload::create(false);
let (mut sender, payload) = h1::Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
sender.feed_data(Bytes::from("line1"));
@ -1171,7 +1176,7 @@ mod tests {
#[actix_rt::test]
async fn test_readexactly() {
let (mut sender, payload) = Payload::create(false);
let (mut sender, payload) = h1::Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_exact(2));
@ -1189,7 +1194,7 @@ mod tests {
#[actix_rt::test]
async fn test_readuntil() {
let (mut sender, payload) = Payload::create(false);
let (mut sender, payload) = h1::Payload::create(false);
let mut payload = PayloadBuffer::new(payload);
assert_eq!(None, payload.read_until(b"ne").unwrap());
@ -1230,7 +1235,7 @@ mod tests {
#[actix_rt::test]
async fn test_multipart_payload_consumption() {
// with sample payload and HttpRequest with no headers
let (_, inner_payload) = Payload::create(false);
let (_, inner_payload) = h1::Payload::create(false);
let mut payload = actix_web::dev::Payload::from(inner_payload);
let req = TestRequest::default().to_http_request();

View file

@ -103,7 +103,7 @@ actix-test = { version = "0.1", features = ["openssl", "rustls"] }
awc = { version = "3", features = ["openssl"] }
brotli = "3.3.3"
const-str = "0.4"
const-str = "0.3"
criterion = { version = "0.4", features = ["html_reports"] }
env_logger = "0.9"
flate2 = "1.0.13"

View file

@ -99,7 +99,7 @@ actix-utils = "3"
actix-web = { version = "4", features = ["openssl"] }
brotli = "3.3.3"
const-str = "0.4"
const-str = "0.3"
env_logger = "0.9"
flate2 = "1.0.13"
futures-util = { version = "0.3.17", default-features = false }