fedimovies/fedimovies-models/src/users/queries.rs
Rafael Caricio 7281340423
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Apply clippy suggestions
2023-04-28 00:01:24 +02:00

436 lines
12 KiB
Rust

use serde_json::Value as JsonValue;
use uuid::Uuid;
use fedimovies_utils::{currencies::Currency, did::Did, did_pkh::DidPkh};
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
use crate::profiles::{
queries::create_profile,
types::{DbActorProfile, ProfileCreateData},
};
use super::types::{
ClientConfig, DbClientConfig, DbInviteCode, DbUser, Role, User, UserCreateData,
};
use super::utils::generate_invite_code;
pub async fn create_invite_code(
db_client: &impl DatabaseClient,
note: Option<&str>,
) -> Result<String, DatabaseError> {
let invite_code = generate_invite_code();
db_client
.execute(
"
INSERT INTO user_invite_code (code, note)
VALUES ($1, $2)
",
&[&invite_code, &note],
)
.await?;
Ok(invite_code)
}
pub async fn get_invite_codes(
db_client: &impl DatabaseClient,
) -> Result<Vec<DbInviteCode>, DatabaseError> {
let rows = db_client
.query(
"
SELECT user_invite_code
FROM user_invite_code
WHERE used = FALSE
",
&[],
)
.await?;
let codes = rows
.iter()
.map(|row| row.try_get("user_invite_code"))
.collect::<Result<_, _>>()?;
Ok(codes)
}
pub async fn is_valid_invite_code(
db_client: &impl DatabaseClient,
invite_code: &str,
) -> Result<bool, DatabaseError> {
let maybe_row = db_client
.query_opt(
"
SELECT 1 FROM user_invite_code
WHERE code = $1 AND used = FALSE
",
&[&invite_code],
)
.await?;
Ok(maybe_row.is_some())
}
pub async fn create_user(
db_client: &mut impl DatabaseClient,
user_data: UserCreateData,
) -> Result<User, DatabaseError> {
assert!(user_data.password_hash.is_some() || user_data.wallet_address.is_some());
let mut transaction = db_client.transaction().await?;
// Prevent changes to actor_profile table
transaction
.execute("LOCK TABLE actor_profile IN EXCLUSIVE MODE", &[])
.await?;
// Ensure there are no local accounts with a similar name
let maybe_row = transaction
.query_opt(
"
SELECT 1
FROM user_account JOIN actor_profile USING (id)
WHERE actor_profile.username ILIKE $1
LIMIT 1
",
&[&user_data.username],
)
.await?;
if maybe_row.is_some() {
return Err(DatabaseError::AlreadyExists("user"));
};
// Use invite code
if let Some(ref invite_code) = user_data.invite_code {
let updated_count = transaction
.execute(
"
UPDATE user_invite_code
SET used = TRUE
WHERE code = $1 AND used = FALSE
",
&[&invite_code],
)
.await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("invite code"));
};
};
// Create profile
let profile_data = ProfileCreateData {
username: user_data.username.clone(),
hostname: None,
display_name: None,
bio: None,
avatar: None,
banner: None,
manually_approves_followers: false,
identity_proofs: vec![],
payment_options: vec![],
extra_fields: vec![],
aliases: vec![],
emojis: vec![],
actor_json: None,
};
let profile = create_profile(&mut transaction, profile_data).await?;
// Create user
let row = transaction
.query_one(
"
INSERT INTO user_account (
id,
wallet_address,
password_hash,
private_key,
invite_code,
user_role
)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING user_account
",
&[
&profile.id,
&user_data.wallet_address,
&user_data.password_hash,
&user_data.private_key_pem,
&user_data.invite_code,
&user_data.role,
],
)
.await
.map_err(catch_unique_violation("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let user = User::new(db_user, profile);
transaction.commit().await?;
Ok(user)
}
pub async fn set_user_password(
db_client: &impl DatabaseClient,
user_id: &Uuid,
password_hash: String,
) -> Result<(), DatabaseError> {
let updated_count = db_client
.execute(
"
UPDATE user_account SET password_hash = $1
WHERE id = $2
",
&[&password_hash, &user_id],
)
.await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("user"));
};
Ok(())
}
pub async fn set_user_role(
db_client: &impl DatabaseClient,
user_id: &Uuid,
role: Role,
) -> Result<(), DatabaseError> {
let updated_count = db_client
.execute(
"
UPDATE user_account SET user_role = $1
WHERE id = $2
",
&[&role, &user_id],
)
.await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("user"));
};
Ok(())
}
pub async fn update_client_config(
db_client: &impl DatabaseClient,
user_id: &Uuid,
client_name: &str,
client_config_value: &JsonValue,
) -> Result<ClientConfig, DatabaseError> {
let maybe_row = db_client
.query_opt(
"
UPDATE user_account
SET client_config = jsonb_set(client_config, ARRAY[$1], $2, true)
WHERE id = $3
RETURNING client_config
",
&[&client_name, &client_config_value, &user_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let client_config: DbClientConfig = row.try_get("client_config")?;
Ok(client_config.into_inner())
}
pub async fn get_user_by_id(
db_client: &impl DatabaseClient,
user_id: &Uuid,
) -> Result<User, DatabaseError> {
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id)
WHERE id = $1
",
&[&user_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
let user = User::new(db_user, db_profile);
Ok(user)
}
pub async fn get_user_by_name(
db_client: &impl DatabaseClient,
username: &str,
) -> Result<User, DatabaseError> {
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id)
WHERE lower(actor_profile.username) = lower($1)
",
&[&username],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
let user = User::new(db_user, db_profile);
Ok(user)
}
pub async fn is_registered_user(
db_client: &impl DatabaseClient,
username: &str,
) -> Result<bool, DatabaseError> {
let maybe_row = db_client
.query_opt(
"
SELECT 1 FROM user_account JOIN actor_profile USING (id)
WHERE actor_profile.username = $1
",
&[&username],
)
.await?;
Ok(maybe_row.is_some())
}
pub async fn get_user_by_login_address(
db_client: &impl DatabaseClient,
wallet_address: &str,
) -> Result<User, DatabaseError> {
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id)
WHERE wallet_address = $1
",
&[&wallet_address],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
let user = User::new(db_user, db_profile);
Ok(user)
}
pub async fn get_user_by_did(
db_client: &impl DatabaseClient,
did: &Did,
) -> Result<User, DatabaseError> {
// DIDs must be locally unique
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id)
WHERE
EXISTS (
SELECT 1
FROM jsonb_array_elements(actor_profile.identity_proofs) AS proof
WHERE proof ->> 'issuer' = $1
)
",
&[&did.to_string()],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
let user = User::new(db_user, db_profile);
Ok(user)
}
pub async fn get_user_by_public_wallet_address(
db_client: &impl DatabaseClient,
currency: &Currency,
wallet_address: &str,
) -> Result<User, DatabaseError> {
let did_pkh = DidPkh::from_address(currency, wallet_address);
let did = Did::Pkh(did_pkh);
get_user_by_did(db_client, &did).await
}
pub async fn get_user_count(db_client: &impl DatabaseClient) -> Result<i64, DatabaseError> {
let row = db_client
.query_one("SELECT count(user_account) FROM user_account", &[])
.await?;
let count = row.try_get("count")?;
Ok(count)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::database::test_utils::create_test_database;
use crate::users::types::Role;
use serde_json::json;
use serial_test::serial;
#[tokio::test]
#[serial]
async fn test_create_invite_code() {
let db_client = &mut create_test_database().await;
let code = create_invite_code(db_client, Some("test")).await.unwrap();
assert_eq!(code.len(), 32);
}
#[tokio::test]
#[serial]
async fn test_create_user() {
let db_client = &mut create_test_database().await;
let user_data = UserCreateData {
username: "myname".to_string(),
password_hash: Some("test".to_string()),
..Default::default()
};
let user = create_user(db_client, user_data).await.unwrap();
assert_eq!(user.profile.username, "myname");
assert_eq!(user.role, Role::NormalUser);
}
#[tokio::test]
#[serial]
async fn test_create_user_impersonation_protection() {
let db_client = &mut create_test_database().await;
let user_data = UserCreateData {
username: "myname".to_string(),
password_hash: Some("test".to_string()),
..Default::default()
};
create_user(db_client, user_data).await.unwrap();
let another_user_data = UserCreateData {
username: "myName".to_string(),
password_hash: Some("test".to_string()),
..Default::default()
};
let result = create_user(db_client, another_user_data).await;
assert!(matches!(result, Err(DatabaseError::AlreadyExists("user"))));
}
#[tokio::test]
#[serial]
async fn test_set_user_role() {
let db_client = &mut create_test_database().await;
let user_data = UserCreateData {
username: "test".to_string(),
password_hash: Some("test".to_string()),
..Default::default()
};
let user = create_user(db_client, user_data).await.unwrap();
assert_eq!(user.role, Role::NormalUser);
set_user_role(db_client, &user.id, Role::ReadOnlyUser)
.await
.unwrap();
let user = get_user_by_id(db_client, &user.id).await.unwrap();
assert_eq!(user.role, Role::ReadOnlyUser);
}
#[tokio::test]
#[serial]
async fn test_update_client_config() {
let db_client = &mut create_test_database().await;
let user_data = UserCreateData {
username: "test".to_string(),
password_hash: Some("test".to_string()),
..Default::default()
};
let user = create_user(db_client, user_data).await.unwrap();
assert!(user.client_config.is_empty());
let client_name = "test";
let client_config_value = json!({"a": 1});
let client_config =
update_client_config(db_client, &user.id, client_name, &client_config_value)
.await
.unwrap();
assert_eq!(
client_config.get(client_name).unwrap(),
&client_config_value,
);
}
}