diff --git a/src/activitypub/views.rs b/src/activitypub/views.rs index 5f39547..89fb283 100644 --- a/src/activitypub/views.rs +++ b/src/activitypub/views.rs @@ -13,7 +13,8 @@ use crate::config::Config; use crate::database::{Pool, get_database_client}; use crate::errors::HttpError; use crate::frontend::{get_post_page_url, get_profile_page_url}; -use crate::models::posts::queries::{get_posts_by_author, get_thread}; +use crate::models::posts::helpers::{add_related_posts, can_view_post}; +use crate::models::posts::queries::{get_post_by_id, get_posts_by_author}; use crate::models::users::queries::get_user_by_name; use super::actors::types::{get_local_actor, get_instance_actor}; use super::builders::create_note::{build_note, build_create_note}; @@ -129,7 +130,7 @@ async fn outbox( let db_client = &**get_database_client(&db_pool).await?; let user = get_user_by_name(db_client, &username).await?; // Posts are ordered by creation date - let posts = get_posts_by_author( + let mut posts = get_posts_by_author( db_client, &user.id, None, // include only public posts @@ -138,12 +139,11 @@ async fn outbox( None, COLLECTION_PAGE_SIZE, ).await?; + add_related_posts(db_client, posts.iter_mut().collect()).await?; let activities: Vec<_> = posts.iter().filter_map(|post| { if post.in_reply_to_id.is_some() || post.repost_of_id.is_some() { return None; }; - // Replies are not included so post.in_reply_to - // does not need to be populated let activity = build_create_note( &instance.host(), &instance.url(), @@ -295,11 +295,10 @@ pub async fn object_view( let internal_object_id = internal_object_id.into_inner(); // Try to find local post by ID, // return 404 if not found, or not public, or it is a repost - let thread = get_thread(db_client, &internal_object_id, None).await?; - let mut post = thread.iter() - .find(|post| post.id == internal_object_id && post.author.is_local()) - .ok_or(HttpError::NotFoundError("post"))? - .clone(); + let mut post = get_post_by_id(db_client, &internal_object_id).await?; + if !post.author.is_local() || !can_view_post(db_client, None, &post).await? { + return Err(HttpError::NotFoundError("post")); + }; if !is_activitypub_request(request.headers()) { let page_url = get_post_page_url(&config.instance_url(), &post.id); let response = HttpResponse::Found() @@ -307,17 +306,7 @@ pub async fn object_view( .finish(); return Ok(response); }; - post.in_reply_to = match post.in_reply_to_id { - Some(in_reply_to_id) => { - let in_reply_to = thread.iter() - .find(|post| post.id == in_reply_to_id) - // Parent post must be present in thread - .ok_or(HttpError::InternalError)? - .clone(); - Some(Box::new(in_reply_to)) - }, - None => None, - }; + add_related_posts(db_client, vec![&mut post]).await?; let object = build_note( &config.instance().host(), &config.instance().url(), diff --git a/src/models/posts/helpers.rs b/src/models/posts/helpers.rs index 0d04630..1aac3b2 100644 --- a/src/models/posts/helpers.rs +++ b/src/models/posts/helpers.rs @@ -23,6 +23,10 @@ pub async fn add_related_posts( Ok(post) }; for post in posts { + if let Some(ref in_reply_to_id) = post.in_reply_to_id { + let in_reply_to = get_post(in_reply_to_id)?; + post.in_reply_to = Some(Box::new(in_reply_to)); + }; if let Some(ref repost_of_id) = post.repost_of_id { let mut repost_of = get_post(repost_of_id)?; for linked_id in repost_of.links.iter() { @@ -119,11 +123,39 @@ mod tests { use serial_test::serial; use tokio_postgres::Client; use crate::database::test_utils::create_test_database; + use crate::models::posts::queries::create_post; + use crate::models::posts::types::PostCreateData; use crate::models::relationships::queries::{follow, subscribe}; use crate::models::users::queries::create_user; use crate::models::users::types::UserCreateData; use super::*; + #[tokio::test] + #[serial] + async fn test_add_related_posts() { + let db_client = &mut create_test_database().await; + let author_data = UserCreateData { + username: "test".to_string(), + ..Default::default() + }; + let author = create_user(db_client, author_data).await.unwrap(); + let post_data = PostCreateData { + content: "post".to_string(), + ..Default::default() + }; + let post = create_post(db_client, &author.id, post_data).await.unwrap(); + let reply_data = PostCreateData { + content: "reply".to_string(), + in_reply_to_id: Some(post.id.clone()), + ..Default::default() + }; + let mut reply = create_post(db_client, &author.id, reply_data).await.unwrap(); + add_related_posts(db_client, vec![&mut reply]).await.unwrap(); + assert_eq!(reply.in_reply_to.unwrap().id, post.id); + assert_eq!(reply.repost_of.is_none(), true); + assert_eq!(reply.linked.is_empty(), true); + } + #[tokio::test] #[serial] async fn test_can_view_post_anonymous() { diff --git a/src/models/posts/queries.rs b/src/models/posts/queries.rs index 8a7fe10..59e676f 100644 --- a/src/models/posts/queries.rs +++ b/src/models/posts/queries.rs @@ -461,6 +461,9 @@ pub async fn get_related_posts( FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE post.id IN ( + SELECT post.in_reply_to_id + FROM post WHERE post.id = ANY($1) + UNION ALL SELECT post.repost_of_id FROM post WHERE post.id = ANY($1) UNION ALL