From be28a0bd6d404e883e6bef6111a7d9606bab39d6 Mon Sep 17 00:00:00 2001 From: Rob Ede Date: Sat, 10 Aug 2024 01:41:27 +0100 Subject: [PATCH] feat: add from_fn middleware (#3447) --- actix-web/CHANGES.md | 1 + actix-web/examples/middleware_from_fn.rs | 127 +++++++++ actix-web/src/middleware/from_fn.rs | 349 +++++++++++++++++++++++ actix-web/src/middleware/mod.rs | 61 ++-- 4 files changed, 521 insertions(+), 17 deletions(-) create mode 100644 actix-web/examples/middleware_from_fn.rs create mode 100644 actix-web/src/middleware/from_fn.rs diff --git a/actix-web/CHANGES.md b/actix-web/CHANGES.md index 10a0e8038..d26859a36 100644 --- a/actix-web/CHANGES.md +++ b/actix-web/CHANGES.md @@ -4,6 +4,7 @@ ### Added +- Add `middleware::from_fn()` helper. - Add `web::ThinData` extractor. ## 4.8.0 diff --git a/actix-web/examples/middleware_from_fn.rs b/actix-web/examples/middleware_from_fn.rs new file mode 100644 index 000000000..da92ef05b --- /dev/null +++ b/actix-web/examples/middleware_from_fn.rs @@ -0,0 +1,127 @@ +//! Shows a couple of ways to use the `from_fn` middleware. + +use std::{collections::HashMap, io, rc::Rc, time::Duration}; + +use actix_web::{ + body::MessageBody, + dev::{Service, ServiceRequest, ServiceResponse, Transform}, + http::header::{self, HeaderValue, Range}, + middleware::{from_fn, Logger, Next}, + web::{self, Header, Query}, + App, Error, HttpResponse, HttpServer, +}; + +async fn noop(req: ServiceRequest, next: Next) -> Result, Error> { + next.call(req).await +} + +async fn print_range_header( + range_header: Option>, + req: ServiceRequest, + next: Next, +) -> Result, Error> { + if let Some(Header(range)) = range_header { + println!("Range: {range}"); + } else { + println!("No Range header"); + } + + next.call(req).await +} + +async fn mutate_body_type( + req: ServiceRequest, + next: Next, +) -> Result, Error> { + let res = next.call(req).await?; + Ok(res.map_into_left_body::<()>()) +} + +async fn mutate_body_type_with_extractors( + string_body: String, + query: Query>, + req: ServiceRequest, + next: Next, +) -> Result, Error> { + println!("body is: {string_body}"); + println!("query string: {query:?}"); + + let res = next.call(req).await?; + + Ok(res.map_body(move |_, _| string_body)) +} + +async fn timeout_10secs( + req: ServiceRequest, + next: Next, +) -> Result, Error> { + match tokio::time::timeout(Duration::from_secs(10), next.call(req)).await { + Ok(res) => res, + Err(_err) => Err(actix_web::error::ErrorRequestTimeout("")), + } +} + +struct MyMw(bool); + +impl MyMw { + async fn mw_cb( + &self, + req: ServiceRequest, + next: Next, + ) -> Result, Error> { + let mut res = match self.0 { + true => req.into_response("short-circuited").map_into_right_body(), + false => next.call(req).await?.map_into_left_body(), + }; + + res.headers_mut() + .insert(header::WARNING, HeaderValue::from_static("42")); + + Ok(res) + } + + pub fn into_middleware( + self, + ) -> impl Transform< + S, + ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + > + where + S: Service, Error = Error> + 'static, + B: MessageBody + 'static, + { + let this = Rc::new(self); + from_fn(move |req, next| { + let this = Rc::clone(&this); + async move { Self::mw_cb(&this, req, next).await } + }) + } +} + +#[actix_web::main] +async fn main() -> io::Result<()> { + env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); + + let bind = ("127.0.0.1", 8080); + log::info!("staring server at http://{}:{}", &bind.0, &bind.1); + + HttpServer::new(|| { + App::new() + .wrap(from_fn(noop)) + .wrap(from_fn(print_range_header)) + .wrap(from_fn(mutate_body_type)) + .wrap(from_fn(mutate_body_type_with_extractors)) + .wrap(from_fn(timeout_10secs)) + // switch bool to true to observe early response + .wrap(MyMw(false).into_middleware()) + .wrap(Logger::default()) + .default_service(web::to(HttpResponse::Ok)) + }) + .workers(1) + .bind(bind)? + .run() + .await +} diff --git a/actix-web/src/middleware/from_fn.rs b/actix-web/src/middleware/from_fn.rs new file mode 100644 index 000000000..608833319 --- /dev/null +++ b/actix-web/src/middleware/from_fn.rs @@ -0,0 +1,349 @@ +use std::{future::Future, marker::PhantomData, rc::Rc}; + +use actix_service::boxed::{self, BoxFuture, RcService}; +use actix_utils::future::{ready, Ready}; +use futures_core::future::LocalBoxFuture; + +use crate::{ + body::MessageBody, + dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, + Error, FromRequest, +}; + +/// Wraps an async function to be used as a middleware. +/// +/// # Examples +/// +/// The wrapped function should have the following form: +/// +/// ``` +/// # use actix_web::{ +/// # App, Error, +/// # body::MessageBody, +/// # dev::{ServiceRequest, ServiceResponse, Service as _}, +/// # }; +/// use actix_web::middleware::{self, Next}; +/// +/// async fn my_mw( +/// req: ServiceRequest, +/// next: Next, +/// ) -> Result, Error> { +/// // pre-processing +/// next.call(req).await +/// // post-processing +/// } +/// # App::new().wrap(middleware::from_fn(my_mw)); +/// ``` +/// +/// Then use in an app builder like this: +/// +/// ``` +/// use actix_web::{ +/// App, Error, +/// dev::{ServiceRequest, ServiceResponse, Service as _}, +/// }; +/// use actix_web::middleware::from_fn; +/// # use actix_web::middleware::Next; +/// # async fn my_mw(req: ServiceRequest, next: Next) -> Result, Error> { +/// # next.call(req).await +/// # } +/// +/// App::new() +/// .wrap(from_fn(my_mw)) +/// # ; +/// ``` +/// +/// It is also possible to write a middleware that automatically uses extractors, similar to request +/// handlers, by declaring them as the first parameters. As usual, **take care with extractors that +/// consume the body stream**, since handlers will no longer be able to read it again without +/// putting the body "back" into the request object within your middleware. +/// +/// ``` +/// # use std::collections::HashMap; +/// # use actix_web::{ +/// # App, Error, +/// # body::MessageBody, +/// # dev::{ServiceRequest, ServiceResponse}, +/// # http::header::{Accept, Date}, +/// # web::{Header, Query}, +/// # }; +/// use actix_web::middleware::Next; +/// +/// async fn my_extracting_mw( +/// accept: Header, +/// query: Query>, +/// req: ServiceRequest, +/// next: Next, +/// ) -> Result, Error> { +/// // pre-processing +/// next.call(req).await +/// // post-processing +/// } +/// # App::new().wrap(actix_web::middleware::from_fn(my_extracting_mw)); +pub fn from_fn(mw_fn: F) -> MiddlewareFn { + MiddlewareFn { + mw_fn: Rc::new(mw_fn), + _phantom: PhantomData, + } +} + +/// Middleware transform for [`from_fn`]. +#[allow(missing_debug_implementations)] +pub struct MiddlewareFn { + mw_fn: Rc, + _phantom: PhantomData, +} + +impl Transform for MiddlewareFn +where + S: Service, Error = Error> + 'static, + F: Fn(ServiceRequest, Next) -> Fut + 'static, + Fut: Future, Error>>, + B2: MessageBody, +{ + type Response = ServiceResponse; + type Error = Error; + type Transform = MiddlewareFnService; + type InitError = (); + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(MiddlewareFnService { + service: boxed::rc_service(service), + mw_fn: Rc::clone(&self.mw_fn), + _phantom: PhantomData, + })) + } +} + +/// Middleware service for [`from_fn`]. +#[allow(missing_debug_implementations)] +pub struct MiddlewareFnService { + service: RcService, Error>, + mw_fn: Rc, + _phantom: PhantomData<(B, Es)>, +} + +impl Service for MiddlewareFnService +where + F: Fn(ServiceRequest, Next) -> Fut, + Fut: Future, Error>>, + B2: MessageBody, +{ + type Response = ServiceResponse; + type Error = Error; + type Future = Fut; + + forward_ready!(service); + + fn call(&self, req: ServiceRequest) -> Self::Future { + (self.mw_fn)( + req, + Next:: { + service: Rc::clone(&self.service), + }, + ) + } +} + +macro_rules! impl_middleware_fn_service { + ($($ext_type:ident),*) => { + impl Transform for MiddlewareFn + where + S: Service, Error = Error> + 'static, + F: Fn($($ext_type),*, ServiceRequest, Next) -> Fut + 'static, + $($ext_type: FromRequest + 'static,)* + Fut: Future, Error>> + 'static, + B: MessageBody + 'static, + B2: MessageBody + 'static, + { + type Response = ServiceResponse; + type Error = Error; + type Transform = MiddlewareFnService; + type InitError = (); + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(MiddlewareFnService { + service: boxed::rc_service(service), + mw_fn: Rc::clone(&self.mw_fn), + _phantom: PhantomData, + })) + } + } + + impl Service + for MiddlewareFnService + where + F: Fn( + $($ext_type),*, + ServiceRequest, + Next + ) -> Fut + 'static, + $($ext_type: FromRequest + 'static,)* + Fut: Future, Error>> + 'static, + B2: MessageBody + 'static, + { + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + forward_ready!(service); + + #[allow(nonstandard_style)] + fn call(&self, mut req: ServiceRequest) -> Self::Future { + let mw_fn = Rc::clone(&self.mw_fn); + let service = Rc::clone(&self.service); + + Box::pin(async move { + let ($($ext_type,)*) = req.extract::<($($ext_type,)*)>().await?; + + (mw_fn)($($ext_type),*, req, Next:: { service }).await + }) + } + } + }; +} + +impl_middleware_fn_service!(E1); +impl_middleware_fn_service!(E1, E2); +impl_middleware_fn_service!(E1, E2, E3); +impl_middleware_fn_service!(E1, E2, E3, E4); +impl_middleware_fn_service!(E1, E2, E3, E4, E5); +impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6); +impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7); +impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8); +impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8, E9); + +/// Wraps the "next" service in the middleware chain. +#[allow(missing_debug_implementations)] +pub struct Next { + service: RcService, Error>, +} + +impl Next { + /// Equivalent to `Service::call(self, req)`. + pub fn call(&self, req: ServiceRequest) -> >::Future { + Service::call(self, req) + } +} + +impl Service for Next { + type Response = ServiceResponse; + type Error = Error; + type Future = BoxFuture>; + + forward_ready!(service); + + fn call(&self, req: ServiceRequest) -> Self::Future { + self.service.call(req) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + http::header::{self, HeaderValue}, + middleware::{Compat, Logger}, + test, web, App, HttpResponse, + }; + + async fn noop(req: ServiceRequest, next: Next) -> Result, Error> { + next.call(req).await + } + + async fn add_res_header( + req: ServiceRequest, + next: Next, + ) -> Result, Error> { + let mut res = next.call(req).await?; + res.headers_mut() + .insert(header::WARNING, HeaderValue::from_static("42")); + Ok(res) + } + + async fn mutate_body_type( + req: ServiceRequest, + next: Next, + ) -> Result, Error> { + let res = next.call(req).await?; + Ok(res.map_into_left_body::<()>()) + } + + struct MyMw(bool); + + impl MyMw { + async fn mw_cb( + &self, + req: ServiceRequest, + next: Next, + ) -> Result, Error> { + let mut res = match self.0 { + true => req.into_response("short-circuited").map_into_right_body(), + false => next.call(req).await?.map_into_left_body(), + }; + res.headers_mut() + .insert(header::WARNING, HeaderValue::from_static("42")); + Ok(res) + } + + pub fn into_middleware( + self, + ) -> impl Transform< + S, + ServiceRequest, + Response = ServiceResponse, + Error = Error, + InitError = (), + > + where + S: Service, Error = Error> + 'static, + B: MessageBody + 'static, + { + let this = Rc::new(self); + from_fn(move |req, next| { + let this = Rc::clone(&this); + async move { Self::mw_cb(&this, req, next).await } + }) + } + } + + #[actix_rt::test] + async fn compat_compat() { + let _ = App::new().wrap(Compat::new(from_fn(noop))); + let _ = App::new().wrap(Compat::new(from_fn(mutate_body_type))); + } + + #[actix_rt::test] + async fn permits_different_in_and_out_body_types() { + let app = test::init_service( + App::new() + .wrap(from_fn(mutate_body_type)) + .wrap(from_fn(add_res_header)) + .wrap(Logger::default()) + .wrap(from_fn(noop)) + .default_service(web::to(HttpResponse::NotFound)), + ) + .await; + + let req = test::TestRequest::default().to_request(); + let res = test::call_service(&app, req).await; + assert!(res.headers().contains_key(header::WARNING)); + } + + #[actix_rt::test] + async fn closure_capture_and_return_from_fn() { + let app = test::init_service( + App::new() + .wrap(Logger::default()) + .wrap(MyMw(true).into_middleware()) + .wrap(Logger::default()), + ) + .await; + + let req = test::TestRequest::default().to_request(); + let res = test::call_service(&app, req).await; + assert!(res.headers().contains_key(header::WARNING)); + } +} diff --git a/actix-web/src/middleware/mod.rs b/actix-web/src/middleware/mod.rs index 79c94658b..4b5b3e896 100644 --- a/actix-web/src/middleware/mod.rs +++ b/actix-web/src/middleware/mod.rs @@ -15,10 +15,47 @@ //! - Access external services (e.g., [sessions](https://docs.rs/actix-session), etc.) //! //! Middleware is registered for each [`App`], [`Scope`](crate::Scope), or -//! [`Resource`](crate::Resource) and executed in opposite order as registration. In general, a -//! middleware is a pair of types that implements the [`Service`] trait and [`Transform`] trait, -//! respectively. The [`new_transform`] and [`call`] methods must return a [`Future`], though it -//! can often be [an immediately-ready one](actix_utils::future::Ready). +//! [`Resource`](crate::Resource) and executed in opposite order as registration. +//! +//! # Simple Middleware +//! +//! In many cases, you can model your middleware as an async function via the [`from_fn()`] helper +//! that provides a natural interface for implementing your desired behaviors. +//! +//! ``` +//! # use actix_web::{ +//! # App, Error, +//! # body::MessageBody, +//! # dev::{ServiceRequest, ServiceResponse, Service as _}, +//! # }; +//! use actix_web::middleware::{self, Next}; +//! +//! async fn my_mw( +//! req: ServiceRequest, +//! next: Next, +//! ) -> Result, Error> { +//! // pre-processing +//! +//! // invoke the wrapped middleware or service +//! let res = next.call(req).await?; +//! +//! // post-processing +//! +//! Ok(res) +//! } +//! +//! App::new() +//! .wrap(middleware::from_fn(my_mw)); +//! ``` +//! +//! ## Complex Middleware +//! +//! In the more general ase, a middleware is a pair of types that implements the [`Service`] trait +//! and [`Transform`] trait, respectively. The [`new_transform`] and [`call`] methods must return a +//! [`Future`], though it can often be [an immediately-ready one](actix_utils::future::Ready). +//! +//! All the built-in middleware use this pattern with pairs of builder (`Transform`) + +//! implementation (`Service`) types. //! //! # Ordering //! @@ -196,18 +233,6 @@ //! # } //! ``` //! -//! # Simpler Middleware -//! -//! In many cases, you _can_ actually use an async function via a helper that will provide a more -//! natural flow for your behavior. -//! -//! The experimental `actix_web_lab` crate provides a [`from_fn`][lab_from_fn] utility which allows -//! an async fn to be wrapped and used in the same way as other middleware. See the -//! [`from_fn`][lab_from_fn] docs for more info and examples of it's use. -//! -//! While [`from_fn`][lab_from_fn] is experimental currently, it's likely this helper will graduate -//! to Actix Web in some form, so feedback is appreciated. -//! //! [`Future`]: std::future::Future //! [`App`]: crate::App //! [`FromRequest`]: crate::FromRequest @@ -215,7 +240,7 @@ //! [`Transform`]: crate::dev::Transform //! [`call`]: crate::dev::Service::call() //! [`new_transform`]: crate::dev::Transform::new_transform() -//! [lab_from_fn]: https://docs.rs/actix-web-lab/latest/actix_web_lab/middleware/fn.from_fn.html +//! [`from_fn`]: crate mod compat; #[cfg(feature = "__compress")] @@ -223,6 +248,7 @@ mod compress; mod condition; mod default_headers; mod err_handlers; +mod from_fn; mod identity; mod logger; mod normalize; @@ -234,6 +260,7 @@ pub use self::{ condition::Condition, default_headers::DefaultHeaders, err_handlers::{ErrorHandlerResponse, ErrorHandlers}, + from_fn::{from_fn, Next}, identity::Identity, logger::Logger, normalize::{NormalizePath, TrailingSlash},