,
+
flags: Flags,
}
diff --git a/actix-http/src/requests/mod.rs b/actix-http/src/requests/mod.rs
index fc35da65a..4a27818a5 100644
--- a/actix-http/src/requests/mod.rs
+++ b/actix-http/src/requests/mod.rs
@@ -3,5 +3,7 @@
mod head;
mod request;
-pub use self::head::{RequestHead, RequestHeadType};
-pub use self::request::Request;
+pub use self::{
+ head::{RequestHead, RequestHeadType},
+ request::Request,
+};
diff --git a/actix-http/src/requests/request.rs b/actix-http/src/requests/request.rs
index ac358e8df..6a267a7a6 100644
--- a/actix-http/src/requests/request.rs
+++ b/actix-http/src/requests/request.rs
@@ -10,8 +10,7 @@ use std::{
use http::{header, Method, Uri, Version};
use crate::{
- header::HeaderMap, BoxedPayloadStream, Extensions, HttpMessage, Message, Payload,
- RequestHead,
+ header::HeaderMap, BoxedPayloadStream, Extensions, HttpMessage, Message, Payload, RequestHead,
};
/// An HTTP request.
@@ -174,7 +173,7 @@ impl Request
{
/// Peer address is the directly connected peer's socket address. If a proxy is used in front of
/// the Actix Web server, then it would be address of this proxy.
///
- /// Will only return None when called in unit tests.
+ /// Will only return None when called in unit tests unless set manually.
#[inline]
pub fn peer_addr(&self) -> Option {
self.head().peer_addr
@@ -234,7 +233,6 @@ impl fmt::Debug for Request
{
#[cfg(test)]
mod tests {
use super::*;
- use std::convert::TryFrom;
#[test]
fn test_basics() {
diff --git a/actix-http/src/responses/builder.rs b/actix-http/src/responses/builder.rs
index 063af92da..91c69ba54 100644
--- a/actix-http/src/responses/builder.rs
+++ b/actix-http/src/responses/builder.rs
@@ -93,7 +93,7 @@ impl ResponseBuilder {
Ok((key, value)) => {
parts.headers.insert(key, value);
}
- Err(e) => self.err = Some(e.into()),
+ Err(err) => self.err = Some(err.into()),
};
}
@@ -119,7 +119,7 @@ impl ResponseBuilder {
if let Some(parts) = self.inner() {
match header.try_into_pair() {
Ok((key, value)) => parts.headers.append(key, value),
- Err(e) => self.err = Some(e.into()),
+ Err(err) => self.err = Some(err.into()),
};
}
@@ -193,7 +193,7 @@ impl ResponseBuilder {
Ok(value) => {
parts.headers.insert(header::CONTENT_TYPE, value);
}
- Err(e) => self.err = Some(e.into()),
+ Err(err) => self.err = Some(err.into()),
};
}
self
diff --git a/actix-http/src/responses/mod.rs b/actix-http/src/responses/mod.rs
index 899232b9f..d99628232 100644
--- a/actix-http/src/responses/mod.rs
+++ b/actix-http/src/responses/mod.rs
@@ -5,7 +5,5 @@ mod head;
#[allow(clippy::module_inception)]
mod response;
-pub use self::builder::ResponseBuilder;
pub(crate) use self::head::BoxedResponseHead;
-pub use self::head::ResponseHead;
-pub use self::response::Response;
+pub use self::{builder::ResponseBuilder, head::ResponseHead, response::Response};
diff --git a/actix-http/src/service.rs b/actix-http/src/service.rs
index 62128f3ec..a58be93c7 100644
--- a/actix-http/src/service.rs
+++ b/actix-http/src/service.rs
@@ -30,9 +30,9 @@ use crate::{
///
/// # Automatic HTTP Version Selection
/// There are two ways to select the HTTP version of an incoming connection:
-/// - One is to rely on the ALPN information that is provided when using a TLS (HTTPS); both
-/// versions are supported automatically when using either of the `.rustls()` or `.openssl()`
-/// finalizing methods.
+/// - One is to rely on the ALPN information that is provided when using TLS (HTTPS); both versions
+/// are supported automatically when using either of the `.rustls()` or `.openssl()` finalizing
+/// methods.
/// - The other is to read the first few bytes of the TCP stream. This is the only viable approach
/// for supporting H2C, which allows the HTTP/2 protocol to work over plaintext connections. Use
/// the `.tcp_auto_h2c()` finalizing method to enable this behavior.
@@ -200,13 +200,8 @@ where
/// The resulting service only supports HTTP/1.x.
pub fn tcp(
self,
- ) -> impl ServiceFactory<
- TcpStream,
- Config = (),
- Response = (),
- Error = DispatchError,
- InitError = (),
- > {
+ ) -> impl ServiceFactory
+ {
fn_service(|io: TcpStream| async {
let peer_addr = io.peer_addr().ok();
Ok((io, Protocol::Http1, peer_addr))
@@ -217,16 +212,10 @@ where
/// Creates TCP stream service from HTTP service that automatically selects HTTP/1.x or HTTP/2
/// on plaintext connections.
#[cfg(feature = "http2")]
- #[cfg_attr(docsrs, doc(cfg(feature = "http2")))]
pub fn tcp_auto_h2c(
self,
- ) -> impl ServiceFactory<
- TcpStream,
- Config = (),
- Response = (),
- Error = DispatchError,
- InitError = (),
- > {
+ ) -> impl ServiceFactory
+ {
fn_service(move |io: TcpStream| async move {
// subset of HTTP/2 preface defined by RFC 9113 §3.4
// this subset was chosen to maximize likelihood that peeking only once will allow us to
@@ -252,14 +241,25 @@ where
}
/// Configuration options used when accepting TLS connection.
-#[cfg(any(feature = "openssl", feature = "rustls"))]
-#[cfg_attr(docsrs, doc(cfg(any(feature = "openssl", feature = "rustls"))))]
+#[cfg(any(
+ feature = "openssl",
+ feature = "rustls-0_20",
+ feature = "rustls-0_21",
+ feature = "rustls-0_22",
+ feature = "rustls-0_23",
+))]
#[derive(Debug, Default)]
pub struct TlsAcceptorConfig {
pub(crate) handshake_timeout: Option,
}
-#[cfg(any(feature = "openssl", feature = "rustls"))]
+#[cfg(any(
+ feature = "openssl",
+ feature = "rustls-0_20",
+ feature = "rustls-0_21",
+ feature = "rustls-0_22",
+ feature = "rustls-0_23",
+))]
impl TlsAcceptorConfig {
/// Set TLS handshake timeout duration.
pub fn handshake_timeout(self, dur: std::time::Duration) -> Self {
@@ -309,7 +309,6 @@ mod openssl {
U::InitError: fmt::Debug,
{
/// Create OpenSSL based service.
- #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
pub fn openssl(
self,
acceptor: SslAcceptor,
@@ -324,7 +323,6 @@ mod openssl {
}
/// Create OpenSSL based service with custom TLS acceptor configuration.
- #[cfg_attr(docsrs, doc(cfg(feature = "openssl")))]
pub fn openssl_with_config(
self,
acceptor: SslAcceptor,
@@ -366,13 +364,13 @@ mod openssl {
}
}
-#[cfg(feature = "rustls")]
-mod rustls {
+#[cfg(feature = "rustls-0_20")]
+mod rustls_0_20 {
use std::io;
use actix_service::ServiceFactoryExt as _;
use actix_tls::accept::{
- rustls::{reexports::ServerConfig, Acceptor, TlsStream},
+ rustls_0_20::{reexports::ServerConfig, Acceptor, TlsStream},
TlsError,
};
@@ -403,8 +401,7 @@ mod rustls {
U::Error: fmt::Display + Into>,
U::InitError: fmt::Debug,
{
- /// Create Rustls based service.
- #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
+ /// Create Rustls v0.20 based service.
pub fn rustls(
self,
config: ServerConfig,
@@ -418,8 +415,7 @@ mod rustls {
self.rustls_with_config(config, TlsAcceptorConfig::default())
}
- /// Create Rustls based service with custom TLS acceptor configuration.
- #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
+ /// Create Rustls v0.20 based service with custom TLS acceptor configuration.
pub fn rustls_with_config(
self,
mut config: ServerConfig,
@@ -464,6 +460,294 @@ mod rustls {
}
}
+#[cfg(feature = "rustls-0_21")]
+mod rustls_0_21 {
+ use std::io;
+
+ use actix_service::ServiceFactoryExt as _;
+ use actix_tls::accept::{
+ rustls_0_21::{reexports::ServerConfig, Acceptor, TlsStream},
+ TlsError,
+ };
+
+ use super::*;
+
+ impl HttpService, S, B, X, U>
+ where
+ S: ServiceFactory,
+ S::Future: 'static,
+ S::Error: Into> + 'static,
+ S::InitError: fmt::Debug,
+ S::Response: Into> + 'static,
+ >::Future: 'static,
+
+ B: MessageBody + 'static,
+
+ X: ServiceFactory,
+ X::Future: 'static,
+ X::Error: Into>,
+ X::InitError: fmt::Debug,
+
+ U: ServiceFactory<
+ (Request, Framed, h1::Codec>),
+ Config = (),
+ Response = (),
+ >,
+ U::Future: 'static,
+ U::Error: fmt::Display + Into>,
+ U::InitError: fmt::Debug,
+ {
+ /// Create Rustls v0.21 based service.
+ pub fn rustls_021(
+ self,
+ config: ServerConfig,
+ ) -> impl ServiceFactory<
+ TcpStream,
+ Config = (),
+ Response = (),
+ Error = TlsError,
+ InitError = (),
+ > {
+ self.rustls_021_with_config(config, TlsAcceptorConfig::default())
+ }
+
+ /// Create Rustls v0.21 based service with custom TLS acceptor configuration.
+ pub fn rustls_021_with_config(
+ self,
+ mut config: ServerConfig,
+ tls_acceptor_config: TlsAcceptorConfig,
+ ) -> impl ServiceFactory<
+ TcpStream,
+ Config = (),
+ Response = (),
+ Error = TlsError,
+ InitError = (),
+ > {
+ let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
+ protos.extend_from_slice(&config.alpn_protocols);
+ config.alpn_protocols = protos;
+
+ let mut acceptor = Acceptor::new(config);
+
+ if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
+ acceptor.set_handshake_timeout(handshake_timeout);
+ }
+
+ acceptor
+ .map_init_err(|_| {
+ unreachable!("TLS acceptor service factory does not error on init")
+ })
+ .map_err(TlsError::into_service_error)
+ .and_then(|io: TlsStream| async {
+ let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
+ if protos.windows(2).any(|window| window == b"h2") {
+ Protocol::Http2
+ } else {
+ Protocol::Http1
+ }
+ } else {
+ Protocol::Http1
+ };
+ let peer_addr = io.get_ref().0.peer_addr().ok();
+ Ok((io, proto, peer_addr))
+ })
+ .and_then(self.map_err(TlsError::Service))
+ }
+ }
+}
+
+#[cfg(feature = "rustls-0_22")]
+mod rustls_0_22 {
+ use std::io;
+
+ use actix_service::ServiceFactoryExt as _;
+ use actix_tls::accept::{
+ rustls_0_22::{reexports::ServerConfig, Acceptor, TlsStream},
+ TlsError,
+ };
+
+ use super::*;
+
+ impl HttpService, S, B, X, U>
+ where
+ S: ServiceFactory,
+ S::Future: 'static,
+ S::Error: Into> + 'static,
+ S::InitError: fmt::Debug,
+ S::Response: Into> + 'static,
+ >::Future: 'static,
+
+ B: MessageBody + 'static,
+
+ X: ServiceFactory,
+ X::Future: 'static,
+ X::Error: Into>,
+ X::InitError: fmt::Debug,
+
+ U: ServiceFactory<
+ (Request, Framed, h1::Codec>),
+ Config = (),
+ Response = (),
+ >,
+ U::Future: 'static,
+ U::Error: fmt::Display + Into>,
+ U::InitError: fmt::Debug,
+ {
+ /// Create Rustls v0.22 based service.
+ pub fn rustls_0_22(
+ self,
+ config: ServerConfig,
+ ) -> impl ServiceFactory<
+ TcpStream,
+ Config = (),
+ Response = (),
+ Error = TlsError,
+ InitError = (),
+ > {
+ self.rustls_0_22_with_config(config, TlsAcceptorConfig::default())
+ }
+
+ /// Create Rustls v0.22 based service with custom TLS acceptor configuration.
+ pub fn rustls_0_22_with_config(
+ self,
+ mut config: ServerConfig,
+ tls_acceptor_config: TlsAcceptorConfig,
+ ) -> impl ServiceFactory<
+ TcpStream,
+ Config = (),
+ Response = (),
+ Error = TlsError,
+ InitError = (),
+ > {
+ let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
+ protos.extend_from_slice(&config.alpn_protocols);
+ config.alpn_protocols = protos;
+
+ let mut acceptor = Acceptor::new(config);
+
+ if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
+ acceptor.set_handshake_timeout(handshake_timeout);
+ }
+
+ acceptor
+ .map_init_err(|_| {
+ unreachable!("TLS acceptor service factory does not error on init")
+ })
+ .map_err(TlsError::into_service_error)
+ .and_then(|io: TlsStream| async {
+ let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
+ if protos.windows(2).any(|window| window == b"h2") {
+ Protocol::Http2
+ } else {
+ Protocol::Http1
+ }
+ } else {
+ Protocol::Http1
+ };
+ let peer_addr = io.get_ref().0.peer_addr().ok();
+ Ok((io, proto, peer_addr))
+ })
+ .and_then(self.map_err(TlsError::Service))
+ }
+ }
+}
+
+#[cfg(feature = "rustls-0_23")]
+mod rustls_0_23 {
+ use std::io;
+
+ use actix_service::ServiceFactoryExt as _;
+ use actix_tls::accept::{
+ rustls_0_23::{reexports::ServerConfig, Acceptor, TlsStream},
+ TlsError,
+ };
+
+ use super::*;
+
+ impl HttpService, S, B, X, U>
+ where
+ S: ServiceFactory,
+ S::Future: 'static,
+ S::Error: Into> + 'static,
+ S::InitError: fmt::Debug,
+ S::Response: Into> + 'static,
+ >::Future: 'static,
+
+ B: MessageBody + 'static,
+
+ X: ServiceFactory,
+ X::Future: 'static,
+ X::Error: Into>,
+ X::InitError: fmt::Debug,
+
+ U: ServiceFactory<
+ (Request, Framed, h1::Codec>),
+ Config = (),
+ Response = (),
+ >,
+ U::Future: 'static,
+ U::Error: fmt::Display + Into>,
+ U::InitError: fmt::Debug,
+ {
+ /// Create Rustls v0.23 based service.
+ pub fn rustls_0_23(
+ self,
+ config: ServerConfig,
+ ) -> impl ServiceFactory<
+ TcpStream,
+ Config = (),
+ Response = (),
+ Error = TlsError,
+ InitError = (),
+ > {
+ self.rustls_0_23_with_config(config, TlsAcceptorConfig::default())
+ }
+
+ /// Create Rustls v0.23 based service with custom TLS acceptor configuration.
+ pub fn rustls_0_23_with_config(
+ self,
+ mut config: ServerConfig,
+ tls_acceptor_config: TlsAcceptorConfig,
+ ) -> impl ServiceFactory<
+ TcpStream,
+ Config = (),
+ Response = (),
+ Error = TlsError,
+ InitError = (),
+ > {
+ let mut protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
+ protos.extend_from_slice(&config.alpn_protocols);
+ config.alpn_protocols = protos;
+
+ let mut acceptor = Acceptor::new(config);
+
+ if let Some(handshake_timeout) = tls_acceptor_config.handshake_timeout {
+ acceptor.set_handshake_timeout(handshake_timeout);
+ }
+
+ acceptor
+ .map_init_err(|_| {
+ unreachable!("TLS acceptor service factory does not error on init")
+ })
+ .map_err(TlsError::into_service_error)
+ .and_then(|io: TlsStream| async {
+ let proto = if let Some(protos) = io.get_ref().1.alpn_protocol() {
+ if protos.windows(2).any(|window| window == b"h2") {
+ Protocol::Http2
+ } else {
+ Protocol::Http1
+ }
+ } else {
+ Protocol::Http1
+ };
+ let peer_addr = io.get_ref().0.peer_addr().ok();
+ Ok((io, proto, peer_addr))
+ })
+ .and_then(self.map_err(TlsError::Service))
+ }
+ }
+}
+
impl ServiceFactory<(T, Protocol, Option)>
for HttpService
where
@@ -569,10 +853,7 @@ where
}
}
- pub(super) fn _poll_ready(
- &self,
- cx: &mut Context<'_>,
- ) -> Poll>> {
+ pub(super) fn _poll_ready(&self, cx: &mut Context<'_>) -> Poll>> {
ready!(self.flow.expect.poll_ready(cx).map_err(Into::into))?;
ready!(self.flow.service.poll_ready(cx).map_err(Into::into))?;
@@ -631,10 +912,7 @@ where
})
}
- fn call(
- &self,
- (io, proto, peer_addr): (T, Protocol, Option),
- ) -> Self::Future {
+ fn call(&self, (io, proto, peer_addr): (T, Protocol, Option)) -> Self::Future {
let conn_data = OnConnectData::from_io(&io, self.on_connect_ext.as_deref());
match proto {
diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs
index 6a149f9a4..ad487e400 100644
--- a/actix-http/src/ws/codec.rs
+++ b/actix-http/src/ws/codec.rs
@@ -74,6 +74,7 @@ pub struct Codec {
}
bitflags! {
+ #[derive(Debug, Clone, Copy)]
struct Flags: u8 {
const SERVER = 0b0000_0001;
const CONTINUATION = 0b0000_0010;
@@ -295,7 +296,7 @@ impl Decoder for Codec {
}
}
Ok(None) => Ok(None),
- Err(e) => Err(e),
+ Err(err) => Err(err),
}
}
}
diff --git a/actix-http/src/ws/dispatcher.rs b/actix-http/src/ws/dispatcher.rs
index 396f1e86c..1354d5ae1 100644
--- a/actix-http/src/ws/dispatcher.rs
+++ b/actix-http/src/ws/dispatcher.rs
@@ -70,15 +70,14 @@ mod inner {
task::{Context, Poll},
};
+ use actix_codec::Framed;
use actix_service::{IntoService, Service};
use futures_core::stream::Stream;
use local_channel::mpsc;
use pin_project_lite::pin_project;
- use tracing::debug;
-
- use actix_codec::Framed;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::{Decoder, Encoder};
+ use tracing::debug;
use crate::{body::BoxBody, Response};
@@ -413,9 +412,7 @@ mod inner {
}
State::Error(_) => {
// flush write buffer
- if !this.framed.is_write_buf_empty()
- && this.framed.flush(cx).is_pending()
- {
+ if !this.framed.is_write_buf_empty() && this.framed.flush(cx).is_pending() {
return Poll::Pending;
}
Poll::Ready(Err(this.state.take_error()))
diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs
index c7e0427ea..35b3f8e66 100644
--- a/actix-http/src/ws/frame.rs
+++ b/actix-http/src/ws/frame.rs
@@ -1,4 +1,4 @@
-use std::convert::TryFrom;
+use std::cmp::min;
use bytes::{Buf, BufMut, BytesMut};
use tracing::debug;
@@ -96,6 +96,10 @@ impl Parser {
// not enough data
if src.len() < idx + length {
+ let min_length = min(length, max_size);
+ if src.capacity() < idx + min_length {
+ src.reserve(idx + min_length - src.capacity());
+ }
return Ok(None);
}
@@ -174,14 +178,14 @@ impl Parser {
};
if payload_len < 126 {
- dst.reserve(p_len + 2 + if mask { 4 } else { 0 });
+ dst.reserve(p_len + 2);
dst.put_slice(&[one, two | payload_len as u8]);
} else if payload_len <= 65_535 {
- dst.reserve(p_len + 4 + if mask { 4 } else { 0 });
+ dst.reserve(p_len + 4);
dst.put_slice(&[one, two | 126]);
dst.put_u16(payload_len as u16);
} else {
- dst.reserve(p_len + 10 + if mask { 4 } else { 0 });
+ dst.reserve(p_len + 10);
dst.put_slice(&[one, two | 127]);
dst.put_u64(payload_len as u64);
};
@@ -217,9 +221,10 @@ impl Parser {
#[cfg(test)]
mod tests {
- use super::*;
use bytes::Bytes;
+ use super::*;
+
struct F {
finished: bool,
opcode: OpCode,
diff --git a/actix-http/src/ws/mask.rs b/actix-http/src/ws/mask.rs
index be72e5631..115a8cf9b 100644
--- a/actix-http/src/ws/mask.rs
+++ b/actix-http/src/ws/mask.rs
@@ -50,7 +50,7 @@ mod tests {
#[test]
fn test_apply_mask() {
let mask = [0x6d, 0xb6, 0xb2, 0x80];
- let unmasked = vec![
+ let unmasked = [
0xf3, 0x00, 0x01, 0x02, 0x03, 0x80, 0x81, 0x82, 0xff, 0xfe, 0x00, 0x17, 0x74, 0xf9,
0x12, 0x03,
];
diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs
index 75d4ca628..3ed53b70a 100644
--- a/actix-http/src/ws/mod.rs
+++ b/actix-http/src/ws/mod.rs
@@ -8,8 +8,7 @@ use std::io;
use derive_more::{Display, Error, From};
use http::{header, Method, StatusCode};
-use crate::body::BoxBody;
-use crate::{header::HeaderValue, RequestHead, Response, ResponseBuilder};
+use crate::{body::BoxBody, header::HeaderValue, RequestHead, Response, ResponseBuilder};
mod codec;
mod dispatcher;
@@ -17,48 +16,50 @@ mod frame;
mod mask;
mod proto;
-pub use self::codec::{Codec, Frame, Item, Message};
-pub use self::dispatcher::Dispatcher;
-pub use self::frame::Parser;
-pub use self::proto::{hash_key, CloseCode, CloseReason, OpCode};
+pub use self::{
+ codec::{Codec, Frame, Item, Message},
+ dispatcher::Dispatcher,
+ frame::Parser,
+ proto::{hash_key, CloseCode, CloseReason, OpCode},
+};
/// WebSocket protocol errors.
#[derive(Debug, Display, Error, From)]
pub enum ProtocolError {
/// Received an unmasked frame from client.
- #[display(fmt = "Received an unmasked frame from client.")]
+ #[display(fmt = "received an unmasked frame from client")]
UnmaskedFrame,
/// Received a masked frame from server.
- #[display(fmt = "Received a masked frame from server.")]
+ #[display(fmt = "received a masked frame from server")]
MaskedFrame,
/// Encountered invalid opcode.
- #[display(fmt = "Invalid opcode: {}.", _0)]
+ #[display(fmt = "invalid opcode ({})", _0)]
InvalidOpcode(#[error(not(source))] u8),
/// Invalid control frame length
- #[display(fmt = "Invalid control frame length: {}.", _0)]
+ #[display(fmt = "invalid control frame length ({})", _0)]
InvalidLength(#[error(not(source))] usize),
/// Bad opcode.
- #[display(fmt = "Bad opcode.")]
+ #[display(fmt = "bad opcode")]
BadOpCode,
/// A payload reached size limit.
- #[display(fmt = "A payload reached size limit.")]
+ #[display(fmt = "payload reached size limit")]
Overflow,
- /// Continuation is not started.
- #[display(fmt = "Continuation is not started.")]
+ /// Continuation has not started.
+ #[display(fmt = "continuation has not started")]
ContinuationNotStarted,
/// Received new continuation but it is already started.
- #[display(fmt = "Received new continuation but it is already started.")]
+ #[display(fmt = "received new continuation but it has already started")]
ContinuationStarted,
/// Unknown continuation fragment.
- #[display(fmt = "Unknown continuation fragment: {}.", _0)]
+ #[display(fmt = "unknown continuation fragment: {}", _0)]
ContinuationFragment(#[error(not(source))] OpCode),
/// I/O error.
@@ -70,27 +71,27 @@ pub enum ProtocolError {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Display, Error)]
pub enum HandshakeError {
/// Only get method is allowed.
- #[display(fmt = "Method not allowed.")]
+ #[display(fmt = "method not allowed")]
GetMethodRequired,
/// Upgrade header if not set to WebSocket.
- #[display(fmt = "WebSocket upgrade is expected.")]
+ #[display(fmt = "WebSocket upgrade is expected")]
NoWebsocketUpgrade,
/// Connection header is not set to upgrade.
- #[display(fmt = "Connection upgrade is expected.")]
+ #[display(fmt = "connection upgrade is expected")]
NoConnectionUpgrade,
/// WebSocket version header is not set.
- #[display(fmt = "WebSocket version header is required.")]
+ #[display(fmt = "WebSocket version header is required")]
NoVersionHeader,
/// Unsupported WebSocket version.
- #[display(fmt = "Unsupported WebSocket version.")]
+ #[display(fmt = "unsupported WebSocket version")]
UnsupportedVersion,
/// WebSocket key is not set or wrong.
- #[display(fmt = "Unknown websocket key.")]
+ #[display(fmt = "unknown WebSocket key")]
BadWebsocketKey,
}
@@ -219,10 +220,8 @@ pub fn handshake_response(req: &RequestHead) -> ResponseBuilder {
#[cfg(test)]
mod tests {
- use crate::{header, Method};
-
use super::*;
- use crate::test::TestRequest;
+ use crate::{header, test::TestRequest};
#[test]
fn test_handshake() {
diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs
index 7222168b7..27815eaf2 100644
--- a/actix-http/src/ws/proto.rs
+++ b/actix-http/src/ws/proto.rs
@@ -1,8 +1,6 @@
-use std::{
- convert::{From, Into},
- fmt,
-};
+use std::fmt;
+use base64::prelude::*;
use tracing::error;
/// Operation codes defined in [RFC 6455 §11.8].
@@ -244,7 +242,7 @@ pub fn hash_key(key: &[u8]) -> [u8; 28] {
};
let mut hash_b64 = [0; 28];
- let n = base64::encode_config_slice(hash, base64::STANDARD, &mut hash_b64);
+ let n = BASE64_STANDARD.encode_slice(hash, &mut hash_b64).unwrap();
assert_eq!(n, 28);
hash_b64
diff --git a/actix-http/tests/test_openssl.rs b/actix-http/tests/test_openssl.rs
index 7464bee4e..4dd22b585 100644
--- a/actix-http/tests/test_openssl.rs
+++ b/actix-http/tests/test_openssl.rs
@@ -1,5 +1,4 @@
#![cfg(feature = "openssl")]
-#![allow(clippy::uninlined_format_args)]
extern crate tls_openssl as openssl;
@@ -43,9 +42,11 @@ where
}
fn tls_config() -> SslAcceptor {
- let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap();
- let cert_file = cert.serialize_pem().unwrap();
- let key_file = cert.serialize_private_key_pem();
+ let rcgen::CertifiedKey { cert, key_pair } =
+ rcgen::generate_simple_self_signed(["localhost".to_owned()]).unwrap();
+ let cert_file = cert.pem();
+ let key_file = key_pair.serialize_pem();
+
let cert = X509::from_pem(cert_file.as_bytes()).unwrap();
let key = PKey::private_key_from_pem(key_file.as_bytes()).unwrap();
@@ -321,8 +322,7 @@ async fn h2_body_length() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| async {
- let body =
- once(async { Ok::<_, Infallible>(Bytes::from_static(STR.as_ref())) });
+ let body = once(async { Ok::<_, Infallible>(Bytes::from_static(STR.as_ref())) });
Ok::<_, Infallible>(
Response::ok().set_body(SizedStream::new(STR.len() as u64, body)),
diff --git a/actix-http/tests/test_rustls.rs b/actix-http/tests/test_rustls.rs
index 0b8197a69..3ca0d94c2 100644
--- a/actix-http/tests/test_rustls.rs
+++ b/actix-http/tests/test_rustls.rs
@@ -1,10 +1,9 @@
-#![cfg(feature = "rustls")]
-#![allow(clippy::uninlined_format_args)]
+#![cfg(feature = "rustls-0_23")]
-extern crate tls_rustls as rustls;
+extern crate tls_rustls_023 as rustls;
use std::{
- convert::{Infallible, TryFrom},
+ convert::Infallible,
io::{self, BufReader, Write},
net::{SocketAddr, TcpStream as StdTcpStream},
sync::Arc,
@@ -21,13 +20,13 @@ use actix_http::{
use actix_http_test::test_server;
use actix_rt::pin;
use actix_service::{fn_factory_with_config, fn_service};
-use actix_tls::connect::rustls::webpki_roots_cert_store;
+use actix_tls::connect::rustls_0_23::webpki_roots_cert_store;
use actix_utils::future::{err, ok, poll_fn};
use bytes::{Bytes, BytesMut};
use derive_more::{Display, Error};
use futures_core::{ready, Stream};
use futures_util::stream::once;
-use rustls::{Certificate, PrivateKey, ServerConfig as RustlsServerConfig, ServerName};
+use rustls::{pki_types::ServerName, ServerConfig as RustlsServerConfig};
use rustls_pemfile::{certs, pkcs8_private_keys};
async fn load_body(stream: S) -> Result
@@ -53,24 +52,25 @@ where
}
fn tls_config() -> RustlsServerConfig {
- let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()]).unwrap();
- let cert_file = cert.serialize_pem().unwrap();
- let key_file = cert.serialize_private_key_pem();
+ let rcgen::CertifiedKey { cert, key_pair } =
+ rcgen::generate_simple_self_signed(["localhost".to_owned()]).unwrap();
+ let cert_file = cert.pem();
+ let key_file = key_pair.serialize_pem();
let cert_file = &mut BufReader::new(cert_file.as_bytes());
let key_file = &mut BufReader::new(key_file.as_bytes());
- let cert_chain = certs(cert_file)
- .unwrap()
- .into_iter()
- .map(Certificate)
- .collect();
- let mut keys = pkcs8_private_keys(key_file).unwrap();
+ let cert_chain = certs(cert_file).collect::, _>>().unwrap();
+ let mut keys = pkcs8_private_keys(key_file)
+ .collect::, _>>()
+ .unwrap();
let mut config = RustlsServerConfig::builder()
- .with_safe_defaults()
.with_no_client_auth()
- .with_single_cert(cert_chain, PrivateKey(keys.remove(0)))
+ .with_single_cert(
+ cert_chain,
+ rustls::pki_types::PrivateKeyDer::Pkcs8(keys.remove(0)),
+ )
.unwrap();
config.alpn_protocols.push(HTTP1_1_ALPN_PROTOCOL.to_vec());
@@ -84,17 +84,14 @@ pub fn get_negotiated_alpn_protocol(
client_alpn_protocol: &[u8],
) -> Option> {
let mut config = rustls::ClientConfig::builder()
- .with_safe_defaults()
.with_root_certificates(webpki_roots_cert_store())
.with_no_client_auth();
config.alpn_protocols.push(client_alpn_protocol.to_vec());
- let mut sess = rustls::ClientConnection::new(
- Arc::new(config),
- ServerName::try_from("localhost").unwrap(),
- )
- .unwrap();
+ let mut sess =
+ rustls::ClientConnection::new(Arc::new(config), ServerName::try_from("localhost").unwrap())
+ .unwrap();
let mut sock = StdTcpStream::connect(addr).unwrap();
let mut stream = rustls::Stream::new(&mut sess, &mut sock);
@@ -112,7 +109,7 @@ async fn h1() -> io::Result<()> {
let srv = test_server(move || {
HttpService::build()
.h1(|_| ok::<_, Error>(Response::ok()))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -126,7 +123,7 @@ async fn h2() -> io::Result<()> {
let srv = test_server(move || {
HttpService::build()
.h2(|_| ok::<_, Error>(Response::ok()))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -144,7 +141,7 @@ async fn h1_1() -> io::Result<()> {
assert_eq!(req.version(), Version::HTTP_11);
ok::<_, Error>(Response::ok())
})
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -162,7 +159,7 @@ async fn h2_1() -> io::Result<()> {
assert_eq!(req.version(), Version::HTTP_2);
ok::<_, Error>(Response::ok())
})
- .rustls_with_config(
+ .rustls_0_23_with_config(
tls_config(),
TlsAcceptorConfig::default().handshake_timeout(Duration::from_secs(5)),
)
@@ -183,7 +180,7 @@ async fn h2_body1() -> io::Result<()> {
let body = load_body(req.take_payload()).await?;
Ok::<_, Error>(Response::ok().set_body(body))
})
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -209,7 +206,7 @@ async fn h2_content_length() {
];
ok::<_, Infallible>(Response::new(statuses[indx]))
})
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -281,7 +278,7 @@ async fn h2_headers() {
}
ok::<_, Infallible>(config.body(data.clone()))
})
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -320,7 +317,7 @@ async fn h2_body2() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR)))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -337,7 +334,7 @@ async fn h2_head_empty() {
let mut srv = test_server(move || {
HttpService::build()
.finish(|_| ok::<_, Infallible>(Response::ok().set_body(STR)))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -363,7 +360,7 @@ async fn h2_head_binary() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR)))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -388,7 +385,7 @@ async fn h2_head_binary2() {
let srv = test_server(move || {
HttpService::build()
.h2(|_| ok::<_, Infallible>(Response::ok().set_body(STR)))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -414,7 +411,7 @@ async fn h2_body_length() {
Response::ok().set_body(SizedStream::new(STR.len() as u64, body)),
)
})
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -438,7 +435,7 @@ async fn h2_body_chunked_explicit() {
.body(BodyStream::new(body)),
)
})
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -467,7 +464,7 @@ async fn h2_response_http_error_handling() {
)
}))
}))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -497,7 +494,7 @@ async fn h2_service_error() {
let mut srv = test_server(move || {
HttpService::build()
.h2(|_| err::, _>(BadRequest))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -514,7 +511,7 @@ async fn h1_service_error() {
let mut srv = test_server(move || {
HttpService::build()
.h1(|_| err::, _>(BadRequest))
- .rustls(tls_config())
+ .rustls_0_23(tls_config())
})
.await;
@@ -537,7 +534,7 @@ async fn alpn_h1() -> io::Result<()> {
config.alpn_protocols.push(CUSTOM_ALPN_PROTOCOL.to_vec());
HttpService::build()
.h1(|_| ok::<_, Error>(Response::ok()))
- .rustls(config)
+ .rustls_0_23(config)
})
.await;
@@ -559,7 +556,7 @@ async fn alpn_h2() -> io::Result<()> {
config.alpn_protocols.push(CUSTOM_ALPN_PROTOCOL.to_vec());
HttpService::build()
.h2(|_| ok::<_, Error>(Response::ok()))
- .rustls(config)
+ .rustls_0_23(config)
})
.await;
@@ -585,7 +582,7 @@ async fn alpn_h2_1() -> io::Result<()> {
config.alpn_protocols.push(CUSTOM_ALPN_PROTOCOL.to_vec());
HttpService::build()
.finish(|_| ok::<_, Error>(Response::ok()))
- .rustls(config)
+ .rustls_0_23(config)
})
.await;
diff --git a/actix-http/tests/test_server.rs b/actix-http/tests/test_server.rs
index 2efb336ae..4ba64a53c 100644
--- a/actix-http/tests/test_server.rs
+++ b/actix-http/tests/test_server.rs
@@ -1,5 +1,3 @@
-#![allow(clippy::uninlined_format_args)]
-
use std::{
convert::Infallible,
io::{Read, Write},
@@ -139,7 +137,7 @@ async fn expect_continue_h1() {
#[actix_rt::test]
async fn chunked_payload() {
- let chunk_sizes = vec![32768, 32, 32768];
+ let chunk_sizes = [32768, 32, 32768];
let total_size: usize = chunk_sizes.iter().sum();
let mut srv = test_server(|| {
@@ -149,7 +147,7 @@ async fn chunked_payload() {
.take_payload()
.map(|res| match res {
Ok(pl) => pl,
- Err(e) => panic!("Error reading payload: {}", e),
+ Err(err) => panic!("Error reading payload: {err}"),
})
.fold(0usize, |acc, chunk| ready(acc + chunk.len()))
.map(|req_size| {
@@ -166,8 +164,7 @@ async fn chunked_payload() {
for chunk_size in chunk_sizes.iter() {
let mut bytes = Vec::new();
- let random_bytes: Vec =
- (0..*chunk_size).map(|_| rand::random::()).collect();
+ let random_bytes: Vec = (0..*chunk_size).map(|_| rand::random::()).collect();
bytes.extend(format!("{:X}\r\n", chunk_size).as_bytes());
bytes.extend(&random_bytes[..]);
@@ -352,8 +349,7 @@ async fn http10_keepalive() {
.await;
let mut stream = net::TcpStream::connect(srv.addr()).unwrap();
- let _ =
- stream.write_all(b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n");
+ let _ = stream.write_all(b"GET /test/tests/test HTTP/1.0\r\nconnection: keep-alive\r\n\r\n");
let mut data = vec![0; 1024];
let _ = stream.read(&mut data);
assert_eq!(&data[..17], b"HTTP/1.0 200 OK\r\n");
@@ -404,7 +400,7 @@ async fn content_length() {
let mut srv = test_server(|| {
HttpService::build()
.h1(|req: Request| {
- let indx: usize = req.uri().path()[1..].parse().unwrap();
+ let idx: usize = req.uri().path()[1..].parse().unwrap();
let statuses = [
StatusCode::NO_CONTENT,
StatusCode::CONTINUE,
@@ -413,7 +409,7 @@ async fn content_length() {
StatusCode::OK,
StatusCode::NOT_FOUND,
];
- ok::<_, Infallible>(Response::new(statuses[indx]))
+ ok::<_, Infallible>(Response::new(statuses[idx]))
})
.tcp()
})
@@ -795,8 +791,9 @@ async fn not_modified_spec_h1() {
.map_into_boxed_body(),
// with no content-length
- "/body" => Response::with_body(StatusCode::NOT_MODIFIED, "1234")
- .map_into_boxed_body(),
+ "/body" => {
+ Response::with_body(StatusCode::NOT_MODIFIED, "1234").map_into_boxed_body()
+ }
// with manual content-length header and specific None body
"/cl-none" => {
diff --git a/actix-http/tests/test_ws.rs b/actix-http/tests/test_ws.rs
index a9c1acd33..9a78074c4 100644
--- a/actix-http/tests/test_ws.rs
+++ b/actix-http/tests/test_ws.rs
@@ -1,5 +1,3 @@
-#![allow(clippy::uninlined_format_args)]
-
use std::{
cell::Cell,
convert::Infallible,
@@ -39,13 +37,13 @@ impl WsService {
#[derive(Debug, Display, Error, From)]
enum WsServiceError {
- #[display(fmt = "http error")]
+ #[display(fmt = "HTTP error")]
Http(actix_http::Error),
- #[display(fmt = "ws handshake error")]
+ #[display(fmt = "WS handshake error")]
Ws(actix_http::ws::HandshakeError),
- #[display(fmt = "io error")]
+ #[display(fmt = "I/O error")]
Io(std::io::Error),
#[display(fmt = "dispatcher error")]
diff --git a/actix-multipart-derive/CHANGES.md b/actix-multipart-derive/CHANGES.md
new file mode 100644
index 000000000..1b44ba4b7
--- /dev/null
+++ b/actix-multipart-derive/CHANGES.md
@@ -0,0 +1,14 @@
+# Changes
+
+## Unreleased
+
+- Minimum supported Rust version (MSRV) is now 1.72.
+
+## 0.6.1
+
+- Update `syn` dependency to `2`.
+- Minimum supported Rust version (MSRV) is now 1.68 due to transitive `time` dependency.
+
+## 0.6.0
+
+- Add `MultipartForm` derive macro.
diff --git a/actix-multipart-derive/Cargo.toml b/actix-multipart-derive/Cargo.toml
new file mode 100644
index 000000000..e978864a3
--- /dev/null
+++ b/actix-multipart-derive/Cargo.toml
@@ -0,0 +1,31 @@
+[package]
+name = "actix-multipart-derive"
+version = "0.6.1"
+authors = ["Jacob Halsey "]
+description = "Multipart form derive macro for Actix Web"
+keywords = ["http", "web", "framework", "async", "futures"]
+homepage.workspace = true
+repository.workspace = true
+license.workspace = true
+edition.workspace = true
+rust-version.workspace = true
+
+[package.metadata.docs.rs]
+rustdoc-args = ["--cfg", "docsrs"]
+all-features = true
+
+[lib]
+proc-macro = true
+
+[dependencies]
+darling = "0.20"
+parse-size = "1"
+proc-macro2 = "1"
+quote = "1"
+syn = "2"
+
+[dev-dependencies]
+actix-multipart = "0.6"
+actix-web = "4"
+rustversion = "1"
+trybuild = "1"
diff --git a/actix-multipart-derive/LICENSE-APACHE b/actix-multipart-derive/LICENSE-APACHE
new file mode 120000
index 000000000..965b606f3
--- /dev/null
+++ b/actix-multipart-derive/LICENSE-APACHE
@@ -0,0 +1 @@
+../LICENSE-APACHE
\ No newline at end of file
diff --git a/actix-multipart-derive/LICENSE-MIT b/actix-multipart-derive/LICENSE-MIT
new file mode 120000
index 000000000..76219eb72
--- /dev/null
+++ b/actix-multipart-derive/LICENSE-MIT
@@ -0,0 +1 @@
+../LICENSE-MIT
\ No newline at end of file
diff --git a/actix-multipart-derive/README.md b/actix-multipart-derive/README.md
new file mode 100644
index 000000000..ec0afffdd
--- /dev/null
+++ b/actix-multipart-derive/README.md
@@ -0,0 +1,16 @@
+# `actix-multipart-derive`
+
+> The derive macro implementation for actix-multipart-derive.
+
+
+
+[![crates.io](https://img.shields.io/crates/v/actix-multipart-derive?label=latest)](https://crates.io/crates/actix-multipart-derive)
+[![Documentation](https://docs.rs/actix-multipart-derive/badge.svg?version=0.6.1)](https://docs.rs/actix-multipart-derive/0.6.1)
+![Version](https://img.shields.io/badge/rustc-1.72+-ab6000.svg)
+![MIT or Apache 2.0 licensed](https://img.shields.io/crates/l/actix-multipart-derive.svg)
+
+[![dependency status](https://deps.rs/crate/actix-multipart-derive/0.6.1/status.svg)](https://deps.rs/crate/actix-multipart-derive/0.6.1)
+[![Download](https://img.shields.io/crates/d/actix-multipart-derive.svg)](https://crates.io/crates/actix-multipart-derive)
+[![Chat on Discord](https://img.shields.io/discord/771444961383153695?label=chat&logo=discord)](https://discord.gg/NWpN5mmg3x)
+
+
diff --git a/actix-multipart-derive/src/lib.rs b/actix-multipart-derive/src/lib.rs
new file mode 100644
index 000000000..9552ad2d9
--- /dev/null
+++ b/actix-multipart-derive/src/lib.rs
@@ -0,0 +1,315 @@
+//! Multipart form derive macro for Actix Web.
+//!
+//! See [`macro@MultipartForm`] for usage examples.
+
+#![deny(rust_2018_idioms, nonstandard_style)]
+#![warn(future_incompatible)]
+#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
+#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
+#![cfg_attr(docsrs, feature(doc_auto_cfg))]
+
+use std::collections::HashSet;
+
+use darling::{FromDeriveInput, FromField, FromMeta};
+use parse_size::parse_size;
+use proc_macro::TokenStream;
+use proc_macro2::Ident;
+use quote::quote;
+use syn::{parse_macro_input, Type};
+
+#[derive(FromMeta)]
+enum DuplicateField {
+ Ignore,
+ Deny,
+ Replace,
+}
+
+impl Default for DuplicateField {
+ fn default() -> Self {
+ Self::Ignore
+ }
+}
+
+#[derive(FromDeriveInput, Default)]
+#[darling(attributes(multipart), default)]
+struct MultipartFormAttrs {
+ deny_unknown_fields: bool,
+ duplicate_field: DuplicateField,
+}
+
+#[derive(FromField, Default)]
+#[darling(attributes(multipart), default)]
+struct FieldAttrs {
+ rename: Option,
+ limit: Option,
+}
+
+struct ParsedField<'t> {
+ serialization_name: String,
+ rust_name: &'t Ident,
+ limit: Option,
+ ty: &'t Type,
+}
+
+/// Implements `MultipartCollect` for a struct so that it can be used with the `MultipartForm`
+/// extractor.
+///
+/// # Basic Use
+///
+/// Each field type should implement the `FieldReader` trait:
+///
+/// ```
+/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
+///
+/// #[derive(MultipartForm)]
+/// struct ImageUpload {
+/// description: Text,
+/// timestamp: Text,
+/// image: TempFile,
+/// }
+/// ```
+///
+/// # Optional and List Fields
+///
+/// You can also use `Vec` and `Option` provided that `T: FieldReader`.
+///
+/// A [`Vec`] field corresponds to an upload with multiple parts under the [same field
+/// name](https://www.rfc-editor.org/rfc/rfc7578#section-4.3).
+///
+/// ```
+/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
+///
+/// #[derive(MultipartForm)]
+/// struct Form {
+/// category: Option>,
+/// files: Vec,
+/// }
+/// ```
+///
+/// # Field Renaming
+///
+/// You can use the `#[multipart(rename = "foo")]` attribute to receive a field by a different name.
+///
+/// ```
+/// use actix_multipart::form::{tempfile::TempFile, MultipartForm};
+///
+/// #[derive(MultipartForm)]
+/// struct Form {
+/// #[multipart(rename = "files[]")]
+/// files: Vec,
+/// }
+/// ```
+///
+/// # Field Limits
+///
+/// You can use the `#[multipart(limit = "")]` attribute to set field level limits. The limit
+/// string is parsed using [parse_size].
+///
+/// Note: the form is also subject to the global limits configured using `MultipartFormConfig`.
+///
+/// ```
+/// use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
+///
+/// #[derive(MultipartForm)]
+/// struct Form {
+/// #[multipart(limit = "2 KiB")]
+/// description: Text,
+///
+/// #[multipart(limit = "512 MiB")]
+/// files: Vec,
+/// }
+/// ```
+///
+/// # Unknown Fields
+///
+/// By default fields with an unknown name are ignored. They can be rejected using the
+/// `#[multipart(deny_unknown_fields)]` attribute:
+///
+/// ```
+/// # use actix_multipart::form::MultipartForm;
+/// #[derive(MultipartForm)]
+/// #[multipart(deny_unknown_fields)]
+/// struct Form { }
+/// ```
+///
+/// # Duplicate Fields
+///
+/// The behaviour for when multiple fields with the same name are received can be changed using the
+/// `#[multipart(duplicate_field = "")]` attribute:
+///
+/// - "ignore": (default) Extra fields are ignored. I.e., the first one is persisted.
+/// - "deny": A `MultipartError::UnsupportedField` error response is returned.
+/// - "replace": Each field is processed, but only the last one is persisted.
+///
+/// Note that `Vec` fields will ignore this option.
+///
+/// ```
+/// # use actix_multipart::form::MultipartForm;
+/// #[derive(MultipartForm)]
+/// #[multipart(duplicate_field = "deny")]
+/// struct Form { }
+/// ```
+///
+/// [parse_size]: https://docs.rs/parse-size/1/parse_size
+#[proc_macro_derive(MultipartForm, attributes(multipart))]
+pub fn impl_multipart_form(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let input: syn::DeriveInput = parse_macro_input!(input);
+
+ let name = &input.ident;
+
+ let data_struct = match &input.data {
+ syn::Data::Struct(data_struct) => data_struct,
+ _ => {
+ return compile_err(syn::Error::new(
+ input.ident.span(),
+ "`MultipartForm` can only be derived for structs",
+ ))
+ }
+ };
+
+ let fields = match &data_struct.fields {
+ syn::Fields::Named(fields_named) => fields_named,
+ _ => {
+ return compile_err(syn::Error::new(
+ input.ident.span(),
+ "`MultipartForm` can only be derived for a struct with named fields",
+ ))
+ }
+ };
+
+ let attrs = match MultipartFormAttrs::from_derive_input(&input) {
+ Ok(attrs) => attrs,
+ Err(err) => return err.write_errors().into(),
+ };
+
+ // Parse the field attributes
+ let parsed = match fields
+ .named
+ .iter()
+ .map(|field| {
+ let rust_name = field.ident.as_ref().unwrap();
+ let attrs = FieldAttrs::from_field(field).map_err(|err| err.write_errors())?;
+ let serialization_name = attrs.rename.unwrap_or_else(|| rust_name.to_string());
+
+ let limit = match attrs.limit.map(|limit| match parse_size(&limit) {
+ Ok(size) => Ok(usize::try_from(size).unwrap()),
+ Err(err) => Err(syn::Error::new(
+ field.ident.as_ref().unwrap().span(),
+ format!("Could not parse size limit `{}`: {}", limit, err),
+ )),
+ }) {
+ Some(Err(err)) => return Err(compile_err(err)),
+ limit => limit.map(Result::unwrap),
+ };
+
+ Ok(ParsedField {
+ serialization_name,
+ rust_name,
+ limit,
+ ty: &field.ty,
+ })
+ })
+ .collect::, TokenStream>>()
+ {
+ Ok(attrs) => attrs,
+ Err(err) => return err,
+ };
+
+ // Check that field names are unique
+ let mut set = HashSet::new();
+ for field in &parsed {
+ if !set.insert(field.serialization_name.clone()) {
+ return compile_err(syn::Error::new(
+ field.rust_name.span(),
+ format!("Multiple fields named: `{}`", field.serialization_name),
+ ));
+ }
+ }
+
+ // Return value when a field name is not supported by the form
+ let unknown_field_result = if attrs.deny_unknown_fields {
+ quote!(::std::result::Result::Err(
+ ::actix_multipart::MultipartError::UnsupportedField(field.name().to_string())
+ ))
+ } else {
+ quote!(::std::result::Result::Ok(()))
+ };
+
+ // Value for duplicate action
+ let duplicate_field = match attrs.duplicate_field {
+ DuplicateField::Ignore => quote!(::actix_multipart::form::DuplicateField::Ignore),
+ DuplicateField::Deny => quote!(::actix_multipart::form::DuplicateField::Deny),
+ DuplicateField::Replace => quote!(::actix_multipart::form::DuplicateField::Replace),
+ };
+
+ // limit() implementation
+ let mut limit_impl = quote!();
+ for field in &parsed {
+ let name = &field.serialization_name;
+ if let Some(value) = field.limit {
+ limit_impl.extend(quote!(
+ #name => ::std::option::Option::Some(#value),
+ ));
+ }
+ }
+
+ // handle_field() implementation
+ let mut handle_field_impl = quote!();
+ for field in &parsed {
+ let name = &field.serialization_name;
+ let ty = &field.ty;
+
+ handle_field_impl.extend(quote!(
+ #name => ::std::boxed::Box::pin(
+ <#ty as ::actix_multipart::form::FieldGroupReader>::handle_field(req, field, limits, state, #duplicate_field)
+ ),
+ ));
+ }
+
+ // from_state() implementation
+ let mut from_state_impl = quote!();
+ for field in &parsed {
+ let name = &field.serialization_name;
+ let rust_name = &field.rust_name;
+ let ty = &field.ty;
+ from_state_impl.extend(quote!(
+ #rust_name: <#ty as ::actix_multipart::form::FieldGroupReader>::from_state(#name, &mut state)?,
+ ));
+ }
+
+ let gen = quote! {
+ impl ::actix_multipart::form::MultipartCollect for #name {
+ fn limit(field_name: &str) -> ::std::option::Option {
+ match field_name {
+ #limit_impl
+ _ => None,
+ }
+ }
+
+ fn handle_field<'t>(
+ req: &'t ::actix_web::HttpRequest,
+ field: ::actix_multipart::Field,
+ limits: &'t mut ::actix_multipart::form::Limits,
+ state: &'t mut ::actix_multipart::form::State,
+ ) -> ::std::pin::Pin<::std::boxed::Box> + 't>> {
+ match field.name() {
+ #handle_field_impl
+ _ => return ::std::boxed::Box::pin(::std::future::ready(#unknown_field_result)),
+ }
+ }
+
+ fn from_state(mut state: ::actix_multipart::form::State) -> ::std::result::Result {
+ Ok(Self {
+ #from_state_impl
+ })
+ }
+
+ }
+ };
+ gen.into()
+}
+
+/// Transform a syn error into a token stream for returning.
+fn compile_err(err: syn::Error) -> TokenStream {
+ TokenStream::from(err.to_compile_error())
+}
diff --git a/actix-multipart-derive/tests/trybuild.rs b/actix-multipart-derive/tests/trybuild.rs
new file mode 100644
index 000000000..6b25d78df
--- /dev/null
+++ b/actix-multipart-derive/tests/trybuild.rs
@@ -0,0 +1,16 @@
+#[rustversion::stable(1.72)] // MSRV
+#[test]
+fn compile_macros() {
+ let t = trybuild::TestCases::new();
+
+ t.pass("tests/trybuild/all-required.rs");
+ t.pass("tests/trybuild/optional-and-list.rs");
+ t.pass("tests/trybuild/rename.rs");
+ t.pass("tests/trybuild/deny-unknown.rs");
+
+ t.pass("tests/trybuild/deny-duplicates.rs");
+ t.compile_fail("tests/trybuild/deny-parse-fail.rs");
+
+ t.pass("tests/trybuild/size-limits.rs");
+ t.compile_fail("tests/trybuild/size-limit-parse-fail.rs");
+}
diff --git a/actix-multipart-derive/tests/trybuild/all-required.rs b/actix-multipart-derive/tests/trybuild/all-required.rs
new file mode 100644
index 000000000..1b4a824d9
--- /dev/null
+++ b/actix-multipart-derive/tests/trybuild/all-required.rs
@@ -0,0 +1,19 @@
+use actix_web::{web, App, Responder};
+
+use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
+
+#[derive(Debug, MultipartForm)]
+struct ImageUpload {
+ description: Text,
+ timestamp: Text,
+ image: TempFile,
+}
+
+async fn handler(_form: MultipartForm) -> impl Responder {
+ "Hello World!"
+}
+
+#[actix_web::main]
+async fn main() {
+ App::new().default_service(web::to(handler));
+}
diff --git a/actix-multipart-derive/tests/trybuild/deny-duplicates.rs b/actix-multipart-derive/tests/trybuild/deny-duplicates.rs
new file mode 100644
index 000000000..9fcc1506c
--- /dev/null
+++ b/actix-multipart-derive/tests/trybuild/deny-duplicates.rs
@@ -0,0 +1,16 @@
+use actix_web::{web, App, Responder};
+
+use actix_multipart::form::MultipartForm;
+
+#[derive(MultipartForm)]
+#[multipart(duplicate_field = "deny")]
+struct Form {}
+
+async fn handler(_form: MultipartForm