Store media uuid mappings, be mindful of locks

This commit is contained in:
asonix 2020-03-26 13:21:05 -05:00
parent d445177c69
commit f11043e57d
10 changed files with 311 additions and 135 deletions

View file

@ -0,0 +1,2 @@
-- This file should undo anything in `up.sql`
DROP TABLE media;

View file

@ -0,0 +1,10 @@
-- Your SQL goes here
CREATE TABLE media (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
media_id UUID UNIQUE NOT NULL,
url TEXT UNIQUE NOT NULL,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT NOW()
);
SELECT diesel_manage_updated_at('media');

View file

@ -92,11 +92,15 @@ impl ActorCache {
} }
pub async fn unfollower(&self, actor: &Actor) -> Result<Option<Uuid>, MyError> { pub async fn unfollower(&self, actor: &Actor) -> Result<Option<Uuid>, MyError> {
let conn = self.db.pool().get().await?; let row_opt = self
.db
let row_opt = conn .pool()
.get()
.await?
.query_opt( .query_opt(
"DELETE FROM actors WHERE actor_id = $1::TEXT RETURNING listener_id;", "DELETE FROM actors
WHERE actor_id = $1::TEXT
RETURNING listener_id;",
&[&actor.id.as_str()], &[&actor.id.as_str()],
) )
.await?; .await?;
@ -109,9 +113,14 @@ impl ActorCache {
let listener_id: Uuid = row.try_get(0)?; let listener_id: Uuid = row.try_get(0)?;
let row_opt = conn let row_opt = self
.db
.pool()
.get()
.await?
.query_opt( .query_opt(
"SELECT FROM actors WHERE listener_id = $1::UUID;", "SELECT FROM actors
WHERE listener_id = $1::UUID;",
&[&listener_id], &[&listener_id],
) )
.await?; .await?;
@ -124,9 +133,11 @@ impl ActorCache {
} }
async fn lookup(&self, id: &XsdAnyUri) -> Result<Option<Actor>, MyError> { async fn lookup(&self, id: &XsdAnyUri) -> Result<Option<Actor>, MyError> {
let conn = self.db.pool().get().await?; let row_opt = self
.db
let row_opt = conn .pool()
.get()
.await?
.query_opt( .query_opt(
"SELECT listeners.actor_id, actors.public_key, actors.public_key_id "SELECT listeners.actor_id, actors.public_key, actors.public_key_id
FROM listeners FROM listeners
@ -158,9 +169,11 @@ impl ActorCache {
} }
async fn save(&self, actor: Actor) -> Result<(), MyError> { async fn save(&self, actor: Actor) -> Result<(), MyError> {
let conn = self.db.pool().get().await?; let row_opt = self
.db
let row_opt = conn .pool()
.get()
.await?
.query_opt( .query_opt(
"SELECT id FROM listeners WHERE actor_id = $1::TEXT LIMIT 1;", "SELECT id FROM listeners WHERE actor_id = $1::TEXT LIMIT 1;",
&[&actor.inbox.as_str()], &[&actor.inbox.as_str()],
@ -175,12 +188,33 @@ impl ActorCache {
let listener_id: Uuid = row.try_get(0)?; let listener_id: Uuid = row.try_get(0)?;
conn.execute( self.db
"INSERT INTO actors (actor_id, public_key, public_key_id, listener_id, created_at, updated_at) .pool()
VALUES ($1::TEXT, $2::TEXT, $3::TEXT, $4::UUID, 'now', 'now') .get()
ON CONFLICT (actor_id) .await?
.execute(
"INSERT INTO actors (
actor_id,
public_key,
public_key_id,
listener_id,
created_at,
updated_at
) VALUES (
$1::TEXT,
$2::TEXT,
$3::TEXT,
$4::UUID,
'now',
'now'
) ON CONFLICT (actor_id)
DO UPDATE SET public_key = $2::TEXT;", DO UPDATE SET public_key = $2::TEXT;",
&[&actor.id.as_str(), &actor.public_key, &actor.public_key_id.as_str(), &listener_id], &[
&actor.id.as_str(),
&actor.public_key,
&actor.public_key_id.as_str(),
&listener_id,
],
) )
.await?; .await?;
Ok(()) Ok(())
@ -192,9 +226,11 @@ impl ActorCache {
public_key: &str, public_key: &str,
public_key_id: &XsdAnyUri, public_key_id: &XsdAnyUri,
) -> Result<(), MyError> { ) -> Result<(), MyError> {
let conn = self.db.pool().get().await?; self.db
.pool()
conn.execute( .get()
.await?
.execute(
"UPDATE actors "UPDATE actors
SET public_key = $2::TEXT, public_key_id = $3::TEXT SET public_key = $2::TEXT, public_key_id = $3::TEXT
WHERE actor_id = $1::TEXT;", WHERE actor_id = $1::TEXT;",
@ -223,9 +259,13 @@ impl ActorCache {
} }
async fn rehydrate(&self) -> Result<(), MyError> { async fn rehydrate(&self) -> Result<(), MyError> {
let conn = self.db.pool().get().await?; let rows = self
.db
let rows = conn.query("SELECT actor_id FROM actors;", &[]).await?; .pool()
.get()
.await?
.query("SELECT actor_id FROM actors;", &[])
.await?;
let actor_ids = rows let actor_ids = rows
.into_iter() .into_iter()

View file

@ -1,5 +1,7 @@
use crate::{db::Db, error::MyError};
use activitystreams::primitives::XsdAnyUri; use activitystreams::primitives::XsdAnyUri;
use bytes::Bytes; use bytes::Bytes;
use futures::join;
use lru::LruCache; use lru::LruCache;
use std::{collections::HashMap, sync::Arc, time::Duration}; use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
@ -10,45 +12,154 @@ static MEDIA_DURATION: Duration = Duration::from_secs(60 * 60 * 24 * 2);
#[derive(Clone)] #[derive(Clone)]
pub struct Media { pub struct Media {
db: Db,
inverse: Arc<Mutex<HashMap<XsdAnyUri, Uuid>>>, inverse: Arc<Mutex<HashMap<XsdAnyUri, Uuid>>>,
url_cache: Arc<Mutex<LruCache<Uuid, XsdAnyUri>>>, url_cache: Arc<Mutex<LruCache<Uuid, XsdAnyUri>>>,
byte_cache: Arc<RwLock<TtlCache<Uuid, (String, Bytes)>>>, byte_cache: Arc<RwLock<TtlCache<Uuid, (String, Bytes)>>>,
} }
impl Media { impl Media {
pub fn new() -> Self { pub fn new(db: Db) -> Self {
Media { Media {
db,
inverse: Arc::new(Mutex::new(HashMap::new())), inverse: Arc::new(Mutex::new(HashMap::new())),
url_cache: Arc::new(Mutex::new(LruCache::new(128))), url_cache: Arc::new(Mutex::new(LruCache::new(128))),
byte_cache: Arc::new(RwLock::new(TtlCache::new(128))), byte_cache: Arc::new(RwLock::new(TtlCache::new(128))),
} }
} }
pub async fn get_uuid(&self, url: &XsdAnyUri) -> Option<Uuid> { pub async fn get_uuid(&self, url: &XsdAnyUri) -> Result<Option<Uuid>, MyError> {
let uuid = self.inverse.lock().await.get(url).cloned()?; let res = self.inverse.lock().await.get(url).cloned();
let uuid = match res {
Some(uuid) => uuid,
_ => {
let row_opt = self
.db
.pool()
.get()
.await?
.query_opt(
"SELECT media_id
FROM media
WHERE url = $1::TEXT
LIMIT 1;",
&[&url.as_str()],
)
.await?;
if let Some(row) = row_opt {
let uuid: Uuid = row.try_get(0)?;
self.inverse.lock().await.insert(url.clone(), uuid);
uuid
} else {
return Ok(None);
}
}
};
if self.url_cache.lock().await.contains(&uuid) { if self.url_cache.lock().await.contains(&uuid) {
return Some(uuid); return Ok(Some(uuid));
}
let row_opt = self
.db
.pool()
.get()
.await?
.query_opt(
"SELECT id
FROM media
WHERE
url = $1::TEXT
AND
media_id = $2::UUID
LIMIT 1;",
&[&url.as_str(), &uuid],
)
.await?;
if row_opt.is_some() {
self.url_cache.lock().await.put(uuid, url.clone());
return Ok(Some(uuid));
} }
self.inverse.lock().await.remove(url); self.inverse.lock().await.remove(url);
None Ok(None)
} }
pub async fn get_url(&self, uuid: Uuid) -> Option<XsdAnyUri> { pub async fn get_url(&self, uuid: Uuid) -> Result<Option<XsdAnyUri>, MyError> {
self.url_cache.lock().await.get(&uuid).cloned() match self.url_cache.lock().await.get(&uuid).cloned() {
Some(url) => return Ok(Some(url)),
_ => (),
}
let row_opt = self
.db
.pool()
.get()
.await?
.query_opt(
"SELECT url
FROM media
WHERE media_id = $1::UUID
LIMIT 1;",
&[&uuid],
)
.await?;
if let Some(row) = row_opt {
let url: String = row.try_get(0)?;
let url: XsdAnyUri = url.parse()?;
return Ok(Some(url));
}
Ok(None)
} }
pub async fn get_bytes(&self, uuid: Uuid) -> Option<(String, Bytes)> { pub async fn get_bytes(&self, uuid: Uuid) -> Option<(String, Bytes)> {
self.byte_cache.read().await.get(&uuid).cloned() self.byte_cache.read().await.get(&uuid).cloned()
} }
pub async fn store_url(&self, url: &XsdAnyUri) -> Uuid { pub async fn store_url(&self, url: &XsdAnyUri) -> Result<Uuid, MyError> {
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let (_, _, res) = join!(
async {
self.inverse.lock().await.insert(url.clone(), uuid); self.inverse.lock().await.insert(url.clone(), uuid);
},
async {
self.url_cache.lock().await.put(uuid, url.clone()); self.url_cache.lock().await.put(uuid, url.clone());
uuid },
async {
self.db
.pool()
.get()
.await?
.execute(
"INSERT INTO media (
media_id,
url,
created_at,
updated_at
) VALUES (
$1::UUID,
$2::TEXT,
'now',
'now'
) ON CONFLICT (media_id)
DO UPDATE SET url = $2::TEXT;",
&[&uuid, &url.as_str()],
)
.await?;
Ok(()) as Result<(), MyError>
}
);
res?;
Ok(uuid)
} }
pub async fn store_bytes(&self, uuid: Uuid, content_type: String, bytes: Bytes) { pub async fn store_bytes(&self, uuid: Uuid, content_type: String, bytes: Bytes) {

View file

@ -118,9 +118,11 @@ impl NodeCache {
} }
async fn do_bust_by_id(&self, id: Uuid) -> Result<(), MyError> { async fn do_bust_by_id(&self, id: Uuid) -> Result<(), MyError> {
let conn = self.db.pool().get().await?; let row_opt = self
.db
let row_opt = conn .pool()
.get()
.await?
.query_opt( .query_opt(
"SELECT ls.actor_id "SELECT ls.actor_id
FROM listeners AS ls FROM listeners AS ls
@ -140,16 +142,17 @@ impl NodeCache {
let listener: String = row.try_get(0)?; let listener: String = row.try_get(0)?;
let listener: XsdAnyUri = listener.parse()?; let listener: XsdAnyUri = listener.parse()?;
let mut write_guard = self.nodes.write().await; self.nodes.write().await.remove(&listener);
write_guard.remove(&listener);
Ok(()) Ok(())
} }
async fn do_cache_by_id(&self, id: Uuid) -> Result<(), MyError> { async fn do_cache_by_id(&self, id: Uuid) -> Result<(), MyError> {
let conn = self.db.pool().get().await?; let row_opt = self
.db
let row_opt = conn .pool()
.get()
.await?
.query_opt( .query_opt(
"SELECT ls.actor_id, nd.nodeinfo, nd.instance, nd.contact "SELECT ls.actor_id, nd.nodeinfo, nd.instance, nd.contact
FROM nodes AS nd FROM nodes AS nd
@ -172,6 +175,7 @@ impl NodeCache {
let instance: Option<Json<Instance>> = row.try_get(2)?; let instance: Option<Json<Instance>> = row.try_get(2)?;
let contact: Option<Json<Contact>> = row.try_get(3)?; let contact: Option<Json<Contact>> = row.try_get(3)?;
{
let mut write_guard = self.nodes.write().await; let mut write_guard = self.nodes.write().await;
let node = write_guard let node = write_guard
.entry(listener.clone()) .entry(listener.clone())
@ -186,6 +190,7 @@ impl NodeCache {
if let Some(contact) = contact { if let Some(contact) = contact {
node.contact = Some(contact.0); node.contact = Some(contact.0);
} }
}
Ok(()) Ok(())
} }
@ -203,12 +208,15 @@ impl NodeCache {
return Ok(()); return Ok(());
} }
let node = {
let mut write_guard = self.nodes.write().await; let mut write_guard = self.nodes.write().await;
let node = write_guard let node = write_guard
.entry(listener.clone()) .entry(listener.clone())
.or_insert(Node::new(listener.clone())); .or_insert(Node::new(listener.clone()));
node.set_info(software, version, reg); node.set_info(software, version, reg);
self.save(listener, node).await?; node.clone()
};
self.save(listener, &node).await?;
Ok(()) Ok(())
} }
@ -227,12 +235,15 @@ impl NodeCache {
return Ok(()); return Ok(());
} }
let node = {
let mut write_guard = self.nodes.write().await; let mut write_guard = self.nodes.write().await;
let node = write_guard let node = write_guard
.entry(listener.clone()) .entry(listener.clone())
.or_insert(Node::new(listener.clone())); .or_insert(Node::new(listener.clone()));
node.set_instance(title, description, version, reg, requires_approval); node.set_instance(title, description, version, reg, requires_approval);
self.save(listener, node).await?; node.clone()
};
self.save(listener, &node).await?;
Ok(()) Ok(())
} }
@ -250,19 +261,24 @@ impl NodeCache {
return Ok(()); return Ok(());
} }
let node = {
let mut write_guard = self.nodes.write().await; let mut write_guard = self.nodes.write().await;
let node = write_guard let node = write_guard
.entry(listener.clone()) .entry(listener.clone())
.or_insert(Node::new(listener.clone())); .or_insert(Node::new(listener.clone()));
node.set_contact(username, display_name, url, avatar); node.set_contact(username, display_name, url, avatar);
self.save(listener, node).await?; node.clone()
};
self.save(listener, &node).await?;
Ok(()) Ok(())
} }
pub async fn save(&self, listener: &XsdAnyUri, node: &Node) -> Result<(), MyError> { pub async fn save(&self, listener: &XsdAnyUri, node: &Node) -> Result<(), MyError> {
let conn = self.db.pool().get().await?; let row_opt = self
.db
let row_opt = conn .pool()
.get()
.await?
.query_opt( .query_opt(
"SELECT id FROM listeners WHERE actor_id = $1::TEXT LIMIT 1;", "SELECT id FROM listeners WHERE actor_id = $1::TEXT LIMIT 1;",
&[&listener.as_str()], &[&listener.as_str()],
@ -275,7 +291,11 @@ impl NodeCache {
return Err(MyError::NotSubscribed(listener.as_str().to_owned())); return Err(MyError::NotSubscribed(listener.as_str().to_owned()));
}; };
conn.execute( self.db
.pool()
.get()
.await?
.execute(
"INSERT INTO nodes ( "INSERT INTO nodes (
listener_id, listener_id,
nodeinfo, nodeinfo,

View file

@ -47,34 +47,29 @@ impl State {
} }
pub async fn bust_whitelist(&self, whitelist: &str) { pub async fn bust_whitelist(&self, whitelist: &str) {
let mut write_guard = self.whitelists.write().await; self.whitelists.write().await.remove(whitelist);
write_guard.remove(whitelist);
} }
pub async fn bust_block(&self, block: &str) { pub async fn bust_block(&self, block: &str) {
let mut write_guard = self.blocks.write().await; self.blocks.write().await.remove(block);
write_guard.remove(block);
} }
pub async fn bust_listener(&self, inbox: &XsdAnyUri) { pub async fn bust_listener(&self, inbox: &XsdAnyUri) {
let mut write_guard = self.listeners.write().await; self.listeners.write().await.remove(inbox);
write_guard.remove(inbox);
} }
pub async fn listeners(&self) -> Vec<XsdAnyUri> { pub async fn listeners(&self) -> Vec<XsdAnyUri> {
let read_guard = self.listeners.read().await; self.listeners.read().await.iter().cloned().collect()
read_guard.iter().cloned().collect()
} }
pub async fn blocks(&self) -> Vec<String> { pub async fn blocks(&self) -> Vec<String> {
let read_guard = self.blocks.read().await; self.blocks.read().await.iter().cloned().collect()
read_guard.iter().cloned().collect()
} }
pub async fn listeners_without(&self, inbox: &XsdAnyUri, domain: &str) -> Vec<XsdAnyUri> { pub async fn listeners_without(&self, inbox: &XsdAnyUri, domain: &str) -> Vec<XsdAnyUri> {
let read_guard = self.listeners.read().await; self.listeners
.read()
read_guard .await
.iter() .iter()
.filter_map(|listener| { .filter_map(|listener| {
if let Some(dom) = listener.as_url().domain() { if let Some(dom) = listener.as_url().domain() {
@ -94,8 +89,7 @@ impl State {
} }
if let Some(host) = actor_id.as_url().host() { if let Some(host) = actor_id.as_url().host() {
let read_guard = self.whitelists.read().await; self.whitelists.read().await.contains(&host.to_string());
return read_guard.contains(&host.to_string());
} }
false false
@ -103,43 +97,34 @@ impl State {
pub async fn is_blocked(&self, actor_id: &XsdAnyUri) -> bool { pub async fn is_blocked(&self, actor_id: &XsdAnyUri) -> bool {
if let Some(host) = actor_id.as_url().host() { if let Some(host) = actor_id.as_url().host() {
let read_guard = self.blocks.read().await; self.blocks.read().await.contains(&host.to_string());
return read_guard.contains(&host.to_string());
} }
true true
} }
pub async fn is_listener(&self, actor_id: &XsdAnyUri) -> bool { pub async fn is_listener(&self, actor_id: &XsdAnyUri) -> bool {
let read_guard = self.listeners.read().await; self.listeners.read().await.contains(actor_id)
read_guard.contains(actor_id)
} }
pub async fn is_cached(&self, object_id: &XsdAnyUri) -> bool { pub async fn is_cached(&self, object_id: &XsdAnyUri) -> bool {
let cache = self.actor_id_cache.clone(); self.actor_id_cache.read().await.contains(object_id)
let read_guard = cache.read().await;
read_guard.contains(object_id)
} }
pub async fn cache(&self, object_id: XsdAnyUri, actor_id: XsdAnyUri) { pub async fn cache(&self, object_id: XsdAnyUri, actor_id: XsdAnyUri) {
let mut write_guard = self.actor_id_cache.write().await; self.actor_id_cache.write().await.put(object_id, actor_id);
write_guard.put(object_id, actor_id);
} }
pub async fn cache_block(&self, host: String) { pub async fn cache_block(&self, host: String) {
let mut write_guard = self.blocks.write().await; self.blocks.write().await.insert(host);
write_guard.insert(host);
} }
pub async fn cache_whitelist(&self, host: String) { pub async fn cache_whitelist(&self, host: String) {
let mut write_guard = self.whitelists.write().await; self.whitelists.write().await.insert(host);
write_guard.insert(host);
} }
pub async fn cache_listener(&self, listener: XsdAnyUri) { pub async fn cache_listener(&self, listener: XsdAnyUri) {
let mut write_guard = self.listeners.write().await; self.listeners.write().await.insert(listener);
write_guard.insert(listener);
} }
pub async fn rehydrate(&self, db: &Db) -> Result<(), MyError> { pub async fn rehydrate(&self, db: &Db) -> Result<(), MyError> {
@ -151,16 +136,13 @@ impl State {
join!( join!(
async move { async move {
let mut write_guard = self.listeners.write().await; *self.listeners.write().await = listeners;
*write_guard = listeners;
}, },
async move { async move {
let mut write_guard = self.whitelists.write().await; *self.whitelists.write().await = whitelists;
*write_guard = whitelists;
}, },
async move { async move {
let mut write_guard = self.blocks.write().await; *self.blocks.write().await = blocks;
*write_guard = blocks;
} }
); );

View file

@ -45,10 +45,10 @@ impl QueryInstance {
}; };
if let Some(mut contact) = instance.contact { if let Some(mut contact) = instance.contact {
if let Some(uuid) = state.media.get_uuid(&contact.avatar).await { if let Some(uuid) = state.media.get_uuid(&contact.avatar).await? {
contact.avatar = state.config.generate_url(UrlKind::Media(uuid)).parse()?; contact.avatar = state.config.generate_url(UrlKind::Media(uuid)).parse()?;
} else { } else {
let uuid = state.media.store_url(&contact.avatar).await; let uuid = state.media.store_url(&contact.avatar).await?;
contact.avatar = state.config.generate_url(UrlKind::Media(uuid)).parse()?; contact.avatar = state.config.generate_url(UrlKind::Media(uuid)).parse()?;
} }

View file

@ -68,7 +68,7 @@ async fn main() -> Result<(), anyhow::Error> {
return Ok(()); return Ok(());
} }
let media = Media::new(); let media = Media::new(db.clone());
let state = State::hydrate(config.clone(), &db).await?; let state = State::hydrate(config.clone(), &db).await?;
let actors = ActorCache::new(db.clone()); let actors = ActorCache::new(db.clone());
let job_server = create_server(db.clone()); let job_server = create_server(db.clone());

View file

@ -13,7 +13,7 @@ pub async fn route(
return Ok(HttpResponse::Ok().content_type(content_type).body(bytes)); return Ok(HttpResponse::Ok().content_type(content_type).body(bytes));
} }
if let Some(url) = media.get_url(uuid).await { if let Some(url) = media.get_url(uuid).await? {
let (content_type, bytes) = requests.fetch_bytes(url.as_str()).await?; let (content_type, bytes) = requests.fetch_bytes(url.as_str()).await?;
media media

View file

@ -43,6 +43,16 @@ table! {
} }
} }
table! {
media (id) {
id -> Uuid,
media_id -> Uuid,
url -> Text,
created_at -> Timestamp,
updated_at -> Timestamp,
}
}
table! { table! {
nodes (id) { nodes (id) {
id -> Uuid, id -> Uuid,
@ -82,6 +92,7 @@ allow_tables_to_appear_in_same_query!(
blocks, blocks,
jobs, jobs,
listeners, listeners,
media,
nodes, nodes,
settings, settings,
whitelists, whitelists,