diff --git a/src/activitypub/fetcher/helpers.rs b/src/activitypub/fetcher/helpers.rs index c441afd..3a8e714 100644 --- a/src/activitypub/fetcher/helpers.rs +++ b/src/activitypub/fetcher/helpers.rs @@ -59,7 +59,7 @@ pub async fn get_or_import_profile_by_actor_id( err })?; log::info!("fetched profile {}", profile_data.acct); - let profile = create_profile(db_client, &profile_data).await?; + let profile = create_profile(db_client, profile_data).await?; profile }, Err(other_error) => return Err(other_error.into()), @@ -90,6 +90,6 @@ pub async fn import_profile_by_actor_address( }; log::info!("fetched profile {}", profile_data.acct); profile_data.clean()?; - let profile = create_profile(db_client, &profile_data).await?; + let profile = create_profile(db_client, profile_data).await?; Ok(profile) } diff --git a/src/models/posts/queries.rs b/src/models/posts/queries.rs index 40a2f64..18c5523 100644 --- a/src/models/posts/queries.rs +++ b/src/models/posts/queries.rs @@ -205,13 +205,29 @@ pub const RELATED_TAGS: &str = WHERE post_tag.post_id = post.id ) AS tags"; +fn build_visibility_filter() -> String { + format!( + "( + post.visibility = {visibility_public} + OR $current_user_id::uuid IS NULL + OR post.author_id = $current_user_id + OR EXISTS ( + SELECT 1 FROM mention + WHERE post_id = post.id AND profile_id = $current_user_id + ) + )", + visibility_public=i16::from(&Visibility::Public), + ) +} + pub async fn get_home_timeline( db_client: &impl GenericClient, current_user_id: &Uuid, max_post_id: Option, limit: i64, ) -> Result, DatabaseError> { - // Select posts from follows + own posts. + // Select posts from follows, posts where current user is mentioned + // and user's own posts. // Exclude direct messages where current user is not mentioned. let statement = format!( " @@ -229,15 +245,12 @@ pub async fn get_home_timeline( SELECT 1 FROM relationship WHERE source_id = $current_user_id AND target_id = post.author_id ) - ) - AND ( - post.visibility = {visibility_public} - OR post.author_id = $current_user_id OR EXISTS ( SELECT 1 FROM mention WHERE post_id = post.id AND profile_id = $current_user_id ) ) + AND {visibility_filter} AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id) ORDER BY post.id DESC LIMIT $limit @@ -245,7 +258,7 @@ pub async fn get_home_timeline( related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, - visibility_public=i16::from(&Visibility::Public), + visibility_filter=build_visibility_filter(), ); let query = query!( &statement, @@ -297,24 +310,15 @@ pub async fn get_posts_by_author( max_post_id: Option, limit: i64, ) -> Result, DatabaseError> { - let mut condition = "post.author_id = $profile_id - AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id)".to_string(); + let mut condition = format!( + "post.author_id = $profile_id + AND {visibility_filter} + AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id)", + visibility_filter=build_visibility_filter(), + ); if !include_replies { condition.push_str(" AND post.in_reply_to_id IS NULL"); }; - let visibility_filter = format!( - " AND ( - post.visibility = {visibility_public} - OR $current_user_id::uuid IS NULL - OR post.author_id = $current_user_id - OR EXISTS ( - SELECT 1 FROM mention - WHERE post_id = post.id AND profile_id = $current_user_id - ) - )", - visibility_public=i16::from(&Visibility::Public), - ); - condition.push_str(&visibility_filter); let statement = format!( " SELECT @@ -365,19 +369,11 @@ pub async fn get_posts_by_tag( FROM post JOIN actor_profile ON post.author_id = actor_profile.id WHERE - ( - post.visibility = {visibility_public} - OR $current_user_id::uuid IS NULL - OR post.author_id = $current_user_id - OR EXISTS ( - SELECT 1 FROM mention - WHERE post_id = post.id AND profile_id = $current_user_id - ) - ) - AND EXISTS ( + EXISTS ( SELECT 1 FROM post_tag JOIN tag ON post_tag.tag_id = tag.id WHERE post_tag.post_id = post.id AND tag.tag_name = $tag_name ) + AND {visibility_filter} AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id) ORDER BY post.id DESC LIMIT $limit @@ -385,7 +381,7 @@ pub async fn get_posts_by_tag( related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, - visibility_public=i16::from(&Visibility::Public), + visibility_filter=build_visibility_filter(), ); let query = query!( &statement, @@ -438,25 +434,13 @@ pub async fn get_thread( post_id: &Uuid, current_user_id: Option<&Uuid>, ) -> Result, DatabaseError> { - let condition = format!( - " - post.visibility = {visibility_public} - OR $current_user_id::uuid IS NULL - OR post.author_id = $current_user_id - OR EXISTS ( - SELECT 1 FROM mention - WHERE post_id = post.id AND profile_id = $current_user_id - ) - ", - visibility_public=i16::from(&Visibility::Public), - ); // TODO: limit recursion depth let statement = format!( " WITH RECURSIVE ancestors (id, in_reply_to_id) AS ( SELECT post.id, post.in_reply_to_id FROM post - WHERE post.id = $post_id AND ({condition}) + WHERE post.id = $post_id AND {visibility_filter} UNION ALL SELECT post.id, post.in_reply_to_id FROM post JOIN ancestors ON post.id = ancestors.in_reply_to_id @@ -476,13 +460,13 @@ pub async fn get_thread( FROM post JOIN context ON post.id = context.id JOIN actor_profile ON post.author_id = actor_profile.id - WHERE {condition} + WHERE {visibility_filter} ORDER BY context.path ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, - condition=condition, + visibility_filter=build_visibility_filter(), ); let query = query!( &statement, @@ -779,3 +763,106 @@ pub async fn delete_post( ipfs_objects: orphaned_ipfs_objects, }) } + + +#[cfg(test)] +mod tests { + use serial_test::serial; + use crate::database::test_utils::create_test_database; + use crate::models::profiles::queries::create_profile; + use crate::models::profiles::types::ProfileCreateData; + use crate::models::relationships::queries::follow; + use crate::models::users::queries::create_user; + use crate::models::users::types::UserCreateData; + use super::*; + + #[tokio::test] + #[serial] + async fn test_create_post() { + let db_client = &mut create_test_database().await; + let profile_data = ProfileCreateData { + username: "test".to_string(), + ..Default::default() + }; + let profile = create_profile(db_client, profile_data).await.unwrap(); + let post_data = PostCreateData { + content: "test post".to_string(), + ..Default::default() + }; + let post = create_post(db_client, &profile.id, post_data).await.unwrap(); + assert_eq!(post.content, "test post"); + assert_eq!(post.author.id, profile.id); + } + + #[tokio::test] + #[serial] + async fn test_home_timeline() { + let db_client = &mut create_test_database().await; + let current_user_data = UserCreateData { + username: "test".to_string(), + ..Default::default() + }; + let current_user = create_user(db_client, current_user_data).await.unwrap(); + // Current user's post + let post_data_1 = PostCreateData { + content: "my post".to_string(), + ..Default::default() + }; + let post_1 = create_post(db_client, ¤t_user.id, post_data_1).await.unwrap(); + // Current user's direct message + let post_data_2 = PostCreateData { + content: "my post".to_string(), + visibility: Visibility::Direct, + ..Default::default() + }; + let post_2 = create_post(db_client, ¤t_user.id, post_data_2).await.unwrap(); + // Another user + let user_data_1 = UserCreateData { + username: "another-user".to_string(), + ..Default::default() + }; + let user_1 = create_user(db_client, user_data_1).await.unwrap(); + let post_data_3 = PostCreateData { + content: "test post".to_string(), + ..Default::default() + }; + let post_3 = create_post(db_client, &user_1.id, post_data_3).await.unwrap(); + // Direct message from another user to current user + let post_data_4 = PostCreateData { + content: "test post".to_string(), + visibility: Visibility::Direct, + mentions: vec![current_user.id], + ..Default::default() + }; + let post_4 = create_post(db_client, &user_1.id, post_data_4).await.unwrap(); + // Followed + let user_data_2 = UserCreateData { + username: "followed".to_string(), + ..Default::default() + }; + let user_2 = create_user(db_client, user_data_2).await.unwrap(); + follow(db_client, ¤t_user.id, &user_2.id).await.unwrap(); + let post_data_5 = PostCreateData { + content: "test post".to_string(), + ..Default::default() + }; + let post_5 = create_post(db_client, &user_2.id, post_data_5).await.unwrap(); + // Direct message from followed user sent to another user + let post_data_6 = PostCreateData { + content: "test post".to_string(), + visibility: Visibility::Direct, + mentions: vec![user_1.id], + ..Default::default() + }; + let post_6 = create_post(db_client, &user_2.id, post_data_6).await.unwrap(); + + let timeline = get_home_timeline(db_client, ¤t_user.id, None, 10).await.unwrap(); + assert_eq!(timeline.len(), 4); + assert_eq!(timeline.iter().any(|post| post.id == post_1.id), true); + assert_eq!(timeline.iter().any(|post| post.id == post_2.id), true); + assert_eq!(timeline.iter().any(|post| post.id == post_3.id), false); + assert_eq!(timeline.iter().any(|post| post.id == post_4.id), true); + assert_eq!(timeline.iter().any(|post| post.id == post_5.id), true); + assert_eq!(timeline.iter().any(|post| post.id == post_6.id), false); + } +} diff --git a/src/models/profiles/queries.rs b/src/models/profiles/queries.rs index e39554f..36586c6 100644 --- a/src/models/profiles/queries.rs +++ b/src/models/profiles/queries.rs @@ -19,7 +19,7 @@ use super::types::{ /// Create new profile using given Client or Transaction. pub async fn create_profile( db_client: &impl GenericClient, - profile_data: &ProfileCreateData, + profile_data: ProfileCreateData, ) -> Result { let profile_id = new_uuid(); let extra_fields = ExtraFields(profile_data.extra_fields.clone()); @@ -398,8 +398,8 @@ mod tests { ..Default::default() }; let db_client = create_test_database().await; - let profile = create_profile(&db_client, &profile_data).await.unwrap(); - assert_eq!(profile.username, profile_data.username); + let profile = create_profile(&db_client, profile_data).await.unwrap(); + assert_eq!(profile.username, "test"); } #[tokio::test] @@ -407,7 +407,7 @@ mod tests { async fn test_delete_profile() { let profile_data = ProfileCreateData::default(); let mut db_client = create_test_database().await; - let profile = create_profile(&db_client, &profile_data).await.unwrap(); + let profile = create_profile(&db_client, profile_data).await.unwrap(); let deletion_queue = delete_profile(&mut db_client, &profile.id).await.unwrap(); assert_eq!(deletion_queue.files.len(), 0); assert_eq!(deletion_queue.ipfs_objects.len(), 0); diff --git a/src/models/users/queries.rs b/src/models/users/queries.rs index c6b2577..fef20bd 100644 --- a/src/models/users/queries.rs +++ b/src/models/users/queries.rs @@ -83,7 +83,7 @@ pub async fn create_user( extra_fields: vec![], actor_json: None, }; - let profile = create_profile(&transaction, &profile_data).await?; + let profile = create_profile(&transaction, profile_data).await?; // Create user let row = transaction.query_one( " diff --git a/src/models/users/types.rs b/src/models/users/types.rs index 6eec234..6792c74 100644 --- a/src/models/users/types.rs +++ b/src/models/users/types.rs @@ -45,6 +45,7 @@ impl User { } } +#[cfg_attr(test, derive(Default))] pub struct UserCreateData { pub username: String, pub password_hash: String,