diff --git a/src/requests.rs b/src/requests.rs index 077de5e..3dfc820 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -3,15 +3,23 @@ use activitystreams_new::primitives::XsdAnyUri; use actix_web::{client::Client, http::header::Date}; use bytes::Bytes; use http_signature_normalization_actix::prelude::*; -use log::{debug, info}; +use log::{debug, info, warn}; use rsa::{hash::Hashes, padding::PaddingScheme, RSAPrivateKey}; use sha2::{Digest, Sha256}; -use std::time::SystemTime; +use std::{ + cell::RefCell, + rc::Rc, + sync::atomic::{AtomicUsize, Ordering}, + time::SystemTime, +}; #[derive(Clone)] pub struct Requests { - client: Client, + client: Rc>, + consecutive_errors: Rc, + error_limit: usize, key_id: String, + user_agent: String, private_key: RSAPrivateKey, config: Config, } @@ -19,21 +27,43 @@ pub struct Requests { impl Requests { pub fn new(key_id: String, private_key: RSAPrivateKey, user_agent: String) -> Self { Requests { - client: Client::build().header("User-Agent", user_agent).finish(), + client: Rc::new(RefCell::new( + Client::build() + .header("User-Agent", user_agent.clone()) + .finish(), + )), + consecutive_errors: Rc::new(AtomicUsize::new(0)), + error_limit: 3, key_id, + user_agent, private_key, config: Config::default().dont_use_created_field(), } } + fn count_err(&self) { + let count = self.consecutive_errors.fetch_add(1, Ordering::Relaxed); + if count + 1 >= self.error_limit { + warn!("{} consecutive errors, rebuilding http client", count); + *self.client.borrow_mut() = Client::build() + .header("User-Agent", self.user_agent.clone()) + .finish(); + self.reset_err(); + } + } + + fn reset_err(&self) { + self.consecutive_errors.swap(0, Ordering::Relaxed); + } + pub async fn fetch(&self, url: &str) -> Result where T: serde::de::DeserializeOwned, { let signer = self.signer(); - let mut res = self - .client + let client: Client = self.client.borrow().clone(); + let res = client .get(url) .header("Accept", "application/activity+json") .set(Date(SystemTime::now().into())) @@ -44,8 +74,15 @@ impl Requests { ) .await? .send() - .await - .map_err(|e| MyError::SendRequest(url.to_string(), e.to_string()))?; + .await; + + if res.is_err() { + self.count_err(); + } + + let mut res = res.map_err(|e| MyError::SendRequest(url.to_string(), e.to_string()))?; + + self.reset_err(); if !res.status().is_success() { if let Ok(bytes) = res.body().await { @@ -68,8 +105,8 @@ impl Requests { info!("Fetching bytes for {}", url); let signer = self.signer(); - let mut res = self - .client + let client: Client = self.client.borrow().clone(); + let res = client .get(url) .header("Accept", "*/*") .set(Date(SystemTime::now().into())) @@ -80,8 +117,15 @@ impl Requests { ) .await? .send() - .await - .map_err(|e| MyError::SendRequest(url.to_string(), e.to_string()))?; + .await; + + if res.is_err() { + self.count_err(); + } + + let mut res = res.map_err(|e| MyError::SendRequest(url.to_string(), e.to_string()))?; + + self.reset_err(); let content_type = if let Some(content_type) = res.headers().get("content-type") { if let Ok(s) = content_type.to_str() { @@ -122,8 +166,8 @@ impl Requests { let signer = self.signer(); let item_string = serde_json::to_string(item)?; - let mut res = self - .client + let client: Client = self.client.borrow().clone(); + let res = client .post(inbox.as_str()) .header("Accept", "application/activity+json") .header("Content-Type", "application/activity+json") @@ -137,8 +181,15 @@ impl Requests { ) .await? .send() - .await - .map_err(|e| MyError::SendRequest(inbox.to_string(), e.to_string()))?; + .await; + + if res.is_err() { + self.count_err(); + } + + let mut res = res.map_err(|e| MyError::SendRequest(inbox.to_string(), e.to_string()))?; + + self.reset_err(); if !res.status().is_success() { if let Ok(bytes) = res.body().await {