diff --git a/src/models/users/queries.rs b/src/models/users/queries.rs index 9110946..41f45f0 100644 --- a/src/models/users/queries.rs +++ b/src/models/users/queries.rs @@ -60,6 +60,23 @@ pub async fn create_user( user_data: UserCreateData, ) -> Result { let transaction = db_client.transaction().await?; + // Prevent changes to actor_profile table + transaction.execute( + "LOCK TABLE actor_profile IN EXCLUSIVE MODE", + &[], + ).await?; + // Ensure there are no local accounts with a similar name + let maybe_row = transaction.query_opt( + " + SELECT 1 + FROM user_account JOIN actor_profile USING (id) + WHERE actor_profile.username ILIKE $1 + ", + &[&user_data.username], + ).await?; + if maybe_row.is_some() { + return Err(DatabaseError::AlreadyExists("user")); + }; // Use invite code if let Some(ref invite_code) = user_data.invite_code { let updated_count = transaction.execute( @@ -72,8 +89,8 @@ pub async fn create_user( ).await?; if updated_count == 0 { return Err(DatabaseError::NotFound("invite code")); - } - } + }; + }; // Create profile let profile_data = ProfileCreateData { username: user_data.username.clone(), @@ -245,3 +262,27 @@ pub async fn get_user_count( let count = row.try_get("count")?; Ok(count) } + +#[cfg(test)] +mod tests { + use serial_test::serial; + use crate::database::test_utils::create_test_database; + use super::*; + + #[tokio::test] + #[serial] + async fn test_create_user_impersonation_protection() { + let db_client = &mut create_test_database().await; + let user_data = UserCreateData { + username: "myname".to_string(), + ..Default::default() + }; + create_user(db_client, user_data).await.unwrap(); + let another_user_data = UserCreateData { + username: "myName".to_string(), + ..Default::default() + }; + let result = create_user(db_client, another_user_data).await; + assert!(matches!(result, Err(DatabaseError::AlreadyExists("user")))); + } +}