From 41f75bf848fee2ad7f81638d9af67cf780409b47 Mon Sep 17 00:00:00 2001 From: Astro Date: Thu, 12 Oct 2023 19:04:05 +0200 Subject: [PATCH] state: refactor --- src/config.rs | 2 +- src/main.rs | 70 +++++++++++++++++---------------------------------- src/relay.rs | 19 ++++++-------- src/state.rs | 36 ++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 60 deletions(-) create mode 100644 src/state.rs diff --git a/src/config.rs b/src/config.rs index 2306b24..5dfb18c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,7 @@ use serde::Deserialize; use sigh::{PrivateKey, PublicKey, Key}; -#[derive(Deserialize)] +#[derive(Clone, Deserialize)] pub struct Config { pub streams: Vec, pub db: String, diff --git a/src/main.rs b/src/main.rs index 572952d..fd4e32a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{FromRef, Path, Query}, + extract::{Path, Query}, http::StatusCode, response::{IntoResponse, Response}, routing::get, Json, Router, @@ -9,13 +9,13 @@ use metrics::increment_counter; use metrics_util::MetricKindMask; use metrics_exporter_prometheus::PrometheusBuilder; use serde_json::json; -use sigh::{PrivateKey, PublicKey}; use std::{net::SocketAddr, sync::Arc, time::Duration, collections::HashMap}; use std::{panic, process}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; mod error; mod config; +mod state; mod actor; mod db; mod digest; @@ -26,22 +26,8 @@ mod relay; mod activitypub; mod endpoint; +use state::State; -#[derive(Clone)] -struct State { - database: db::Database, - client: Arc, - hostname: Arc, - priv_key: PrivateKey, - pub_key: PublicKey, -} - - -impl FromRef for Arc { - fn from_ref(state: &State) -> Arc { - state.client.clone() - } -} fn track_request(method: &'static str, controller: &'static str, result: &'static str) { increment_counter!("api_http_requests_total", "controller" => controller, "method" => method, "result" => result); @@ -315,37 +301,33 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); - let config = config::Config::load( - &std::env::args().nth(1) - .expect("Call with config.yaml") - ); - let priv_key = config.priv_key(); - let pub_key = config.pub_key(); - let recorder = PrometheusBuilder::new() .add_global_label("application", env!("CARGO_PKG_NAME")) .idle_timeout(MetricKindMask::ALL, Some(Duration::from_secs(600))) .install_recorder() .unwrap(); + let config = config::Config::load( + &std::env::args().nth(1) + .expect("Call with config.yaml") + ); let database = db::Database::connect(&config.db).await; - let stream_rx = stream::spawn(config.streams.into_iter()); - let client = Arc::new( - reqwest::Client::builder() - .timeout(Duration::from_secs(5)) - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION"), - )) - .pool_max_idle_per_host(1) - .pool_idle_timeout(Some(Duration::from_secs(5))) - .build() - .unwrap() - ); - let hostname = Arc::new(config.hostname.clone()); - relay::spawn(client.clone(), hostname.clone(), database.clone(), priv_key.clone(), stream_rx); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(5)) + .user_agent(concat!( + env!("CARGO_PKG_NAME"), + "/", + env!("CARGO_PKG_VERSION"), + )) + .pool_max_idle_per_host(1) + .pool_idle_timeout(Some(Duration::from_secs(5))) + .build() + .unwrap(); + let state = State::new(config.clone(), database, client); + + let stream_rx = stream::spawn(config.streams.clone().into_iter()); + relay::spawn(state.clone(), stream_rx); let app = Router::new() .route("/tag/:tag", get(get_tag_actor).post(post_tag_relay)) @@ -357,13 +339,7 @@ async fn main() { .route("/metrics", get(|| async move { recorder.render().into_response() })) - .with_state(State { - database, - client, - hostname, - priv_key, - pub_key, - }) + .with_state(state) .merge(SpaRouter::new("/", "static")); let addr = SocketAddr::from(([127, 0, 0, 1], config.listen_port)); diff --git a/src/relay.rs b/src/relay.rs index bbac752..141038b 100644 --- a/src/relay.rs +++ b/src/relay.rs @@ -7,7 +7,7 @@ use sigh::PrivateKey; use tokio::{ sync::mpsc::Receiver, }; -use crate::{db::Database, send, actor}; +use crate::{send, actor, state::State}; #[derive(Deserialize)] struct Post<'a> { @@ -143,14 +143,9 @@ fn spawn_worker(client: Arc) -> Sender { } pub fn spawn( - client: Arc, - hostname: Arc, - database: Database, - private_key: PrivateKey, + state: State, mut stream_rx: Receiver ) { - let private_key = Arc::new(private_key); - tokio::spawn(async move { let mut workers = HashMap::new(); @@ -175,13 +170,13 @@ pub fn spawn( let mut seen_actors = HashSet::new(); let mut seen_inboxes = HashSet::new(); let published = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true); - for actor in post.relay_targets(hostname.clone()) { + for actor in post.relay_targets(state.hostname.clone()) { if seen_actors.contains(&actor) { continue; } let actor_id = Arc::new(actor.uri()); - let announce_id = format!("https://{}/announce/{}", hostname, urlencoding::encode(&post_url)); + let announce_id = format!("https://{}/announce/{}", state.hostname, urlencoding::encode(&post_url)); let body = json!({ "@context": "https://www.w3.org/ns/activitystreams", "type": "Announce", @@ -196,7 +191,7 @@ pub fn spawn( serde_json::to_vec(&body) .unwrap() ); - for inbox in database.get_following_inboxes(&actor_id).await.unwrap() { + for inbox in state.database.get_following_inboxes(&actor_id).await.unwrap() { let Ok(inbox_url) = reqwest::Url::parse(&inbox) else { continue; }; // Avoid duplicate processing. @@ -212,14 +207,14 @@ pub fn spawn( // Lookup/create worker queue per inbox. let tx = workers.entry(inbox_url.host_str().unwrap_or("").to_string()) - .or_insert_with(|| spawn_worker(client.clone())); + .or_insert_with(|| spawn_worker(state.client.clone())); // Create queue item. let job = Job { post_url: post_url.clone(), actor_id: actor_id.clone(), body: body.clone(), key_id: actor.key_id(), - private_key: private_key.clone(), + private_key: state.priv_key.clone(), inbox_url, }; // Enqueue job for worker. diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..08bb87a --- /dev/null +++ b/src/state.rs @@ -0,0 +1,36 @@ +use axum::{ + extract::FromRef, +}; +use sigh::{PrivateKey, PublicKey}; +use std::sync::Arc; +use crate::{config::Config, db::Database}; + +#[derive(Clone)] +pub struct State { + pub database: Database, + pub client: Arc, + pub hostname: Arc, + pub priv_key: Arc, + pub pub_key: Arc, +} + + +impl FromRef for Arc { + fn from_ref(state: &State) -> Arc { + state.client.clone() + } +} + +impl State { + pub fn new(config: Config, database: Database, client: reqwest::Client) -> Self { + let priv_key = Arc::new(config.priv_key()); + let pub_key = Arc::new(config.pub_key()); + State { + database, + client: Arc::new(client), + hostname: Arc::new(config.hostname), + priv_key, + pub_key, + } + } +}