state: refactor

This commit is contained in:
Astro 2023-10-12 19:04:05 +02:00
parent ff9654d4f0
commit 41f75bf848
4 changed files with 67 additions and 60 deletions

View file

@ -1,7 +1,7 @@
use serde::Deserialize;
use sigh::{PrivateKey, PublicKey, Key};
#[derive(Deserialize)]
#[derive(Clone, Deserialize)]
pub struct Config {
pub streams: Vec<String>,
pub db: String,

View file

@ -1,5 +1,5 @@
use axum::{
extract::{FromRef, Path, Query},
extract::{Path, Query},
http::StatusCode,
response::{IntoResponse, Response},
routing::get, Json, Router,
@ -9,13 +9,13 @@ use metrics::increment_counter;
use metrics_util::MetricKindMask;
use metrics_exporter_prometheus::PrometheusBuilder;
use serde_json::json;
use sigh::{PrivateKey, PublicKey};
use std::{net::SocketAddr, sync::Arc, time::Duration, collections::HashMap};
use std::{panic, process};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod error;
mod config;
mod state;
mod actor;
mod db;
mod digest;
@ -26,22 +26,8 @@ mod relay;
mod activitypub;
mod endpoint;
use state::State;
#[derive(Clone)]
struct State {
database: db::Database,
client: Arc<reqwest::Client>,
hostname: Arc<String>,
priv_key: PrivateKey,
pub_key: PublicKey,
}
impl FromRef<State> for Arc<reqwest::Client> {
fn from_ref(state: &State) -> Arc<reqwest::Client> {
state.client.clone()
}
}
fn track_request(method: &'static str, controller: &'static str, result: &'static str) {
increment_counter!("api_http_requests_total", "controller" => controller, "method" => method, "result" => result);
@ -315,24 +301,19 @@ async fn main() {
.with(tracing_subscriber::fmt::layer())
.init();
let config = config::Config::load(
&std::env::args().nth(1)
.expect("Call with config.yaml")
);
let priv_key = config.priv_key();
let pub_key = config.pub_key();
let recorder = PrometheusBuilder::new()
.add_global_label("application", env!("CARGO_PKG_NAME"))
.idle_timeout(MetricKindMask::ALL, Some(Duration::from_secs(600)))
.install_recorder()
.unwrap();
let config = config::Config::load(
&std::env::args().nth(1)
.expect("Call with config.yaml")
);
let database = db::Database::connect(&config.db).await;
let stream_rx = stream::spawn(config.streams.into_iter());
let client = Arc::new(
reqwest::Client::builder()
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.user_agent(concat!(
env!("CARGO_PKG_NAME"),
@ -342,10 +323,11 @@ async fn main() {
.pool_max_idle_per_host(1)
.pool_idle_timeout(Some(Duration::from_secs(5)))
.build()
.unwrap()
);
let hostname = Arc::new(config.hostname.clone());
relay::spawn(client.clone(), hostname.clone(), database.clone(), priv_key.clone(), stream_rx);
.unwrap();
let state = State::new(config.clone(), database, client);
let stream_rx = stream::spawn(config.streams.clone().into_iter());
relay::spawn(state.clone(), stream_rx);
let app = Router::new()
.route("/tag/:tag", get(get_tag_actor).post(post_tag_relay))
@ -357,13 +339,7 @@ async fn main() {
.route("/metrics", get(|| async move {
recorder.render().into_response()
}))
.with_state(State {
database,
client,
hostname,
priv_key,
pub_key,
})
.with_state(state)
.merge(SpaRouter::new("/", "static"));
let addr = SocketAddr::from(([127, 0, 0, 1], config.listen_port));

View file

@ -7,7 +7,7 @@ use sigh::PrivateKey;
use tokio::{
sync::mpsc::Receiver,
};
use crate::{db::Database, send, actor};
use crate::{send, actor, state::State};
#[derive(Deserialize)]
struct Post<'a> {
@ -143,14 +143,9 @@ fn spawn_worker(client: Arc<reqwest::Client>) -> Sender<Job> {
}
pub fn spawn(
client: Arc<reqwest::Client>,
hostname: Arc<String>,
database: Database,
private_key: PrivateKey,
state: State,
mut stream_rx: Receiver<String>
) {
let private_key = Arc::new(private_key);
tokio::spawn(async move {
let mut workers = HashMap::new();
@ -175,13 +170,13 @@ pub fn spawn(
let mut seen_actors = HashSet::new();
let mut seen_inboxes = HashSet::new();
let published = chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Secs, true);
for actor in post.relay_targets(hostname.clone()) {
for actor in post.relay_targets(state.hostname.clone()) {
if seen_actors.contains(&actor) {
continue;
}
let actor_id = Arc::new(actor.uri());
let announce_id = format!("https://{}/announce/{}", hostname, urlencoding::encode(&post_url));
let announce_id = format!("https://{}/announce/{}", state.hostname, urlencoding::encode(&post_url));
let body = json!({
"@context": "https://www.w3.org/ns/activitystreams",
"type": "Announce",
@ -196,7 +191,7 @@ pub fn spawn(
serde_json::to_vec(&body)
.unwrap()
);
for inbox in database.get_following_inboxes(&actor_id).await.unwrap() {
for inbox in state.database.get_following_inboxes(&actor_id).await.unwrap() {
let Ok(inbox_url) = reqwest::Url::parse(&inbox) else { continue; };
// Avoid duplicate processing.
@ -212,14 +207,14 @@ pub fn spawn(
// Lookup/create worker queue per inbox.
let tx = workers.entry(inbox_url.host_str().unwrap_or("").to_string())
.or_insert_with(|| spawn_worker(client.clone()));
.or_insert_with(|| spawn_worker(state.client.clone()));
// Create queue item.
let job = Job {
post_url: post_url.clone(),
actor_id: actor_id.clone(),
body: body.clone(),
key_id: actor.key_id(),
private_key: private_key.clone(),
private_key: state.priv_key.clone(),
inbox_url,
};
// Enqueue job for worker.

36
src/state.rs Normal file
View file

@ -0,0 +1,36 @@
use axum::{
extract::FromRef,
};
use sigh::{PrivateKey, PublicKey};
use std::sync::Arc;
use crate::{config::Config, db::Database};
#[derive(Clone)]
pub struct State {
pub database: Database,
pub client: Arc<reqwest::Client>,
pub hostname: Arc<String>,
pub priv_key: Arc<PrivateKey>,
pub pub_key: Arc<PublicKey>,
}
impl FromRef<State> for Arc<reqwest::Client> {
fn from_ref(state: &State) -> Arc<reqwest::Client> {
state.client.clone()
}
}
impl State {
pub fn new(config: Config, database: Database, client: reqwest::Client) -> Self {
let priv_key = Arc::new(config.priv_key());
let pub_key = Arc::new(config.pub_key());
State {
database,
client: Arc::new(client),
hostname: Arc::new(config.hostname),
priv_key,
pub_key,
}
}
}