Apply cargo fmt

This commit is contained in:
Rafael Caricio 2023-04-24 17:35:32 +02:00
parent b4ff7abbc1
commit 47529ff703
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
165 changed files with 3779 additions and 5132 deletions

View file

@ -22,7 +22,7 @@ async fn main() {
log::info!("config loaded from {}", config.config_path);
for warning in config_warnings {
log::warn!("{}", warning);
};
}
let db_config = config.database_url.parse().unwrap();
let db_client = &mut create_database_client(&db_config).await;
@ -39,19 +39,37 @@ async fn main() {
SubCommand::DeleteProfile(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeletePost(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeleteEmoji(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeleteExtraneousPosts(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeleteUnusedAttachments(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeleteOrphanedFiles(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeleteEmptyProfiles(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::PruneRemoteEmojis(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::ListUnreachableActors(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeleteExtraneousPosts(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::DeleteUnusedAttachments(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::DeleteOrphanedFiles(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::DeleteEmptyProfiles(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::PruneRemoteEmojis(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::ListUnreachableActors(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::ImportEmoji(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::UpdateCurrentBlock(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::ResetSubscriptions(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::UpdateCurrentBlock(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::ResetSubscriptions(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::CreateMoneroWallet(cmd) => cmd.execute(&config).await.unwrap(),
SubCommand::CheckExpiredInvoice(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::CheckExpiredInvoice(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
_ => unreachable!(),
};
},
}
};
}

View file

@ -1,6 +1,6 @@
use std::path::PathBuf;
use log::{Level as LogLevel};
use log::Level as LogLevel;
use rsa::RsaPrivateKey;
use serde::Deserialize;
use url::Url;
@ -14,9 +14,13 @@ use super::registration::RegistrationConfig;
use super::retention::RetentionConfig;
use super::REEF_VERSION;
fn default_log_level() -> LogLevel { LogLevel::Info }
fn default_log_level() -> LogLevel {
LogLevel::Info
}
fn default_login_message() -> String { "What?!".to_string() }
fn default_login_message() -> String {
"What?!".to_string()
}
#[derive(Clone, Deserialize)]
pub struct Config {
@ -95,9 +99,8 @@ impl Config {
onion_proxy_url: self.federation.onion_proxy_url.clone(),
i2p_proxy_url: self.federation.i2p_proxy_url.clone(),
// Private instance doesn't send activities and sign requests
is_private:
!self.federation.enabled ||
matches!(self.environment, Environment::Development),
is_private: !self.federation.enabled
|| matches!(self.environment, Environment::Development),
fetcher_timeout: self.federation.fetcher_timeout,
deliverer_timeout: self.federation.deliverer_timeout,
}
@ -139,8 +142,8 @@ impl Instance {
pub fn agent(&self) -> String {
format!(
"Reef {version}; {instance_url}",
version= REEF_VERSION,
instance_url=self.url(),
version = REEF_VERSION,
instance_url = self.url(),
)
}
}
@ -164,8 +167,8 @@ impl Instance {
#[cfg(test)]
mod tests {
use mitra_utils::crypto_rsa::generate_weak_rsa_key;
use super::*;
use mitra_utils::crypto_rsa::generate_weak_rsa_key;
#[test]
fn test_instance_url_https_dns() {

View file

@ -10,9 +10,13 @@ pub enum Environment {
impl Default for Environment {
#[cfg(feature = "production")]
fn default() -> Self { Self::Production }
fn default() -> Self {
Self::Production
}
#[cfg(not(feature = "production"))]
fn default() -> Self { Self::Development }
fn default() -> Self {
Self::Development
}
}
impl FromStr for Environment {

View file

@ -1,9 +1,15 @@
use serde::Deserialize;
fn default_federation_enabled() -> bool { true }
fn default_federation_enabled() -> bool {
true
}
const fn default_fetcher_timeout() -> u64 { 300 }
const fn default_deliverer_timeout() -> u64 { 30 }
const fn default_fetcher_timeout() -> u64 {
300
}
const fn default_deliverer_timeout() -> u64 {
30
}
#[derive(Clone, Deserialize)]
pub struct FederationConfig {

View file

@ -1,19 +1,17 @@
use regex::Regex;
use serde::{
Deserialize,
Deserializer,
de::{Error as DeserializerError},
};
use super::ConfigError;
use regex::Regex;
use serde::{de::Error as DeserializerError, Deserialize, Deserializer};
const FILE_SIZE_RE: &str = r#"^(?i)(?P<size>\d+)(?P<unit>[kmg]?)b?$"#;
fn parse_file_size(value: &str) -> Result<usize, ConfigError> {
let file_size_re = Regex::new(FILE_SIZE_RE)
.expect("regexp should be valid");
let caps = file_size_re.captures(value)
let file_size_re = Regex::new(FILE_SIZE_RE).expect("regexp should be valid");
let caps = file_size_re
.captures(value)
.ok_or(ConfigError("invalid file size"))?;
let size: usize = caps["size"].to_string().parse()
let size: usize = caps["size"]
.to_string()
.parse()
.map_err(|_| ConfigError("invalid file size"))?;
let unit = caps["unit"].to_string().to_lowercase();
let multiplier = match unit.as_str() {
@ -26,31 +24,33 @@ fn parse_file_size(value: &str) -> Result<usize, ConfigError> {
Ok(size * multiplier)
}
fn deserialize_file_size<'de, D>(
deserializer: D,
) -> Result<usize, D::Error>
where D: Deserializer<'de>
fn deserialize_file_size<'de, D>(deserializer: D) -> Result<usize, D::Error>
where
D: Deserializer<'de>,
{
let file_size_str = String::deserialize(deserializer)?;
let file_size = parse_file_size(&file_size_str)
.map_err(DeserializerError::custom)?;
let file_size = parse_file_size(&file_size_str).map_err(DeserializerError::custom)?;
Ok(file_size)
}
const fn default_file_size_limit() -> usize { 20_000_000 } // 20 MB
const fn default_emoji_size_limit() -> usize { 500_000 } // 500 kB
const fn default_file_size_limit() -> usize {
20_000_000
} // 20 MB
const fn default_emoji_size_limit() -> usize {
500_000
} // 500 kB
#[derive(Clone, Deserialize)]
pub struct MediaLimits {
#[serde(
default = "default_file_size_limit",
deserialize_with = "deserialize_file_size",
deserialize_with = "deserialize_file_size"
)]
pub file_size_limit: usize,
#[serde(
default = "default_emoji_size_limit",
deserialize_with = "deserialize_file_size",
deserialize_with = "deserialize_file_size"
)]
pub emoji_size_limit: usize,
}
@ -64,7 +64,9 @@ impl Default for MediaLimits {
}
}
const fn default_post_character_limit() -> usize { 2000 }
const fn default_post_character_limit() -> usize {
2000
}
#[derive(Clone, Deserialize)]
pub struct PostLimits {

View file

@ -5,11 +5,7 @@ use std::str::FromStr;
use rsa::RsaPrivateKey;
use mitra_utils::{
crypto_rsa::{
deserialize_private_key,
generate_rsa_key,
serialize_private_key,
},
crypto_rsa::{deserialize_private_key, generate_rsa_key, serialize_private_key},
files::{set_file_permissions, write_file},
};
@ -30,9 +26,9 @@ const DEFAULT_CONFIG_PATH: &str = "config.yaml";
fn parse_env() -> EnvConfig {
dotenv::from_filename(".env.local").ok();
dotenv::dotenv().ok();
let config_path = std::env::var("CONFIG_PATH")
.unwrap_or(DEFAULT_CONFIG_PATH.to_string());
let environment = std::env::var("ENVIRONMENT").ok()
let config_path = std::env::var("CONFIG_PATH").unwrap_or(DEFAULT_CONFIG_PATH.to_string());
let environment = std::env::var("ENVIRONMENT")
.ok()
.map(|val| Environment::from_str(&val).expect("invalid environment type"));
EnvConfig {
config_path,
@ -45,8 +41,7 @@ extern "C" {
}
fn check_directory_owner(path: &Path) -> () {
let metadata = std::fs::metadata(path)
.expect("can't read file metadata");
let metadata = std::fs::metadata(path).expect("can't read file metadata");
let owner_uid = metadata.uid();
let current_uid = unsafe { geteuid() };
if owner_uid != current_uid {
@ -63,16 +58,15 @@ fn check_directory_owner(path: &Path) -> () {
fn read_instance_rsa_key(storage_dir: &Path) -> RsaPrivateKey {
let private_key_path = storage_dir.join("instance_rsa_key");
if private_key_path.exists() {
let private_key_str = std::fs::read_to_string(&private_key_path)
.expect("failed to read instance RSA key");
let private_key = deserialize_private_key(&private_key_str)
.expect("failed to read instance RSA key");
let private_key_str =
std::fs::read_to_string(&private_key_path).expect("failed to read instance RSA key");
let private_key =
deserialize_private_key(&private_key_str).expect("failed to read instance RSA key");
private_key
} else {
let private_key = generate_rsa_key()
.expect("failed to generate RSA key");
let private_key_str = serialize_private_key(&private_key)
.expect("failed to serialize RSA key");
let private_key = generate_rsa_key().expect("failed to generate RSA key");
let private_key_str =
serialize_private_key(&private_key).expect("failed to serialize RSA key");
write_file(private_key_str.as_bytes(), &private_key_path)
.expect("failed to write instance RSA key");
set_file_permissions(&private_key_path, 0o600)
@ -83,10 +77,9 @@ fn read_instance_rsa_key(storage_dir: &Path) -> RsaPrivateKey {
pub fn parse_config() -> (Config, Vec<&'static str>) {
let env = parse_env();
let config_yaml = std::fs::read_to_string(&env.config_path)
.expect("failed to load config file");
let mut config = serde_yaml::from_str::<Config>(&config_yaml)
.expect("invalid yaml data");
let config_yaml =
std::fs::read_to_string(&env.config_path).expect("failed to load config file");
let mut config = serde_yaml::from_str::<Config>(&config_yaml).expect("invalid yaml data");
let mut warnings = vec![];
// Set parameters from environment
@ -109,7 +102,8 @@ pub fn parse_config() -> (Config, Vec<&'static str>) {
// Migrations
if let Some(registrations_open) = config.registrations_open {
// Change type if 'registrations_open' parameter is used
warnings.push("'registrations_open' setting is deprecated, use 'registration.type' instead");
warnings
.push("'registrations_open' setting is deprecated, use 'registration.type' instead");
if registrations_open {
config.registration.registration_type = RegistrationType::Open;
} else {

View file

@ -1,8 +1,4 @@
use serde::{
Deserialize,
Deserializer,
de::Error as DeserializerError,
};
use serde::{de::Error as DeserializerError, Deserialize, Deserializer};
#[derive(Clone, PartialEq)]
pub enum RegistrationType {
@ -11,12 +7,15 @@ pub enum RegistrationType {
}
impl Default for RegistrationType {
fn default() -> Self { Self::Invite }
fn default() -> Self {
Self::Invite
}
}
impl<'de> Deserialize<'de> for RegistrationType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>
where
D: Deserializer<'de>,
{
let registration_type_str = String::deserialize(deserializer)?;
let registration_type = match registration_type_str.as_str() {
@ -35,12 +34,15 @@ pub enum DefaultRole {
}
impl Default for DefaultRole {
fn default() -> Self { Self::NormalUser }
fn default() -> Self {
Self::NormalUser
}
}
impl<'de> Deserialize<'de> for DefaultRole {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>
where
D: Deserializer<'de>,
{
let role_str = String::deserialize(deserializer)?;
let role = match role_str.as_str() {

View file

@ -3,11 +3,7 @@ use uuid::Uuid;
use mitra_utils::id::generate_ulid;
use crate::cleanup::{
find_orphaned_files,
find_orphaned_ipfs_objects,
DeletionQueue,
};
use crate::cleanup::{find_orphaned_files, find_orphaned_ipfs_objects, DeletionQueue};
use crate::database::{DatabaseClient, DatabaseError};
use super::types::DbMediaAttachment;
@ -20,10 +16,10 @@ pub async fn create_attachment(
media_type: Option<String>,
) -> Result<DbMediaAttachment, DatabaseError> {
let attachment_id = generate_ulid();
let file_size: i32 = file_size.try_into()
.expect("value should be within bounds");
let inserted_row = db_client.query_one(
"
let file_size: i32 = file_size.try_into().expect("value should be within bounds");
let inserted_row = db_client
.query_one(
"
INSERT INTO media_attachment (
id,
owner_id,
@ -34,14 +30,15 @@ pub async fn create_attachment(
VALUES ($1, $2, $3, $4, $5)
RETURNING media_attachment
",
&[
&attachment_id,
&owner_id,
&file_name,
&file_size,
&media_type,
],
).await?;
&[
&attachment_id,
&owner_id,
&file_name,
&file_size,
&media_type,
],
)
.await?;
let db_attachment: DbMediaAttachment = inserted_row.try_get("media_attachment")?;
Ok(db_attachment)
}
@ -51,15 +48,17 @@ pub async fn set_attachment_ipfs_cid(
attachment_id: &Uuid,
ipfs_cid: &str,
) -> Result<DbMediaAttachment, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
UPDATE media_attachment
SET ipfs_cid = $1
WHERE id = $2 AND ipfs_cid IS NULL
RETURNING media_attachment
",
&[&ipfs_cid, &attachment_id],
).await?;
&[&ipfs_cid, &attachment_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("attachment"))?;
let db_attachment = row.try_get("media_attachment")?;
Ok(db_attachment)
@ -69,14 +68,16 @@ pub async fn delete_unused_attachments(
db_client: &impl DatabaseClient,
created_before: &DateTime<Utc>,
) -> Result<DeletionQueue, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
DELETE FROM media_attachment
WHERE post_id IS NULL AND created_at < $1
RETURNING file_name, ipfs_cid
",
&[&created_before],
).await?;
&[&created_before],
)
.await?;
let mut files = vec![];
let mut ipfs_objects = vec![];
for row in rows {
@ -85,7 +86,7 @@ pub async fn delete_unused_attachments(
if let Some(ipfs_cid) = row.try_get("ipfs_cid")? {
ipfs_objects.push(ipfs_cid);
};
};
}
let orphaned_files = find_orphaned_files(db_client, files).await?;
let orphaned_ipfs_objects = find_orphaned_ipfs_objects(db_client, ipfs_objects).await?;
Ok(DeletionQueue {
@ -96,13 +97,10 @@ pub async fn delete_unused_attachments(
#[cfg(test)]
mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use crate::profiles::{
queries::create_profile,
types::ProfileCreateData,
};
use super::*;
use crate::database::test_utils::create_test_database;
use crate::profiles::{queries::create_profile, types::ProfileCreateData};
use serial_test::serial;
#[tokio::test]
#[serial]
@ -122,7 +120,9 @@ mod tests {
file_name.to_string(),
file_size,
Some(media_type.to_string()),
).await.unwrap();
)
.await
.unwrap();
assert_eq!(attachment.owner_id, profile.id);
assert_eq!(attachment.file_name, file_name);
assert_eq!(attachment.file_size.unwrap(), file_size as i32);

View file

@ -35,7 +35,7 @@ impl AttachmentType {
} else {
Self::Unknown
}
},
}
None => Self::Unknown,
}
}

View file

@ -2,8 +2,8 @@ use chrono::{DateTime, Utc};
use serde_json::Value;
use uuid::Uuid;
use crate::database::{DatabaseClient, DatabaseError};
use super::types::{DbBackgroundJob, JobStatus, JobType};
use crate::database::{DatabaseClient, DatabaseError};
pub async fn enqueue_job(
db_client: &impl DatabaseClient,
@ -12,8 +12,9 @@ pub async fn enqueue_job(
scheduled_for: &DateTime<Utc>,
) -> Result<(), DatabaseError> {
let job_id = Uuid::new_v4();
db_client.execute(
"
db_client
.execute(
"
INSERT INTO background_job (
id,
job_type,
@ -22,8 +23,9 @@ pub async fn enqueue_job(
)
VALUES ($1, $2, $3, $4)
",
&[&job_id, &job_type, &job_data, &scheduled_for],
).await?;
&[&job_id, &job_type, &job_data, &scheduled_for],
)
.await?;
Ok(())
}
@ -35,8 +37,9 @@ pub async fn get_job_batch(
) -> Result<Vec<DbBackgroundJob>, DatabaseError> {
// https://github.com/sfackler/rust-postgres/issues/60
let job_timeout_pg = format!("{}S", job_timeout); // interval
let rows = db_client.query(
"
let rows = db_client
.query(
"
UPDATE background_job
SET
job_status = $1,
@ -62,15 +65,17 @@ pub async fn get_job_batch(
)
RETURNING background_job
",
&[
&JobStatus::Running,
&job_type,
&JobStatus::Queued,
&i64::from(batch_size),
&job_timeout_pg,
],
).await?;
let jobs = rows.iter()
&[
&JobStatus::Running,
&job_type,
&JobStatus::Queued,
&i64::from(batch_size),
&job_timeout_pg,
],
)
.await?;
let jobs = rows
.iter()
.map(|row| row.try_get("background_job"))
.collect::<Result<_, _>>()?;
Ok(jobs)
@ -80,13 +85,15 @@ pub async fn delete_job_from_queue(
db_client: &impl DatabaseClient,
job_id: &Uuid,
) -> Result<(), DatabaseError> {
let deleted_count = db_client.execute(
"
let deleted_count = db_client
.execute(
"
DELETE FROM background_job
WHERE id = $1
",
&[&job_id],
).await?;
&[&job_id],
)
.await?;
if deleted_count == 0 {
return Err(DatabaseError::NotFound("background job"));
};
@ -95,10 +102,10 @@ pub async fn delete_job_from_queue(
#[cfg(test)]
mod tests {
use super::*;
use crate::database::test_utils::create_test_database;
use serde_json::json;
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use super::*;
#[tokio::test]
#[serial]
@ -111,7 +118,9 @@ mod tests {
"failure_count": 0,
});
let scheduled_for = Utc::now();
enqueue_job(db_client, &job_type, &job_data, &scheduled_for).await.unwrap();
enqueue_job(db_client, &job_type, &job_data, &scheduled_for)
.await
.unwrap();
let batch_1 = get_job_batch(db_client, &job_type, 10, 3600).await.unwrap();
assert_eq!(batch_1.len(), 1);

View file

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc};
use serde_json::Value;
use postgres_types::FromSql;
use serde_json::Value;
use uuid::Uuid;
use crate::database::{

View file

@ -9,8 +9,9 @@ pub async fn find_orphaned_files(
db_client: &impl DatabaseClient,
files: Vec<String>,
) -> Result<Vec<String>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT DISTINCT fname
FROM unnest($1::text[]) AS fname
WHERE
@ -27,9 +28,11 @@ pub async fn find_orphaned_files(
WHERE image ->> 'file_name' = fname
)
",
&[&files],
).await?;
let orphaned_files = rows.iter()
&[&files],
)
.await?;
let orphaned_files = rows
.iter()
.map(|row| row.try_get("fname"))
.collect::<Result<_, _>>()?;
Ok(orphaned_files)
@ -39,8 +42,9 @@ pub async fn find_orphaned_ipfs_objects(
db_client: &impl DatabaseClient,
ipfs_objects: Vec<String>,
) -> Result<Vec<String>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT DISTINCT cid
FROM unnest($1::text[]) AS cid
WHERE
@ -51,9 +55,11 @@ pub async fn find_orphaned_ipfs_objects(
SELECT 1 FROM post WHERE ipfs_cid = cid
)
",
&[&ipfs_objects],
).await?;
let orphaned_ipfs_objects = rows.iter()
&[&ipfs_objects],
)
.await?;
let orphaned_ipfs_objects = rows
.iter()
.map(|row| row.try_get("cid"))
.collect::<Result<_, _>>()?;
Ok(orphaned_ipfs_objects)

View file

@ -12,7 +12,7 @@ macro_rules! int_enum_from_sql {
postgres_types::accepts!(INT2);
}
}
};
}
macro_rules! int_enum_to_sql {
@ -31,7 +31,7 @@ macro_rules! int_enum_to_sql {
postgres_types::accepts!(INT2);
postgres_types::to_sql_checked!();
}
}
};
}
pub(crate) use {int_enum_from_sql, int_enum_to_sql};

View file

@ -14,7 +14,7 @@ macro_rules! json_from_sql {
postgres_types::accepts!(JSON, JSONB);
}
}
};
}
/// Implements ToSql trait for any serializable type
@ -33,7 +33,7 @@ macro_rules! json_to_sql {
postgres_types::accepts!(JSON, JSONB);
postgres_types::to_sql_checked!();
}
}
};
}
pub(crate) use {json_from_sql, json_to_sql};

View file

@ -8,7 +8,8 @@ mod embedded {
pub async fn apply_migrations(db_client: &mut Client) {
let migration_report = embedded::migrations::runner()
.run_async(db_client)
.await.unwrap();
.await
.unwrap();
for migration in migration_report.applied_migrations() {
log::info!(

View file

@ -1,4 +1,4 @@
use tokio_postgres::config::{Config as DatabaseConfig};
use tokio_postgres::config::Config as DatabaseConfig;
use tokio_postgres::error::{Error as PgError, SqlState};
pub mod int_enum;
@ -10,7 +10,7 @@ pub mod query_macro;
pub mod test_utils;
pub type DbPool = deadpool_postgres::Pool;
pub use tokio_postgres::{GenericClient as DatabaseClient};
pub use tokio_postgres::GenericClient as DatabaseClient;
#[derive(thiserror::Error, Debug)]
#[error("database type error")]
@ -37,11 +37,8 @@ pub enum DatabaseError {
AlreadyExists(&'static str), // object type
}
pub async fn create_database_client(db_config: &DatabaseConfig)
-> tokio_postgres::Client
{
let (client, connection) = db_config.connect(tokio_postgres::NoTls)
.await.unwrap();
pub async fn create_database_client(db_config: &DatabaseConfig) -> tokio_postgres::Client {
let (client, connection) = db_config.connect(tokio_postgres::NoTls).await.unwrap();
tokio::spawn(async move {
if let Err(err) = connection.await {
log::error!("connection error: {}", err);
@ -55,21 +52,22 @@ pub fn create_pool(database_url: &str, pool_size: usize) -> DbPool {
database_url.parse().expect("invalid database URL"),
tokio_postgres::NoTls,
);
DbPool::builder(manager).max_size(pool_size).build().unwrap()
DbPool::builder(manager)
.max_size(pool_size)
.build()
.unwrap()
}
pub async fn get_database_client(db_pool: &DbPool)
-> Result<deadpool_postgres::Client, DatabaseError>
{
pub async fn get_database_client(
db_pool: &DbPool,
) -> Result<deadpool_postgres::Client, DatabaseError> {
// Returns wrapped client
// https://github.com/bikeshedder/deadpool/issues/56
let client = db_pool.get().await?;
Ok(client)
}
pub fn catch_unique_violation(
object_type: &'static str,
) -> impl Fn(PgError) -> DatabaseError {
pub fn catch_unique_violation(object_type: &'static str) -> impl Fn(PgError) -> DatabaseError {
move |err| {
if let Some(code) = err.code() {
if code == &SqlState::UNIQUE_VIOLATION {

View file

@ -1,31 +1,30 @@
use tokio_postgres::Client;
use tokio_postgres::config::Config;
use super::create_database_client;
use super::migrate::apply_migrations;
use tokio_postgres::config::Config;
use tokio_postgres::Client;
const DEFAULT_CONNECTION_URL: &str = "postgres://mitra:mitra@127.0.0.1:55432/mitra-test";
pub async fn create_test_database() -> Client {
let connection_url = std::env::var("TEST_DATABASE_URL")
.unwrap_or(DEFAULT_CONNECTION_URL.to_string());
let mut db_config: Config = connection_url.parse()
let connection_url =
std::env::var("TEST_DATABASE_URL").unwrap_or(DEFAULT_CONNECTION_URL.to_string());
let mut db_config: Config = connection_url
.parse()
.expect("invalid database connection URL");
let db_name = db_config.get_dbname()
let db_name = db_config
.get_dbname()
.expect("database name not specified")
.to_string();
// Create connection without database name
db_config.dbname("");
let db_client = create_database_client(&db_config).await;
let drop_db_statement = format!(
"DROP DATABASE IF EXISTS {db_name:?}",
db_name=db_name,
);
let drop_db_statement = format!("DROP DATABASE IF EXISTS {db_name:?}", db_name = db_name,);
db_client.execute(&drop_db_statement, &[]).await.unwrap();
let create_db_statement = format!(
"CREATE DATABASE {db_name:?} WITH OWNER={owner:?};",
db_name=db_name,
owner=db_config.get_user().unwrap(),
db_name = db_name,
owner = db_config.get_user().unwrap(),
);
db_client.execute(&create_db_statement, &[]).await.unwrap();

View file

@ -1,10 +1,7 @@
use crate::database::{DatabaseClient, DatabaseError};
use super::queries::{get_emoji_by_name_and_hostname, get_local_emoji_by_name};
use super::types::DbEmoji;
use super::queries::{
get_local_emoji_by_name,
get_emoji_by_name_and_hostname,
};
pub async fn get_emoji_by_name(
db_client: &impl DatabaseClient,

View file

@ -4,11 +4,7 @@ use uuid::Uuid;
use mitra_utils::id::generate_ulid;
use crate::cleanup::{find_orphaned_files, DeletionQueue};
use crate::database::{
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
use crate::instances::queries::create_instance;
use crate::profiles::queries::update_emoji_caches;
@ -26,8 +22,9 @@ pub async fn create_emoji(
if let Some(hostname) = hostname {
create_instance(db_client, hostname).await?;
};
let row = db_client.query_one(
"
let row = db_client
.query_one(
"
INSERT INTO emoji (
id,
emoji_name,
@ -39,15 +36,17 @@ pub async fn create_emoji(
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING emoji
",
&[
&emoji_id,
&emoji_name,
&hostname,
&image,
&object_id,
&updated_at,
],
).await.map_err(catch_unique_violation("emoji"))?;
&[
&emoji_id,
&emoji_name,
&hostname,
&image,
&object_id,
&updated_at,
],
)
.await
.map_err(catch_unique_violation("emoji"))?;
let emoji = row.try_get("emoji")?;
Ok(emoji)
}
@ -58,8 +57,9 @@ pub async fn update_emoji(
image: EmojiImage,
updated_at: &DateTime<Utc>,
) -> Result<DbEmoji, DatabaseError> {
let row = db_client.query_one(
"
let row = db_client
.query_one(
"
UPDATE emoji
SET
image = $1,
@ -67,12 +67,9 @@ pub async fn update_emoji(
WHERE id = $3
RETURNING emoji
",
&[
&image,
&updated_at,
&emoji_id,
],
).await?;
&[&image, &updated_at, &emoji_id],
)
.await?;
let emoji: DbEmoji = row.try_get("emoji")?;
update_emoji_caches(db_client, &emoji.id).await?;
Ok(emoji)
@ -82,14 +79,16 @@ pub async fn get_local_emoji_by_name(
db_client: &impl DatabaseClient,
emoji_name: &str,
) -> Result<DbEmoji, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT emoji
FROM emoji
WHERE hostname IS NULL AND emoji_name = $1
",
&[&emoji_name],
).await?;
&[&emoji_name],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?;
let emoji = row.try_get("emoji")?;
Ok(emoji)
@ -99,15 +98,18 @@ pub async fn get_local_emojis_by_names(
db_client: &impl DatabaseClient,
names: &[String],
) -> Result<Vec<DbEmoji>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT emoji
FROM emoji
WHERE hostname IS NULL AND emoji_name = ANY($1)
",
&[&names],
).await?;
let emojis = rows.iter()
&[&names],
)
.await?;
let emojis = rows
.iter()
.map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?;
Ok(emojis)
@ -116,15 +118,18 @@ pub async fn get_local_emojis_by_names(
pub async fn get_local_emojis(
db_client: &impl DatabaseClient,
) -> Result<Vec<DbEmoji>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT emoji
FROM emoji
WHERE hostname IS NULL
",
&[],
).await?;
let emojis = rows.iter()
&[],
)
.await?;
let emojis = rows
.iter()
.map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?;
Ok(emojis)
@ -135,13 +140,15 @@ pub async fn get_emoji_by_name_and_hostname(
emoji_name: &str,
hostname: &str,
) -> Result<DbEmoji, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT emoji
FROM emoji WHERE emoji_name = $1 AND hostname = $2
",
&[&emoji_name, &hostname],
).await?;
&[&emoji_name, &hostname],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?;
let emoji = row.try_get("emoji")?;
Ok(emoji)
@ -151,13 +158,15 @@ pub async fn get_emoji_by_remote_object_id(
db_client: &impl DatabaseClient,
object_id: &str,
) -> Result<DbEmoji, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT emoji
FROM emoji WHERE object_id = $1
",
&[&object_id],
).await?;
&[&object_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?;
let emoji = row.try_get("emoji")?;
Ok(emoji)
@ -167,20 +176,19 @@ pub async fn delete_emoji(
db_client: &impl DatabaseClient,
emoji_id: &Uuid,
) -> Result<DeletionQueue, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
DELETE FROM emoji WHERE id = $1
RETURNING emoji
",
&[&emoji_id],
).await?;
&[&emoji_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?;
let emoji: DbEmoji = row.try_get("emoji")?;
update_emoji_caches(db_client, &emoji.id).await?;
let orphaned_files = find_orphaned_files(
db_client,
vec![emoji.image.file_name],
).await?;
let orphaned_files = find_orphaned_files(db_client, vec![emoji.image.file_name]).await?;
Ok(DeletionQueue {
files: orphaned_files,
ipfs_objects: vec![],
@ -190,8 +198,9 @@ pub async fn delete_emoji(
pub async fn find_unused_remote_emojis(
db_client: &impl DatabaseClient,
) -> Result<Vec<Uuid>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT emoji.id
FROM emoji
WHERE
@ -207,9 +216,11 @@ pub async fn find_unused_remote_emojis(
WHERE profile_emoji.emoji_id = emoji.id
)
",
&[],
).await?;
let ids: Vec<Uuid> = rows.iter()
&[],
)
.await?;
let ids: Vec<Uuid> = rows
.iter()
.map(|row| row.try_get("id"))
.collect::<Result<_, _>>()?;
Ok(ids)
@ -217,9 +228,9 @@ pub async fn find_unused_remote_emojis(
#[cfg(test)]
mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use super::*;
use crate::database::test_utils::create_test_database;
use serial_test::serial;
#[tokio::test]
#[serial]
@ -241,11 +252,12 @@ mod tests {
image,
Some(object_id),
&updated_at,
).await.unwrap();
let emoji = get_emoji_by_remote_object_id(
db_client,
object_id,
).await.unwrap();
)
.await
.unwrap();
let emoji = get_emoji_by_remote_object_id(db_client, object_id)
.await
.unwrap();
assert_eq!(emoji.id, emoji_id);
assert_eq!(emoji.emoji_name, emoji_name);
assert_eq!(emoji.hostname, Some(hostname.to_string()));
@ -256,20 +268,12 @@ mod tests {
async fn test_update_emoji() {
let db_client = &create_test_database().await;
let image = EmojiImage::default();
let emoji = create_emoji(
db_client,
"test",
None,
image.clone(),
None,
&Utc::now(),
).await.unwrap();
let updated_emoji = update_emoji(
db_client,
&emoji.id,
image,
&Utc::now(),
).await.unwrap();
let emoji = create_emoji(db_client, "test", None, image.clone(), None, &Utc::now())
.await
.unwrap();
let updated_emoji = update_emoji(db_client, &emoji.id, image, &Utc::now())
.await
.unwrap();
assert_ne!(updated_emoji.updated_at, emoji.updated_at);
}
@ -278,14 +282,9 @@ mod tests {
async fn test_delete_emoji() {
let db_client = &create_test_database().await;
let image = EmojiImage::default();
let emoji = create_emoji(
db_client,
"test",
None,
image,
None,
&Utc::now(),
).await.unwrap();
let emoji = create_emoji(db_client, "test", None, image, None, &Utc::now())
.await
.unwrap();
let deletion_queue = delete_emoji(db_client, &emoji.id).await.unwrap();
assert_eq!(deletion_queue.files.len(), 1);
assert_eq!(deletion_queue.ipfs_objects.len(), 0);

View file

@ -6,7 +6,9 @@ use uuid::Uuid;
use crate::database::json_macro::{json_from_sql, json_to_sql};
// Migration
fn default_emoji_file_size() -> usize { 250 * 1000 }
fn default_emoji_file_size() -> usize {
250 * 1000
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(feature = "test-utils", derive(Default))]

View file

@ -4,38 +4,38 @@ pub async fn create_instance(
db_client: &impl DatabaseClient,
hostname: &str,
) -> Result<(), DatabaseError> {
db_client.execute(
"
db_client
.execute(
"
INSERT INTO instance VALUES ($1)
ON CONFLICT DO NOTHING
",
&[&hostname],
).await?;
&[&hostname],
)
.await?;
Ok(())
}
pub async fn get_peers(
db_client: &impl DatabaseClient,
) -> Result<Vec<String>, DatabaseError> {
let rows = db_client.query(
"
pub async fn get_peers(db_client: &impl DatabaseClient) -> Result<Vec<String>, DatabaseError> {
let rows = db_client
.query(
"
SELECT instance.hostname FROM instance
",
&[],
).await?;
let peers = rows.iter()
&[],
)
.await?;
let peers = rows
.iter()
.map(|row| row.try_get("hostname"))
.collect::<Result<_, _>>()?;
Ok(peers)
}
pub async fn get_peer_count(
db_client: &impl DatabaseClient,
) -> Result<i64, DatabaseError> {
let row = db_client.query_one(
"SELECT count(instance) FROM instance",
&[],
).await?;
pub async fn get_peer_count(db_client: &impl DatabaseClient) -> Result<i64, DatabaseError> {
let row = db_client
.query_one("SELECT count(instance) FROM instance", &[])
.await?;
let count = row.try_get("count")?;
Ok(count)
}

View file

@ -1,15 +1,8 @@
use uuid::Uuid;
use mitra_utils::{
caip2::ChainId,
id::generate_ulid,
};
use mitra_utils::{caip2::ChainId, id::generate_ulid};
use crate::database::{
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
use super::types::{DbChainId, DbInvoice, InvoiceStatus};
@ -22,8 +15,9 @@ pub async fn create_invoice(
amount: i64,
) -> Result<DbInvoice, DatabaseError> {
let invoice_id = generate_ulid();
let row = db_client.query_one(
"
let row = db_client
.query_one(
"
INSERT INTO invoice (
id,
sender_id,
@ -35,15 +29,17 @@ pub async fn create_invoice(
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING invoice
",
&[
&invoice_id,
&sender_id,
&recipient_id,
&DbChainId::new(chain_id),
&payment_address,
&amount,
],
).await.map_err(catch_unique_violation("invoice"))?;
&[
&invoice_id,
&sender_id,
&recipient_id,
&DbChainId::new(chain_id),
&payment_address,
&amount,
],
)
.await
.map_err(catch_unique_violation("invoice"))?;
let invoice = row.try_get("invoice")?;
Ok(invoice)
}
@ -52,13 +48,15 @@ pub async fn get_invoice_by_id(
db_client: &impl DatabaseClient,
invoice_id: &Uuid,
) -> Result<DbInvoice, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT invoice
FROM invoice WHERE id = $1
",
&[&invoice_id],
).await?;
&[&invoice_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("invoice"))?;
let invoice = row.try_get("invoice")?;
Ok(invoice)
@ -69,13 +67,15 @@ pub async fn get_invoice_by_address(
chain_id: &ChainId,
payment_address: &str,
) -> Result<DbInvoice, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT invoice
FROM invoice WHERE chain_id = $1 AND payment_address = $2
",
&[&DbChainId::new(chain_id), &payment_address],
).await?;
&[&DbChainId::new(chain_id), &payment_address],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("invoice"))?;
let invoice = row.try_get("invoice")?;
Ok(invoice)
@ -86,14 +86,17 @@ pub async fn get_invoices_by_status(
chain_id: &ChainId,
status: InvoiceStatus,
) -> Result<Vec<DbInvoice>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT invoice
FROM invoice WHERE chain_id = $1 AND invoice_status = $2
",
&[&DbChainId::new(chain_id), &status],
).await?;
let invoices = rows.iter()
&[&DbChainId::new(chain_id), &status],
)
.await?;
let invoices = rows
.iter()
.map(|row| row.try_get("invoice"))
.collect::<Result<_, _>>()?;
Ok(invoices)
@ -104,13 +107,15 @@ pub async fn set_invoice_status(
invoice_id: &Uuid,
status: InvoiceStatus,
) -> Result<(), DatabaseError> {
let updated_count = db_client.execute(
"
let updated_count = db_client
.execute(
"
UPDATE invoice SET invoice_status = $2
WHERE id = $1
",
&[&invoice_id, &status],
).await?;
&[&invoice_id, &status],
)
.await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("invoice"));
};
@ -119,17 +124,11 @@ pub async fn set_invoice_status(
#[cfg(test)]
mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use crate::profiles::{
queries::create_profile,
types::ProfileCreateData,
};
use crate::users::{
queries::create_user,
types::UserCreateData,
};
use super::*;
use crate::database::test_utils::create_test_database;
use crate::profiles::{queries::create_profile, types::ProfileCreateData};
use crate::users::{queries::create_user, types::UserCreateData};
use serial_test::serial;
#[tokio::test]
#[serial]
@ -159,7 +158,9 @@ mod tests {
&chain_id,
payment_address,
amount,
).await.unwrap();
)
.await
.unwrap();
assert_eq!(invoice.sender_id, sender.id);
assert_eq!(invoice.recipient_id, recipient.id);
assert_eq!(invoice.chain_id.into_inner(), chain_id);

View file

@ -1,14 +1,6 @@
use chrono::{DateTime, Utc};
use postgres_protocol::types::{text_from_sql, text_to_sql};
use postgres_types::{
accepts,
private::BytesMut,
to_sql_checked,
FromSql,
IsNull,
ToSql,
Type,
};
use postgres_types::{accepts, private::BytesMut, to_sql_checked, FromSql, IsNull, ToSql, Type};
use uuid::Uuid;
use mitra_utils::caip2::ChainId;
@ -44,10 +36,7 @@ impl PartialEq<ChainId> for DbChainId {
}
impl<'a> FromSql<'a> for DbChainId {
fn from_sql(
_: &Type,
raw: &'a [u8],
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let value_str = text_from_sql(raw)?;
let value: ChainId = value_str.parse()?;
Ok(Self(value))

View file

@ -1,7 +1,7 @@
pub mod attachments;
pub mod background_jobs;
pub mod database;
pub mod cleanup;
pub mod database;
pub mod emojis;
pub mod instances;
pub mod invoices;

View file

@ -1,7 +1,7 @@
use uuid::Uuid;
use crate::database::{DatabaseClient, DatabaseError};
use super::types::{DbTimelineMarker, Timeline};
use crate::database::{DatabaseClient, DatabaseError};
pub async fn create_or_update_marker(
db_client: &impl DatabaseClient,
@ -9,16 +9,18 @@ pub async fn create_or_update_marker(
timeline: Timeline,
last_read_id: String,
) -> Result<DbTimelineMarker, DatabaseError> {
let row = db_client.query_one(
"
let row = db_client
.query_one(
"
INSERT INTO timeline_marker (user_id, timeline, last_read_id)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, timeline) DO UPDATE
SET last_read_id = $3, updated_at = now()
RETURNING timeline_marker
",
&[&user_id, &timeline, &last_read_id],
).await?;
&[&user_id, &timeline, &last_read_id],
)
.await?;
let marker = row.try_get("timeline_marker")?;
Ok(marker)
}
@ -28,14 +30,16 @@ pub async fn get_marker_opt(
user_id: &Uuid,
timeline: Timeline,
) -> Result<Option<DbTimelineMarker>, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT timeline_marker
FROM timeline_marker
WHERE user_id = $1 AND timeline = $2
",
&[&user_id, &timeline],
).await?;
&[&user_id, &timeline],
)
.await?;
let maybe_marker = match maybe_row {
Some(row) => row.try_get("timeline_marker")?,
None => None,

View file

@ -3,13 +3,7 @@ use uuid::Uuid;
use crate::database::{DatabaseClient, DatabaseError};
use crate::posts::{
helpers::{add_related_posts, add_user_actions},
queries::{
RELATED_ATTACHMENTS,
RELATED_EMOJIS,
RELATED_LINKS,
RELATED_MENTIONS,
RELATED_TAGS,
},
queries::{RELATED_ATTACHMENTS, RELATED_EMOJIS, RELATED_LINKS, RELATED_MENTIONS, RELATED_TAGS},
};
use super::types::{EventType, Notification};
@ -21,8 +15,9 @@ async fn create_notification(
post_id: Option<&Uuid>,
event_type: EventType,
) -> Result<(), DatabaseError> {
db_client.execute(
"
db_client
.execute(
"
INSERT INTO notification (
sender_id,
recipient_id,
@ -31,8 +26,9 @@ async fn create_notification(
)
VALUES ($1, $2, $3, $4)
",
&[&sender_id, &recipient_id, &post_id, &event_type],
).await?;
&[&sender_id, &recipient_id, &post_id, &event_type],
)
.await?;
Ok(())
}
@ -41,10 +37,7 @@ pub async fn create_follow_notification(
sender_id: &Uuid,
recipient_id: &Uuid,
) -> Result<(), DatabaseError> {
create_notification(
db_client, sender_id, recipient_id, None,
EventType::Follow,
).await
create_notification(db_client, sender_id, recipient_id, None, EventType::Follow).await
}
pub async fn create_reply_notification(
@ -54,9 +47,13 @@ pub async fn create_reply_notification(
post_id: &Uuid,
) -> Result<(), DatabaseError> {
create_notification(
db_client, sender_id, recipient_id, Some(post_id),
db_client,
sender_id,
recipient_id,
Some(post_id),
EventType::Reply,
).await
)
.await
}
pub async fn create_reaction_notification(
@ -66,9 +63,13 @@ pub async fn create_reaction_notification(
post_id: &Uuid,
) -> Result<(), DatabaseError> {
create_notification(
db_client, sender_id, recipient_id, Some(post_id),
db_client,
sender_id,
recipient_id,
Some(post_id),
EventType::Reaction,
).await
)
.await
}
pub async fn create_mention_notification(
@ -78,9 +79,13 @@ pub async fn create_mention_notification(
post_id: &Uuid,
) -> Result<(), DatabaseError> {
create_notification(
db_client, sender_id, recipient_id, Some(post_id),
db_client,
sender_id,
recipient_id,
Some(post_id),
EventType::Mention,
).await
)
.await
}
pub async fn create_repost_notification(
@ -90,9 +95,13 @@ pub async fn create_repost_notification(
post_id: &Uuid,
) -> Result<(), DatabaseError> {
create_notification(
db_client, sender_id, recipient_id, Some(post_id),
db_client,
sender_id,
recipient_id,
Some(post_id),
EventType::Repost,
).await
)
.await
}
pub async fn create_subscription_notification(
@ -101,9 +110,13 @@ pub async fn create_subscription_notification(
recipient_id: &Uuid,
) -> Result<(), DatabaseError> {
create_notification(
db_client, sender_id, recipient_id, None,
db_client,
sender_id,
recipient_id,
None,
EventType::Subscription,
).await
)
.await
}
pub async fn create_subscription_expiration_notification(
@ -112,9 +125,13 @@ pub async fn create_subscription_expiration_notification(
recipient_id: &Uuid,
) -> Result<(), DatabaseError> {
create_notification(
db_client, sender_id, recipient_id, None,
db_client,
sender_id,
recipient_id,
None,
EventType::SubscriptionExpiration,
).await
)
.await
}
pub async fn create_move_notification(
@ -122,10 +139,7 @@ pub async fn create_move_notification(
sender_id: &Uuid,
recipient_id: &Uuid,
) -> Result<(), DatabaseError> {
create_notification(
db_client, sender_id, recipient_id, None,
EventType::Move,
).await
create_notification(db_client, sender_id, recipient_id, None, EventType::Move).await
}
pub async fn get_notifications(
@ -156,31 +170,35 @@ pub async fn get_notifications(
ORDER BY notification.id DESC
LIMIT $3
",
related_attachments=RELATED_ATTACHMENTS,
related_mentions=RELATED_MENTIONS,
related_tags=RELATED_TAGS,
related_links=RELATED_LINKS,
related_emojis=RELATED_EMOJIS,
related_attachments = RELATED_ATTACHMENTS,
related_mentions = RELATED_MENTIONS,
related_tags = RELATED_TAGS,
related_links = RELATED_LINKS,
related_emojis = RELATED_EMOJIS,
);
let rows = db_client.query(
&statement,
&[&recipient_id, &max_id, &i64::from(limit)],
).await?;
let mut notifications: Vec<Notification> = rows.iter()
let rows = db_client
.query(&statement, &[&recipient_id, &max_id, &i64::from(limit)])
.await?;
let mut notifications: Vec<Notification> = rows
.iter()
.map(Notification::try_from)
.collect::<Result<_, _>>()?;
add_related_posts(
db_client,
notifications.iter_mut()
notifications
.iter_mut()
.filter_map(|item| item.post.as_mut())
.collect(),
).await?;
)
.await?;
add_user_actions(
db_client,
recipient_id,
notifications.iter_mut()
notifications
.iter_mut()
.filter_map(|item| item.post.as_mut())
.collect(),
).await?;
)
.await?;
Ok(notifications)
}

View file

@ -6,8 +6,7 @@ use uuid::Uuid;
use crate::attachments::types::DbMediaAttachment;
use crate::database::{
int_enum::{int_enum_from_sql, int_enum_to_sql},
DatabaseError,
DatabaseTypeError,
DatabaseError, DatabaseTypeError,
};
use crate::emojis::types::DbEmoji;
use crate::posts::types::{DbPost, Post};
@ -89,7 +88,6 @@ pub struct Notification {
}
impl TryFrom<&Row> for Notification {
type Error = DatabaseError;
fn try_from(row: &Row) -> Result<Self, Self::Error> {
@ -114,7 +112,7 @@ impl TryFrom<&Row> for Notification {
db_emojis,
)?;
Some(post)
},
}
None => None,
};
let notification = Self {

View file

@ -1,2 +1,2 @@
pub mod types;
pub mod queries;
pub mod types;

View file

@ -1,11 +1,7 @@
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::database::{
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
use crate::profiles::types::DbActorProfile;
use crate::users::types::{DbUser, User};
@ -15,8 +11,9 @@ pub async fn create_oauth_app(
db_client: &impl DatabaseClient,
app_data: DbOauthAppData,
) -> Result<DbOauthApp, DatabaseError> {
let row = db_client.query_one(
"
let row = db_client
.query_one(
"
INSERT INTO oauth_application (
app_name,
website,
@ -28,15 +25,17 @@ pub async fn create_oauth_app(
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING oauth_application
",
&[
&app_data.app_name,
&app_data.website,
&app_data.scopes,
&app_data.redirect_uri,
&app_data.client_id,
&app_data.client_secret,
],
).await.map_err(catch_unique_violation("oauth_application"))?;
&[
&app_data.app_name,
&app_data.website,
&app_data.scopes,
&app_data.redirect_uri,
&app_data.client_id,
&app_data.client_secret,
],
)
.await
.map_err(catch_unique_violation("oauth_application"))?;
let app = row.try_get("oauth_application")?;
Ok(app)
}
@ -45,14 +44,16 @@ pub async fn get_oauth_app_by_client_id(
db_client: &impl DatabaseClient,
client_id: &Uuid,
) -> Result<DbOauthApp, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT oauth_application
FROM oauth_application
WHERE client_id = $1
",
&[&client_id],
).await?;
&[&client_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("oauth application"))?;
let app = row.try_get("oauth_application")?;
Ok(app)
@ -67,8 +68,9 @@ pub async fn create_oauth_authorization(
created_at: &DateTime<Utc>,
expires_at: &DateTime<Utc>,
) -> Result<(), DatabaseError> {
db_client.execute(
"
db_client
.execute(
"
INSERT INTO oauth_authorization (
code,
user_id,
@ -79,15 +81,16 @@ pub async fn create_oauth_authorization(
)
VALUES ($1, $2, $3, $4, $5, $6)
",
&[
&authorization_code,
&user_id,
&application_id,
&scopes,
&created_at,
&expires_at,
],
).await?;
&[
&authorization_code,
&user_id,
&application_id,
&scopes,
&created_at,
&expires_at,
],
)
.await?;
Ok(())
}
@ -95,8 +98,9 @@ pub async fn get_user_by_authorization_code(
db_client: &impl DatabaseClient,
authorization_code: &str,
) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM oauth_authorization
JOIN user_account ON oauth_authorization.user_id = user_account.id
@ -105,8 +109,9 @@ pub async fn get_user_by_authorization_code(
oauth_authorization.code = $1
AND oauth_authorization.expires_at > CURRENT_TIMESTAMP
",
&[&authorization_code],
).await?;
&[&authorization_code],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("authorization"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -121,13 +126,15 @@ pub async fn save_oauth_token(
created_at: &DateTime<Utc>,
expires_at: &DateTime<Utc>,
) -> Result<(), DatabaseError> {
db_client.execute(
"
db_client
.execute(
"
INSERT INTO oauth_token (owner_id, token, created_at, expires_at)
VALUES ($1, $2, $3, $4)
",
&[&owner_id, &token, &created_at, &expires_at],
).await?;
&[&owner_id, &token, &created_at, &expires_at],
)
.await?;
Ok(())
}
@ -137,24 +144,25 @@ pub async fn delete_oauth_token(
token: &str,
) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?;
let maybe_row = transaction.query_opt(
"
let maybe_row = transaction
.query_opt(
"
SELECT owner_id FROM oauth_token
WHERE token = $1
FOR UPDATE
",
&[&token],
).await?;
&[&token],
)
.await?;
if let Some(row) = maybe_row {
let owner_id: Uuid = row.try_get("owner_id")?;
if owner_id != *current_user_id {
// Return error if token is owned by a different user
return Err(DatabaseError::NotFound("token"));
} else {
transaction.execute(
"DELETE FROM oauth_token WHERE token = $1",
&[&token],
).await?;
transaction
.execute("DELETE FROM oauth_token WHERE token = $1", &[&token])
.await?;
};
};
transaction.commit().await?;
@ -165,10 +173,9 @@ pub async fn delete_oauth_tokens(
db_client: &impl DatabaseClient,
owner_id: &Uuid,
) -> Result<(), DatabaseError> {
db_client.execute(
"DELETE FROM oauth_token WHERE owner_id = $1",
&[&owner_id],
).await?;
db_client
.execute("DELETE FROM oauth_token WHERE owner_id = $1", &[&owner_id])
.await?;
Ok(())
}
@ -176,8 +183,9 @@ pub async fn get_user_by_oauth_token(
db_client: &impl DatabaseClient,
access_token: &str,
) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM oauth_token
JOIN user_account ON oauth_token.owner_id = user_account.id
@ -186,8 +194,9 @@ pub async fn get_user_by_oauth_token(
oauth_token.token = $1
AND oauth_token.expires_at > CURRENT_TIMESTAMP
",
&[&access_token],
).await?;
&[&access_token],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -197,13 +206,10 @@ pub async fn get_user_by_oauth_token(
#[cfg(test)]
mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use crate::users::{
queries::create_user,
types::UserCreateData,
};
use super::*;
use crate::database::test_utils::create_test_database;
use crate::users::{queries::create_user, types::UserCreateData};
use serial_test::serial;
#[tokio::test]
#[serial]
@ -240,7 +246,9 @@ mod tests {
"read write",
&Utc::now(),
&Utc::now(),
).await.unwrap();
)
.await
.unwrap();
}
#[tokio::test]
@ -254,17 +262,11 @@ mod tests {
};
let user = create_user(db_client, user_data).await.unwrap();
let token = "test-token";
save_oauth_token(
db_client,
&user.id,
token,
&Utc::now(),
&Utc::now(),
).await.unwrap();
delete_oauth_token(
db_client,
&user.id,
token,
).await.unwrap();
save_oauth_token(db_client, &user.id, token, &Utc::now(), &Utc::now())
.await
.unwrap();
delete_oauth_token(db_client, &user.id, token)
.await
.unwrap();
}
}

View file

@ -2,17 +2,10 @@ use uuid::Uuid;
use crate::database::{DatabaseClient, DatabaseError};
use crate::reactions::queries::find_favourited_by_user;
use crate::relationships::{
queries::has_relationship,
types::RelationshipType,
};
use crate::relationships::{queries::has_relationship, types::RelationshipType};
use crate::users::types::{Permission, User};
use super::queries::{
get_post_by_id,
get_related_posts,
find_reposted_by_user,
};
use super::queries::{find_reposted_by_user, get_post_by_id, get_related_posts};
use super::types::{Post, PostActions, Visibility};
pub async fn add_related_posts(
@ -22,7 +15,8 @@ pub async fn add_related_posts(
let posts_ids = posts.iter().map(|post| post.id).collect();
let related = get_related_posts(db_client, posts_ids).await?;
let get_post = |post_id: &Uuid| -> Result<Post, DatabaseError> {
let post = related.iter()
let post = related
.iter()
.find(|post| post.id == *post_id)
.ok_or(DatabaseError::NotFound("post"))?
.clone();
@ -38,14 +32,14 @@ pub async fn add_related_posts(
for linked_id in repost_of.links.iter() {
let linked = get_post(linked_id)?;
repost_of.linked.push(linked);
};
}
post.repost_of = Some(Box::new(repost_of));
};
for linked_id in post.links.iter() {
let linked = get_post(linked_id)?;
post.linked.push(linked);
};
};
}
}
Ok(())
}
@ -54,12 +48,14 @@ pub async fn add_user_actions(
user_id: &Uuid,
posts: Vec<&mut Post>,
) -> Result<(), DatabaseError> {
let posts_ids: Vec<Uuid> = posts.iter()
let posts_ids: Vec<Uuid> = posts
.iter()
.map(|post| post.id)
.chain(
posts.iter()
posts
.iter()
.filter_map(|post| post.repost_of.as_ref())
.map(|post| post.id)
.map(|post| post.id),
)
.collect();
let favourites = find_favourited_by_user(db_client, user_id, &posts_ids).await?;
@ -87,7 +83,9 @@ pub async fn can_view_post(
post: &Post,
) -> Result<bool, DatabaseError> {
let is_mentioned = |user: &User| {
post.mentions.iter().any(|profile| profile.id == user.profile.id)
post.mentions
.iter()
.any(|profile| profile.id == user.profile.id)
};
let result = match post.visibility {
Visibility::Public => true,
@ -98,7 +96,7 @@ pub async fn can_view_post(
} else {
false
}
},
}
Visibility::Followers => {
if let Some(user) = user {
let is_following = has_relationship(
@ -106,12 +104,13 @@ pub async fn can_view_post(
&user.id,
&post.author.id,
RelationshipType::Follow,
).await?;
)
.await?;
is_following || is_mentioned(user)
} else {
false
}
},
}
Visibility::Subscribers => {
if let Some(user) = user {
// Can view only if mentioned
@ -119,14 +118,12 @@ pub async fn can_view_post(
} else {
false
}
},
}
};
Ok(result)
}
pub fn can_create_post(
user: &User,
) -> bool {
pub fn can_create_post(user: &User) -> bool {
user.role.has_permission(Permission::CreatePost)
}
@ -143,19 +140,16 @@ pub async fn get_local_post_by_id(
#[cfg(test)]
mod tests {
use serial_test::serial;
use tokio_postgres::Client;
use super::*;
use crate::database::test_utils::create_test_database;
use crate::posts::{
queries::create_post,
types::PostCreateData,
};
use crate::posts::{queries::create_post, types::PostCreateData};
use crate::relationships::queries::{follow, subscribe};
use crate::users::{
queries::create_user,
types::{Role, User, UserCreateData},
};
use super::*;
use serial_test::serial;
use tokio_postgres::Client;
async fn create_test_user(db_client: &mut Client, username: &str) -> User {
let user_data = UserCreateData {
@ -181,8 +175,12 @@ mod tests {
in_reply_to_id: Some(post.id.clone()),
..Default::default()
};
let mut reply = create_post(db_client, &author.id, reply_data).await.unwrap();
add_related_posts(db_client, vec![&mut reply]).await.unwrap();
let mut reply = create_post(db_client, &author.id, reply_data)
.await
.unwrap();
add_related_posts(db_client, vec![&mut reply])
.await
.unwrap();
assert_eq!(reply.in_reply_to.unwrap().id, post.id);
assert_eq!(reply.repost_of.is_none(), true);
assert_eq!(reply.linked.is_empty(), true);
@ -253,7 +251,9 @@ mod tests {
visibility: Visibility::Followers,
..Default::default()
};
let result = can_view_post(db_client, Some(&follower), &post).await.unwrap();
let result = can_view_post(db_client, Some(&follower), &post)
.await
.unwrap();
assert_eq!(result, true);
}
@ -265,23 +265,26 @@ mod tests {
let follower = create_test_user(db_client, "follower").await;
follow(db_client, &follower.id, &author.id).await.unwrap();
let subscriber = create_test_user(db_client, "subscriber").await;
subscribe(db_client, &subscriber.id, &author.id).await.unwrap();
subscribe(db_client, &subscriber.id, &author.id)
.await
.unwrap();
let post = Post {
author: author.profile,
visibility: Visibility::Subscribers,
mentions: vec![subscriber.profile.clone()],
..Default::default()
};
assert_eq!(can_view_post(db_client, None, &post).await.unwrap(), false,);
assert_eq!(
can_view_post(db_client, None, &post).await.unwrap(),
can_view_post(db_client, Some(&follower), &post)
.await
.unwrap(),
false,
);
assert_eq!(
can_view_post(db_client, Some(&follower), &post).await.unwrap(),
false,
);
assert_eq!(
can_view_post(db_client, Some(&subscriber), &post).await.unwrap(),
can_view_post(db_client, Some(&subscriber), &post)
.await
.unwrap(),
true,
);
}

File diff suppressed because it is too large Load diff

View file

@ -6,8 +6,7 @@ use uuid::Uuid;
use crate::attachments::types::DbMediaAttachment;
use crate::database::{
int_enum::{int_enum_from_sql, int_enum_to_sql},
DatabaseError,
DatabaseTypeError,
DatabaseError, DatabaseTypeError,
};
use crate::emojis::types::DbEmoji;
use crate::profiles::types::DbActorProfile;
@ -21,7 +20,9 @@ pub enum Visibility {
}
impl Default for Visibility {
fn default() -> Self { Self::Public }
fn default() -> Self {
Self::Public
}
}
impl From<&Visibility> for i16 {
@ -130,19 +131,19 @@ impl Post {
if db_author.is_local() != db_post.object_id.is_none() {
return Err(DatabaseTypeError);
};
if db_post.repost_of_id.is_some() && (
db_post.content.len() != 0 ||
db_post.is_sensitive ||
db_post.in_reply_to_id.is_some() ||
db_post.ipfs_cid.is_some() ||
db_post.token_id.is_some() ||
db_post.token_tx_id.is_some() ||
!db_attachments.is_empty() ||
!db_mentions.is_empty() ||
!db_tags.is_empty() ||
!db_links.is_empty() ||
!db_emojis.is_empty()
) {
if db_post.repost_of_id.is_some()
&& (db_post.content.len() != 0
|| db_post.is_sensitive
|| db_post.in_reply_to_id.is_some()
|| db_post.ipfs_cid.is_some()
|| db_post.token_id.is_some()
|| db_post.token_tx_id.is_some()
|| !db_attachments.is_empty()
|| !db_mentions.is_empty()
|| !db_tags.is_empty()
|| !db_links.is_empty()
|| !db_emojis.is_empty())
{
return Err(DatabaseTypeError);
};
let post = Self {
@ -218,7 +219,6 @@ impl Default for Post {
}
impl TryFrom<&Row> for Post {
type Error = DatabaseError;
fn try_from(row: &Row) -> Result<Self, Self::Error> {
@ -259,10 +259,7 @@ pub struct PostCreateData {
}
impl PostCreateData {
pub fn repost(
repost_of_id: Uuid,
object_id: Option<String>,
) -> Self {
pub fn repost(repost_of_id: Uuid, object_id: Option<String>) -> Self {
Self {
repost_of_id: Some(repost_of_id),
object_id: object_id,

View file

@ -1,9 +1,6 @@
use crate::database::{DatabaseClient, DatabaseError};
use super::queries::{
get_profile_by_remote_actor_id,
search_profiles_by_did_only,
};
use super::queries::{get_profile_by_remote_actor_id, search_profiles_by_did_only};
use super::types::DbActorProfile;
pub async fn find_declared_aliases(
@ -12,17 +9,14 @@ pub async fn find_declared_aliases(
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let mut results = vec![];
for actor_id in profile.aliases.clone().into_actor_ids() {
let alias = match get_profile_by_remote_actor_id(
db_client,
&actor_id,
).await {
let alias = match get_profile_by_remote_actor_id(db_client, &actor_id).await {
Ok(profile) => profile,
// Ignore unknown profiles
Err(DatabaseError::NotFound(_)) => continue,
Err(other_error) => return Err(other_error),
};
results.push(alias);
};
}
Ok(results)
}
@ -32,16 +26,13 @@ pub async fn find_verified_aliases(
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let mut results = vec![];
for identity_proof in profile.identity_proofs.inner() {
let aliases = search_profiles_by_did_only(
db_client,
&identity_proof.issuer,
).await?;
let aliases = search_profiles_by_did_only(db_client, &identity_proof.issuer).await?;
for alias in aliases {
if alias.id == profile.id {
continue;
};
results.push(alias);
};
};
}
}
Ok(results)
}

View file

@ -1,35 +1,16 @@
use chrono::{DateTime, Utc};
use uuid::Uuid;
use mitra_utils::{
currencies::Currency,
did::Did,
did_pkh::DidPkh,
id::generate_ulid,
};
use mitra_utils::{currencies::Currency, did::Did, did_pkh::DidPkh, id::generate_ulid};
use crate::cleanup::{
find_orphaned_files,
find_orphaned_ipfs_objects,
DeletionQueue,
};
use crate::database::{
catch_unique_violation,
query_macro::query,
DatabaseClient,
DatabaseError,
};
use crate::cleanup::{find_orphaned_files, find_orphaned_ipfs_objects, DeletionQueue};
use crate::database::{catch_unique_violation, query_macro::query, DatabaseClient, DatabaseError};
use crate::emojis::types::DbEmoji;
use crate::instances::queries::create_instance;
use crate::relationships::types::RelationshipType;
use super::types::{
Aliases,
DbActorProfile,
ExtraFields,
IdentityProofs,
PaymentOptions,
ProfileCreateData,
Aliases, DbActorProfile, ExtraFields, IdentityProofs, PaymentOptions, ProfileCreateData,
ProfileUpdateData,
};
@ -38,8 +19,9 @@ async fn create_profile_emojis(
profile_id: &Uuid,
emojis: Vec<Uuid>,
) -> Result<Vec<DbEmoji>, DatabaseError> {
let emojis_rows = db_client.query(
"
let emojis_rows = db_client
.query(
"
INSERT INTO profile_emoji (profile_id, emoji_id)
SELECT $1, emoji.id FROM emoji WHERE id = ANY($2)
RETURNING (
@ -47,12 +29,14 @@ async fn create_profile_emojis(
WHERE emoji.id = emoji_id
)
",
&[&profile_id, &emojis],
).await?;
&[&profile_id, &emojis],
)
.await?;
if emojis_rows.len() != emojis.len() {
return Err(DatabaseError::NotFound("emoji"));
};
let emojis = emojis_rows.iter()
let emojis = emojis_rows
.iter()
.map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?;
Ok(emojis)
@ -62,8 +46,9 @@ async fn update_emoji_cache(
db_client: &impl DatabaseClient,
profile_id: &Uuid,
) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
WITH profile_emojis AS (
SELECT
actor_profile.id AS profile_id,
@ -83,8 +68,9 @@ async fn update_emoji_cache(
WHERE actor_profile.id = profile_emojis.profile_id
RETURNING actor_profile
",
&[&profile_id],
).await?;
&[&profile_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile: DbActorProfile = row.try_get("actor_profile")?;
Ok(profile)
@ -94,8 +80,9 @@ pub async fn update_emoji_caches(
db_client: &impl DatabaseClient,
emoji_id: &Uuid,
) -> Result<(), DatabaseError> {
db_client.execute(
"
db_client
.execute(
"
WITH profile_emojis AS (
SELECT
actor_profile.id AS profile_id,
@ -115,8 +102,9 @@ pub async fn update_emoji_caches(
FROM profile_emojis
WHERE actor_profile.id = profile_emojis.profile_id
",
&[&emoji_id],
).await?;
&[&emoji_id],
)
.await?;
Ok(())
}
@ -130,8 +118,9 @@ pub async fn create_profile(
if let Some(ref hostname) = profile_data.hostname {
create_instance(&transaction, hostname).await?;
};
transaction.execute(
"
transaction
.execute(
"
INSERT INTO actor_profile (
id,
username,
@ -151,30 +140,28 @@ pub async fn create_profile(
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING actor_profile
",
&[
&profile_id,
&profile_data.username,
&profile_data.hostname,
&profile_data.display_name,
&profile_data.bio,
&profile_data.bio,
&profile_data.avatar,
&profile_data.banner,
&profile_data.manually_approves_followers,
&IdentityProofs(profile_data.identity_proofs),
&PaymentOptions(profile_data.payment_options),
&ExtraFields(profile_data.extra_fields),
&Aliases::new(profile_data.aliases),
&profile_data.actor_json,
],
).await.map_err(catch_unique_violation("profile"))?;
&[
&profile_id,
&profile_data.username,
&profile_data.hostname,
&profile_data.display_name,
&profile_data.bio,
&profile_data.bio,
&profile_data.avatar,
&profile_data.banner,
&profile_data.manually_approves_followers,
&IdentityProofs(profile_data.identity_proofs),
&PaymentOptions(profile_data.payment_options),
&ExtraFields(profile_data.extra_fields),
&Aliases::new(profile_data.aliases),
&profile_data.actor_json,
],
)
.await
.map_err(catch_unique_violation("profile"))?;
// Create related objects
create_profile_emojis(
&transaction,
&profile_id,
profile_data.emojis,
).await?;
create_profile_emojis(&transaction, &profile_id, profile_data.emojis).await?;
let profile = update_emoji_cache(&transaction, &profile_id).await?;
transaction.commit().await?;
@ -187,8 +174,9 @@ pub async fn update_profile(
profile_data: ProfileUpdateData,
) -> Result<DbActorProfile, DatabaseError> {
let transaction = db_client.transaction().await?;
transaction.execute(
"
transaction
.execute(
"
UPDATE actor_profile
SET
display_name = $1,
@ -206,32 +194,31 @@ pub async fn update_profile(
WHERE id = $12
RETURNING actor_profile
",
&[
&profile_data.display_name,
&profile_data.bio,
&profile_data.bio_source,
&profile_data.avatar,
&profile_data.banner,
&profile_data.manually_approves_followers,
&IdentityProofs(profile_data.identity_proofs),
&PaymentOptions(profile_data.payment_options),
&ExtraFields(profile_data.extra_fields),
&Aliases::new(profile_data.aliases),
&profile_data.actor_json,
&profile_id,
],
).await?;
&[
&profile_data.display_name,
&profile_data.bio,
&profile_data.bio_source,
&profile_data.avatar,
&profile_data.banner,
&profile_data.manually_approves_followers,
&IdentityProofs(profile_data.identity_proofs),
&PaymentOptions(profile_data.payment_options),
&ExtraFields(profile_data.extra_fields),
&Aliases::new(profile_data.aliases),
&profile_data.actor_json,
&profile_id,
],
)
.await?;
// Delete and re-create related objects
transaction.execute(
"DELETE FROM profile_emoji WHERE profile_id = $1",
&[profile_id],
).await?;
create_profile_emojis(
&transaction,
profile_id,
profile_data.emojis,
).await?;
transaction
.execute(
"DELETE FROM profile_emoji WHERE profile_id = $1",
&[profile_id],
)
.await?;
create_profile_emojis(&transaction, profile_id, profile_data.emojis).await?;
let profile = update_emoji_cache(&transaction, profile_id).await?;
transaction.commit().await?;
@ -242,14 +229,16 @@ pub async fn get_profile_by_id(
db_client: &impl DatabaseClient,
profile_id: &Uuid,
) -> Result<DbActorProfile, DatabaseError> {
let result = db_client.query_opt(
"
let result = db_client
.query_opt(
"
SELECT actor_profile
FROM actor_profile
WHERE id = $1
",
&[&profile_id],
).await?;
&[&profile_id],
)
.await?;
let profile = match result {
Some(row) => row.try_get("actor_profile")?,
None => return Err(DatabaseError::NotFound("profile")),
@ -261,14 +250,16 @@ pub async fn get_profile_by_remote_actor_id(
db_client: &impl DatabaseClient,
actor_id: &str,
) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT actor_profile
FROM actor_profile
WHERE actor_id = $1
",
&[&actor_id],
).await?;
&[&actor_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile: DbActorProfile = row.try_get("actor_profile")?;
profile.check_remote()?;
@ -279,14 +270,16 @@ pub async fn get_profile_by_acct(
db_client: &impl DatabaseClient,
acct: &str,
) -> Result<DbActorProfile, DatabaseError> {
let result = db_client.query_opt(
"
let result = db_client
.query_opt(
"
SELECT actor_profile
FROM actor_profile
WHERE actor_profile.acct = $1
",
&[&acct],
).await?;
&[&acct],
)
.await?;
let profile = match result {
Some(row) => row.try_get("actor_profile")?,
None => return Err(DatabaseError::NotFound("profile")),
@ -300,7 +293,11 @@ pub async fn get_profiles(
offset: u16,
limit: u16,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let condition = if only_local { "WHERE actor_id IS NULL" } else { "" };
let condition = if only_local {
"WHERE actor_id IS NULL"
} else {
""
};
let statement = format!(
"
SELECT actor_profile
@ -309,13 +306,13 @@ pub async fn get_profiles(
ORDER BY username
LIMIT $1 OFFSET $2
",
condition=condition,
condition = condition,
);
let rows = db_client.query(
&statement,
&[&i64::from(limit), &i64::from(offset)],
).await?;
let profiles = rows.iter()
let rows = db_client
.query(&statement, &[&i64::from(limit), &i64::from(offset)])
.await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
@ -325,15 +322,18 @@ pub async fn get_profiles_by_accts(
db_client: &impl DatabaseClient,
accts: Vec<String>,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT actor_profile
FROM actor_profile
WHERE acct = ANY($1)
",
&[&accts],
).await?;
let profiles = rows.iter()
&[&accts],
)
.await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
@ -347,8 +347,9 @@ pub async fn delete_profile(
let transaction = db_client.transaction().await?;
// Select all posts authored by given actor,
// their descendants and reposts.
let posts_rows = transaction.query(
"
let posts_rows = transaction
.query(
"
WITH RECURSIVE context (post_id) AS (
SELECT post.id FROM post
WHERE post.author_id = $1
@ -361,14 +362,17 @@ pub async fn delete_profile(
)
SELECT post_id FROM context
",
&[&profile_id],
).await?;
let posts: Vec<Uuid> = posts_rows.iter()
&[&profile_id],
)
.await?;
let posts: Vec<Uuid> = posts_rows
.iter()
.map(|row| row.try_get("post_id"))
.collect::<Result<_, _>>()?;
// Get list of media files
let files_rows = transaction.query(
"
let files_rows = transaction
.query(
"
SELECT unnest(array_remove(
ARRAY[
avatar ->> 'file_name',
@ -381,14 +385,17 @@ pub async fn delete_profile(
SELECT file_name
FROM media_attachment WHERE post_id = ANY($2)
",
&[&profile_id, &posts],
).await?;
let files: Vec<String> = files_rows.iter()
&[&profile_id, &posts],
)
.await?;
let files: Vec<String> = files_rows
.iter()
.map(|row| row.try_get("file_name"))
.collect::<Result<_, _>>()?;
// Get list of IPFS objects
let ipfs_objects_rows = transaction.query(
"
let ipfs_objects_rows = transaction
.query(
"
SELECT ipfs_cid
FROM media_attachment
WHERE post_id = ANY($1) AND ipfs_cid IS NOT NULL
@ -397,14 +404,17 @@ pub async fn delete_profile(
FROM post
WHERE id = ANY($1) AND ipfs_cid IS NOT NULL
",
&[&posts],
).await?;
let ipfs_objects: Vec<String> = ipfs_objects_rows.iter()
&[&posts],
)
.await?;
let ipfs_objects: Vec<String> = ipfs_objects_rows
.iter()
.map(|row| row.try_get("ipfs_cid"))
.collect::<Result<_, _>>()?;
// Update post counters
transaction.execute(
"
transaction
.execute(
"
UPDATE actor_profile
SET post_count = post_count - post.count
FROM (
@ -414,11 +424,13 @@ pub async fn delete_profile(
) AS post
WHERE actor_profile.id = post.author_id
",
&[&posts],
).await?;
&[&posts],
)
.await?;
// Update counters
transaction.execute(
"
transaction
.execute(
"
UPDATE actor_profile
SET follower_count = follower_count - 1
FROM relationship
@ -427,10 +439,12 @@ pub async fn delete_profile(
AND relationship.target_id = actor_profile.id
AND relationship.relationship_type = $2
",
&[&profile_id, &RelationshipType::Follow],
).await?;
transaction.execute(
"
&[&profile_id, &RelationshipType::Follow],
)
.await?;
transaction
.execute(
"
UPDATE actor_profile
SET following_count = following_count - 1
FROM relationship
@ -439,10 +453,12 @@ pub async fn delete_profile(
AND relationship.target_id = $1
AND relationship.relationship_type = $2
",
&[&profile_id, &RelationshipType::Follow],
).await?;
transaction.execute(
"
&[&profile_id, &RelationshipType::Follow],
)
.await?;
transaction
.execute(
"
UPDATE actor_profile
SET subscriber_count = subscriber_count - 1
FROM relationship
@ -451,10 +467,12 @@ pub async fn delete_profile(
AND relationship.target_id = actor_profile.id
AND relationship.relationship_type = $2
",
&[&profile_id, &RelationshipType::Subscription],
).await?;
transaction.execute(
"
&[&profile_id, &RelationshipType::Subscription],
)
.await?;
transaction
.execute(
"
UPDATE post
SET reply_count = reply_count - reply.count
FROM (
@ -464,10 +482,12 @@ pub async fn delete_profile(
) AS reply
WHERE post.id = reply.in_reply_to_id
",
&[&profile_id],
).await?;
transaction.execute(
"
&[&profile_id],
)
.await?;
transaction
.execute(
"
UPDATE post
SET reaction_count = reaction_count - 1
FROM post_reaction
@ -475,10 +495,12 @@ pub async fn delete_profile(
post_reaction.post_id = post.id
AND post_reaction.author_id = $1
",
&[&profile_id],
).await?;
transaction.execute(
"
&[&profile_id],
)
.await?;
transaction
.execute(
"
UPDATE post
SET repost_count = post.repost_count - 1
FROM post AS repost
@ -486,16 +508,19 @@ pub async fn delete_profile(
repost.repost_of_id = post.id
AND repost.author_id = $1
",
&[&profile_id],
).await?;
&[&profile_id],
)
.await?;
// Delete profile
let deleted_count = transaction.execute(
"
let deleted_count = transaction
.execute(
"
DELETE FROM actor_profile WHERE id = $1
RETURNING actor_profile
",
&[&profile_id],
).await?;
&[&profile_id],
)
.await?;
if deleted_count == 0 {
return Err(DatabaseError::NotFound("profile"));
}
@ -518,22 +543,25 @@ pub async fn search_profiles(
Some(hostname) => {
// Search for exact actor address
format!("{}@{}", username, hostname)
},
}
None => {
// Fuzzy search for username
format!("%{}%", username)
},
}
};
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT actor_profile
FROM actor_profile
WHERE acct ILIKE $1
LIMIT $2
",
&[&db_search_query, &i64::from(limit)],
).await?;
let profiles: Vec<DbActorProfile> = rows.iter()
&[&db_search_query, &i64::from(limit)],
)
.await?;
let profiles: Vec<DbActorProfile> = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
@ -543,8 +571,9 @@ pub async fn search_profiles_by_did_only(
db_client: &impl DatabaseClient,
did: &Did,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT actor_profile
FROM actor_profile
WHERE
@ -554,9 +583,11 @@ pub async fn search_profiles_by_did_only(
WHERE proof ->> 'issuer' = $1
)
",
&[&did.to_string()],
).await?;
let profiles: Vec<DbActorProfile> = rows.iter()
&[&did.to_string()],
)
.await?;
let profiles: Vec<DbActorProfile> = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
@ -569,10 +600,9 @@ pub async fn search_profiles_by_did(
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let verified = search_profiles_by_did_only(db_client, did).await?;
let maybe_currency_address = match did {
Did::Pkh(did_pkh) => {
did_pkh.currency()
.map(|currency| (currency, did_pkh.address.clone()))
},
Did::Pkh(did_pkh) => did_pkh
.currency()
.map(|currency| (currency, did_pkh.address.clone())),
_ => None,
};
let unverified = if let Some((currency, address)) = maybe_currency_address {
@ -597,16 +627,13 @@ pub async fn search_profiles_by_did(
AND field ->> 'value' {value_op} $field_value
)
",
value_op=value_op,
value_op = value_op,
);
let field_name = currency.field_name();
let query = query!(
&statement,
field_name=field_name,
field_value=address,
)?;
let query = query!(&statement, field_name = field_name, field_value = address,)?;
let rows = db_client.query(query.sql(), query.parameters()).await?;
let unverified = rows.iter()
let unverified = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<Vec<DbActorProfile>, _>>()?
.into_iter()
@ -641,15 +668,17 @@ pub async fn update_follower_count(
profile_id: &Uuid,
change: i32,
) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
UPDATE actor_profile
SET follower_count = follower_count + $1
WHERE id = $2
RETURNING actor_profile
",
&[&change, &profile_id],
).await?;
&[&change, &profile_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile = row.try_get("actor_profile")?;
Ok(profile)
@ -660,15 +689,17 @@ pub async fn update_following_count(
profile_id: &Uuid,
change: i32,
) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
UPDATE actor_profile
SET following_count = following_count + $1
WHERE id = $2
RETURNING actor_profile
",
&[&change, &profile_id],
).await?;
&[&change, &profile_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile = row.try_get("actor_profile")?;
Ok(profile)
@ -679,15 +710,17 @@ pub async fn update_subscriber_count(
profile_id: &Uuid,
change: i32,
) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
UPDATE actor_profile
SET subscriber_count = subscriber_count + $1
WHERE id = $2
RETURNING actor_profile
",
&[&change, &profile_id],
).await?;
&[&change, &profile_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile = row.try_get("actor_profile")?;
Ok(profile)
@ -698,15 +731,17 @@ pub async fn update_post_count(
profile_id: &Uuid,
change: i32,
) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
UPDATE actor_profile
SET post_count = post_count + $1
WHERE id = $2
RETURNING actor_profile
",
&[&change, &profile_id],
).await?;
&[&change, &profile_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile = row.try_get("actor_profile")?;
Ok(profile)
@ -720,24 +755,28 @@ pub async fn set_reachability_status(
) -> Result<(), DatabaseError> {
if !is_reachable {
// Don't update profile if unreachable_since is already set
db_client.execute(
"
db_client
.execute(
"
UPDATE actor_profile
SET unreachable_since = CURRENT_TIMESTAMP
WHERE actor_id = $1 AND unreachable_since IS NULL
",
&[&actor_id],
).await?;
&[&actor_id],
)
.await?;
} else {
// Remove status (if set)
db_client.execute(
"
db_client
.execute(
"
UPDATE actor_profile
SET unreachable_since = NULL
WHERE actor_id = $1
",
&[&actor_id],
).await?;
&[&actor_id],
)
.await?;
};
Ok(())
}
@ -746,16 +785,19 @@ pub async fn find_unreachable(
db_client: &impl DatabaseClient,
unreachable_since: &DateTime<Utc>,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT actor_profile
FROM actor_profile
WHERE unreachable_since < $1
ORDER BY hostname, username
",
&[&unreachable_since],
).await?;
let profiles = rows.iter()
&[&unreachable_since],
)
.await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
@ -768,8 +810,9 @@ pub async fn find_empty_profiles(
db_client: &impl DatabaseClient,
updated_before: &DateTime<Utc>,
) -> Result<Vec<Uuid>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT actor_profile.id
FROM actor_profile
WHERE
@ -816,9 +859,11 @@ pub async fn find_empty_profiles(
WHERE sender_id = actor_profile.id
)
",
&[&updated_before],
).await?;
let ids: Vec<Uuid> = rows.iter()
&[&updated_before],
)
.await?;
let ids: Vec<Uuid> = rows
.iter()
.map(|row| row.try_get("id"))
.collect::<Result<_, _>>()?;
Ok(ids)
@ -826,30 +871,21 @@ pub async fn find_empty_profiles(
#[cfg(test)]
mod tests {
use serial_test::serial;
use super::*;
use crate::database::test_utils::create_test_database;
use crate::emojis::{
queries::create_emoji,
types::EmojiImage,
};
use crate::emojis::{queries::create_emoji, types::EmojiImage};
use crate::profiles::{
queries::create_profile,
types::{
DbActor,
ExtraField,
IdentityProof,
IdentityProofType,
ProfileCreateData,
},
types::{DbActor, ExtraField, IdentityProof, IdentityProofType, ProfileCreateData},
};
use crate::users::{
queries::create_user,
types::UserCreateData,
};
use super::*;
use crate::users::{queries::create_user, types::UserCreateData};
use serial_test::serial;
fn create_test_actor(actor_id: &str) -> DbActor {
DbActor { id: actor_id.to_string(), ..Default::default() }
DbActor {
id: actor_id.to_string(),
..Default::default()
}
}
#[tokio::test]
@ -883,10 +919,7 @@ mod tests {
assert_eq!(profile.username, "test");
assert_eq!(profile.hostname.unwrap(), "example.com");
assert_eq!(profile.acct, "test@example.com");
assert_eq!(
profile.actor_id.unwrap(),
"https://example.com/users/test",
);
assert_eq!(profile.actor_id.unwrap(), "https://example.com/users/test",);
}
#[tokio::test]
@ -894,14 +927,9 @@ mod tests {
async fn test_create_profile_with_emoji() {
let db_client = &mut create_test_database().await;
let image = EmojiImage::default();
let emoji = create_emoji(
db_client,
"testemoji",
None,
image,
None,
&Utc::now(),
).await.unwrap();
let emoji = create_emoji(db_client, "testemoji", None, image, None, &Utc::now())
.await
.unwrap();
let profile_data = ProfileCreateData {
username: "test".to_string(),
emojis: vec![emoji.id.clone()],
@ -931,7 +959,10 @@ mod tests {
actor_json: Some(create_test_actor(actor_id)),
..Default::default()
};
let error = create_profile(db_client, profile_data_2).await.err().unwrap();
let error = create_profile(db_client, profile_data_2)
.await
.err()
.unwrap();
assert_eq!(error.to_string(), "profile already exists");
}
@ -947,11 +978,9 @@ mod tests {
let mut profile_data = ProfileUpdateData::from(&profile);
let bio = "test bio";
profile_data.bio = Some(bio.to_string());
let profile_updated = update_profile(
db_client,
&profile.id,
profile_data,
).await.unwrap();
let profile_updated = update_profile(db_client, &profile.id, profile_data)
.await
.unwrap();
assert_eq!(profile_updated.username, profile.username);
assert_eq!(profile_updated.bio.unwrap(), bio);
assert!(profile_updated.updated_at != profile.updated_at);
@ -980,8 +1009,10 @@ mod tests {
..Default::default()
};
let _user = create_user(db_client, user_data).await.unwrap();
let profiles = search_profiles_by_wallet_address(
db_client, &ETHEREUM, wallet_address, false).await.unwrap();
let profiles =
search_profiles_by_wallet_address(db_client, &ETHEREUM, wallet_address, false)
.await
.unwrap();
// Login address must not be exposed
assert_eq!(profiles.len(), 0);
@ -1001,8 +1032,9 @@ mod tests {
..Default::default()
};
let profile = create_profile(db_client, profile_data).await.unwrap();
let profiles = search_profiles_by_wallet_address(
db_client, &ETHEREUM, "0x1234abcd", false).await.unwrap();
let profiles = search_profiles_by_wallet_address(db_client, &ETHEREUM, "0x1234abcd", false)
.await
.unwrap();
assert_eq!(profiles.len(), 1);
assert_eq!(profiles[0].id, profile.id);
@ -1022,8 +1054,9 @@ mod tests {
..Default::default()
};
let profile = create_profile(db_client, profile_data).await.unwrap();
let profiles = search_profiles_by_wallet_address(
db_client, &ETHEREUM, "0x1234abcd", false).await.unwrap();
let profiles = search_profiles_by_wallet_address(db_client, &ETHEREUM, "0x1234abcd", false)
.await
.unwrap();
assert_eq!(profiles.len(), 1);
assert_eq!(profiles[0].id, profile.id);
@ -1041,7 +1074,9 @@ mod tests {
..Default::default()
};
let profile = create_profile(db_client, profile_data).await.unwrap();
set_reachability_status(db_client, actor_id, false).await.unwrap();
set_reachability_status(db_client, actor_id, false)
.await
.unwrap();
let profile = get_profile_by_id(db_client, &profile.id).await.unwrap();
assert_eq!(profile.unreachable_since.is_some(), true);
}
@ -1051,7 +1086,9 @@ mod tests {
async fn test_find_empty_profiles() {
let db_client = &mut create_test_database().await;
let updated_before = Utc::now();
let profiles = find_empty_profiles(db_client, &updated_before).await.unwrap();
let profiles = find_empty_profiles(db_client, &updated_before)
.await
.unwrap();
assert_eq!(profiles.is_empty(), true);
}
}

View file

@ -1,17 +1,12 @@
use chrono::{DateTime, Duration, Utc};
use postgres_types::FromSql;
use serde::{
Deserialize, Deserializer, Serialize, Serializer,
de::Error as DeserializerError,
ser::SerializeMap,
__private::ser::FlatMapSerializer,
Deserialize, Deserializer, Serialize, Serializer, __private::ser::FlatMapSerializer,
de::Error as DeserializerError, ser::SerializeMap,
};
use uuid::Uuid;
use mitra_utils::{
caip2::ChainId,
did::Did,
};
use mitra_utils::{caip2::ChainId, did::Did};
use crate::database::{
json_macro::{json_from_sql, json_to_sql},
@ -27,11 +22,7 @@ pub struct ProfileImage {
}
impl ProfileImage {
pub fn new(
file_name: String,
file_size: usize,
media_type: Option<String>,
) -> Self {
pub fn new(file_name: String, file_size: usize, media_type: Option<String>) -> Self {
Self {
file_name,
file_size: Some(file_size),
@ -73,16 +64,19 @@ impl TryFrom<i16> for IdentityProofType {
impl<'de> Deserialize<'de> for IdentityProofType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>
where
D: Deserializer<'de>,
{
i16::deserialize(deserializer)?
.try_into().map_err(DeserializerError::custom)
.try_into()
.map_err(DeserializerError::custom)
}
}
impl Serialize for IdentityProofType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
where
S: Serializer,
{
serializer.serialize_i16(self.into())
}
@ -194,30 +188,31 @@ impl PaymentOption {
// Workaround: https://stackoverflow.com/a/65576570
impl<'de> Deserialize<'de> for PaymentOption {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>
where
D: Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
let payment_type = value.get("payment_type")
let payment_type = value
.get("payment_type")
.and_then(serde_json::Value::as_u64)
.and_then(|val| i16::try_from(val).ok())
.and_then(|val| PaymentType::try_from(val).ok())
.ok_or(DeserializerError::custom("invalid payment type"))?;
let payment_option = match payment_type {
PaymentType::Link => {
let link = PaymentLink::deserialize(value)
.map_err(DeserializerError::custom)?;
let link = PaymentLink::deserialize(value).map_err(DeserializerError::custom)?;
Self::Link(link)
},
}
PaymentType::EthereumSubscription => {
let payment_info = EthereumSubscription::deserialize(value)
.map_err(DeserializerError::custom)?;
let payment_info =
EthereumSubscription::deserialize(value).map_err(DeserializerError::custom)?;
Self::EthereumSubscription(payment_info)
},
}
PaymentType::MoneroSubscription => {
let payment_info = MoneroSubscription::deserialize(value)
.map_err(DeserializerError::custom)?;
let payment_info =
MoneroSubscription::deserialize(value).map_err(DeserializerError::custom)?;
Self::MoneroSubscription(payment_info)
},
}
};
Ok(payment_option)
}
@ -225,7 +220,8 @@ impl<'de> Deserialize<'de> for PaymentOption {
impl Serialize for PaymentOption {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer,
where
S: Serializer,
{
let mut map = serializer.serialize_map(None)?;
let payment_type = self.payment_type();
@ -235,10 +231,10 @@ impl Serialize for PaymentOption {
Self::Link(link) => link.serialize(FlatMapSerializer(&mut map))?,
Self::EthereumSubscription(payment_info) => {
payment_info.serialize(FlatMapSerializer(&mut map))?
},
}
Self::MoneroSubscription(payment_info) => {
payment_info.serialize(FlatMapSerializer(&mut map))?
},
}
};
map.end()
}
@ -267,7 +263,8 @@ impl PaymentOptions {
/// of the given type.
pub fn any(&self, payment_type: PaymentType) -> bool {
let Self(payment_options) = self;
payment_options.iter()
payment_options
.iter()
.any(|option| option.payment_type() == payment_type)
}
}
@ -306,7 +303,8 @@ pub struct Aliases(Vec<Alias>);
impl Aliases {
pub fn new(actor_ids: Vec<String>) -> Self {
// Not signed
let aliases = actor_ids.into_iter()
let aliases = actor_ids
.into_iter()
.map(|actor_id| Alias { id: actor_id })
.collect();
Self(aliases)
@ -369,7 +367,7 @@ pub struct DbActorProfile {
pub username: String,
pub hostname: Option<String>,
pub display_name: Option<String>,
pub bio: Option<String>, // html
pub bio: Option<String>, // html
pub bio_source: Option<String>, // plaintext or markdown
pub avatar: Option<ProfileImage>,
pub banner: Option<ProfileImage>,
@ -491,16 +489,16 @@ impl ProfileUpdateData {
/// Adds new identity proof
/// or replaces the existing one if it has the same issuer.
pub fn add_identity_proof(&mut self, proof: IdentityProof) -> () {
self.identity_proofs.retain(|item| item.issuer != proof.issuer);
self.identity_proofs
.retain(|item| item.issuer != proof.issuer);
self.identity_proofs.push(proof);
}
/// Adds new payment option
/// or replaces the existing one if it has the same type.
pub fn add_payment_option(&mut self, option: PaymentOption) -> () {
self.payment_options.retain(|item| {
item.payment_type() != option.payment_type()
});
self.payment_options
.retain(|item| item.payment_type() != option.payment_type());
self.payment_options.push(option);
}
}
@ -519,7 +517,10 @@ impl From<&DbActorProfile> for ProfileUpdateData {
payment_options: profile.payment_options.into_inner(),
extra_fields: profile.extra_fields.into_inner(),
aliases: profile.aliases.into_actor_ids(),
emojis: profile.emojis.into_inner().into_iter()
emojis: profile
.emojis
.into_inner()
.into_iter()
.map(|emoji| emoji.id)
.collect(),
actor_json: profile.actor_json,
@ -539,7 +540,10 @@ mod tests {
Did::Pkh(ref did_pkh) => did_pkh,
_ => panic!("unexpected did method"),
};
assert_eq!(did_pkh.address, "0xb9c5714089478a327f09197987f16f9e5d936e8a");
assert_eq!(
did_pkh.address,
"0xb9c5714089478a327f09197987f16f9e5d936e8a"
);
let serialized = serde_json::to_string(&proof).unwrap();
assert_eq!(serialized, json_data);
}

View file

@ -1,31 +1,25 @@
use serde::{
de::DeserializeOwned,
Serialize,
};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value as JsonValue;
use crate::database::{
DatabaseClient,
DatabaseError,
DatabaseTypeError,
};
use crate::database::{DatabaseClient, DatabaseError, DatabaseTypeError};
pub async fn set_internal_property(
db_client: &impl DatabaseClient,
name: &str,
value: &impl Serialize,
) -> Result<(), DatabaseError> {
let value_json = serde_json::to_value(value)
.map_err(|_| DatabaseTypeError)?;
db_client.execute(
"
let value_json = serde_json::to_value(value).map_err(|_| DatabaseTypeError)?;
db_client
.execute(
"
INSERT INTO internal_property (property_name, property_value)
VALUES ($1, $2)
ON CONFLICT (property_name) DO UPDATE
SET property_value = $2
",
&[&name, &value_json],
).await?;
&[&name, &value_json],
)
.await?;
Ok(())
}
@ -33,21 +27,22 @@ pub async fn get_internal_property<T: DeserializeOwned>(
db_client: &impl DatabaseClient,
name: &str,
) -> Result<Option<T>, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT property_value
FROM internal_property
WHERE property_name = $1
",
&[&name],
).await?;
&[&name],
)
.await?;
let maybe_value = match maybe_row {
Some(row) => {
let value_json: JsonValue = row.try_get("property_value")?;
let value: T = serde_json::from_value(value_json)
.map_err(|_| DatabaseTypeError)?;
let value: T = serde_json::from_value(value_json).map_err(|_| DatabaseTypeError)?;
Some(value)
},
}
None => None,
};
Ok(maybe_value)
@ -55,9 +50,9 @@ pub async fn get_internal_property<T: DeserializeOwned>(
#[cfg(test)]
mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use super::*;
use crate::database::test_utils::create_test_database;
use serial_test::serial;
#[tokio::test]
#[serial]
@ -65,9 +60,13 @@ mod tests {
let db_client = &create_test_database().await;
let name = "myproperty";
let value = 100;
set_internal_property(db_client, name, &value).await.unwrap();
let db_value: u32 = get_internal_property(db_client, name).await
.unwrap().unwrap_or_default();
set_internal_property(db_client, name, &value)
.await
.unwrap();
let db_value: u32 = get_internal_property(db_client, name)
.await
.unwrap()
.unwrap_or_default();
assert_eq!(db_value, value);
}
}

View file

@ -2,16 +2,9 @@ use uuid::Uuid;
use mitra_utils::id::generate_ulid;
use crate::database::{
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
use crate::notifications::queries::create_reaction_notification;
use crate::posts::queries::{
update_reaction_count,
get_post_author,
};
use crate::posts::queries::{get_post_author, update_reaction_count};
use super::types::DbReaction;
@ -24,8 +17,9 @@ pub async fn create_reaction(
let transaction = db_client.transaction().await?;
let reaction_id = generate_ulid();
// Reactions to reposts are not allowed
let maybe_row = transaction.query_opt(
"
let maybe_row = transaction
.query_opt(
"
INSERT INTO post_reaction (id, author_id, post_id, activity_id)
SELECT $1, $2, $3, $4
WHERE NOT EXISTS (
@ -34,19 +28,16 @@ pub async fn create_reaction(
)
RETURNING post_reaction
",
&[&reaction_id, &author_id, &post_id, &activity_id],
).await.map_err(catch_unique_violation("reaction"))?;
&[&reaction_id, &author_id, &post_id, &activity_id],
)
.await
.map_err(catch_unique_violation("reaction"))?;
let row = maybe_row.ok_or(DatabaseError::NotFound("post"))?;
let reaction: DbReaction = row.try_get("post_reaction")?;
update_reaction_count(&transaction, post_id, 1).await?;
let post_author = get_post_author(&transaction, post_id).await?;
if post_author.is_local() && post_author.id != *author_id {
create_reaction_notification(
&transaction,
author_id,
&post_author.id,
post_id,
).await?;
create_reaction_notification(&transaction, author_id, &post_author.id, post_id).await?;
};
transaction.commit().await?;
Ok(reaction)
@ -56,14 +47,16 @@ pub async fn get_reaction_by_remote_activity_id(
db_client: &impl DatabaseClient,
activity_id: &str,
) -> Result<DbReaction, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT post_reaction
FROM post_reaction
WHERE activity_id = $1
",
&[&activity_id],
).await?;
&[&activity_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("reaction"))?;
let reaction = row.try_get("post_reaction")?;
Ok(reaction)
@ -75,14 +68,16 @@ pub async fn delete_reaction(
post_id: &Uuid,
) -> Result<Uuid, DatabaseError> {
let transaction = db_client.transaction().await?;
let maybe_row = transaction.query_opt(
"
let maybe_row = transaction
.query_opt(
"
DELETE FROM post_reaction
WHERE author_id = $1 AND post_id = $2
RETURNING post_reaction.id
",
&[&author_id, &post_id],
).await?;
&[&author_id, &post_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("reaction"))?;
let reaction_id = row.try_get("id")?;
update_reaction_count(&transaction, post_id, -1).await?;
@ -96,15 +91,18 @@ pub async fn find_favourited_by_user(
user_id: &Uuid,
posts_ids: &[Uuid],
) -> Result<Vec<Uuid>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT post_id
FROM post_reaction
WHERE author_id = $1 AND post_id = ANY($2)
",
&[&user_id, &posts_ids],
).await?;
let favourites: Vec<Uuid> = rows.iter()
&[&user_id, &posts_ids],
)
.await?;
let favourites: Vec<Uuid> = rows
.iter()
.map(|row| row.try_get("post_id"))
.collect::<Result<_, _>>()?;
Ok(favourites)

View file

@ -2,27 +2,15 @@ use uuid::Uuid;
use mitra_utils::id::generate_ulid;
use crate::database::{
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
use crate::notifications::queries::create_follow_notification;
use crate::profiles::{
queries::{
update_follower_count,
update_following_count,
update_subscriber_count,
},
queries::{update_follower_count, update_following_count, update_subscriber_count},
types::DbActorProfile,
};
use super::types::{
DbFollowRequest,
DbRelationship,
FollowRequestStatus,
RelatedActorProfile,
RelationshipType,
DbFollowRequest, DbRelationship, FollowRequestStatus, RelatedActorProfile, RelationshipType,
};
pub async fn get_relationships(
@ -30,8 +18,9 @@ pub async fn get_relationships(
source_id: &Uuid,
target_id: &Uuid,
) -> Result<Vec<DbRelationship>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT source_id, target_id, relationship_type
FROM relationship
WHERE
@ -45,14 +34,16 @@ pub async fn get_relationships(
source_id = $1 AND target_id = $2
AND request_status = $3
",
&[
&source_id,
&target_id,
&FollowRequestStatus::Pending,
&RelationshipType::FollowRequest,
],
).await?;
let relationships = rows.iter()
&[
&source_id,
&target_id,
&FollowRequestStatus::Pending,
&RelationshipType::FollowRequest,
],
)
.await?;
let relationships = rows
.iter()
.map(DbRelationship::try_from)
.collect::<Result<_, _>>()?;
Ok(relationships)
@ -64,20 +55,18 @@ pub async fn has_relationship(
target_id: &Uuid,
relationship_type: RelationshipType,
) -> Result<bool, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT 1
FROM relationship
WHERE
source_id = $1 AND target_id = $2
AND relationship_type = $3
",
&[
&source_id,
&target_id,
&relationship_type,
],
).await?;
&[&source_id, &target_id, &relationship_type],
)
.await?;
Ok(maybe_row.is_some())
}
@ -87,13 +76,16 @@ pub async fn follow(
target_id: &Uuid,
) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?;
transaction.execute(
"
transaction
.execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3)
",
&[&source_id, &target_id, &RelationshipType::Follow],
).await.map_err(catch_unique_violation("relationship"))?;
&[&source_id, &target_id, &RelationshipType::Follow],
)
.await
.map_err(catch_unique_violation("relationship"))?;
let target_profile = update_follower_count(&transaction, target_id, 1).await?;
update_following_count(&transaction, source_id, 1).await?;
if target_profile.is_local() {
@ -109,22 +101,21 @@ pub async fn unfollow(
target_id: &Uuid,
) -> Result<Option<Uuid>, DatabaseError> {
let transaction = db_client.transaction().await?;
let deleted_count = transaction.execute(
"
let deleted_count = transaction
.execute(
"
DELETE FROM relationship
WHERE
source_id = $1 AND target_id = $2
AND relationship_type = $3
",
&[&source_id, &target_id, &RelationshipType::Follow],
).await?;
&[&source_id, &target_id, &RelationshipType::Follow],
)
.await?;
let relationship_deleted = deleted_count > 0;
// Delete follow request (for remote follows)
let follow_request_deleted = delete_follow_request_opt(
&transaction,
source_id,
target_id,
).await?;
let follow_request_deleted =
delete_follow_request_opt(&transaction, source_id, target_id).await?;
if !relationship_deleted && follow_request_deleted.is_none() {
return Err(DatabaseError::NotFound("relationship"));
};
@ -147,21 +138,24 @@ pub async fn create_follow_request(
target_id: &Uuid,
) -> Result<DbFollowRequest, DatabaseError> {
let request_id = generate_ulid();
let row = db_client.query_one(
"
let row = db_client
.query_one(
"
INSERT INTO follow_request (
id, source_id, target_id, request_status
)
VALUES ($1, $2, $3, $4)
RETURNING follow_request
",
&[
&request_id,
&source_id,
&target_id,
&FollowRequestStatus::Pending,
],
).await.map_err(catch_unique_violation("follow request"))?;
&[
&request_id,
&source_id,
&target_id,
&FollowRequestStatus::Pending,
],
)
.await
.map_err(catch_unique_violation("follow request"))?;
let request = row.try_get("follow_request")?;
Ok(request)
}
@ -174,8 +168,9 @@ pub async fn create_remote_follow_request_opt(
activity_id: &str,
) -> Result<DbFollowRequest, DatabaseError> {
let request_id = generate_ulid();
let row = db_client.query_one(
"
let row = db_client
.query_one(
"
INSERT INTO follow_request (
id,
source_id,
@ -188,14 +183,15 @@ pub async fn create_remote_follow_request_opt(
DO UPDATE SET activity_id = $4
RETURNING follow_request
",
&[
&request_id,
&source_id,
&target_id,
&activity_id,
&FollowRequestStatus::Pending,
],
).await?;
&[
&request_id,
&source_id,
&target_id,
&activity_id,
&FollowRequestStatus::Pending,
],
)
.await?;
let request = row.try_get("follow_request")?;
Ok(request)
}
@ -205,15 +201,17 @@ pub async fn follow_request_accepted(
request_id: &Uuid,
) -> Result<(), DatabaseError> {
let mut transaction = db_client.transaction().await?;
let maybe_row = transaction.query_opt(
"
let maybe_row = transaction
.query_opt(
"
UPDATE follow_request
SET request_status = $1
WHERE id = $2
RETURNING source_id, target_id
",
&[&FollowRequestStatus::Accepted, &request_id],
).await?;
&[&FollowRequestStatus::Accepted, &request_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?;
let source_id: Uuid = row.try_get("source_id")?;
let target_id: Uuid = row.try_get("target_id")?;
@ -226,14 +224,16 @@ pub async fn follow_request_rejected(
db_client: &impl DatabaseClient,
request_id: &Uuid,
) -> Result<(), DatabaseError> {
let updated_count = db_client.execute(
"
let updated_count = db_client
.execute(
"
UPDATE follow_request
SET request_status = $1
WHERE id = $2
",
&[&FollowRequestStatus::Rejected, &request_id],
).await?;
&[&FollowRequestStatus::Rejected, &request_id],
)
.await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("follow request"));
}
@ -245,33 +245,39 @@ async fn delete_follow_request_opt(
source_id: &Uuid,
target_id: &Uuid,
) -> Result<Option<Uuid>, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
DELETE FROM follow_request
WHERE source_id = $1 AND target_id = $2
RETURNING id
",
&[&source_id, &target_id],
).await?;
&[&source_id, &target_id],
)
.await?;
let maybe_request_id = if let Some(row) = maybe_row {
let request_id: Uuid = row.try_get("id")?;
Some(request_id)
} else { None };
} else {
None
};
Ok(maybe_request_id)
}
pub async fn get_follow_request_by_id(
db_client: &impl DatabaseClient,
db_client: &impl DatabaseClient,
request_id: &Uuid,
) -> Result<DbFollowRequest, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT follow_request
FROM follow_request
WHERE id = $1
",
&[&request_id],
).await?;
&[&request_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?;
let request = row.try_get("follow_request")?;
Ok(request)
@ -281,14 +287,16 @@ pub async fn get_follow_request_by_activity_id(
db_client: &impl DatabaseClient,
activity_id: &str,
) -> Result<DbFollowRequest, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT follow_request
FROM follow_request
WHERE activity_id = $1
",
&[&activity_id],
).await?;
&[&activity_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?;
let request = row.try_get("follow_request")?;
Ok(request)
@ -298,8 +306,9 @@ pub async fn get_followers(
db_client: &impl DatabaseClient,
profile_id: &Uuid,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT actor_profile
FROM actor_profile
JOIN relationship
@ -308,9 +317,11 @@ pub async fn get_followers(
relationship.target_id = $1
AND relationship.relationship_type = $2
",
&[&profile_id, &RelationshipType::Follow],
).await?;
let profiles = rows.iter()
&[&profile_id, &RelationshipType::Follow],
)
.await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
@ -322,8 +333,9 @@ pub async fn get_followers_paginated(
max_relationship_id: Option<i32>,
limit: u16,
) -> Result<Vec<RelatedActorProfile>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT relationship.id, actor_profile
FROM actor_profile
JOIN relationship
@ -335,14 +347,16 @@ pub async fn get_followers_paginated(
ORDER BY relationship.id DESC
LIMIT $4
",
&[
&profile_id,
&RelationshipType::Follow,
&max_relationship_id,
&i64::from(limit),
],
).await?;
let related_profiles = rows.iter()
&[
&profile_id,
&RelationshipType::Follow,
&max_relationship_id,
&i64::from(limit),
],
)
.await?;
let related_profiles = rows
.iter()
.map(RelatedActorProfile::try_from)
.collect::<Result<_, _>>()?;
Ok(related_profiles)
@ -352,8 +366,9 @@ pub async fn has_local_followers(
db_client: &impl DatabaseClient,
actor_id: &str,
) -> Result<bool, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT 1
FROM relationship
JOIN actor_profile ON (relationship.target_id = actor_profile.id)
@ -362,8 +377,9 @@ pub async fn has_local_followers(
AND relationship_type = $2
LIMIT 1
",
&[&actor_id, &RelationshipType::Follow]
).await?;
&[&actor_id, &RelationshipType::Follow],
)
.await?;
Ok(maybe_row.is_some())
}
@ -371,8 +387,9 @@ pub async fn get_following(
db_client: &impl DatabaseClient,
profile_id: &Uuid,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT actor_profile
FROM actor_profile
JOIN relationship
@ -381,9 +398,11 @@ pub async fn get_following(
relationship.source_id = $1
AND relationship.relationship_type = $2
",
&[&profile_id, &RelationshipType::Follow],
).await?;
let profiles = rows.iter()
&[&profile_id, &RelationshipType::Follow],
)
.await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
@ -395,8 +414,9 @@ pub async fn get_following_paginated(
max_relationship_id: Option<i32>,
limit: u16,
) -> Result<Vec<RelatedActorProfile>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT relationship.id, actor_profile
FROM actor_profile
JOIN relationship
@ -408,14 +428,16 @@ pub async fn get_following_paginated(
ORDER BY relationship.id DESC
LIMIT $4
",
&[
&profile_id,
&RelationshipType::Follow,
&max_relationship_id,
&i64::from(limit),
],
).await?;
let related_profiles = rows.iter()
&[
&profile_id,
&RelationshipType::Follow,
&max_relationship_id,
&i64::from(limit),
],
)
.await?;
let related_profiles = rows
.iter()
.map(RelatedActorProfile::try_from)
.collect::<Result<_, _>>()?;
Ok(related_profiles)
@ -427,13 +449,16 @@ pub async fn subscribe(
target_id: &Uuid,
) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?;
transaction.execute(
"
transaction
.execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3)
",
&[&source_id, &target_id, &RelationshipType::Subscription],
).await.map_err(catch_unique_violation("relationship"))?;
&[&source_id, &target_id, &RelationshipType::Subscription],
)
.await
.map_err(catch_unique_violation("relationship"))?;
update_subscriber_count(&transaction, target_id, 1).await?;
transaction.commit().await?;
Ok(())
@ -445,14 +470,16 @@ pub async fn subscribe_opt(
target_id: &Uuid,
) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?;
let inserted_count = transaction.execute(
"
let inserted_count = transaction
.execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3)
ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING
",
&[&source_id, &target_id, &RelationshipType::Subscription],
).await?;
&[&source_id, &target_id, &RelationshipType::Subscription],
)
.await?;
if inserted_count > 0 {
update_subscriber_count(&transaction, target_id, 1).await?;
};
@ -466,15 +493,17 @@ pub async fn unsubscribe(
target_id: &Uuid,
) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?;
let deleted_count = transaction.execute(
"
let deleted_count = transaction
.execute(
"
DELETE FROM relationship
WHERE
source_id = $1 AND target_id = $2
AND relationship_type = $3
",
&[&source_id, &target_id, &RelationshipType::Subscription],
).await?;
&[&source_id, &target_id, &RelationshipType::Subscription],
)
.await?;
if deleted_count == 0 {
return Err(DatabaseError::NotFound("relationship"));
};
@ -487,8 +516,9 @@ pub async fn get_subscribers(
db_client: &impl DatabaseClient,
profile_id: &Uuid,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT actor_profile
FROM actor_profile
JOIN relationship
@ -498,9 +528,11 @@ pub async fn get_subscribers(
AND relationship.relationship_type = $2
ORDER BY relationship.id DESC
",
&[&profile_id, &RelationshipType::Subscription],
).await?;
let profiles = rows.iter()
&[&profile_id, &RelationshipType::Subscription],
)
.await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
@ -511,14 +543,16 @@ pub async fn hide_reposts(
source_id: &Uuid,
target_id: &Uuid,
) -> Result<(), DatabaseError> {
db_client.execute(
"
db_client
.execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3)
ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING
",
&[&source_id, &target_id, &RelationshipType::HideReposts],
).await?;
&[&source_id, &target_id, &RelationshipType::HideReposts],
)
.await?;
Ok(())
}
@ -528,15 +562,17 @@ pub async fn show_reposts(
target_id: &Uuid,
) -> Result<(), DatabaseError> {
// Does not return NotFound error
db_client.execute(
"
db_client
.execute(
"
DELETE FROM relationship
WHERE
source_id = $1 AND target_id = $2
AND relationship_type = $3
",
&[&source_id, &target_id, &RelationshipType::HideReposts],
).await?;
&[&source_id, &target_id, &RelationshipType::HideReposts],
)
.await?;
Ok(())
}
@ -545,14 +581,16 @@ pub async fn hide_replies(
source_id: &Uuid,
target_id: &Uuid,
) -> Result<(), DatabaseError> {
db_client.execute(
"
db_client
.execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3)
ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING
",
&[&source_id, &target_id, &RelationshipType::HideReplies],
).await?;
&[&source_id, &target_id, &RelationshipType::HideReplies],
)
.await?;
Ok(())
}
@ -562,34 +600,30 @@ pub async fn show_replies(
target_id: &Uuid,
) -> Result<(), DatabaseError> {
// Does not return NotFound error
db_client.execute(
"
db_client
.execute(
"
DELETE FROM relationship
WHERE
source_id = $1 AND target_id = $2
AND relationship_type = $3
",
&[&source_id, &target_id, &RelationshipType::HideReplies],
).await?;
&[&source_id, &target_id, &RelationshipType::HideReplies],
)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use serial_test::serial;
use crate::database::{
test_utils::create_test_database,
DatabaseError,
};
use super::*;
use crate::database::{test_utils::create_test_database, DatabaseError};
use crate::profiles::{
queries::create_profile,
types::{DbActor, ProfileCreateData},
};
use crate::users::{
queries::create_user,
types::UserCreateData,
};
use super::*;
use crate::users::{queries::create_user, types::UserCreateData};
use serial_test::serial;
#[tokio::test]
#[serial]
@ -614,7 +648,8 @@ mod tests {
let target = create_profile(db_client, target_data).await.unwrap();
// Create follow request
let follow_request = create_follow_request(db_client, &source.id, &target.id)
.await.unwrap();
.await
.unwrap();
assert_eq!(follow_request.source_id, source.id);
assert_eq!(follow_request.target_id, target.id);
assert_eq!(follow_request.activity_id, None);
@ -622,22 +657,27 @@ mod tests {
let following = get_following(db_client, &source.id).await.unwrap();
assert!(following.is_empty());
// Accept follow request
follow_request_accepted(db_client, &follow_request.id).await.unwrap();
follow_request_accepted(db_client, &follow_request.id)
.await
.unwrap();
let follow_request = get_follow_request_by_id(db_client, &follow_request.id)
.await.unwrap();
.await
.unwrap();
assert_eq!(follow_request.request_status, FollowRequestStatus::Accepted);
let following = get_following(db_client, &source.id).await.unwrap();
assert_eq!(following[0].id, target.id);
let target_has_followers =
has_local_followers(db_client, target_actor_id).await.unwrap();
let target_has_followers = has_local_followers(db_client, target_actor_id)
.await
.unwrap();
assert_eq!(target_has_followers, true);
// Unfollow
let follow_request_id = unfollow(db_client, &source.id, &target.id)
.await.unwrap().unwrap();
.await
.unwrap()
.unwrap();
assert_eq!(follow_request_id, follow_request.id);
let follow_request_result =
get_follow_request_by_id(db_client, &follow_request_id).await;
let follow_request_result = get_follow_request_by_id(db_client, &follow_request_id).await;
assert!(matches!(
follow_request_result,
Err(DatabaseError::NotFound("follow request")),
@ -665,21 +705,26 @@ mod tests {
let target = create_user(db_client, target_data).await.unwrap();
// Create follow request
let activity_id = "https://example.org/objects/123";
let _follow_request = create_remote_follow_request_opt(
db_client, &source.id, &target.id, activity_id,
).await.unwrap();
let _follow_request =
create_remote_follow_request_opt(db_client, &source.id, &target.id, activity_id)
.await
.unwrap();
// Repeat
let follow_request = create_remote_follow_request_opt(
db_client, &source.id, &target.id, activity_id,
).await.unwrap();
let follow_request =
create_remote_follow_request_opt(db_client, &source.id, &target.id, activity_id)
.await
.unwrap();
assert_eq!(follow_request.source_id, source.id);
assert_eq!(follow_request.target_id, target.id);
assert_eq!(follow_request.activity_id, Some(activity_id.to_string()));
assert_eq!(follow_request.request_status, FollowRequestStatus::Pending);
// Accept follow request
follow_request_accepted(db_client, &follow_request.id).await.unwrap();
follow_request_accepted(db_client, &follow_request.id)
.await
.unwrap();
let follow_request = get_follow_request_by_id(db_client, &follow_request.id)
.await.unwrap();
.await
.unwrap();
assert_eq!(follow_request.request_status, FollowRequestStatus::Accepted);
}
}

View file

@ -5,8 +5,7 @@ use uuid::Uuid;
use crate::{
database::{
int_enum::{int_enum_from_sql, int_enum_to_sql},
DatabaseError,
DatabaseTypeError,
DatabaseError, DatabaseTypeError,
},
profiles::types::DbActorProfile,
};
@ -58,11 +57,7 @@ pub struct DbRelationship {
}
impl DbRelationship {
pub fn is_direct(
&self,
source_id: &Uuid,
target_id: &Uuid,
) -> Result<bool, DatabaseTypeError> {
pub fn is_direct(&self, source_id: &Uuid, target_id: &Uuid) -> Result<bool, DatabaseTypeError> {
if &self.source_id == source_id && &self.target_id == target_id {
Ok(true)
} else if &self.source_id == target_id && &self.target_id == source_id {
@ -74,7 +69,6 @@ impl DbRelationship {
}
impl TryFrom<&Row> for DbRelationship {
type Error = tokio_postgres::Error;
fn try_from(row: &Row) -> Result<Self, Self::Error> {
@ -97,7 +91,7 @@ pub enum FollowRequestStatus {
impl From<&FollowRequestStatus> for i16 {
fn from(value: &FollowRequestStatus) -> i16 {
match value {
FollowRequestStatus::Pending => 1,
FollowRequestStatus::Pending => 1,
FollowRequestStatus::Accepted => 2,
FollowRequestStatus::Rejected => 3,
}
@ -137,12 +131,14 @@ pub struct RelatedActorProfile {
}
impl TryFrom<&Row> for RelatedActorProfile {
type Error = DatabaseError;
fn try_from(row: &Row) -> Result<Self, Self::Error> {
let relationship_id = row.try_get("id")?;
let profile = row.try_get("actor_profile")?;
Ok(Self { relationship_id, profile })
Ok(Self {
relationship_id,
profile,
})
}
}

View file

@ -3,11 +3,7 @@ use uuid::Uuid;
use mitra_utils::caip2::ChainId;
use crate::database::{
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
use crate::invoices::types::DbChainId;
use crate::profiles::types::PaymentType;
use crate::relationships::{
@ -28,8 +24,9 @@ pub async fn create_subscription(
) -> Result<(), DatabaseError> {
assert!(chain_id.is_ethereum() == sender_address.is_some());
let mut transaction = db_client.transaction().await?;
transaction.execute(
"
transaction
.execute(
"
INSERT INTO subscription (
sender_id,
sender_address,
@ -40,15 +37,17 @@ pub async fn create_subscription(
)
VALUES ($1, $2, $3, $4, $5, $6)
",
&[
&sender_id,
&sender_address,
&recipient_id,
&DbChainId::new(chain_id),
&expires_at,
&updated_at,
],
).await.map_err(catch_unique_violation("subscription"))?;
&[
&sender_id,
&sender_address,
&recipient_id,
&DbChainId::new(chain_id),
&expires_at,
&updated_at,
],
)
.await
.map_err(catch_unique_violation("subscription"))?;
subscribe(&mut transaction, sender_id, recipient_id).await?;
transaction.commit().await?;
Ok(())
@ -61,8 +60,9 @@ pub async fn update_subscription(
updated_at: &DateTime<Utc>,
) -> Result<(), DatabaseError> {
let mut transaction = db_client.transaction().await?;
let maybe_row = transaction.query_opt(
"
let maybe_row = transaction
.query_opt(
"
UPDATE subscription
SET
expires_at = $2,
@ -70,12 +70,9 @@ pub async fn update_subscription(
WHERE id = $1
RETURNING sender_id, recipient_id
",
&[
&subscription_id,
&expires_at,
&updated_at,
],
).await?;
&[&subscription_id, &expires_at, &updated_at],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("subscription"))?;
let sender_id: Uuid = row.try_get("sender_id")?;
let recipient_id: Uuid = row.try_get("recipient_id")?;
@ -91,14 +88,16 @@ pub async fn get_subscription_by_participants(
sender_id: &Uuid,
recipient_id: &Uuid,
) -> Result<DbSubscription, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT subscription
FROM subscription
WHERE sender_id = $1 AND recipient_id = $2
",
&[sender_id, recipient_id],
).await?;
&[sender_id, recipient_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("subscription"))?;
let subscription: DbSubscription = row.try_get("subscription")?;
Ok(subscription)
@ -107,8 +106,9 @@ pub async fn get_subscription_by_participants(
pub async fn get_expired_subscriptions(
db_client: &impl DatabaseClient,
) -> Result<Vec<DbSubscription>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT subscription
FROM subscription
JOIN relationship
@ -119,9 +119,11 @@ pub async fn get_expired_subscriptions(
)
WHERE subscription.expires_at <= CURRENT_TIMESTAMP
",
&[&RelationshipType::Subscription],
).await?;
let subscriptions = rows.iter()
&[&RelationshipType::Subscription],
)
.await?;
let subscriptions = rows
.iter()
.map(|row| row.try_get("subscription"))
.collect::<Result<_, _>>()?;
Ok(subscriptions)
@ -133,8 +135,9 @@ pub async fn get_incoming_subscriptions(
max_subscription_id: Option<i32>,
limit: u16,
) -> Result<Vec<Subscription>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT subscription, actor_profile AS sender
FROM actor_profile
JOIN subscription
@ -145,9 +148,11 @@ pub async fn get_incoming_subscriptions(
ORDER BY subscription.id DESC
LIMIT $3
",
&[&recipient_id, &max_subscription_id, &i64::from(limit)],
).await?;
let subscriptions = rows.iter()
&[&recipient_id, &max_subscription_id, &i64::from(limit)],
)
.await?;
let subscriptions = rows
.iter()
.map(Subscription::try_from)
.collect::<Result<_, _>>()?;
Ok(subscriptions)
@ -161,8 +166,9 @@ pub async fn reset_subscriptions(
if ethereum_contract_replaced {
// Ethereum subscription configuration is stored in contract.
// If contract is replaced, payment option needs to be deleted.
transaction.execute(
"
transaction
.execute(
"
UPDATE actor_profile
SET payment_options = '[]'
WHERE
@ -174,19 +180,22 @@ pub async fn reset_subscriptions(
WHERE CAST(option ->> 'payment_type' AS SMALLINT) = $1
)
",
&[&i16::from(&PaymentType::EthereumSubscription)],
).await?;
&[&i16::from(&PaymentType::EthereumSubscription)],
)
.await?;
};
transaction.execute(
"
transaction
.execute(
"
DELETE FROM relationship
WHERE relationship_type = $1
",
&[&RelationshipType::Subscription],
).await?;
transaction.execute(
"UPDATE actor_profile SET subscriber_count = 0", &[],
).await?;
&[&RelationshipType::Subscription],
)
.await?;
transaction
.execute("UPDATE actor_profile SET subscriber_count = 0", &[])
.await?;
transaction.execute("DELETE FROM subscription", &[]).await?;
transaction.commit().await?;
Ok(())
@ -194,21 +203,12 @@ pub async fn reset_subscriptions(
#[cfg(test)]
mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use crate::profiles::{
queries::create_profile,
types::ProfileCreateData,
};
use crate::relationships::{
queries::has_relationship,
types::RelationshipType,
};
use crate::users::{
queries::create_user,
types::UserCreateData,
};
use super::*;
use crate::database::test_utils::create_test_database;
use crate::profiles::{queries::create_profile, types::ProfileCreateData};
use crate::relationships::{queries::has_relationship, types::RelationshipType};
use crate::users::{queries::create_user, types::UserCreateData};
use serial_test::serial;
#[tokio::test]
#[serial]
@ -237,14 +237,18 @@ mod tests {
&chain_id,
&expires_at,
&updated_at,
).await.unwrap();
)
.await
.unwrap();
let is_subscribed = has_relationship(
db_client,
&sender.id,
&recipient.id,
RelationshipType::Subscription,
).await.unwrap();
)
.await
.unwrap();
assert_eq!(is_subscribed, true);
}
}

View file

@ -27,7 +27,6 @@ pub struct Subscription {
}
impl TryFrom<&Row> for Subscription {
type Error = DatabaseError;
fn try_from(row: &Row) -> Result<Self, Self::Error> {

View file

@ -6,16 +6,19 @@ pub async fn search_tags(
limit: u16,
) -> Result<Vec<String>, DatabaseError> {
let db_search_query = format!("%{}%", search_query);
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT tag_name
FROM tag
WHERE tag_name ILIKE $1
LIMIT $2
",
&[&db_search_query, &i64::from(limit)],
).await?;
let tags: Vec<String> = rows.iter()
&[&db_search_query, &i64::from(limit)],
)
.await?;
let tags: Vec<String> = rows
.iter()
.map(|row| row.try_get("tag_name"))
.collect::<Result<_, _>>()?;
Ok(tags)

View file

@ -1,30 +1,16 @@
use serde_json::{Value as JsonValue};
use serde_json::Value as JsonValue;
use uuid::Uuid;
use mitra_utils::{
currencies::Currency,
did::Did,
did_pkh::DidPkh,
};
use mitra_utils::{currencies::Currency, did::Did, did_pkh::DidPkh};
use crate::database::{
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
use crate::profiles::{
queries::create_profile,
types::{DbActorProfile, ProfileCreateData},
};
use super::types::{
ClientConfig,
DbClientConfig,
DbInviteCode,
DbUser,
Role,
User,
UserCreateData,
ClientConfig, DbClientConfig, DbInviteCode, DbUser, Role, User, UserCreateData,
};
use super::utils::generate_invite_code;
@ -33,28 +19,33 @@ pub async fn create_invite_code(
note: Option<&str>,
) -> Result<String, DatabaseError> {
let invite_code = generate_invite_code();
db_client.execute(
"
db_client
.execute(
"
INSERT INTO user_invite_code (code, note)
VALUES ($1, $2)
",
&[&invite_code, &note],
).await?;
&[&invite_code, &note],
)
.await?;
Ok(invite_code)
}
pub async fn get_invite_codes(
db_client: &impl DatabaseClient,
) -> Result<Vec<DbInviteCode>, DatabaseError> {
let rows = db_client.query(
"
let rows = db_client
.query(
"
SELECT user_invite_code
FROM user_invite_code
WHERE used = FALSE
",
&[],
).await?;
let codes = rows.iter()
&[],
)
.await?;
let codes = rows
.iter()
.map(|row| row.try_get("user_invite_code"))
.collect::<Result<_, _>>()?;
Ok(codes)
@ -64,13 +55,15 @@ pub async fn is_valid_invite_code(
db_client: &impl DatabaseClient,
invite_code: &str,
) -> Result<bool, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT 1 FROM user_invite_code
WHERE code = $1 AND used = FALSE
",
&[&invite_code],
).await?;
&[&invite_code],
)
.await?;
Ok(maybe_row.is_some())
}
@ -78,37 +71,39 @@ pub async fn create_user(
db_client: &mut impl DatabaseClient,
user_data: UserCreateData,
) -> Result<User, DatabaseError> {
assert!(user_data.password_hash.is_some() ||
user_data.wallet_address.is_some());
assert!(user_data.password_hash.is_some() || user_data.wallet_address.is_some());
let mut transaction = db_client.transaction().await?;
// Prevent changes to actor_profile table
transaction.execute(
"LOCK TABLE actor_profile IN EXCLUSIVE MODE",
&[],
).await?;
transaction
.execute("LOCK TABLE actor_profile IN EXCLUSIVE MODE", &[])
.await?;
// Ensure there are no local accounts with a similar name
let maybe_row = transaction.query_opt(
"
let maybe_row = transaction
.query_opt(
"
SELECT 1
FROM user_account JOIN actor_profile USING (id)
WHERE actor_profile.username ILIKE $1
LIMIT 1
",
&[&user_data.username],
).await?;
&[&user_data.username],
)
.await?;
if maybe_row.is_some() {
return Err(DatabaseError::AlreadyExists("user"));
};
// Use invite code
if let Some(ref invite_code) = user_data.invite_code {
let updated_count = transaction.execute(
"
let updated_count = transaction
.execute(
"
UPDATE user_invite_code
SET used = TRUE
WHERE code = $1 AND used = FALSE
",
&[&invite_code],
).await?;
&[&invite_code],
)
.await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("invite code"));
};
@ -131,8 +126,9 @@ pub async fn create_user(
};
let profile = create_profile(&mut transaction, profile_data).await?;
// Create user
let row = transaction.query_one(
"
let row = transaction
.query_one(
"
INSERT INTO user_account (
id,
wallet_address,
@ -144,15 +140,17 @@ pub async fn create_user(
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING user_account
",
&[
&profile.id,
&user_data.wallet_address,
&user_data.password_hash,
&user_data.private_key_pem,
&user_data.invite_code,
&user_data.role,
],
).await.map_err(catch_unique_violation("user"))?;
&[
&profile.id,
&user_data.wallet_address,
&user_data.password_hash,
&user_data.private_key_pem,
&user_data.invite_code,
&user_data.role,
],
)
.await
.map_err(catch_unique_violation("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let user = User::new(db_user, profile);
transaction.commit().await?;
@ -164,13 +162,15 @@ pub async fn set_user_password(
user_id: &Uuid,
password_hash: String,
) -> Result<(), DatabaseError> {
let updated_count = db_client.execute(
"
let updated_count = db_client
.execute(
"
UPDATE user_account SET password_hash = $1
WHERE id = $2
",
&[&password_hash, &user_id],
).await?;
&[&password_hash, &user_id],
)
.await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("user"));
};
@ -182,13 +182,15 @@ pub async fn set_user_role(
user_id: &Uuid,
role: Role,
) -> Result<(), DatabaseError> {
let updated_count = db_client.execute(
"
let updated_count = db_client
.execute(
"
UPDATE user_account SET user_role = $1
WHERE id = $2
",
&[&role, &user_id],
).await?;
&[&role, &user_id],
)
.await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("user"));
};
@ -201,15 +203,17 @@ pub async fn update_client_config(
client_name: &str,
client_config_value: &JsonValue,
) -> Result<ClientConfig, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
UPDATE user_account
SET client_config = jsonb_set(client_config, ARRAY[$1], $2, true)
WHERE id = $3
RETURNING client_config
",
&[&client_name, &client_config_value, &user_id],
).await?;
&[&client_name, &client_config_value, &user_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let client_config: DbClientConfig = row.try_get("client_config")?;
Ok(client_config.into_inner())
@ -219,14 +223,16 @@ pub async fn get_user_by_id(
db_client: &impl DatabaseClient,
user_id: &Uuid,
) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id)
WHERE id = $1
",
&[&user_id],
).await?;
&[&user_id],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -238,14 +244,16 @@ pub async fn get_user_by_name(
db_client: &impl DatabaseClient,
username: &str,
) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id)
WHERE actor_profile.username = $1
",
&[&username],
).await?;
&[&username],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -257,13 +265,15 @@ pub async fn is_registered_user(
db_client: &impl DatabaseClient,
username: &str,
) -> Result<bool, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT 1 FROM user_account JOIN actor_profile USING (id)
WHERE actor_profile.username = $1
",
&[&username],
).await?;
&[&username],
)
.await?;
Ok(maybe_row.is_some())
}
@ -271,14 +281,16 @@ pub async fn get_user_by_login_address(
db_client: &impl DatabaseClient,
wallet_address: &str,
) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id)
WHERE wallet_address = $1
",
&[&wallet_address],
).await?;
&[&wallet_address],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -291,8 +303,9 @@ pub async fn get_user_by_did(
did: &Did,
) -> Result<User, DatabaseError> {
// DIDs must be locally unique
let maybe_row = db_client.query_opt(
"
let maybe_row = db_client
.query_opt(
"
SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id)
WHERE
@ -302,8 +315,9 @@ pub async fn get_user_by_did(
WHERE proof ->> 'issuer' = $1
)
",
&[&did.to_string()],
).await?;
&[&did.to_string()],
)
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -321,24 +335,21 @@ pub async fn get_user_by_public_wallet_address(
get_user_by_did(db_client, &did).await
}
pub async fn get_user_count(
db_client: &impl DatabaseClient,
) -> Result<i64, DatabaseError> {
let row = db_client.query_one(
"SELECT count(user_account) FROM user_account",
&[],
).await?;
pub async fn get_user_count(db_client: &impl DatabaseClient) -> Result<i64, DatabaseError> {
let row = db_client
.query_one("SELECT count(user_account) FROM user_account", &[])
.await?;
let count = row.try_get("count")?;
Ok(count)
}
#[cfg(test)]
mod tests {
use serde_json::json;
use serial_test::serial;
use super::*;
use crate::database::test_utils::create_test_database;
use crate::users::types::Role;
use super::*;
use serde_json::json;
use serial_test::serial;
#[tokio::test]
#[serial]
@ -392,7 +403,9 @@ mod tests {
};
let user = create_user(db_client, user_data).await.unwrap();
assert_eq!(user.role, Role::NormalUser);
set_user_role(db_client, &user.id, Role::ReadOnlyUser).await.unwrap();
set_user_role(db_client, &user.id, Role::ReadOnlyUser)
.await
.unwrap();
let user = get_user_by_id(db_client, &user.id).await.unwrap();
assert_eq!(user.role, Role::ReadOnlyUser);
}
@ -410,12 +423,10 @@ mod tests {
assert_eq!(user.client_config.is_empty(), true);
let client_name = "test";
let client_config_value = json!({"a": 1});
let client_config = update_client_config(
db_client,
&user.id,
client_name,
&client_config_value,
).await.unwrap();
let client_config =
update_client_config(db_client, &user.id, client_name, &client_config_value)
.await
.unwrap();
assert_eq!(
client_config.get(client_name).unwrap(),
&client_config_value,

View file

@ -3,13 +3,10 @@ use std::collections::HashMap;
use chrono::{DateTime, Utc};
use postgres_types::FromSql;
use serde::Deserialize;
use serde_json::{Value as JsonValue};
use serde_json::Value as JsonValue;
use uuid::Uuid;
use mitra_utils::{
currencies::Currency,
did::Did,
};
use mitra_utils::{currencies::Currency, did::Did};
use crate::database::{
int_enum::{int_enum_from_sql, int_enum_to_sql},
@ -46,7 +43,9 @@ pub enum Role {
}
impl Default for Role {
fn default() -> Self { Self::NormalUser }
fn default() -> Self {
Self::NormalUser
}
}
impl Role {
@ -65,9 +64,7 @@ impl Role {
Permission::DeleteAnyProfile,
Permission::ManageSubscriptionOptions,
],
Self::ReadOnlyUser => vec![
Permission::CreateFollowRequest,
],
Self::ReadOnlyUser => vec![Permission::CreateFollowRequest],
}
}
@ -147,10 +144,7 @@ pub struct User {
}
impl User {
pub fn new(
db_user: DbUser,
db_profile: DbActorProfile,
) -> Self {
pub fn new(db_user: DbUser, db_profile: DbActorProfile) -> Self {
assert_eq!(db_user.id, db_profile.id);
Self {
id: db_user.id,
@ -177,7 +171,7 @@ impl User {
return Some(did_pkh.address);
};
};
};
}
None
}
}

View file

@ -3,13 +3,7 @@ use std::fmt;
use std::str::FromStr;
use regex::Regex;
use serde::{
Deserialize,
Deserializer,
Serialize,
Serializer,
de::Error as DeserializerError,
};
use serde::{de::Error as DeserializerError, Deserialize, Deserializer, Serialize, Serializer};
use super::currencies::Currency;
@ -52,7 +46,9 @@ impl ChainId {
if !self.is_ethereum() {
return Err(ChainIdError("namespace is not eip155"));
};
let chain_id: u32 = self.reference.parse()
let chain_id: u32 = self
.reference
.parse()
.map_err(|_| ChainIdError("invalid EIP-155 chain ID"))?;
Ok(chain_id)
}
@ -76,7 +72,8 @@ impl FromStr for ChainId {
fn from_str(value: &str) -> Result<Self, Self::Err> {
let caip2_re = Regex::new(CAIP2_RE).unwrap();
let caps = caip2_re.captures(value)
let caps = caip2_re
.captures(value)
.ok_or(ChainIdError("invalid chain ID"))?;
let chain_id = Self {
namespace: caps["namespace"].to_string(),
@ -94,7 +91,8 @@ impl fmt::Display for ChainId {
impl Serialize for ChainId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
@ -102,10 +100,12 @@ impl Serialize for ChainId {
impl<'de> Deserialize<'de> for ChainId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>
where
D: Deserializer<'de>,
{
String::deserialize(deserializer)?
.parse().map_err(DeserializerError::custom)
.parse()
.map_err(DeserializerError::custom)
}
}

View file

@ -5,9 +5,7 @@ use serde::Serialize;
pub struct CanonicalizationError(#[from] serde_json::Error);
/// JCS: https://www.rfc-editor.org/rfc/rfc8785
pub fn canonicalize_object(
object: &impl Serialize,
) -> Result<String, CanonicalizationError> {
pub fn canonicalize_object(object: &impl Serialize) -> Result<String, CanonicalizationError> {
let object_str = serde_jcs::to_string(object)?;
Ok(object_str)
}

View file

@ -1,5 +1,5 @@
use rsa::{Hash, PaddingScheme, PublicKey, RsaPrivateKey, RsaPublicKey};
use rsa::pkcs8::{FromPrivateKey, FromPublicKey, ToPrivateKey, ToPublicKey};
use rsa::{Hash, PaddingScheme, PublicKey, RsaPrivateKey, RsaPublicKey};
use sha2::{Digest, Sha256};
pub fn generate_rsa_key() -> Result<RsaPrivateKey, rsa::errors::Error> {
@ -16,32 +16,24 @@ pub fn generate_weak_rsa_key() -> Result<RsaPrivateKey, rsa::errors::Error> {
RsaPrivateKey::new(&mut rng, bits)
}
pub fn serialize_private_key(
private_key: &RsaPrivateKey,
) -> Result<String, rsa::pkcs8::Error> {
pub fn serialize_private_key(private_key: &RsaPrivateKey) -> Result<String, rsa::pkcs8::Error> {
private_key.to_pkcs8_pem().map(|val| val.to_string())
}
pub fn deserialize_private_key(
private_key_pem: &str,
) -> Result<RsaPrivateKey, rsa::pkcs8::Error> {
pub fn deserialize_private_key(private_key_pem: &str) -> Result<RsaPrivateKey, rsa::pkcs8::Error> {
RsaPrivateKey::from_pkcs8_pem(private_key_pem)
}
pub fn get_public_key_pem(
private_key: &RsaPrivateKey,
) -> Result<String, rsa::pkcs8::Error> {
pub fn get_public_key_pem(private_key: &RsaPrivateKey) -> Result<String, rsa::pkcs8::Error> {
let public_key = RsaPublicKey::from(private_key);
public_key.to_public_key_pem()
}
pub fn deserialize_public_key(
public_key_pem: &str,
) -> Result<RsaPublicKey, rsa::pkcs8::Error> {
pub fn deserialize_public_key(public_key_pem: &str) -> Result<RsaPublicKey, rsa::pkcs8::Error> {
// rsa package can't decode PEM string with non-standard wrap width,
// so the input should be normalized first
let parsed_pem = pem::parse(public_key_pem.trim().as_bytes())
.map_err(|_| rsa::pkcs8::Error::Pem)?;
let parsed_pem =
pem::parse(public_key_pem.trim().as_bytes()).map_err(|_| rsa::pkcs8::Error::Pem)?;
let normalized_pem = pem::encode(&parsed_pem);
RsaPublicKey::from_public_key_pem(&normalized_pem)
}
@ -70,11 +62,7 @@ pub fn verify_rsa_sha256_signature(
) -> bool {
let digest = Sha256::digest(message.as_bytes());
let padding = PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA2_256));
let is_valid = public_key.verify(
padding,
&digest,
signature,
).is_ok();
let is_valid = public_key.verify(padding, &digest, signature).is_ok();
is_valid
}
@ -102,17 +90,10 @@ YsFtrgWDQ/s8k86sNBU+Ce2GOL7seh46kyAWgJeohh4Rcrr23rftHbvxOcRM8VzYuCeb1DgVhPGtA0xU
fn test_verify_rsa_signature() {
let private_key = generate_weak_rsa_key().unwrap();
let message = "test".to_string();
let signature = create_rsa_sha256_signature(
&private_key,
&message,
).unwrap();
let signature = create_rsa_sha256_signature(&private_key, &message).unwrap();
let public_key = RsaPublicKey::from(&private_key);
let is_valid = verify_rsa_sha256_signature(
&public_key,
&message,
&signature,
);
let is_valid = verify_rsa_sha256_signature(&public_key, &message, &signature);
assert_eq!(is_valid, true);
}
}

View file

@ -9,7 +9,8 @@ impl Currency {
match self {
Self::Ethereum => "ETH",
Self::Monero => "XMR",
}.to_string()
}
.to_string()
}
pub fn field_name(&self) -> String {

View file

@ -1,8 +1,7 @@
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
pub fn get_min_datetime() -> DateTime<Utc> {
let native = NaiveDateTime::from_timestamp_opt(0, 0)
.expect("0 should be a valid argument");
let native = NaiveDateTime::from_timestamp_opt(0, 0).expect("0 should be a valid argument");
DateTime::from_utc(native, Utc)
}

View file

@ -3,10 +3,7 @@ use std::fmt;
use std::str::FromStr;
use regex::Regex;
use serde::{
Deserialize, Deserializer, Serialize, Serializer,
de::Error as DeserializerError,
};
use serde::{de::Error as DeserializerError, Deserialize, Deserializer, Serialize, Serializer};
use super::did_key::DidKey;
use super::did_pkh::DidPkh;
@ -33,11 +30,11 @@ impl FromStr for Did {
"key" => {
let did_key = DidKey::from_str(value)?;
Self::Key(did_key)
},
}
"pkh" => {
let did_pkh = DidPkh::from_str(value)?;
Self::Pkh(did_pkh)
},
}
_ => return Err(DidParseError),
};
Ok(did)
@ -56,7 +53,8 @@ impl fmt::Display for Did {
impl<'de> Deserialize<'de> for Did {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>
where
D: Deserializer<'de>,
{
let did_str: String = Deserialize::deserialize(deserializer)?;
did_str.parse().map_err(DeserializerError::custom)
@ -65,7 +63,8 @@ impl<'de> Deserialize<'de> for Did {
impl Serialize for Did {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer
where
S: Serializer,
{
let did_str = self.to_string();
serializer.serialize_str(&did_str)

View file

@ -6,10 +6,7 @@ use regex::Regex;
use super::{
did::DidParseError,
multibase::{
decode_multibase_base58btc,
encode_multibase_base58btc,
},
multibase::{decode_multibase_base58btc, encode_multibase_base58btc},
};
const DID_KEY_RE: &str = r"did:key:(?P<key>z[a-km-zA-HJ-NP-Z1-9]+)";
@ -34,10 +31,7 @@ impl DidKey {
}
pub fn from_ed25519_key(key: [u8; 32]) -> Self {
let prefixed_key = [
MULTICODEC_ED25519_PREFIX.to_vec(),
key.to_vec(),
].concat();
let prefixed_key = [MULTICODEC_ED25519_PREFIX.to_vec(), key.to_vec()].concat();
Self { key: prefixed_key }
}
@ -62,8 +56,7 @@ impl FromStr for DidKey {
fn from_str(value: &str) -> Result<Self, Self::Err> {
let did_key_re = Regex::new(DID_KEY_RE).unwrap();
let caps = did_key_re.captures(value).ok_or(DidParseError)?;
let key = decode_multibase_base58btc(&caps["key"])
.map_err(|_| DidParseError)?;
let key = decode_multibase_base58btc(&caps["key"]).map_err(|_| DidParseError)?;
let did_key = Self { key };
Ok(did_key)
}

View file

@ -4,11 +4,7 @@ use std::str::FromStr;
use regex::Regex;
use super::{
caip2::ChainId,
currencies::Currency,
did::DidParseError,
};
use super::{caip2::ChainId, currencies::Currency, did::DidParseError};
// https://github.com/ChainAgnostic/CAIPs/blob/master/CAIPs/caip-10.md#syntax
const DID_PKH_RE: &str = r"did:pkh:(?P<network>[-a-z0-9]{3,8}):(?P<chain>[-a-zA-Z0-9]{1,32}):(?P<address>[a-zA-Z0-9]{1,64})";
@ -38,9 +34,7 @@ impl fmt::Display for DidPkh {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
let did_str = format!(
"did:pkh:{}:{}:{}",
self.chain_id.namespace,
self.chain_id.reference,
self.address,
self.chain_id.namespace, self.chain_id.reference, self.address,
);
write!(formatter, "{}", did_str)
}

View file

@ -1,10 +1,6 @@
use std::fs::{
set_permissions,
File,
Permissions,
};
use std::io::Error;
use std::fs::{set_permissions, File, Permissions};
use std::io::prelude::*;
use std::io::Error;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
@ -19,11 +15,9 @@ pub fn get_media_type_extension(media_type: &str) -> Option<&'static str> {
match media_type {
// Override extension provided by mime_guess
"image/jpeg" => Some("jpg"),
_ => {
get_mime_extensions_str(media_type)
.and_then(|extensions| extensions.first())
.copied()
}
_ => get_mime_extensions_str(media_type)
.and_then(|extensions| extensions.first())
.copied(),
}
}
@ -45,13 +39,7 @@ mod tests {
#[test]
fn test_get_media_type_extension() {
assert_eq!(
get_media_type_extension("image/png"),
Some("png"),
);
assert_eq!(
get_media_type_extension("image/jpeg"),
Some("jpg"),
);
assert_eq!(get_media_type_extension("image/png"), Some("png"),);
assert_eq!(get_media_type_extension("image/jpeg"), Some("jpg"),);
}
}

View file

@ -3,7 +3,7 @@ use std::iter::FromIterator;
use ammonia::Builder;
pub use ammonia::{clean_text as escape_html};
pub use ammonia::clean_text as escape_html;
pub fn clean_html(
unsafe_html: &str,
@ -12,7 +12,7 @@ pub fn clean_html(
let mut builder = Builder::default();
for (tag, classes) in allowed_classes.iter() {
builder.add_allowed_classes(tag, classes);
};
}
let safe_html = builder
// Remove src from external images to prevent tracking
.set_tag_attribute_value("img", "src", "")
@ -28,15 +28,11 @@ pub fn clean_html_strict(
allowed_tags: &[&str],
allowed_classes: Vec<(&'static str, Vec<&'static str>)>,
) -> String {
let allowed_tags =
HashSet::from_iter(allowed_tags.iter().copied());
let allowed_tags = HashSet::from_iter(allowed_tags.iter().copied());
let mut allowed_classes_map = HashMap::new();
for (tag, classes) in allowed_classes {
allowed_classes_map.insert(
tag,
HashSet::from_iter(classes.into_iter()),
);
};
allowed_classes_map.insert(tag, HashSet::from_iter(classes.into_iter()));
}
let safe_html = Builder::default()
.tags(allowed_tags)
.allowed_classes(allowed_classes_map)
@ -47,9 +43,7 @@ pub fn clean_html_strict(
}
pub fn clean_html_all(html: &str) -> String {
let text = Builder::empty()
.clean(html)
.to_string();
let text = Builder::empty().clean(html).to_string();
text
}
@ -69,10 +63,7 @@ mod tests {
);
let safe_html = clean_html(
unsafe_html,
vec![
("a", vec!["mention", "u-url"]),
("span", vec!["h-card"]),
],
vec![("a", vec!["mention", "u-url"]), ("span", vec!["h-card"])],
);
assert_eq!(safe_html, expected_safe_html);
}
@ -83,12 +74,12 @@ mod tests {
let safe_html = clean_html_strict(
unsafe_html,
&["a", "br", "code", "p", "span"],
vec![
("a", vec!["mention", "u-url"]),
("span", vec!["h-card"]),
],
vec![("a", vec!["mention", "u-url"]), ("span", vec!["h-card"])],
);
assert_eq!(
safe_html,
r#"<p><span class="h-card"><a href="https://example.com/user" class="u-url mention" rel="noopener">@<span>user</span></a></span> test bold with <a href="https://example.com" rel="noopener">link</a> and <code>code</code></p>"#
);
assert_eq!(safe_html, r#"<p><span class="h-card"><a href="https://example.com/user" class="u-url mention" rel="noopener">@<span>user</span></a></span> test bold with <a href="https://example.com" rel="noopener">link</a> and <code>code</code></p>"#);
}
#[test]

View file

@ -2,14 +2,9 @@ use std::cell::RefCell;
use comrak::{
arena_tree::Node,
format_commonmark,
format_html,
format_commonmark, format_html,
nodes::{Ast, AstNode, ListType, NodeValue},
parse_document,
Arena,
ComrakOptions,
ComrakExtensionOptions,
ComrakParseOptions,
parse_document, Arena, ComrakExtensionOptions, ComrakOptions, ComrakParseOptions,
ComrakRenderOptions,
};
@ -37,16 +32,14 @@ fn build_comrak_options() -> ComrakOptions {
}
}
fn iter_nodes<'a, F>(
node: &'a AstNode<'a>,
func: &F,
) -> Result<(), MarkdownError>
where F: Fn(&'a AstNode<'a>) -> Result<(), MarkdownError>
fn iter_nodes<'a, F>(node: &'a AstNode<'a>, func: &F) -> Result<(), MarkdownError>
where
F: Fn(&'a AstNode<'a>) -> Result<(), MarkdownError>,
{
func(node)?;
for child in node.children() {
iter_nodes(child, func)?;
};
}
Ok(())
}
@ -80,15 +73,13 @@ fn replace_with_markdown<'a>(
let markdown = node_to_markdown(node, options)?;
for child in node.children() {
child.detach();
};
}
let text = NodeValue::Text(markdown);
replace_node_value(node, text);
Ok(())
}
fn fix_microsyntaxes<'a>(
node: &'a AstNode<'a>,
) -> Result<(), MarkdownError> {
fn fix_microsyntaxes<'a>(node: &'a AstNode<'a>) -> Result<(), MarkdownError> {
if let Some(prev) = node.previous_sibling() {
if let NodeValue::Text(ref prev_text) = prev.data.borrow().value {
// Remove autolink if mention or object link syntax is found
@ -100,7 +91,7 @@ fn fix_microsyntaxes<'a>(
if let NodeValue::Text(child_text) = child_value {
link_text.push_str(child_text);
};
};
}
let text = NodeValue::Text(link_text);
replace_node_value(node, text);
};
@ -138,11 +129,7 @@ fn fix_linebreaks(html: &str) -> String {
pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
let options = build_comrak_options();
let arena = Arena::new();
let root = parse_document(
&arena,
text,
&options,
);
let root = parse_document(&arena, text, &options);
// Re-render blockquotes, headings, HRs, images and lists
// Headings: poorly degrade on Pleroma
@ -160,12 +147,12 @@ pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
};
for child in node.children() {
child.detach();
};
}
let text = NodeValue::Text(markdown);
let text_node = arena.alloc(create_node(text));
node.append(text_node);
replace_node_value(node, NodeValue::Paragraph);
},
}
NodeValue::Image(_) => replace_with_markdown(node, &options)?,
NodeValue::List(_) => {
// Replace list and list item nodes
@ -176,11 +163,10 @@ pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
for paragraph in list_item.children() {
for content_node in paragraph.children() {
contents.push(content_node);
};
}
paragraph.detach();
};
let mut list_prefix_markdown =
node_to_markdown(list_item, &options)?;
}
let mut list_prefix_markdown = node_to_markdown(list_item, &options)?;
if let NodeValue::Item(item) = list_item.data.borrow().value {
if item.list_type == ListType::Ordered {
// Preserve numbering in ordered lists
@ -200,14 +186,14 @@ pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
replacements.push(list_prefix_node);
for content_node in contents {
replacements.push(content_node);
};
}
list_item.detach();
};
}
for child_node in replacements {
node.append(child_node);
};
}
replace_node_value(node, NodeValue::Paragraph);
},
}
NodeValue::Link(_) => fix_microsyntaxes(node)?,
_ => (),
};
@ -224,20 +210,15 @@ pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
pub fn markdown_basic_to_html(text: &str) -> Result<String, MarkdownError> {
let options = build_comrak_options();
let arena = Arena::new();
let root = parse_document(
&arena,
text,
&options,
);
let root = parse_document(&arena, text, &options);
iter_nodes(root, &|node| {
let node_value = node.data.borrow().value.clone();
match node_value {
NodeValue::Document |
NodeValue::Text(_) |
NodeValue::SoftBreak |
NodeValue::LineBreak
=> (),
NodeValue::Document
| NodeValue::Text(_)
| NodeValue::SoftBreak
| NodeValue::LineBreak => (),
NodeValue::Link(_) => fix_microsyntaxes(node)?,
NodeValue::Paragraph => {
if node.next_sibling().is_some() {
@ -248,13 +229,12 @@ pub fn markdown_basic_to_html(text: &str) -> Result<String, MarkdownError> {
let last_child_value = &last_child.data.borrow().value;
if !matches!(last_child_value, NodeValue::LineBreak) {
let line_break = NodeValue::LineBreak;
let line_break_node =
arena.alloc(create_node(line_break));
let line_break_node = arena.alloc(create_node(line_break));
node.append(line_break_node);
};
};
};
},
}
_ => replace_with_markdown(node, &options)?,
};
Ok(())
@ -340,9 +320,6 @@ mod tests {
fn test_markdown_to_html() {
let text = "# heading\n\ntest";
let html = markdown_to_html(text);
assert_eq!(
html,
"<h1>heading</h1>\n<p>test</p>\n",
);
assert_eq!(html, "<h1>heading</h1>\n<p>test</p>\n",);
}
}

View file

@ -12,10 +12,10 @@ pub enum MultibaseError {
/// Decodes multibase base58 (bitcoin) value
/// https://github.com/multiformats/multibase
pub fn decode_multibase_base58btc(value: &str)
-> Result<Vec<u8>, MultibaseError>
{
let base = value.chars().next()
pub fn decode_multibase_base58btc(value: &str) -> Result<Vec<u8>, MultibaseError> {
let base = value
.chars()
.next()
.ok_or(MultibaseError::InvalidBaseString)?;
// z == base58btc
// https://github.com/multiformats/multibase#multibase-table

View file

@ -7,10 +7,7 @@ pub fn hash_password(password: &str) -> Result<String, argon2::Error> {
argon2::hash_encoded(password.as_bytes(), &salt, &config)
}
pub fn verify_password(
password_hash: &str,
password: &str,
) -> Result<bool, argon2::Error> {
pub fn verify_password(password_hash: &str, password: &str) -> Result<bool, argon2::Error> {
argon2::verify_encoded(password_hash, password.as_bytes())
}

View file

@ -2,10 +2,7 @@ use std::net::{Ipv4Addr, Ipv6Addr};
use url::{Host, ParseError, Url};
pub fn get_hostname(url: &str) -> Result<String, ParseError> {
let hostname = match Url::parse(url)?
.host()
.ok_or(ParseError::EmptyHost)?
{
let hostname = match Url::parse(url)?.host().ok_or(ParseError::EmptyHost)? {
Host::Domain(domain) => domain.to_string(),
Host::Ipv4(addr) => addr.to_string(),
Host::Ipv6(addr) => addr.to_string(),
@ -35,10 +32,7 @@ pub fn guess_protocol(hostname: &str) -> &'static str {
}
pub fn normalize_url(url: &str) -> Result<Url, url::ParseError> {
let normalized_url = if
url.starts_with("http://") ||
url.starts_with("https://")
{
let normalized_url = if url.starts_with("http://") || url.starts_with("https://") {
url.to_string()
} else {
// Add scheme
@ -49,11 +43,7 @@ pub fn normalize_url(url: &str) -> Result<Url, url::ParseError> {
url
};
let url_scheme = guess_protocol(hostname);
format!(
"{}://{}",
url_scheme,
url,
)
format!("{}://{}", url_scheme, url,)
};
let url = Url::parse(&normalized_url)?;
url.host().ok_or(ParseError::EmptyHost)?; // validates URL
@ -82,7 +72,10 @@ mod tests {
fn test_get_hostname_tor() {
let url = "http://2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion/objects/1";
let hostname = get_hostname(url).unwrap();
assert_eq!(hostname, "2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion");
assert_eq!(
hostname,
"2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion"
);
}
#[test]
@ -101,28 +94,19 @@ mod tests {
#[test]
fn test_guess_protocol() {
assert_eq!(
guess_protocol("example.org"),
"https",
);
assert_eq!(guess_protocol("example.org"), "https",);
assert_eq!(
guess_protocol("2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion"),
"http",
);
assert_eq!(
guess_protocol("zzz.i2p"),
"http",
);
assert_eq!(guess_protocol("zzz.i2p"), "http",);
// Yggdrasil
assert_eq!(
guess_protocol("319:3cf0:dd1d:47b9:20c:29ff:fe2c:39be"),
"http",
);
// localhost
assert_eq!(
guess_protocol("127.0.0.1"),
"http",
);
assert_eq!(guess_protocol("127.0.0.1"), "http",);
}
#[test]

View file

@ -1,36 +1,20 @@
use mitra_models::profiles::types::{
ExtraField,
IdentityProof,
IdentityProofType,
PaymentLink,
PaymentOption,
ExtraField, IdentityProof, IdentityProofType, PaymentLink, PaymentOption,
};
use mitra_utils::did::Did;
use crate::activitypub::vocabulary::{
IDENTITY_PROOF,
LINK,
PROPERTY_VALUE,
};
use crate::activitypub::vocabulary::{IDENTITY_PROOF, LINK, PROPERTY_VALUE};
use crate::errors::ValidationError;
use crate::identity::{
claims::create_identity_claim,
minisign::{
parse_minisign_signature,
verify_minisign_signature,
},
};
use crate::json_signatures::proofs::{
PROOF_TYPE_ID_EIP191,
PROOF_TYPE_ID_MINISIGN,
minisign::{parse_minisign_signature, verify_minisign_signature},
};
use crate::json_signatures::proofs::{PROOF_TYPE_ID_EIP191, PROOF_TYPE_ID_MINISIGN};
use crate::web_client::urls::get_subscription_page_url;
use super::types::ActorAttachment;
pub fn attach_identity_proof(
proof: IdentityProof,
) -> ActorAttachment {
pub fn attach_identity_proof(proof: IdentityProof) -> ActorAttachment {
let proof_type_str = match proof.proof_type {
IdentityProofType::LegacyEip191IdentityProof => PROOF_TYPE_ID_EIP191,
IdentityProofType::LegacyMinisignIdentityProof => PROOF_TYPE_ID_MINISIGN,
@ -52,18 +36,24 @@ pub fn parse_identity_proof(
if attachment.object_type != IDENTITY_PROOF {
return Err(ValidationError("invalid attachment type"));
};
let proof_type_str = attachment.signature_algorithm.as_ref()
let proof_type_str = attachment
.signature_algorithm
.as_ref()
.ok_or(ValidationError("missing proof type"))?;
let proof_type = match proof_type_str.as_str() {
PROOF_TYPE_ID_EIP191 => IdentityProofType::LegacyEip191IdentityProof,
PROOF_TYPE_ID_MINISIGN => IdentityProofType::LegacyMinisignIdentityProof,
_ => return Err(ValidationError("unsupported proof type")),
};
let did = attachment.name.parse::<Did>()
let did = attachment
.name
.parse::<Did>()
.map_err(|_| ValidationError("invalid DID"))?;
let message = create_identity_claim(actor_id, &did)
.map_err(|_| ValidationError("invalid claim"))?;
let signature = attachment.signature_value.as_ref()
let message =
create_identity_claim(actor_id, &did).map_err(|_| ValidationError("invalid claim"))?;
let signature = attachment
.signature_value
.as_ref()
.ok_or(ValidationError("missing signature"))?;
match did {
Did::Key(ref did_key) => {
@ -72,15 +62,12 @@ pub fn parse_identity_proof(
};
let signature_bin = parse_minisign_signature(signature)
.map_err(|_| ValidationError("invalid signature encoding"))?;
verify_minisign_signature(
did_key,
&message,
&signature_bin,
).map_err(|_| ValidationError("invalid identity proof"))?;
},
verify_minisign_signature(did_key, &message, &signature_bin)
.map_err(|_| ValidationError("invalid identity proof"))?;
}
Did::Pkh(ref _did_pkh) => {
return Err(ValidationError("incorrect proof type"));
},
}
};
let proof = IdentityProof {
issuer: did,
@ -102,12 +89,12 @@ pub fn attach_payment_option(
let name = "EthereumSubscription".to_string();
let href = get_subscription_page_url(instance_url, username);
(name, href)
},
}
PaymentOption::MoneroSubscription(_) => {
let name = "MoneroSubscription".to_string();
let href = get_subscription_page_url(instance_url, username);
(name, href)
},
}
};
ActorAttachment {
object_type: LINK.to_string(),
@ -125,7 +112,9 @@ pub fn parse_payment_option(
if attachment.object_type != LINK {
return Err(ValidationError("invalid attachment type"));
};
let href = attachment.href.as_ref()
let href = attachment
.href
.as_ref()
.ok_or(ValidationError("href attribute is required"))?
.to_string();
let payment_option = PaymentOption::Link(PaymentLink {
@ -135,9 +124,7 @@ pub fn parse_payment_option(
Ok(payment_option)
}
pub fn attach_extra_field(
field: ExtraField,
) -> ActorAttachment {
pub fn attach_extra_field(field: ExtraField) -> ActorAttachment {
ActorAttachment {
object_type: PROPERTY_VALUE.to_string(),
name: field.name,
@ -148,13 +135,13 @@ pub fn attach_extra_field(
}
}
pub fn parse_extra_field(
attachment: &ActorAttachment,
) -> Result<ExtraField, ValidationError> {
pub fn parse_extra_field(attachment: &ActorAttachment) -> Result<ExtraField, ValidationError> {
if attachment.object_type != PROPERTY_VALUE {
return Err(ValidationError("invalid attachment type"));
};
let property_value = attachment.value.as_ref()
let property_value = attachment
.value
.as_ref()
.ok_or(ValidationError("missing property value"))?;
let field = ExtraField {
name: attachment.name.clone(),
@ -166,10 +153,8 @@ pub fn parse_extra_field(
#[cfg(test)]
mod tests {
use mitra_utils::{
caip2::ChainId,
};
use super::*;
use mitra_utils::caip2::ChainId;
const INSTANCE_URL: &str = "https://example.com";
@ -191,15 +176,9 @@ mod tests {
#[test]
fn test_payment_option() {
let username = "testuser";
let payment_option =
PaymentOption::ethereum_subscription(ChainId::ethereum_mainnet());
let subscription_page_url =
"https://example.com/@testuser/subscription";
let attachment = attach_payment_option(
INSTANCE_URL,
username,
payment_option,
);
let payment_option = PaymentOption::ethereum_subscription(ChainId::ethereum_mainnet());
let subscription_page_url = "https://example.com/@testuser/subscription";
let attachment = attach_payment_option(INSTANCE_URL, username, payment_option);
assert_eq!(attachment.object_type, LINK);
assert_eq!(attachment.name, "EthereumSubscription");
assert_eq!(attachment.href.as_deref().unwrap(), subscription_page_url);

View file

@ -6,12 +6,7 @@ use mitra_config::Instance;
use mitra_models::{
database::DatabaseClient,
profiles::queries::{create_profile, update_profile},
profiles::types::{
DbActorProfile,
ProfileImage,
ProfileCreateData,
ProfileUpdateData,
},
profiles::types::{DbActorProfile, ProfileCreateData, ProfileImage, ProfileUpdateData},
};
use crate::activitypub::{
@ -44,19 +39,17 @@ async fn fetch_actor_images(
icon.media_type.as_deref(),
ACTOR_IMAGE_MAX_SIZE,
media_dir,
).await {
)
.await
{
Ok((file_name, file_size, maybe_media_type)) => {
let image = ProfileImage::new(
file_name,
file_size,
maybe_media_type,
);
let image = ProfileImage::new(file_name, file_size, maybe_media_type);
Some(image)
},
}
Err(error) => {
log::warn!("failed to fetch avatar ({})", error);
default_avatar
},
}
}
} else {
None
@ -68,19 +61,17 @@ async fn fetch_actor_images(
image.media_type.as_deref(),
ACTOR_IMAGE_MAX_SIZE,
media_dir,
).await {
)
.await
{
Ok((file_name, file_size, maybe_media_type)) => {
let image = ProfileImage::new(
file_name,
file_size,
maybe_media_type,
);
let image = ProfileImage::new(file_name, file_size, maybe_media_type);
Some(image)
},
}
Err(error) => {
log::warn!("failed to fetch banner ({})", error);
default_banner
},
}
}
} else {
None
@ -90,24 +81,24 @@ async fn fetch_actor_images(
fn parse_aliases(actor: &Actor) -> Vec<String> {
// Aliases reported by server (not signed)
actor.also_known_as.as_ref()
.and_then(|value| {
match parse_array(value) {
Ok(array) => {
let mut aliases = vec![];
for actor_id in array {
if validate_object_id(&actor_id).is_err() {
log::warn!("invalid alias: {}", actor_id);
continue;
};
aliases.push(actor_id);
actor
.also_known_as
.as_ref()
.and_then(|value| match parse_array(value) {
Ok(array) => {
let mut aliases = vec![];
for actor_id in array {
if validate_object_id(&actor_id).is_err() {
log::warn!("invalid alias: {}", actor_id);
continue;
};
Some(aliases)
},
Err(_) => {
log::warn!("invalid alias list: {}", value);
None
},
aliases.push(actor_id);
}
Some(aliases)
}
Err(_) => {
log::warn!("invalid alias list: {}", value);
None
}
})
.unwrap_or_default()
@ -127,23 +118,18 @@ async fn parse_tags(
log::warn!("too many emojis");
continue;
};
match handle_emoji(
db_client,
instance,
storage,
tag_value,
).await? {
match handle_emoji(db_client, instance, storage, tag_value).await? {
Some(emoji) => {
if !emojis.contains(&emoji.id) {
emojis.push(emoji.id);
};
},
}
None => continue,
};
} else {
log::warn!("skipping actor tag of type {}", tag_type);
};
};
}
Ok(emojis)
}
@ -157,22 +143,11 @@ pub async fn create_remote_profile(
if actor_address.hostname == instance.hostname() {
return Err(HandlerError::LocalObject);
};
let (maybe_avatar, maybe_banner) = fetch_actor_images(
instance,
&actor,
&storage.media_dir,
None,
None,
).await;
let (identity_proofs, payment_options, extra_fields) =
actor.parse_attachments();
let (maybe_avatar, maybe_banner) =
fetch_actor_images(instance, &actor, &storage.media_dir, None, None).await;
let (identity_proofs, payment_options, extra_fields) = actor.parse_attachments();
let aliases = parse_aliases(&actor);
let emojis = parse_tags(
db_client,
instance,
storage,
&actor,
).await?;
let emojis = parse_tags(db_client, instance, storage, &actor).await?;
let mut profile_data = ProfileCreateData {
username: actor.preferred_username.clone(),
hostname: Some(actor_address.hostname),
@ -203,11 +178,7 @@ pub async fn update_remote_profile(
) -> Result<DbActorProfile, HandlerError> {
let actor_old = profile.actor_json.ok_or(HandlerError::LocalObject)?;
if actor_old.id != actor.id {
log::warn!(
"actor ID changed from {} to {}",
actor_old.id,
actor.id,
);
log::warn!("actor ID changed from {} to {}", actor_old.id, actor.id,);
};
if actor_old.public_key.public_key_pem != actor.public_key.public_key_pem {
log::warn!(
@ -222,16 +193,11 @@ pub async fn update_remote_profile(
&storage.media_dir,
profile.avatar,
profile.banner,
).await;
let (identity_proofs, payment_options, extra_fields) =
actor.parse_attachments();
)
.await;
let (identity_proofs, payment_options, extra_fields) = actor.parse_attachments();
let aliases = parse_aliases(&actor);
let emojis = parse_tags(
db_client,
instance,
storage,
&actor,
).await?;
let emojis = parse_tags(db_client, instance, storage, &actor).await?;
let mut profile_data = ProfileUpdateData {
display_name: actor.name.clone(),
bio: actor.summary.clone(),

View file

@ -1,22 +1,11 @@
use std::collections::HashMap;
use serde::{
Deserialize,
Deserializer,
Serialize,
de::{Error as DeserializerError},
};
use serde::{de::Error as DeserializerError, Deserialize, Deserializer, Serialize};
use serde_json::{json, Value};
use mitra_config::Instance;
use mitra_models::{
profiles::types::{
DbActor,
DbActorPublicKey,
ExtraField,
IdentityProof,
PaymentOption,
},
profiles::types::{DbActor, DbActorPublicKey, ExtraField, IdentityProof, PaymentOption},
users::types::User,
};
use mitra_utils::{
@ -26,17 +15,10 @@ use mitra_utils::{
use crate::activitypub::{
constants::{
AP_CONTEXT,
MASTODON_CONTEXT,
MITRA_CONTEXT,
SCHEMA_ORG_CONTEXT,
W3ID_SECURITY_CONTEXT,
AP_CONTEXT, MASTODON_CONTEXT, MITRA_CONTEXT, SCHEMA_ORG_CONTEXT, W3ID_SECURITY_CONTEXT,
},
identifiers::{
local_actor_id,
local_actor_key_id,
local_instance_actor_id,
LocalActorCollection,
local_actor_id, local_actor_key_id, local_instance_actor_id, LocalActorCollection,
},
types::deserialize_value_array,
vocabulary::{IDENTITY_PROOF, IMAGE, LINK, PERSON, PROPERTY_VALUE, SERVICE},
@ -46,12 +28,8 @@ use crate::media::get_file_url;
use crate::webfinger::types::ActorAddress;
use super::attachments::{
attach_extra_field,
attach_identity_proof,
attach_payment_option,
parse_extra_field,
parse_identity_proof,
parse_payment_option,
attach_extra_field, attach_identity_proof, attach_payment_option, parse_extra_field,
parse_identity_proof, parse_payment_option,
};
#[derive(Deserialize, Serialize)]
@ -94,21 +72,17 @@ pub struct ActorAttachment {
}
// Some implementations use empty object instead of null
fn deserialize_image_opt<'de, D>(
deserializer: D,
) -> Result<Option<ActorImage>, D::Error>
where D: Deserializer<'de>
fn deserialize_image_opt<'de, D>(deserializer: D) -> Result<Option<ActorImage>, D::Error>
where
D: Deserializer<'de>,
{
let maybe_value: Option<Value> = Option::deserialize(deserializer)?;
let maybe_image = if let Some(value) = maybe_value {
let is_empty_object = value.as_object()
.map(|map| map.is_empty())
.unwrap_or(false);
let is_empty_object = value.as_object().map(|map| map.is_empty()).unwrap_or(false);
if is_empty_object {
None
} else {
let image = ActorImage::deserialize(value)
.map_err(DeserializerError::custom)?;
let image = ActorImage::deserialize(value).map_err(DeserializerError::custom)?;
Some(image)
}
} else {
@ -149,7 +123,7 @@ pub struct Actor {
#[serde(
default,
deserialize_with = "deserialize_image_opt",
skip_serializing_if = "Option::is_none",
skip_serializing_if = "Option::is_none"
)]
pub icon: Option<ActorImage>,
@ -165,7 +139,7 @@ pub struct Actor {
#[serde(
default,
deserialize_with = "deserialize_value_array",
skip_serializing_if = "Vec::is_empty",
skip_serializing_if = "Vec::is_empty"
)]
pub attachment: Vec<Value>,
@ -175,7 +149,7 @@ pub struct Actor {
#[serde(
default,
deserialize_with = "deserialize_value_array",
skip_serializing_if = "Vec::is_empty",
skip_serializing_if = "Vec::is_empty"
)]
pub tag: Vec<Value>,
@ -184,11 +158,8 @@ pub struct Actor {
}
impl Actor {
pub fn address(
&self,
) -> Result<ActorAddress, ValidationError> {
let hostname = get_hostname(&self.id)
.map_err(|_| ValidationError("invalid actor ID"))?;
pub fn address(&self) -> Result<ActorAddress, ValidationError> {
let hostname = get_hostname(&self.id).map_err(|_| ValidationError("invalid actor ID"))?;
let actor_address = ActorAddress {
username: self.preferred_username.clone(),
hostname: hostname,
@ -213,11 +184,7 @@ impl Actor {
}
}
pub fn parse_attachments(&self) -> (
Vec<IdentityProof>,
Vec<PaymentOption>,
Vec<ExtraField>,
) {
pub fn parse_attachments(&self) -> (Vec<IdentityProof>, Vec<PaymentOption>, Vec<ExtraField>) {
let mut identity_proofs = vec![];
let mut payment_options = vec![];
let mut extra_fields = vec![];
@ -229,17 +196,13 @@ impl Actor {
);
};
for attachment_value in self.attachment.iter() {
let attachment_type =
attachment_value["type"].as_str().unwrap_or("Unknown");
let attachment_type = attachment_value["type"].as_str().unwrap_or("Unknown");
let attachment = match serde_json::from_value(attachment_value.clone()) {
Ok(attachment) => attachment,
Err(_) => {
log_error(
attachment_type,
ValidationError("invalid attachment"),
);
log_error(attachment_type, ValidationError("invalid attachment"));
continue;
},
}
};
match attachment_type {
IDENTITY_PROOF => {
@ -247,27 +210,27 @@ impl Actor {
Ok(proof) => identity_proofs.push(proof),
Err(error) => log_error(attachment_type, error),
};
},
}
LINK => {
match parse_payment_option(&attachment) {
Ok(option) => payment_options.push(option),
Err(error) => log_error(attachment_type, error),
};
},
}
PROPERTY_VALUE => {
match parse_extra_field(&attachment) {
Ok(field) => extra_fields.push(field),
Err(error) => log_error(attachment_type, error),
};
},
}
_ => {
log_error(
attachment_type,
ValidationError("unsupported attachment type"),
);
},
}
};
};
}
(identity_proofs, payment_options, extra_fields)
}
}
@ -295,10 +258,7 @@ fn build_actor_context() -> (
)
}
pub fn get_local_actor(
user: &User,
instance_url: &str,
) -> Result<Actor, ActorKeyError> {
pub fn get_local_actor(user: &User, instance_url: &str) -> Result<Actor, ActorKeyError> {
let username = &user.profile.username;
let actor_id = local_actor_id(instance_url, username);
let inbox = LocalActorCollection::Inbox.of(&actor_id);
@ -322,7 +282,7 @@ pub fn get_local_actor(
media_type: image.media_type.clone(),
};
Some(actor_image)
},
}
None => None,
};
let banner = match &user.profile.banner {
@ -333,32 +293,29 @@ pub fn get_local_actor(
media_type: image.media_type.clone(),
};
Some(actor_image)
},
}
None => None,
};
let mut attachments = vec![];
for proof in user.profile.identity_proofs.clone().into_inner() {
let attachment = attach_identity_proof(proof);
let attachment_value = serde_json::to_value(attachment)
.expect("attachment should be serializable");
let attachment_value =
serde_json::to_value(attachment).expect("attachment should be serializable");
attachments.push(attachment_value);
};
}
for payment_option in user.profile.payment_options.clone().into_inner() {
let attachment = attach_payment_option(
instance_url,
&user.profile.username,
payment_option,
);
let attachment_value = serde_json::to_value(attachment)
.expect("attachment should be serializable");
let attachment =
attach_payment_option(instance_url, &user.profile.username, payment_option);
let attachment_value =
serde_json::to_value(attachment).expect("attachment should be serializable");
attachments.push(attachment_value);
};
}
for field in user.profile.extra_fields.clone().into_inner() {
let attachment = attach_extra_field(field);
let attachment_value = serde_json::to_value(attachment)
.expect("attachment should be serializable");
let attachment_value =
serde_json::to_value(attachment).expect("attachment should be serializable");
attachments.push(attachment_value);
};
}
let aliases = user.profile.aliases.clone().into_actor_ids();
let actor = Actor {
context: Some(json!(build_actor_context())),
@ -384,9 +341,7 @@ pub fn get_local_actor(
Ok(actor)
}
pub fn get_instance_actor(
instance: &Instance,
) -> Result<Actor, ActorKeyError> {
pub fn get_instance_actor(instance: &Instance) -> Result<Actor, ActorKeyError> {
let actor_id = local_instance_actor_id(&instance.url());
let actor_inbox = LocalActorCollection::Inbox.of(&actor_id);
let actor_outbox = LocalActorCollection::Outbox.of(&actor_id);
@ -422,12 +377,9 @@ pub fn get_instance_actor(
#[cfg(test)]
mod tests {
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::crypto_rsa::{
generate_weak_rsa_key,
serialize_private_key,
};
use super::*;
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::crypto_rsa::{generate_weak_rsa_key, serialize_private_key};
const INSTANCE_HOSTNAME: &str = "example.com";
const INSTANCE_URL: &str = "https://example.com";

View file

@ -7,24 +7,17 @@ use mitra_models::{
profiles::queries::get_profile_by_remote_actor_id,
profiles::types::DbActorProfile,
};
use mitra_utils::{
crypto_rsa::deserialize_public_key,
did::Did,
};
use mitra_utils::{crypto_rsa::deserialize_public_key, did::Did};
use crate::http_signatures::verify::{
parse_http_signature,
verify_http_signature,
parse_http_signature, verify_http_signature,
HttpSignatureVerificationError as HttpSignatureError,
};
use crate::json_signatures::{
proofs::ProofType,
verify::{
get_json_signature,
verify_ed25519_json_signature,
verify_eip191_json_signature,
verify_rsa_json_signature,
JsonSignatureVerificationError as JsonSignatureError,
get_json_signature, verify_ed25519_json_signature, verify_eip191_json_signature,
verify_rsa_json_signature, JsonSignatureVerificationError as JsonSignatureError,
JsonSigner,
},
};
@ -93,12 +86,14 @@ async fn get_signer(
&config.instance(),
&MediaStorage::from(config),
signer_id,
).await {
)
.await
{
Ok(profile) => profile,
Err(HandlerError::DatabaseError(error)) => return Err(error.into()),
Err(other_error) => {
return Err(AuthenticationError::ImportError(other_error.to_string()));
},
}
}
};
Ok(signer)
@ -111,24 +106,22 @@ pub async fn verify_signed_request(
request: &HttpRequest,
no_fetch: bool,
) -> Result<DbActorProfile, AuthenticationError> {
let signature_data = match parse_http_signature(
request.method(),
request.uri(),
request.headers(),
) {
Ok(signature_data) => signature_data,
Err(HttpSignatureError::NoSignature) => {
return Err(AuthenticationError::NoHttpSignature);
},
Err(other_error) => return Err(other_error.into()),
};
let signature_data =
match parse_http_signature(request.method(), request.uri(), request.headers()) {
Ok(signature_data) => signature_data,
Err(HttpSignatureError::NoSignature) => {
return Err(AuthenticationError::NoHttpSignature);
}
Err(other_error) => return Err(other_error.into()),
};
let signer_id = key_id_to_actor_id(&signature_data.key_id)?;
let signer = get_signer(config, db_client, &signer_id, no_fetch).await?;
let signer_actor = signer.actor_json.as_ref()
let signer_actor = signer
.actor_json
.as_ref()
.expect("request should be signed by remote actor");
let signer_key =
deserialize_public_key(&signer_actor.public_key.public_key_pem)?;
let signer_key = deserialize_public_key(&signer_actor.public_key.public_key_pem)?;
verify_http_signature(&signature_data, &signer_key)?;
@ -146,13 +139,14 @@ pub async fn verify_signed_activity(
Ok(signature_data) => signature_data,
Err(JsonSignatureError::NoProof) => {
return Err(AuthenticationError::NoJsonSignature);
},
}
Err(other_error) => return Err(other_error.into()),
};
// Signed activities must have `actor` property, to avoid situations
// where signer is identified by DID but there is no matching
// identity proof in the local database.
let actor_id = activity["actor"].as_str()
let actor_id = activity["actor"]
.as_str()
.ok_or(AuthenticationError::ActorError("unknown actor"))?;
let actor_profile = get_signer(config, db_client, actor_id, no_fetch).await?;
@ -165,12 +159,13 @@ pub async fn verify_signed_activity(
if signer_id != actor_id {
return Err(AuthenticationError::UnexpectedSigner);
};
let signer_actor = actor_profile.actor_json.as_ref()
let signer_actor = actor_profile
.actor_json
.as_ref()
.expect("activity should be signed by remote actor");
let signer_key =
deserialize_public_key(&signer_actor.public_key.public_key_pem)?;
let signer_key = deserialize_public_key(&signer_actor.public_key.public_key_pem)?;
verify_rsa_json_signature(&signature_data, &signer_key)?;
},
}
JsonSigner::Did(did) => {
if !actor_profile.identity_proofs.any(&did) {
return Err(AuthenticationError::UnexpectedSigner);
@ -186,7 +181,7 @@ pub async fn verify_signed_activity(
&signature_data.message,
&signature_data.signature,
)?;
},
}
ProofType::JcsEip191Signature => {
let did_pkh = match did {
Did::Pkh(did_pkh) => did_pkh,
@ -197,10 +192,10 @@ pub async fn verify_signed_activity(
&signature_data.message,
&signature_data.signature,
)?;
},
}
_ => return Err(AuthenticationError::InvalidJsonSignatureType),
};
},
}
};
// Signer is actor
Ok(actor_profile)

View file

@ -61,12 +61,7 @@ pub fn prepare_accept_follow(
follow_activity_id,
);
let recipients = vec![source_actor.clone()];
OutgoingActivity::new(
instance,
sender,
activity,
recipients,
)
OutgoingActivity::new(instance, sender, activity, recipients)
}
#[cfg(test)]
@ -83,12 +78,7 @@ mod tests {
};
let follow_activity_id = "https://test.remote/objects/999";
let follower_id = "https://test.remote/users/123";
let activity = build_accept_follow(
INSTANCE_URL,
&target,
follower_id,
follow_activity_id,
);
let activity = build_accept_follow(INSTANCE_URL, &target, follower_id, follow_activity_id);
assert_eq!(activity.id.starts_with(INSTANCE_URL), true);
assert_eq!(activity.activity_type, "Accept");

View file

@ -1,10 +1,7 @@
use serde::Serialize;
use mitra_config::Instance;
use mitra_models::{
profiles::types::DbActor,
users::types::User,
};
use mitra_models::{profiles::types::DbActor, users::types::User};
use mitra_utils::id::generate_ulid;
use crate::activitypub::{
@ -67,12 +64,7 @@ pub fn prepare_update_collection(
remove,
);
let recipients = vec![person.clone()];
OutgoingActivity::new(
instance,
sender,
activity,
recipients,
)
OutgoingActivity::new(instance, sender, activity, recipients)
}
pub fn prepare_add_person(
@ -95,13 +87,8 @@ mod tests {
let sender_username = "local";
let person_id = "https://test.remote/actor/test";
let collection = LocalActorCollection::Subscribers;
let activity = build_update_collection(
INSTANCE_URL,
sender_username,
person_id,
collection,
false,
);
let activity =
build_update_collection(INSTANCE_URL, sender_username, person_id, collection, false);
assert_eq!(activity.activity_type, "Add");
assert_eq!(

View file

@ -14,11 +14,7 @@ use crate::activitypub::{
constants::AP_PUBLIC,
deliverer::OutgoingActivity,
identifiers::{
local_actor_followers,
local_actor_id,
local_object_id,
post_object_id,
profile_actor_id,
local_actor_followers, local_actor_id, local_object_id, post_object_id, profile_actor_id,
},
types::{build_default_context, Context},
vocabulary::ANNOUNCE,
@ -41,12 +37,11 @@ pub struct Announce {
cc: Vec<String>,
}
pub fn build_announce(
instance_url: &str,
repost: &Post,
) -> Announce {
pub fn build_announce(instance_url: &str, repost: &Post) -> Announce {
let actor_id = local_actor_id(instance_url, &repost.author.username);
let post = repost.repost_of.as_ref()
let post = repost
.repost_of
.as_ref()
.expect("repost_of field should be populated");
let object_id = post_object_id(instance_url, post);
let activity_id = local_object_id(instance_url, &repost.id);
@ -76,7 +71,7 @@ pub async fn get_announce_recipients(
if let Some(remote_actor) = profile.actor_json {
recipients.push(remote_actor);
};
};
}
let primary_recipient = profile_actor_id(instance_url, &post.author);
if let Some(remote_actor) = post.author.actor_json.as_ref() {
recipients.push(remote_actor.clone());
@ -91,30 +86,21 @@ pub async fn prepare_announce(
repost: &Post,
) -> Result<OutgoingActivity, DatabaseError> {
assert_eq!(sender.id, repost.author.id);
let post = repost.repost_of.as_ref()
let post = repost
.repost_of
.as_ref()
.expect("repost_of field should be populated");
let (recipients, _) = get_announce_recipients(
db_client,
&instance.url(),
sender,
post,
).await?;
let activity = build_announce(
&instance.url(),
repost,
);
let (recipients, _) = get_announce_recipients(db_client, &instance.url(), sender, post).await?;
let activity = build_announce(&instance.url(), repost);
Ok(OutgoingActivity::new(
instance,
sender,
activity,
recipients,
instance, sender, activity, recipients,
))
}
#[cfg(test)]
mod tests {
use mitra_models::profiles::types::DbActorProfile;
use super::*;
use mitra_models::profiles::types::DbActorProfile;
const INSTANCE_URL: &str = "https://example.com";
@ -145,18 +131,12 @@ mod tests {
repost_of: Some(Box::new(post)),
..Default::default()
};
let activity = build_announce(
INSTANCE_URL,
&repost,
);
let activity = build_announce(INSTANCE_URL, &repost);
assert_eq!(
activity.id,
format!("{}/objects/{}", INSTANCE_URL, repost.id),
);
assert_eq!(
activity.actor,
format!("{}/users/announcer", INSTANCE_URL),
);
assert_eq!(activity.actor, format!("{}/users/announcer", INSTANCE_URL),);
assert_eq!(activity.object, post_id);
assert_eq!(activity.to, vec![AP_PUBLIC, post_author_id]);
}

View file

@ -16,23 +16,11 @@ use crate::activitypub::{
constants::{AP_MEDIA_TYPE, AP_PUBLIC},
deliverer::OutgoingActivity,
identifiers::{
local_actor_id,
local_actor_followers,
local_actor_subscribers,
local_emoji_id,
local_object_id,
local_tag_collection,
post_object_id,
profile_actor_id,
local_actor_followers, local_actor_id, local_actor_subscribers, local_emoji_id,
local_object_id, local_tag_collection, post_object_id, profile_actor_id,
},
types::{
build_default_context,
Attachment,
Context,
EmojiTag,
EmojiTagImage,
LinkTag,
SimpleTag,
build_default_context, Attachment, Context, EmojiTag, EmojiTagImage, LinkTag, SimpleTag,
},
vocabulary::*,
};
@ -96,50 +84,45 @@ pub fn build_emoji_tag(instance_url: &str, emoji: &DbEmoji) -> EmojiTag {
}
}
pub fn build_note(
instance_hostname: &str,
instance_url: &str,
post: &Post,
) -> Note {
pub fn build_note(instance_hostname: &str, instance_url: &str, post: &Post) -> Note {
let object_id = local_object_id(instance_url, &post.id);
let actor_id = local_actor_id(instance_url, &post.author.username);
let attachments: Vec<Attachment> = post.attachments.iter().map(|db_item| {
let url = get_file_url(instance_url, &db_item.file_name);
let media_type = db_item.media_type.clone();
Attachment {
name: None,
attachment_type: DOCUMENT.to_string(),
media_type,
url: Some(url),
}
}).collect();
let attachments: Vec<Attachment> = post
.attachments
.iter()
.map(|db_item| {
let url = get_file_url(instance_url, &db_item.file_name);
let media_type = db_item.media_type.clone();
Attachment {
name: None,
attachment_type: DOCUMENT.to_string(),
media_type,
url: Some(url),
}
})
.collect();
let mut primary_audience = vec![];
let mut secondary_audience = vec![];
let followers_collection_id =
local_actor_followers(instance_url, &post.author.username);
let subscribers_collection_id =
local_actor_subscribers(instance_url, &post.author.username);
let followers_collection_id = local_actor_followers(instance_url, &post.author.username);
let subscribers_collection_id = local_actor_subscribers(instance_url, &post.author.username);
match post.visibility {
Visibility::Public => {
primary_audience.push(AP_PUBLIC.to_string());
secondary_audience.push(followers_collection_id);
},
}
Visibility::Followers => {
primary_audience.push(followers_collection_id);
},
}
Visibility::Subscribers => {
primary_audience.push(subscribers_collection_id);
},
}
Visibility::Direct => (),
};
let mut tags = vec![];
for profile in &post.mentions {
let actor_address = ActorAddress::from_profile(
instance_hostname,
profile,
);
let actor_address = ActorAddress::from_profile(instance_hostname, profile);
let tag_name = format!("@{}", actor_address);
let actor_id = profile_actor_id(instance_url, profile);
if !primary_audience.contains(&actor_id) {
@ -151,7 +134,7 @@ pub fn build_note(
href: actor_id,
};
tags.push(Tag::SimpleTag(tag));
};
}
for tag_name in &post.tags {
let tag_href = local_tag_collection(instance_url, tag_name);
let tag = SimpleTag {
@ -160,7 +143,7 @@ pub fn build_note(
href: tag_href,
};
tags.push(Tag::SimpleTag(tag));
};
}
assert_eq!(post.links.len(), post.linked.len());
for linked in &post.linked {
@ -168,7 +151,7 @@ pub fn build_note(
// https://codeberg.org/fediverse/fep/src/branch/main/feps/fep-e232.md
let link_href = post_object_id(instance_url, linked);
let tag = LinkTag {
name: None, // no microsyntax
name: None, // no microsyntax
tag_type: LINK.to_string(),
href: link_href,
media_type: AP_MEDIA_TYPE.to_string(),
@ -176,29 +159,30 @@ pub fn build_note(
if cfg!(feature = "fep-e232") {
tags.push(Tag::LinkTag(tag));
};
};
let maybe_quote_url = post.linked.get(0)
}
let maybe_quote_url = post
.linked
.get(0)
.map(|linked| post_object_id(instance_url, linked));
for emoji in &post.emojis {
let tag = build_emoji_tag(instance_url, emoji);
tags.push(Tag::EmojiTag(tag));
};
}
let in_reply_to_object_id = match post.in_reply_to_id {
Some(in_reply_to_id) => {
let in_reply_to = post.in_reply_to.as_ref()
let in_reply_to = post
.in_reply_to
.as_ref()
.expect("in_reply_to should be populated");
assert_eq!(in_reply_to.id, in_reply_to_id);
let in_reply_to_actor_id = profile_actor_id(
instance_url,
&in_reply_to.author,
);
let in_reply_to_actor_id = profile_actor_id(instance_url, &in_reply_to.author);
if !primary_audience.contains(&in_reply_to_actor_id) {
primary_audience.push(in_reply_to_actor_id);
};
Some(post_object_id(instance_url, in_reply_to))
},
}
None => None,
};
Note {
@ -234,11 +218,7 @@ pub struct CreateNote {
cc: Vec<String>,
}
pub fn build_create_note(
instance_hostname: &str,
instance_url: &str,
post: &Post,
) -> CreateNote {
pub fn build_create_note(instance_hostname: &str, instance_url: &str, post: &Post) -> CreateNote {
let object = build_note(instance_hostname, instance_url, post);
let primary_audience = object.to.clone();
let secondary_audience = object.cc.clone();
@ -264,11 +244,11 @@ pub async fn get_note_recipients(
Visibility::Public | Visibility::Followers => {
let followers = get_followers(db_client, &current_user.id).await?;
audience.extend(followers);
},
}
Visibility::Subscribers => {
let subscribers = get_subscribers(db_client, &current_user.id).await?;
audience.extend(subscribers);
},
}
Visibility::Direct => (),
};
if let Some(in_reply_to_id) = post.in_reply_to_id {
@ -283,7 +263,7 @@ pub async fn get_note_recipients(
if let Some(remote_actor) = profile.actor_json {
recipients.push(remote_actor);
};
};
}
Ok(recipients)
}
@ -294,25 +274,18 @@ pub async fn prepare_create_note(
post: &Post,
) -> Result<OutgoingActivity, DatabaseError> {
assert_eq!(author.id, post.author.id);
let activity = build_create_note(
&instance.hostname(),
&instance.url(),
post,
);
let activity = build_create_note(&instance.hostname(), &instance.url(), post);
let recipients = get_note_recipients(db_client, author, post).await?;
Ok(OutgoingActivity::new(
instance,
author,
activity,
recipients,
instance, author, activity, recipients,
))
}
#[cfg(test)]
mod tests {
use serde_json::json;
use mitra_models::profiles::types::DbActorProfile;
use super::*;
use mitra_models::profiles::types::DbActorProfile;
use serde_json::json;
const INSTANCE_HOSTNAME: &str = "example.com";
const INSTANCE_URL: &str = "https://example.com";
@ -326,11 +299,14 @@ mod tests {
};
let tag = Tag::SimpleTag(simple_tag);
let value = serde_json::to_value(tag).unwrap();
assert_eq!(value, json!({
"type": "Hashtag",
"href": "https://example.org/tags/test",
"name": "#test",
}));
assert_eq!(
value,
json!({
"type": "Hashtag",
"href": "https://example.org/tags/test",
"name": "#test",
})
);
}
#[test]
@ -357,10 +333,7 @@ mod tests {
};
let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!(
note.id,
format!("{}/objects/{}", INSTANCE_URL, post.id),
);
assert_eq!(note.id, format!("{}/objects/{}", INSTANCE_URL, post.id),);
assert_eq!(note.attachment.len(), 0);
assert_eq!(
note.attributed_to,
@ -369,9 +342,10 @@ mod tests {
assert_eq!(note.in_reply_to.is_none(), true);
assert_eq!(note.content, post.content);
assert_eq!(note.to, vec![AP_PUBLIC]);
assert_eq!(note.cc, vec![
local_actor_followers(INSTANCE_URL, "author"),
]);
assert_eq!(
note.cc,
vec![local_actor_followers(INSTANCE_URL, "author"),]
);
assert_eq!(note.tag.len(), 1);
let tag = match note.tag[0] {
Tag::SimpleTag(ref tag) => tag,
@ -389,9 +363,10 @@ mod tests {
};
let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!(note.to, vec![
local_actor_followers(INSTANCE_URL, &post.author.username),
]);
assert_eq!(
note.to,
vec![local_actor_followers(INSTANCE_URL, &post.author.username),]
);
assert_eq!(note.cc.is_empty(), true);
}
@ -415,10 +390,13 @@ mod tests {
};
let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!(note.to, vec![
local_actor_subscribers(INSTANCE_URL, &post.author.username),
subscriber_id.to_string(),
]);
assert_eq!(
note.to,
vec![
local_actor_subscribers(INSTANCE_URL, &post.author.username),
subscriber_id.to_string(),
]
);
assert_eq!(note.cc.is_empty(), true);
}
@ -460,10 +438,13 @@ mod tests {
note.in_reply_to.unwrap(),
format!("{}/objects/{}", INSTANCE_URL, parent.id),
);
assert_eq!(note.to, vec![
AP_PUBLIC.to_string(),
local_actor_id(INSTANCE_URL, &parent.author.username),
]);
assert_eq!(
note.to,
vec![
AP_PUBLIC.to_string(),
local_actor_id(INSTANCE_URL, &parent.author.username),
]
);
}
#[test]
@ -496,10 +477,7 @@ mod tests {
};
let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!(
note.in_reply_to.unwrap(),
parent.object_id.unwrap(),
);
assert_eq!(note.in_reply_to.unwrap(), parent.object_id.unwrap(),);
let tags = note.tag;
assert_eq!(tags.len(), 1);
let tag = match tags[0] {
@ -518,12 +496,11 @@ mod tests {
username: author_username.to_string(),
..Default::default()
};
let post = Post { author, ..Default::default() };
let activity = build_create_note(
INSTANCE_HOSTNAME,
INSTANCE_URL,
&post,
);
let post = Post {
author,
..Default::default()
};
let activity = build_create_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!(
activity.id,

View file

@ -15,11 +15,7 @@ use crate::activitypub::{
vocabulary::{DELETE, NOTE, TOMBSTONE},
};
use super::create_note::{
build_note,
get_note_recipients,
Note,
};
use super::create_note::{build_note, get_note_recipients, Note};
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
@ -48,20 +44,12 @@ struct DeleteNote {
cc: Vec<String>,
}
fn build_delete_note(
instance_hostname: &str,
instance_url: &str,
post: &Post,
) -> DeleteNote {
fn build_delete_note(instance_hostname: &str, instance_url: &str, post: &Post) -> DeleteNote {
assert!(post.is_local());
let object_id = local_object_id(instance_url, &post.id);
let activity_id = format!("{}/delete", object_id);
let actor_id = local_actor_id(instance_url, &post.author.username);
let Note { to, cc, .. } = build_note(
instance_hostname,
instance_url,
post,
);
let Note { to, cc, .. } = build_note(instance_hostname, instance_url, post);
DeleteNote {
context: build_default_context(),
activity_type: DELETE.to_string(),
@ -86,28 +74,18 @@ pub async fn prepare_delete_note(
assert_eq!(author.id, post.author.id);
let mut post = post.clone();
add_related_posts(db_client, vec![&mut post]).await?;
let activity = build_delete_note(
&instance.hostname(),
&instance.url(),
&post,
);
let activity = build_delete_note(&instance.hostname(), &instance.url(), &post);
let recipients = get_note_recipients(db_client, author, &post).await?;
Ok(OutgoingActivity::new(
instance,
author,
activity,
recipients,
instance, author, activity, recipients,
))
}
#[cfg(test)]
mod tests {
use mitra_models::profiles::types::DbActorProfile;
use crate::activitypub::{
constants::AP_PUBLIC,
identifiers::local_actor_followers,
};
use super::*;
use crate::activitypub::{constants::AP_PUBLIC, identifiers::local_actor_followers};
use mitra_models::profiles::types::DbActorProfile;
const INSTANCE_HOSTNAME: &str = "example.com";
const INSTANCE_URL: &str = "https://example.com";
@ -118,12 +96,11 @@ mod tests {
username: "author".to_string(),
..Default::default()
};
let post = Post { author, ..Default::default() };
let activity = build_delete_note(
INSTANCE_HOSTNAME,
INSTANCE_URL,
&post,
);
let post = Post {
author,
..Default::default()
};
let activity = build_delete_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!(
activity.id,

View file

@ -32,10 +32,7 @@ struct DeletePerson {
to: Vec<String>,
}
fn build_delete_person(
instance_url: &str,
user: &User,
) -> DeletePerson {
fn build_delete_person(instance_url: &str, user: &User) -> DeletePerson {
let actor_id = local_actor_id(instance_url, &user.profile.username);
let activity_id = format!("{}/delete", actor_id);
DeletePerson {
@ -59,7 +56,7 @@ async fn get_delete_person_recipients(
if let Some(remote_actor) = profile.actor_json {
recipients.push(remote_actor);
};
};
}
Ok(recipients)
}
@ -70,18 +67,13 @@ pub async fn prepare_delete_person(
) -> Result<OutgoingActivity, DatabaseError> {
let activity = build_delete_person(&instance.url(), user);
let recipients = get_delete_person_recipients(db_client, &user.id).await?;
Ok(OutgoingActivity::new(
instance,
user,
activity,
recipients,
))
Ok(OutgoingActivity::new(instance, user, activity, recipients))
}
#[cfg(test)]
mod tests {
use mitra_models::profiles::types::DbActorProfile;
use super::*;
use mitra_models::profiles::types::DbActorProfile;
const INSTANCE_URL: &str = "https://example.com";
@ -100,10 +92,7 @@ mod tests {
format!("{}/users/testuser/delete", INSTANCE_URL),
);
assert_eq!(activity.actor, activity.object);
assert_eq!(
activity.object,
format!("{}/users/testuser", INSTANCE_URL),
);
assert_eq!(activity.object, format!("{}/users/testuser", INSTANCE_URL),);
assert_eq!(activity.to, vec![AP_PUBLIC]);
}
}

View file

@ -62,12 +62,7 @@ pub fn prepare_follow(
follow_request_id,
);
let recipients = vec![target_actor.clone()];
OutgoingActivity::new(
instance,
sender,
activity,
recipients,
)
OutgoingActivity::new(instance, sender, activity, recipients)
}
pub async fn follow_or_create_request(
@ -78,19 +73,12 @@ pub async fn follow_or_create_request(
) -> Result<(), DatabaseError> {
if let Some(ref remote_actor) = target_profile.actor_json {
// Create follow request if target is remote
match create_follow_request(
db_client,
&current_user.id,
&target_profile.id,
).await {
match create_follow_request(db_client, &current_user.id, &target_profile.id).await {
Ok(follow_request) => {
prepare_follow(
instance,
current_user,
remote_actor,
&follow_request.id,
).enqueue(db_client).await?;
},
prepare_follow(instance, current_user, remote_actor, &follow_request.id)
.enqueue(db_client)
.await?;
}
Err(DatabaseError::AlreadyExists(_)) => (), // already following
Err(other_error) => return Err(other_error),
};
@ -106,8 +94,8 @@ pub async fn follow_or_create_request(
#[cfg(test)]
mod tests {
use mitra_utils::id::generate_ulid;
use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com";
@ -119,12 +107,7 @@ mod tests {
};
let follow_request_id = generate_ulid();
let target_actor_id = "https://test.remote/actor/test";
let activity = build_follow(
INSTANCE_URL,
&follower,
target_actor_id,
&follow_request_id,
);
let activity = build_follow(INSTANCE_URL, &follower, target_actor_id, &follow_request_id);
assert_eq!(
activity.id,

View file

@ -12,12 +12,7 @@ use mitra_models::{
use crate::activitypub::{
constants::AP_PUBLIC,
deliverer::OutgoingActivity,
identifiers::{
local_actor_id,
local_object_id,
post_object_id,
profile_actor_id,
},
identifiers::{local_actor_id, local_object_id, post_object_id, profile_actor_id},
types::{build_default_context, Context},
vocabulary::LIKE,
};
@ -60,8 +55,7 @@ fn build_like(
) -> Like {
let activity_id = local_object_id(instance_url, reaction_id);
let actor_id = local_actor_id(instance_url, &actor_profile.username);
let (primary_audience, secondary_audience) =
get_like_audience(post_author_id, post_visibility);
let (primary_audience, secondary_audience) = get_like_audience(post_author_id, post_visibility);
Like {
context: build_default_context(),
activity_type: LIKE.to_string(),
@ -92,11 +86,7 @@ pub async fn prepare_like(
post: &Post,
reaction_id: &Uuid,
) -> Result<OutgoingActivity, DatabaseError> {
let recipients = get_like_recipients(
db_client,
&instance.url(),
post,
).await?;
let recipients = get_like_recipients(db_client, &instance.url(), post).await?;
let object_id = post_object_id(&instance.url(), post);
let post_author_id = profile_actor_id(&instance.url(), &post.author);
let activity = build_like(
@ -108,17 +98,14 @@ pub async fn prepare_like(
&post.visibility,
);
Ok(OutgoingActivity::new(
instance,
sender,
activity,
recipients,
instance, sender, activity, recipients,
))
}
#[cfg(test)]
mod tests {
use mitra_utils::id::generate_ulid;
use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com";

View file

@ -2,10 +2,7 @@ use serde::Serialize;
use uuid::Uuid;
use mitra_config::Instance;
use mitra_models::{
profiles::types::DbActor,
users::types::User,
};
use mitra_models::{profiles::types::DbActor, users::types::User};
use mitra_utils::id::generate_ulid;
use crate::activitypub::{
@ -38,7 +35,8 @@ pub fn build_move_person(
followers: &[String],
maybe_internal_activity_id: Option<&Uuid>,
) -> MovePerson {
let internal_activity_id = maybe_internal_activity_id.copied()
let internal_activity_id = maybe_internal_activity_id
.copied()
.unwrap_or(generate_ulid());
let activity_id = local_object_id(instance_url, &internal_activity_id);
let actor_id = local_actor_id(instance_url, &sender.profile.username);
@ -60,9 +58,7 @@ pub fn prepare_move_person(
followers: Vec<DbActor>,
maybe_internal_activity_id: Option<&Uuid>,
) -> OutgoingActivity {
let followers_ids: Vec<String> = followers.iter()
.map(|actor| actor.id.clone())
.collect();
let followers_ids: Vec<String> = followers.iter().map(|actor| actor.id.clone()).collect();
let activity = build_move_person(
&instance.url(),
sender,
@ -70,19 +66,14 @@ pub fn prepare_move_person(
&followers_ids,
maybe_internal_activity_id,
);
OutgoingActivity::new(
instance,
sender,
activity,
followers,
)
OutgoingActivity::new(instance, sender, activity, followers)
}
#[cfg(test)]
mod tests {
use mitra_utils::id::generate_ulid;
use mitra_models::profiles::types::DbActorProfile;
use super::*;
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com";

View file

@ -1,13 +1,7 @@
use mitra_config::Instance;
use mitra_models::{
profiles::types::DbActor,
users::types::User,
};
use mitra_models::{profiles::types::DbActor, users::types::User};
use crate::activitypub::{
deliverer::OutgoingActivity,
identifiers::LocalActorCollection,
};
use crate::activitypub::{deliverer::OutgoingActivity, identifiers::LocalActorCollection};
use super::add_person::prepare_update_collection;

View file

@ -9,14 +9,14 @@ use mitra_models::{
users::types::User,
};
use super::announce::get_announce_recipients;
use crate::activitypub::{
constants::AP_PUBLIC,
deliverer::OutgoingActivity,
identifiers::{local_actor_id, local_actor_followers, local_object_id},
identifiers::{local_actor_followers, local_actor_id, local_object_id},
types::{build_default_context, Context},
vocabulary::UNDO,
};
use super::announce::get_announce_recipients;
#[derive(Serialize)]
struct UndoAnnounce {
@ -43,13 +43,8 @@ fn build_undo_announce(
let object_id = local_object_id(instance_url, repost_id);
let activity_id = format!("{}/undo", object_id);
let actor_id = local_actor_id(instance_url, &actor_profile.username);
let primary_audience = vec![
AP_PUBLIC.to_string(),
recipient_id.to_string(),
];
let secondary_audience = vec![
local_actor_followers(instance_url, &actor_profile.username),
];
let primary_audience = vec![AP_PUBLIC.to_string(), recipient_id.to_string()];
let secondary_audience = vec![local_actor_followers(instance_url, &actor_profile.username)];
UndoAnnounce {
context: build_default_context(),
activity_type: UNDO.to_string(),
@ -69,12 +64,8 @@ pub async fn prepare_undo_announce(
repost_id: &Uuid,
) -> Result<OutgoingActivity, DatabaseError> {
assert_ne!(&post.id, repost_id);
let (recipients, primary_recipient) = get_announce_recipients(
db_client,
&instance.url(),
sender,
post,
).await?;
let (recipients, primary_recipient) =
get_announce_recipients(db_client, &instance.url(), sender, post).await?;
let activity = build_undo_announce(
&instance.url(),
&sender.profile,
@ -82,17 +73,14 @@ pub async fn prepare_undo_announce(
&primary_recipient,
);
Ok(OutgoingActivity::new(
instance,
sender,
activity,
recipients,
instance, sender, activity, recipients,
))
}
#[cfg(test)]
mod tests {
use mitra_utils::id::generate_ulid;
use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com";
@ -101,12 +89,7 @@ mod tests {
let announcer = DbActorProfile::default();
let post_author_id = "https://example.com/users/test";
let repost_id = generate_ulid();
let activity = build_undo_announce(
INSTANCE_URL,
&announcer,
&repost_id,
post_author_id,
);
let activity = build_undo_announce(INSTANCE_URL, &announcer, &repost_id, post_author_id);
assert_eq!(
activity.id,
format!("{}/objects/{}/undo", INSTANCE_URL, repost_id),

View file

@ -37,14 +37,8 @@ fn build_undo_follow(
target_actor_id: &str,
follow_request_id: &Uuid,
) -> UndoFollow {
let follow_activity_id = local_object_id(
instance_url,
follow_request_id,
);
let follow_actor_id = local_actor_id(
instance_url,
&actor_profile.username,
);
let follow_activity_id = local_object_id(instance_url, follow_request_id);
let follow_actor_id = local_actor_id(instance_url, &actor_profile.username);
let object = Follow {
context: build_default_context(),
activity_type: FOLLOW.to_string(),
@ -78,18 +72,13 @@ pub fn prepare_undo_follow(
follow_request_id,
);
let recipients = vec![target_actor.clone()];
OutgoingActivity::new(
instance,
sender,
activity,
recipients,
)
OutgoingActivity::new(instance, sender, activity, recipients)
}
#[cfg(test)]
mod tests {
use mitra_utils::id::generate_ulid;
use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com";

View file

@ -16,10 +16,7 @@ use crate::activitypub::{
vocabulary::UNDO,
};
use super::like::{
get_like_audience,
get_like_recipients,
};
use super::like::{get_like_audience, get_like_recipients};
#[derive(Serialize)]
struct UndoLike {
@ -47,8 +44,7 @@ fn build_undo_like(
let object_id = local_object_id(instance_url, reaction_id);
let activity_id = format!("{}/undo", object_id);
let actor_id = local_actor_id(instance_url, &actor_profile.username);
let (primary_audience, secondary_audience) =
get_like_audience(post_author_id, post_visibility);
let (primary_audience, secondary_audience) = get_like_audience(post_author_id, post_visibility);
UndoLike {
context: build_default_context(),
activity_type: UNDO.to_string(),
@ -67,11 +63,7 @@ pub async fn prepare_undo_like(
post: &Post,
reaction_id: &Uuid,
) -> Result<OutgoingActivity, DatabaseError> {
let recipients = get_like_recipients(
db_client,
&instance.url(),
post,
).await?;
let recipients = get_like_recipients(db_client, &instance.url(), post).await?;
let post_author_id = profile_actor_id(&instance.url(), &post.author);
let activity = build_undo_like(
&instance.url(),
@ -81,18 +73,15 @@ pub async fn prepare_undo_like(
&post.visibility,
);
Ok(OutgoingActivity::new(
instance,
sender,
activity,
recipients,
instance, sender, activity, recipients,
))
}
#[cfg(test)]
mod tests {
use mitra_utils::id::generate_ulid;
use crate::activitypub::constants::AP_PUBLIC;
use super::*;
use crate::activitypub::constants::AP_PUBLIC;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com";

View file

@ -41,8 +41,7 @@ pub fn build_update_person(
) -> Result<UpdatePerson, ActorKeyError> {
let actor = get_local_actor(user, instance_url)?;
// Update(Person) is idempotent so its ID can be random
let internal_activity_id =
maybe_internal_activity_id.unwrap_or(generate_ulid());
let internal_activity_id = maybe_internal_activity_id.unwrap_or(generate_ulid());
let activity_id = local_object_id(instance_url, &internal_activity_id);
let activity = UpdatePerson {
context: build_default_context(),
@ -68,7 +67,7 @@ async fn get_update_person_recipients(
if let Some(remote_actor) = profile.actor_json {
recipients.push(remote_actor);
};
};
}
Ok(recipients)
}
@ -78,28 +77,17 @@ pub async fn prepare_update_person(
user: &User,
maybe_internal_activity_id: Option<Uuid>,
) -> Result<OutgoingActivity, DatabaseError> {
let activity = build_update_person(
&instance.url(),
user,
maybe_internal_activity_id,
).map_err(|_| DatabaseTypeError)?;
let activity = build_update_person(&instance.url(), user, maybe_internal_activity_id)
.map_err(|_| DatabaseTypeError)?;
let recipients = get_update_person_recipients(db_client, &user.id).await?;
Ok(OutgoingActivity::new(
instance,
user,
activity,
recipients,
))
Ok(OutgoingActivity::new(instance, user, activity, recipients))
}
#[cfg(test)]
mod tests {
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::crypto_rsa::{
generate_weak_rsa_key,
serialize_private_key,
};
use super::*;
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::crypto_rsa::{generate_weak_rsa_key, serialize_private_key};
const INSTANCE_URL: &str = "https://example.com";
@ -116,11 +104,7 @@ mod tests {
..Default::default()
};
let internal_id = generate_ulid();
let activity = build_update_person(
INSTANCE_URL,
&user,
Some(internal_id),
).unwrap();
let activity = build_update_person(INSTANCE_URL, &user, Some(internal_id)).unwrap();
assert_eq!(
activity.id,
format!("{}/objects/{}", INSTANCE_URL, internal_id),
@ -129,9 +113,12 @@ mod tests {
activity.object.id,
format!("{}/users/testuser", INSTANCE_URL),
);
assert_eq!(activity.to, vec![
AP_PUBLIC.to_string(),
format!("{}/users/testuser/followers", INSTANCE_URL),
]);
assert_eq!(
activity.to,
vec![
AP_PUBLIC.to_string(),
format!("{}/users/testuser/followers", INSTANCE_URL),
]
);
}
}

View file

@ -1,5 +1,5 @@
use serde::Serialize;
use serde_json::{Value as JsonValue};
use serde_json::Value as JsonValue;
use super::types::{build_default_context, Context};
use super::vocabulary::{ORDERED_COLLECTION, ORDERED_COLLECTION_PAGE};
@ -53,10 +53,7 @@ pub struct OrderedCollectionPage {
}
impl OrderedCollectionPage {
pub fn new(
collection_page_id: String,
items: Vec<JsonValue>,
) -> Self {
pub fn new(collection_page_id: String, items: Vec<JsonValue>) -> Self {
Self {
context: build_default_context(),
id: collection_page_id,

View file

@ -1,5 +1,6 @@
// https://www.w3.org/TR/activitypub/#server-to-server-interactions
pub const AP_MEDIA_TYPE: &str = r#"application/ld+json; profile="https://www.w3.org/ns/activitystreams""#;
pub const AP_MEDIA_TYPE: &str =
r#"application/ld+json; profile="https://www.w3.org/ns/activitystreams""#;
pub const AS_MEDIA_TYPE: &str = "application/activity+json";
// Contexts

View file

@ -8,24 +8,14 @@ use serde_json::Value;
use mitra_config::Instance;
use mitra_models::{
database::{
DatabaseClient,
DatabaseError,
},
database::{DatabaseClient, DatabaseError},
profiles::types::DbActor,
users::types::User,
};
use mitra_utils::crypto_rsa::deserialize_private_key;
use crate::http_signatures::create::{
create_http_signature,
HttpSignatureError,
};
use crate::json_signatures::create::{
is_object_signed,
sign_object,
JsonSignatureError,
};
use crate::http_signatures::create::{create_http_signature, HttpSignatureError};
use crate::json_signatures::create::{is_object_signed, sign_object, JsonSignatureError};
use super::{
constants::AP_MEDIA_TYPE,
@ -58,16 +48,9 @@ pub enum DelivererError {
HttpError(reqwest::StatusCode),
}
fn build_client(
instance: &Instance,
request_url: &str,
) -> Result<Client, DelivererError> {
fn build_client(instance: &Instance, request_url: &str) -> Result<Client, DelivererError> {
let network = get_network_type(request_url)?;
let client = build_federation_client(
instance,
network,
instance.deliverer_timeout,
)?;
let client = build_federation_client(instance, network, instance.deliverer_timeout)?;
Ok(client)
}
@ -87,7 +70,8 @@ async fn send_activity(
)?;
let client = build_client(instance, inbox_url)?;
let request = client.post(inbox_url)
let request = client
.post(inbox_url)
.header("Host", headers.host)
.header("Date", headers.date)
.header("Digest", headers.digest.unwrap())
@ -97,15 +81,16 @@ async fn send_activity(
.body(activity_json.to_owned());
if instance.is_private {
log::info!(
"private mode: not sending activity to {}",
inbox_url,
);
log::info!("private mode: not sending activity to {}", inbox_url,);
} else {
let response = request.send().await?;
let response_status = response.status();
let response_text: String = response.text().await?
.chars().filter(|chr| *chr != '\n' && *chr != '\r').take(75)
let response_text: String = response
.text()
.await?
.chars()
.filter(|chr| *chr != '\n' && *chr != '\r')
.take(75)
.collect();
log::info!(
"response from {}: [{}] {}",
@ -135,10 +120,7 @@ async fn deliver_activity_worker(
recipients: &mut [Recipient],
) -> Result<(), DelivererError> {
let actor_key = deserialize_private_key(&sender.private_key)?;
let actor_id = local_actor_id(
&instance.url(),
&sender.profile.username,
);
let actor_id = local_actor_id(&instance.url(), &sender.profile.username);
let actor_key_id = local_actor_key_id(&actor_id);
let activity_signed = if is_object_signed(&activity) {
log::warn!("activity is already signed");
@ -158,7 +140,9 @@ async fn deliver_activity_worker(
&actor_key_id,
&activity_json,
&recipient.inbox,
).await {
)
.await
{
log::warn!(
"failed to deliver activity to {}: {}",
recipient.inbox,
@ -167,7 +151,7 @@ async fn deliver_activity_worker(
} else {
recipient.is_delivered = true;
};
};
}
Ok(())
}
@ -196,32 +180,27 @@ impl OutgoingActivity {
};
recipient_map.insert(actor.id, recipient);
};
};
}
Self {
instance: instance.clone(),
sender: sender.clone(),
activity: serde_json::to_value(activity)
.expect("activity should be serializable"),
activity: serde_json::to_value(activity).expect("activity should be serializable"),
recipients: recipient_map.into_values().collect(),
}
}
pub(super) async fn deliver(
mut self,
) -> Result<Vec<Recipient>, DelivererError> {
pub(super) async fn deliver(mut self) -> Result<Vec<Recipient>, DelivererError> {
deliver_activity_worker(
self.instance,
self.sender,
self.activity,
&mut self.recipients,
).await?;
)
.await?;
Ok(self.recipients)
}
pub async fn enqueue(
self,
db_client: &impl DatabaseClient,
) -> Result<(), DatabaseError> {
pub async fn enqueue(self, db_client: &impl DatabaseClient) -> Result<(), DatabaseError> {
if self.recipients.is_empty() {
return Ok(());
};

View file

@ -2,13 +2,10 @@ use std::path::Path;
use reqwest::{Client, Method, RequestBuilder};
use serde::Deserialize;
use serde_json::{Value as JsonValue};
use serde_json::Value as JsonValue;
use mitra_config::Instance;
use mitra_utils::{
files::sniff_media_type,
urls::guess_protocol,
};
use mitra_utils::{files::sniff_media_type, urls::guess_protocol};
use crate::activitypub::{
actors::types::Actor,
@ -18,10 +15,7 @@ use crate::activitypub::{
types::Object,
vocabulary::GROUP,
};
use crate::http_signatures::create::{
create_http_signature,
HttpSignatureError,
};
use crate::http_signatures::create::{create_http_signature, HttpSignatureError};
use crate::media::{save_file, SUPPORTED_MEDIA_TYPES};
use crate::webfinger::types::{ActorAddress, JsonResourceDescriptor};
@ -52,39 +46,23 @@ pub enum FetchError {
OtherError(&'static str),
}
fn build_client(
instance: &Instance,
request_url: &str,
) -> Result<Client, FetchError> {
fn build_client(instance: &Instance, request_url: &str) -> Result<Client, FetchError> {
let network = get_network_type(request_url)?;
let client = build_federation_client(
instance,
network,
instance.fetcher_timeout,
)?;
let client = build_federation_client(instance, network, instance.fetcher_timeout)?;
Ok(client)
}
fn build_request(
instance: &Instance,
client: Client,
method: Method,
url: &str,
) -> RequestBuilder {
fn build_request(instance: &Instance, client: Client, method: Method, url: &str) -> RequestBuilder {
let mut request_builder = client.request(method, url);
if !instance.is_private {
// Public instances should set User-Agent header
request_builder = request_builder
.header(reqwest::header::USER_AGENT, instance.agent());
request_builder = request_builder.header(reqwest::header::USER_AGENT, instance.agent());
};
request_builder
}
/// Sends GET request to fetch AP object
async fn send_request(
instance: &Instance,
url: &str,
) -> Result<String, FetchError> {
async fn send_request(instance: &Instance, url: &str) -> Result<String, FetchError> {
let client = build_client(instance, url)?;
let mut request_builder = build_request(instance, client, Method::GET, url)
.header(reqwest::header::ACCEPT, AP_MEDIA_TYPE);
@ -107,9 +85,11 @@ async fn send_request(
};
let data = request_builder
.send().await?
.send()
.await?
.error_for_status()?
.text().await?;
.text()
.await?;
Ok(data)
}
@ -121,12 +101,10 @@ pub async fn fetch_file(
output_dir: &Path,
) -> Result<(String, usize, Option<String>), FetchError> {
let client = build_client(instance, url)?;
let request_builder =
build_request(instance, client, Method::GET, url);
let request_builder = build_request(instance, client, Method::GET, url);
let response = request_builder.send().await?.error_for_status()?;
if let Some(file_size) = response.content_length() {
let file_size: usize = file_size.try_into()
.expect("value should be within bounds");
let file_size: usize = file_size.try_into().expect("value should be within bounds");
if file_size > file_max_size {
return Err(FetchError::FileTooLarge);
};
@ -145,19 +123,11 @@ pub async fn fetch_file(
if SUPPORTED_MEDIA_TYPES.contains(&media_type.as_str()) {
true
} else {
log::info!(
"unsupported media type {}: {}",
media_type,
url,
);
log::info!("unsupported media type {}: {}", media_type, url,);
false
}
});
let file_name = save_file(
file_data.to_vec(),
output_dir,
maybe_media_type.as_deref(),
)?;
let file_name = save_file(file_data.to_vec(), output_dir, maybe_media_type.as_deref())?;
Ok((file_name, file_size, maybe_media_type))
}
@ -172,43 +142,45 @@ pub async fn perform_webfinger_query(
actor_address.hostname,
);
let client = build_client(instance, &webfinger_url)?;
let request_builder =
build_request(instance, client, Method::GET, &webfinger_url);
let request_builder = build_request(instance, client, Method::GET, &webfinger_url);
let webfinger_data = request_builder
.query(&[("resource", webfinger_account_uri)])
.send().await?
.send()
.await?
.error_for_status()?
.text().await?;
.text()
.await?;
let jrd: JsonResourceDescriptor = serde_json::from_str(&webfinger_data)?;
// Lemmy servers can have Group and Person actors with the same name
// https://github.com/LemmyNet/lemmy/issues/2037
let ap_type_property = format!("{}#type", AP_CONTEXT);
let group_link = jrd.links.iter()
.find(|link| {
link.rel == "self" &&
link.properties
let group_link = jrd.links.iter().find(|link| {
link.rel == "self"
&& link
.properties
.get(&ap_type_property)
.map(|val| val.as_str()) == Some(GROUP)
});
.map(|val| val.as_str())
== Some(GROUP)
});
let link = if let Some(link) = group_link {
// Prefer Group if the actor type is provided
link
} else {
// Otherwise take first "self" link
jrd.links.iter()
jrd.links
.iter()
.find(|link| link.rel == "self")
.ok_or(FetchError::OtherError("self link not found"))?
};
let actor_url = link.href.as_ref()
let actor_url = link
.href
.as_ref()
.ok_or(FetchError::OtherError("account href not found"))?
.to_string();
Ok(actor_url)
}
pub async fn fetch_actor(
instance: &Instance,
actor_url: &str,
) -> Result<Actor, FetchError> {
pub async fn fetch_actor(instance: &Instance, actor_url: &str) -> Result<Actor, FetchError> {
let actor_json = send_request(instance, actor_url).await?;
let actor: Actor = serde_json::from_str(&actor_json)?;
if actor.id != actor_url {
@ -217,10 +189,7 @@ pub async fn fetch_actor(
Ok(actor)
}
pub async fn fetch_object(
instance: &Instance,
object_url: &str,
) -> Result<Object, FetchError> {
pub async fn fetch_object(instance: &Instance, object_url: &str) -> Result<Object, FetchError> {
let object_json = send_request(instance, object_url).await?;
let object_value: JsonValue = serde_json::from_str(&object_json)?;
let object: Object = serde_json::from_value(object_value)?;
@ -245,7 +214,6 @@ pub async fn fetch_outbox(
let collection: Collection = serde_json::from_str(&collection_json)?;
let page_json = send_request(instance, &collection.first).await?;
let page: CollectionPage = serde_json::from_str(&page_json)?;
let activities = page.ordered_items.into_iter()
.take(limit).collect();
let activities = page.ordered_items.into_iter().take(limit).collect();
Ok(activities)
}

View file

@ -6,13 +6,13 @@ use mitra_models::{
posts::helpers::get_local_post_by_id,
posts::queries::get_post_by_remote_object_id,
posts::types::Post,
profiles::queries::{
get_profile_by_acct,
get_profile_by_remote_actor_id,
},
profiles::queries::{get_profile_by_acct, get_profile_by_remote_actor_id},
profiles::types::DbActorProfile,
};
use super::fetchers::{
fetch_actor, fetch_object, fetch_outbox, perform_webfinger_query, FetchError,
};
use crate::activitypub::{
actors::helpers::{create_remote_profile, update_remote_profile},
handlers::create::{get_object_links, handle_note},
@ -23,13 +23,6 @@ use crate::activitypub::{
use crate::errors::ValidationError;
use crate::media::MediaStorage;
use crate::webfinger::types::ActorAddress;
use super::fetchers::{
fetch_actor,
fetch_object,
fetch_outbox,
perform_webfinger_query,
FetchError,
};
pub async fn get_or_import_profile_by_actor_id(
db_client: &mut impl DatabaseClient,
@ -40,37 +33,28 @@ pub async fn get_or_import_profile_by_actor_id(
if actor_id.starts_with(&instance.url()) {
return Err(HandlerError::LocalObject);
};
let profile = match get_profile_by_remote_actor_id(
db_client,
actor_id,
).await {
let profile = match get_profile_by_remote_actor_id(db_client, actor_id).await {
Ok(profile) => {
if profile.possibly_outdated() {
// Try to re-fetch actor profile
match fetch_actor(instance, actor_id).await {
Ok(actor) => {
log::info!("re-fetched profile {}", profile.acct);
let profile_updated = update_remote_profile(
db_client,
instance,
storage,
profile,
actor,
).await?;
let profile_updated =
update_remote_profile(db_client, instance, storage, profile, actor)
.await?;
profile_updated
},
}
Err(err) => {
// Ignore error and return stored profile
log::warn!(
"failed to re-fetch {} ({})", profile.acct, err,
);
log::warn!("failed to re-fetch {} ({})", profile.acct, err,);
profile
},
}
}
} else {
profile
}
},
}
Err(DatabaseError::NotFound(_)) => {
let actor = fetch_actor(instance, actor_id).await?;
let actor_address = actor.address()?;
@ -79,28 +63,19 @@ pub async fn get_or_import_profile_by_actor_id(
Ok(profile) => {
// WARNING: Possible actor ID change
log::info!("re-fetched profile {}", profile.acct);
let profile_updated = update_remote_profile(
db_client,
instance,
storage,
profile,
actor,
).await?;
let profile_updated =
update_remote_profile(db_client, instance, storage, profile, actor).await?;
profile_updated
},
}
Err(DatabaseError::NotFound(_)) => {
log::info!("fetched profile {}", acct);
let profile = create_remote_profile(
db_client,
instance,
storage,
actor,
).await?;
let profile =
create_remote_profile(db_client, instance, storage, actor).await?;
profile
},
}
Err(other_error) => return Err(other_error.into()),
}
},
}
Err(other_error) => return Err(other_error.into()),
};
Ok(profile)
@ -128,12 +103,7 @@ pub async fn import_profile_by_actor_address(
};
};
log::info!("fetched profile {}", profile_acct);
let profile = create_remote_profile(
db_client,
instance,
storage,
actor,
).await?;
let profile = create_remote_profile(db_client, instance, storage, actor).await?;
Ok(profile)
}
@ -145,22 +115,14 @@ pub async fn get_or_import_profile_by_actor_address(
actor_address: &ActorAddress,
) -> Result<DbActorProfile, HandlerError> {
let acct = actor_address.acct(&instance.hostname());
let profile = match get_profile_by_acct(
db_client,
&acct,
).await {
let profile = match get_profile_by_acct(db_client, &acct).await {
Ok(profile) => profile,
Err(db_error @ DatabaseError::NotFound(_)) => {
if actor_address.hostname == instance.hostname() {
return Err(db_error.into());
};
import_profile_by_actor_address(
db_client,
instance,
storage,
actor_address,
).await?
},
import_profile_by_actor_address(db_client, instance, storage, actor_address).await?
}
Err(other_error) => return Err(other_error.into()),
};
Ok(profile)
@ -176,12 +138,12 @@ pub async fn get_post_by_object_id(
// Local post
let post = get_local_post_by_id(db_client, &post_id).await?;
Ok(post)
},
}
Err(_) => {
// Remote post
let post = get_post_by_remote_object_id(db_client, object_id).await?;
Ok(post)
},
}
}
}
@ -222,10 +184,7 @@ pub async fn import_post(
get_local_post_by_id(db_client, &post_id).await?;
continue;
};
match get_post_by_remote_object_id(
db_client,
&object_id,
).await {
match get_post_by_remote_object_id(db_client, &object_id).await {
Ok(post) => {
// Object already fetched
if objects.len() == 0 {
@ -233,16 +192,16 @@ pub async fn import_post(
return Ok(post);
};
continue;
},
}
Err(DatabaseError::NotFound(_)) => (),
Err(other_error) => return Err(other_error.into()),
};
object_id
},
}
None => {
// No object to fetch
break;
},
}
};
let object = match maybe_object {
Some(object) => object,
@ -251,15 +210,14 @@ pub async fn import_post(
// TODO: create tombstone
return Err(FetchError::RecursionError.into());
};
let object = fetch_object(instance, &object_id).await
.map_err(|err| {
log::warn!("{}", err);
ValidationError("failed to fetch object")
})?;
let object = fetch_object(instance, &object_id).await.map_err(|err| {
log::warn!("{}", err);
ValidationError("failed to fetch object")
})?;
log::info!("fetched object {}", object.id);
fetch_count += 1;
fetch_count += 1;
object
},
}
};
if object.id != object_id {
// ID of fetched object doesn't match requested ID
@ -277,27 +235,22 @@ pub async fn import_post(
for object_id in get_object_links(&object) {
// Fetch linked objects after fetching current thread
queue.insert(0, object_id);
};
}
maybe_object = None;
objects.push(object);
};
}
let initial_object_id = objects[0].id.clone();
// Objects are ordered according to their place in reply tree,
// starting with the root
objects.reverse();
for object in objects {
let post = handle_note(
db_client,
instance,
storage,
object,
&redirects,
).await?;
let post = handle_note(db_client, instance, storage, object, &redirects).await?;
posts.push(post);
};
}
let initial_post = posts.into_iter()
let initial_post = posts
.into_iter()
.find(|post| post.object_id.as_ref() == Some(&initial_object_id))
.unwrap();
Ok(initial_post)
@ -314,24 +267,20 @@ pub async fn import_from_outbox(
let activities = fetch_outbox(&instance, &actor.outbox, limit).await?;
log::info!("fetched {} activities", activities.len());
for activity in activities {
let activity_actor = activity["actor"].as_str()
let activity_actor = activity["actor"]
.as_str()
.ok_or(ValidationError("actor property is missing"))?;
if activity_actor != actor.id {
log::warn!("activity doesn't belong to outbox owner");
continue;
};
handle_activity(
config,
db_client,
&activity,
true, // is authenticated
).await.unwrap_or_else(|error| {
log::warn!(
"failed to process activity ({}): {}",
error,
activity,
);
config, db_client, &activity, true, // is authenticated
)
.await
.unwrap_or_else(|error| {
log::warn!("failed to process activity ({}): {}", error, activity,);
});
};
}
Ok(())
}

View file

@ -5,17 +5,12 @@ use mitra_config::Config;
use mitra_models::{
database::DatabaseClient,
profiles::queries::get_profile_by_remote_actor_id,
relationships::queries::{
follow_request_accepted,
get_follow_request_by_id,
},
relationships::queries::{follow_request_accepted, get_follow_request_by_id},
relationships::types::FollowRequestStatus,
};
use crate::activitypub::{
identifiers::parse_local_object_id,
receiver::deserialize_into_object_id,
vocabulary::FOLLOW,
identifiers::parse_local_object_id, receiver::deserialize_into_object_id, vocabulary::FOLLOW,
};
use crate::errors::ValidationError;
@ -36,14 +31,8 @@ pub async fn handle_accept(
// Accept(Follow)
let activity: Accept = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id(
db_client,
&activity.actor,
).await?;
let follow_request_id = parse_local_object_id(
&config.instance_url(),
&activity.object,
)?;
let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
let follow_request_id = parse_local_object_id(&config.instance_url(), &activity.object)?;
let follow_request = get_follow_request_by_id(db_client, &follow_request_id).await?;
if follow_request.target_id != actor_profile.id {
return Err(ValidationError("actor is not a target").into());

View file

@ -3,18 +3,13 @@ use serde_json::Value;
use mitra_config::Config;
use mitra_models::{
database::DatabaseClient,
profiles::queries::get_profile_by_remote_actor_id,
relationships::queries::subscribe_opt,
users::queries::get_user_by_name,
database::DatabaseClient, profiles::queries::get_profile_by_remote_actor_id,
relationships::queries::subscribe_opt, users::queries::get_user_by_name,
};
use crate::activitypub::{
identifiers::parse_local_actor_id,
vocabulary::PERSON,
};
use crate::errors::ValidationError;
use super::{HandlerError, HandlerResult};
use crate::activitypub::{identifiers::parse_local_actor_id, vocabulary::PERSON};
use crate::errors::ValidationError;
#[derive(Deserialize)]
struct Add {
@ -30,17 +25,11 @@ pub async fn handle_add(
) -> HandlerResult {
let activity: Add = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id(
db_client,
&activity.actor,
).await?;
let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
let actor = actor_profile.actor_json.ok_or(HandlerError::LocalObject)?;
if Some(activity.target) == actor.subscribers {
// Adding to subscribers
let username = parse_local_actor_id(
&config.instance_url(),
&activity.object,
)?;
let username = parse_local_actor_id(&config.instance_url(), &activity.object)?;
let user = get_user_by_name(db_client, &username).await?;
subscribe_opt(db_client, &user.id, &actor_profile.id).await?;
return Ok(Some(PERSON));

View file

@ -4,13 +4,11 @@ use serde_json::Value;
use mitra_config::Config;
use mitra_models::{
database::{DatabaseClient, DatabaseError},
posts::queries::{
create_post,
get_post_by_remote_object_id,
},
posts::queries::{create_post, get_post_by_remote_object_id},
posts::types::PostCreateData,
};
use super::HandlerResult;
use crate::activitypub::{
fetcher::helpers::{get_or_import_profile_by_actor_id, import_post},
identifiers::parse_local_object_id,
@ -19,7 +17,6 @@ use crate::activitypub::{
};
use crate::errors::ValidationError;
use crate::media::MediaStorage;
use super::HandlerResult;
#[derive(Deserialize)]
struct Announce {
@ -44,43 +41,24 @@ pub async fn handle_announce(
let activity: Announce = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?;
let repost_object_id = activity.id;
match get_post_by_remote_object_id(
db_client,
&repost_object_id,
).await {
match get_post_by_remote_object_id(db_client, &repost_object_id).await {
Ok(_) => return Ok(None), // Ignore if repost already exists
Err(DatabaseError::NotFound(_)) => (),
Err(other_error) => return Err(other_error.into()),
};
let instance = config.instance();
let storage = MediaStorage::from(config);
let author = get_or_import_profile_by_actor_id(
db_client,
&instance,
&storage,
&activity.actor,
).await?;
let post_id = match parse_local_object_id(
&instance.url(),
&activity.object,
) {
let author =
get_or_import_profile_by_actor_id(db_client, &instance, &storage, &activity.actor).await?;
let post_id = match parse_local_object_id(&instance.url(), &activity.object) {
Ok(post_id) => post_id,
Err(_) => {
// Try to get remote post
let post = import_post(
db_client,
&instance,
&storage,
activity.object,
None,
).await?;
let post = import_post(db_client, &instance, &storage, activity.object, None).await?;
post.id
},
}
};
let repost_data = PostCreateData::repost(
post_id,
Some(repost_object_id.clone()),
);
let repost_data = PostCreateData::repost(post_id, Some(repost_object_id.clone()));
match create_post(db_client, &author.id, repost_data).await {
Ok(_) => Ok(Some(NOTE)),
Err(DatabaseError::AlreadyExists("post")) => {
@ -88,7 +66,7 @@ pub async fn handle_announce(
// object ID, or due to race condition in a handler).
log::warn!("repost already exists: {}", repost_object_id);
Ok(None)
},
}
// May return "post not found" error if post if not public
Err(other_error) => Err(other_error.into()),
}
@ -96,8 +74,8 @@ pub async fn handle_announce(
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use serde_json::json;
#[test]
fn test_deserialize_announce() {

View file

@ -2,18 +2,14 @@ use std::collections::HashMap;
use chrono::Utc;
use serde::Deserialize;
use serde_json::{Value as JsonValue};
use serde_json::Value as JsonValue;
use uuid::Uuid;
use mitra_config::{Config, Instance};
use mitra_models::{
attachments::queries::create_attachment,
database::{DatabaseClient, DatabaseError},
emojis::queries::{
create_emoji,
get_emoji_by_remote_object_id,
update_emoji,
},
emojis::queries::{create_emoji, get_emoji_by_remote_object_id, update_emoji},
emojis::types::{DbEmoji, EmojiImage},
posts::{
queries::create_post,
@ -23,19 +19,15 @@ use mitra_models::{
relationships::queries::has_local_followers,
users::queries::get_user_by_name,
};
use mitra_utils::{
html::clean_html,
urls::get_hostname,
};
use mitra_utils::{html::clean_html, urls::get_hostname};
use super::HandlerResult;
use crate::activitypub::{
constants::{AP_MEDIA_TYPE, AP_PUBLIC, AS_MEDIA_TYPE},
fetcher::fetchers::{fetch_file, FetchError},
fetcher::helpers::{
get_or_import_profile_by_actor_address,
get_or_import_profile_by_actor_id,
get_post_by_object_id,
import_post,
get_or_import_profile_by_actor_address, get_or_import_profile_by_actor_id,
get_post_by_object_id, import_post,
},
identifiers::{parse_local_actor_id, profile_actor_id},
receiver::{parse_array, parse_property_value, HandlerError},
@ -45,28 +37,19 @@ use crate::activitypub::{
use crate::errors::ValidationError;
use crate::media::MediaStorage;
use crate::validators::{
emojis::{
validate_emoji_name,
EMOJI_MEDIA_TYPES,
},
emojis::{validate_emoji_name, EMOJI_MEDIA_TYPES},
posts::{
content_allowed_classes,
ATTACHMENT_LIMIT,
CONTENT_MAX_SIZE,
EMOJI_LIMIT,
LINK_LIMIT,
MENTION_LIMIT,
OBJECT_ID_SIZE_MAX,
content_allowed_classes, ATTACHMENT_LIMIT, CONTENT_MAX_SIZE, EMOJI_LIMIT, LINK_LIMIT,
MENTION_LIMIT, OBJECT_ID_SIZE_MAX,
},
tags::validate_hashtag,
};
use crate::webfinger::types::ActorAddress;
use super::HandlerResult;
fn get_object_attributed_to(object: &Object)
-> Result<String, ValidationError>
{
let attributed_to = object.attributed_to.as_ref()
fn get_object_attributed_to(object: &Object) -> Result<String, ValidationError> {
let attributed_to = object
.attributed_to
.as_ref()
.ok_or(ValidationError("unattributed note"))?;
let author_id = parse_array(attributed_to)
.map_err(|_| ValidationError("invalid attributedTo property"))?
@ -83,7 +66,7 @@ pub fn get_object_url(object: &Object) -> Result<String, ValidationError> {
let links: Vec<Link> = parse_property_value(other_value)
.map_err(|_| ValidationError("invalid object URL"))?;
links.get(0).map(|link| link.href.clone())
},
}
None => None,
};
let object_url = maybe_object_url.unwrap_or(object.id.clone());
@ -110,10 +93,7 @@ pub fn get_object_content(object: &Object) -> Result<String, ValidationError> {
}
pub fn create_content_link(url: String) -> String {
format!(
r#"<p><a href="{0}" rel="noopener">{0}</a></p>"#,
url,
)
format!(r#"<p><a href="{0}" rel="noopener">{0}</a></p>"#, url,)
}
fn is_gnu_social_link(author_id: &str, attachment: &Attachment) -> bool {
@ -147,21 +127,16 @@ pub async fn get_object_attachments(
match attachment.attachment_type.as_str() {
DOCUMENT | IMAGE | VIDEO => (),
_ => {
log::warn!(
"skipping attachment of type {}",
attachment.attachment_type,
);
log::warn!("skipping attachment of type {}", attachment.attachment_type,);
continue;
},
}
};
if is_gnu_social_link(
&profile_actor_id(&instance.url(), author),
&attachment,
) {
if is_gnu_social_link(&profile_actor_id(&instance.url(), author), &attachment) {
// Don't fetch HTML pages attached by GNU Social
continue;
};
let attachment_url = attachment.url
let attachment_url = attachment
.url
.ok_or(ValidationError("attachment URL is missing"))?;
let (file_name, file_size, maybe_media_type) = match fetch_file(
instance,
@ -169,17 +144,19 @@ pub async fn get_object_attachments(
attachment.media_type.as_deref(),
storage.file_size_limit,
&storage.media_dir,
).await {
)
.await
{
Ok(file) => file,
Err(FetchError::FileTooLarge) => {
log::warn!("attachment is too large: {}", attachment_url);
unprocessed.push(attachment_url);
continue;
},
}
Err(other_error) => {
log::warn!("{}", other_error);
return Err(ValidationError("failed to fetch attachment").into());
},
}
};
log::info!("downloaded attachment {}", attachment_url);
downloaded.push((file_name, file_size, maybe_media_type));
@ -188,7 +165,7 @@ pub async fn get_object_attachments(
log::warn!("too many attachments");
break;
};
};
}
for (file_name, file_size, maybe_media_type) in downloaded {
let db_attachment = create_attachment(
db_client,
@ -196,9 +173,10 @@ pub async fn get_object_attachments(
file_name,
file_size,
maybe_media_type,
).await?;
)
.await?;
attachments.push(db_attachment.id);
};
}
};
Ok((attachments, unprocessed))
}
@ -209,9 +187,7 @@ fn normalize_hashtag(tag: &str) -> Result<String, ValidationError> {
Ok(tag_name.to_lowercase())
}
pub fn get_object_links(
object: &Object,
) -> Vec<String> {
pub fn get_object_links(object: &Object) -> Vec<String> {
let mut links = vec![];
for tag_value in object.tag.clone() {
let tag_type = tag_value["type"].as_str().unwrap_or(HASHTAG);
@ -221,11 +197,9 @@ pub fn get_object_links(
Err(_) => {
log::warn!("invalid link tag");
continue;
},
}
};
if tag.media_type != AP_MEDIA_TYPE &&
tag.media_type != AS_MEDIA_TYPE
{
if tag.media_type != AP_MEDIA_TYPE && tag.media_type != AS_MEDIA_TYPE {
// Unknown media type
continue;
};
@ -233,7 +207,7 @@ pub fn get_object_links(
links.push(tag.href);
};
};
};
}
if let Some(ref object_id) = object.quote_url {
if !links.contains(object_id) {
links.push(object_id.to_owned());
@ -253,17 +227,14 @@ pub async fn handle_emoji(
Err(error) => {
log::warn!("invalid emoji tag: {}", error);
return Ok(None);
},
}
};
let emoji_name = tag.name.trim_matches(':');
if validate_emoji_name(emoji_name).is_err() {
log::warn!("invalid emoji name: {}", emoji_name);
return Ok(None);
};
let maybe_emoji_id = match get_emoji_by_remote_object_id(
db_client,
&tag.id,
).await {
let maybe_emoji_id = match get_emoji_by_remote_object_id(db_client, &tag.id).await {
Ok(emoji) => {
if emoji.updated_at >= tag.updated {
// Emoji already exists and is up to date
@ -274,7 +245,7 @@ pub async fn handle_emoji(
return Ok(None);
};
Some(emoji.id)
},
}
Err(DatabaseError::NotFound("emoji")) => None,
Err(other_error) => return Err(other_error.into()),
};
@ -284,37 +255,32 @@ pub async fn handle_emoji(
tag.icon.media_type.as_deref(),
storage.emoji_size_limit,
&storage.media_dir,
).await {
)
.await
{
Ok(file) => file,
Err(error) => {
log::warn!("failed to fetch emoji: {}", error);
return Ok(None);
},
}
};
let media_type = match maybe_media_type {
Some(media_type) if EMOJI_MEDIA_TYPES.contains(&media_type.as_str()) => {
media_type
},
Some(media_type) if EMOJI_MEDIA_TYPES.contains(&media_type.as_str()) => media_type,
_ => {
log::warn!(
"unexpected emoji media type: {:?}",
maybe_media_type,
);
log::warn!("unexpected emoji media type: {:?}", maybe_media_type,);
return Ok(None);
},
}
};
log::info!("downloaded emoji {}", tag.icon.url);
let image = EmojiImage { file_name, file_size, media_type };
let image = EmojiImage {
file_name,
file_size,
media_type,
};
let emoji = if let Some(emoji_id) = maybe_emoji_id {
update_emoji(
db_client,
&emoji_id,
image,
&tag.updated,
).await?
update_emoji(db_client, &emoji_id, image, &tag.updated).await?
} else {
let hostname = get_hostname(&tag.id)
.map_err(|_| ValidationError("invalid emoji ID"))?;
let hostname = get_hostname(&tag.id).map_err(|_| ValidationError("invalid emoji ID"))?;
match create_emoji(
db_client,
emoji_name,
@ -322,12 +288,14 @@ pub async fn handle_emoji(
image,
Some(&tag.id),
&tag.updated,
).await {
)
.await
{
Ok(emoji) => emoji,
Err(DatabaseError::AlreadyExists(_)) => {
log::warn!("emoji name is not unique: {}", emoji_name);
return Ok(None);
},
}
Err(other_error) => return Err(other_error.into()),
}
};
@ -353,7 +321,7 @@ pub async fn get_object_tags(
Err(_) => {
log::warn!("invalid hashtag");
continue;
},
}
};
if let Some(tag_name) = tag.name {
// Ignore invalid tags
@ -373,7 +341,7 @@ pub async fn get_object_tags(
Err(_) => {
log::warn!("invalid mention");
continue;
},
}
};
// Try to find profile by actor ID.
if let Some(href) = tag.href {
@ -386,25 +354,16 @@ pub async fn get_object_tags(
};
// NOTE: `href` attribute is usually actor ID
// but also can be actor URL (profile link).
match get_or_import_profile_by_actor_id(
db_client,
instance,
storage,
&href,
).await {
match get_or_import_profile_by_actor_id(db_client, instance, storage, &href).await {
Ok(profile) => {
if !mentions.contains(&profile.id) {
mentions.push(profile.id);
};
continue;
},
}
Err(error) => {
log::warn!(
"failed to find mentioned profile by ID {}: {}",
href,
error,
);
},
log::warn!("failed to find mentioned profile by ID {}: {}", href, error,);
}
};
};
// Try to find profile by actor address
@ -413,7 +372,7 @@ pub async fn get_object_tags(
None => {
log::warn!("failed to parse mention");
continue;
},
}
};
if let Ok(actor_address) = ActorAddress::from_mention(&tag_name) {
let profile = match get_or_import_profile_by_actor_address(
@ -421,12 +380,14 @@ pub async fn get_object_tags(
instance,
storage,
&actor_address,
).await {
)
.await
{
Ok(profile) => profile,
Err(error @ (
HandlerError::FetchError(_) |
HandlerError::DatabaseError(DatabaseError::NotFound(_))
)) => {
Err(
error @ (HandlerError::FetchError(_)
| HandlerError::DatabaseError(DatabaseError::NotFound(_))),
) => {
// Ignore mention if fetcher fails
// Ignore mention if local address is not valid
log::warn!(
@ -435,7 +396,7 @@ pub async fn get_object_tags(
error,
);
continue;
},
}
Err(other_error) => return Err(other_error),
};
if !mentions.contains(&profile.id) {
@ -454,20 +415,14 @@ pub async fn get_object_tags(
Err(_) => {
log::warn!("invalid link tag");
continue;
},
}
};
if tag.media_type != AP_MEDIA_TYPE &&
tag.media_type != AS_MEDIA_TYPE
{
if tag.media_type != AP_MEDIA_TYPE && tag.media_type != AS_MEDIA_TYPE {
// Unknown media type
continue;
};
let href = redirects.get(&tag.href).unwrap_or(&tag.href);
let linked = get_post_by_object_id(
db_client,
&instance.url(),
href,
).await?;
let linked = get_post_by_object_id(db_client, &instance.url(), href).await?;
if !links.contains(&linked.id) {
links.push(linked.id);
};
@ -476,30 +431,21 @@ pub async fn get_object_tags(
log::warn!("too many emojis");
continue;
};
match handle_emoji(
db_client,
instance,
storage,
tag_value,
).await? {
match handle_emoji(db_client, instance, storage, tag_value).await? {
Some(emoji) => {
if !emojis.contains(&emoji.id) {
emojis.push(emoji.id);
};
},
}
None => continue,
};
} else {
log::warn!("skipping tag of type {}", tag_type);
};
};
}
if let Some(ref object_id) = object.quote_url {
let object_id = redirects.get(object_id).unwrap_or(object_id);
let linked = get_post_by_object_id(
db_client,
&instance.url(),
object_id,
).await?;
let linked = get_post_by_object_id(db_client, &instance.url(), object_id).await?;
if !links.contains(&linked.id) {
links.push(linked.id);
};
@ -510,16 +456,14 @@ pub async fn get_object_tags(
fn get_audience(object: &Object) -> Result<Vec<String>, ValidationError> {
let primary_audience = match object.to {
Some(ref value) => {
parse_array(value)
.map_err(|_| ValidationError("invalid 'to' property value"))?
},
parse_array(value).map_err(|_| ValidationError("invalid 'to' property value"))?
}
None => vec![],
};
let secondary_audience = match object.cc {
Some(ref value) => {
parse_array(value)
.map_err(|_| ValidationError("invalid 'cc' property value"))?
},
parse_array(value).map_err(|_| ValidationError("invalid 'cc' property value"))?
}
None => vec![],
};
let audience = [primary_audience, secondary_audience].concat();
@ -528,22 +472,19 @@ fn get_audience(object: &Object) -> Result<Vec<String>, ValidationError> {
fn is_public_object(audience: &[String]) -> bool {
// Some servers (e.g. Takahe) use "as" namespace
const PUBLIC_VARIANTS: [&str; 3] = [
AP_PUBLIC,
"as:Public",
"Public",
];
audience.iter().any(|item| PUBLIC_VARIANTS.contains(&item.as_str()))
const PUBLIC_VARIANTS: [&str; 3] = [AP_PUBLIC, "as:Public", "Public"];
audience
.iter()
.any(|item| PUBLIC_VARIANTS.contains(&item.as_str()))
}
fn get_object_visibility(
author: &DbActorProfile,
audience: &[String],
) -> Visibility {
fn get_object_visibility(author: &DbActorProfile, audience: &[String]) -> Visibility {
if is_public_object(audience) {
return Visibility::Public;
return Visibility::Public;
};
let actor = author.actor_json.as_ref()
let actor = author
.actor_json
.as_ref()
.expect("actor data should be present");
if let Some(ref followers) = actor.followers {
if audience.contains(followers) {
@ -569,26 +510,23 @@ pub async fn handle_note(
NOTE => (),
ARTICLE | EVENT | QUESTION | PAGE | VIDEO => {
log::info!("processing object of type {}", object.object_type);
},
}
other_type => {
log::warn!("discarding object of type {}", other_type);
return Err(ValidationError("unsupported object type").into());
},
}
};
if object.id.len() > OBJECT_ID_SIZE_MAX {
return Err(ValidationError("object ID is too long").into());
};
let author_id = get_object_attributed_to(&object)?;
let author = get_or_import_profile_by_actor_id(
db_client,
instance,
storage,
&author_id,
).await.map_err(|err| {
log::warn!("failed to import {} ({})", author_id, err);
err
})?;
let author = get_or_import_profile_by_actor_id(db_client, instance, storage, &author_id)
.await
.map_err(|err| {
log::warn!("failed to import {} ({})", author_id, err);
err
})?;
let mut content = get_object_content(&object)?;
if object.object_type != NOTE {
@ -596,38 +534,24 @@ pub async fn handle_note(
let object_url = get_object_url(&object)?;
content += &create_content_link(object_url);
};
let (attachments, unprocessed) = get_object_attachments(
db_client,
instance,
storage,
&object,
&author,
).await?;
let (attachments, unprocessed) =
get_object_attachments(db_client, instance, storage, &object, &author).await?;
for attachment_url in unprocessed {
content += &create_content_link(attachment_url);
};
}
if content.is_empty() && attachments.is_empty() {
return Err(ValidationError("post is empty").into());
};
let (mentions, hashtags, links, emojis) = get_object_tags(
db_client,
instance,
storage,
&object,
redirects,
).await?;
let (mentions, hashtags, links, emojis) =
get_object_tags(db_client, instance, storage, &object, redirects).await?;
let in_reply_to_id = match object.in_reply_to {
Some(ref object_id) => {
let object_id = redirects.get(object_id).unwrap_or(object_id);
let in_reply_to = get_post_by_object_id(
db_client,
&instance.url(),
object_id,
).await?;
let in_reply_to = get_post_by_object_id(db_client, &instance.url(), object_id).await?;
Some(in_reply_to.id)
},
}
None => None,
};
let audience = get_audience(&object)?;
@ -665,17 +589,15 @@ pub async fn is_unsolicited_message(
object: &Object,
) -> Result<bool, HandlerError> {
let author_id = get_object_attributed_to(object)?;
let author_has_followers =
has_local_followers(db_client, &author_id).await?;
let author_has_followers = has_local_followers(db_client, &author_id).await?;
let audience = get_audience(object)?;
let has_local_recipients = audience.iter().any(|actor_id| {
parse_local_actor_id(instance_url, actor_id).is_ok()
});
let result =
object.in_reply_to.is_none() &&
is_public_object(&audience) &&
!has_local_recipients &&
!author_has_followers;
let has_local_recipients = audience
.iter()
.any(|actor_id| parse_local_actor_id(instance_url, actor_id).is_ok());
let result = object.in_reply_to.is_none()
&& is_public_object(&audience)
&& !has_local_recipients
&& !author_has_followers;
Ok(result)
}
@ -691,8 +613,8 @@ pub async fn handle_create(
activity: JsonValue,
mut is_authenticated: bool,
) -> HandlerResult {
let activity: CreateNote = serde_json::from_value(activity)
.map_err(|_| ValidationError("invalid object"))?;
let activity: CreateNote =
serde_json::from_value(activity).map_err(|_| ValidationError("invalid object"))?;
let object = activity.object;
// Verify attribution
@ -716,23 +638,21 @@ pub async fn handle_create(
&MediaStorage::from(config),
object_id,
object_received,
).await?;
)
.await?;
Ok(Some(NOTE))
}
#[cfg(test)]
mod tests {
use serde_json::json;
use mitra_models::profiles::types::DbActor;
use crate::activitypub::{
types::Object,
vocabulary::NOTE,
};
use super::*;
use crate::activitypub::{types::Object, vocabulary::NOTE};
use mitra_models::profiles::types::DbActor;
use serde_json::json;
#[test]
fn test_get_object_attributed_to() {
let object = Object {
let object = Object {
object_type: NOTE.to_string(),
attributed_to: Some(json!(["https://example.org/1"])),
..Default::default()

View file

@ -4,14 +4,8 @@ use serde_json::Value;
use mitra_config::Config;
use mitra_models::{
database::{DatabaseClient, DatabaseError},
posts::queries::{
delete_post,
get_post_by_remote_object_id,
},
profiles::queries::{
delete_profile,
get_profile_by_remote_actor_id,
},
posts::queries::{delete_post, get_post_by_remote_object_id},
profiles::queries::{delete_profile, get_profile_by_remote_actor_id},
};
use crate::activitypub::{
@ -39,10 +33,7 @@ pub async fn handle_delete(
.map_err(|_| ValidationError("unexpected activity structure"))?;
if activity.object == activity.actor {
// Self-delete
let profile = match get_profile_by_remote_actor_id(
db_client,
&activity.object,
).await {
let profile = match get_profile_by_remote_actor_id(db_client, &activity.object).await {
Ok(profile) => profile,
// Ignore Delete(Person) if profile is not found
Err(DatabaseError::NotFound(_)) => return Ok(None),
@ -56,19 +47,13 @@ pub async fn handle_delete(
log::info!("deleted profile {}", profile.acct);
return Ok(Some(PERSON));
};
let post = match get_post_by_remote_object_id(
db_client,
&activity.object,
).await {
let post = match get_post_by_remote_object_id(db_client, &activity.object).await {
Ok(post) => post,
// Ignore Delete(Note) if post is not found
Err(DatabaseError::NotFound(_)) => return Ok(None),
Err(other_error) => return Err(other_error.into()),
};
let actor_profile = get_profile_by_remote_actor_id(
db_client,
&activity.actor,
).await?;
let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
if post.author.id != actor_profile.id {
return Err(ValidationError("actor is not an author").into());
};

View file

@ -4,23 +4,18 @@ use serde_json::Value;
use mitra_config::Config;
use mitra_models::{
database::{DatabaseClient, DatabaseError},
relationships::queries::{
create_remote_follow_request_opt,
follow_request_accepted,
},
relationships::queries::{create_remote_follow_request_opt, follow_request_accepted},
users::queries::get_user_by_name,
};
use super::{HandlerError, HandlerResult};
use crate::activitypub::{
builders::accept_follow::prepare_accept_follow,
fetcher::helpers::get_or_import_profile_by_actor_id,
identifiers::parse_local_actor_id,
receiver::deserialize_into_object_id,
vocabulary::PERSON,
fetcher::helpers::get_or_import_profile_by_actor_id, identifiers::parse_local_actor_id,
receiver::deserialize_into_object_id, vocabulary::PERSON,
};
use crate::errors::ValidationError;
use crate::media::MediaStorage;
use super::{HandlerError, HandlerResult};
#[derive(Deserialize)]
struct Follow {
@ -43,20 +38,18 @@ pub async fn handle_follow(
&config.instance(),
&MediaStorage::from(config),
&activity.actor,
).await?;
let source_actor = source_profile.actor_json
.ok_or(HandlerError::LocalObject)?;
let target_username = parse_local_actor_id(
&config.instance_url(),
&activity.object,
)?;
)
.await?;
let source_actor = source_profile.actor_json.ok_or(HandlerError::LocalObject)?;
let target_username = parse_local_actor_id(&config.instance_url(), &activity.object)?;
let target_user = get_user_by_name(db_client, &target_username).await?;
let follow_request = create_remote_follow_request_opt(
db_client,
&source_profile.id,
&target_user.id,
&activity.id,
).await?;
)
.await?;
match follow_request_accepted(db_client, &follow_request.id).await {
Ok(_) => (),
// Proceed even if relationship already exists
@ -70,7 +63,9 @@ pub async fn handle_follow(
&target_user,
&source_actor,
&activity.id,
).enqueue(db_client).await?;
)
.enqueue(db_client)
.await?;
Ok(Some(PERSON))
}

View file

@ -8,10 +8,7 @@ use mitra_models::{
};
use crate::activitypub::{
fetcher::helpers::{
get_or_import_profile_by_actor_id,
get_post_by_object_id,
},
fetcher::helpers::{get_or_import_profile_by_actor_id, get_post_by_object_id},
receiver::deserialize_into_object_id,
vocabulary::NOTE,
};
@ -40,23 +37,16 @@ pub async fn handle_like(
&config.instance(),
&MediaStorage::from(config),
&activity.actor,
).await?;
let post_id = match get_post_by_object_id(
db_client,
&config.instance_url(),
&activity.object,
).await {
Ok(post) => post.id,
// Ignore like if post is not found locally
Err(DatabaseError::NotFound(_)) => return Ok(None),
Err(other_error) => return Err(other_error.into()),
};
match create_reaction(
db_client,
&author.id,
&post_id,
Some(&activity.id),
).await {
)
.await?;
let post_id =
match get_post_by_object_id(db_client, &config.instance_url(), &activity.object).await {
Ok(post) => post.id,
// Ignore like if post is not found locally
Err(DatabaseError::NotFound(_)) => return Ok(None),
Err(other_error) => return Err(other_error.into()),
};
match create_reaction(db_client, &author.id, &post_id, Some(&activity.id)).await {
Ok(_) => (),
// Ignore activity if reaction is already saved
Err(DatabaseError::AlreadyExists(_)) => return Ok(None),

View file

@ -6,18 +6,12 @@ use mitra_models::{
database::DatabaseClient,
notifications::queries::create_move_notification,
profiles::helpers::find_verified_aliases,
relationships::queries::{
get_followers,
unfollow,
},
relationships::queries::{get_followers, unfollow},
users::queries::{get_user_by_id, get_user_by_name},
};
use crate::activitypub::{
builders::{
follow::follow_or_create_request,
undo_follow::prepare_undo_follow,
},
builders::{follow::follow_or_create_request, undo_follow::prepare_undo_follow},
fetcher::helpers::get_or_import_profile_by_actor_id,
identifiers::{parse_local_actor_id, profile_actor_id},
vocabulary::PERSON,
@ -50,39 +44,26 @@ pub async fn handle_move(
let instance = config.instance();
let storage = MediaStorage::from(config);
let old_profile = if let Ok(username) = parse_local_actor_id(
&instance.url(),
&activity.object,
) {
let old_profile = if let Ok(username) = parse_local_actor_id(&instance.url(), &activity.object)
{
let old_user = get_user_by_name(db_client, &username).await?;
old_user.profile
} else {
get_or_import_profile_by_actor_id(
db_client,
&instance,
&storage,
&activity.object,
).await?
get_or_import_profile_by_actor_id(db_client, &instance, &storage, &activity.object).await?
};
let old_actor_id = profile_actor_id(&instance.url(), &old_profile);
let new_profile = if let Ok(username) = parse_local_actor_id(
&instance.url(),
&activity.target,
) {
let new_profile = if let Ok(username) = parse_local_actor_id(&instance.url(), &activity.target)
{
let new_user = get_user_by_name(db_client, &username).await?;
new_user.profile
} else {
get_or_import_profile_by_actor_id(
db_client,
&instance,
&storage,
&activity.target,
).await?
get_or_import_profile_by_actor_id(db_client, &instance, &storage, &activity.target).await?
};
// Find aliases by DIDs (verified)
let mut aliases = find_verified_aliases(db_client, &new_profile).await?
let mut aliases = find_verified_aliases(db_client, &new_profile)
.await?
.into_iter()
.map(|profile| profile_actor_id(&instance.url(), &profile))
.collect::<Vec<_>>();
@ -96,39 +77,22 @@ pub async fn handle_move(
for follower in followers {
let follower = get_user_by_id(db_client, &follower.id).await?;
// Unfollow old profile
let maybe_follow_request_id = unfollow(
db_client,
&follower.id,
&old_profile.id,
).await?;
let maybe_follow_request_id = unfollow(db_client, &follower.id, &old_profile.id).await?;
// Send Undo(Follow) if old actor is not local
if let Some(ref old_actor) = old_profile.actor_json {
let follow_request_id = maybe_follow_request_id
.expect("follow request must exist");
prepare_undo_follow(
&instance,
&follower,
old_actor,
&follow_request_id,
).enqueue(db_client).await?;
let follow_request_id = maybe_follow_request_id.expect("follow request must exist");
prepare_undo_follow(&instance, &follower, old_actor, &follow_request_id)
.enqueue(db_client)
.await?;
};
if follower.id == new_profile.id {
// Don't self-follow
continue;
};
// Follow new profile
follow_or_create_request(
db_client,
&instance,
&follower,
&new_profile,
).await?;
create_move_notification(
db_client,
&new_profile.id,
&follower.id,
).await?;
};
follow_or_create_request(db_client, &instance, &follower, &new_profile).await?;
create_move_notification(db_client, &new_profile.id, &follower.id).await?;
}
Ok(Some(PERSON))
}

View file

@ -5,17 +5,12 @@ use mitra_config::Config;
use mitra_models::{
database::DatabaseClient,
profiles::queries::get_profile_by_remote_actor_id,
relationships::queries::{
follow_request_rejected,
get_follow_request_by_id,
},
relationships::queries::{follow_request_rejected, get_follow_request_by_id},
relationships::types::FollowRequestStatus,
};
use crate::activitypub::{
identifiers::parse_local_object_id,
receiver::deserialize_into_object_id,
vocabulary::FOLLOW,
identifiers::parse_local_object_id, receiver::deserialize_into_object_id, vocabulary::FOLLOW,
};
use crate::errors::ValidationError;
@ -36,14 +31,8 @@ pub async fn handle_reject(
// Reject(Follow)
let activity: Reject = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id(
db_client,
&activity.actor,
).await?;
let follow_request_id = parse_local_object_id(
&config.instance_url(),
&activity.object,
)?;
let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
let follow_request_id = parse_local_object_id(&config.instance_url(), &activity.object)?;
let follow_request = get_follow_request_by_id(db_client, &follow_request_id).await?;
if follow_request.target_id != actor_profile.id {
return Err(ValidationError("actor is not a target").into());

View file

@ -4,18 +4,13 @@ use serde_json::Value;
use mitra_config::Config;
use mitra_models::{
database::{DatabaseClient, DatabaseError},
notifications::queries::{
create_subscription_expiration_notification,
},
notifications::queries::create_subscription_expiration_notification,
profiles::queries::get_profile_by_remote_actor_id,
relationships::queries::unsubscribe,
users::queries::get_user_by_name,
};
use crate::activitypub::{
identifiers::parse_local_actor_id,
vocabulary::PERSON,
};
use crate::activitypub::{identifiers::parse_local_actor_id, vocabulary::PERSON};
use crate::errors::ValidationError;
use super::{HandlerError, HandlerResult};
@ -34,28 +29,19 @@ pub async fn handle_remove(
) -> HandlerResult {
let activity: Remove = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id(
db_client,
&activity.actor,
).await?;
let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
let actor = actor_profile.actor_json.ok_or(HandlerError::LocalObject)?;
if Some(activity.target) == actor.subscribers {
// Removing from subscribers
let username = parse_local_actor_id(
&config.instance_url(),
&activity.object,
)?;
let username = parse_local_actor_id(&config.instance_url(), &activity.object)?;
let user = get_user_by_name(db_client, &username).await?;
// actor is recipient, user is sender
match unsubscribe(db_client, &user.id, &actor_profile.id).await {
Ok(_) => {
create_subscription_expiration_notification(
db_client,
&actor_profile.id,
&user.id,
).await?;
create_subscription_expiration_notification(db_client, &actor_profile.id, &user.id)
.await?;
return Ok(Some(PERSON));
},
}
// Ignore removal if relationship does not exist
Err(DatabaseError::NotFound(_)) => return Ok(None),
Err(other_error) => return Err(other_error.into()),

View file

@ -4,22 +4,10 @@ use serde_json::Value;
use mitra_config::Config;
use mitra_models::{
database::{DatabaseClient, DatabaseError},
posts::queries::{
delete_post,
get_post_by_remote_object_id,
},
profiles::queries::{
get_profile_by_acct,
get_profile_by_remote_actor_id,
},
reactions::queries::{
delete_reaction,
get_reaction_by_remote_activity_id,
},
relationships::queries::{
get_follow_request_by_activity_id,
unfollow,
},
posts::queries::{delete_post, get_post_by_remote_object_id},
profiles::queries::{get_profile_by_acct, get_profile_by_remote_actor_id},
reactions::queries::{delete_reaction, get_reaction_by_remote_activity_id},
relationships::queries::{get_follow_request_by_activity_id, unfollow},
};
use crate::activitypub::{
@ -44,15 +32,9 @@ async fn handle_undo_follow(
) -> HandlerResult {
let activity: UndoFollow = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?;
let source_profile = get_profile_by_remote_actor_id(
db_client,
&activity.actor,
).await?;
let source_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
let target_actor_id = find_object_id(&activity.object["object"])?;
let target_username = parse_local_actor_id(
&config.instance_url(),
&target_actor_id,
)?;
let target_username = parse_local_actor_id(&config.instance_url(), &target_actor_id)?;
// acct equals username if profile is local
let target_profile = get_profile_by_acct(db_client, &target_username).await?;
match unfollow(db_client, &source_profile.id, &target_profile.id).await {
@ -83,10 +65,7 @@ pub async fn handle_undo(
let activity: Undo = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id(
db_client,
&activity.actor,
).await?;
let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
match get_follow_request_by_activity_id(db_client, &activity.object).await {
Ok(follow_request) => {
@ -98,9 +77,10 @@ pub async fn handle_undo(
db_client,
&follow_request.source_id,
&follow_request.target_id,
).await?;
)
.await?;
return Ok(Some(FOLLOW));
},
}
Err(DatabaseError::NotFound(_)) => (), // try other object types
Err(other_error) => return Err(other_error.into()),
};
@ -111,19 +91,12 @@ pub async fn handle_undo(
if reaction.author_id != actor_profile.id {
return Err(ValidationError("actor is not an author").into());
};
delete_reaction(
db_client,
&reaction.author_id,
&reaction.post_id,
).await?;
delete_reaction(db_client, &reaction.author_id, &reaction.post_id).await?;
Ok(Some(LIKE))
},
}
Err(DatabaseError::NotFound(_)) => {
// Undo(Announce)
let post = match get_post_by_remote_object_id(
db_client,
&activity.object,
).await {
let post = match get_post_by_remote_object_id(db_client, &activity.object).await {
Ok(post) => post,
// Ignore undo if neither reaction nor repost is found
Err(DatabaseError::NotFound(_)) => return Ok(None),
@ -139,7 +112,7 @@ pub async fn handle_undo(
None => return Err(ValidationError("object is not a repost").into()),
};
Ok(Some(ANNOUNCE))
},
}
Err(other_error) => Err(other_error.into()),
}
}

View file

@ -7,24 +7,15 @@ use serde_json::Value;
use mitra_config::Config;
use mitra_models::{
database::{DatabaseClient, DatabaseError},
posts::queries::{
get_post_by_remote_object_id,
update_post,
},
posts::queries::{get_post_by_remote_object_id, update_post},
posts::types::PostUpdateData,
profiles::queries::get_profile_by_remote_actor_id,
};
use crate::activitypub::{
actors::{
helpers::update_remote_profile,
types::Actor,
},
actors::{helpers::update_remote_profile, types::Actor},
handlers::create::{
create_content_link,
get_object_attachments,
get_object_content,
get_object_tags,
create_content_link, get_object_attachments, get_object_content, get_object_tags,
get_object_url,
},
identifiers::profile_actor_id,
@ -47,13 +38,10 @@ async fn handle_update_note(
db_client: &mut impl DatabaseClient,
activity: Value,
) -> HandlerResult {
let activity: UpdateNote = serde_json::from_value(activity)
.map_err(|_| ValidationError("invalid object"))?;
let activity: UpdateNote =
serde_json::from_value(activity).map_err(|_| ValidationError("invalid object"))?;
let object = activity.object;
let post = match get_post_by_remote_object_id(
db_client,
&object.id,
).await {
let post = match get_post_by_remote_object_id(db_client, &object.id).await {
Ok(post) => post,
// Ignore Update if post is not found locally
Err(DatabaseError::NotFound(_)) => return Ok(None),
@ -71,26 +59,16 @@ async fn handle_update_note(
};
let is_sensitive = object.sensitive.unwrap_or(false);
let storage = MediaStorage::from(config);
let (attachments, unprocessed) = get_object_attachments(
db_client,
&instance,
&storage,
&object,
&post.author,
).await?;
let (attachments, unprocessed) =
get_object_attachments(db_client, &instance, &storage, &object, &post.author).await?;
for attachment_url in unprocessed {
content += &create_content_link(attachment_url);
};
}
if content.is_empty() && attachments.is_empty() {
return Err(ValidationError("post is empty").into());
};
let (mentions, hashtags, links, emojis) = get_object_tags(
db_client,
&instance,
&storage,
&object,
&HashMap::new(),
).await?;
let (mentions, hashtags, links, emojis) =
get_object_tags(db_client, &instance, &storage, &object, &HashMap::new()).await?;
let updated_at = object.updated.unwrap_or(Utc::now());
let post_data = PostUpdateData {
content,
@ -117,22 +95,20 @@ async fn handle_update_person(
db_client: &mut impl DatabaseClient,
activity: Value,
) -> HandlerResult {
let activity: UpdatePerson = serde_json::from_value(activity)
.map_err(|_| ValidationError("invalid actor data"))?;
let activity: UpdatePerson =
serde_json::from_value(activity).map_err(|_| ValidationError("invalid actor data"))?;
if activity.object.id != activity.actor {
return Err(ValidationError("actor ID mismatch").into());
};
let profile = get_profile_by_remote_actor_id(
db_client,
&activity.object.id,
).await?;
let profile = get_profile_by_remote_actor_id(db_client, &activity.object.id).await?;
update_remote_profile(
db_client,
&config.instance(),
&MediaStorage::from(config),
profile,
activity.object,
).await?;
)
.await?;
Ok(Some(PERSON))
}
@ -141,18 +117,15 @@ pub async fn handle_update(
db_client: &mut impl DatabaseClient,
activity: Value,
) -> HandlerResult {
let object_type = activity["object"]["type"].as_str()
let object_type = activity["object"]["type"]
.as_str()
.ok_or(ValidationError("unknown object type"))?;
match object_type {
NOTE => {
handle_update_note(config, db_client, activity).await
},
PERSON => {
handle_update_person(config, db_client, activity).await
},
NOTE => handle_update_note(config, db_client, activity).await,
PERSON => handle_update_person(config, db_client, activity).await,
_ => {
log::warn!("unexpected object type {}", object_type);
Ok(None)
},
}
}
}

View file

@ -14,9 +14,7 @@ pub enum Network {
I2p,
}
pub fn get_network_type(request_url: &str) ->
Result<Network, url::ParseError>
{
pub fn get_network_type(request_url: &str) -> Result<Network, url::ParseError> {
let hostname = get_hostname(request_url)?;
let network = if hostname.ends_with(".onion") {
Network::Tor
@ -38,23 +36,18 @@ pub fn build_federation_client(
match network {
Network::Default => (),
Network::Tor => {
maybe_proxy_url = instance.onion_proxy_url.as_ref()
.or(maybe_proxy_url);
},
maybe_proxy_url = instance.onion_proxy_url.as_ref().or(maybe_proxy_url);
}
Network::I2p => {
maybe_proxy_url = instance.i2p_proxy_url.as_ref()
.or(maybe_proxy_url);
},
maybe_proxy_url = instance.i2p_proxy_url.as_ref().or(maybe_proxy_url);
}
};
if let Some(proxy_url) = maybe_proxy_url {
let proxy = Proxy::all(proxy_url)?;
client_builder = client_builder.proxy(proxy);
};
let request_timeout = Duration::from_secs(timeout);
let connect_timeout = Duration::from_secs(max(
timeout,
CONNECTION_TIMEOUT,
));
let connect_timeout = Duration::from_secs(max(timeout, CONNECTION_TIMEOUT));
client_builder
.timeout(request_timeout)
.connect_timeout(connect_timeout)

View file

@ -1,10 +1,7 @@
use regex::Regex;
use uuid::Uuid;
use mitra_models::{
posts::types::Post,
profiles::types::DbActorProfile,
};
use mitra_models::{posts::types::Post, profiles::types::DbActorProfile};
use mitra_utils::urls::get_hostname;
use crate::errors::ValidationError;
@ -83,45 +80,41 @@ pub fn local_tag_collection(instance_url: &str, tag_name: &str) -> String {
}
pub fn validate_object_id(object_id: &str) -> Result<(), ValidationError> {
get_hostname(object_id)
.map_err(|_| ValidationError("invalid object ID"))?;
get_hostname(object_id).map_err(|_| ValidationError("invalid object ID"))?;
Ok(())
}
pub fn parse_local_actor_id(
instance_url: &str,
actor_id: &str,
) -> Result<String, ValidationError> {
pub fn parse_local_actor_id(instance_url: &str, actor_id: &str) -> Result<String, ValidationError> {
let url_regexp_str = format!(
"^{}/users/(?P<username>[0-9a-z_]+)$",
instance_url.replace('.', r"\."),
);
let url_regexp = Regex::new(&url_regexp_str)
.map_err(|_| ValidationError("error"))?;
let url_caps = url_regexp.captures(actor_id)
let url_regexp = Regex::new(&url_regexp_str).map_err(|_| ValidationError("error"))?;
let url_caps = url_regexp
.captures(actor_id)
.ok_or(ValidationError("invalid actor ID"))?;
let username = url_caps.name("username")
let username = url_caps
.name("username")
.ok_or(ValidationError("invalid actor ID"))?
.as_str()
.to_owned();
Ok(username)
}
pub fn parse_local_object_id(
instance_url: &str,
object_id: &str,
) -> Result<Uuid, ValidationError> {
pub fn parse_local_object_id(instance_url: &str, object_id: &str) -> Result<Uuid, ValidationError> {
let url_regexp_str = format!(
"^{}/objects/(?P<uuid>[0-9a-f-]+)$",
instance_url.replace('.', r"\."),
);
let url_regexp = Regex::new(&url_regexp_str)
.map_err(|_| ValidationError("error"))?;
let url_caps = url_regexp.captures(object_id)
let url_regexp = Regex::new(&url_regexp_str).map_err(|_| ValidationError("error"))?;
let url_caps = url_regexp
.captures(object_id)
.ok_or(ValidationError("invalid object ID"))?;
let internal_object_id: Uuid = url_caps.name("uuid")
let internal_object_id: Uuid = url_caps
.name("uuid")
.ok_or(ValidationError("invalid object ID"))?
.as_str().parse()
.as_str()
.parse()
.map_err(|_| ValidationError("invalid object ID"))?;
Ok(internal_object_id)
}
@ -151,68 +144,51 @@ pub fn profile_actor_url(instance_url: &str, profile: &DbActorProfile) -> String
#[cfg(test)]
mod tests {
use mitra_utils::id::generate_ulid;
use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.org";
#[test]
fn test_parse_local_actor_id() {
let username = parse_local_actor_id(
INSTANCE_URL,
"https://example.org/users/test",
).unwrap();
let username =
parse_local_actor_id(INSTANCE_URL, "https://example.org/users/test").unwrap();
assert_eq!(username, "test".to_string());
}
#[test]
fn test_parse_local_actor_id_wrong_path() {
let error = parse_local_actor_id(
INSTANCE_URL,
"https://example.org/user/test",
).unwrap_err();
let error =
parse_local_actor_id(INSTANCE_URL, "https://example.org/user/test").unwrap_err();
assert_eq!(error.to_string(), "invalid actor ID");
}
#[test]
fn test_parse_local_actor_id_invalid_username() {
let error = parse_local_actor_id(
INSTANCE_URL,
"https://example.org/users/tes-t",
).unwrap_err();
let error =
parse_local_actor_id(INSTANCE_URL, "https://example.org/users/tes-t").unwrap_err();
assert_eq!(error.to_string(), "invalid actor ID");
}
#[test]
fn test_parse_local_actor_id_invalid_instance_url() {
let error = parse_local_actor_id(
INSTANCE_URL,
"https://example.gov/users/test",
).unwrap_err();
let error =
parse_local_actor_id(INSTANCE_URL, "https://example.gov/users/test").unwrap_err();
assert_eq!(error.to_string(), "invalid actor ID");
}
#[test]
fn test_parse_local_object_id() {
let expected_uuid = generate_ulid();
let object_id = format!(
"https://example.org/objects/{}",
expected_uuid,
);
let internal_object_id = parse_local_object_id(
INSTANCE_URL,
&object_id,
).unwrap();
let object_id = format!("https://example.org/objects/{}", expected_uuid,);
let internal_object_id = parse_local_object_id(INSTANCE_URL, &object_id).unwrap();
assert_eq!(internal_object_id, expected_uuid);
}
#[test]
fn test_parse_local_object_id_invalid_uuid() {
let object_id = "https://example.org/objects/1234";
let error = parse_local_object_id(
INSTANCE_URL,
object_id,
).unwrap_err();
let error = parse_local_object_id(INSTANCE_URL, object_id).unwrap_err();
assert_eq!(error.to_string(), "invalid object ID");
}
@ -223,9 +199,6 @@ mod tests {
..Default::default()
};
let profile_url = profile_actor_url(INSTANCE_URL, &profile);
assert_eq!(
profile_url,
"https://example.org/users/test",
);
assert_eq!(profile_url, "https://example.org/users/test",);
}
}

View file

@ -5,19 +5,9 @@ use uuid::Uuid;
use mitra_config::Config;
use mitra_models::{
background_jobs::queries::{
enqueue_job,
get_job_batch,
delete_job_from_queue,
},
background_jobs::queries::{delete_job_from_queue, enqueue_job, get_job_batch},
background_jobs::types::JobType,
database::{
get_database_client,
DatabaseClient,
DatabaseError,
DatabaseTypeError,
DbPool,
},
database::{get_database_client, DatabaseClient, DatabaseError, DatabaseTypeError, DbPool},
profiles::queries::set_reachability_status,
users::queries::get_user_by_id,
};
@ -49,15 +39,15 @@ impl IncomingActivityJobData {
db_client: &impl DatabaseClient,
delay: u32,
) -> Result<(), DatabaseError> {
let job_data = serde_json::to_value(self)
.expect("activity should be serializable");
let job_data = serde_json::to_value(self).expect("activity should be serializable");
let scheduled_for = Utc::now() + Duration::seconds(delay.into());
enqueue_job(
db_client,
&JobType::IncomingActivity,
&job_data,
&scheduled_for,
).await
)
.await
}
}
@ -78,11 +68,11 @@ pub async fn process_queued_incoming_activities(
&JobType::IncomingActivity,
INCOMING_QUEUE_BATCH_SIZE,
JOB_TIMEOUT,
).await?;
)
.await?;
for job in batch {
let mut job_data: IncomingActivityJobData =
serde_json::from_value(job.job_data)
.map_err(|_| DatabaseTypeError)?;
serde_json::from_value(job.job_data).map_err(|_| DatabaseTypeError)?;
// See also: activitypub::queues::JOB_TIMEOUT
let duration_max = std::time::Duration::from_secs(600);
let handler_future = handle_activity(
@ -91,10 +81,7 @@ pub async fn process_queued_incoming_activities(
&job_data.activity,
job_data.is_authenticated,
);
let handler_result = match tokio::time::timeout(
duration_max,
handler_future,
).await {
let handler_result = match tokio::time::timeout(duration_max, handler_future).await {
Ok(result) => result,
Err(_) => {
log::error!(
@ -103,12 +90,12 @@ pub async fn process_queued_incoming_activities(
);
delete_job_from_queue(db_client, &job.id).await?;
continue;
},
}
};
if let Err(error) = handler_result {
job_data.failure_count += 1;
if let HandlerError::DatabaseError(
DatabaseError::DatabaseClientError(ref pg_error)) = error
if let HandlerError::DatabaseError(DatabaseError::DatabaseClientError(ref pg_error)) =
error
{
log::error!("database client error: {}", pg_error);
};
@ -129,7 +116,7 @@ pub async fn process_queued_incoming_activities(
};
};
delete_job_from_queue(db_client, &job.id).await?;
};
}
Ok(())
}
@ -147,15 +134,15 @@ impl OutgoingActivityJobData {
db_client: &impl DatabaseClient,
delay: u32,
) -> Result<(), DatabaseError> {
let job_data = serde_json::to_value(self)
.expect("activity should be serializable");
let job_data = serde_json::to_value(self).expect("activity should be serializable");
let scheduled_for = Utc::now() + Duration::seconds(delay.into());
enqueue_job(
db_client,
&JobType::OutgoingActivity,
&job_data,
&scheduled_for,
).await
)
.await
}
}
@ -178,11 +165,11 @@ pub async fn process_queued_outgoing_activities(
&JobType::OutgoingActivity,
OUTGOING_QUEUE_BATCH_SIZE,
JOB_TIMEOUT,
).await?;
)
.await?;
for job in batch {
let mut job_data: OutgoingActivityJobData =
serde_json::from_value(job.job_data)
.map_err(|_| DatabaseTypeError)?;
serde_json::from_value(job.job_data).map_err(|_| DatabaseTypeError)?;
let sender = get_user_by_id(db_client, &job_data.sender_id).await?;
let outgoing_activity = OutgoingActivity {
instance: config.instance(),
@ -198,7 +185,7 @@ pub async fn process_queued_outgoing_activities(
log::error!("{}", error);
delete_job_from_queue(db_client, &job.id).await?;
return Ok(());
},
}
};
log::info!(
"delivery job: {} delivered, {} errors (attempt #{})",
@ -206,8 +193,8 @@ pub async fn process_queued_outgoing_activities(
recipients.iter().filter(|item| !item.is_delivered).count(),
job_data.failure_count + 1,
);
if recipients.iter().any(|recipient| !recipient.is_delivered) &&
job_data.failure_count < OUTGOING_QUEUE_RETRIES_MAX
if recipients.iter().any(|recipient| !recipient.is_delivered)
&& job_data.failure_count < OUTGOING_QUEUE_RETRIES_MAX
{
job_data.failure_count += 1;
// Re-queue if some deliveries are not successful
@ -219,15 +206,11 @@ pub async fn process_queued_outgoing_activities(
// Update inbox status if all deliveries are successful
// or if retry limit is reached
for recipient in recipients {
set_reachability_status(
db_client,
&recipient.id,
recipient.is_delivered,
).await?;
};
set_reachability_status(db_client, &recipient.id, recipient.is_delivered).await?;
}
};
delete_job_from_queue(db_client, &job.id).await?;
};
}
Ok(())
}

View file

@ -1,35 +1,17 @@
use actix_web::HttpRequest;
use serde::{
Deserialize,
Deserializer,
de::DeserializeOwned,
de::Error as DeserializerError,
};
use serde::{de::DeserializeOwned, de::Error as DeserializerError, Deserialize, Deserializer};
use serde_json::Value;
use mitra_config::Config;
use mitra_models::database::{DatabaseClient, DatabaseError};
use crate::errors::{
ConversionError,
HttpError,
ValidationError,
};
use super::authentication::{
verify_signed_activity,
verify_signed_request,
AuthenticationError,
};
use super::authentication::{verify_signed_activity, verify_signed_request, AuthenticationError};
use super::fetcher::fetchers::FetchError;
use super::handlers::{
accept::handle_accept,
add::handle_add,
announce::handle_announce,
create::{
handle_create,
is_unsolicited_message,
CreateNote,
},
create::{handle_create, is_unsolicited_message, CreateNote},
delete::handle_delete,
follow::handle_follow,
like::handle_like,
@ -42,6 +24,7 @@ use super::handlers::{
use super::identifiers::profile_actor_id;
use super::queues::IncomingActivityJobData;
use super::vocabulary::*;
use crate::errors::{ConversionError, HttpError, ValidationError};
#[derive(thiserror::Error, Debug)]
pub enum HandlerError {
@ -65,14 +48,10 @@ impl From<HandlerError> for HttpError {
fn from(error: HandlerError) -> Self {
match error {
HandlerError::LocalObject => HttpError::InternalError,
HandlerError::FetchError(error) => {
HttpError::ValidationError(error.to_string())
},
HandlerError::FetchError(error) => HttpError::ValidationError(error.to_string()),
HandlerError::ValidationError(error) => error.into(),
HandlerError::DatabaseError(error) => error.into(),
HandlerError::AuthError(_) => {
HttpError::AuthError("invalid signature")
},
HandlerError::AuthError(_) => HttpError::AuthError("invalid signature"),
}
}
}
@ -93,13 +72,13 @@ pub fn parse_array(value: &Value) -> Result<Vec<String>, ConversionError> {
// id property is missing
return Err(ConversionError);
};
},
}
// Unexpected array item type
_ => return Err(ConversionError),
};
};
}
results
},
}
// Unexpected value type
_ => return Err(ConversionError),
};
@ -116,10 +95,9 @@ pub fn parse_property_value<T: DeserializeOwned>(value: &Value) -> Result<Vec<T>
};
let mut items = vec![];
for object in objects {
let item: T = serde_json::from_value(object)
.map_err(|_| ConversionError)?;
let item: T = serde_json::from_value(object).map_err(|_| ConversionError)?;
items.push(item);
};
}
Ok(items)
}
@ -128,23 +106,22 @@ pub fn find_object_id(object: &Value) -> Result<String, ValidationError> {
let object_id = match object.as_str() {
Some(object_id) => object_id.to_owned(),
None => {
let object_id = object["id"].as_str()
let object_id = object["id"]
.as_str()
.ok_or(ValidationError("missing object ID"))?
.to_string();
object_id
},
}
};
Ok(object_id)
}
pub fn deserialize_into_object_id<'de, D>(
deserializer: D,
) -> Result<String, D::Error>
where D: Deserializer<'de>
pub fn deserialize_into_object_id<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
let object_id = find_object_id(&value)
.map_err(DeserializerError::custom)?;
let object_id = find_object_id(&value).map_err(DeserializerError::custom)?;
Ok(object_id)
}
@ -154,54 +131,32 @@ pub async fn handle_activity(
activity: &Value,
is_authenticated: bool,
) -> Result<(), HandlerError> {
let activity_type = activity["type"].as_str()
let activity_type = activity["type"]
.as_str()
.ok_or(ValidationError("type property is missing"))?
.to_owned();
let activity_actor = activity["actor"].as_str()
let activity_actor = activity["actor"]
.as_str()
.ok_or(ValidationError("actor property is missing"))?
.to_owned();
let activity = activity.clone();
let maybe_object_type = match activity_type.as_str() {
ACCEPT => {
handle_accept(config, db_client, activity).await?
},
ADD => {
handle_add(config, db_client, activity).await?
},
ANNOUNCE => {
handle_announce(config, db_client, activity).await?
},
CREATE => {
handle_create(config, db_client, activity, is_authenticated).await?
},
DELETE => {
handle_delete(config, db_client, activity).await?
},
FOLLOW => {
handle_follow(config, db_client, activity).await?
},
LIKE | EMOJI_REACT => {
handle_like(config, db_client, activity).await?
},
MOVE => {
handle_move(config, db_client, activity).await?
},
REJECT => {
handle_reject(config, db_client, activity).await?
},
REMOVE => {
handle_remove(config, db_client, activity).await?
},
UNDO => {
handle_undo(config, db_client, activity).await?
},
UPDATE => {
handle_update(config, db_client, activity).await?
},
ACCEPT => handle_accept(config, db_client, activity).await?,
ADD => handle_add(config, db_client, activity).await?,
ANNOUNCE => handle_announce(config, db_client, activity).await?,
CREATE => handle_create(config, db_client, activity, is_authenticated).await?,
DELETE => handle_delete(config, db_client, activity).await?,
FOLLOW => handle_follow(config, db_client, activity).await?,
LIKE | EMOJI_REACT => handle_like(config, db_client, activity).await?,
MOVE => handle_move(config, db_client, activity).await?,
REJECT => handle_reject(config, db_client, activity).await?,
REMOVE => handle_remove(config, db_client, activity).await?,
UNDO => handle_undo(config, db_client, activity).await?,
UPDATE => handle_update(config, db_client, activity).await?,
_ => {
log::warn!("activity type is not supported: {}", activity);
None
},
}
};
if let Some(object_type) = maybe_object_type {
log::info!(
@ -220,9 +175,11 @@ pub async fn receive_activity(
request: &HttpRequest,
activity: &Value,
) -> Result<(), HandlerError> {
let activity_type = activity["type"].as_str()
let activity_type = activity["type"]
.as_str()
.ok_or(ValidationError("type property is missing"))?;
let activity_actor = activity["actor"].as_str()
let activity_actor = activity["actor"]
.as_str()
.ok_or(ValidationError("actor property is missing"))?;
let actor_hostname = url::Url::parse(activity_actor)
@ -230,7 +187,9 @@ pub async fn receive_activity(
.host_str()
.ok_or(ValidationError("invalid actor ID"))?
.to_string();
if config.blocked_instances.iter()
if config
.blocked_instances
.iter()
.any(|instance_hostname| &actor_hostname == instance_hostname)
{
log::warn!("ignoring activity from blocked instance: {}", activity);
@ -240,7 +199,9 @@ pub async fn receive_activity(
let is_self_delete = if activity_type == DELETE {
let object_id = find_object_id(&activity["object"])?;
object_id == activity_actor
} else { false };
} else {
false
};
// HTTP signature is required
let mut signer = match verify_signed_request(
@ -249,24 +210,28 @@ pub async fn receive_activity(
request,
// Don't fetch signer if this is Delete(Person) activity
is_self_delete,
).await {
)
.await
{
Ok(request_signer) => {
log::debug!("request signed by {}", request_signer.acct);
request_signer
},
}
Err(error) => {
if is_self_delete && matches!(
error,
AuthenticationError::NoHttpSignature |
AuthenticationError::DatabaseError(DatabaseError::NotFound(_))
) {
if is_self_delete
&& matches!(
error,
AuthenticationError::NoHttpSignature
| AuthenticationError::DatabaseError(DatabaseError::NotFound(_))
)
{
// Ignore Delete(Person) activities without HTTP signatures
// or if signer is not found in local database
return Ok(());
};
log::warn!("invalid HTTP signature: {}", error);
return Err(error.into());
},
}
};
// JSON signature is optional
@ -276,7 +241,9 @@ pub async fn receive_activity(
activity,
// Don't fetch actor if this is Delete(Person) activity
is_self_delete,
).await {
)
.await
{
Ok(activity_signer) => {
if activity_signer.acct != signer.acct {
log::warn!(
@ -289,14 +256,16 @@ pub async fn receive_activity(
};
// Activity signature has higher priority
signer = activity_signer;
},
}
Err(AuthenticationError::NoJsonSignature) => (), // ignore
Err(other_error) => {
log::warn!("invalid JSON signature: {}", other_error);
},
}
};
if config.blocked_instances.iter()
if config
.blocked_instances
.iter()
.any(|instance| signer.hostname.as_ref() == Some(instance))
{
log::warn!("ignoring activity from blocked instance: {}", activity);
@ -311,7 +280,7 @@ pub async fn receive_activity(
DELETE | LIKE => {
// Ignore forwarded Delete and Like activities
return Ok(());
},
}
_ => {
// Reject other types
log::warn!(
@ -320,7 +289,7 @@ pub async fn receive_activity(
activity_actor,
);
return Err(AuthenticationError::UnexpectedSigner.into());
},
}
};
};
@ -336,31 +305,24 @@ pub async fn receive_activity(
if let ANNOUNCE | CREATE | DELETE | MOVE | UNDO | UPDATE = activity_type {
// Add activity to job queue and release lock
IncomingActivityJobData::new(activity, is_authenticated)
.into_job(db_client, 0).await?;
.into_job(db_client, 0)
.await?;
log::debug!("activity added to the queue: {}", activity_type);
return Ok(());
};
handle_activity(
config,
db_client,
activity,
is_authenticated,
).await
handle_activity(config, db_client, activity, is_authenticated).await
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use serde_json::json;
#[test]
fn test_parse_array_with_string() {
let value = json!("test");
assert_eq!(
parse_array(&value).unwrap(),
vec!["test".to_string()],
);
assert_eq!(parse_array(&value).unwrap(), vec!["test".to_string()],);
}
#[test]

View file

@ -1,19 +1,11 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{
Deserialize,
Deserializer,
Serialize,
de::{Error as DeserializerError},
};
use serde::{de::Error as DeserializerError, Deserialize, Deserializer, Serialize};
use serde_json::Value;
use super::constants::{
AP_CONTEXT,
MITRA_CONTEXT,
W3ID_DATA_INTEGRITY_CONTEXT,
W3ID_SECURITY_CONTEXT,
AP_CONTEXT, MITRA_CONTEXT, W3ID_DATA_INTEGRITY_CONTEXT, W3ID_SECURITY_CONTEXT,
};
use super::receiver::parse_property_value;
use super::vocabulary::HASHTAG;
@ -35,7 +27,9 @@ pub struct Link {
pub href: String,
}
fn default_tag_type() -> String { HASHTAG.to_string() }
fn default_tag_type() -> String {
HASHTAG.to_string()
}
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
@ -88,10 +82,9 @@ pub struct EmojiTag {
pub updated: DateTime<Utc>,
}
pub fn deserialize_value_array<'de, D>(
deserializer: D,
) -> Result<Vec<Value>, D::Error>
where D: Deserializer<'de>
pub fn deserialize_value_array<'de, D>(deserializer: D) -> Result<Vec<Value>, D::Error>
where
D: Deserializer<'de>,
{
let maybe_value: Option<Value> = Option::deserialize(deserializer)?;
let values = if let Some(value) = maybe_value {
@ -124,10 +117,7 @@ pub struct Object {
pub quote_url: Option<String>,
pub sensitive: Option<bool>,
#[serde(
default,
deserialize_with = "deserialize_value_array",
)]
#[serde(default, deserialize_with = "deserialize_value_array")]
pub tag: Vec<Value>,
pub to: Option<Value>,

View file

@ -1,14 +1,8 @@
use std::time::Instant;
use actix_web::{
get,
post,
web,
http::header as http_header,
http::header::HeaderMap,
HttpRequest,
HttpResponse,
Scope,
get, http::header as http_header, http::header::HeaderMap, post, web, HttpRequest,
HttpResponse, Scope,
};
use serde::Deserialize;
use tokio::sync::Mutex;
@ -23,36 +17,23 @@ use mitra_models::{
users::queries::get_user_by_name,
};
use crate::errors::HttpError;
use crate::web_client::urls::{
get_post_page_url,
get_profile_page_url,
get_tag_page_url,
};
use super::actors::types::{get_local_actor, get_instance_actor};
use super::actors::types::{get_instance_actor, get_local_actor};
use super::builders::{
announce::build_announce,
create_note::{
build_emoji_tag,
build_note,
build_create_note,
},
};
use super::collections::{
OrderedCollection,
OrderedCollectionPage,
create_note::{build_create_note, build_emoji_tag, build_note},
};
use super::collections::{OrderedCollection, OrderedCollectionPage};
use super::constants::{AP_MEDIA_TYPE, AS_MEDIA_TYPE};
use super::identifiers::{
local_actor_followers,
local_actor_following,
local_actor_subscribers,
local_actor_outbox,
local_actor_followers, local_actor_following, local_actor_outbox, local_actor_subscribers,
};
use super::receiver::{receive_activity, HandlerError};
use crate::errors::HttpError;
use crate::web_client::urls::{get_post_page_url, get_profile_page_url, get_tag_page_url};
pub fn is_activitypub_request(headers: &HeaderMap) -> bool {
let maybe_user_agent = headers.get(http_header::USER_AGENT)
let maybe_user_agent = headers
.get(http_header::USER_AGENT)
.and_then(|value| value.to_str().ok());
if let Some(user_agent) = maybe_user_agent {
if user_agent.contains("THIS. IS. GNU social!!!!") {
@ -67,7 +48,9 @@ pub fn is_activitypub_request(headers: &HeaderMap) -> bool {
"application/json",
];
if let Some(content_type) = headers.get(http_header::ACCEPT) {
let content_type_str = content_type.to_str().ok()
let content_type_str = content_type
.to_str()
.ok()
// Take first content type if there are many
.and_then(|value| value.split(',').next())
.unwrap_or("");
@ -86,20 +69,15 @@ async fn actor_view(
let db_client = &**get_database_client(&db_pool).await?;
let user = get_user_by_name(db_client, &username).await?;
if !is_activitypub_request(request.headers()) {
let page_url = get_profile_page_url(
&config.instance_url(),
&user.profile.username,
);
let page_url = get_profile_page_url(&config.instance_url(), &user.profile.username);
let response = HttpResponse::Found()
.append_header((http_header::LOCATION, page_url))
.finish();
return Ok(response);
};
let actor = get_local_actor(&user, &config.instance_url())
.map_err(|_| HttpError::InternalError)?;
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(actor);
let actor =
get_local_actor(&user, &config.instance_url()).map_err(|_| HttpError::InternalError)?;
let response = HttpResponse::Ok().content_type(AP_MEDIA_TYPE).json(actor);
Ok(response)
}
@ -126,19 +104,16 @@ async fn inbox(
activity["id"].as_str().unwrap_or_default(),
);
let db_client = &mut **get_database_client(&db_pool).await?;
receive_activity(&config, db_client, &request, &activity).await
receive_activity(&config, db_client, &request, &activity)
.await
.map_err(|error| {
// TODO: preserve original error text in DatabaseError
if let HandlerError::DatabaseError(
DatabaseError::DatabaseClientError(ref pg_error)) = error
if let HandlerError::DatabaseError(DatabaseError::DatabaseClientError(ref pg_error)) =
error
{
log::error!("database client error: {}", pg_error);
};
log::warn!(
"failed to process activity ({}): {}",
error,
activity,
);
log::warn!("failed to process activity ({}): {}", error, activity,);
error
})?;
Ok(HttpResponse::Accepted().finish())
@ -160,11 +135,7 @@ async fn outbox(
let collection_id = local_actor_outbox(&instance.url(), &username);
let first_page_id = format!("{}?page=true", collection_id);
if query_params.page.is_none() {
let collection = OrderedCollection::new(
collection_id,
Some(first_page_id),
None,
);
let collection = OrderedCollection::new(collection_id, Some(first_page_id), None);
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(collection);
@ -182,27 +153,22 @@ async fn outbox(
true, // include reposts
None,
COLLECTION_PAGE_SIZE,
).await?;
)
.await?;
add_related_posts(db_client, posts.iter_mut().collect()).await?;
let activities = posts.iter().map(|post| {
if post.repost_of_id.is_some() {
let activity = build_announce(&instance.url(), post);
serde_json::to_value(activity)
.expect("activity should be serializable")
} else {
let activity = build_create_note(
&instance.hostname(),
&instance.url(),
post,
);
serde_json::to_value(activity)
.expect("activity should be serializable")
}
}).collect();
let collection_page = OrderedCollectionPage::new(
first_page_id,
activities,
);
let activities = posts
.iter()
.map(|post| {
if post.repost_of_id.is_some() {
let activity = build_announce(&instance.url(), post);
serde_json::to_value(activity).expect("activity should be serializable")
} else {
let activity = build_create_note(&instance.hostname(), &instance.url(), post);
serde_json::to_value(activity).expect("activity should be serializable")
}
})
.collect();
let collection_page = OrderedCollectionPage::new(first_page_id, activities);
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(collection_page);
@ -227,15 +193,8 @@ async fn followers_collection(
};
let db_client = &**get_database_client(&db_pool).await?;
let user = get_user_by_name(db_client, &username).await?;
let collection_id = local_actor_followers(
&config.instance_url(),
&username,
);
let collection = OrderedCollection::new(
collection_id,
None,
Some(user.profile.follower_count),
);
let collection_id = local_actor_followers(&config.instance_url(), &username);
let collection = OrderedCollection::new(collection_id, None, Some(user.profile.follower_count));
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(collection);
@ -255,15 +214,9 @@ async fn following_collection(
};
let db_client = &**get_database_client(&db_pool).await?;
let user = get_user_by_name(db_client, &username).await?;
let collection_id = local_actor_following(
&config.instance_url(),
&username,
);
let collection = OrderedCollection::new(
collection_id,
None,
Some(user.profile.following_count),
);
let collection_id = local_actor_following(&config.instance_url(), &username);
let collection =
OrderedCollection::new(collection_id, None, Some(user.profile.following_count));
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(collection);
@ -283,15 +236,9 @@ async fn subscribers_collection(
};
let db_client = &**get_database_client(&db_pool).await?;
let user = get_user_by_name(db_client, &username).await?;
let collection_id = local_actor_subscribers(
&config.instance_url(),
&username,
);
let collection = OrderedCollection::new(
collection_id,
None,
Some(user.profile.subscriber_count),
);
let collection_id = local_actor_subscribers(&config.instance_url(), &username);
let collection =
OrderedCollection::new(collection_id, None, Some(user.profile.subscriber_count));
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(collection);
@ -310,14 +257,9 @@ pub fn actor_scope() -> Scope {
}
#[get("")]
async fn instance_actor_view(
config: web::Data<Config>,
) -> Result<HttpResponse, HttpError> {
let actor = get_instance_actor(&config.instance())
.map_err(|_| HttpError::InternalError)?;
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(actor);
async fn instance_actor_view(config: web::Data<Config>) -> Result<HttpResponse, HttpError> {
let actor = get_instance_actor(&config.instance()).map_err(|_| HttpError::InternalError)?;
let response = HttpResponse::Ok().content_type(AP_MEDIA_TYPE).json(actor);
Ok(response)
}
@ -370,9 +312,7 @@ pub async fn object_view(
&config.instance().url(),
&post,
);
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(object);
let response = HttpResponse::Ok().content_type(AP_MEDIA_TYPE).json(object);
Ok(response)
}
@ -383,17 +323,9 @@ pub async fn emoji_view(
emoji_name: web::Path<String>,
) -> Result<HttpResponse, HttpError> {
let db_client = &**get_database_client(&db_pool).await?;
let emoji = get_local_emoji_by_name(
db_client,
&emoji_name,
).await?;
let object = build_emoji_tag(
&config.instance().url(),
&emoji,
);
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(object);
let emoji = get_local_emoji_by_name(db_client, &emoji_name).await?;
let object = build_emoji_tag(&config.instance().url(), &emoji);
let response = HttpResponse::Ok().content_type(AP_MEDIA_TYPE).json(object);
Ok(response)
}
@ -411,11 +343,11 @@ pub async fn tag_view(
#[cfg(test)]
mod tests {
use super::*;
use actix_web::http::{
header,
header::{HeaderMap, HeaderValue},
};
use super::*;
#[test]
fn test_is_activitypub_request_mastodon() {
@ -442,10 +374,7 @@ mod tests {
#[test]
fn test_is_activitypub_request_browser() {
let mut request_headers = HeaderMap::new();
request_headers.insert(
header::ACCEPT,
HeaderValue::from_static("text/html"),
);
request_headers.insert(header::ACCEPT, HeaderValue::from_static("text/html"));
let result = is_activitypub_request(&request_headers);
assert_eq!(result, false);
}

View file

@ -1,8 +1,5 @@
use mitra_config::Instance;
use mitra_models::{
posts::types::Post,
profiles::types::DbActorProfile,
};
use mitra_models::{posts::types::Post, profiles::types::DbActorProfile};
use mitra_utils::{
datetime::get_min_datetime,
html::{clean_html_all, escape_html},
@ -13,19 +10,16 @@ use crate::webfinger::types::ActorAddress;
const ENTRY_TITLE_MAX_LENGTH: usize = 75;
fn make_entry(
instance_url: &str,
post: &Post,
) -> String {
fn make_entry(instance_url: &str, post: &Post) -> String {
let object_id = local_object_id(instance_url, &post.id);
let content_escaped = escape_html(&post.content);
let content_cleaned = clean_html_all(&post.content);
// Use trimmed content for title
let mut title: String = content_cleaned.chars()
let mut title: String = content_cleaned
.chars()
.take(ENTRY_TITLE_MAX_LENGTH)
.collect();
if title.len() == ENTRY_TITLE_MAX_LENGTH &&
content_cleaned.len() != ENTRY_TITLE_MAX_LENGTH {
if title.len() == ENTRY_TITLE_MAX_LENGTH && content_cleaned.len() != ENTRY_TITLE_MAX_LENGTH {
title += "...";
};
format!(
@ -37,11 +31,11 @@ fn make_entry(
<content type="html">{content}</content>
<link rel="alternate" href="{url}"/>
</entry>"#,
url=object_id,
title=title,
updated_at=post.created_at.to_rfc3339(),
author=post.author.username,
content=content_escaped,
url = object_id,
title = title,
updated_at = post.created_at.to_rfc3339(),
author = post.author.username,
content = content_escaped,
)
}
@ -49,18 +43,10 @@ fn get_feed_url(instance_url: &str, username: &str) -> String {
format!("{}/feeds/users/{}", instance_url, username)
}
pub fn make_feed(
instance: &Instance,
profile: &DbActorProfile,
posts: Vec<Post>,
) -> String {
pub fn make_feed(instance: &Instance, profile: &DbActorProfile, posts: Vec<Post>) -> String {
let actor_id = local_actor_id(&instance.url(), &profile.username);
let actor_name = profile.display_name.as_ref()
.unwrap_or(&profile.username);
let actor_address = ActorAddress::from_profile(
&instance.hostname(),
profile,
);
let actor_name = profile.display_name.as_ref().unwrap_or(&profile.username);
let actor_address = ActorAddress::from_profile(&instance.hostname(), profile);
let feed_url = get_feed_url(&instance.url(), &profile.username);
let feed_title = format!("{} (@{})", actor_name, actor_address);
let mut entries = vec![];
@ -71,7 +57,7 @@ pub fn make_feed(
if post.created_at > feed_updated_at {
feed_updated_at = post.created_at;
};
};
}
format!(
r#"<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
@ -81,19 +67,19 @@ pub fn make_feed(
<updated>{updated_at}</updated>
{entries}
</feed>"#,
id=actor_id,
url=feed_url,
title=feed_title,
updated_at=feed_updated_at.to_rfc3339(),
entries=entries.join("\n"),
id = actor_id,
url = feed_url,
title = feed_title,
updated_at = feed_updated_at.to_rfc3339(),
entries = entries.join("\n"),
)
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{TimeZone, Utc};
use uuid::uuid;
use super::*;
#[test]
fn test_make_entry() {
@ -118,8 +104,10 @@ mod tests {
" <title>titletext text text</title>\n",
" <updated>2020-03-03T03:03:03+00:00</updated>\n",
" <author><name>username</name></author>\n",
r#" <content type="html">&lt;p&gt;title&lt;&#47;p&gt;&lt;p&gt;text&#32;text&#32;text&lt;&#47;p&gt;</content>"#, "\n",
r#" <link rel="alternate" href="https://example.org/objects/67e55044-10b1-426f-9247-bb680e5fe0c8"/>"#, "\n",
r#" <content type="html">&lt;p&gt;title&lt;&#47;p&gt;&lt;p&gt;text&#32;text&#32;text&lt;&#47;p&gt;</content>"#,
"\n",
r#" <link rel="alternate" href="https://example.org/objects/67e55044-10b1-426f-9247-bb680e5fe0c8"/>"#,
"\n",
"</entry>",
);
assert_eq!(entry, expected_entry);

Some files were not shown because too many files have changed in this diff Show more