diff --git a/src/activitypub/handlers/follow.rs b/src/activitypub/handlers/follow.rs index f91e7be..41b668e 100644 --- a/src/activitypub/handlers/follow.rs +++ b/src/activitypub/handlers/follow.rs @@ -12,8 +12,13 @@ use crate::activitypub::{ use crate::config::Config; use crate::database::DatabaseError; use crate::errors::ValidationError; -use crate::models::relationships::queries::follow; -use crate::models::users::queries::get_user_by_name; +use crate::models::{ + relationships::queries::{ + create_remote_follow_request_opt, + follow_request_accepted, + }, + users::queries::get_user_by_name, +}; use super::{HandlerError, HandlerResult}; #[derive(Deserialize)] @@ -45,7 +50,13 @@ pub async fn handle_follow( &activity.object, )?; let target_user = get_user_by_name(db_client, &target_username).await?; - match follow(db_client, &source_profile.id, &target_user.profile.id).await { + let follow_request = create_remote_follow_request_opt( + db_client, + &source_profile.id, + &target_user.id, + &activity.id, + ).await?; + match follow_request_accepted(db_client, &follow_request.id).await { Ok(_) => (), // Proceed even if relationship already exists Err(DatabaseError::AlreadyExists(_)) => (), diff --git a/src/models/relationships/queries.rs b/src/models/relationships/queries.rs index 2e68c91..7645c05 100644 --- a/src/models/relationships/queries.rs +++ b/src/models/relationships/queries.rs @@ -115,7 +115,7 @@ pub async fn unfollow( ).await?; let relationship_deleted = deleted_count > 0; // Delete follow request (for remote follows) - let follow_request_deleted = delete_follow_request( + let follow_request_deleted = delete_follow_request_opt( &transaction, source_id, target_id, @@ -135,6 +135,7 @@ pub async fn unfollow( Ok(follow_request_deleted) } +// Follow remote actor pub async fn create_follow_request( db_client: &impl GenericClient, source_id: &Uuid, @@ -160,6 +161,40 @@ pub async fn create_follow_request( Ok(request) } +// Save follow request from remote actor +pub async fn create_remote_follow_request_opt( + db_client: &impl GenericClient, + source_id: &Uuid, + target_id: &Uuid, + activity_id: &str, +) -> Result { + let request_id = new_uuid(); + let row = db_client.query_one( + " + INSERT INTO follow_request ( + id, + source_id, + target_id, + activity_id, + request_status + ) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (source_id, target_id) + DO UPDATE SET activity_id = $4 + RETURNING follow_request + ", + &[ + &request_id, + &source_id, + &target_id, + &activity_id, + &FollowRequestStatus::Pending, + ], + ).await?; + let request = row.try_get("follow_request")?; + Ok(request) +} + pub async fn follow_request_accepted( db_client: &mut impl GenericClient, request_id: &Uuid, @@ -200,7 +235,7 @@ pub async fn follow_request_rejected( Ok(()) } -async fn delete_follow_request( +async fn delete_follow_request_opt( db_client: &impl GenericClient, source_id: &Uuid, target_id: &Uuid, @@ -557,4 +592,40 @@ mod tests { let following = get_following(db_client, &source.id).await.unwrap(); assert!(following.is_empty()); } + + #[tokio::test] + #[serial] + async fn test_followed_by_remote_profile() { + let db_client = &mut create_test_database().await; + let source_data = ProfileCreateData { + username: "follower".to_string(), + hostname: Some("example.org".to_string()), + actor_json: Some(Actor::default()), + ..Default::default() + }; + let source = create_profile(db_client, source_data).await.unwrap(); + let target_data = UserCreateData { + username: "test".to_string(), + ..Default::default() + }; + let target = create_user(db_client, target_data).await.unwrap(); + // Create follow request + let activity_id = "https://example.org/objects/123"; + let _follow_request = create_remote_follow_request_opt( + db_client, &source.id, &target.id, activity_id, + ).await.unwrap(); + // Repeat + let follow_request = create_remote_follow_request_opt( + db_client, &source.id, &target.id, activity_id, + ).await.unwrap(); + assert_eq!(follow_request.source_id, source.id); + assert_eq!(follow_request.target_id, target.id); + assert_eq!(follow_request.activity_id, Some(activity_id.to_string())); + assert_eq!(follow_request.request_status, FollowRequestStatus::Pending); + // Accept follow request + follow_request_accepted(db_client, &follow_request.id).await.unwrap(); + let follow_request = get_follow_request_by_id(db_client, &follow_request.id) + .await.unwrap(); + assert_eq!(follow_request.request_status, FollowRequestStatus::Accepted); + } }