Update cache from postgres notifications

- Add actor that listens to the notification stream
- Move queries into db module
This commit is contained in:
asonix 2020-03-16 12:56:26 -05:00
parent 8d157ccbc6
commit b7bf001287
7 changed files with 466 additions and 135 deletions

View file

@ -0,0 +1,8 @@
-- This file should undo anything in `up.sql`
DROP TRIGGER IF EXISTS whitelists_notify ON whitelists;
DROP TRIGGER IF EXISTS blocks_notify ON blocks;
DROP TRIGGER IF EXISTS listeners_notify ON listeners;
DROP FUNCTION IF EXISTS invoke_whitelists_trigger();
DROP FUNCTION IF EXISTS invoke_blocks_trigger();
DROP FUNCTION IF EXISTS invoke_listeners_trigger();

View file

@ -0,0 +1,99 @@
-- Your SQL goes here
CREATE OR REPLACE FUNCTION invoke_listeners_trigger ()
RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
DECLARE
rec RECORD;
channel TEXT;
payload TEXT;
BEGIN
case TG_OP
WHEN 'INSERT' THEN
rec := NEW;
channel := 'new_listeners';
payload := NEW.actor_id;
WHEN 'DELETE' THEN
rec := OLD;
channel := 'rm_listeners';
payload := OLD.actor_id;
ELSE
RAISE EXCEPTION 'Unknown TG_OP: "%". Should not occur!', TG_OP;
END CASE;
PERFORM pg_notify(channel, payload::TEXT);
RETURN rec;
END;
$$;
CREATE OR REPLACE FUNCTION invoke_blocks_trigger ()
RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
DECLARE
rec RECORD;
channel TEXT;
payload TEXT;
BEGIN
case TG_OP
WHEN 'INSERT' THEN
rec := NEW;
channel := 'new_blocks';
payload := NEW.domain_name;
WHEN 'DELETE' THEN
rec := OLD;
channel := 'rm_blocks';
payload := OLD.domain_name;
ELSE
RAISE EXCEPTION 'Unknown TG_OP: "%". Should not occur!', TG_OP;
END CASE;
PERFORM pg_notify(channel, payload::TEXT);
RETURN NULL;
END;
$$;
CREATE OR REPLACE FUNCTION invoke_whitelists_trigger ()
RETURNS TRIGGER
LANGUAGE plpgsql
AS $$
DECLARE
rec RECORD;
channel TEXT;
payload TEXT;
BEGIN
case TG_OP
WHEN 'INSERT' THEN
rec := NEW;
channel := 'new_whitelists';
payload := NEW.domain_name;
WHEN 'DELETE' THEN
rec := OLD;
channel := 'rm_whitelists';
payload := OLD.domain_name;
ELSE
RAISE EXCEPTION 'Unknown TG_OP: "%". Should not occur!', TG_OP;
END CASE;
PERFORM pg_notify(channel, payload::TEXT);
RETURN rec;
END;
$$;
CREATE TRIGGER listeners_notify
AFTER INSERT OR UPDATE OR DELETE
ON listeners
FOR EACH ROW
EXECUTE PROCEDURE invoke_listeners_trigger();
CREATE TRIGGER blocks_notify
AFTER INSERT OR UPDATE OR DELETE
ON blocks
FOR EACH ROW
EXECUTE PROCEDURE invoke_blocks_trigger();
CREATE TRIGGER whitelists_notify
AFTER INSERT OR UPDATE OR DELETE
ON whitelists
FOR EACH ROW
EXECUTE PROCEDURE invoke_whitelists_trigger();

158
src/db.rs Normal file
View file

@ -0,0 +1,158 @@
use activitystreams::primitives::XsdAnyUri;
use anyhow::Error;
use bb8_postgres::tokio_postgres::{row::Row, Client};
use log::info;
use rsa::RSAPrivateKey;
use rsa_pem::KeyExt;
use std::collections::HashSet;
#[derive(Clone, Debug, thiserror::Error)]
#[error("No host present in URI")]
pub struct HostError;
pub async fn listen(client: &Client) -> Result<(), Error> {
client
.batch_execute(
"LISTEN new_blocks;
LISTEN new_whitelists;
LISTEN new_listeners;
LISTEN rm_blocks;
LISTEN rm_whitelists;
LISTEN rm_listeners;",
)
.await?;
Ok(())
}
pub async fn hydrate_private_key(client: &Client) -> Result<Option<RSAPrivateKey>, Error> {
info!("SELECT value FROM settings WHERE key = 'private_key'");
let rows = client
.query("SELECT value FROM settings WHERE key = 'private_key'", &[])
.await?;
if let Some(row) = rows.into_iter().next() {
let key_str: String = row.get(0);
return Ok(Some(KeyExt::from_pem_pkcs8(&key_str)?));
}
Ok(None)
}
pub async fn update_private_key(client: &Client, key: &RSAPrivateKey) -> Result<(), Error> {
let pem_pkcs8 = key.to_pem_pkcs8()?;
info!("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');");
client.execute("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');", &[&pem_pkcs8]).await?;
Ok(())
}
pub async fn add_block(client: &Client, block: &XsdAnyUri) -> Result<(), Error> {
let host = if let Some(host) = block.as_url().host() {
host
} else {
return Err(HostError.into());
};
info!(
"INSERT INTO blocks (domain_name, created_at) VALUES ($1::TEXT, 'now'); [{}]",
host.to_string()
);
client
.execute(
"INSERT INTO blocks (domain_name, created_at) VALUES ($1::TEXT, 'now');",
&[&host.to_string()],
)
.await?;
Ok(())
}
pub async fn add_whitelist(client: &Client, whitelist: &XsdAnyUri) -> Result<(), Error> {
let host = if let Some(host) = whitelist.as_url().host() {
host
} else {
return Err(HostError.into());
};
info!(
"INSERT INTO whitelists (domain_name, created_at) VALUES ($1::TEXT, 'now'); [{}]",
host.to_string()
);
client
.execute(
"INSERT INTO whitelists (domain_name, created_at) VALUES ($1::TEXT, 'now');",
&[&host.to_string()],
)
.await?;
Ok(())
}
pub async fn remove_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> {
info!(
"DELETE FROM listeners WHERE actor_id = {};",
listener.as_str()
);
client
.execute(
"DELETE FROM listeners WHERE actor_id = $1::TEXT;",
&[&listener.as_str()],
)
.await?;
Ok(())
}
pub async fn add_listener(client: &Client, listener: &XsdAnyUri) -> Result<(), Error> {
info!(
"INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now'); [{}]",
listener.as_str(),
);
client
.execute(
"INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now');",
&[&listener.as_str()],
)
.await?;
Ok(())
}
pub async fn hydrate_blocks(client: &Client) -> Result<HashSet<String>, Error> {
info!("SELECT domain_name FROM blocks");
let rows = client.query("SELECT domain_name FROM blocks", &[]).await?;
parse_rows(rows)
}
pub async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, Error> {
info!("SELECT domain_name FROM whitelists");
let rows = client
.query("SELECT domain_name FROM whitelists", &[])
.await?;
parse_rows(rows)
}
pub async fn hydrate_listeners(client: &Client) -> Result<HashSet<XsdAnyUri>, Error> {
info!("SELECT actor_id FROM listeners");
let rows = client.query("SELECT actor_id FROM listeners", &[]).await?;
parse_rows(rows)
}
fn parse_rows<T>(rows: Vec<Row>) -> Result<HashSet<T>, Error>
where
T: std::str::FromStr + Eq + std::hash::Hash,
{
let hs = rows
.into_iter()
.filter_map(move |row| {
let s: String = row.try_get(0).ok()?;
s.parse().ok()
})
.collect();
Ok(hs)
}

View file

@ -63,14 +63,15 @@ async fn handle_undo(
let inbox = actor.inbox().to_owned();
let state2 = state.clone().into_inner();
db_actor.do_send(DbQuery(move |pool: Pool| {
let inbox = inbox.clone();
async move {
let conn = pool.get().await?;
state2.remove_listener(&conn, &inbox).await.map_err(|e| {
crate::db::remove_listener(&conn, &inbox)
.await
.map_err(|e| {
error!("Error removing listener, {}", e);
e
})
@ -181,17 +182,14 @@ async fn handle_follow(
}
if !is_listener {
let state = state.clone().into_inner();
let inbox = actor.inbox().to_owned();
db_actor.do_send(DbQuery(move |pool: Pool| {
let inbox = inbox.clone();
let state = state.clone();
async move {
let conn = pool.get().await?;
state.add_listener(&conn, inbox).await.map_err(|e| {
crate::db::add_listener(&conn, &inbox).await.map_err(|e| {
error!("Error adding listener, {}", e);
e
})

View file

@ -6,10 +6,12 @@ use rsa_pem::KeyExt;
use sha2::{Digest, Sha256};
mod apub;
mod db;
mod db_actor;
mod error;
mod inbox;
mod label;
mod notify;
mod state;
mod verifier;
mod webfinger;
@ -62,7 +64,7 @@ async fn actor_route(state: web::Data<State>) -> Result<impl Responder, MyError>
#[actix_rt::main]
async fn main() -> Result<(), anyhow::Error> {
dotenv::dotenv().ok();
std::env::set_var("RUST_LOG", "info");
std::env::set_var("RUST_LOG", "debug");
pretty_env_logger::init();
let pg_config: tokio_postgres::Config = std::env::var("DATABASE_URL")?.parse()?;
@ -82,6 +84,8 @@ async fn main() -> Result<(), anyhow::Error> {
.await?
.await??;
let _ = notify::NotifyHandler::start_handler(state.clone(), pg_config.clone());
HttpServer::new(move || {
let actor = DbActor::new(pg_config.clone());
arbiter_labeler.clone().set_label();

161
src/notify.rs Normal file
View file

@ -0,0 +1,161 @@
use crate::state::State;
use activitystreams::primitives::XsdAnyUri;
use actix::prelude::*;
use bb8_postgres::tokio_postgres::{tls::NoTls, AsyncMessage, Client, Config, Notification};
use futures::{
future::ready,
stream::{poll_fn, StreamExt},
};
use log::{debug, error, info};
use tokio::sync::mpsc;
#[derive(Message)]
#[rtype(result = "()")]
pub struct Notify(Notification);
pub struct NotifyHandler {
client: Option<Client>,
state: State,
config: Config,
}
impl NotifyHandler {
fn new(state: State, config: Config) -> Self {
NotifyHandler {
state,
config,
client: None,
}
}
pub fn start_handler(state: State, config: Config) -> Addr<Self> {
Supervisor::start(|_| Self::new(state, config))
}
}
impl Actor for NotifyHandler {
type Context = Context<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
let config = self.config.clone();
let fut = async move {
let (client, mut conn) = match config.connect(NoTls).await {
Ok((client, conn)) => (client, conn),
Err(e) => {
error!("Error establishing DB Connection, {}", e);
return Err(());
}
};
let mut stream = poll_fn(move |cx| conn.poll_message(cx)).filter_map(|m| match m {
Ok(AsyncMessage::Notification(n)) => {
debug!("Handling Notification, {:?}", n);
ready(Some(Notify(n)))
}
Ok(AsyncMessage::Notice(e)) => {
debug!("Handling Notice, {:?}", e);
ready(None)
}
Err(e) => {
debug!("Handling Error, {:?}", e);
ready(None)
}
_ => {
debug!("Handling rest");
ready(None)
}
});
let (mut tx, rx) = mpsc::channel(256);
Arbiter::spawn(async move {
debug!("Spawned stream handler");
while let Some(n) = stream.next().await {
match tx.send(n).await {
Err(e) => error!("Error forwarding notification, {}", e),
_ => (),
};
}
debug!("Stream handler ended");
});
Ok((client, rx))
};
let fut = fut.into_actor(self).map(|res, actor, ctx| match res {
Ok((client, stream)) => {
Self::add_stream(stream, ctx);
let f = async move {
match crate::db::listen(&client).await {
Err(e) => {
error!("Error listening, {}", e);
Err(())
}
Ok(_) => Ok(client),
}
};
ctx.wait(f.into_actor(actor).map(|res, actor, ctx| match res {
Ok(client) => {
actor.client = Some(client);
}
Err(_) => {
ctx.stop();
}
}));
}
Err(_) => {
ctx.stop();
}
});
ctx.wait(fut);
info!("Listener starting");
}
}
impl StreamHandler<Notify> for NotifyHandler {
fn handle(&mut self, Notify(notif): Notify, ctx: &mut Self::Context) {
let state = self.state.clone();
info!("Handling notification in {}", notif.channel());
let fut = async move {
match notif.channel() {
"new_blocks" => {
debug!("Caching block of {}", notif.payload());
state.cache_block(notif.payload().to_owned()).await;
}
"new_whitelists" => {
debug!("Caching whitelist of {}", notif.payload());
state.cache_whitelist(notif.payload().to_owned()).await;
}
"new_listeners" => {
if let Ok(uri) = notif.payload().parse::<XsdAnyUri>() {
debug!("Caching listener {}", uri);
state.cache_listener(uri).await;
}
}
"rm_blocks" => {
debug!("Busting block cache for {}", notif.payload());
state.bust_block(notif.payload()).await;
}
"rm_whitelists" => {
debug!("Busting whitelist cache for {}", notif.payload());
state.bust_whitelist(notif.payload()).await;
}
"rm_listeners" => {
if let Ok(uri) = notif.payload().parse::<XsdAnyUri>() {
debug!("Busting listener cache for {}", uri);
state.bust_listener(&uri).await;
}
}
_ => (),
}
};
ctx.spawn(fut.into_actor(self));
}
}
impl Supervised for NotifyHandler {}

View file

@ -1,12 +1,11 @@
use activitystreams::primitives::XsdAnyUri;
use anyhow::Error;
use bb8_postgres::tokio_postgres::{row::Row, Client};
use bb8_postgres::tokio_postgres::Client;
use futures::try_join;
use log::{error, info};
use lru::LruCache;
use rand::thread_rng;
use rsa::{RSAPrivateKey, RSAPublicKey};
use rsa_pem::KeyExt;
use std::{collections::HashSet, sync::Arc};
use tokio::sync::RwLock;
use ttl_cache::TtlCache;
@ -42,10 +41,6 @@ pub enum UrlKind {
MainKey,
}
#[derive(Clone, Debug, thiserror::Error)]
#[error("No host present in URI")]
pub struct HostError;
#[derive(Clone, Debug, thiserror::Error)]
#[error("Error generating RSA key")]
pub struct RsaError;
@ -57,14 +52,8 @@ impl Settings {
whitelist_enabled: bool,
hostname: String,
) -> Result<Self, Error> {
info!("SELECT value FROM settings WHERE key = 'private_key'");
let rows = client
.query("SELECT value FROM settings WHERE key = 'private_key'", &[])
.await?;
let private_key = if let Some(row) = rows.into_iter().next() {
let key_str: String = row.get(0);
KeyExt::from_pem_pkcs8(&key_str)?
let private_key = if let Some(key) = crate::db::hydrate_private_key(client).await? {
key
} else {
info!("Generating new keys");
let mut rng = thread_rng();
@ -72,10 +61,9 @@ impl Settings {
error!("Error generating RSA key, {}", e);
RsaError
})?;
let pem_pkcs8 = key.to_pem_pkcs8()?;
info!("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');");
client.execute("INSERT INTO settings (key, value, created_at) VALUES ('private_key', $1::TEXT, 'now');", &[&pem_pkcs8]).await?;
crate::db::update_private_key(client, &key).await?;
key
};
@ -131,21 +119,25 @@ impl State {
self.settings.sign(bytes)
}
pub async fn remove_listener(&self, client: &Client, inbox: &XsdAnyUri) -> Result<(), Error> {
let hs = self.listeners.clone();
pub async fn bust_whitelist(&self, whitelist: &str) {
let hs = self.whitelists.clone();
info!("DELETE FROM listeners WHERE actor_id = {};", inbox.as_str());
client
.execute(
"DELETE FROM listeners WHERE actor_id = $1::TEXT;",
&[&inbox.as_str()],
)
.await?;
let mut write_guard = hs.write().await;
write_guard.remove(whitelist);
}
pub async fn bust_block(&self, block: &str) {
let hs = self.blocks.clone();
let mut write_guard = hs.write().await;
write_guard.remove(block);
}
pub async fn bust_listener(&self, inbox: &XsdAnyUri) {
let hs = self.listeners.clone();
let mut write_guard = hs.write().await;
write_guard.remove(inbox);
Ok(())
}
pub async fn listeners_without(&self, inbox: &XsdAnyUri, domain: &str) -> Vec<XsdAnyUri> {
@ -228,76 +220,25 @@ impl State {
write_guard.put(object_id, actor_id);
}
pub async fn add_block(&self, client: &Client, block: XsdAnyUri) -> Result<(), Error> {
pub async fn cache_block(&self, host: String) {
let blocks = self.blocks.clone();
let host = if let Some(host) = block.as_url().host() {
host
} else {
return Err(HostError.into());
};
info!(
"INSERT INTO blocks (domain_name, created_at) VALUES ($1::TEXT, 'now'); [{}]",
host.to_string()
);
client
.execute(
"INSERT INTO blocks (domain_name, created_at) VALUES ($1::TEXT, 'now');",
&[&host.to_string()],
)
.await?;
let mut write_guard = blocks.write().await;
write_guard.insert(host.to_string());
Ok(())
write_guard.insert(host);
}
pub async fn add_whitelist(&self, client: &Client, whitelist: XsdAnyUri) -> Result<(), Error> {
pub async fn cache_whitelist(&self, host: String) {
let whitelists = self.whitelists.clone();
let host = if let Some(host) = whitelist.as_url().host() {
host
} else {
return Err(HostError.into());
};
info!(
"INSERT INTO whitelists (domain_name, created_at) VALUES ($1::TEXT, 'now'); [{}]",
host.to_string()
);
client
.execute(
"INSERT INTO whitelists (domain_name, created_at) VALUES ($1::TEXT, 'now');",
&[&host.to_string()],
)
.await?;
let mut write_guard = whitelists.write().await;
write_guard.insert(host.to_string());
Ok(())
write_guard.insert(host);
}
pub async fn add_listener(&self, client: &Client, listener: XsdAnyUri) -> Result<(), Error> {
pub async fn cache_listener(&self, listener: XsdAnyUri) {
let listeners = self.listeners.clone();
info!(
"INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now'); [{}]",
listener.as_str(),
);
client
.execute(
"INSERT INTO listeners (actor_id, created_at) VALUES ($1::TEXT, 'now');",
&[&listener.as_str()],
)
.await?;
let mut write_guard = listeners.write().await;
write_guard.insert(listener);
Ok(())
}
pub async fn hydrate(
@ -313,19 +254,19 @@ impl State {
let f1 = async move {
let conn = pool.get().await?;
hydrate_blocks(&conn).await
crate::db::hydrate_blocks(&conn).await
};
let f2 = async move {
let conn = pool1.get().await?;
hydrate_whitelists(&conn).await
crate::db::hydrate_whitelists(&conn).await
};
let f3 = async move {
let conn = pool2.get().await?;
hydrate_listeners(&conn).await
crate::db::hydrate_listeners(&conn).await
};
let f4 = async move {
@ -346,41 +287,3 @@ impl State {
})
}
}
pub async fn hydrate_blocks(client: &Client) -> Result<HashSet<String>, Error> {
info!("SELECT domain_name FROM blocks");
let rows = client.query("SELECT domain_name FROM blocks", &[]).await?;
parse_rows(rows)
}
pub async fn hydrate_whitelists(client: &Client) -> Result<HashSet<String>, Error> {
info!("SELECT domain_name FROM whitelists");
let rows = client
.query("SELECT domain_name FROM whitelists", &[])
.await?;
parse_rows(rows)
}
pub async fn hydrate_listeners(client: &Client) -> Result<HashSet<XsdAnyUri>, Error> {
info!("SELECT actor_id FROM listeners");
let rows = client.query("SELECT actor_id FROM listeners", &[]).await?;
parse_rows(rows)
}
pub fn parse_rows<T>(rows: Vec<Row>) -> Result<HashSet<T>, Error>
where
T: std::str::FromStr + Eq + std::hash::Hash,
{
let hs = rows
.into_iter()
.filter_map(move |row| {
let s: String = row.try_get(0).ok()?;
s.parse().ok()
})
.collect();
Ok(hs)
}