diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b2b75e..fb7be85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +### Added + +- Replace post attachments and other related objects when processing `Update(Note)` activity. + ### Changed - Use proof suites with prefix `Mitra`. diff --git a/src/activitypub/handlers/create.rs b/src/activitypub/handlers/create.rs index 5a8fc2e..c4bd330 100644 --- a/src/activitypub/handlers/create.rs +++ b/src/activitypub/handlers/create.rs @@ -131,7 +131,7 @@ fn is_gnu_social_link(author_id: &str, attachment: &Attachment) -> bool { } } -async fn get_object_attachments( +pub async fn get_object_attachments( config: &Config, db_client: &impl DatabaseClient, object: &Object, @@ -197,7 +197,7 @@ async fn get_object_attachments( Ok(attachments) } -async fn get_object_tags( +pub async fn get_object_tags( config: &Config, db_client: &impl DatabaseClient, object: &Object, diff --git a/src/activitypub/handlers/update.rs b/src/activitypub/handlers/update.rs index 4c73dc2..cb868f1 100644 --- a/src/activitypub/handlers/update.rs +++ b/src/activitypub/handlers/update.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use chrono::Utc; use serde::Deserialize; use serde_json::Value; @@ -7,7 +9,11 @@ use crate::activitypub::{ helpers::update_remote_profile, types::Actor, }, - handlers::create::get_object_content, + handlers::create::{ + get_object_attachments, + get_object_content, + get_object_tags, + }, types::Object, vocabulary::{NOTE, PERSON}, }; @@ -25,24 +31,48 @@ use crate::models::{ use super::HandlerResult; async fn handle_update_note( + config: &Config, db_client: &mut impl DatabaseClient, activity: Value, ) -> HandlerResult { let object: Object = serde_json::from_value(activity["object"].to_owned()) .map_err(|_| ValidationError("invalid object"))?; - let post_id = match get_post_by_remote_object_id( + let post = match get_post_by_remote_object_id( db_client, &object.id, ).await { - Ok(post) => post.id, + Ok(post) => post, // Ignore Update if post is not found locally Err(DatabaseError::NotFound(_)) => return Ok(None), Err(other_error) => return Err(other_error.into()), }; let content = get_object_content(&object)?; let updated_at = object.updated.unwrap_or(Utc::now()); - let post_data = PostUpdateData { content, updated_at }; - update_post(db_client, &post_id, post_data).await?; + let attachments = get_object_attachments( + config, + db_client, + &object, + &post.author, + ).await?; + if content.is_empty() && attachments.is_empty() { + return Err(ValidationError("post is empty").into()); + }; + let (mentions, hashtags, links, emojis) = get_object_tags( + config, + db_client, + &object, + &HashMap::new(), + ).await?; + let post_data = PostUpdateData { + content, + attachments, + mentions, + tags: hashtags, + links, + emojis, + updated_at, + }; + update_post(db_client, &post.id, post_data).await?; Ok(Some(NOTE)) } @@ -85,7 +115,7 @@ pub async fn handle_update( .ok_or(ValidationError("unknown object type"))?; match object_type { NOTE => { - handle_update_note(db_client, activity).await + handle_update_note(config, db_client, activity).await }, PERSON => { handle_update_person(config, db_client, activity).await diff --git a/src/models/posts/queries.rs b/src/models/posts/queries.rs index c6fa8d0..6109519 100644 --- a/src/models/posts/queries.rs +++ b/src/models/posts/queries.rs @@ -306,12 +306,13 @@ pub async fn create_post( } pub async fn update_post( - db_client: &impl DatabaseClient, + db_client: &mut impl DatabaseClient, post_id: &Uuid, post_data: PostUpdateData, ) -> Result<(), DatabaseError> { + let transaction = db_client.transaction().await?; // Reposts and immutable posts can't be updated - let updated_count = db_client.execute( + let maybe_row = transaction.query_opt( " UPDATE post SET @@ -320,6 +321,7 @@ pub async fn update_post( WHERE id = $3 AND repost_of_id IS NULL AND ipfs_cid IS NULL + RETURNING post ", &[ &post_data.content, @@ -327,9 +329,58 @@ pub async fn update_post( &post_id, ], ).await?; - if updated_count == 0 { - return Err(DatabaseError::NotFound("post")); - }; + let row = maybe_row.ok_or(DatabaseError::NotFound("post"))?; + let db_post: DbPost = row.try_get("post")?; + + // Delete and re-create related objects + transaction.execute( + "DELETE FROM media_attachment WHERE post_id = $1", + &[&db_post.id], + ).await?; + transaction.execute( + "DELETE FROM mention WHERE post_id = $1", + &[&db_post.id], + ).await?; + transaction.execute( + "DELETE FROM post_tag WHERE post_id = $1", + &[&db_post.id], + ).await?; + transaction.execute( + "DELETE FROM post_link WHERE source_id = $1", + &[&db_post.id], + ).await?; + transaction.execute( + "DELETE FROM post_emoji WHERE post_id = $1", + &[&db_post.id], + ).await?; + create_post_attachments( + &transaction, + &db_post.id, + &db_post.author_id, + post_data.attachments, + ).await?; + create_post_mentions( + &transaction, + &db_post.id, + post_data.mentions, + ).await?; + create_post_tags( + &transaction, + &db_post.id, + post_data.tags, + ).await?; + create_post_links( + &transaction, + &db_post.id, + post_data.links, + ).await?; + create_post_emojis( + &transaction, + &db_post.id, + post_data.emojis, + ).await?; + + transaction.commit().await?; Ok(()) } @@ -1323,6 +1374,7 @@ mod tests { let post_data = PostUpdateData { content: "test update".to_string(), updated_at: Utc::now(), + ..Default::default() }; update_post(db_client, &post.id, post_data).await.unwrap(); let post = get_post_by_id(db_client, &post.id).await.unwrap(); diff --git a/src/models/posts/types.rs b/src/models/posts/types.rs index bd36281..bcb8a40 100644 --- a/src/models/posts/types.rs +++ b/src/models/posts/types.rs @@ -276,7 +276,13 @@ impl PostCreateData { } } +#[cfg_attr(test, derive(Default))] pub struct PostUpdateData { pub content: String, + pub attachments: Vec, + pub mentions: Vec, + pub tags: Vec, + pub links: Vec, + pub emojis: Vec, pub updated_at: DateTime, }