diff --git a/awc/src/middleware/mod.rs b/awc/src/middleware/mod.rs index 330e3b7fe..df66ad70e 100644 --- a/awc/src/middleware/mod.rs +++ b/awc/src/middleware/mod.rs @@ -1,6 +1,8 @@ mod redirect; +mod retry; pub use self::redirect::Redirect; +pub use self::retry::Retry; use std::marker::PhantomData; diff --git a/awc/src/middleware/retry.rs b/awc/src/middleware/retry.rs new file mode 100644 index 000000000..27bb3a1c3 --- /dev/null +++ b/awc/src/middleware/retry.rs @@ -0,0 +1,472 @@ +use super::Transform; +use std::rc::Rc; +use actix_http::RequestHeadType; +use actix_http::http::{StatusCode, HeaderMap}; +use std::ops::Deref; +use crate::{ConnectRequest, ConnectResponse}; +use actix_service::Service; +use actix_http::client::SendRequestError; +use std::task::{Context, Poll}; +use crate::RequestHead; +use futures_core::future::LocalBoxFuture; +use actix_http::body::Body; + +pub struct Retry(Inner); + +struct Inner { + /// Number of retries. So each request will be tried [max_retry + 1] times + max_retry: u8, + policies: Vec, +} + +impl Retry { + pub fn new(retries: u8) -> Self { + Retry(Inner { + max_retry: retries, + policies: vec![], + }) + } + + /// Allows you to add a retry policy to the [`policies`] + /// It allows two types of policy: + /// - `Vec` and will retry if one of them is received + /// - `Fn(&ResponseHead) -> bool` and will retry when this function resolves to false + /// + /// # example + /// + ///``` + /// + /// // Creates a policy which will try each request a max of 5 times if any policies resolve to true + /// // i.e. + /// // if you receive a 401 or 501 status code + /// // or + /// // the response doesn't have a [`SOME_HEADER`] header + /// use awc::http::{StatusCode, HeaderMap}; + /// use awc::middleware::Retry; + /// + /// let retry_policies = Retry::new(5) + /// .policy(vec![StatusCode::INTERNAL_SERVER_ERROR, StatusCode::UNAUTHORIZED]) + /// .policy(|code: StatusCode, headers: &HeaderMap| { + /// return if headers.contains_key("SOME_HEADER") { + /// true + /// } else { + /// false + /// }; + /// }); + /// + /// // Creates awc client + /// let client = awc::Client::builder() + /// .wrap(retry_policies) + /// .finish(); + ///``` + pub fn policy(mut self, p: T) -> Self + where T: IntoRetryPolicy + { + self.0.policies.push(p.into_policy()); + self + } +} + +#[non_exhaustive] +pub enum RetryPolicy { + Status(Vec), + Custom(Box bool>), +} + +pub trait IntoRetryPolicy { + fn into_policy(self) -> RetryPolicy; +} + +impl IntoRetryPolicy for T + where T: for<'a> Fn(StatusCode, &'a HeaderMap) -> bool + 'static +{ + fn into_policy(self) -> RetryPolicy { + RetryPolicy::Custom(Box::new(self)) + } +} + +impl IntoRetryPolicy for Vec { + fn into_policy(self) -> RetryPolicy { + RetryPolicy::Status(self) + } +} + +impl Transform for Retry + where + S: Service + 'static, +{ + type Transform = RetryService; + + fn new_transform(self, service: S) -> Self::Transform { + RetryService { + max_retry: self.0.max_retry, + policies: self.0.policies.into_boxed_slice().into(), + connector: service.into(), + } + } +} + +#[derive(Clone)] +pub struct RetryService { + policies: Rc<[RetryPolicy]>, + max_retry: u8, + connector: Rc, +} + +impl Service for RetryService + where + S: Service + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = LocalBoxFuture<'static, Result>; + + fn poll_ready(&self, ctx: &mut Context<'_>) -> Poll> { + self.connector.poll_ready(ctx) + } + + fn call(&self, req: ConnectRequest) -> Self::Future { + let connector = self.connector.clone(); + let policies = self.policies.clone(); + let max_retry = self.max_retry; + + Box::pin(async move { + let mut tries = 0; + match req { + ConnectRequest::Client(head, body, addr) => { + match body { + Body::Bytes(b) => { + loop { + let h = clone_request_head_type(&head); + + match connector.call(ConnectRequest::Client(h, Body::Bytes(b.clone()), addr)).await + { + Ok(res) => { + // ConnectResponse + match &res { + ConnectResponse::Client(ref r) => { + if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + return Ok(res); + } + + if tries == max_retry { + return Ok(res); + } + + tries += 1; + } + ConnectResponse::Tunnel(ref head, _) => { + if is_valid_response(policies.as_ref(), head.status, head.headers()) { + return Ok(res); + } + + if tries == max_retry { + return Ok(res); + } + + tries += 1; + } + } + } + // SendRequestError + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e); + } + + tries += 1; + } + } + } + } + Body::Empty => { + loop { + let h = clone_request_head_type(&head); + + match connector.call(ConnectRequest::Client(h, Body::Empty, addr)).await + { + Ok(res) => { + // ConnectResponse + match &res { + ConnectResponse::Client(ref r) => { + if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + return Ok(res); + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + ConnectResponse::Tunnel(ref head, _) => { + if is_valid_response(policies.as_ref(), head.status, head.headers()) { + return Ok(res); + } else { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + } + } + } + // SendRequestError + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e); + } + + tries += 1; + } + } + } + } + _ => { + log::debug!("Non cloneable body type given - defaulting to `Body::None`"); + loop { + let h = clone_request_head_type(&head); + + match connector.call(ConnectRequest::Client(h, Body::None, addr)).await + { + Ok(res) => { + // ConnectResponse + match &res { + ConnectResponse::Client(ref r) => { + if is_valid_response(policies.as_ref(), r.status(), r.headers()) { + return Ok(res); + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + ConnectResponse::Tunnel(ref head, _) => { + if is_valid_response(policies.as_ref(), head.status, head.headers()) { + return Ok(res); + } else { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res); + } + + tries += 1; + } + } + } + } + // SendRequestError + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e); + } + + tries += 1; + } + } + } + } + } + } + ConnectRequest::Tunnel(head, addr) => { + loop { + let h = clone_request_head(&head); + + match connector.call(ConnectRequest::Tunnel(h, addr)).await { + Ok(res) => { + match &res { + ConnectResponse::Client(r) => { + if is_valid_response(&policies, r.status(), r.headers()) { + return Ok(res) + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res) + } + + tries += 1; + } + ConnectResponse::Tunnel(head, _) => { + if is_valid_response(&policies, head.status, head.headers()) { + return Ok(res) + } + + if tries == max_retry { + log::debug!("Request max retry reached"); + return Ok(res) + } + + tries += 1; + } + } + }, + Err(e) => { + if tries == max_retry { + log::debug!("Request max retry reached"); + return Err(e) + } + + tries += 1; + } + } + } + } + } + }) + } +} + +#[doc(hidden)] +/// Clones [RequestHeadType] except for the extensions (not required for this middleware) +fn clone_request_head_type(head_type: &RequestHeadType) -> RequestHeadType { + match head_type { + RequestHeadType::Owned(h) => { + let mut inner_head = RequestHead::default(); + inner_head.uri = h.uri.clone(); + inner_head.method = h.method.clone(); + inner_head.version = h.version; + inner_head.peer_addr = h.peer_addr; + inner_head.headers = h.headers.clone(); + + RequestHeadType::Owned(inner_head) + } + RequestHeadType::Rc(h, header_map) => { + RequestHeadType::Rc(h.clone(), header_map.clone()) + } + } +} + +#[doc(hidden)] +/// Clones [RequestHeadType] except for the extensions (not required for this middleware) +fn clone_request_head(head: &RequestHead) -> RequestHead { + let mut new_head = RequestHead::default(); + new_head.uri = head.uri.clone(); + new_head.method = head.method.clone(); + new_head.version = head.version; + new_head.headers = head.headers.clone(); + new_head.peer_addr = head.peer_addr; + + new_head +} + +#[doc(hidden)] +/// Checks whether the response matches the policies +fn is_valid_response(policies: &[RetryPolicy], status_code: StatusCode, headers: &HeaderMap) -> bool { + policies.iter().all(|policy| { + match policy { + RetryPolicy::Status(v) => { + // is valid if: + // - the list of status codes is empty + // or + // - the list doesn't contain the received status code + v.is_empty() || !v.contains(&status_code) + } + RetryPolicy::Custom(func) => { + // custom policy + (func.deref())(status_code, headers) + } + } + }) +} + +#[cfg(test)] +mod tests { + use actix_web::{web, App, Error, HttpResponse}; + + use super::*; + use crate::ClientBuilder; + + #[actix_rt::test] + async fn test_basic_policy() { + let client = ClientBuilder::new() + .disable_redirects() + .wrap(Retry::new(3) + .policy(vec![StatusCode::INTERNAL_SERVER_ERROR]) + ) + .finish(); + + let srv = actix_test::start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::InternalServerError() + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/test")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 500); + } + + #[actix_rt::test] + async fn test_header_policy() { + std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); + env_logger::init(); + + let client = ClientBuilder::new() + .disable_redirects() + .wrap(Retry::new(3) + .policy(|code: StatusCode, headers: &HeaderMap| { + code.is_success() && headers.contains_key("SOME_HEADER") + }) + ) + .finish(); + + let srv = actix_test::start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .insert_header(("SOME_HEADER", "test")) + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/test")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 200); + } + + #[actix_rt::test] + async fn test_bad_header_policy() { + std::env::set_var("RUST_LOG", "RUST_LOG=debug,a=debug"); + env_logger::init(); + + let client = ClientBuilder::new() + .disable_redirects() + .wrap(Retry::new(3) + .policy(|code: StatusCode, headers: &HeaderMap| { + code.is_success() && headers.contains_key("WRONG_HEADER") + }) + ) + .finish(); + + let srv = actix_test::start(|| { + App::new() + .service(web::resource("/test").route(web::to(|| async { + Ok::<_, Error>( + HttpResponse::Ok() + .insert_header(("SOME_HEADER", "test")) + .finish(), + ) + }))) + }); + + let res = client.get(srv.url("/test")).send().await.unwrap(); + + assert_eq!(res.status().as_u16(), 200); + } +}