use chrono::{DateTime, Utc}; use uuid::Uuid; use mitra_utils::id::generate_ulid; use crate::attachments::{ queries::set_attachment_ipfs_cid, types::DbMediaAttachment, }; use crate::cleanup::{ find_orphaned_files, find_orphaned_ipfs_objects, DeletionQueue, }; use crate::database::{ catch_unique_violation, query_macro::query, DatabaseClient, DatabaseError, }; use crate::emojis::types::DbEmoji; use crate::notifications::queries::{ create_mention_notification, create_reply_notification, create_repost_notification, }; use crate::profiles::{ queries::update_post_count, types::DbActorProfile, }; use crate::relationships::types::RelationshipType; use super::types::{ DbPost, Post, PostCreateData, PostUpdateData, Visibility, }; async fn create_post_attachments( db_client: &impl DatabaseClient, post_id: &Uuid, author_id: &Uuid, attachments: Vec, ) -> Result, DatabaseError> { let attachments_rows = db_client.query( " UPDATE media_attachment SET post_id = $1 WHERE owner_id = $2 AND id = ANY($3) RETURNING media_attachment ", &[&post_id, &author_id, &attachments], ).await?; if attachments_rows.len() != attachments.len() { // Some attachments were not found return Err(DatabaseError::NotFound("attachment")); }; let mut attachments: Vec = attachments_rows.iter() .map(|row| row.try_get("media_attachment")) .collect::>()?; attachments.sort_by(|a, b| a.created_at.cmp(&b.created_at)); Ok(attachments) } async fn create_post_mentions( db_client: &impl DatabaseClient, post_id: &Uuid, mentions: Vec, ) -> Result, DatabaseError> { let mentions_rows = db_client.query( " INSERT INTO mention (post_id, profile_id) SELECT $1, actor_profile.id FROM actor_profile WHERE id = ANY($2) RETURNING ( SELECT actor_profile FROM actor_profile WHERE actor_profile.id = profile_id ) AS actor_profile ", &[&post_id, &mentions], ).await?; if mentions_rows.len() != mentions.len() { // Some profiles were not found return Err(DatabaseError::NotFound("profile")); }; let profiles = mentions_rows.iter() .map(|row| row.try_get("actor_profile")) .collect::>()?; Ok(profiles) } async fn create_post_tags( db_client: &impl DatabaseClient, post_id: &Uuid, tags: Vec, ) -> Result, DatabaseError> { db_client.execute( " INSERT INTO tag (tag_name) SELECT unnest($1::text[]) ON CONFLICT (tag_name) DO NOTHING ", &[&tags], ).await?; let tags_rows = db_client.query( " INSERT INTO post_tag (post_id, tag_id) SELECT $1, tag.id FROM tag WHERE tag_name = ANY($2) RETURNING (SELECT tag_name FROM tag WHERE tag.id = tag_id) ", &[&post_id, &tags], ).await?; if tags_rows.len() != tags.len() { return Err(DatabaseError::NotFound("tag")); }; let tags = tags_rows.iter() .map(|row| row.try_get("tag_name")) .collect::>()?; Ok(tags) } async fn create_post_links( db_client: &impl DatabaseClient, post_id: &Uuid, links: Vec, ) -> Result, DatabaseError> { let links_rows = db_client.query( " INSERT INTO post_link (source_id, target_id) SELECT $1, post.id FROM post WHERE id = ANY($2) RETURNING target_id ", &[&post_id, &links], ).await?; if links_rows.len() != links.len() { return Err(DatabaseError::NotFound("post")); }; let links = links_rows.iter() .map(|row| row.try_get("target_id")) .collect::>()?; Ok(links) } async fn create_post_emojis( db_client: &impl DatabaseClient, post_id: &Uuid, emojis: Vec, ) -> Result, DatabaseError> { let emojis_rows = db_client.query( " INSERT INTO post_emoji (post_id, emoji_id) SELECT $1, emoji.id FROM emoji WHERE id = ANY($2) RETURNING ( SELECT emoji FROM emoji WHERE emoji.id = emoji_id ) ", &[&post_id, &emojis], ).await?; if emojis_rows.len() != emojis.len() { return Err(DatabaseError::NotFound("emoji")); }; let emojis = emojis_rows.iter() .map(|row| row.try_get("emoji")) .collect::>()?; Ok(emojis) } pub async fn create_post( db_client: &mut impl DatabaseClient, author_id: &Uuid, post_data: PostCreateData, ) -> Result { let transaction = db_client.transaction().await?; let post_id = generate_ulid(); // Replying to reposts is not allowed // Reposting of other reposts or non-public posts is not allowed let insert_statement = format!( " INSERT INTO post ( id, author_id, content, in_reply_to_id, repost_of_id, visibility, is_sensitive, object_id, created_at ) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9 WHERE NOT EXISTS ( SELECT 1 FROM post WHERE post.id = $4 AND post.repost_of_id IS NOT NULL ) AND NOT EXISTS ( SELECT 1 FROM post WHERE post.id = $5 AND ( post.repost_of_id IS NOT NULL OR post.visibility != {visibility_public} ) ) RETURNING post ", visibility_public=i16::from(&Visibility::Public), ); let maybe_post_row = transaction.query_opt( &insert_statement, &[ &post_id, &author_id, &post_data.content, &post_data.in_reply_to_id, &post_data.repost_of_id, &post_data.visibility, &post_data.is_sensitive, &post_data.object_id, &post_data.created_at, ], ).await.map_err(catch_unique_violation("post"))?; // Return NotFound error if reply/repost is not allowed let post_row = maybe_post_row.ok_or(DatabaseError::NotFound("post"))?; let db_post: DbPost = post_row.try_get("post")?; // Create related objects let db_attachments = create_post_attachments( &transaction, &db_post.id, &db_post.author_id, post_data.attachments, ).await?; let db_mentions = create_post_mentions( &transaction, &db_post.id, post_data.mentions, ).await?; let db_tags = create_post_tags( &transaction, &db_post.id, post_data.tags, ).await?; let db_links = create_post_links( &transaction, &db_post.id, post_data.links, ).await?; let db_emojis = create_post_emojis( &transaction, &db_post.id, post_data.emojis, ).await?; // Update counters let author = update_post_count(&transaction, &db_post.author_id, 1).await?; let mut notified_users = vec![]; if let Some(in_reply_to_id) = &db_post.in_reply_to_id { update_reply_count(&transaction, in_reply_to_id, 1).await?; let in_reply_to_author = get_post_author(&transaction, in_reply_to_id).await?; if in_reply_to_author.is_local() && in_reply_to_author.id != db_post.author_id { create_reply_notification( &transaction, &db_post.author_id, &in_reply_to_author.id, &db_post.id, ).await?; notified_users.push(in_reply_to_author.id); }; }; if let Some(repost_of_id) = &db_post.repost_of_id { update_repost_count(&transaction, repost_of_id, 1).await?; let repost_of_author = get_post_author(&transaction, repost_of_id).await?; if repost_of_author.is_local() && repost_of_author.id != db_post.author_id && !notified_users.contains(&repost_of_author.id) { create_repost_notification( &transaction, &db_post.author_id, &repost_of_author.id, repost_of_id, ).await?; notified_users.push(repost_of_author.id); }; }; // Notify mentioned users for profile in db_mentions.iter() { if profile.is_local() && profile.id != db_post.author_id && // Don't send mention notification to the author of parent post // or to the author of reposted post !notified_users.contains(&profile.id) { create_mention_notification( &transaction, &db_post.author_id, &profile.id, &db_post.id, ).await?; }; }; // Construct post object let post = Post::new( db_post, author, db_attachments, db_mentions, db_tags, db_links, db_emojis, )?; transaction.commit().await?; Ok(post) } pub async fn update_post( db_client: &mut impl DatabaseClient, post_id: &Uuid, post_data: PostUpdateData, ) -> Result<(), DatabaseError> { let transaction = db_client.transaction().await?; // Reposts and immutable posts can't be updated let maybe_row = transaction.query_opt( " UPDATE post SET content = $1, is_sensitive = $2, updated_at = $3 WHERE id = $4 AND repost_of_id IS NULL AND ipfs_cid IS NULL RETURNING post ", &[ &post_data.content, &post_data.is_sensitive, &post_data.updated_at, &post_id, ], ).await?; let row = maybe_row.ok_or(DatabaseError::NotFound("post"))?; let db_post: DbPost = row.try_get("post")?; // Delete and re-create related objects transaction.execute( "DELETE FROM media_attachment WHERE post_id = $1", &[&db_post.id], ).await?; transaction.execute( "DELETE FROM mention WHERE post_id = $1", &[&db_post.id], ).await?; transaction.execute( "DELETE FROM post_tag WHERE post_id = $1", &[&db_post.id], ).await?; transaction.execute( "DELETE FROM post_link WHERE source_id = $1", &[&db_post.id], ).await?; transaction.execute( "DELETE FROM post_emoji WHERE post_id = $1", &[&db_post.id], ).await?; create_post_attachments( &transaction, &db_post.id, &db_post.author_id, post_data.attachments, ).await?; create_post_mentions( &transaction, &db_post.id, post_data.mentions, ).await?; create_post_tags( &transaction, &db_post.id, post_data.tags, ).await?; create_post_links( &transaction, &db_post.id, post_data.links, ).await?; create_post_emojis( &transaction, &db_post.id, post_data.emojis, ).await?; transaction.commit().await?; Ok(()) } pub const RELATED_ATTACHMENTS: &str = "ARRAY( SELECT media_attachment FROM media_attachment WHERE post_id = post.id ORDER BY media_attachment.created_at ) AS attachments"; pub const RELATED_MENTIONS: &str = "ARRAY( SELECT actor_profile FROM mention JOIN actor_profile ON mention.profile_id = actor_profile.id WHERE post_id = post.id ) AS mentions"; pub const RELATED_TAGS: &str = "ARRAY( SELECT tag.tag_name FROM tag JOIN post_tag ON post_tag.tag_id = tag.id WHERE post_tag.post_id = post.id ) AS tags"; pub const RELATED_LINKS: &str = "ARRAY( SELECT post_link.target_id FROM post_link WHERE post_link.source_id = post.id ) AS links"; pub const RELATED_EMOJIS: &str = "ARRAY( SELECT emoji FROM post_emoji JOIN emoji ON post_emoji.emoji_id = emoji.id WHERE post_emoji.post_id = post.id ) AS emojis"; fn build_visibility_filter() -> String { format!( "( post.author_id = $current_user_id OR post.visibility = {visibility_public} -- covers direct messages and subscribers-only posts OR EXISTS ( SELECT 1 FROM mention WHERE post_id = post.id AND profile_id = $current_user_id ) OR post.visibility = {visibility_followers} AND EXISTS ( SELECT 1 FROM relationship WHERE source_id = $current_user_id AND target_id = post.author_id AND relationship_type = {relationship_follow} ) )", visibility_public=i16::from(&Visibility::Public), visibility_followers=i16::from(&Visibility::Followers), relationship_follow=i16::from(&RelationshipType::Follow), ) } pub async fn get_home_timeline( db_client: &impl DatabaseClient, current_user_id: &Uuid, max_post_id: Option, limit: u16, ) -> Result, DatabaseError> { // Select posts from follows, subscriptions, // posts where current user is mentioned // and user's own posts. let statement = format!( " SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE ( post.author_id = $current_user_id OR ( EXISTS ( SELECT 1 FROM relationship WHERE source_id = $current_user_id AND target_id = post.author_id AND relationship_type IN ({relationship_follow}, {relationship_subscription}) ) AND ( -- show posts post.repost_of_id IS NULL -- show reposts if they are not hidden OR NOT EXISTS ( SELECT 1 FROM relationship WHERE source_id = $current_user_id AND target_id = post.author_id AND relationship_type = {relationship_hide_reposts} ) -- show reposts of current user's posts OR EXISTS ( SELECT 1 FROM post AS repost_of WHERE repost_of.id = post.repost_of_id AND repost_of.author_id = $current_user_id ) ) AND ( -- show posts (top-level) post.in_reply_to_id IS NULL -- show replies if they are not hidden OR NOT EXISTS ( SELECT 1 FROM relationship WHERE source_id = $current_user_id AND target_id = post.author_id AND relationship_type = {relationship_hide_replies} ) -- show replies to current user's posts OR EXISTS ( SELECT 1 FROM post AS in_reply_to WHERE in_reply_to.id = post.in_reply_to_id AND in_reply_to.author_id = $current_user_id ) ) ) OR EXISTS ( SELECT 1 FROM mention WHERE post_id = post.id AND profile_id = $current_user_id ) ) AND {visibility_filter} AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id) ORDER BY post.id DESC LIMIT $limit ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, relationship_follow=i16::from(&RelationshipType::Follow), relationship_subscription=i16::from(&RelationshipType::Subscription), relationship_hide_reposts=i16::from(&RelationshipType::HideReposts), relationship_hide_replies=i16::from(&RelationshipType::HideReplies), visibility_filter=build_visibility_filter(), ); let limit: i64 = limit.into(); let query = query!( &statement, current_user_id=current_user_id, max_post_id=max_post_id, limit=limit, )?; let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; Ok(posts) } pub async fn get_local_timeline( db_client: &impl DatabaseClient, current_user_id: &Uuid, max_post_id: Option, limit: u16, ) -> Result, DatabaseError> { let statement = format!( " SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE actor_profile.actor_json IS NULL AND post.visibility = {visibility_public} AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id) ORDER BY post.id DESC LIMIT $limit ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, visibility_public=i16::from(&Visibility::Public), ); let limit: i64 = limit.into(); let query = query!( &statement, current_user_id=current_user_id, max_post_id=max_post_id, limit=limit, )?; let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; Ok(posts) } pub async fn get_related_posts( db_client: &impl DatabaseClient, posts_ids: Vec, ) -> Result, DatabaseError> { let statement = format!( " SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE post.id IN ( SELECT post.in_reply_to_id FROM post WHERE post.id = ANY($1) UNION ALL SELECT post.repost_of_id FROM post WHERE post.id = ANY($1) UNION ALL SELECT post_link.target_id FROM post_link WHERE post_link.source_id = ANY($1) UNION ALL SELECT post_link.target_id FROM post_link JOIN post ON (post.repost_of_id = post_link.source_id) WHERE post.id = ANY($1) ) ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, ); let rows = db_client.query( &statement, &[&posts_ids], ).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; Ok(posts) } pub async fn get_posts_by_author( db_client: &impl DatabaseClient, profile_id: &Uuid, current_user_id: Option<&Uuid>, include_replies: bool, include_reposts: bool, max_post_id: Option, limit: u16, ) -> Result, DatabaseError> { let mut condition = format!( "post.author_id = $profile_id AND {visibility_filter} AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id)", visibility_filter=build_visibility_filter(), ); if !include_replies { condition.push_str(" AND post.in_reply_to_id IS NULL"); }; if !include_reposts { condition.push_str(" AND post.repost_of_id IS NULL"); }; let statement = format!( " SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE {condition} ORDER BY post.created_at DESC LIMIT $limit ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, condition=condition, ); let limit: i64 = limit.into(); let query = query!( &statement, profile_id=profile_id, current_user_id=current_user_id, max_post_id=max_post_id, limit=limit, )?; let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; Ok(posts) } pub async fn get_posts_by_tag( db_client: &impl DatabaseClient, tag_name: &str, current_user_id: Option<&Uuid>, max_post_id: Option, limit: u16, ) -> Result, DatabaseError> { let tag_name = tag_name.to_lowercase(); let statement = format!( " SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE EXISTS ( SELECT 1 FROM post_tag JOIN tag ON post_tag.tag_id = tag.id WHERE post_tag.post_id = post.id AND tag.tag_name = $tag_name ) AND {visibility_filter} AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id) ORDER BY post.id DESC LIMIT $limit ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, visibility_filter=build_visibility_filter(), ); let limit: i64 = limit.into(); let query = query!( &statement, tag_name=tag_name, current_user_id=current_user_id, max_post_id=max_post_id, limit=limit, )?; let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; Ok(posts) } pub async fn get_post_by_id( db_client: &impl DatabaseClient, post_id: &Uuid, ) -> Result { let statement = format!( " SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE post.id = $1 ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, ); let maybe_row = db_client.query_opt( &statement, &[&post_id], ).await?; let post = match maybe_row { Some(row) => Post::try_from(&row)?, None => return Err(DatabaseError::NotFound("post")), }; Ok(post) } /// Given a post ID, finds all items in thread. /// Results are sorted by tree path. pub async fn get_thread( db_client: &impl DatabaseClient, post_id: &Uuid, current_user_id: Option<&Uuid>, ) -> Result, DatabaseError> { // TODO: limit recursion depth let statement = format!( " WITH RECURSIVE ancestors (id, in_reply_to_id) AS ( SELECT post.id, post.in_reply_to_id FROM post WHERE post.id = $post_id AND post.repost_of_id IS NULL AND {visibility_filter} UNION ALL SELECT post.id, post.in_reply_to_id FROM post JOIN ancestors ON post.id = ancestors.in_reply_to_id ), thread (id, path) AS ( SELECT ancestors.id, ARRAY[ancestors.id] FROM ancestors WHERE ancestors.in_reply_to_id IS NULL UNION SELECT post.id, array_append(thread.path, post.id) FROM post JOIN thread ON post.in_reply_to_id = thread.id ) SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN thread ON post.id = thread.id JOIN actor_profile ON post.author_id = actor_profile.id WHERE {visibility_filter} ORDER BY thread.path ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, visibility_filter=build_visibility_filter(), ); let query = query!( &statement, post_id=post_id, current_user_id=current_user_id, )?; let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; if posts.is_empty() { return Err(DatabaseError::NotFound("post")); } Ok(posts) } pub async fn get_post_by_remote_object_id( db_client: &impl DatabaseClient, object_id: &str, ) -> Result { let statement = format!( " SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE post.object_id = $1 ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, ); let maybe_row = db_client.query_opt( &statement, &[&object_id], ).await?; let row = maybe_row.ok_or(DatabaseError::NotFound("post"))?; let post = Post::try_from(&row)?; Ok(post) } pub async fn get_post_by_ipfs_cid( db_client: &impl DatabaseClient, ipfs_cid: &str, ) -> Result { let statement = format!( " SELECT post, actor_profile, {related_attachments}, {related_mentions}, {related_tags}, {related_links}, {related_emojis} FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE post.ipfs_cid = $1 ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, related_links=RELATED_LINKS, related_emojis=RELATED_EMOJIS, ); let result = db_client.query_opt( &statement, &[&ipfs_cid], ).await?; let post = match result { Some(row) => Post::try_from(&row)?, None => return Err(DatabaseError::NotFound("post")), }; Ok(post) } pub async fn update_reply_count( db_client: &impl DatabaseClient, post_id: &Uuid, change: i32, ) -> Result<(), DatabaseError> { let updated_count = db_client.execute( " UPDATE post SET reply_count = reply_count + $1 WHERE id = $2 AND repost_of_id IS NULL ", &[&change, &post_id], ).await?; if updated_count == 0 { return Err(DatabaseError::NotFound("post")); } Ok(()) } pub async fn update_reaction_count( db_client: &impl DatabaseClient, post_id: &Uuid, change: i32, ) -> Result<(), DatabaseError> { let updated_count = db_client.execute( " UPDATE post SET reaction_count = reaction_count + $1 WHERE id = $2 AND repost_of_id IS NULL ", &[&change, &post_id], ).await?; if updated_count == 0 { return Err(DatabaseError::NotFound("post")); }; Ok(()) } pub async fn update_repost_count( db_client: &impl DatabaseClient, post_id: &Uuid, change: i32, ) -> Result<(), DatabaseError> { let updated_count = db_client.execute( " UPDATE post SET repost_count = repost_count + $1 WHERE id = $2 AND repost_of_id IS NULL ", &[&change, &post_id], ).await?; if updated_count == 0 { return Err(DatabaseError::NotFound("post")); }; Ok(()) } pub async fn set_post_ipfs_cid( db_client: &mut impl DatabaseClient, post_id: &Uuid, ipfs_cid: &str, attachments: Vec<(Uuid, String)>, ) -> Result<(), DatabaseError> { let transaction = db_client.transaction().await?; let updated_count = transaction.execute( " UPDATE post SET ipfs_cid = $1 WHERE id = $2 AND repost_of_id IS NULL AND ipfs_cid IS NULL ", &[&ipfs_cid, &post_id], ).await?; if updated_count == 0 { return Err(DatabaseError::NotFound("post")); }; for (attachment_id, cid) in attachments { set_attachment_ipfs_cid(&transaction, &attachment_id, &cid).await?; }; transaction.commit().await?; Ok(()) } pub async fn set_post_token_id( db_client: &impl DatabaseClient, post_id: &Uuid, token_id: i32, ) -> Result<(), DatabaseError> { let updated_count = db_client.execute( " UPDATE post SET token_id = $1 WHERE id = $2 AND repost_of_id IS NULL AND token_id IS NULL ", &[&token_id, &post_id], ).await?; if updated_count == 0 { return Err(DatabaseError::NotFound("post")); }; Ok(()) } pub async fn set_post_token_tx_id( db_client: &impl DatabaseClient, post_id: &Uuid, token_tx_id: &str, ) -> Result<(), DatabaseError> { let updated_count = db_client.execute( " UPDATE post SET token_tx_id = $1 WHERE id = $2 AND repost_of_id IS NULL ", &[&token_tx_id, &post_id], ).await?; if updated_count == 0 { return Err(DatabaseError::NotFound("post")); }; Ok(()) } pub async fn get_post_author( db_client: &impl DatabaseClient, post_id: &Uuid, ) -> Result { let maybe_row = db_client.query_opt( " SELECT actor_profile FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE post.id = $1 ", &[&post_id], ).await?; let row = maybe_row.ok_or(DatabaseError::NotFound("post"))?; let author: DbActorProfile = row.try_get("actor_profile")?; Ok(author) } /// Finds reposts of given posts and returns their IDs pub async fn find_reposts_by_user( db_client: &impl DatabaseClient, user_id: &Uuid, posts_ids: &[Uuid], ) -> Result, DatabaseError> { let rows = db_client.query( " SELECT post.id FROM post WHERE post.author_id = $1 AND post.repost_of_id = ANY($2) ", &[&user_id, &posts_ids], ).await?; let reposts: Vec = rows.iter() .map(|row| row.try_get("id")) .collect::>()?; Ok(reposts) } /// Finds items reposted by user among given posts pub async fn find_reposted_by_user( db_client: &impl DatabaseClient, user_id: &Uuid, posts_ids: &[Uuid], ) -> Result, DatabaseError> { let rows = db_client.query( " SELECT post.id FROM post WHERE post.id = ANY($2) AND EXISTS ( SELECT 1 FROM post AS repost WHERE repost.author_id = $1 AND repost.repost_of_id = post.id ) ", &[&user_id, &posts_ids], ).await?; let reposted: Vec = rows.iter() .map(|row| row.try_get("id")) .collect::>()?; Ok(reposted) } pub async fn get_token_waitlist( db_client: &impl DatabaseClient, ) -> Result, DatabaseError> { let rows = db_client.query( " SELECT post.id FROM post WHERE token_tx_id IS NOT NULL AND token_id IS NULL ", &[], ).await?; let waitlist: Vec = rows.iter() .map(|row| row.try_get("id")) .collect::>()?; Ok(waitlist) } /// Finds all contexts (identified by top-level post) /// updated before the specified date /// that do not contain local posts, reposts, mentions, links or reactions. pub async fn find_extraneous_posts( db_client: &impl DatabaseClient, updated_before: &DateTime, ) -> Result, DatabaseError> { let rows = db_client.query( " WITH RECURSIVE context_post (id, post_id, created_at) AS ( SELECT post.id, post.id, post.created_at FROM post WHERE post.in_reply_to_id IS NULL AND post.repost_of_id IS NULL AND post.created_at < $1 UNION SELECT context_post.id, post.id, post.created_at FROM post JOIN context_post ON ( post.in_reply_to_id = context_post.post_id OR post.repost_of_id = context_post.post_id ) ) SELECT context.id FROM ( SELECT context_post.id, array_agg(context_post.post_id) AS posts, max(context_post.created_at) AS updated_at FROM context_post GROUP BY context_post.id ) AS context WHERE context.updated_at < $1 AND NOT EXISTS ( SELECT 1 FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE post.id = ANY(context.posts) AND actor_profile.actor_json IS NULL ) AND NOT EXISTS ( SELECT 1 FROM mention JOIN actor_profile ON mention.profile_id = actor_profile.id WHERE mention.post_id = ANY(context.posts) AND actor_profile.actor_json IS NULL ) AND NOT EXISTS ( SELECT 1 FROM post_reaction JOIN actor_profile ON post_reaction.author_id = actor_profile.id WHERE post_reaction.post_id = ANY(context.posts) AND actor_profile.actor_json IS NULL ) AND NOT EXISTS ( SELECT 1 FROM post_link JOIN post ON post_link.source_id = post.id JOIN actor_profile ON post.author_id = actor_profile.id WHERE post_link.target_id = ANY(context.posts) AND actor_profile.actor_json IS NULL ) ", &[&updated_before], ).await?; let ids: Vec = rows.iter() .map(|row| row.try_get("id")) .collect::>()?; Ok(ids) } /// Deletes post from database and returns collection of orphaned objects. pub async fn delete_post( db_client: &mut impl DatabaseClient, post_id: &Uuid, ) -> Result { let transaction = db_client.transaction().await?; // Select all posts that will be deleted. // This includes given post, its descendants and reposts. let posts_rows = transaction.query( " WITH RECURSIVE context (post_id) AS ( SELECT post.id FROM post WHERE post.id = $1 UNION SELECT post.id FROM post JOIN context ON ( post.in_reply_to_id = context.post_id OR post.repost_of_id = context.post_id ) ) SELECT post_id FROM context ", &[&post_id], ).await?; let posts: Vec = posts_rows.iter() .map(|row| row.try_get("post_id")) .collect::>()?; // Get list of attached files let files_rows = transaction.query( " SELECT file_name FROM media_attachment WHERE post_id = ANY($1) ", &[&posts], ).await?; let files: Vec = files_rows.iter() .map(|row| row.try_get("file_name")) .collect::>()?; // Get list of linked IPFS objects let ipfs_objects_rows = transaction.query( " SELECT ipfs_cid FROM media_attachment WHERE post_id = ANY($1) AND ipfs_cid IS NOT NULL UNION ALL SELECT ipfs_cid FROM post WHERE id = ANY($1) AND ipfs_cid IS NOT NULL ", &[&posts], ).await?; let ipfs_objects: Vec = ipfs_objects_rows.iter() .map(|row| row.try_get("ipfs_cid")) .collect::>()?; // Update post counters transaction.execute( " UPDATE actor_profile SET post_count = post_count - post.count FROM ( SELECT post.author_id, count(*) FROM post WHERE post.id = ANY($1) GROUP BY post.author_id ) AS post WHERE actor_profile.id = post.author_id ", &[&posts], ).await?; // Delete post let maybe_post_row = transaction.query_opt( " DELETE FROM post WHERE id = $1 RETURNING post ", &[&post_id], ).await?; let post_row = maybe_post_row.ok_or(DatabaseError::NotFound("post"))?; let db_post: DbPost = post_row.try_get("post")?; // Update counters if let Some(parent_id) = &db_post.in_reply_to_id { update_reply_count(&transaction, parent_id, -1).await?; } if let Some(repost_of_id) = &db_post.repost_of_id { update_repost_count(&transaction, repost_of_id, -1).await?; }; let orphaned_files = find_orphaned_files(&transaction, files).await?; let orphaned_ipfs_objects = find_orphaned_ipfs_objects(&transaction, ipfs_objects).await?; transaction.commit().await?; Ok(DeletionQueue { files: orphaned_files, ipfs_objects: orphaned_ipfs_objects, }) } pub async fn get_local_post_count( db_client: &impl DatabaseClient, ) -> Result { let row = db_client.query_one( " SELECT count(post) FROM post JOIN user_account ON (post.author_id = user_account.id) WHERE post.in_reply_to_id IS NULL AND post.repost_of_id IS NULL ", &[], ).await?; let count = row.try_get("count")?; Ok(count) } #[cfg(test)] mod tests { use chrono::Duration; use serial_test::serial; use crate::database::test_utils::create_test_database; use crate::profiles::{ queries::create_profile, types::ProfileCreateData, }; use crate::relationships::queries::{ follow, hide_reposts, subscribe, }; use crate::users::{ queries::create_user, types::UserCreateData, }; use super::*; #[tokio::test] #[serial] async fn test_create_post() { let db_client = &mut create_test_database().await; let author_data = ProfileCreateData { username: "test".to_string(), ..Default::default() }; let author = create_profile(db_client, author_data).await.unwrap(); let post_data = PostCreateData { content: "test post".to_string(), ..Default::default() }; let post = create_post(db_client, &author.id, post_data).await.unwrap(); assert_eq!(post.content, "test post"); assert_eq!(post.author.id, author.id); assert_eq!(post.attachments.is_empty(), true); assert_eq!(post.mentions.is_empty(), true); assert_eq!(post.tags.is_empty(), true); assert_eq!(post.links.is_empty(), true); assert_eq!(post.updated_at, None); } #[tokio::test] #[serial] async fn test_create_post_with_link() { let db_client = &mut create_test_database().await; let author_data = ProfileCreateData { username: "test".to_string(), ..Default::default() }; let author = create_profile(db_client, author_data).await.unwrap(); let post_data_1 = PostCreateData::default(); let post_1 = create_post(db_client, &author.id, post_data_1).await.unwrap(); let post_data_2 = PostCreateData { links: vec![post_1.id], ..Default::default() }; let post_2 = create_post(db_client, &author.id, post_data_2).await.unwrap(); assert_eq!(post_2.links, vec![post_1.id]); } #[tokio::test] #[serial] async fn test_update_post() { let db_client = &mut create_test_database().await; let user_data = UserCreateData { username: "test".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let user = create_user(db_client, user_data).await.unwrap(); let post_data = PostCreateData { content: "test post".to_string(), ..Default::default() }; let post = create_post(db_client, &user.id, post_data).await.unwrap(); let post_data = PostUpdateData { content: "test update".to_string(), updated_at: Utc::now(), ..Default::default() }; update_post(db_client, &post.id, post_data).await.unwrap(); let post = get_post_by_id(db_client, &post.id).await.unwrap(); assert_eq!(post.content, "test update"); assert_eq!(post.updated_at.is_some(), true); } #[tokio::test] #[serial] async fn test_delete_post() { let db_client = &mut create_test_database().await; let user_data = UserCreateData { username: "test".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let user = create_user(db_client, user_data).await.unwrap(); let post_data = PostCreateData { content: "test post".to_string(), ..Default::default() }; let post = create_post(db_client, &user.id, post_data).await.unwrap(); let deletion_queue = delete_post(db_client, &post.id).await.unwrap(); assert_eq!(deletion_queue.files.len(), 0); assert_eq!(deletion_queue.ipfs_objects.len(), 0); } #[tokio::test] #[serial] async fn test_home_timeline() { let db_client = &mut create_test_database().await; let current_user_data = UserCreateData { username: "test".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let current_user = create_user(db_client, current_user_data).await.unwrap(); // Current user's post let post_data_1 = PostCreateData { content: "my post".to_string(), ..Default::default() }; let post_1 = create_post(db_client, ¤t_user.id, post_data_1).await.unwrap(); // Current user's direct message let post_data_2 = PostCreateData { content: "my post".to_string(), visibility: Visibility::Direct, ..Default::default() }; let post_2 = create_post(db_client, ¤t_user.id, post_data_2).await.unwrap(); // Another user's public post let user_data_1 = UserCreateData { username: "another-user".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let user_1 = create_user(db_client, user_data_1).await.unwrap(); let post_data_3 = PostCreateData { content: "test post".to_string(), ..Default::default() }; let post_3 = create_post(db_client, &user_1.id, post_data_3).await.unwrap(); // Direct message from another user to current user let post_data_4 = PostCreateData { content: "test post".to_string(), visibility: Visibility::Direct, mentions: vec![current_user.id], ..Default::default() }; let post_4 = create_post(db_client, &user_1.id, post_data_4).await.unwrap(); // Followers-only post from another user let post_data_5 = PostCreateData { content: "followers only".to_string(), visibility: Visibility::Followers, ..Default::default() }; let post_5 = create_post(db_client, &user_1.id, post_data_5).await.unwrap(); // Followed user's public post let user_data_2 = UserCreateData { username: "followed".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let user_2 = create_user(db_client, user_data_2).await.unwrap(); follow(db_client, ¤t_user.id, &user_2.id).await.unwrap(); let post_data_6 = PostCreateData { content: "test post".to_string(), ..Default::default() }; let post_6 = create_post(db_client, &user_2.id, post_data_6).await.unwrap(); // Followed user's repost let post_data_7 = PostCreateData { repost_of_id: Some(post_3.id), ..Default::default() }; let post_7 = create_post(db_client, &user_2.id, post_data_7).await.unwrap(); // Direct message from followed user sent to another user let post_data_8 = PostCreateData { content: "test post".to_string(), visibility: Visibility::Direct, mentions: vec![user_1.id], ..Default::default() }; let post_8 = create_post(db_client, &user_2.id, post_data_8).await.unwrap(); // Followers-only post from followed user let post_data_9 = PostCreateData { content: "followers only".to_string(), visibility: Visibility::Followers, ..Default::default() }; let post_9 = create_post(db_client, &user_2.id, post_data_9).await.unwrap(); // Subscribers-only post by followed user let post_data_10 = PostCreateData { content: "subscribers only".to_string(), visibility: Visibility::Subscribers, ..Default::default() }; let post_10 = create_post(db_client, &user_2.id, post_data_10).await.unwrap(); // Subscribers-only post by subscription (without mention) let user_data_3 = UserCreateData { username: "subscription".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let user_3 = create_user(db_client, user_data_3).await.unwrap(); subscribe(db_client, ¤t_user.id, &user_3.id).await.unwrap(); let post_data_11 = PostCreateData { content: "subscribers only".to_string(), visibility: Visibility::Subscribers, ..Default::default() }; let post_11 = create_post(db_client, &user_3.id, post_data_11).await.unwrap(); // Subscribers-only post by subscription (with mention) let post_data_12 = PostCreateData { content: "subscribers only".to_string(), visibility: Visibility::Subscribers, mentions: vec![current_user.id], ..Default::default() }; let post_12 = create_post(db_client, &user_3.id, post_data_12).await.unwrap(); // Repost from followed user if hiding reposts let user_data_4 = UserCreateData { username: "hide reposts".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let user_4 = create_user(db_client, user_data_4).await.unwrap(); follow(db_client, ¤t_user.id, &user_4.id).await.unwrap(); hide_reposts(db_client, ¤t_user.id, &user_4.id).await.unwrap(); let post_data_13 = PostCreateData { repost_of_id: Some(post_3.id), ..Default::default() }; let post_13 = create_post(db_client, &user_4.id, post_data_13).await.unwrap(); let timeline = get_home_timeline(db_client, ¤t_user.id, None, 20).await.unwrap(); assert_eq!(timeline.len(), 7); assert_eq!(timeline.iter().any(|post| post.id == post_1.id), true); assert_eq!(timeline.iter().any(|post| post.id == post_2.id), true); assert_eq!(timeline.iter().any(|post| post.id == post_3.id), false); assert_eq!(timeline.iter().any(|post| post.id == post_4.id), true); assert_eq!(timeline.iter().any(|post| post.id == post_5.id), false); assert_eq!(timeline.iter().any(|post| post.id == post_6.id), true); assert_eq!(timeline.iter().any(|post| post.id == post_7.id), true); assert_eq!(timeline.iter().any(|post| post.id == post_8.id), false); assert_eq!(timeline.iter().any(|post| post.id == post_9.id), true); assert_eq!(timeline.iter().any(|post| post.id == post_10.id), false); assert_eq!(timeline.iter().any(|post| post.id == post_11.id), false); assert_eq!(timeline.iter().any(|post| post.id == post_12.id), true); assert_eq!(timeline.iter().any(|post| post.id == post_13.id), false); } #[tokio::test] #[serial] async fn test_profile_timeline_public() { let db_client = &mut create_test_database().await; let user_data = UserCreateData { username: "test".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let user = create_user(db_client, user_data).await.unwrap(); // Public post let post_data_1 = PostCreateData { content: "my post".to_string(), ..Default::default() }; let post_1 = create_post(db_client, &user.id, post_data_1).await.unwrap(); // Followers only post let post_data_2 = PostCreateData { content: "my post".to_string(), visibility: Visibility::Followers, ..Default::default() }; let post_2 = create_post(db_client, &user.id, post_data_2).await.unwrap(); // Subscribers only post let post_data_3 = PostCreateData { content: "my post".to_string(), visibility: Visibility::Subscribers, ..Default::default() }; let post_3 = create_post(db_client, &user.id, post_data_3).await.unwrap(); // Direct message let post_data_4 = PostCreateData { content: "my post".to_string(), visibility: Visibility::Direct, ..Default::default() }; let post_4 = create_post(db_client, &user.id, post_data_4).await.unwrap(); // Reply let reply_data = PostCreateData { content: "my reply".to_string(), in_reply_to_id: Some(post_1.id.clone()), ..Default::default() }; let reply = create_post(db_client, &user.id, reply_data).await.unwrap(); // Repost let repost_data = PostCreateData { repost_of_id: Some(reply.id.clone()), ..Default::default() }; let repost = create_post(db_client, &user.id, repost_data).await.unwrap(); // Anonymous viewer let timeline = get_posts_by_author( db_client, &user.id, None, false, true, None, 10 ).await.unwrap(); assert_eq!(timeline.len(), 2); assert_eq!(timeline.iter().any(|post| post.id == post_1.id), true); assert_eq!(timeline.iter().any(|post| post.id == post_2.id), false); assert_eq!(timeline.iter().any(|post| post.id == post_3.id), false); assert_eq!(timeline.iter().any(|post| post.id == post_4.id), false); assert_eq!(timeline.iter().any(|post| post.id == reply.id), false); assert_eq!(timeline.iter().any(|post| post.id == repost.id), true); } #[tokio::test] #[serial] async fn test_get_thread() { let db_client = &mut create_test_database().await; let user_data = UserCreateData { username: "test".to_string(), password_hash: Some("test".to_string()), ..Default::default() }; let user = create_user(db_client, user_data).await.unwrap(); let post_data_1 = PostCreateData { content: "my post".to_string(), ..Default::default() }; let post_1 = create_post(db_client, &user.id, post_data_1).await.unwrap(); let post_data_2 = PostCreateData { content: "my reply".to_string(), in_reply_to_id: Some(post_1.id.clone()), ..Default::default() }; let post_2 = create_post(db_client, &user.id, post_data_2).await.unwrap(); let thread = get_thread( db_client, &post_2.id, Some(&user.id), ).await.unwrap(); assert_eq!(thread[0].id, post_1.id); assert_eq!(thread[1].id, post_2.id); } #[tokio::test] #[serial] async fn test_find_extraneous_posts() { let db_client = &mut create_test_database().await; let author_data = ProfileCreateData { username: "test".to_string(), ..Default::default() }; let author = create_profile(db_client, author_data).await.unwrap(); let post_data = PostCreateData { content: "test post".to_string(), ..Default::default() }; create_post(db_client, &author.id, post_data).await.unwrap(); let updated_before = Utc::now() - Duration::days(1); let result = find_extraneous_posts( db_client, &updated_before, ).await.unwrap(); assert!(result.is_empty()); } }