Refactor can_view_post() into async function

This commit is contained in:
silverpill 2022-01-06 13:45:07 +00:00
parent 011951c129
commit 2d2ec704a2
2 changed files with 30 additions and 17 deletions

View file

@ -115,7 +115,7 @@ async fn get_status(
None => None, None => None,
}; };
let mut post = get_post_by_id(db_client, &status_id).await?; let mut post = get_post_by_id(db_client, &status_id).await?;
if !can_view_post(maybe_current_user.as_ref(), &post) { if !can_view_post(db_client, maybe_current_user.as_ref(), &post).await? {
return Err(HttpError::NotFoundError("post")); return Err(HttpError::NotFoundError("post"));
}; };
get_reposted_posts(db_client, vec![&mut post]).await?; get_reposted_posts(db_client, vec![&mut post]).await?;
@ -197,7 +197,7 @@ async fn favourite(
let db_client = &mut **get_database_client(&db_pool).await?; let db_client = &mut **get_database_client(&db_pool).await?;
let current_user = get_current_user(db_client, auth.token()).await?; let current_user = get_current_user(db_client, auth.token()).await?;
let mut post = get_post_by_id(db_client, &status_id).await?; let mut post = get_post_by_id(db_client, &status_id).await?;
if !can_view_post(Some(&current_user), &post) { if !can_view_post(db_client, Some(&current_user), &post).await? {
return Err(HttpError::NotFoundError("post")); return Err(HttpError::NotFoundError("post"));
}; };
let maybe_reaction_created = match create_reaction( let maybe_reaction_created = match create_reaction(
@ -242,9 +242,6 @@ async fn unfavourite(
let db_client = &mut **get_database_client(&db_pool).await?; let db_client = &mut **get_database_client(&db_pool).await?;
let current_user = get_current_user(db_client, auth.token()).await?; let current_user = get_current_user(db_client, auth.token()).await?;
let mut post = get_post_by_id(db_client, &status_id).await?; let mut post = get_post_by_id(db_client, &status_id).await?;
if !can_view_post(Some(&current_user), &post) {
return Err(HttpError::NotFoundError("post"));
};
let maybe_reaction_deleted = match delete_reaction( let maybe_reaction_deleted = match delete_reaction(
db_client, &current_user.id, &status_id, db_client, &current_user.id, &status_id,
).await { ).await {

View file

@ -59,8 +59,12 @@ pub async fn get_actions_for_posts(
Ok(()) Ok(())
} }
pub fn can_view_post(user: Option<&User>, post: &Post) -> bool { pub async fn can_view_post(
match post.visibility { _db_client: &impl GenericClient,
user: Option<&User>,
post: &Post,
) -> Result<bool, DatabaseError> {
let result = match post.visibility {
Visibility::Public => true, Visibility::Public => true,
Visibility::Direct => { Visibility::Direct => {
if let Some(user) = user { if let Some(user) = user {
@ -71,40 +75,52 @@ pub fn can_view_post(user: Option<&User>, post: &Post) -> bool {
false false
} }
}, },
} };
Ok(result)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use super::*; use super::*;
#[test] #[tokio::test]
fn test_can_view_post_anonymous() { #[serial]
async fn test_can_view_post_anonymous() {
let post = Post { let post = Post {
visibility: Visibility::Public, visibility: Visibility::Public,
..Default::default() ..Default::default()
}; };
assert!(can_view_post(None, &post)); let db_client = &create_test_database().await;
let result = can_view_post(db_client, None, &post).await.unwrap();
assert_eq!(result, true);
} }
#[test] #[tokio::test]
fn test_can_view_post_direct() { #[serial]
async fn test_can_view_post_direct() {
let user = User::default(); let user = User::default();
let post = Post { let post = Post {
visibility: Visibility::Direct, visibility: Visibility::Direct,
..Default::default() ..Default::default()
}; };
assert!(!can_view_post(Some(&user), &post)); let db_client = &create_test_database().await;
let result = can_view_post(db_client, Some(&user), &post).await.unwrap();
assert_eq!(result, false);
} }
#[test] #[tokio::test]
fn test_can_view_post_direct_mentioned() { #[serial]
async fn test_can_view_post_direct_mentioned() {
let user = User::default(); let user = User::default();
let post = Post { let post = Post {
visibility: Visibility::Direct, visibility: Visibility::Direct,
mentions: vec![user.profile.clone()], mentions: vec![user.profile.clone()],
..Default::default() ..Default::default()
}; };
assert!(can_view_post(Some(&user), &post)); let db_client = &create_test_database().await;
let result = can_view_post(db_client, Some(&user), &post).await.unwrap();
assert_eq!(result, true);
} }
} }