diff --git a/src/mastodon_api/statuses/views.rs b/src/mastodon_api/statuses/views.rs index 4951a11..f10804d 100644 --- a/src/mastodon_api/statuses/views.rs +++ b/src/mastodon_api/statuses/views.rs @@ -115,7 +115,7 @@ async fn get_status( None => None, }; let mut post = get_post_by_id(db_client, &status_id).await?; - if !can_view_post(maybe_current_user.as_ref(), &post) { + if !can_view_post(db_client, maybe_current_user.as_ref(), &post).await? { return Err(HttpError::NotFoundError("post")); }; get_reposted_posts(db_client, vec![&mut post]).await?; @@ -197,7 +197,7 @@ async fn favourite( let db_client = &mut **get_database_client(&db_pool).await?; let current_user = get_current_user(db_client, auth.token()).await?; let mut post = get_post_by_id(db_client, &status_id).await?; - if !can_view_post(Some(¤t_user), &post) { + if !can_view_post(db_client, Some(¤t_user), &post).await? { return Err(HttpError::NotFoundError("post")); }; let maybe_reaction_created = match create_reaction( @@ -242,9 +242,6 @@ async fn unfavourite( let db_client = &mut **get_database_client(&db_pool).await?; let current_user = get_current_user(db_client, auth.token()).await?; let mut post = get_post_by_id(db_client, &status_id).await?; - if !can_view_post(Some(¤t_user), &post) { - return Err(HttpError::NotFoundError("post")); - }; let maybe_reaction_deleted = match delete_reaction( db_client, ¤t_user.id, &status_id, ).await { diff --git a/src/models/posts/helpers.rs b/src/models/posts/helpers.rs index a24bcf3..f8c0d31 100644 --- a/src/models/posts/helpers.rs +++ b/src/models/posts/helpers.rs @@ -59,8 +59,12 @@ pub async fn get_actions_for_posts( Ok(()) } -pub fn can_view_post(user: Option<&User>, post: &Post) -> bool { - match post.visibility { +pub async fn can_view_post( + _db_client: &impl GenericClient, + user: Option<&User>, + post: &Post, +) -> Result { + let result = match post.visibility { Visibility::Public => true, Visibility::Direct => { if let Some(user) = user { @@ -71,40 +75,52 @@ pub fn can_view_post(user: Option<&User>, post: &Post) -> bool { false } }, - } + }; + Ok(result) } #[cfg(test)] mod tests { + use serial_test::serial; + use crate::database::test_utils::create_test_database; use super::*; - #[test] - fn test_can_view_post_anonymous() { + #[tokio::test] + #[serial] + async fn test_can_view_post_anonymous() { let post = Post { visibility: Visibility::Public, ..Default::default() }; - assert!(can_view_post(None, &post)); + let db_client = &create_test_database().await; + let result = can_view_post(db_client, None, &post).await.unwrap(); + assert_eq!(result, true); } - #[test] - fn test_can_view_post_direct() { + #[tokio::test] + #[serial] + async fn test_can_view_post_direct() { let user = User::default(); let post = Post { visibility: Visibility::Direct, ..Default::default() }; - assert!(!can_view_post(Some(&user), &post)); + let db_client = &create_test_database().await; + let result = can_view_post(db_client, Some(&user), &post).await.unwrap(); + assert_eq!(result, false); } - #[test] - fn test_can_view_post_direct_mentioned() { + #[tokio::test] + #[serial] + async fn test_can_view_post_direct_mentioned() { let user = User::default(); let post = Post { visibility: Visibility::Direct, mentions: vec![user.profile.clone()], ..Default::default() }; - assert!(can_view_post(Some(&user), &post)); + let db_client = &create_test_database().await; + let result = can_view_post(db_client, Some(&user), &post).await.unwrap(); + assert_eq!(result, true); } }