Use single connection pool

This commit is contained in:
asonix 2020-03-19 17:19:05 -05:00
parent 790d0965fb
commit 3c5154d449
9 changed files with 123 additions and 297 deletions

1
Cargo.lock generated
View file

@ -387,6 +387,7 @@ dependencies = [
"http-signature-normalization-actix", "http-signature-normalization-actix",
"log", "log",
"lru", "lru",
"num_cpus",
"pretty_env_logger", "pretty_env_logger",
"rand", "rand",
"rsa", "rsa",

View file

@ -20,6 +20,7 @@ futures = "0.3.4"
http-signature-normalization-actix = { version = "0.3.0-alpha.5", default-features = false, features = ["sha-2"] } http-signature-normalization-actix = { version = "0.3.0-alpha.5", default-features = false, features = ["sha-2"] }
log = "0.4" log = "0.4"
lru = "0.4.3" lru = "0.4.3"
num_cpus = "1.12"
pretty_env_logger = "0.4.0" pretty_env_logger = "0.4.0"
rand = "0.7" rand = "0.7"
rsa = "0.2" rsa = "0.2"

102
src/db.rs
View file

@ -1,16 +1,80 @@
use crate::error::MyError;
use activitystreams::primitives::XsdAnyUri; use activitystreams::primitives::XsdAnyUri;
use anyhow::Error; use bb8_postgres::{
use bb8_postgres::tokio_postgres::{row::Row, Client}; bb8,
tokio_postgres::{row::Row, Client, Config, NoTls},
PostgresConnectionManager,
};
use log::{info, warn}; use log::{info, warn};
use rsa::RSAPrivateKey; use rsa::RSAPrivateKey;
use rsa_pem::KeyExt; use rsa_pem::KeyExt;
use std::collections::HashSet; use std::{collections::HashSet, convert::TryInto};
#[derive(Clone, Debug, thiserror::Error)] pub type Pool = bb8::Pool<PostgresConnectionManager<NoTls>>;
#[error("No host present in URI")]
pub struct HostError;
pub async fn listen(client: &Client) -> Result<(), Error> { #[derive(Clone)]
pub struct Db {
pool: Pool,
}
impl Db {
pub async fn build(config: Config) -> Result<Self, MyError> {
let manager = PostgresConnectionManager::new(config, NoTls);
let pool = bb8::Pool::builder()
.max_size((num_cpus::get() * 4).try_into()?)
.build(manager)
.await?;
Ok(Db { pool })
}
pub async fn remove_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> {
let conn = self.pool.get().await?;
remove_listener(&conn, &inbox).await?;
Ok(())
}
pub async fn add_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> {
let conn = self.pool.get().await?;
add_listener(&conn, &inbox).await?;
Ok(())
}
pub async fn hydrate_blocks(&self) -> Result<HashSet<String>, MyError> {
let conn = self.pool.get().await?;
Ok(hydrate_blocks(&conn).await?)
}
pub async fn hydrate_whitelists(&self) -> Result<HashSet<String>, MyError> {
let conn = self.pool.get().await?;
Ok(hydrate_whitelists(&conn).await?)
}
pub async fn hydrate_listeners(&self) -> Result<HashSet<XsdAnyUri>, MyError> {
let conn = self.pool.get().await?;
Ok(hydrate_listeners(&conn).await?)
}
pub async fn hydrate_private_key(&self) -> Result<Option<RSAPrivateKey>, MyError> {
let conn = self.pool.get().await?;
Ok(hydrate_private_key(&conn).await?)
}
pub async fn update_private_key(&self, private_key: &RSAPrivateKey) -> Result<(), MyError> {
let conn = self.pool.get().await?;
Ok(update_private_key(&conn, private_key).await?)
}
}
pub async fn listen(client: &Client) -> Result<(), MyError> {
info!("LISTEN new_blocks;"); info!("LISTEN new_blocks;");
info!("LISTEN new_whitelists;"); info!("LISTEN new_whitelists;");
info!("LISTEN new_listeners;"); info!("LISTEN new_listeners;");
@ -31,7 +95,7 @@ pub async fn listen(client: &Client) -> Result<(), Error> {
Ok(()) Ok(())
} }
pub async fn hydrate_private_key(client: &Client) -> Result<Option<RSAPrivateKey>, Error> { async fn hydrate_private_key(client: &Client) -> Result<Option<RSAPrivateKey>, MyError> {
info!("SELECT value FROM settings WHERE key = 'private_key'"); info!("SELECT value FROM settings WHERE key = 'private_key'");
let rows = client let rows = client
.query("SELECT value FROM settings WHERE key = 'private_key'", &[]) .query("SELECT value FROM settings WHERE key = 'private_key'", &[])
@ -45,7 +109,7 @@ pub async fn hydrate_private_key(client: &Client) -> Result<Option<RSAPrivateKey
Ok(None) Ok(None)
} }
pub async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<(), Error> { async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<(), MyError> {
let pem_pkcs8 = key.to_pem_pkcs8()?; let pem_pkcs8 = key.to_pem_pkcs8()?;
info!("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');"); info!("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');");
@ -53,11 +117,11 @@ pub async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<
Ok(()) Ok(())
} }
pub async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), Error> { async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), MyError> {
let host = if let Some(host) = block.as_url().host() { let host = if let Some(host) = block.as_url().host() {
host host
} else { } else {
return Err(HostError.into()); return Err(MyError::Host(block.to_string()));
}; };
info!( info!(
@ -74,11 +138,11 @@ pub async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), Error>
Ok(()) Ok(())
} }
pub async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), Error> { async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), MyError> {
let host = if let Some(host) = whitelist.as_url().host() { let host = if let Some(host) = whitelist.as_url().host() {
host host
} else { } else {
return Err(HostError.into()); return Err(MyError::Host(whitelist.to_string()));
}; };
info!( info!(
@ -95,7 +159,7 @@ pub async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(),
Ok(()) Ok(())
} }
pub async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> { async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), MyError> {
info!( info!(
"DELETE FROM listeners WHERE actor_id = {};", "DELETE FROM listeners WHERE actor_id = {};",
listener.as_str() listener.as_str()
@ -110,7 +174,7 @@ pub async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<()
Ok(()) Ok(())
} }
pub async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> { async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), MyError> {
info!( info!(
"INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now'); [{}]", "INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now'); [{}]",
listener.as_str(), listener.as_str(),
@ -125,14 +189,14 @@ pub async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), E
Ok(()) Ok(())
} }
pub async fn hydrate_blocks(client: &Client) -> Result<HashSet<String>, Error> { async fn hydrate_blocks(client: &Client) -> Result<HashSet<String>, MyError> {
info!("SELECT domain_name FROM blocks"); info!("SELECT domain_name FROM blocks");
let rows = client.query("SELECT domain_name FROM blocks", &[]).await?; let rows = client.query("SELECT domain_name FROM blocks", &[]).await?;
parse_rows(rows) parse_rows(rows)
} }
pub async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, Error> { async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, MyError> {
info!("SELECT domain_name FROM whitelists"); info!("SELECT domain_name FROM whitelists");
let rows = client let rows = client
.query("SELECT domain_name FROM whitelists", &[]) .query("SELECT domain_name FROM whitelists", &[])
@ -141,14 +205,14 @@ pub async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, Erro
parse_rows(rows) parse_rows(rows)
} }
pub async fn hydrate_listeners(client: &Client) -> Result<HashSet<XsdAnyUri>, Error> { async fn hydrate_listeners(client: &Client) -> Result<HashSet<XsdAnyUri>, MyError> {
info!("SELECT actor_id FROM listeners"); info!("SELECT actor_id FROM listeners");
let rows = client.query("SELECT actor_id FROM listeners", &[]).await?; let rows = client.query("SELECT actor_id FROM listeners", &[]).await?;
parse_rows(rows) parse_rows(rows)
} }
fn parse_rows<T, E>(rows: Vec<Row>) -> Result<HashSet<T>, Error> fn parse_rows<T, E>(rows: Vec<Row>) -> Result<HashSet<T>, MyError>
where where
T: std::str::FromStr<Err = E> + Eq + std::hash::Hash, T: std::str::FromStr<Err = E> + Eq + std::hash::Hash,
E: std::fmt::Display, E: std::fmt::Display,

View file

@ -1,164 +0,0 @@
use crate::{
db::{add_listener, remove_listener},
error::MyError,
label::ArbiterLabel,
};
use activitystreams::primitives::XsdAnyUri;
use actix::prelude::*;
use bb8_postgres::{bb8, tokio_postgres, PostgresConnectionManager};
use log::{error, info};
use tokio::sync::oneshot::{channel, Receiver};
#[derive(Clone)]
pub struct Db {
actor: Addr<DbActor>,
}
pub type Pool = bb8::Pool<PostgresConnectionManager<tokio_postgres::tls::NoTls>>;
pub enum DbActorState {
Waiting(tokio_postgres::Config),
Ready(Pool),
}
pub struct DbActor {
pool: DbActorState,
}
pub struct DbQuery<F>(pub F);
impl Db {
pub fn new(config: tokio_postgres::Config) -> Db {
let actor = Supervisor::start(|_| DbActor {
pool: DbActorState::new_empty(config),
});
Db { actor }
}
pub async fn execute_inline<T, F, Fut>(&self, f: F) -> Result<T, MyError>
where
T: Send + 'static,
F: FnOnce(Pool) -> Fut + Send + 'static,
Fut: Future<Output = T>,
{
Ok(self.actor.send(DbQuery(f)).await?.await?)
}
pub async fn remove_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> {
self.execute_inline(move |pool: Pool| {
let inbox = inbox.clone();
async move {
let conn = pool.get().await?;
remove_listener(&conn, &inbox).await
}
})
.await?
.map_err(MyError::from)
}
pub async fn add_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> {
self.execute_inline(move |pool: Pool| {
let inbox = inbox.clone();
async move {
let conn = pool.get().await?;
add_listener(&conn, &inbox).await
}
})
.await?
.map_err(MyError::from)
}
}
impl DbActorState {
pub fn new_empty(config: tokio_postgres::Config) -> Self {
DbActorState::Waiting(config)
}
pub async fn new(config: tokio_postgres::Config) -> Result<Self, tokio_postgres::error::Error> {
let manager = PostgresConnectionManager::new(config, tokio_postgres::tls::NoTls);
let pool = bb8::Pool::builder().max_size(8).build(manager).await?;
Ok(DbActorState::Ready(pool))
}
}
impl Actor for DbActor {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
info!("Starting DB Actor in {}", ArbiterLabel::get());
match self.pool {
DbActorState::Waiting(ref config) => {
let fut =
DbActorState::new(config.clone())
.into_actor(self)
.map(|res, actor, ctx| {
match res {
Ok(pool) => {
info!("DB pool created in {}", ArbiterLabel::get());
actor.pool = pool;
}
Err(e) => {
error!(
"Error starting DB Actor in {}, {}",
ArbiterLabel::get(),
e
);
ctx.stop();
}
};
});
ctx.wait(fut);
}
_ => (),
};
}
}
impl Supervised for DbActor {}
impl<F, Fut, R> Handler<DbQuery<F>> for DbActor
where
F: FnOnce(Pool) -> Fut + 'static,
Fut: Future<Output = R>,
R: Send + 'static,
{
type Result = ResponseFuture<Receiver<R>>;
fn handle(&mut self, msg: DbQuery<F>, ctx: &mut Self::Context) -> Self::Result {
let (tx, rx) = channel();
let pool = match self.pool {
DbActorState::Ready(ref pool) => pool.clone(),
_ => {
error!("Tried to query DB before ready");
return Box::pin(async move { rx });
}
};
ctx.spawn(
async move {
let result = (msg.0)(pool).await;
let _ = tx.send(result);
}
.into_actor(self),
);
Box::pin(async move { rx })
}
}
impl<F, Fut, R> Message for DbQuery<F>
where
F: FnOnce(Pool) -> Fut,
Fut: Future<Output = R>,
R: Send + 'static,
{
type Result = Receiver<R>;
}

View file

@ -1,15 +1,13 @@
use activitystreams::primitives::XsdAnyUriError; use activitystreams::primitives::XsdAnyUriError;
use actix::MailboxError;
use actix_web::{error::ResponseError, http::StatusCode, HttpResponse}; use actix_web::{error::ResponseError, http::StatusCode, HttpResponse};
use log::error; use log::error;
use rsa_pem::KeyError; use rsa_pem::KeyError;
use std::{convert::Infallible, io::Error}; use std::{convert::Infallible, io::Error};
use tokio::sync::oneshot::error::RecvError;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum MyError { pub enum MyError {
#[error("Error in db, {0}")] #[error("Error in db, {0}")]
DbError(#[from] anyhow::Error), DbError(#[from] bb8_postgres::tokio_postgres::error::Error),
#[error("Couldn't parse key, {0}")] #[error("Couldn't parse key, {0}")]
Key(#[from] KeyError), Key(#[from] KeyError),
@ -32,9 +30,6 @@ pub enum MyError {
#[error("Couldn't parse the signature header")] #[error("Couldn't parse the signature header")]
HeaderValidation(#[from] actix_web::http::header::InvalidHeaderValue), HeaderValidation(#[from] actix_web::http::header::InvalidHeaderValue),
#[error("Failed to get output of db operation")]
Oneshot(#[from] RecvError),
#[error("Couldn't decode base64")] #[error("Couldn't decode base64")]
Base64(#[from] base64::DecodeError), Base64(#[from] base64::DecodeError),
@ -56,11 +51,14 @@ pub enum MyError {
#[error("Wrong ActivityPub kind, {0}")] #[error("Wrong ActivityPub kind, {0}")]
Kind(String), Kind(String),
#[error("The requested actor's mailbox is closed")] #[error("No host present in URI, {0}")]
MailboxClosed, Host(String),
#[error("The requested actor's mailbox has timed out")] #[error("Too many CPUs, {0}")]
MailboxTimeout, CpuCount(#[from] std::num::TryFromIntError),
#[error("Timed out while waiting on db pool")]
DbTimeout,
#[error("Invalid algorithm provided to verifier")] #[error("Invalid algorithm provided to verifier")]
Algorithm, Algorithm,
@ -104,6 +102,18 @@ impl ResponseError for MyError {
} }
} }
impl<T> From<bb8_postgres::bb8::RunError<T>> for MyError
where
T: Into<MyError>,
{
fn from(e: bb8_postgres::bb8::RunError<T>) -> Self {
match e {
bb8_postgres::bb8::RunError::User(e) => e.into(),
bb8_postgres::bb8::RunError::TimedOut => MyError::DbTimeout,
}
}
}
impl From<Infallible> for MyError { impl From<Infallible> for MyError {
fn from(i: Infallible) -> Self { fn from(i: Infallible) -> Self {
match i {} match i {}
@ -115,12 +125,3 @@ impl From<rsa::errors::Error> for MyError {
MyError::Rsa(e) MyError::Rsa(e)
} }
} }
impl From<MailboxError> for MyError {
fn from(m: MailboxError) -> MyError {
match m {
MailboxError::Closed => MyError::MailboxClosed,
MailboxError::Timeout => MyError::MailboxTimeout,
}
}
}

View file

@ -1,7 +1,7 @@
use crate::{ use crate::{
accepted, accepted,
apub::{AcceptedActors, AcceptedObjects, ValidTypes}, apub::{AcceptedActors, AcceptedObjects, ValidTypes},
db_actor::Db, db::Db,
error::MyError, error::MyError,
requests::Requests, requests::Requests,
state::{State, UrlKind}, state::{State, UrlKind},

View file

@ -1,35 +0,0 @@
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
#[derive(Clone, Debug)]
pub struct ArbiterLabelFactory(Arc<AtomicUsize>);
#[derive(Clone, Debug)]
pub struct ArbiterLabel(usize);
impl ArbiterLabelFactory {
pub fn new() -> Self {
ArbiterLabelFactory(Arc::new(AtomicUsize::new(0)))
}
pub fn set_label(&self) {
if !actix::Arbiter::contains_item::<ArbiterLabel>() {
let id = self.0.fetch_add(1, Ordering::SeqCst);
actix::Arbiter::set_item(ArbiterLabel(id));
}
}
}
impl ArbiterLabel {
pub fn get() -> ArbiterLabel {
actix::Arbiter::get_item(|label: &ArbiterLabel| label.clone())
}
}
impl std::fmt::Display for ArbiterLabel {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Arbiter #{}", self.0)
}
}

View file

@ -10,10 +10,8 @@ use sha2::{Digest, Sha256};
mod apub; mod apub;
mod db; mod db;
mod db_actor;
mod error; mod error;
mod inbox; mod inbox;
mod label;
mod nodeinfo; mod nodeinfo;
mod notify; mod notify;
mod requests; mod requests;
@ -23,9 +21,8 @@ mod webfinger;
use self::{ use self::{
apub::PublicKey, apub::PublicKey,
db_actor::Db, db::Db,
error::MyError, error::MyError,
label::ArbiterLabelFactory,
state::{State, UrlKind}, state::{State, UrlKind},
verifier::MyVerify, verifier::MyVerify,
webfinger::RelayResolver, webfinger::RelayResolver,
@ -95,25 +92,18 @@ async fn main() -> Result<(), anyhow::Error> {
let use_whitelist = std::env::var("USE_WHITELIST").is_ok(); let use_whitelist = std::env::var("USE_WHITELIST").is_ok();
let use_https = std::env::var("USE_HTTPS").is_ok(); let use_https = std::env::var("USE_HTTPS").is_ok();
let arbiter_labeler = ArbiterLabelFactory::new(); let db = Db::build(pg_config.clone()).await?;
let db = Db::new(pg_config.clone()); let state = State::hydrate(use_https, use_whitelist, hostname, &db).await?;
arbiter_labeler.clone().set_label();
let state: State = db
.execute_inline(move |pool| State::hydrate(use_https, use_whitelist, hostname, pool))
.await??;
let _ = notify::NotifyHandler::start_handler(state.clone(), pg_config.clone()); let _ = notify::NotifyHandler::start_handler(state.clone(), pg_config.clone());
HttpServer::new(move || { HttpServer::new(move || {
arbiter_labeler.clone().set_label();
let state = state.clone(); let state = state.clone();
let actor = Db::new(pg_config.clone());
App::new() App::new()
.wrap(Logger::default()) .wrap(Logger::default())
.data(actor) .data(db.clone())
.data(state.clone()) .data(state.clone())
.data(state.requests()) .data(state.requests())
.service(web::resource("/").route(web::get().to(index))) .service(web::resource("/").route(web::get().to(index)))

View file

@ -1,9 +1,7 @@
use crate::{apub::AcceptedActors, db_actor::Pool, requests::Requests}; use crate::{apub::AcceptedActors, db::Db, error::MyError, requests::Requests};
use activitystreams::primitives::XsdAnyUri; use activitystreams::primitives::XsdAnyUri;
use anyhow::Error;
use bb8_postgres::tokio_postgres::Client;
use futures::try_join; use futures::try_join;
use log::{error, info}; use log::info;
use lru::LruCache; use lru::LruCache;
use rand::thread_rng; use rand::thread_rng;
use rsa::{RSAPrivateKey, RSAPublicKey}; use rsa::{RSAPrivateKey, RSAPublicKey};
@ -44,28 +42,21 @@ pub enum UrlKind {
Outbox, Outbox,
} }
#[derive(Clone, Debug, thiserror::Error)]
#[error("Error generating RSA key")]
pub struct RsaError;
impl Settings { impl Settings {
async fn hydrate( async fn hydrate(
client: &Client, db: &Db,
use_https: bool, use_https: bool,
whitelist_enabled: bool, whitelist_enabled: bool,
hostname: String, hostname: String,
) -> Result<Self, Error> { ) -> Result<Self, MyError> {
let private_key = if let Some(key) = crate::db::hydrate_private_key(client).await? { let private_key = if let Some(key) = db.hydrate_private_key().await? {
key key
} else { } else {
info!("Generating new keys"); info!("Generating new keys");
let mut rng = thread_rng(); let mut rng = thread_rng();
let key = RSAPrivateKey::new(&mut rng, 4096).map_err(|e| { let key = RSAPrivateKey::new(&mut rng, 4096)?;
error!("Error generating RSA key, {}", e);
RsaError
})?;
crate::db::update_private_key(client, &key).await?; db.update_private_key(&key).await?;
key key
}; };
@ -249,35 +240,12 @@ impl State {
use_https: bool, use_https: bool,
whitelist_enabled: bool, whitelist_enabled: bool,
hostname: String, hostname: String,
pool: Pool, db: &Db,
) -> Result<Self, Error> { ) -> Result<Self, MyError> {
let pool1 = pool.clone(); let f1 = db.hydrate_blocks();
let pool2 = pool.clone(); let f2 = db.hydrate_whitelists();
let pool3 = pool.clone(); let f3 = db.hydrate_listeners();
let f4 = Settings::hydrate(db, use_https, whitelist_enabled, hostname);
let f1 = async move {
let conn = pool.get().await?;
crate::db::hydrate_blocks(&conn).await
};
let f2 = async move {
let conn = pool1.get().await?;
crate::db::hydrate_whitelists(&conn).await
};
let f3 = async move {
let conn = pool2.get().await?;
crate::db::hydrate_listeners(&conn).await
};
let f4 = async move {
let conn = pool3.get().await?;
Settings::hydrate(&conn, use_https, whitelist_enabled, hostname).await
};
let (blocks, whitelists, listeners, settings) = try_join!(f1, f2, f3, f4)?; let (blocks, whitelists, listeners, settings) = try_join!(f1, f2, f3, f4)?;