435 lines
12 KiB
Rust
435 lines
12 KiB
Rust
use serde_json::Value as JsonValue;
|
|
use uuid::Uuid;
|
|
|
|
use mitra_utils::{currencies::Currency, did::Did, did_pkh::DidPkh};
|
|
|
|
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
|
|
use crate::profiles::{
|
|
queries::create_profile,
|
|
types::{DbActorProfile, ProfileCreateData},
|
|
};
|
|
|
|
use super::types::{
|
|
ClientConfig, DbClientConfig, DbInviteCode, DbUser, Role, User, UserCreateData,
|
|
};
|
|
use super::utils::generate_invite_code;
|
|
|
|
pub async fn create_invite_code(
|
|
db_client: &impl DatabaseClient,
|
|
note: Option<&str>,
|
|
) -> Result<String, DatabaseError> {
|
|
let invite_code = generate_invite_code();
|
|
db_client
|
|
.execute(
|
|
"
|
|
INSERT INTO user_invite_code (code, note)
|
|
VALUES ($1, $2)
|
|
",
|
|
&[&invite_code, ¬e],
|
|
)
|
|
.await?;
|
|
Ok(invite_code)
|
|
}
|
|
|
|
pub async fn get_invite_codes(
|
|
db_client: &impl DatabaseClient,
|
|
) -> Result<Vec<DbInviteCode>, DatabaseError> {
|
|
let rows = db_client
|
|
.query(
|
|
"
|
|
SELECT user_invite_code
|
|
FROM user_invite_code
|
|
WHERE used = FALSE
|
|
",
|
|
&[],
|
|
)
|
|
.await?;
|
|
let codes = rows
|
|
.iter()
|
|
.map(|row| row.try_get("user_invite_code"))
|
|
.collect::<Result<_, _>>()?;
|
|
Ok(codes)
|
|
}
|
|
|
|
pub async fn is_valid_invite_code(
|
|
db_client: &impl DatabaseClient,
|
|
invite_code: &str,
|
|
) -> Result<bool, DatabaseError> {
|
|
let maybe_row = db_client
|
|
.query_opt(
|
|
"
|
|
SELECT 1 FROM user_invite_code
|
|
WHERE code = $1 AND used = FALSE
|
|
",
|
|
&[&invite_code],
|
|
)
|
|
.await?;
|
|
Ok(maybe_row.is_some())
|
|
}
|
|
|
|
pub async fn create_user(
|
|
db_client: &mut impl DatabaseClient,
|
|
user_data: UserCreateData,
|
|
) -> Result<User, DatabaseError> {
|
|
assert!(user_data.password_hash.is_some() || user_data.wallet_address.is_some());
|
|
let mut transaction = db_client.transaction().await?;
|
|
// Prevent changes to actor_profile table
|
|
transaction
|
|
.execute("LOCK TABLE actor_profile IN EXCLUSIVE MODE", &[])
|
|
.await?;
|
|
// Ensure there are no local accounts with a similar name
|
|
let maybe_row = transaction
|
|
.query_opt(
|
|
"
|
|
SELECT 1
|
|
FROM user_account JOIN actor_profile USING (id)
|
|
WHERE actor_profile.username ILIKE $1
|
|
LIMIT 1
|
|
",
|
|
&[&user_data.username],
|
|
)
|
|
.await?;
|
|
if maybe_row.is_some() {
|
|
return Err(DatabaseError::AlreadyExists("user"));
|
|
};
|
|
// Use invite code
|
|
if let Some(ref invite_code) = user_data.invite_code {
|
|
let updated_count = transaction
|
|
.execute(
|
|
"
|
|
UPDATE user_invite_code
|
|
SET used = TRUE
|
|
WHERE code = $1 AND used = FALSE
|
|
",
|
|
&[&invite_code],
|
|
)
|
|
.await?;
|
|
if updated_count == 0 {
|
|
return Err(DatabaseError::NotFound("invite code"));
|
|
};
|
|
};
|
|
// Create profile
|
|
let profile_data = ProfileCreateData {
|
|
username: user_data.username.clone(),
|
|
hostname: None,
|
|
display_name: None,
|
|
bio: None,
|
|
avatar: None,
|
|
banner: None,
|
|
manually_approves_followers: false,
|
|
identity_proofs: vec![],
|
|
payment_options: vec![],
|
|
extra_fields: vec![],
|
|
aliases: vec![],
|
|
emojis: vec![],
|
|
actor_json: None,
|
|
};
|
|
let profile = create_profile(&mut transaction, profile_data).await?;
|
|
// Create user
|
|
let row = transaction
|
|
.query_one(
|
|
"
|
|
INSERT INTO user_account (
|
|
id,
|
|
wallet_address,
|
|
password_hash,
|
|
private_key,
|
|
invite_code,
|
|
user_role
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING user_account
|
|
",
|
|
&[
|
|
&profile.id,
|
|
&user_data.wallet_address,
|
|
&user_data.password_hash,
|
|
&user_data.private_key_pem,
|
|
&user_data.invite_code,
|
|
&user_data.role,
|
|
],
|
|
)
|
|
.await
|
|
.map_err(catch_unique_violation("user"))?;
|
|
let db_user: DbUser = row.try_get("user_account")?;
|
|
let user = User::new(db_user, profile);
|
|
transaction.commit().await?;
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn set_user_password(
|
|
db_client: &impl DatabaseClient,
|
|
user_id: &Uuid,
|
|
password_hash: String,
|
|
) -> Result<(), DatabaseError> {
|
|
let updated_count = db_client
|
|
.execute(
|
|
"
|
|
UPDATE user_account SET password_hash = $1
|
|
WHERE id = $2
|
|
",
|
|
&[&password_hash, &user_id],
|
|
)
|
|
.await?;
|
|
if updated_count == 0 {
|
|
return Err(DatabaseError::NotFound("user"));
|
|
};
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn set_user_role(
|
|
db_client: &impl DatabaseClient,
|
|
user_id: &Uuid,
|
|
role: Role,
|
|
) -> Result<(), DatabaseError> {
|
|
let updated_count = db_client
|
|
.execute(
|
|
"
|
|
UPDATE user_account SET user_role = $1
|
|
WHERE id = $2
|
|
",
|
|
&[&role, &user_id],
|
|
)
|
|
.await?;
|
|
if updated_count == 0 {
|
|
return Err(DatabaseError::NotFound("user"));
|
|
};
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn update_client_config(
|
|
db_client: &impl DatabaseClient,
|
|
user_id: &Uuid,
|
|
client_name: &str,
|
|
client_config_value: &JsonValue,
|
|
) -> Result<ClientConfig, DatabaseError> {
|
|
let maybe_row = db_client
|
|
.query_opt(
|
|
"
|
|
UPDATE user_account
|
|
SET client_config = jsonb_set(client_config, ARRAY[$1], $2, true)
|
|
WHERE id = $3
|
|
RETURNING client_config
|
|
",
|
|
&[&client_name, &client_config_value, &user_id],
|
|
)
|
|
.await?;
|
|
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
|
|
let client_config: DbClientConfig = row.try_get("client_config")?;
|
|
Ok(client_config.into_inner())
|
|
}
|
|
|
|
pub async fn get_user_by_id(
|
|
db_client: &impl DatabaseClient,
|
|
user_id: &Uuid,
|
|
) -> Result<User, DatabaseError> {
|
|
let maybe_row = db_client
|
|
.query_opt(
|
|
"
|
|
SELECT user_account, actor_profile
|
|
FROM user_account JOIN actor_profile USING (id)
|
|
WHERE id = $1
|
|
",
|
|
&[&user_id],
|
|
)
|
|
.await?;
|
|
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
|
|
let db_user: DbUser = row.try_get("user_account")?;
|
|
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
|
|
let user = User::new(db_user, db_profile);
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn get_user_by_name(
|
|
db_client: &impl DatabaseClient,
|
|
username: &str,
|
|
) -> Result<User, DatabaseError> {
|
|
let maybe_row = db_client
|
|
.query_opt(
|
|
"
|
|
SELECT user_account, actor_profile
|
|
FROM user_account JOIN actor_profile USING (id)
|
|
WHERE actor_profile.username = $1
|
|
",
|
|
&[&username],
|
|
)
|
|
.await?;
|
|
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
|
|
let db_user: DbUser = row.try_get("user_account")?;
|
|
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
|
|
let user = User::new(db_user, db_profile);
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn is_registered_user(
|
|
db_client: &impl DatabaseClient,
|
|
username: &str,
|
|
) -> Result<bool, DatabaseError> {
|
|
let maybe_row = db_client
|
|
.query_opt(
|
|
"
|
|
SELECT 1 FROM user_account JOIN actor_profile USING (id)
|
|
WHERE actor_profile.username = $1
|
|
",
|
|
&[&username],
|
|
)
|
|
.await?;
|
|
Ok(maybe_row.is_some())
|
|
}
|
|
|
|
pub async fn get_user_by_login_address(
|
|
db_client: &impl DatabaseClient,
|
|
wallet_address: &str,
|
|
) -> Result<User, DatabaseError> {
|
|
let maybe_row = db_client
|
|
.query_opt(
|
|
"
|
|
SELECT user_account, actor_profile
|
|
FROM user_account JOIN actor_profile USING (id)
|
|
WHERE wallet_address = $1
|
|
",
|
|
&[&wallet_address],
|
|
)
|
|
.await?;
|
|
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
|
|
let db_user: DbUser = row.try_get("user_account")?;
|
|
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
|
|
let user = User::new(db_user, db_profile);
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn get_user_by_did(
|
|
db_client: &impl DatabaseClient,
|
|
did: &Did,
|
|
) -> Result<User, DatabaseError> {
|
|
// DIDs must be locally unique
|
|
let maybe_row = db_client
|
|
.query_opt(
|
|
"
|
|
SELECT user_account, actor_profile
|
|
FROM user_account JOIN actor_profile USING (id)
|
|
WHERE
|
|
EXISTS (
|
|
SELECT 1
|
|
FROM jsonb_array_elements(actor_profile.identity_proofs) AS proof
|
|
WHERE proof ->> 'issuer' = $1
|
|
)
|
|
",
|
|
&[&did.to_string()],
|
|
)
|
|
.await?;
|
|
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
|
|
let db_user: DbUser = row.try_get("user_account")?;
|
|
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
|
|
let user = User::new(db_user, db_profile);
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn get_user_by_public_wallet_address(
|
|
db_client: &impl DatabaseClient,
|
|
currency: &Currency,
|
|
wallet_address: &str,
|
|
) -> Result<User, DatabaseError> {
|
|
let did_pkh = DidPkh::from_address(currency, wallet_address);
|
|
let did = Did::Pkh(did_pkh);
|
|
get_user_by_did(db_client, &did).await
|
|
}
|
|
|
|
pub async fn get_user_count(db_client: &impl DatabaseClient) -> Result<i64, DatabaseError> {
|
|
let row = db_client
|
|
.query_one("SELECT count(user_account) FROM user_account", &[])
|
|
.await?;
|
|
let count = row.try_get("count")?;
|
|
Ok(count)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::database::test_utils::create_test_database;
|
|
use crate::users::types::Role;
|
|
use serde_json::json;
|
|
use serial_test::serial;
|
|
|
|
#[tokio::test]
|
|
#[serial]
|
|
async fn test_create_invite_code() {
|
|
let db_client = &mut create_test_database().await;
|
|
let code = create_invite_code(db_client, Some("test")).await.unwrap();
|
|
assert_eq!(code.len(), 32);
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[serial]
|
|
async fn test_create_user() {
|
|
let db_client = &mut create_test_database().await;
|
|
let user_data = UserCreateData {
|
|
username: "myname".to_string(),
|
|
password_hash: Some("test".to_string()),
|
|
..Default::default()
|
|
};
|
|
let user = create_user(db_client, user_data).await.unwrap();
|
|
assert_eq!(user.profile.username, "myname");
|
|
assert_eq!(user.role, Role::NormalUser);
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[serial]
|
|
async fn test_create_user_impersonation_protection() {
|
|
let db_client = &mut create_test_database().await;
|
|
let user_data = UserCreateData {
|
|
username: "myname".to_string(),
|
|
password_hash: Some("test".to_string()),
|
|
..Default::default()
|
|
};
|
|
create_user(db_client, user_data).await.unwrap();
|
|
let another_user_data = UserCreateData {
|
|
username: "myName".to_string(),
|
|
password_hash: Some("test".to_string()),
|
|
..Default::default()
|
|
};
|
|
let result = create_user(db_client, another_user_data).await;
|
|
assert!(matches!(result, Err(DatabaseError::AlreadyExists("user"))));
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[serial]
|
|
async fn test_set_user_role() {
|
|
let db_client = &mut create_test_database().await;
|
|
let user_data = UserCreateData {
|
|
username: "test".to_string(),
|
|
password_hash: Some("test".to_string()),
|
|
..Default::default()
|
|
};
|
|
let user = create_user(db_client, user_data).await.unwrap();
|
|
assert_eq!(user.role, Role::NormalUser);
|
|
set_user_role(db_client, &user.id, Role::ReadOnlyUser)
|
|
.await
|
|
.unwrap();
|
|
let user = get_user_by_id(db_client, &user.id).await.unwrap();
|
|
assert_eq!(user.role, Role::ReadOnlyUser);
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[serial]
|
|
async fn test_update_client_config() {
|
|
let db_client = &mut create_test_database().await;
|
|
let user_data = UserCreateData {
|
|
username: "test".to_string(),
|
|
password_hash: Some("test".to_string()),
|
|
..Default::default()
|
|
};
|
|
let user = create_user(db_client, user_data).await.unwrap();
|
|
assert_eq!(user.client_config.is_empty(), true);
|
|
let client_name = "test";
|
|
let client_config_value = json!({"a": 1});
|
|
let client_config =
|
|
update_client_config(db_client, &user.id, client_name, &client_config_value)
|
|
.await
|
|
.unwrap();
|
|
assert_eq!(
|
|
client_config.get(client_name).unwrap(),
|
|
&client_config_value,
|
|
);
|
|
}
|
|
}
|