diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index a12fed4b9..e46d07e73 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -54,6 +54,7 @@ bitflags = "1.2" bytes = "1" bytestring = "1" derive_more = "0.99.5" +dyn-clone = "1" encoding_rs = "0.8" futures-core = { version = "0.3.7", default-features = false, features = ["alloc"] } futures-util = { version = "0.3.7", default-features = false, features = ["alloc", "sink"] } diff --git a/actix-http/src/builder.rs b/actix-http/src/builder.rs index af5a377c6..04e8282e8 100644 --- a/actix-http/src/builder.rs +++ b/actix-http/src/builder.rs @@ -3,7 +3,15 @@ use std::{error::Error as StdError, fmt, marker::PhantomData, net, rc::Rc}; use actix_codec::Framed; use actix_service::{IntoServiceFactory, Service, ServiceFactory}; -use crate::{ConnectCallback, Extensions, Request, Response, body::{AnyBody, MessageBody}, config::{KeepAlive, ServiceConfig}, h1::{self, ExpectHandler, H1Service, UpgradeHandler}, h2::H2Service, service::HttpService}; +use crate::{ + body::{AnyBody, MessageBody}, + config::{KeepAlive, ServiceConfig}, + extensions::CloneableExtensions, + h1::{self, ExpectHandler, H1Service, UpgradeHandler}, + h2::H2Service, + service::HttpService, + ConnectCallback, Request, Response, +}; /// A HTTP service builder /// @@ -160,7 +168,7 @@ where /// and handlers. pub fn on_connect_ext(mut self, f: F) -> Self where - F: Fn(&T, &mut Extensions) + 'static, + F: Fn(&T, &mut CloneableExtensions) + 'static, { self.on_connect_ext = Some(Rc::new(f)); self diff --git a/actix-http/src/extensions.rs b/actix-http/src/extensions.rs index 1e02063ed..a3b9bb5b0 100644 --- a/actix-http/src/extensions.rs +++ b/actix-http/src/extensions.rs @@ -1,7 +1,6 @@ use std::{ any::{Any, TypeId}, fmt, mem, - rc::Rc, }; use ahash::AHashMap; @@ -13,7 +12,7 @@ use ahash::AHashMap; pub struct Extensions { /// Use FxHasher with a std HashMap with for faster /// lookups on the small `TypeId` (u64 equivalent) keys. - map: AHashMap>, + map: AHashMap>, } impl Extensions { @@ -39,8 +38,8 @@ impl Extensions { /// ``` pub fn insert(&mut self, val: T) -> Option { self.map - .insert(TypeId::of::(), Rc::new(val)) - .and_then(downcast_rc) + .insert(TypeId::of::(), Box::new(val)) + .and_then(downcast_owned) } /// Check if map contains an item of a given type. @@ -71,19 +70,19 @@ impl Extensions { .and_then(|boxed| boxed.downcast_ref()) } - // /// Get a mutable reference to an item of a given type. - // /// - // /// ``` - // /// # use actix_http::Extensions; - // /// let mut map = Extensions::new(); - // /// map.insert(1u32); - // /// assert_eq!(map.get_mut::(), Some(&mut 1u32)); - // /// ``` - // pub fn get_mut(&mut self) -> Option<&mut T> { - // self.map - // .get_mut(&TypeId::of::()) - // .and_then(|boxed| boxed.downcast_mut()) - // } + /// Get a mutable reference to an item of a given type. + /// + /// ``` + /// # use actix_http::Extensions; + /// let mut map = Extensions::new(); + /// map.insert(1u32); + /// assert_eq!(map.get_mut::(), Some(&mut 1u32)); + /// ``` + pub fn get_mut(&mut self) -> Option<&mut T> { + self.map + .get_mut(&TypeId::of::()) + .and_then(|boxed| boxed.downcast_mut()) + } /// Remove an item from the map of a given type. /// @@ -100,7 +99,7 @@ impl Extensions { /// assert!(!map.contains::()); /// ``` pub fn remove(&mut self) -> Option { - self.map.remove(&TypeId::of::()).and_then(downcast_rc) + self.map.remove(&TypeId::of::()).and_then(downcast_owned) } /// Clear the `Extensions` of all inserted extensions. @@ -131,9 +130,9 @@ impl Extensions { } /// Sets (or overrides) items from cloneable extensions map into this map. - pub(crate) fn clone_from(&mut self, other: &Self) { + pub(crate) fn clone_from(&mut self, other: &CloneableExtensions) { for (k, val) in &other.map { - self.map.insert(*k, Rc::clone(val)); + self.map.insert(*k, (**val).clone_to_any()); } } } @@ -148,48 +147,86 @@ fn downcast_owned(boxed: Box) -> Option { boxed.downcast().ok().map(|boxed| *boxed) } -fn downcast_rc(boxed: Rc) -> Option { - boxed - .downcast() - .ok() - .and_then(|boxed| Rc::try_unwrap(boxed).ok()) +// fn downcast_rc(boxed: Rc) -> Option { +// boxed +// .downcast() +// .ok() +// .and_then(|boxed| Rc::try_unwrap(boxed).ok()) +// } + +#[doc(hidden)] +pub trait CloneToAny { + /// Clone `self` into a new `Box` object. + fn clone_to_any(&self) -> Box; + + /// Clone `self` into a new `Box` object. + fn clone_to_clone_any(&self) -> Box; } -// /// A type map for request extensions. -// /// -// /// All entries into this map must be owned types (or static references). -// #[derive(Default)] -// pub struct CloneableExtensions { -// /// Use FxHasher with a std HashMap with for faster -// /// lookups on the small `TypeId` (u64 equivalent) keys. -// map: AHashMap>, -// } +impl CloneToAny for T { + #[inline] + fn clone_to_any(&self) -> Box { + Box::new(self.clone()) + } -// impl CloneableExtensions { -// pub(crate) fn priv_clone(&self) -> CloneableExtensions { -// Self { -// map: self.map.clone(), -// } -// } + #[inline] + fn clone_to_clone_any(&self) -> Box { + Box::new(self.clone()) + } +} -// /// Insert an item into the map. -// /// -// /// If an item of this type was already stored, it will be replaced and returned. -// /// -// /// ``` -// /// # use actix_http::Extensions; -// /// let mut map = Extensions::new(); -// /// assert_eq!(map.insert(""), None); -// /// assert_eq!(map.insert(1u32), None); -// /// assert_eq!(map.insert(2u32), Some(1u32)); -// /// assert_eq!(*map.get::().unwrap(), 2u32); -// /// ``` -// pub fn insert(&mut self, val: T) -> Option { -// self.map -// .insert(TypeId::of::(), Rc::new(val)) -// .and_then(downcast_rc) -// } -// } +/// An [`Any`] trait with an additional [`Clone`] requirement. +pub trait CloneAny: Any + CloneToAny {} +impl CloneAny for T {} + +impl Clone for Box { + fn clone(&self) -> Self { + (**self).clone_to_clone_any() + } +} + +trait UncheckedAnyExt { + #[inline] + unsafe fn downcast_unchecked(self: Box) -> Box { + Box::from_raw(Box::into_raw(self) as *mut T) + } +} + +impl UncheckedAnyExt for dyn CloneAny {} + +fn downcast_cloneable(boxed: Box) -> T { + *unsafe { UncheckedAnyExt::downcast_unchecked::(boxed) } +} + +/// A type map for request extensions. +/// +/// All entries into this map must be owned types (or static references). +#[derive(Default)] +pub struct CloneableExtensions { + /// Use FxHasher with a std HashMap with for faster + /// lookups on the small `TypeId` (u64 equivalent) keys. + map: AHashMap>, +} + +impl CloneableExtensions { + /// Insert an item into the map. + /// + /// If an item of this type was already stored, it will be replaced and returned. + /// + /// ``` + /// # use actix_http::Extensions; + /// let mut map = Extensions::new(); + /// assert_eq!(map.insert(""), None); + /// assert_eq!(map.insert(1u32), None); + /// assert_eq!(map.insert(2u32), Some(1u32)); + /// assert_eq!(*map.get::().unwrap(), 2u32); + /// ``` + pub fn insert(&mut self, val: T) -> Option { + self.map + .insert(TypeId::of::(), Box::new(val)) + .and_then(downcast_cloneable) + } +} #[cfg(test)] mod tests { diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index e0a7e70a9..32c41174c 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -55,7 +55,7 @@ pub mod ws; pub use self::builder::HttpServiceBuilder; pub use self::config::{KeepAlive, ServiceConfig}; pub use self::error::Error; -pub use self::extensions::Extensions; +pub use self::extensions::{CloneableExtensions, Extensions}; pub use self::header::ContentEncoding; pub use self::http_message::HttpMessage; pub use self::message::ConnectionType; @@ -98,13 +98,13 @@ pub enum Protocol { Http3, } -type ConnectCallback = dyn Fn(&IO, &mut Extensions); +type ConnectCallback = dyn Fn(&IO, &mut CloneableExtensions); /// Container for data that extract with ConnectCallback. /// /// # Implementation Details /// Uses Option to reduce necessary allocations when merging with request extensions. -pub(crate) struct OnConnectData(Option); +pub(crate) struct OnConnectData(Option); impl Default for OnConnectData { fn default() -> Self { @@ -119,7 +119,7 @@ impl OnConnectData { on_connect_ext: Option<&ConnectCallback>, ) -> Self { let ext = on_connect_ext.map(|handler| { - let mut extensions = Extensions::default(); + let mut extensions = CloneableExtensions::default(); handler(io, &mut extensions); extensions }); diff --git a/examples/on_connect.rs b/examples/on_connect.rs index 24ac86c6b..301770b43 100644 --- a/examples/on_connect.rs +++ b/examples/on_connect.rs @@ -6,7 +6,8 @@ use std::{any::Any, io, net::SocketAddr}; -use actix_web::{dev::Extensions, rt::net::TcpStream, web, App, HttpServer}; +use actix_web::{rt::net::TcpStream, web, App, HttpServer}; +use actix_http::CloneableExtensions; #[derive(Debug, Clone)] struct ConnectionInfo { @@ -22,7 +23,7 @@ async fn route_whoami(conn_info: web::ReqData) -> String { ) } -fn get_conn_info(connection: &dyn Any, data: &mut Extensions) { +fn get_conn_info(connection: &dyn Any, data: &mut CloneableExtensions) { if let Some(sock) = connection.downcast_ref::() { data.insert(ConnectionInfo { bind: sock.local_addr().unwrap(), diff --git a/src/server.rs b/src/server.rs index 194b49139..17ff4f657 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,7 +8,9 @@ use std::{ sync::{Arc, Mutex}, }; -use actix_http::{body::MessageBody, Extensions, HttpService, KeepAlive, Request, Response}; +use actix_http::{ + body::MessageBody, CloneableExtensions, HttpService, KeepAlive, Request, Response, +}; use actix_server::{Server, ServerBuilder}; use actix_service::{ map_config, IntoServiceFactory, Service, ServiceFactory, ServiceFactoryExt as _, @@ -65,7 +67,7 @@ where backlog: u32, sockets: Vec, builder: ServerBuilder, - on_connect_fn: Option>, + on_connect_fn: Option>, _phantom: PhantomData<(S, B)>, } @@ -115,7 +117,7 @@ where /// See `on_connect` example for additional details. pub fn on_connect(self, f: CB) -> HttpServer where - CB: Fn(&dyn Any, &mut Extensions) + Send + Sync + 'static, + CB: Fn(&dyn Any, &mut CloneableExtensions) + Send + Sync + 'static, { HttpServer { factory: self.factory,