Use single connection pool

This commit is contained in:
asonix 2020-03-19 17:19:05 -05:00
parent 790d0965fb
commit 3c5154d449
9 changed files with 123 additions and 297 deletions

1
Cargo.lock generated
View file

@ -387,6 +387,7 @@ dependencies = [
"http-signature-normalization-actix",
"log",
"lru",
"num_cpus",
"pretty_env_logger",
"rand",
"rsa",

View file

@ -20,6 +20,7 @@ futures = "0.3.4"
http-signature-normalization-actix = { version = "0.3.0-alpha.5", default-features = false, features = ["sha-2"] }
log = "0.4"
lru = "0.4.3"
num_cpus = "1.12"
pretty_env_logger = "0.4.0"
rand = "0.7"
rsa = "0.2"

102
src/db.rs
View file

@ -1,16 +1,80 @@
use crate::error::MyError;
use activitystreams::primitives::XsdAnyUri;
use anyhow::Error;
use bb8_postgres::tokio_postgres::{row::Row, Client};
use bb8_postgres::{
bb8,
tokio_postgres::{row::Row, Client, Config, NoTls},
PostgresConnectionManager,
};
use log::{info, warn};
use rsa::RSAPrivateKey;
use rsa_pem::KeyExt;
use std::collections::HashSet;
use std::{collections::HashSet, convert::TryInto};
#[derive(Clone, Debug, thiserror::Error)]
#[error("No host present in URI")]
pub struct HostError;
pub type Pool = bb8::Pool<PostgresConnectionManager<NoTls>>;
pub async fn listen(client: &Client) -> Result<(), Error> {
#[derive(Clone)]
pub struct Db {
pool: Pool,
}
impl Db {
pub async fn build(config: Config) -> Result<Self, MyError> {
let manager = PostgresConnectionManager::new(config, NoTls);
let pool = bb8::Pool::builder()
.max_size((num_cpus::get() * 4).try_into()?)
.build(manager)
.await?;
Ok(Db { pool })
}
pub async fn remove_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> {
let conn = self.pool.get().await?;
remove_listener(&conn, &inbox).await?;
Ok(())
}
pub async fn add_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> {
let conn = self.pool.get().await?;
add_listener(&conn, &inbox).await?;
Ok(())
}
pub async fn hydrate_blocks(&self) -> Result<HashSet<String>, MyError> {
let conn = self.pool.get().await?;
Ok(hydrate_blocks(&conn).await?)
}
pub async fn hydrate_whitelists(&self) -> Result<HashSet<String>, MyError> {
let conn = self.pool.get().await?;
Ok(hydrate_whitelists(&conn).await?)
}
pub async fn hydrate_listeners(&self) -> Result<HashSet<XsdAnyUri>, MyError> {
let conn = self.pool.get().await?;
Ok(hydrate_listeners(&conn).await?)
}
pub async fn hydrate_private_key(&self) -> Result<Option<RSAPrivateKey>, MyError> {
let conn = self.pool.get().await?;
Ok(hydrate_private_key(&conn).await?)
}
pub async fn update_private_key(&self, private_key: &RSAPrivateKey) -> Result<(), MyError> {
let conn = self.pool.get().await?;
Ok(update_private_key(&conn, private_key).await?)
}
}
pub async fn listen(client: &Client) -> Result<(), MyError> {
info!("LISTEN new_blocks;");
info!("LISTEN new_whitelists;");
info!("LISTEN new_listeners;");
@ -31,7 +95,7 @@ pub async fn listen(client: &Client) -> Result<(), Error> {
Ok(())
}
pub async fn hydrate_private_key(client: &Client) -> Result<Option<RSAPrivateKey>, Error> {
async fn hydrate_private_key(client: &Client) -> Result<Option<RSAPrivateKey>, MyError> {
info!("SELECT value FROM settings WHERE key = 'private_key'");
let rows = client
.query("SELECT value FROM settings WHERE key = 'private_key'", &[])
@ -45,7 +109,7 @@ pub async fn hydrate_private_key(client: &Client) -> Result<Option<RSAPrivateKey
Ok(None)
}
pub async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<(), Error> {
async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<(), MyError> {
let pem_pkcs8 = key.to_pem_pkcs8()?;
info!("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');");
@ -53,11 +117,11 @@ pub async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<
Ok(())
}
pub async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), Error> {
async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), MyError> {
let host = if let Some(host) = block.as_url().host() {
host
} else {
return Err(HostError.into());
return Err(MyError::Host(block.to_string()));
};
info!(
@ -74,11 +138,11 @@ pub async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), Error>
Ok(())
}
pub async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), Error> {
async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), MyError> {
let host = if let Some(host) = whitelist.as_url().host() {
host
} else {
return Err(HostError.into());
return Err(MyError::Host(whitelist.to_string()));
};
info!(
@ -95,7 +159,7 @@ pub async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(),
Ok(())
}
pub async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> {
async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), MyError> {
info!(
"DELETE FROM listeners WHERE actor_id = {};",
listener.as_str()
@ -110,7 +174,7 @@ pub async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<()
Ok(())
}
pub async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> {
async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), MyError> {
info!(
"INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now'); [{}]",
listener.as_str(),
@ -125,14 +189,14 @@ pub async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), E
Ok(())
}
pub async fn hydrate_blocks(client: &Client) -> Result<HashSet<String>, Error> {
async fn hydrate_blocks(client: &Client) -> Result<HashSet<String>, MyError> {
info!("SELECT domain_name FROM blocks");
let rows = client.query("SELECT domain_name FROM blocks", &[]).await?;
parse_rows(rows)
}
pub async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, Error> {
async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, MyError> {
info!("SELECT domain_name FROM whitelists");
let rows = client
.query("SELECT domain_name FROM whitelists", &[])
@ -141,14 +205,14 @@ pub async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, Erro
parse_rows(rows)
}
pub async fn hydrate_listeners(client: &Client) -> Result<HashSet<XsdAnyUri>, Error> {
async fn hydrate_listeners(client: &Client) -> Result<HashSet<XsdAnyUri>, MyError> {
info!("SELECT actor_id FROM listeners");
let rows = client.query("SELECT actor_id FROM listeners", &[]).await?;
parse_rows(rows)
}
fn parse_rows<T, E>(rows: Vec<Row>) -> Result<HashSet<T>, Error>
fn parse_rows<T, E>(rows: Vec<Row>) -> Result<HashSet<T>, MyError>
where
T: std::str::FromStr<Err = E> + Eq + std::hash::Hash,
E: std::fmt::Display,

View file

@ -1,164 +0,0 @@
use crate::{
db::{add_listener, remove_listener},
error::MyError,
label::ArbiterLabel,
};
use activitystreams::primitives::XsdAnyUri;
use actix::prelude::*;
use bb8_postgres::{bb8, tokio_postgres, PostgresConnectionManager};
use log::{error, info};
use tokio::sync::oneshot::{channel, Receiver};
#[derive(Clone)]
pub struct Db {
actor: Addr<DbActor>,
}
pub type Pool = bb8::Pool<PostgresConnectionManager<tokio_postgres::tls::NoTls>>;
pub enum DbActorState {
Waiting(tokio_postgres::Config),
Ready(Pool),
}
pub struct DbActor {
pool: DbActorState,
}
pub struct DbQuery<F>(pub F);
impl Db {
pub fn new(config: tokio_postgres::Config) -> Db {
let actor = Supervisor::start(|_| DbActor {
pool: DbActorState::new_empty(config),
});
Db { actor }
}
pub async fn execute_inline<T, F, Fut>(&self, f: F) -> Result<T, MyError>
where
T: Send + 'static,
F: FnOnce(Pool) -> Fut + Send + 'static,
Fut: Future<Output = T>,
{
Ok(self.actor.send(DbQuery(f)).await?.await?)
}
pub async fn remove_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> {
self.execute_inline(move |pool: Pool| {
let inbox = inbox.clone();
async move {
let conn = pool.get().await?;
remove_listener(&conn, &inbox).await
}
})
.await?
.map_err(MyError::from)
}
pub async fn add_listener(&self, inbox: XsdAnyUri) -> Result<(), MyError> {
self.execute_inline(move |pool: Pool| {
let inbox = inbox.clone();
async move {
let conn = pool.get().await?;
add_listener(&conn, &inbox).await
}
})
.await?
.map_err(MyError::from)
}
}
impl DbActorState {
pub fn new_empty(config: tokio_postgres::Config) -> Self {
DbActorState::Waiting(config)
}
pub async fn new(config: tokio_postgres::Config) -> Result<Self, tokio_postgres::error::Error> {
let manager = PostgresConnectionManager::new(config, tokio_postgres::tls::NoTls);
let pool = bb8::Pool::builder().max_size(8).build(manager).await?;
Ok(DbActorState::Ready(pool))
}
}
impl Actor for DbActor {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
info!("Starting DB Actor in {}", ArbiterLabel::get());
match self.pool {
DbActorState::Waiting(ref config) => {
let fut =
DbActorState::new(config.clone())
.into_actor(self)
.map(|res, actor, ctx| {
match res {
Ok(pool) => {
info!("DB pool created in {}", ArbiterLabel::get());
actor.pool = pool;
}
Err(e) => {
error!(
"Error starting DB Actor in {}, {}",
ArbiterLabel::get(),
e
);
ctx.stop();
}
};
});
ctx.wait(fut);
}
_ => (),
};
}
}
impl Supervised for DbActor {}
impl<F, Fut, R> Handler<DbQuery<F>> for DbActor
where
F: FnOnce(Pool) -> Fut + 'static,
Fut: Future<Output = R>,
R: Send + 'static,
{
type Result = ResponseFuture<Receiver<R>>;
fn handle(&mut self, msg: DbQuery<F>, ctx: &mut Self::Context) -> Self::Result {
let (tx, rx) = channel();
let pool = match self.pool {
DbActorState::Ready(ref pool) => pool.clone(),
_ => {
error!("Tried to query DB before ready");
return Box::pin(async move { rx });
}
};
ctx.spawn(
async move {
let result = (msg.0)(pool).await;
let _ = tx.send(result);
}
.into_actor(self),
);
Box::pin(async move { rx })
}
}
impl<F, Fut, R> Message for DbQuery<F>
where
F: FnOnce(Pool) -> Fut,
Fut: Future<Output = R>,
R: Send + 'static,
{
type Result = Receiver<R>;
}

View file

@ -1,15 +1,13 @@
use activitystreams::primitives::XsdAnyUriError;
use actix::MailboxError;
use actix_web::{error::ResponseError, http::StatusCode, HttpResponse};
use log::error;
use rsa_pem::KeyError;
use std::{convert::Infallible, io::Error};
use tokio::sync::oneshot::error::RecvError;
#[derive(Debug, thiserror::Error)]
pub enum MyError {
#[error("Error in db, {0}")]
DbError(#[from] anyhow::Error),
DbError(#[from] bb8_postgres::tokio_postgres::error::Error),
#[error("Couldn't parse key, {0}")]
Key(#[from] KeyError),
@ -32,9 +30,6 @@ pub enum MyError {
#[error("Couldn't parse the signature header")]
HeaderValidation(#[from] actix_web::http::header::InvalidHeaderValue),
#[error("Failed to get output of db operation")]
Oneshot(#[from] RecvError),
#[error("Couldn't decode base64")]
Base64(#[from] base64::DecodeError),
@ -56,11 +51,14 @@ pub enum MyError {
#[error("Wrong ActivityPub kind, {0}")]
Kind(String),
#[error("The requested actor's mailbox is closed")]
MailboxClosed,
#[error("No host present in URI, {0}")]
Host(String),
#[error("The requested actor's mailbox has timed out")]
MailboxTimeout,
#[error("Too many CPUs, {0}")]
CpuCount(#[from] std::num::TryFromIntError),
#[error("Timed out while waiting on db pool")]
DbTimeout,
#[error("Invalid algorithm provided to verifier")]
Algorithm,
@ -104,6 +102,18 @@ impl ResponseError for MyError {
}
}
impl<T> From<bb8_postgres::bb8::RunError<T>> for MyError
where
T: Into<MyError>,
{
fn from(e: bb8_postgres::bb8::RunError<T>) -> Self {
match e {
bb8_postgres::bb8::RunError::User(e) => e.into(),
bb8_postgres::bb8::RunError::TimedOut => MyError::DbTimeout,
}
}
}
impl From<Infallible> for MyError {
fn from(i: Infallible) -> Self {
match i {}
@ -115,12 +125,3 @@ impl From<rsa::errors::Error> for MyError {
MyError::Rsa(e)
}
}
impl From<MailboxError> for MyError {
fn from(m: MailboxError) -> MyError {
match m {
MailboxError::Closed => MyError::MailboxClosed,
MailboxError::Timeout => MyError::MailboxTimeout,
}
}
}

View file

@ -1,7 +1,7 @@
use crate::{
accepted,
apub::{AcceptedActors, AcceptedObjects, ValidTypes},
db_actor::Db,
db::Db,
error::MyError,
requests::Requests,
state::{State, UrlKind},

View file

@ -1,35 +0,0 @@
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
#[derive(Clone, Debug)]
pub struct ArbiterLabelFactory(Arc<AtomicUsize>);
#[derive(Clone, Debug)]
pub struct ArbiterLabel(usize);
impl ArbiterLabelFactory {
pub fn new() -> Self {
ArbiterLabelFactory(Arc::new(AtomicUsize::new(0)))
}
pub fn set_label(&self) {
if !actix::Arbiter::contains_item::<ArbiterLabel>() {
let id = self.0.fetch_add(1, Ordering::SeqCst);
actix::Arbiter::set_item(ArbiterLabel(id));
}
}
}
impl ArbiterLabel {
pub fn get() -> ArbiterLabel {
actix::Arbiter::get_item(|label: &ArbiterLabel| label.clone())
}
}
impl std::fmt::Display for ArbiterLabel {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Arbiter #{}", self.0)
}
}

View file

@ -10,10 +10,8 @@ use sha2::{Digest, Sha256};
mod apub;
mod db;
mod db_actor;
mod error;
mod inbox;
mod label;
mod nodeinfo;
mod notify;
mod requests;
@ -23,9 +21,8 @@ mod webfinger;
use self::{
apub::PublicKey,
db_actor::Db,
db::Db,
error::MyError,
label::ArbiterLabelFactory,
state::{State, UrlKind},
verifier::MyVerify,
webfinger::RelayResolver,
@ -95,25 +92,18 @@ async fn main() -> Result<(), anyhow::Error> {
let use_whitelist = std::env::var("USE_WHITELIST").is_ok();
let use_https = std::env::var("USE_HTTPS").is_ok();
let arbiter_labeler = ArbiterLabelFactory::new();
let db = Db::build(pg_config.clone()).await?;
let db = Db::new(pg_config.clone());
arbiter_labeler.clone().set_label();
let state: State = db
.execute_inline(move |pool| State::hydrate(use_https, use_whitelist, hostname, pool))
.await??;
let state = State::hydrate(use_https, use_whitelist, hostname, &db).await?;
let _ = notify::NotifyHandler::start_handler(state.clone(), pg_config.clone());
HttpServer::new(move || {
arbiter_labeler.clone().set_label();
let state = state.clone();
let actor = Db::new(pg_config.clone());
App::new()
.wrap(Logger::default())
.data(actor)
.data(db.clone())
.data(state.clone())
.data(state.requests())
.service(web::resource("/").route(web::get().to(index)))

View file

@ -1,9 +1,7 @@
use crate::{apub::AcceptedActors, db_actor::Pool, requests::Requests};
use crate::{apub::AcceptedActors, db::Db, error::MyError, requests::Requests};
use activitystreams::primitives::XsdAnyUri;
use anyhow::Error;
use bb8_postgres::tokio_postgres::Client;
use futures::try_join;
use log::{error, info};
use log::info;
use lru::LruCache;
use rand::thread_rng;
use rsa::{RSAPrivateKey, RSAPublicKey};
@ -44,28 +42,21 @@ pub enum UrlKind {
Outbox,
}
#[derive(Clone, Debug, thiserror::Error)]
#[error("Error generating RSA key")]
pub struct RsaError;
impl Settings {
async fn hydrate(
client: &Client,
db: &Db,
use_https: bool,
whitelist_enabled: bool,
hostname: String,
) -> Result<Self, Error> {
let private_key = if let Some(key) = crate::db::hydrate_private_key(client).await? {
) -> Result<Self, MyError> {
let private_key = if let Some(key) = db.hydrate_private_key().await? {
key
} else {
info!("Generating new keys");
let mut rng = thread_rng();
let key = RSAPrivateKey::new(&mut rng, 4096).map_err(|e| {
error!("Error generating RSA key, {}", e);
RsaError
})?;
let key = RSAPrivateKey::new(&mut rng, 4096)?;
crate::db::update_private_key(client, &key).await?;
db.update_private_key(&key).await?;
key
};
@ -249,35 +240,12 @@ impl State {
use_https: bool,
whitelist_enabled: bool,
hostname: String,
pool: Pool,
) -> Result<Self, Error> {
let pool1 = pool.clone();
let pool2 = pool.clone();
let pool3 = pool.clone();
let f1 = async move {
let conn = pool.get().await?;
crate::db::hydrate_blocks(&conn).await
};
let f2 = async move {
let conn = pool1.get().await?;
crate::db::hydrate_whitelists(&conn).await
};
let f3 = async move {
let conn = pool2.get().await?;
crate::db::hydrate_listeners(&conn).await
};
let f4 = async move {
let conn = pool3.get().await?;
Settings::hydrate(&conn, use_https, whitelist_enabled, hostname).await
};
db: &Db,
) -> Result<Self, MyError> {
let f1 = db.hydrate_blocks();
let f2 = db.hydrate_whitelists();
let f3 = db.hydrate_listeners();
let f4 = Settings::hydrate(db, use_https, whitelist_enabled, hostname);
let (blocks, whitelists, listeners, settings) = try_join!(f1, f2, f3, f4)?;