diff --git a/docs/openapi.yaml b/docs/openapi.yaml index 26921e9..452fc0a 100644 --- a/docs/openapi.yaml +++ b/docs/openapi.yaml @@ -651,6 +651,12 @@ components: requested: description: Do you have a pending follow request for this user? type: boolean + subscription_to: + description: Are you sending subscription payments to this user? + type: boolean + subscription_from: + description: Are you receiving subscription payments from this user? + type: boolean Signature: type: object properties: diff --git a/src/models/relationships/queries.rs b/src/models/relationships/queries.rs index 25608b3..726c59f 100644 --- a/src/models/relationships/queries.rs +++ b/src/models/relationships/queries.rs @@ -75,6 +75,13 @@ pub async fn get_relationship( relationship_map.requested = true; }; }, + RelationshipType::Subscription => { + if relationship.is_direct(source_id, target_id)? { + relationship_map.subscription_to = true; + } else { + relationship_map.subscription_from = true; + }; + }, }; }; Ok(relationship_map) @@ -309,19 +316,53 @@ pub async fn get_following( Ok(profiles) } +pub async fn subscribe( + db_client: &impl GenericClient, + source_id: &Uuid, + target_id: &Uuid, +) -> Result<(), DatabaseError> { + db_client.execute( + " + INSERT INTO relationship (source_id, target_id, relationship_type) + VALUES ($1, $2, $3) + ", + &[&source_id, &target_id, &RelationshipType::Subscription], + ).await.map_err(catch_unique_violation("relationship"))?; + Ok(()) +} + +pub async fn unsubscribe( + db_client: &impl GenericClient, + source_id: &Uuid, + target_id: &Uuid, +) -> Result<(), DatabaseError> { + let deleted_count = db_client.execute( + " + DELETE FROM relationship + WHERE + source_id = $1 AND target_id = $2 + AND relationship_type = $3 + ", + &[&source_id, &target_id, &RelationshipType::Subscription], + ).await?; + if deleted_count == 0 { + return Err(DatabaseError::NotFound("relationship")); + }; + Ok(()) +} + #[cfg(test)] mod tests { use serial_test::serial; use crate::database::test_utils::create_test_database; use crate::models::relationships::queries::follow; use crate::models::users::queries::create_user; - use crate::models::users::types::UserCreateData; + use crate::models::users::types::{User, UserCreateData}; use super::*; - #[tokio::test] - #[serial] - async fn test_get_relationship() { - let db_client = &mut create_test_database().await; + async fn create_users(db_client: &mut impl GenericClient) + -> Result<(User, User), DatabaseError> + { let user_data_1 = UserCreateData { username: "user".to_string(), ..Default::default() @@ -332,16 +373,25 @@ mod tests { ..Default::default() }; let user_2 = create_user(db_client, user_data_2).await.unwrap(); + Ok((user_1, user_2)) + } + + #[tokio::test] + #[serial] + async fn test_follow_unfollow() { + let db_client = &mut create_test_database().await; + let (user_1, user_2) = create_users(db_client).await.unwrap(); // Initial state let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); assert_eq!(relationship.id, user_2.id); assert_eq!(relationship.following, false); assert_eq!(relationship.followed_by, false); assert_eq!(relationship.requested, false); + assert_eq!(relationship.subscription_to, false); + assert_eq!(relationship.subscription_from, false); // Follow request let follow_request = create_follow_request(db_client, &user_1.id, &user_2.id).await.unwrap(); let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); - assert_eq!(relationship.id, user_2.id); assert_eq!(relationship.following, false); assert_eq!(relationship.followed_by, false); assert_eq!(relationship.requested, true); @@ -349,9 +399,31 @@ mod tests { follow_request_accepted(db_client, &follow_request.id).await.unwrap(); follow(db_client, &user_2.id, &user_1.id).await.unwrap(); let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); - assert_eq!(relationship.id, user_2.id); assert_eq!(relationship.following, true); assert_eq!(relationship.followed_by, true); assert_eq!(relationship.requested, false); + // Unfollow + unfollow(db_client, &user_1.id, &user_2.id).await.unwrap(); + let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); + assert_eq!(relationship.following, false); + assert_eq!(relationship.followed_by, true); + assert_eq!(relationship.requested, false); + } + + #[tokio::test] + #[serial] + async fn test_subscribe_unsubscribe() { + let db_client = &mut create_test_database().await; + let (user_1, user_2) = create_users(db_client).await.unwrap(); + + subscribe(db_client, &user_1.id, &user_2.id).await.unwrap(); + let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); + assert_eq!(relationship.subscription_to, true); + assert_eq!(relationship.subscription_from, false); + + unsubscribe(db_client, &user_1.id, &user_2.id).await.unwrap(); + let relationship = get_relationship(db_client, &user_1.id, &user_2.id).await.unwrap(); + assert_eq!(relationship.subscription_to, false); + assert_eq!(relationship.subscription_from, false); } } diff --git a/src/models/relationships/types.rs b/src/models/relationships/types.rs index 38870f6..bf1bc2d 100644 --- a/src/models/relationships/types.rs +++ b/src/models/relationships/types.rs @@ -12,6 +12,7 @@ use crate::errors::ConversionError; pub enum RelationshipType { Follow, FollowRequest, + Subscription, } impl From<&RelationshipType> for i16 { @@ -19,6 +20,7 @@ impl From<&RelationshipType> for i16 { match value { RelationshipType::Follow => 1, RelationshipType::FollowRequest => 2, + RelationshipType::Subscription => 3, } } } @@ -30,6 +32,7 @@ impl TryFrom for RelationshipType { let relationship_type = match value { 1 => Self::Follow, 2 => Self::FollowRequest, + 3 => Self::Subscription, _ => return Err(ConversionError), }; Ok(relationship_type) @@ -77,10 +80,12 @@ impl TryFrom<&Row> for DbRelationship { #[derive(Default, Serialize)] pub struct RelationshipMap { - pub id: Uuid, + pub id: Uuid, // target ID pub following: bool, pub followed_by: bool, pub requested: bool, + pub subscription_to: bool, + pub subscription_from: bool, } #[derive(Debug)]