diff --git a/src/models/relationships/queries.rs b/src/models/relationships/queries.rs index 7426971..bfe96ae 100644 --- a/src/models/relationships/queries.rs +++ b/src/models/relationships/queries.rs @@ -38,12 +38,12 @@ pub async fn get_relationships( EXISTS ( SELECT 1 FROM follow_request WHERE source_id = $1 AND target_id = actor_profile.id - AND request_status = 1 + AND request_status = $3 ) AS requested FROM actor_profile WHERE actor_profile.id = ANY($2) ", - &[&source_id, &target_ids], + &[&source_id, &target_ids, &FollowRequestStatus::Pending], ).await?; let relationships = rows.iter() .map(Relationship::try_from) @@ -71,12 +71,12 @@ pub async fn get_relationship( EXISTS ( SELECT 1 FROM follow_request WHERE source_id = $1 AND target_id = actor_profile.id - AND request_status = 1 + AND request_status = $3 ) AS requested FROM actor_profile WHERE actor_profile.id = $2 ", - &[&source_id, &target_id], + &[&source_id, &target_id, &FollowRequestStatus::Pending], ).await?; let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?; let relationship = Relationship::try_from(&row)?; @@ -307,3 +307,50 @@ pub async fn get_following( .collect::>()?; Ok(profiles) } + +#[cfg(test)] +mod tests { + use serial_test::serial; + use crate::database::test_utils::create_test_database; + use crate::models::relationships::queries::follow; + use crate::models::users::queries::create_user; + use crate::models::users::types::UserCreateData; + use super::*; + + #[tokio::test] + #[serial] + async fn test_get_relationship() { + let db_client = &mut create_test_database().await; + let user_data_1 = UserCreateData { + username: "user".to_string(), + ..Default::default() + }; + let user_1 = create_user(db_client, user_data_1).await.unwrap(); + let user_data_2 = UserCreateData { + username: "another-user".to_string(), + ..Default::default() + }; + let user_2 = create_user(db_client, user_data_2).await.unwrap(); + // Initial state + let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); + assert_eq!(relationship.id, user_2.id); + assert_eq!(relationship.following, false); + assert_eq!(relationship.followed_by, false); + assert_eq!(relationship.requested, false); + // Follow request + let follow_request = create_follow_request(db_client, &user_1.id, &user_2.id).await.unwrap(); + let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); + assert_eq!(relationship.id, user_2.id); + assert_eq!(relationship.following, false); + assert_eq!(relationship.followed_by, false); + assert_eq!(relationship.requested, true); + // Mutual follow + follow_request_accepted(db_client, &follow_request.id).await.unwrap(); + follow(db_client, &user_2.id, &user_1.id).await.unwrap(); + let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); + assert_eq!(relationship.id, user_2.id); + assert_eq!(relationship.following, true); + assert_eq!(relationship.followed_by, true); + assert_eq!(relationship.requested, false); + } +}