Apply cargo fmt

This commit is contained in:
Rafael Caricio 2023-04-24 17:35:32 +02:00
parent b4ff7abbc1
commit 47529ff703
Signed by: rafaelcaricio
GPG key ID: 3C86DBCE8E93C947
165 changed files with 3779 additions and 5132 deletions

View file

@ -22,7 +22,7 @@ async fn main() {
log::info!("config loaded from {}", config.config_path); log::info!("config loaded from {}", config.config_path);
for warning in config_warnings { for warning in config_warnings {
log::warn!("{}", warning); log::warn!("{}", warning);
}; }
let db_config = config.database_url.parse().unwrap(); let db_config = config.database_url.parse().unwrap();
let db_client = &mut create_database_client(&db_config).await; let db_client = &mut create_database_client(&db_config).await;
@ -39,19 +39,37 @@ async fn main() {
SubCommand::DeleteProfile(cmd) => cmd.execute(&config, db_client).await.unwrap(), SubCommand::DeleteProfile(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeletePost(cmd) => cmd.execute(&config, db_client).await.unwrap(), SubCommand::DeletePost(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeleteEmoji(cmd) => cmd.execute(&config, db_client).await.unwrap(), SubCommand::DeleteEmoji(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::DeleteExtraneousPosts(cmd) => cmd.execute(&config, db_client).await.unwrap(), SubCommand::DeleteExtraneousPosts(cmd) => {
SubCommand::DeleteUnusedAttachments(cmd) => cmd.execute(&config, db_client).await.unwrap(), cmd.execute(&config, db_client).await.unwrap()
SubCommand::DeleteOrphanedFiles(cmd) => cmd.execute(&config, db_client).await.unwrap(), }
SubCommand::DeleteEmptyProfiles(cmd) => cmd.execute(&config, db_client).await.unwrap(), SubCommand::DeleteUnusedAttachments(cmd) => {
SubCommand::PruneRemoteEmojis(cmd) => cmd.execute(&config, db_client).await.unwrap(), cmd.execute(&config, db_client).await.unwrap()
SubCommand::ListUnreachableActors(cmd) => cmd.execute(&config, db_client).await.unwrap(), }
SubCommand::DeleteOrphanedFiles(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::DeleteEmptyProfiles(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::PruneRemoteEmojis(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::ListUnreachableActors(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::ImportEmoji(cmd) => cmd.execute(&config, db_client).await.unwrap(), SubCommand::ImportEmoji(cmd) => cmd.execute(&config, db_client).await.unwrap(),
SubCommand::UpdateCurrentBlock(cmd) => cmd.execute(&config, db_client).await.unwrap(), SubCommand::UpdateCurrentBlock(cmd) => {
SubCommand::ResetSubscriptions(cmd) => cmd.execute(&config, db_client).await.unwrap(), cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::ResetSubscriptions(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
SubCommand::CreateMoneroWallet(cmd) => cmd.execute(&config).await.unwrap(), SubCommand::CreateMoneroWallet(cmd) => cmd.execute(&config).await.unwrap(),
SubCommand::CheckExpiredInvoice(cmd) => cmd.execute(&config, db_client).await.unwrap(), SubCommand::CheckExpiredInvoice(cmd) => {
cmd.execute(&config, db_client).await.unwrap()
}
_ => unreachable!(), _ => unreachable!(),
}; };
}, }
}; };
} }

View file

@ -1,6 +1,6 @@
use std::path::PathBuf; use std::path::PathBuf;
use log::{Level as LogLevel}; use log::Level as LogLevel;
use rsa::RsaPrivateKey; use rsa::RsaPrivateKey;
use serde::Deserialize; use serde::Deserialize;
use url::Url; use url::Url;
@ -14,9 +14,13 @@ use super::registration::RegistrationConfig;
use super::retention::RetentionConfig; use super::retention::RetentionConfig;
use super::REEF_VERSION; use super::REEF_VERSION;
fn default_log_level() -> LogLevel { LogLevel::Info } fn default_log_level() -> LogLevel {
LogLevel::Info
}
fn default_login_message() -> String { "What?!".to_string() } fn default_login_message() -> String {
"What?!".to_string()
}
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
pub struct Config { pub struct Config {
@ -95,9 +99,8 @@ impl Config {
onion_proxy_url: self.federation.onion_proxy_url.clone(), onion_proxy_url: self.federation.onion_proxy_url.clone(),
i2p_proxy_url: self.federation.i2p_proxy_url.clone(), i2p_proxy_url: self.federation.i2p_proxy_url.clone(),
// Private instance doesn't send activities and sign requests // Private instance doesn't send activities and sign requests
is_private: is_private: !self.federation.enabled
!self.federation.enabled || || matches!(self.environment, Environment::Development),
matches!(self.environment, Environment::Development),
fetcher_timeout: self.federation.fetcher_timeout, fetcher_timeout: self.federation.fetcher_timeout,
deliverer_timeout: self.federation.deliverer_timeout, deliverer_timeout: self.federation.deliverer_timeout,
} }
@ -139,8 +142,8 @@ impl Instance {
pub fn agent(&self) -> String { pub fn agent(&self) -> String {
format!( format!(
"Reef {version}; {instance_url}", "Reef {version}; {instance_url}",
version= REEF_VERSION, version = REEF_VERSION,
instance_url=self.url(), instance_url = self.url(),
) )
} }
} }
@ -164,8 +167,8 @@ impl Instance {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::crypto_rsa::generate_weak_rsa_key;
use super::*; use super::*;
use mitra_utils::crypto_rsa::generate_weak_rsa_key;
#[test] #[test]
fn test_instance_url_https_dns() { fn test_instance_url_https_dns() {

View file

@ -10,9 +10,13 @@ pub enum Environment {
impl Default for Environment { impl Default for Environment {
#[cfg(feature = "production")] #[cfg(feature = "production")]
fn default() -> Self { Self::Production } fn default() -> Self {
Self::Production
}
#[cfg(not(feature = "production"))] #[cfg(not(feature = "production"))]
fn default() -> Self { Self::Development } fn default() -> Self {
Self::Development
}
} }
impl FromStr for Environment { impl FromStr for Environment {

View file

@ -1,9 +1,15 @@
use serde::Deserialize; use serde::Deserialize;
fn default_federation_enabled() -> bool { true } fn default_federation_enabled() -> bool {
true
}
const fn default_fetcher_timeout() -> u64 { 300 } const fn default_fetcher_timeout() -> u64 {
const fn default_deliverer_timeout() -> u64 { 30 } 300
}
const fn default_deliverer_timeout() -> u64 {
30
}
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
pub struct FederationConfig { pub struct FederationConfig {

View file

@ -1,19 +1,17 @@
use regex::Regex;
use serde::{
Deserialize,
Deserializer,
de::{Error as DeserializerError},
};
use super::ConfigError; use super::ConfigError;
use regex::Regex;
use serde::{de::Error as DeserializerError, Deserialize, Deserializer};
const FILE_SIZE_RE: &str = r#"^(?i)(?P<size>\d+)(?P<unit>[kmg]?)b?$"#; const FILE_SIZE_RE: &str = r#"^(?i)(?P<size>\d+)(?P<unit>[kmg]?)b?$"#;
fn parse_file_size(value: &str) -> Result<usize, ConfigError> { fn parse_file_size(value: &str) -> Result<usize, ConfigError> {
let file_size_re = Regex::new(FILE_SIZE_RE) let file_size_re = Regex::new(FILE_SIZE_RE).expect("regexp should be valid");
.expect("regexp should be valid"); let caps = file_size_re
let caps = file_size_re.captures(value) .captures(value)
.ok_or(ConfigError("invalid file size"))?; .ok_or(ConfigError("invalid file size"))?;
let size: usize = caps["size"].to_string().parse() let size: usize = caps["size"]
.to_string()
.parse()
.map_err(|_| ConfigError("invalid file size"))?; .map_err(|_| ConfigError("invalid file size"))?;
let unit = caps["unit"].to_string().to_lowercase(); let unit = caps["unit"].to_string().to_lowercase();
let multiplier = match unit.as_str() { let multiplier = match unit.as_str() {
@ -26,31 +24,33 @@ fn parse_file_size(value: &str) -> Result<usize, ConfigError> {
Ok(size * multiplier) Ok(size * multiplier)
} }
fn deserialize_file_size<'de, D>( fn deserialize_file_size<'de, D>(deserializer: D) -> Result<usize, D::Error>
deserializer: D, where
) -> Result<usize, D::Error> D: Deserializer<'de>,
where D: Deserializer<'de>
{ {
let file_size_str = String::deserialize(deserializer)?; let file_size_str = String::deserialize(deserializer)?;
let file_size = parse_file_size(&file_size_str) let file_size = parse_file_size(&file_size_str).map_err(DeserializerError::custom)?;
.map_err(DeserializerError::custom)?;
Ok(file_size) Ok(file_size)
} }
const fn default_file_size_limit() -> usize { 20_000_000 } // 20 MB const fn default_file_size_limit() -> usize {
const fn default_emoji_size_limit() -> usize { 500_000 } // 500 kB 20_000_000
} // 20 MB
const fn default_emoji_size_limit() -> usize {
500_000
} // 500 kB
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
pub struct MediaLimits { pub struct MediaLimits {
#[serde( #[serde(
default = "default_file_size_limit", default = "default_file_size_limit",
deserialize_with = "deserialize_file_size", deserialize_with = "deserialize_file_size"
)] )]
pub file_size_limit: usize, pub file_size_limit: usize,
#[serde( #[serde(
default = "default_emoji_size_limit", default = "default_emoji_size_limit",
deserialize_with = "deserialize_file_size", deserialize_with = "deserialize_file_size"
)] )]
pub emoji_size_limit: usize, pub emoji_size_limit: usize,
} }
@ -64,7 +64,9 @@ impl Default for MediaLimits {
} }
} }
const fn default_post_character_limit() -> usize { 2000 } const fn default_post_character_limit() -> usize {
2000
}
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
pub struct PostLimits { pub struct PostLimits {

View file

@ -5,11 +5,7 @@ use std::str::FromStr;
use rsa::RsaPrivateKey; use rsa::RsaPrivateKey;
use mitra_utils::{ use mitra_utils::{
crypto_rsa::{ crypto_rsa::{deserialize_private_key, generate_rsa_key, serialize_private_key},
deserialize_private_key,
generate_rsa_key,
serialize_private_key,
},
files::{set_file_permissions, write_file}, files::{set_file_permissions, write_file},
}; };
@ -30,9 +26,9 @@ const DEFAULT_CONFIG_PATH: &str = "config.yaml";
fn parse_env() -> EnvConfig { fn parse_env() -> EnvConfig {
dotenv::from_filename(".env.local").ok(); dotenv::from_filename(".env.local").ok();
dotenv::dotenv().ok(); dotenv::dotenv().ok();
let config_path = std::env::var("CONFIG_PATH") let config_path = std::env::var("CONFIG_PATH").unwrap_or(DEFAULT_CONFIG_PATH.to_string());
.unwrap_or(DEFAULT_CONFIG_PATH.to_string()); let environment = std::env::var("ENVIRONMENT")
let environment = std::env::var("ENVIRONMENT").ok() .ok()
.map(|val| Environment::from_str(&val).expect("invalid environment type")); .map(|val| Environment::from_str(&val).expect("invalid environment type"));
EnvConfig { EnvConfig {
config_path, config_path,
@ -45,8 +41,7 @@ extern "C" {
} }
fn check_directory_owner(path: &Path) -> () { fn check_directory_owner(path: &Path) -> () {
let metadata = std::fs::metadata(path) let metadata = std::fs::metadata(path).expect("can't read file metadata");
.expect("can't read file metadata");
let owner_uid = metadata.uid(); let owner_uid = metadata.uid();
let current_uid = unsafe { geteuid() }; let current_uid = unsafe { geteuid() };
if owner_uid != current_uid { if owner_uid != current_uid {
@ -63,16 +58,15 @@ fn check_directory_owner(path: &Path) -> () {
fn read_instance_rsa_key(storage_dir: &Path) -> RsaPrivateKey { fn read_instance_rsa_key(storage_dir: &Path) -> RsaPrivateKey {
let private_key_path = storage_dir.join("instance_rsa_key"); let private_key_path = storage_dir.join("instance_rsa_key");
if private_key_path.exists() { if private_key_path.exists() {
let private_key_str = std::fs::read_to_string(&private_key_path) let private_key_str =
.expect("failed to read instance RSA key"); std::fs::read_to_string(&private_key_path).expect("failed to read instance RSA key");
let private_key = deserialize_private_key(&private_key_str) let private_key =
.expect("failed to read instance RSA key"); deserialize_private_key(&private_key_str).expect("failed to read instance RSA key");
private_key private_key
} else { } else {
let private_key = generate_rsa_key() let private_key = generate_rsa_key().expect("failed to generate RSA key");
.expect("failed to generate RSA key"); let private_key_str =
let private_key_str = serialize_private_key(&private_key) serialize_private_key(&private_key).expect("failed to serialize RSA key");
.expect("failed to serialize RSA key");
write_file(private_key_str.as_bytes(), &private_key_path) write_file(private_key_str.as_bytes(), &private_key_path)
.expect("failed to write instance RSA key"); .expect("failed to write instance RSA key");
set_file_permissions(&private_key_path, 0o600) set_file_permissions(&private_key_path, 0o600)
@ -83,10 +77,9 @@ fn read_instance_rsa_key(storage_dir: &Path) -> RsaPrivateKey {
pub fn parse_config() -> (Config, Vec<&'static str>) { pub fn parse_config() -> (Config, Vec<&'static str>) {
let env = parse_env(); let env = parse_env();
let config_yaml = std::fs::read_to_string(&env.config_path) let config_yaml =
.expect("failed to load config file"); std::fs::read_to_string(&env.config_path).expect("failed to load config file");
let mut config = serde_yaml::from_str::<Config>(&config_yaml) let mut config = serde_yaml::from_str::<Config>(&config_yaml).expect("invalid yaml data");
.expect("invalid yaml data");
let mut warnings = vec![]; let mut warnings = vec![];
// Set parameters from environment // Set parameters from environment
@ -109,7 +102,8 @@ pub fn parse_config() -> (Config, Vec<&'static str>) {
// Migrations // Migrations
if let Some(registrations_open) = config.registrations_open { if let Some(registrations_open) = config.registrations_open {
// Change type if 'registrations_open' parameter is used // Change type if 'registrations_open' parameter is used
warnings.push("'registrations_open' setting is deprecated, use 'registration.type' instead"); warnings
.push("'registrations_open' setting is deprecated, use 'registration.type' instead");
if registrations_open { if registrations_open {
config.registration.registration_type = RegistrationType::Open; config.registration.registration_type = RegistrationType::Open;
} else { } else {

View file

@ -1,8 +1,4 @@
use serde::{ use serde::{de::Error as DeserializerError, Deserialize, Deserializer};
Deserialize,
Deserializer,
de::Error as DeserializerError,
};
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
pub enum RegistrationType { pub enum RegistrationType {
@ -11,12 +7,15 @@ pub enum RegistrationType {
} }
impl Default for RegistrationType { impl Default for RegistrationType {
fn default() -> Self { Self::Invite } fn default() -> Self {
Self::Invite
}
} }
impl<'de> Deserialize<'de> for RegistrationType { impl<'de> Deserialize<'de> for RegistrationType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> where
D: Deserializer<'de>,
{ {
let registration_type_str = String::deserialize(deserializer)?; let registration_type_str = String::deserialize(deserializer)?;
let registration_type = match registration_type_str.as_str() { let registration_type = match registration_type_str.as_str() {
@ -35,12 +34,15 @@ pub enum DefaultRole {
} }
impl Default for DefaultRole { impl Default for DefaultRole {
fn default() -> Self { Self::NormalUser } fn default() -> Self {
Self::NormalUser
}
} }
impl<'de> Deserialize<'de> for DefaultRole { impl<'de> Deserialize<'de> for DefaultRole {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> where
D: Deserializer<'de>,
{ {
let role_str = String::deserialize(deserializer)?; let role_str = String::deserialize(deserializer)?;
let role = match role_str.as_str() { let role = match role_str.as_str() {

View file

@ -3,11 +3,7 @@ use uuid::Uuid;
use mitra_utils::id::generate_ulid; use mitra_utils::id::generate_ulid;
use crate::cleanup::{ use crate::cleanup::{find_orphaned_files, find_orphaned_ipfs_objects, DeletionQueue};
find_orphaned_files,
find_orphaned_ipfs_objects,
DeletionQueue,
};
use crate::database::{DatabaseClient, DatabaseError}; use crate::database::{DatabaseClient, DatabaseError};
use super::types::DbMediaAttachment; use super::types::DbMediaAttachment;
@ -20,10 +16,10 @@ pub async fn create_attachment(
media_type: Option<String>, media_type: Option<String>,
) -> Result<DbMediaAttachment, DatabaseError> { ) -> Result<DbMediaAttachment, DatabaseError> {
let attachment_id = generate_ulid(); let attachment_id = generate_ulid();
let file_size: i32 = file_size.try_into() let file_size: i32 = file_size.try_into().expect("value should be within bounds");
.expect("value should be within bounds"); let inserted_row = db_client
let inserted_row = db_client.query_one( .query_one(
" "
INSERT INTO media_attachment ( INSERT INTO media_attachment (
id, id,
owner_id, owner_id,
@ -34,14 +30,15 @@ pub async fn create_attachment(
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
RETURNING media_attachment RETURNING media_attachment
", ",
&[ &[
&attachment_id, &attachment_id,
&owner_id, &owner_id,
&file_name, &file_name,
&file_size, &file_size,
&media_type, &media_type,
], ],
).await?; )
.await?;
let db_attachment: DbMediaAttachment = inserted_row.try_get("media_attachment")?; let db_attachment: DbMediaAttachment = inserted_row.try_get("media_attachment")?;
Ok(db_attachment) Ok(db_attachment)
} }
@ -51,15 +48,17 @@ pub async fn set_attachment_ipfs_cid(
attachment_id: &Uuid, attachment_id: &Uuid,
ipfs_cid: &str, ipfs_cid: &str,
) -> Result<DbMediaAttachment, DatabaseError> { ) -> Result<DbMediaAttachment, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
UPDATE media_attachment UPDATE media_attachment
SET ipfs_cid = $1 SET ipfs_cid = $1
WHERE id = $2 AND ipfs_cid IS NULL WHERE id = $2 AND ipfs_cid IS NULL
RETURNING media_attachment RETURNING media_attachment
", ",
&[&ipfs_cid, &attachment_id], &[&ipfs_cid, &attachment_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("attachment"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("attachment"))?;
let db_attachment = row.try_get("media_attachment")?; let db_attachment = row.try_get("media_attachment")?;
Ok(db_attachment) Ok(db_attachment)
@ -69,14 +68,16 @@ pub async fn delete_unused_attachments(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
created_before: &DateTime<Utc>, created_before: &DateTime<Utc>,
) -> Result<DeletionQueue, DatabaseError> { ) -> Result<DeletionQueue, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
DELETE FROM media_attachment DELETE FROM media_attachment
WHERE post_id IS NULL AND created_at < $1 WHERE post_id IS NULL AND created_at < $1
RETURNING file_name, ipfs_cid RETURNING file_name, ipfs_cid
", ",
&[&created_before], &[&created_before],
).await?; )
.await?;
let mut files = vec![]; let mut files = vec![];
let mut ipfs_objects = vec![]; let mut ipfs_objects = vec![];
for row in rows { for row in rows {
@ -85,7 +86,7 @@ pub async fn delete_unused_attachments(
if let Some(ipfs_cid) = row.try_get("ipfs_cid")? { if let Some(ipfs_cid) = row.try_get("ipfs_cid")? {
ipfs_objects.push(ipfs_cid); ipfs_objects.push(ipfs_cid);
}; };
}; }
let orphaned_files = find_orphaned_files(db_client, files).await?; let orphaned_files = find_orphaned_files(db_client, files).await?;
let orphaned_ipfs_objects = find_orphaned_ipfs_objects(db_client, ipfs_objects).await?; let orphaned_ipfs_objects = find_orphaned_ipfs_objects(db_client, ipfs_objects).await?;
Ok(DeletionQueue { Ok(DeletionQueue {
@ -96,13 +97,10 @@ pub async fn delete_unused_attachments(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use crate::profiles::{
queries::create_profile,
types::ProfileCreateData,
};
use super::*; use super::*;
use crate::database::test_utils::create_test_database;
use crate::profiles::{queries::create_profile, types::ProfileCreateData};
use serial_test::serial;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -122,7 +120,9 @@ mod tests {
file_name.to_string(), file_name.to_string(),
file_size, file_size,
Some(media_type.to_string()), Some(media_type.to_string()),
).await.unwrap(); )
.await
.unwrap();
assert_eq!(attachment.owner_id, profile.id); assert_eq!(attachment.owner_id, profile.id);
assert_eq!(attachment.file_name, file_name); assert_eq!(attachment.file_name, file_name);
assert_eq!(attachment.file_size.unwrap(), file_size as i32); assert_eq!(attachment.file_size.unwrap(), file_size as i32);

View file

@ -35,7 +35,7 @@ impl AttachmentType {
} else { } else {
Self::Unknown Self::Unknown
} }
}, }
None => Self::Unknown, None => Self::Unknown,
} }
} }

View file

@ -2,8 +2,8 @@ use chrono::{DateTime, Utc};
use serde_json::Value; use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
use crate::database::{DatabaseClient, DatabaseError};
use super::types::{DbBackgroundJob, JobStatus, JobType}; use super::types::{DbBackgroundJob, JobStatus, JobType};
use crate::database::{DatabaseClient, DatabaseError};
pub async fn enqueue_job( pub async fn enqueue_job(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
@ -12,8 +12,9 @@ pub async fn enqueue_job(
scheduled_for: &DateTime<Utc>, scheduled_for: &DateTime<Utc>,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let job_id = Uuid::new_v4(); let job_id = Uuid::new_v4();
db_client.execute( db_client
" .execute(
"
INSERT INTO background_job ( INSERT INTO background_job (
id, id,
job_type, job_type,
@ -22,8 +23,9 @@ pub async fn enqueue_job(
) )
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
", ",
&[&job_id, &job_type, &job_data, &scheduled_for], &[&job_id, &job_type, &job_data, &scheduled_for],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -35,8 +37,9 @@ pub async fn get_job_batch(
) -> Result<Vec<DbBackgroundJob>, DatabaseError> { ) -> Result<Vec<DbBackgroundJob>, DatabaseError> {
// https://github.com/sfackler/rust-postgres/issues/60 // https://github.com/sfackler/rust-postgres/issues/60
let job_timeout_pg = format!("{}S", job_timeout); // interval let job_timeout_pg = format!("{}S", job_timeout); // interval
let rows = db_client.query( let rows = db_client
" .query(
"
UPDATE background_job UPDATE background_job
SET SET
job_status = $1, job_status = $1,
@ -62,15 +65,17 @@ pub async fn get_job_batch(
) )
RETURNING background_job RETURNING background_job
", ",
&[ &[
&JobStatus::Running, &JobStatus::Running,
&job_type, &job_type,
&JobStatus::Queued, &JobStatus::Queued,
&i64::from(batch_size), &i64::from(batch_size),
&job_timeout_pg, &job_timeout_pg,
], ],
).await?; )
let jobs = rows.iter() .await?;
let jobs = rows
.iter()
.map(|row| row.try_get("background_job")) .map(|row| row.try_get("background_job"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(jobs) Ok(jobs)
@ -80,13 +85,15 @@ pub async fn delete_job_from_queue(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
job_id: &Uuid, job_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let deleted_count = db_client.execute( let deleted_count = db_client
" .execute(
"
DELETE FROM background_job DELETE FROM background_job
WHERE id = $1 WHERE id = $1
", ",
&[&job_id], &[&job_id],
).await?; )
.await?;
if deleted_count == 0 { if deleted_count == 0 {
return Err(DatabaseError::NotFound("background job")); return Err(DatabaseError::NotFound("background job"));
}; };
@ -95,10 +102,10 @@ pub async fn delete_job_from_queue(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use crate::database::test_utils::create_test_database;
use serde_json::json; use serde_json::json;
use serial_test::serial; use serial_test::serial;
use crate::database::test_utils::create_test_database;
use super::*;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -111,7 +118,9 @@ mod tests {
"failure_count": 0, "failure_count": 0,
}); });
let scheduled_for = Utc::now(); let scheduled_for = Utc::now();
enqueue_job(db_client, &job_type, &job_data, &scheduled_for).await.unwrap(); enqueue_job(db_client, &job_type, &job_data, &scheduled_for)
.await
.unwrap();
let batch_1 = get_job_batch(db_client, &job_type, 10, 3600).await.unwrap(); let batch_1 = get_job_batch(db_client, &job_type, 10, 3600).await.unwrap();
assert_eq!(batch_1.len(), 1); assert_eq!(batch_1.len(), 1);

View file

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde_json::Value;
use postgres_types::FromSql; use postgres_types::FromSql;
use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
use crate::database::{ use crate::database::{

View file

@ -9,8 +9,9 @@ pub async fn find_orphaned_files(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
files: Vec<String>, files: Vec<String>,
) -> Result<Vec<String>, DatabaseError> { ) -> Result<Vec<String>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT DISTINCT fname SELECT DISTINCT fname
FROM unnest($1::text[]) AS fname FROM unnest($1::text[]) AS fname
WHERE WHERE
@ -27,9 +28,11 @@ pub async fn find_orphaned_files(
WHERE image ->> 'file_name' = fname WHERE image ->> 'file_name' = fname
) )
", ",
&[&files], &[&files],
).await?; )
let orphaned_files = rows.iter() .await?;
let orphaned_files = rows
.iter()
.map(|row| row.try_get("fname")) .map(|row| row.try_get("fname"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(orphaned_files) Ok(orphaned_files)
@ -39,8 +42,9 @@ pub async fn find_orphaned_ipfs_objects(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
ipfs_objects: Vec<String>, ipfs_objects: Vec<String>,
) -> Result<Vec<String>, DatabaseError> { ) -> Result<Vec<String>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT DISTINCT cid SELECT DISTINCT cid
FROM unnest($1::text[]) AS cid FROM unnest($1::text[]) AS cid
WHERE WHERE
@ -51,9 +55,11 @@ pub async fn find_orphaned_ipfs_objects(
SELECT 1 FROM post WHERE ipfs_cid = cid SELECT 1 FROM post WHERE ipfs_cid = cid
) )
", ",
&[&ipfs_objects], &[&ipfs_objects],
).await?; )
let orphaned_ipfs_objects = rows.iter() .await?;
let orphaned_ipfs_objects = rows
.iter()
.map(|row| row.try_get("cid")) .map(|row| row.try_get("cid"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(orphaned_ipfs_objects) Ok(orphaned_ipfs_objects)

View file

@ -12,7 +12,7 @@ macro_rules! int_enum_from_sql {
postgres_types::accepts!(INT2); postgres_types::accepts!(INT2);
} }
} };
} }
macro_rules! int_enum_to_sql { macro_rules! int_enum_to_sql {
@ -31,7 +31,7 @@ macro_rules! int_enum_to_sql {
postgres_types::accepts!(INT2); postgres_types::accepts!(INT2);
postgres_types::to_sql_checked!(); postgres_types::to_sql_checked!();
} }
} };
} }
pub(crate) use {int_enum_from_sql, int_enum_to_sql}; pub(crate) use {int_enum_from_sql, int_enum_to_sql};

View file

@ -14,7 +14,7 @@ macro_rules! json_from_sql {
postgres_types::accepts!(JSON, JSONB); postgres_types::accepts!(JSON, JSONB);
} }
} };
} }
/// Implements ToSql trait for any serializable type /// Implements ToSql trait for any serializable type
@ -33,7 +33,7 @@ macro_rules! json_to_sql {
postgres_types::accepts!(JSON, JSONB); postgres_types::accepts!(JSON, JSONB);
postgres_types::to_sql_checked!(); postgres_types::to_sql_checked!();
} }
} };
} }
pub(crate) use {json_from_sql, json_to_sql}; pub(crate) use {json_from_sql, json_to_sql};

View file

@ -8,7 +8,8 @@ mod embedded {
pub async fn apply_migrations(db_client: &mut Client) { pub async fn apply_migrations(db_client: &mut Client) {
let migration_report = embedded::migrations::runner() let migration_report = embedded::migrations::runner()
.run_async(db_client) .run_async(db_client)
.await.unwrap(); .await
.unwrap();
for migration in migration_report.applied_migrations() { for migration in migration_report.applied_migrations() {
log::info!( log::info!(

View file

@ -1,4 +1,4 @@
use tokio_postgres::config::{Config as DatabaseConfig}; use tokio_postgres::config::Config as DatabaseConfig;
use tokio_postgres::error::{Error as PgError, SqlState}; use tokio_postgres::error::{Error as PgError, SqlState};
pub mod int_enum; pub mod int_enum;
@ -10,7 +10,7 @@ pub mod query_macro;
pub mod test_utils; pub mod test_utils;
pub type DbPool = deadpool_postgres::Pool; pub type DbPool = deadpool_postgres::Pool;
pub use tokio_postgres::{GenericClient as DatabaseClient}; pub use tokio_postgres::GenericClient as DatabaseClient;
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
#[error("database type error")] #[error("database type error")]
@ -37,11 +37,8 @@ pub enum DatabaseError {
AlreadyExists(&'static str), // object type AlreadyExists(&'static str), // object type
} }
pub async fn create_database_client(db_config: &DatabaseConfig) pub async fn create_database_client(db_config: &DatabaseConfig) -> tokio_postgres::Client {
-> tokio_postgres::Client let (client, connection) = db_config.connect(tokio_postgres::NoTls).await.unwrap();
{
let (client, connection) = db_config.connect(tokio_postgres::NoTls)
.await.unwrap();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = connection.await { if let Err(err) = connection.await {
log::error!("connection error: {}", err); log::error!("connection error: {}", err);
@ -55,21 +52,22 @@ pub fn create_pool(database_url: &str, pool_size: usize) -> DbPool {
database_url.parse().expect("invalid database URL"), database_url.parse().expect("invalid database URL"),
tokio_postgres::NoTls, tokio_postgres::NoTls,
); );
DbPool::builder(manager).max_size(pool_size).build().unwrap() DbPool::builder(manager)
.max_size(pool_size)
.build()
.unwrap()
} }
pub async fn get_database_client(db_pool: &DbPool) pub async fn get_database_client(
-> Result<deadpool_postgres::Client, DatabaseError> db_pool: &DbPool,
{ ) -> Result<deadpool_postgres::Client, DatabaseError> {
// Returns wrapped client // Returns wrapped client
// https://github.com/bikeshedder/deadpool/issues/56 // https://github.com/bikeshedder/deadpool/issues/56
let client = db_pool.get().await?; let client = db_pool.get().await?;
Ok(client) Ok(client)
} }
pub fn catch_unique_violation( pub fn catch_unique_violation(object_type: &'static str) -> impl Fn(PgError) -> DatabaseError {
object_type: &'static str,
) -> impl Fn(PgError) -> DatabaseError {
move |err| { move |err| {
if let Some(code) = err.code() { if let Some(code) = err.code() {
if code == &SqlState::UNIQUE_VIOLATION { if code == &SqlState::UNIQUE_VIOLATION {

View file

@ -1,31 +1,30 @@
use tokio_postgres::Client;
use tokio_postgres::config::Config;
use super::create_database_client; use super::create_database_client;
use super::migrate::apply_migrations; use super::migrate::apply_migrations;
use tokio_postgres::config::Config;
use tokio_postgres::Client;
const DEFAULT_CONNECTION_URL: &str = "postgres://mitra:mitra@127.0.0.1:55432/mitra-test"; const DEFAULT_CONNECTION_URL: &str = "postgres://mitra:mitra@127.0.0.1:55432/mitra-test";
pub async fn create_test_database() -> Client { pub async fn create_test_database() -> Client {
let connection_url = std::env::var("TEST_DATABASE_URL") let connection_url =
.unwrap_or(DEFAULT_CONNECTION_URL.to_string()); std::env::var("TEST_DATABASE_URL").unwrap_or(DEFAULT_CONNECTION_URL.to_string());
let mut db_config: Config = connection_url.parse() let mut db_config: Config = connection_url
.parse()
.expect("invalid database connection URL"); .expect("invalid database connection URL");
let db_name = db_config.get_dbname() let db_name = db_config
.get_dbname()
.expect("database name not specified") .expect("database name not specified")
.to_string(); .to_string();
// Create connection without database name // Create connection without database name
db_config.dbname(""); db_config.dbname("");
let db_client = create_database_client(&db_config).await; let db_client = create_database_client(&db_config).await;
let drop_db_statement = format!( let drop_db_statement = format!("DROP DATABASE IF EXISTS {db_name:?}", db_name = db_name,);
"DROP DATABASE IF EXISTS {db_name:?}",
db_name=db_name,
);
db_client.execute(&drop_db_statement, &[]).await.unwrap(); db_client.execute(&drop_db_statement, &[]).await.unwrap();
let create_db_statement = format!( let create_db_statement = format!(
"CREATE DATABASE {db_name:?} WITH OWNER={owner:?};", "CREATE DATABASE {db_name:?} WITH OWNER={owner:?};",
db_name=db_name, db_name = db_name,
owner=db_config.get_user().unwrap(), owner = db_config.get_user().unwrap(),
); );
db_client.execute(&create_db_statement, &[]).await.unwrap(); db_client.execute(&create_db_statement, &[]).await.unwrap();

View file

@ -1,10 +1,7 @@
use crate::database::{DatabaseClient, DatabaseError}; use crate::database::{DatabaseClient, DatabaseError};
use super::queries::{get_emoji_by_name_and_hostname, get_local_emoji_by_name};
use super::types::DbEmoji; use super::types::DbEmoji;
use super::queries::{
get_local_emoji_by_name,
get_emoji_by_name_and_hostname,
};
pub async fn get_emoji_by_name( pub async fn get_emoji_by_name(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,

View file

@ -4,11 +4,7 @@ use uuid::Uuid;
use mitra_utils::id::generate_ulid; use mitra_utils::id::generate_ulid;
use crate::cleanup::{find_orphaned_files, DeletionQueue}; use crate::cleanup::{find_orphaned_files, DeletionQueue};
use crate::database::{ use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::instances::queries::create_instance; use crate::instances::queries::create_instance;
use crate::profiles::queries::update_emoji_caches; use crate::profiles::queries::update_emoji_caches;
@ -26,8 +22,9 @@ pub async fn create_emoji(
if let Some(hostname) = hostname { if let Some(hostname) = hostname {
create_instance(db_client, hostname).await?; create_instance(db_client, hostname).await?;
}; };
let row = db_client.query_one( let row = db_client
" .query_one(
"
INSERT INTO emoji ( INSERT INTO emoji (
id, id,
emoji_name, emoji_name,
@ -39,15 +36,17 @@ pub async fn create_emoji(
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING emoji RETURNING emoji
", ",
&[ &[
&emoji_id, &emoji_id,
&emoji_name, &emoji_name,
&hostname, &hostname,
&image, &image,
&object_id, &object_id,
&updated_at, &updated_at,
], ],
).await.map_err(catch_unique_violation("emoji"))?; )
.await
.map_err(catch_unique_violation("emoji"))?;
let emoji = row.try_get("emoji")?; let emoji = row.try_get("emoji")?;
Ok(emoji) Ok(emoji)
} }
@ -58,8 +57,9 @@ pub async fn update_emoji(
image: EmojiImage, image: EmojiImage,
updated_at: &DateTime<Utc>, updated_at: &DateTime<Utc>,
) -> Result<DbEmoji, DatabaseError> { ) -> Result<DbEmoji, DatabaseError> {
let row = db_client.query_one( let row = db_client
" .query_one(
"
UPDATE emoji UPDATE emoji
SET SET
image = $1, image = $1,
@ -67,12 +67,9 @@ pub async fn update_emoji(
WHERE id = $3 WHERE id = $3
RETURNING emoji RETURNING emoji
", ",
&[ &[&image, &updated_at, &emoji_id],
&image, )
&updated_at, .await?;
&emoji_id,
],
).await?;
let emoji: DbEmoji = row.try_get("emoji")?; let emoji: DbEmoji = row.try_get("emoji")?;
update_emoji_caches(db_client, &emoji.id).await?; update_emoji_caches(db_client, &emoji.id).await?;
Ok(emoji) Ok(emoji)
@ -82,14 +79,16 @@ pub async fn get_local_emoji_by_name(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
emoji_name: &str, emoji_name: &str,
) -> Result<DbEmoji, DatabaseError> { ) -> Result<DbEmoji, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT emoji SELECT emoji
FROM emoji FROM emoji
WHERE hostname IS NULL AND emoji_name = $1 WHERE hostname IS NULL AND emoji_name = $1
", ",
&[&emoji_name], &[&emoji_name],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?;
let emoji = row.try_get("emoji")?; let emoji = row.try_get("emoji")?;
Ok(emoji) Ok(emoji)
@ -99,15 +98,18 @@ pub async fn get_local_emojis_by_names(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
names: &[String], names: &[String],
) -> Result<Vec<DbEmoji>, DatabaseError> { ) -> Result<Vec<DbEmoji>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT emoji SELECT emoji
FROM emoji FROM emoji
WHERE hostname IS NULL AND emoji_name = ANY($1) WHERE hostname IS NULL AND emoji_name = ANY($1)
", ",
&[&names], &[&names],
).await?; )
let emojis = rows.iter() .await?;
let emojis = rows
.iter()
.map(|row| row.try_get("emoji")) .map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(emojis) Ok(emojis)
@ -116,15 +118,18 @@ pub async fn get_local_emojis_by_names(
pub async fn get_local_emojis( pub async fn get_local_emojis(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
) -> Result<Vec<DbEmoji>, DatabaseError> { ) -> Result<Vec<DbEmoji>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT emoji SELECT emoji
FROM emoji FROM emoji
WHERE hostname IS NULL WHERE hostname IS NULL
", ",
&[], &[],
).await?; )
let emojis = rows.iter() .await?;
let emojis = rows
.iter()
.map(|row| row.try_get("emoji")) .map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(emojis) Ok(emojis)
@ -135,13 +140,15 @@ pub async fn get_emoji_by_name_and_hostname(
emoji_name: &str, emoji_name: &str,
hostname: &str, hostname: &str,
) -> Result<DbEmoji, DatabaseError> { ) -> Result<DbEmoji, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT emoji SELECT emoji
FROM emoji WHERE emoji_name = $1 AND hostname = $2 FROM emoji WHERE emoji_name = $1 AND hostname = $2
", ",
&[&emoji_name, &hostname], &[&emoji_name, &hostname],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?;
let emoji = row.try_get("emoji")?; let emoji = row.try_get("emoji")?;
Ok(emoji) Ok(emoji)
@ -151,13 +158,15 @@ pub async fn get_emoji_by_remote_object_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
object_id: &str, object_id: &str,
) -> Result<DbEmoji, DatabaseError> { ) -> Result<DbEmoji, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT emoji SELECT emoji
FROM emoji WHERE object_id = $1 FROM emoji WHERE object_id = $1
", ",
&[&object_id], &[&object_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?;
let emoji = row.try_get("emoji")?; let emoji = row.try_get("emoji")?;
Ok(emoji) Ok(emoji)
@ -167,20 +176,19 @@ pub async fn delete_emoji(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
emoji_id: &Uuid, emoji_id: &Uuid,
) -> Result<DeletionQueue, DatabaseError> { ) -> Result<DeletionQueue, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
DELETE FROM emoji WHERE id = $1 DELETE FROM emoji WHERE id = $1
RETURNING emoji RETURNING emoji
", ",
&[&emoji_id], &[&emoji_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("emoji"))?;
let emoji: DbEmoji = row.try_get("emoji")?; let emoji: DbEmoji = row.try_get("emoji")?;
update_emoji_caches(db_client, &emoji.id).await?; update_emoji_caches(db_client, &emoji.id).await?;
let orphaned_files = find_orphaned_files( let orphaned_files = find_orphaned_files(db_client, vec![emoji.image.file_name]).await?;
db_client,
vec![emoji.image.file_name],
).await?;
Ok(DeletionQueue { Ok(DeletionQueue {
files: orphaned_files, files: orphaned_files,
ipfs_objects: vec![], ipfs_objects: vec![],
@ -190,8 +198,9 @@ pub async fn delete_emoji(
pub async fn find_unused_remote_emojis( pub async fn find_unused_remote_emojis(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
) -> Result<Vec<Uuid>, DatabaseError> { ) -> Result<Vec<Uuid>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT emoji.id SELECT emoji.id
FROM emoji FROM emoji
WHERE WHERE
@ -207,9 +216,11 @@ pub async fn find_unused_remote_emojis(
WHERE profile_emoji.emoji_id = emoji.id WHERE profile_emoji.emoji_id = emoji.id
) )
", ",
&[], &[],
).await?; )
let ids: Vec<Uuid> = rows.iter() .await?;
let ids: Vec<Uuid> = rows
.iter()
.map(|row| row.try_get("id")) .map(|row| row.try_get("id"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(ids) Ok(ids)
@ -217,9 +228,9 @@ pub async fn find_unused_remote_emojis(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use super::*; use super::*;
use crate::database::test_utils::create_test_database;
use serial_test::serial;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -241,11 +252,12 @@ mod tests {
image, image,
Some(object_id), Some(object_id),
&updated_at, &updated_at,
).await.unwrap(); )
let emoji = get_emoji_by_remote_object_id( .await
db_client, .unwrap();
object_id, let emoji = get_emoji_by_remote_object_id(db_client, object_id)
).await.unwrap(); .await
.unwrap();
assert_eq!(emoji.id, emoji_id); assert_eq!(emoji.id, emoji_id);
assert_eq!(emoji.emoji_name, emoji_name); assert_eq!(emoji.emoji_name, emoji_name);
assert_eq!(emoji.hostname, Some(hostname.to_string())); assert_eq!(emoji.hostname, Some(hostname.to_string()));
@ -256,20 +268,12 @@ mod tests {
async fn test_update_emoji() { async fn test_update_emoji() {
let db_client = &create_test_database().await; let db_client = &create_test_database().await;
let image = EmojiImage::default(); let image = EmojiImage::default();
let emoji = create_emoji( let emoji = create_emoji(db_client, "test", None, image.clone(), None, &Utc::now())
db_client, .await
"test", .unwrap();
None, let updated_emoji = update_emoji(db_client, &emoji.id, image, &Utc::now())
image.clone(), .await
None, .unwrap();
&Utc::now(),
).await.unwrap();
let updated_emoji = update_emoji(
db_client,
&emoji.id,
image,
&Utc::now(),
).await.unwrap();
assert_ne!(updated_emoji.updated_at, emoji.updated_at); assert_ne!(updated_emoji.updated_at, emoji.updated_at);
} }
@ -278,14 +282,9 @@ mod tests {
async fn test_delete_emoji() { async fn test_delete_emoji() {
let db_client = &create_test_database().await; let db_client = &create_test_database().await;
let image = EmojiImage::default(); let image = EmojiImage::default();
let emoji = create_emoji( let emoji = create_emoji(db_client, "test", None, image, None, &Utc::now())
db_client, .await
"test", .unwrap();
None,
image,
None,
&Utc::now(),
).await.unwrap();
let deletion_queue = delete_emoji(db_client, &emoji.id).await.unwrap(); let deletion_queue = delete_emoji(db_client, &emoji.id).await.unwrap();
assert_eq!(deletion_queue.files.len(), 1); assert_eq!(deletion_queue.files.len(), 1);
assert_eq!(deletion_queue.ipfs_objects.len(), 0); assert_eq!(deletion_queue.ipfs_objects.len(), 0);

View file

@ -6,7 +6,9 @@ use uuid::Uuid;
use crate::database::json_macro::{json_from_sql, json_to_sql}; use crate::database::json_macro::{json_from_sql, json_to_sql};
// Migration // Migration
fn default_emoji_file_size() -> usize { 250 * 1000 } fn default_emoji_file_size() -> usize {
250 * 1000
}
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(feature = "test-utils", derive(Default))] #[cfg_attr(feature = "test-utils", derive(Default))]

View file

@ -4,38 +4,38 @@ pub async fn create_instance(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
hostname: &str, hostname: &str,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
db_client.execute( db_client
" .execute(
"
INSERT INTO instance VALUES ($1) INSERT INTO instance VALUES ($1)
ON CONFLICT DO NOTHING ON CONFLICT DO NOTHING
", ",
&[&hostname], &[&hostname],
).await?; )
.await?;
Ok(()) Ok(())
} }
pub async fn get_peers( pub async fn get_peers(db_client: &impl DatabaseClient) -> Result<Vec<String>, DatabaseError> {
db_client: &impl DatabaseClient, let rows = db_client
) -> Result<Vec<String>, DatabaseError> { .query(
let rows = db_client.query( "
"
SELECT instance.hostname FROM instance SELECT instance.hostname FROM instance
", ",
&[], &[],
).await?; )
let peers = rows.iter() .await?;
let peers = rows
.iter()
.map(|row| row.try_get("hostname")) .map(|row| row.try_get("hostname"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(peers) Ok(peers)
} }
pub async fn get_peer_count( pub async fn get_peer_count(db_client: &impl DatabaseClient) -> Result<i64, DatabaseError> {
db_client: &impl DatabaseClient, let row = db_client
) -> Result<i64, DatabaseError> { .query_one("SELECT count(instance) FROM instance", &[])
let row = db_client.query_one( .await?;
"SELECT count(instance) FROM instance",
&[],
).await?;
let count = row.try_get("count")?; let count = row.try_get("count")?;
Ok(count) Ok(count)
} }

View file

@ -1,15 +1,8 @@
use uuid::Uuid; use uuid::Uuid;
use mitra_utils::{ use mitra_utils::{caip2::ChainId, id::generate_ulid};
caip2::ChainId,
id::generate_ulid,
};
use crate::database::{ use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use super::types::{DbChainId, DbInvoice, InvoiceStatus}; use super::types::{DbChainId, DbInvoice, InvoiceStatus};
@ -22,8 +15,9 @@ pub async fn create_invoice(
amount: i64, amount: i64,
) -> Result<DbInvoice, DatabaseError> { ) -> Result<DbInvoice, DatabaseError> {
let invoice_id = generate_ulid(); let invoice_id = generate_ulid();
let row = db_client.query_one( let row = db_client
" .query_one(
"
INSERT INTO invoice ( INSERT INTO invoice (
id, id,
sender_id, sender_id,
@ -35,15 +29,17 @@ pub async fn create_invoice(
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING invoice RETURNING invoice
", ",
&[ &[
&invoice_id, &invoice_id,
&sender_id, &sender_id,
&recipient_id, &recipient_id,
&DbChainId::new(chain_id), &DbChainId::new(chain_id),
&payment_address, &payment_address,
&amount, &amount,
], ],
).await.map_err(catch_unique_violation("invoice"))?; )
.await
.map_err(catch_unique_violation("invoice"))?;
let invoice = row.try_get("invoice")?; let invoice = row.try_get("invoice")?;
Ok(invoice) Ok(invoice)
} }
@ -52,13 +48,15 @@ pub async fn get_invoice_by_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
invoice_id: &Uuid, invoice_id: &Uuid,
) -> Result<DbInvoice, DatabaseError> { ) -> Result<DbInvoice, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT invoice SELECT invoice
FROM invoice WHERE id = $1 FROM invoice WHERE id = $1
", ",
&[&invoice_id], &[&invoice_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("invoice"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("invoice"))?;
let invoice = row.try_get("invoice")?; let invoice = row.try_get("invoice")?;
Ok(invoice) Ok(invoice)
@ -69,13 +67,15 @@ pub async fn get_invoice_by_address(
chain_id: &ChainId, chain_id: &ChainId,
payment_address: &str, payment_address: &str,
) -> Result<DbInvoice, DatabaseError> { ) -> Result<DbInvoice, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT invoice SELECT invoice
FROM invoice WHERE chain_id = $1 AND payment_address = $2 FROM invoice WHERE chain_id = $1 AND payment_address = $2
", ",
&[&DbChainId::new(chain_id), &payment_address], &[&DbChainId::new(chain_id), &payment_address],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("invoice"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("invoice"))?;
let invoice = row.try_get("invoice")?; let invoice = row.try_get("invoice")?;
Ok(invoice) Ok(invoice)
@ -86,14 +86,17 @@ pub async fn get_invoices_by_status(
chain_id: &ChainId, chain_id: &ChainId,
status: InvoiceStatus, status: InvoiceStatus,
) -> Result<Vec<DbInvoice>, DatabaseError> { ) -> Result<Vec<DbInvoice>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT invoice SELECT invoice
FROM invoice WHERE chain_id = $1 AND invoice_status = $2 FROM invoice WHERE chain_id = $1 AND invoice_status = $2
", ",
&[&DbChainId::new(chain_id), &status], &[&DbChainId::new(chain_id), &status],
).await?; )
let invoices = rows.iter() .await?;
let invoices = rows
.iter()
.map(|row| row.try_get("invoice")) .map(|row| row.try_get("invoice"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(invoices) Ok(invoices)
@ -104,13 +107,15 @@ pub async fn set_invoice_status(
invoice_id: &Uuid, invoice_id: &Uuid,
status: InvoiceStatus, status: InvoiceStatus,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let updated_count = db_client.execute( let updated_count = db_client
" .execute(
"
UPDATE invoice SET invoice_status = $2 UPDATE invoice SET invoice_status = $2
WHERE id = $1 WHERE id = $1
", ",
&[&invoice_id, &status], &[&invoice_id, &status],
).await?; )
.await?;
if updated_count == 0 { if updated_count == 0 {
return Err(DatabaseError::NotFound("invoice")); return Err(DatabaseError::NotFound("invoice"));
}; };
@ -119,17 +124,11 @@ pub async fn set_invoice_status(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use crate::profiles::{
queries::create_profile,
types::ProfileCreateData,
};
use crate::users::{
queries::create_user,
types::UserCreateData,
};
use super::*; use super::*;
use crate::database::test_utils::create_test_database;
use crate::profiles::{queries::create_profile, types::ProfileCreateData};
use crate::users::{queries::create_user, types::UserCreateData};
use serial_test::serial;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -159,7 +158,9 @@ mod tests {
&chain_id, &chain_id,
payment_address, payment_address,
amount, amount,
).await.unwrap(); )
.await
.unwrap();
assert_eq!(invoice.sender_id, sender.id); assert_eq!(invoice.sender_id, sender.id);
assert_eq!(invoice.recipient_id, recipient.id); assert_eq!(invoice.recipient_id, recipient.id);
assert_eq!(invoice.chain_id.into_inner(), chain_id); assert_eq!(invoice.chain_id.into_inner(), chain_id);

View file

@ -1,14 +1,6 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use postgres_protocol::types::{text_from_sql, text_to_sql}; use postgres_protocol::types::{text_from_sql, text_to_sql};
use postgres_types::{ use postgres_types::{accepts, private::BytesMut, to_sql_checked, FromSql, IsNull, ToSql, Type};
accepts,
private::BytesMut,
to_sql_checked,
FromSql,
IsNull,
ToSql,
Type,
};
use uuid::Uuid; use uuid::Uuid;
use mitra_utils::caip2::ChainId; use mitra_utils::caip2::ChainId;
@ -44,10 +36,7 @@ impl PartialEq<ChainId> for DbChainId {
} }
impl<'a> FromSql<'a> for DbChainId { impl<'a> FromSql<'a> for DbChainId {
fn from_sql( fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
_: &Type,
raw: &'a [u8],
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
let value_str = text_from_sql(raw)?; let value_str = text_from_sql(raw)?;
let value: ChainId = value_str.parse()?; let value: ChainId = value_str.parse()?;
Ok(Self(value)) Ok(Self(value))

View file

@ -1,7 +1,7 @@
pub mod attachments; pub mod attachments;
pub mod background_jobs; pub mod background_jobs;
pub mod database;
pub mod cleanup; pub mod cleanup;
pub mod database;
pub mod emojis; pub mod emojis;
pub mod instances; pub mod instances;
pub mod invoices; pub mod invoices;

View file

@ -1,7 +1,7 @@
use uuid::Uuid; use uuid::Uuid;
use crate::database::{DatabaseClient, DatabaseError};
use super::types::{DbTimelineMarker, Timeline}; use super::types::{DbTimelineMarker, Timeline};
use crate::database::{DatabaseClient, DatabaseError};
pub async fn create_or_update_marker( pub async fn create_or_update_marker(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
@ -9,16 +9,18 @@ pub async fn create_or_update_marker(
timeline: Timeline, timeline: Timeline,
last_read_id: String, last_read_id: String,
) -> Result<DbTimelineMarker, DatabaseError> { ) -> Result<DbTimelineMarker, DatabaseError> {
let row = db_client.query_one( let row = db_client
" .query_one(
"
INSERT INTO timeline_marker (user_id, timeline, last_read_id) INSERT INTO timeline_marker (user_id, timeline, last_read_id)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (user_id, timeline) DO UPDATE ON CONFLICT (user_id, timeline) DO UPDATE
SET last_read_id = $3, updated_at = now() SET last_read_id = $3, updated_at = now()
RETURNING timeline_marker RETURNING timeline_marker
", ",
&[&user_id, &timeline, &last_read_id], &[&user_id, &timeline, &last_read_id],
).await?; )
.await?;
let marker = row.try_get("timeline_marker")?; let marker = row.try_get("timeline_marker")?;
Ok(marker) Ok(marker)
} }
@ -28,14 +30,16 @@ pub async fn get_marker_opt(
user_id: &Uuid, user_id: &Uuid,
timeline: Timeline, timeline: Timeline,
) -> Result<Option<DbTimelineMarker>, DatabaseError> { ) -> Result<Option<DbTimelineMarker>, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT timeline_marker SELECT timeline_marker
FROM timeline_marker FROM timeline_marker
WHERE user_id = $1 AND timeline = $2 WHERE user_id = $1 AND timeline = $2
", ",
&[&user_id, &timeline], &[&user_id, &timeline],
).await?; )
.await?;
let maybe_marker = match maybe_row { let maybe_marker = match maybe_row {
Some(row) => row.try_get("timeline_marker")?, Some(row) => row.try_get("timeline_marker")?,
None => None, None => None,

View file

@ -3,13 +3,7 @@ use uuid::Uuid;
use crate::database::{DatabaseClient, DatabaseError}; use crate::database::{DatabaseClient, DatabaseError};
use crate::posts::{ use crate::posts::{
helpers::{add_related_posts, add_user_actions}, helpers::{add_related_posts, add_user_actions},
queries::{ queries::{RELATED_ATTACHMENTS, RELATED_EMOJIS, RELATED_LINKS, RELATED_MENTIONS, RELATED_TAGS},
RELATED_ATTACHMENTS,
RELATED_EMOJIS,
RELATED_LINKS,
RELATED_MENTIONS,
RELATED_TAGS,
},
}; };
use super::types::{EventType, Notification}; use super::types::{EventType, Notification};
@ -21,8 +15,9 @@ async fn create_notification(
post_id: Option<&Uuid>, post_id: Option<&Uuid>,
event_type: EventType, event_type: EventType,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
db_client.execute( db_client
" .execute(
"
INSERT INTO notification ( INSERT INTO notification (
sender_id, sender_id,
recipient_id, recipient_id,
@ -31,8 +26,9 @@ async fn create_notification(
) )
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
", ",
&[&sender_id, &recipient_id, &post_id, &event_type], &[&sender_id, &recipient_id, &post_id, &event_type],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -41,10 +37,7 @@ pub async fn create_follow_notification(
sender_id: &Uuid, sender_id: &Uuid,
recipient_id: &Uuid, recipient_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
create_notification( create_notification(db_client, sender_id, recipient_id, None, EventType::Follow).await
db_client, sender_id, recipient_id, None,
EventType::Follow,
).await
} }
pub async fn create_reply_notification( pub async fn create_reply_notification(
@ -54,9 +47,13 @@ pub async fn create_reply_notification(
post_id: &Uuid, post_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
create_notification( create_notification(
db_client, sender_id, recipient_id, Some(post_id), db_client,
sender_id,
recipient_id,
Some(post_id),
EventType::Reply, EventType::Reply,
).await )
.await
} }
pub async fn create_reaction_notification( pub async fn create_reaction_notification(
@ -66,9 +63,13 @@ pub async fn create_reaction_notification(
post_id: &Uuid, post_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
create_notification( create_notification(
db_client, sender_id, recipient_id, Some(post_id), db_client,
sender_id,
recipient_id,
Some(post_id),
EventType::Reaction, EventType::Reaction,
).await )
.await
} }
pub async fn create_mention_notification( pub async fn create_mention_notification(
@ -78,9 +79,13 @@ pub async fn create_mention_notification(
post_id: &Uuid, post_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
create_notification( create_notification(
db_client, sender_id, recipient_id, Some(post_id), db_client,
sender_id,
recipient_id,
Some(post_id),
EventType::Mention, EventType::Mention,
).await )
.await
} }
pub async fn create_repost_notification( pub async fn create_repost_notification(
@ -90,9 +95,13 @@ pub async fn create_repost_notification(
post_id: &Uuid, post_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
create_notification( create_notification(
db_client, sender_id, recipient_id, Some(post_id), db_client,
sender_id,
recipient_id,
Some(post_id),
EventType::Repost, EventType::Repost,
).await )
.await
} }
pub async fn create_subscription_notification( pub async fn create_subscription_notification(
@ -101,9 +110,13 @@ pub async fn create_subscription_notification(
recipient_id: &Uuid, recipient_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
create_notification( create_notification(
db_client, sender_id, recipient_id, None, db_client,
sender_id,
recipient_id,
None,
EventType::Subscription, EventType::Subscription,
).await )
.await
} }
pub async fn create_subscription_expiration_notification( pub async fn create_subscription_expiration_notification(
@ -112,9 +125,13 @@ pub async fn create_subscription_expiration_notification(
recipient_id: &Uuid, recipient_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
create_notification( create_notification(
db_client, sender_id, recipient_id, None, db_client,
sender_id,
recipient_id,
None,
EventType::SubscriptionExpiration, EventType::SubscriptionExpiration,
).await )
.await
} }
pub async fn create_move_notification( pub async fn create_move_notification(
@ -122,10 +139,7 @@ pub async fn create_move_notification(
sender_id: &Uuid, sender_id: &Uuid,
recipient_id: &Uuid, recipient_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
create_notification( create_notification(db_client, sender_id, recipient_id, None, EventType::Move).await
db_client, sender_id, recipient_id, None,
EventType::Move,
).await
} }
pub async fn get_notifications( pub async fn get_notifications(
@ -156,31 +170,35 @@ pub async fn get_notifications(
ORDER BY notification.id DESC ORDER BY notification.id DESC
LIMIT $3 LIMIT $3
", ",
related_attachments=RELATED_ATTACHMENTS, related_attachments = RELATED_ATTACHMENTS,
related_mentions=RELATED_MENTIONS, related_mentions = RELATED_MENTIONS,
related_tags=RELATED_TAGS, related_tags = RELATED_TAGS,
related_links=RELATED_LINKS, related_links = RELATED_LINKS,
related_emojis=RELATED_EMOJIS, related_emojis = RELATED_EMOJIS,
); );
let rows = db_client.query( let rows = db_client
&statement, .query(&statement, &[&recipient_id, &max_id, &i64::from(limit)])
&[&recipient_id, &max_id, &i64::from(limit)], .await?;
).await?; let mut notifications: Vec<Notification> = rows
let mut notifications: Vec<Notification> = rows.iter() .iter()
.map(Notification::try_from) .map(Notification::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
add_related_posts( add_related_posts(
db_client, db_client,
notifications.iter_mut() notifications
.iter_mut()
.filter_map(|item| item.post.as_mut()) .filter_map(|item| item.post.as_mut())
.collect(), .collect(),
).await?; )
.await?;
add_user_actions( add_user_actions(
db_client, db_client,
recipient_id, recipient_id,
notifications.iter_mut() notifications
.iter_mut()
.filter_map(|item| item.post.as_mut()) .filter_map(|item| item.post.as_mut())
.collect(), .collect(),
).await?; )
.await?;
Ok(notifications) Ok(notifications)
} }

View file

@ -6,8 +6,7 @@ use uuid::Uuid;
use crate::attachments::types::DbMediaAttachment; use crate::attachments::types::DbMediaAttachment;
use crate::database::{ use crate::database::{
int_enum::{int_enum_from_sql, int_enum_to_sql}, int_enum::{int_enum_from_sql, int_enum_to_sql},
DatabaseError, DatabaseError, DatabaseTypeError,
DatabaseTypeError,
}; };
use crate::emojis::types::DbEmoji; use crate::emojis::types::DbEmoji;
use crate::posts::types::{DbPost, Post}; use crate::posts::types::{DbPost, Post};
@ -89,7 +88,6 @@ pub struct Notification {
} }
impl TryFrom<&Row> for Notification { impl TryFrom<&Row> for Notification {
type Error = DatabaseError; type Error = DatabaseError;
fn try_from(row: &Row) -> Result<Self, Self::Error> { fn try_from(row: &Row) -> Result<Self, Self::Error> {
@ -114,7 +112,7 @@ impl TryFrom<&Row> for Notification {
db_emojis, db_emojis,
)?; )?;
Some(post) Some(post)
}, }
None => None, None => None,
}; };
let notification = Self { let notification = Self {

View file

@ -1,2 +1,2 @@
pub mod types;
pub mod queries; pub mod queries;
pub mod types;

View file

@ -1,11 +1,7 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid; use uuid::Uuid;
use crate::database::{ use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::profiles::types::DbActorProfile; use crate::profiles::types::DbActorProfile;
use crate::users::types::{DbUser, User}; use crate::users::types::{DbUser, User};
@ -15,8 +11,9 @@ pub async fn create_oauth_app(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
app_data: DbOauthAppData, app_data: DbOauthAppData,
) -> Result<DbOauthApp, DatabaseError> { ) -> Result<DbOauthApp, DatabaseError> {
let row = db_client.query_one( let row = db_client
" .query_one(
"
INSERT INTO oauth_application ( INSERT INTO oauth_application (
app_name, app_name,
website, website,
@ -28,15 +25,17 @@ pub async fn create_oauth_app(
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING oauth_application RETURNING oauth_application
", ",
&[ &[
&app_data.app_name, &app_data.app_name,
&app_data.website, &app_data.website,
&app_data.scopes, &app_data.scopes,
&app_data.redirect_uri, &app_data.redirect_uri,
&app_data.client_id, &app_data.client_id,
&app_data.client_secret, &app_data.client_secret,
], ],
).await.map_err(catch_unique_violation("oauth_application"))?; )
.await
.map_err(catch_unique_violation("oauth_application"))?;
let app = row.try_get("oauth_application")?; let app = row.try_get("oauth_application")?;
Ok(app) Ok(app)
} }
@ -45,14 +44,16 @@ pub async fn get_oauth_app_by_client_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
client_id: &Uuid, client_id: &Uuid,
) -> Result<DbOauthApp, DatabaseError> { ) -> Result<DbOauthApp, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT oauth_application SELECT oauth_application
FROM oauth_application FROM oauth_application
WHERE client_id = $1 WHERE client_id = $1
", ",
&[&client_id], &[&client_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("oauth application"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("oauth application"))?;
let app = row.try_get("oauth_application")?; let app = row.try_get("oauth_application")?;
Ok(app) Ok(app)
@ -67,8 +68,9 @@ pub async fn create_oauth_authorization(
created_at: &DateTime<Utc>, created_at: &DateTime<Utc>,
expires_at: &DateTime<Utc>, expires_at: &DateTime<Utc>,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
db_client.execute( db_client
" .execute(
"
INSERT INTO oauth_authorization ( INSERT INTO oauth_authorization (
code, code,
user_id, user_id,
@ -79,15 +81,16 @@ pub async fn create_oauth_authorization(
) )
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
", ",
&[ &[
&authorization_code, &authorization_code,
&user_id, &user_id,
&application_id, &application_id,
&scopes, &scopes,
&created_at, &created_at,
&expires_at, &expires_at,
], ],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -95,8 +98,9 @@ pub async fn get_user_by_authorization_code(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
authorization_code: &str, authorization_code: &str,
) -> Result<User, DatabaseError> { ) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT user_account, actor_profile SELECT user_account, actor_profile
FROM oauth_authorization FROM oauth_authorization
JOIN user_account ON oauth_authorization.user_id = user_account.id JOIN user_account ON oauth_authorization.user_id = user_account.id
@ -105,8 +109,9 @@ pub async fn get_user_by_authorization_code(
oauth_authorization.code = $1 oauth_authorization.code = $1
AND oauth_authorization.expires_at > CURRENT_TIMESTAMP AND oauth_authorization.expires_at > CURRENT_TIMESTAMP
", ",
&[&authorization_code], &[&authorization_code],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("authorization"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("authorization"))?;
let db_user: DbUser = row.try_get("user_account")?; let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?; let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -121,13 +126,15 @@ pub async fn save_oauth_token(
created_at: &DateTime<Utc>, created_at: &DateTime<Utc>,
expires_at: &DateTime<Utc>, expires_at: &DateTime<Utc>,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
db_client.execute( db_client
" .execute(
"
INSERT INTO oauth_token (owner_id, token, created_at, expires_at) INSERT INTO oauth_token (owner_id, token, created_at, expires_at)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
", ",
&[&owner_id, &token, &created_at, &expires_at], &[&owner_id, &token, &created_at, &expires_at],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -137,24 +144,25 @@ pub async fn delete_oauth_token(
token: &str, token: &str,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
let maybe_row = transaction.query_opt( let maybe_row = transaction
" .query_opt(
"
SELECT owner_id FROM oauth_token SELECT owner_id FROM oauth_token
WHERE token = $1 WHERE token = $1
FOR UPDATE FOR UPDATE
", ",
&[&token], &[&token],
).await?; )
.await?;
if let Some(row) = maybe_row { if let Some(row) = maybe_row {
let owner_id: Uuid = row.try_get("owner_id")?; let owner_id: Uuid = row.try_get("owner_id")?;
if owner_id != *current_user_id { if owner_id != *current_user_id {
// Return error if token is owned by a different user // Return error if token is owned by a different user
return Err(DatabaseError::NotFound("token")); return Err(DatabaseError::NotFound("token"));
} else { } else {
transaction.execute( transaction
"DELETE FROM oauth_token WHERE token = $1", .execute("DELETE FROM oauth_token WHERE token = $1", &[&token])
&[&token], .await?;
).await?;
}; };
}; };
transaction.commit().await?; transaction.commit().await?;
@ -165,10 +173,9 @@ pub async fn delete_oauth_tokens(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
owner_id: &Uuid, owner_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
db_client.execute( db_client
"DELETE FROM oauth_token WHERE owner_id = $1", .execute("DELETE FROM oauth_token WHERE owner_id = $1", &[&owner_id])
&[&owner_id], .await?;
).await?;
Ok(()) Ok(())
} }
@ -176,8 +183,9 @@ pub async fn get_user_by_oauth_token(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
access_token: &str, access_token: &str,
) -> Result<User, DatabaseError> { ) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT user_account, actor_profile SELECT user_account, actor_profile
FROM oauth_token FROM oauth_token
JOIN user_account ON oauth_token.owner_id = user_account.id JOIN user_account ON oauth_token.owner_id = user_account.id
@ -186,8 +194,9 @@ pub async fn get_user_by_oauth_token(
oauth_token.token = $1 oauth_token.token = $1
AND oauth_token.expires_at > CURRENT_TIMESTAMP AND oauth_token.expires_at > CURRENT_TIMESTAMP
", ",
&[&access_token], &[&access_token],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?; let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?; let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -197,13 +206,10 @@ pub async fn get_user_by_oauth_token(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use crate::users::{
queries::create_user,
types::UserCreateData,
};
use super::*; use super::*;
use crate::database::test_utils::create_test_database;
use crate::users::{queries::create_user, types::UserCreateData};
use serial_test::serial;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -240,7 +246,9 @@ mod tests {
"read write", "read write",
&Utc::now(), &Utc::now(),
&Utc::now(), &Utc::now(),
).await.unwrap(); )
.await
.unwrap();
} }
#[tokio::test] #[tokio::test]
@ -254,17 +262,11 @@ mod tests {
}; };
let user = create_user(db_client, user_data).await.unwrap(); let user = create_user(db_client, user_data).await.unwrap();
let token = "test-token"; let token = "test-token";
save_oauth_token( save_oauth_token(db_client, &user.id, token, &Utc::now(), &Utc::now())
db_client, .await
&user.id, .unwrap();
token, delete_oauth_token(db_client, &user.id, token)
&Utc::now(), .await
&Utc::now(), .unwrap();
).await.unwrap();
delete_oauth_token(
db_client,
&user.id,
token,
).await.unwrap();
} }
} }

View file

@ -2,17 +2,10 @@ use uuid::Uuid;
use crate::database::{DatabaseClient, DatabaseError}; use crate::database::{DatabaseClient, DatabaseError};
use crate::reactions::queries::find_favourited_by_user; use crate::reactions::queries::find_favourited_by_user;
use crate::relationships::{ use crate::relationships::{queries::has_relationship, types::RelationshipType};
queries::has_relationship,
types::RelationshipType,
};
use crate::users::types::{Permission, User}; use crate::users::types::{Permission, User};
use super::queries::{ use super::queries::{find_reposted_by_user, get_post_by_id, get_related_posts};
get_post_by_id,
get_related_posts,
find_reposted_by_user,
};
use super::types::{Post, PostActions, Visibility}; use super::types::{Post, PostActions, Visibility};
pub async fn add_related_posts( pub async fn add_related_posts(
@ -22,7 +15,8 @@ pub async fn add_related_posts(
let posts_ids = posts.iter().map(|post| post.id).collect(); let posts_ids = posts.iter().map(|post| post.id).collect();
let related = get_related_posts(db_client, posts_ids).await?; let related = get_related_posts(db_client, posts_ids).await?;
let get_post = |post_id: &Uuid| -> Result<Post, DatabaseError> { let get_post = |post_id: &Uuid| -> Result<Post, DatabaseError> {
let post = related.iter() let post = related
.iter()
.find(|post| post.id == *post_id) .find(|post| post.id == *post_id)
.ok_or(DatabaseError::NotFound("post"))? .ok_or(DatabaseError::NotFound("post"))?
.clone(); .clone();
@ -38,14 +32,14 @@ pub async fn add_related_posts(
for linked_id in repost_of.links.iter() { for linked_id in repost_of.links.iter() {
let linked = get_post(linked_id)?; let linked = get_post(linked_id)?;
repost_of.linked.push(linked); repost_of.linked.push(linked);
}; }
post.repost_of = Some(Box::new(repost_of)); post.repost_of = Some(Box::new(repost_of));
}; };
for linked_id in post.links.iter() { for linked_id in post.links.iter() {
let linked = get_post(linked_id)?; let linked = get_post(linked_id)?;
post.linked.push(linked); post.linked.push(linked);
}; }
}; }
Ok(()) Ok(())
} }
@ -54,12 +48,14 @@ pub async fn add_user_actions(
user_id: &Uuid, user_id: &Uuid,
posts: Vec<&mut Post>, posts: Vec<&mut Post>,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let posts_ids: Vec<Uuid> = posts.iter() let posts_ids: Vec<Uuid> = posts
.iter()
.map(|post| post.id) .map(|post| post.id)
.chain( .chain(
posts.iter() posts
.iter()
.filter_map(|post| post.repost_of.as_ref()) .filter_map(|post| post.repost_of.as_ref())
.map(|post| post.id) .map(|post| post.id),
) )
.collect(); .collect();
let favourites = find_favourited_by_user(db_client, user_id, &posts_ids).await?; let favourites = find_favourited_by_user(db_client, user_id, &posts_ids).await?;
@ -87,7 +83,9 @@ pub async fn can_view_post(
post: &Post, post: &Post,
) -> Result<bool, DatabaseError> { ) -> Result<bool, DatabaseError> {
let is_mentioned = |user: &User| { let is_mentioned = |user: &User| {
post.mentions.iter().any(|profile| profile.id == user.profile.id) post.mentions
.iter()
.any(|profile| profile.id == user.profile.id)
}; };
let result = match post.visibility { let result = match post.visibility {
Visibility::Public => true, Visibility::Public => true,
@ -98,7 +96,7 @@ pub async fn can_view_post(
} else { } else {
false false
} }
}, }
Visibility::Followers => { Visibility::Followers => {
if let Some(user) = user { if let Some(user) = user {
let is_following = has_relationship( let is_following = has_relationship(
@ -106,12 +104,13 @@ pub async fn can_view_post(
&user.id, &user.id,
&post.author.id, &post.author.id,
RelationshipType::Follow, RelationshipType::Follow,
).await?; )
.await?;
is_following || is_mentioned(user) is_following || is_mentioned(user)
} else { } else {
false false
} }
}, }
Visibility::Subscribers => { Visibility::Subscribers => {
if let Some(user) = user { if let Some(user) = user {
// Can view only if mentioned // Can view only if mentioned
@ -119,14 +118,12 @@ pub async fn can_view_post(
} else { } else {
false false
} }
}, }
}; };
Ok(result) Ok(result)
} }
pub fn can_create_post( pub fn can_create_post(user: &User) -> bool {
user: &User,
) -> bool {
user.role.has_permission(Permission::CreatePost) user.role.has_permission(Permission::CreatePost)
} }
@ -143,19 +140,16 @@ pub async fn get_local_post_by_id(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial; use super::*;
use tokio_postgres::Client;
use crate::database::test_utils::create_test_database; use crate::database::test_utils::create_test_database;
use crate::posts::{ use crate::posts::{queries::create_post, types::PostCreateData};
queries::create_post,
types::PostCreateData,
};
use crate::relationships::queries::{follow, subscribe}; use crate::relationships::queries::{follow, subscribe};
use crate::users::{ use crate::users::{
queries::create_user, queries::create_user,
types::{Role, User, UserCreateData}, types::{Role, User, UserCreateData},
}; };
use super::*; use serial_test::serial;
use tokio_postgres::Client;
async fn create_test_user(db_client: &mut Client, username: &str) -> User { async fn create_test_user(db_client: &mut Client, username: &str) -> User {
let user_data = UserCreateData { let user_data = UserCreateData {
@ -181,8 +175,12 @@ mod tests {
in_reply_to_id: Some(post.id.clone()), in_reply_to_id: Some(post.id.clone()),
..Default::default() ..Default::default()
}; };
let mut reply = create_post(db_client, &author.id, reply_data).await.unwrap(); let mut reply = create_post(db_client, &author.id, reply_data)
add_related_posts(db_client, vec![&mut reply]).await.unwrap(); .await
.unwrap();
add_related_posts(db_client, vec![&mut reply])
.await
.unwrap();
assert_eq!(reply.in_reply_to.unwrap().id, post.id); assert_eq!(reply.in_reply_to.unwrap().id, post.id);
assert_eq!(reply.repost_of.is_none(), true); assert_eq!(reply.repost_of.is_none(), true);
assert_eq!(reply.linked.is_empty(), true); assert_eq!(reply.linked.is_empty(), true);
@ -253,7 +251,9 @@ mod tests {
visibility: Visibility::Followers, visibility: Visibility::Followers,
..Default::default() ..Default::default()
}; };
let result = can_view_post(db_client, Some(&follower), &post).await.unwrap(); let result = can_view_post(db_client, Some(&follower), &post)
.await
.unwrap();
assert_eq!(result, true); assert_eq!(result, true);
} }
@ -265,23 +265,26 @@ mod tests {
let follower = create_test_user(db_client, "follower").await; let follower = create_test_user(db_client, "follower").await;
follow(db_client, &follower.id, &author.id).await.unwrap(); follow(db_client, &follower.id, &author.id).await.unwrap();
let subscriber = create_test_user(db_client, "subscriber").await; let subscriber = create_test_user(db_client, "subscriber").await;
subscribe(db_client, &subscriber.id, &author.id).await.unwrap(); subscribe(db_client, &subscriber.id, &author.id)
.await
.unwrap();
let post = Post { let post = Post {
author: author.profile, author: author.profile,
visibility: Visibility::Subscribers, visibility: Visibility::Subscribers,
mentions: vec![subscriber.profile.clone()], mentions: vec![subscriber.profile.clone()],
..Default::default() ..Default::default()
}; };
assert_eq!(can_view_post(db_client, None, &post).await.unwrap(), false,);
assert_eq!( assert_eq!(
can_view_post(db_client, None, &post).await.unwrap(), can_view_post(db_client, Some(&follower), &post)
.await
.unwrap(),
false, false,
); );
assert_eq!( assert_eq!(
can_view_post(db_client, Some(&follower), &post).await.unwrap(), can_view_post(db_client, Some(&subscriber), &post)
false, .await
); .unwrap(),
assert_eq!(
can_view_post(db_client, Some(&subscriber), &post).await.unwrap(),
true, true,
); );
} }

File diff suppressed because it is too large Load diff

View file

@ -6,8 +6,7 @@ use uuid::Uuid;
use crate::attachments::types::DbMediaAttachment; use crate::attachments::types::DbMediaAttachment;
use crate::database::{ use crate::database::{
int_enum::{int_enum_from_sql, int_enum_to_sql}, int_enum::{int_enum_from_sql, int_enum_to_sql},
DatabaseError, DatabaseError, DatabaseTypeError,
DatabaseTypeError,
}; };
use crate::emojis::types::DbEmoji; use crate::emojis::types::DbEmoji;
use crate::profiles::types::DbActorProfile; use crate::profiles::types::DbActorProfile;
@ -21,7 +20,9 @@ pub enum Visibility {
} }
impl Default for Visibility { impl Default for Visibility {
fn default() -> Self { Self::Public } fn default() -> Self {
Self::Public
}
} }
impl From<&Visibility> for i16 { impl From<&Visibility> for i16 {
@ -130,19 +131,19 @@ impl Post {
if db_author.is_local() != db_post.object_id.is_none() { if db_author.is_local() != db_post.object_id.is_none() {
return Err(DatabaseTypeError); return Err(DatabaseTypeError);
}; };
if db_post.repost_of_id.is_some() && ( if db_post.repost_of_id.is_some()
db_post.content.len() != 0 || && (db_post.content.len() != 0
db_post.is_sensitive || || db_post.is_sensitive
db_post.in_reply_to_id.is_some() || || db_post.in_reply_to_id.is_some()
db_post.ipfs_cid.is_some() || || db_post.ipfs_cid.is_some()
db_post.token_id.is_some() || || db_post.token_id.is_some()
db_post.token_tx_id.is_some() || || db_post.token_tx_id.is_some()
!db_attachments.is_empty() || || !db_attachments.is_empty()
!db_mentions.is_empty() || || !db_mentions.is_empty()
!db_tags.is_empty() || || !db_tags.is_empty()
!db_links.is_empty() || || !db_links.is_empty()
!db_emojis.is_empty() || !db_emojis.is_empty())
) { {
return Err(DatabaseTypeError); return Err(DatabaseTypeError);
}; };
let post = Self { let post = Self {
@ -218,7 +219,6 @@ impl Default for Post {
} }
impl TryFrom<&Row> for Post { impl TryFrom<&Row> for Post {
type Error = DatabaseError; type Error = DatabaseError;
fn try_from(row: &Row) -> Result<Self, Self::Error> { fn try_from(row: &Row) -> Result<Self, Self::Error> {
@ -259,10 +259,7 @@ pub struct PostCreateData {
} }
impl PostCreateData { impl PostCreateData {
pub fn repost( pub fn repost(repost_of_id: Uuid, object_id: Option<String>) -> Self {
repost_of_id: Uuid,
object_id: Option<String>,
) -> Self {
Self { Self {
repost_of_id: Some(repost_of_id), repost_of_id: Some(repost_of_id),
object_id: object_id, object_id: object_id,

View file

@ -1,9 +1,6 @@
use crate::database::{DatabaseClient, DatabaseError}; use crate::database::{DatabaseClient, DatabaseError};
use super::queries::{ use super::queries::{get_profile_by_remote_actor_id, search_profiles_by_did_only};
get_profile_by_remote_actor_id,
search_profiles_by_did_only,
};
use super::types::DbActorProfile; use super::types::DbActorProfile;
pub async fn find_declared_aliases( pub async fn find_declared_aliases(
@ -12,17 +9,14 @@ pub async fn find_declared_aliases(
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let mut results = vec![]; let mut results = vec![];
for actor_id in profile.aliases.clone().into_actor_ids() { for actor_id in profile.aliases.clone().into_actor_ids() {
let alias = match get_profile_by_remote_actor_id( let alias = match get_profile_by_remote_actor_id(db_client, &actor_id).await {
db_client,
&actor_id,
).await {
Ok(profile) => profile, Ok(profile) => profile,
// Ignore unknown profiles // Ignore unknown profiles
Err(DatabaseError::NotFound(_)) => continue, Err(DatabaseError::NotFound(_)) => continue,
Err(other_error) => return Err(other_error), Err(other_error) => return Err(other_error),
}; };
results.push(alias); results.push(alias);
}; }
Ok(results) Ok(results)
} }
@ -32,16 +26,13 @@ pub async fn find_verified_aliases(
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let mut results = vec![]; let mut results = vec![];
for identity_proof in profile.identity_proofs.inner() { for identity_proof in profile.identity_proofs.inner() {
let aliases = search_profiles_by_did_only( let aliases = search_profiles_by_did_only(db_client, &identity_proof.issuer).await?;
db_client,
&identity_proof.issuer,
).await?;
for alias in aliases { for alias in aliases {
if alias.id == profile.id { if alias.id == profile.id {
continue; continue;
}; };
results.push(alias); results.push(alias);
}; }
}; }
Ok(results) Ok(results)
} }

View file

@ -1,35 +1,16 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use uuid::Uuid; use uuid::Uuid;
use mitra_utils::{ use mitra_utils::{currencies::Currency, did::Did, did_pkh::DidPkh, id::generate_ulid};
currencies::Currency,
did::Did,
did_pkh::DidPkh,
id::generate_ulid,
};
use crate::cleanup::{ use crate::cleanup::{find_orphaned_files, find_orphaned_ipfs_objects, DeletionQueue};
find_orphaned_files, use crate::database::{catch_unique_violation, query_macro::query, DatabaseClient, DatabaseError};
find_orphaned_ipfs_objects,
DeletionQueue,
};
use crate::database::{
catch_unique_violation,
query_macro::query,
DatabaseClient,
DatabaseError,
};
use crate::emojis::types::DbEmoji; use crate::emojis::types::DbEmoji;
use crate::instances::queries::create_instance; use crate::instances::queries::create_instance;
use crate::relationships::types::RelationshipType; use crate::relationships::types::RelationshipType;
use super::types::{ use super::types::{
Aliases, Aliases, DbActorProfile, ExtraFields, IdentityProofs, PaymentOptions, ProfileCreateData,
DbActorProfile,
ExtraFields,
IdentityProofs,
PaymentOptions,
ProfileCreateData,
ProfileUpdateData, ProfileUpdateData,
}; };
@ -38,8 +19,9 @@ async fn create_profile_emojis(
profile_id: &Uuid, profile_id: &Uuid,
emojis: Vec<Uuid>, emojis: Vec<Uuid>,
) -> Result<Vec<DbEmoji>, DatabaseError> { ) -> Result<Vec<DbEmoji>, DatabaseError> {
let emojis_rows = db_client.query( let emojis_rows = db_client
" .query(
"
INSERT INTO profile_emoji (profile_id, emoji_id) INSERT INTO profile_emoji (profile_id, emoji_id)
SELECT $1, emoji.id FROM emoji WHERE id = ANY($2) SELECT $1, emoji.id FROM emoji WHERE id = ANY($2)
RETURNING ( RETURNING (
@ -47,12 +29,14 @@ async fn create_profile_emojis(
WHERE emoji.id = emoji_id WHERE emoji.id = emoji_id
) )
", ",
&[&profile_id, &emojis], &[&profile_id, &emojis],
).await?; )
.await?;
if emojis_rows.len() != emojis.len() { if emojis_rows.len() != emojis.len() {
return Err(DatabaseError::NotFound("emoji")); return Err(DatabaseError::NotFound("emoji"));
}; };
let emojis = emojis_rows.iter() let emojis = emojis_rows
.iter()
.map(|row| row.try_get("emoji")) .map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(emojis) Ok(emojis)
@ -62,8 +46,9 @@ async fn update_emoji_cache(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
profile_id: &Uuid, profile_id: &Uuid,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
WITH profile_emojis AS ( WITH profile_emojis AS (
SELECT SELECT
actor_profile.id AS profile_id, actor_profile.id AS profile_id,
@ -83,8 +68,9 @@ async fn update_emoji_cache(
WHERE actor_profile.id = profile_emojis.profile_id WHERE actor_profile.id = profile_emojis.profile_id
RETURNING actor_profile RETURNING actor_profile
", ",
&[&profile_id], &[&profile_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile: DbActorProfile = row.try_get("actor_profile")?; let profile: DbActorProfile = row.try_get("actor_profile")?;
Ok(profile) Ok(profile)
@ -94,8 +80,9 @@ pub async fn update_emoji_caches(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
emoji_id: &Uuid, emoji_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
db_client.execute( db_client
" .execute(
"
WITH profile_emojis AS ( WITH profile_emojis AS (
SELECT SELECT
actor_profile.id AS profile_id, actor_profile.id AS profile_id,
@ -115,8 +102,9 @@ pub async fn update_emoji_caches(
FROM profile_emojis FROM profile_emojis
WHERE actor_profile.id = profile_emojis.profile_id WHERE actor_profile.id = profile_emojis.profile_id
", ",
&[&emoji_id], &[&emoji_id],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -130,8 +118,9 @@ pub async fn create_profile(
if let Some(ref hostname) = profile_data.hostname { if let Some(ref hostname) = profile_data.hostname {
create_instance(&transaction, hostname).await?; create_instance(&transaction, hostname).await?;
}; };
transaction.execute( transaction
" .execute(
"
INSERT INTO actor_profile ( INSERT INTO actor_profile (
id, id,
username, username,
@ -151,30 +140,28 @@ pub async fn create_profile(
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING actor_profile RETURNING actor_profile
", ",
&[ &[
&profile_id, &profile_id,
&profile_data.username, &profile_data.username,
&profile_data.hostname, &profile_data.hostname,
&profile_data.display_name, &profile_data.display_name,
&profile_data.bio, &profile_data.bio,
&profile_data.bio, &profile_data.bio,
&profile_data.avatar, &profile_data.avatar,
&profile_data.banner, &profile_data.banner,
&profile_data.manually_approves_followers, &profile_data.manually_approves_followers,
&IdentityProofs(profile_data.identity_proofs), &IdentityProofs(profile_data.identity_proofs),
&PaymentOptions(profile_data.payment_options), &PaymentOptions(profile_data.payment_options),
&ExtraFields(profile_data.extra_fields), &ExtraFields(profile_data.extra_fields),
&Aliases::new(profile_data.aliases), &Aliases::new(profile_data.aliases),
&profile_data.actor_json, &profile_data.actor_json,
], ],
).await.map_err(catch_unique_violation("profile"))?; )
.await
.map_err(catch_unique_violation("profile"))?;
// Create related objects // Create related objects
create_profile_emojis( create_profile_emojis(&transaction, &profile_id, profile_data.emojis).await?;
&transaction,
&profile_id,
profile_data.emojis,
).await?;
let profile = update_emoji_cache(&transaction, &profile_id).await?; let profile = update_emoji_cache(&transaction, &profile_id).await?;
transaction.commit().await?; transaction.commit().await?;
@ -187,8 +174,9 @@ pub async fn update_profile(
profile_data: ProfileUpdateData, profile_data: ProfileUpdateData,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
transaction.execute( transaction
" .execute(
"
UPDATE actor_profile UPDATE actor_profile
SET SET
display_name = $1, display_name = $1,
@ -206,32 +194,31 @@ pub async fn update_profile(
WHERE id = $12 WHERE id = $12
RETURNING actor_profile RETURNING actor_profile
", ",
&[ &[
&profile_data.display_name, &profile_data.display_name,
&profile_data.bio, &profile_data.bio,
&profile_data.bio_source, &profile_data.bio_source,
&profile_data.avatar, &profile_data.avatar,
&profile_data.banner, &profile_data.banner,
&profile_data.manually_approves_followers, &profile_data.manually_approves_followers,
&IdentityProofs(profile_data.identity_proofs), &IdentityProofs(profile_data.identity_proofs),
&PaymentOptions(profile_data.payment_options), &PaymentOptions(profile_data.payment_options),
&ExtraFields(profile_data.extra_fields), &ExtraFields(profile_data.extra_fields),
&Aliases::new(profile_data.aliases), &Aliases::new(profile_data.aliases),
&profile_data.actor_json, &profile_data.actor_json,
&profile_id, &profile_id,
], ],
).await?; )
.await?;
// Delete and re-create related objects // Delete and re-create related objects
transaction.execute( transaction
"DELETE FROM profile_emoji WHERE profile_id = $1", .execute(
&[profile_id], "DELETE FROM profile_emoji WHERE profile_id = $1",
).await?; &[profile_id],
create_profile_emojis( )
&transaction, .await?;
profile_id, create_profile_emojis(&transaction, profile_id, profile_data.emojis).await?;
profile_data.emojis,
).await?;
let profile = update_emoji_cache(&transaction, profile_id).await?; let profile = update_emoji_cache(&transaction, profile_id).await?;
transaction.commit().await?; transaction.commit().await?;
@ -242,14 +229,16 @@ pub async fn get_profile_by_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
profile_id: &Uuid, profile_id: &Uuid,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let result = db_client.query_opt( let result = db_client
" .query_opt(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
WHERE id = $1 WHERE id = $1
", ",
&[&profile_id], &[&profile_id],
).await?; )
.await?;
let profile = match result { let profile = match result {
Some(row) => row.try_get("actor_profile")?, Some(row) => row.try_get("actor_profile")?,
None => return Err(DatabaseError::NotFound("profile")), None => return Err(DatabaseError::NotFound("profile")),
@ -261,14 +250,16 @@ pub async fn get_profile_by_remote_actor_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
actor_id: &str, actor_id: &str,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
WHERE actor_id = $1 WHERE actor_id = $1
", ",
&[&actor_id], &[&actor_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile: DbActorProfile = row.try_get("actor_profile")?; let profile: DbActorProfile = row.try_get("actor_profile")?;
profile.check_remote()?; profile.check_remote()?;
@ -279,14 +270,16 @@ pub async fn get_profile_by_acct(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
acct: &str, acct: &str,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let result = db_client.query_opt( let result = db_client
" .query_opt(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
WHERE actor_profile.acct = $1 WHERE actor_profile.acct = $1
", ",
&[&acct], &[&acct],
).await?; )
.await?;
let profile = match result { let profile = match result {
Some(row) => row.try_get("actor_profile")?, Some(row) => row.try_get("actor_profile")?,
None => return Err(DatabaseError::NotFound("profile")), None => return Err(DatabaseError::NotFound("profile")),
@ -300,7 +293,11 @@ pub async fn get_profiles(
offset: u16, offset: u16,
limit: u16, limit: u16,
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let condition = if only_local { "WHERE actor_id IS NULL" } else { "" }; let condition = if only_local {
"WHERE actor_id IS NULL"
} else {
""
};
let statement = format!( let statement = format!(
" "
SELECT actor_profile SELECT actor_profile
@ -309,13 +306,13 @@ pub async fn get_profiles(
ORDER BY username ORDER BY username
LIMIT $1 OFFSET $2 LIMIT $1 OFFSET $2
", ",
condition=condition, condition = condition,
); );
let rows = db_client.query( let rows = db_client
&statement, .query(&statement, &[&i64::from(limit), &i64::from(offset)])
&[&i64::from(limit), &i64::from(offset)], .await?;
).await?; let profiles = rows
let profiles = rows.iter() .iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(profiles) Ok(profiles)
@ -325,15 +322,18 @@ pub async fn get_profiles_by_accts(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
accts: Vec<String>, accts: Vec<String>,
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
WHERE acct = ANY($1) WHERE acct = ANY($1)
", ",
&[&accts], &[&accts],
).await?; )
let profiles = rows.iter() .await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(profiles) Ok(profiles)
@ -347,8 +347,9 @@ pub async fn delete_profile(
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
// Select all posts authored by given actor, // Select all posts authored by given actor,
// their descendants and reposts. // their descendants and reposts.
let posts_rows = transaction.query( let posts_rows = transaction
" .query(
"
WITH RECURSIVE context (post_id) AS ( WITH RECURSIVE context (post_id) AS (
SELECT post.id FROM post SELECT post.id FROM post
WHERE post.author_id = $1 WHERE post.author_id = $1
@ -361,14 +362,17 @@ pub async fn delete_profile(
) )
SELECT post_id FROM context SELECT post_id FROM context
", ",
&[&profile_id], &[&profile_id],
).await?; )
let posts: Vec<Uuid> = posts_rows.iter() .await?;
let posts: Vec<Uuid> = posts_rows
.iter()
.map(|row| row.try_get("post_id")) .map(|row| row.try_get("post_id"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
// Get list of media files // Get list of media files
let files_rows = transaction.query( let files_rows = transaction
" .query(
"
SELECT unnest(array_remove( SELECT unnest(array_remove(
ARRAY[ ARRAY[
avatar ->> 'file_name', avatar ->> 'file_name',
@ -381,14 +385,17 @@ pub async fn delete_profile(
SELECT file_name SELECT file_name
FROM media_attachment WHERE post_id = ANY($2) FROM media_attachment WHERE post_id = ANY($2)
", ",
&[&profile_id, &posts], &[&profile_id, &posts],
).await?; )
let files: Vec<String> = files_rows.iter() .await?;
let files: Vec<String> = files_rows
.iter()
.map(|row| row.try_get("file_name")) .map(|row| row.try_get("file_name"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
// Get list of IPFS objects // Get list of IPFS objects
let ipfs_objects_rows = transaction.query( let ipfs_objects_rows = transaction
" .query(
"
SELECT ipfs_cid SELECT ipfs_cid
FROM media_attachment FROM media_attachment
WHERE post_id = ANY($1) AND ipfs_cid IS NOT NULL WHERE post_id = ANY($1) AND ipfs_cid IS NOT NULL
@ -397,14 +404,17 @@ pub async fn delete_profile(
FROM post FROM post
WHERE id = ANY($1) AND ipfs_cid IS NOT NULL WHERE id = ANY($1) AND ipfs_cid IS NOT NULL
", ",
&[&posts], &[&posts],
).await?; )
let ipfs_objects: Vec<String> = ipfs_objects_rows.iter() .await?;
let ipfs_objects: Vec<String> = ipfs_objects_rows
.iter()
.map(|row| row.try_get("ipfs_cid")) .map(|row| row.try_get("ipfs_cid"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
// Update post counters // Update post counters
transaction.execute( transaction
" .execute(
"
UPDATE actor_profile UPDATE actor_profile
SET post_count = post_count - post.count SET post_count = post_count - post.count
FROM ( FROM (
@ -414,11 +424,13 @@ pub async fn delete_profile(
) AS post ) AS post
WHERE actor_profile.id = post.author_id WHERE actor_profile.id = post.author_id
", ",
&[&posts], &[&posts],
).await?; )
.await?;
// Update counters // Update counters
transaction.execute( transaction
" .execute(
"
UPDATE actor_profile UPDATE actor_profile
SET follower_count = follower_count - 1 SET follower_count = follower_count - 1
FROM relationship FROM relationship
@ -427,10 +439,12 @@ pub async fn delete_profile(
AND relationship.target_id = actor_profile.id AND relationship.target_id = actor_profile.id
AND relationship.relationship_type = $2 AND relationship.relationship_type = $2
", ",
&[&profile_id, &RelationshipType::Follow], &[&profile_id, &RelationshipType::Follow],
).await?; )
transaction.execute( .await?;
" transaction
.execute(
"
UPDATE actor_profile UPDATE actor_profile
SET following_count = following_count - 1 SET following_count = following_count - 1
FROM relationship FROM relationship
@ -439,10 +453,12 @@ pub async fn delete_profile(
AND relationship.target_id = $1 AND relationship.target_id = $1
AND relationship.relationship_type = $2 AND relationship.relationship_type = $2
", ",
&[&profile_id, &RelationshipType::Follow], &[&profile_id, &RelationshipType::Follow],
).await?; )
transaction.execute( .await?;
" transaction
.execute(
"
UPDATE actor_profile UPDATE actor_profile
SET subscriber_count = subscriber_count - 1 SET subscriber_count = subscriber_count - 1
FROM relationship FROM relationship
@ -451,10 +467,12 @@ pub async fn delete_profile(
AND relationship.target_id = actor_profile.id AND relationship.target_id = actor_profile.id
AND relationship.relationship_type = $2 AND relationship.relationship_type = $2
", ",
&[&profile_id, &RelationshipType::Subscription], &[&profile_id, &RelationshipType::Subscription],
).await?; )
transaction.execute( .await?;
" transaction
.execute(
"
UPDATE post UPDATE post
SET reply_count = reply_count - reply.count SET reply_count = reply_count - reply.count
FROM ( FROM (
@ -464,10 +482,12 @@ pub async fn delete_profile(
) AS reply ) AS reply
WHERE post.id = reply.in_reply_to_id WHERE post.id = reply.in_reply_to_id
", ",
&[&profile_id], &[&profile_id],
).await?; )
transaction.execute( .await?;
" transaction
.execute(
"
UPDATE post UPDATE post
SET reaction_count = reaction_count - 1 SET reaction_count = reaction_count - 1
FROM post_reaction FROM post_reaction
@ -475,10 +495,12 @@ pub async fn delete_profile(
post_reaction.post_id = post.id post_reaction.post_id = post.id
AND post_reaction.author_id = $1 AND post_reaction.author_id = $1
", ",
&[&profile_id], &[&profile_id],
).await?; )
transaction.execute( .await?;
" transaction
.execute(
"
UPDATE post UPDATE post
SET repost_count = post.repost_count - 1 SET repost_count = post.repost_count - 1
FROM post AS repost FROM post AS repost
@ -486,16 +508,19 @@ pub async fn delete_profile(
repost.repost_of_id = post.id repost.repost_of_id = post.id
AND repost.author_id = $1 AND repost.author_id = $1
", ",
&[&profile_id], &[&profile_id],
).await?; )
.await?;
// Delete profile // Delete profile
let deleted_count = transaction.execute( let deleted_count = transaction
" .execute(
"
DELETE FROM actor_profile WHERE id = $1 DELETE FROM actor_profile WHERE id = $1
RETURNING actor_profile RETURNING actor_profile
", ",
&[&profile_id], &[&profile_id],
).await?; )
.await?;
if deleted_count == 0 { if deleted_count == 0 {
return Err(DatabaseError::NotFound("profile")); return Err(DatabaseError::NotFound("profile"));
} }
@ -518,22 +543,25 @@ pub async fn search_profiles(
Some(hostname) => { Some(hostname) => {
// Search for exact actor address // Search for exact actor address
format!("{}@{}", username, hostname) format!("{}@{}", username, hostname)
}, }
None => { None => {
// Fuzzy search for username // Fuzzy search for username
format!("%{}%", username) format!("%{}%", username)
}, }
}; };
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
WHERE acct ILIKE $1 WHERE acct ILIKE $1
LIMIT $2 LIMIT $2
", ",
&[&db_search_query, &i64::from(limit)], &[&db_search_query, &i64::from(limit)],
).await?; )
let profiles: Vec<DbActorProfile> = rows.iter() .await?;
let profiles: Vec<DbActorProfile> = rows
.iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(profiles) Ok(profiles)
@ -543,8 +571,9 @@ pub async fn search_profiles_by_did_only(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
did: &Did, did: &Did,
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
WHERE WHERE
@ -554,9 +583,11 @@ pub async fn search_profiles_by_did_only(
WHERE proof ->> 'issuer' = $1 WHERE proof ->> 'issuer' = $1
) )
", ",
&[&did.to_string()], &[&did.to_string()],
).await?; )
let profiles: Vec<DbActorProfile> = rows.iter() .await?;
let profiles: Vec<DbActorProfile> = rows
.iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(profiles) Ok(profiles)
@ -569,10 +600,9 @@ pub async fn search_profiles_by_did(
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let verified = search_profiles_by_did_only(db_client, did).await?; let verified = search_profiles_by_did_only(db_client, did).await?;
let maybe_currency_address = match did { let maybe_currency_address = match did {
Did::Pkh(did_pkh) => { Did::Pkh(did_pkh) => did_pkh
did_pkh.currency() .currency()
.map(|currency| (currency, did_pkh.address.clone())) .map(|currency| (currency, did_pkh.address.clone())),
},
_ => None, _ => None,
}; };
let unverified = if let Some((currency, address)) = maybe_currency_address { let unverified = if let Some((currency, address)) = maybe_currency_address {
@ -597,16 +627,13 @@ pub async fn search_profiles_by_did(
AND field ->> 'value' {value_op} $field_value AND field ->> 'value' {value_op} $field_value
) )
", ",
value_op=value_op, value_op = value_op,
); );
let field_name = currency.field_name(); let field_name = currency.field_name();
let query = query!( let query = query!(&statement, field_name = field_name, field_value = address,)?;
&statement,
field_name=field_name,
field_value=address,
)?;
let rows = db_client.query(query.sql(), query.parameters()).await?; let rows = db_client.query(query.sql(), query.parameters()).await?;
let unverified = rows.iter() let unverified = rows
.iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<Vec<DbActorProfile>, _>>()? .collect::<Result<Vec<DbActorProfile>, _>>()?
.into_iter() .into_iter()
@ -641,15 +668,17 @@ pub async fn update_follower_count(
profile_id: &Uuid, profile_id: &Uuid,
change: i32, change: i32,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
UPDATE actor_profile UPDATE actor_profile
SET follower_count = follower_count + $1 SET follower_count = follower_count + $1
WHERE id = $2 WHERE id = $2
RETURNING actor_profile RETURNING actor_profile
", ",
&[&change, &profile_id], &[&change, &profile_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile = row.try_get("actor_profile")?; let profile = row.try_get("actor_profile")?;
Ok(profile) Ok(profile)
@ -660,15 +689,17 @@ pub async fn update_following_count(
profile_id: &Uuid, profile_id: &Uuid,
change: i32, change: i32,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
UPDATE actor_profile UPDATE actor_profile
SET following_count = following_count + $1 SET following_count = following_count + $1
WHERE id = $2 WHERE id = $2
RETURNING actor_profile RETURNING actor_profile
", ",
&[&change, &profile_id], &[&change, &profile_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile = row.try_get("actor_profile")?; let profile = row.try_get("actor_profile")?;
Ok(profile) Ok(profile)
@ -679,15 +710,17 @@ pub async fn update_subscriber_count(
profile_id: &Uuid, profile_id: &Uuid,
change: i32, change: i32,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
UPDATE actor_profile UPDATE actor_profile
SET subscriber_count = subscriber_count + $1 SET subscriber_count = subscriber_count + $1
WHERE id = $2 WHERE id = $2
RETURNING actor_profile RETURNING actor_profile
", ",
&[&change, &profile_id], &[&change, &profile_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile = row.try_get("actor_profile")?; let profile = row.try_get("actor_profile")?;
Ok(profile) Ok(profile)
@ -698,15 +731,17 @@ pub async fn update_post_count(
profile_id: &Uuid, profile_id: &Uuid,
change: i32, change: i32,
) -> Result<DbActorProfile, DatabaseError> { ) -> Result<DbActorProfile, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
UPDATE actor_profile UPDATE actor_profile
SET post_count = post_count + $1 SET post_count = post_count + $1
WHERE id = $2 WHERE id = $2
RETURNING actor_profile RETURNING actor_profile
", ",
&[&change, &profile_id], &[&change, &profile_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("profile"))?;
let profile = row.try_get("actor_profile")?; let profile = row.try_get("actor_profile")?;
Ok(profile) Ok(profile)
@ -720,24 +755,28 @@ pub async fn set_reachability_status(
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
if !is_reachable { if !is_reachable {
// Don't update profile if unreachable_since is already set // Don't update profile if unreachable_since is already set
db_client.execute( db_client
" .execute(
"
UPDATE actor_profile UPDATE actor_profile
SET unreachable_since = CURRENT_TIMESTAMP SET unreachable_since = CURRENT_TIMESTAMP
WHERE actor_id = $1 AND unreachable_since IS NULL WHERE actor_id = $1 AND unreachable_since IS NULL
", ",
&[&actor_id], &[&actor_id],
).await?; )
.await?;
} else { } else {
// Remove status (if set) // Remove status (if set)
db_client.execute( db_client
" .execute(
"
UPDATE actor_profile UPDATE actor_profile
SET unreachable_since = NULL SET unreachable_since = NULL
WHERE actor_id = $1 WHERE actor_id = $1
", ",
&[&actor_id], &[&actor_id],
).await?; )
.await?;
}; };
Ok(()) Ok(())
} }
@ -746,16 +785,19 @@ pub async fn find_unreachable(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
unreachable_since: &DateTime<Utc>, unreachable_since: &DateTime<Utc>,
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
WHERE unreachable_since < $1 WHERE unreachable_since < $1
ORDER BY hostname, username ORDER BY hostname, username
", ",
&[&unreachable_since], &[&unreachable_since],
).await?; )
let profiles = rows.iter() .await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(profiles) Ok(profiles)
@ -768,8 +810,9 @@ pub async fn find_empty_profiles(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
updated_before: &DateTime<Utc>, updated_before: &DateTime<Utc>,
) -> Result<Vec<Uuid>, DatabaseError> { ) -> Result<Vec<Uuid>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT actor_profile.id SELECT actor_profile.id
FROM actor_profile FROM actor_profile
WHERE WHERE
@ -816,9 +859,11 @@ pub async fn find_empty_profiles(
WHERE sender_id = actor_profile.id WHERE sender_id = actor_profile.id
) )
", ",
&[&updated_before], &[&updated_before],
).await?; )
let ids: Vec<Uuid> = rows.iter() .await?;
let ids: Vec<Uuid> = rows
.iter()
.map(|row| row.try_get("id")) .map(|row| row.try_get("id"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(ids) Ok(ids)
@ -826,30 +871,21 @@ pub async fn find_empty_profiles(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial; use super::*;
use crate::database::test_utils::create_test_database; use crate::database::test_utils::create_test_database;
use crate::emojis::{ use crate::emojis::{queries::create_emoji, types::EmojiImage};
queries::create_emoji,
types::EmojiImage,
};
use crate::profiles::{ use crate::profiles::{
queries::create_profile, queries::create_profile,
types::{ types::{DbActor, ExtraField, IdentityProof, IdentityProofType, ProfileCreateData},
DbActor,
ExtraField,
IdentityProof,
IdentityProofType,
ProfileCreateData,
},
}; };
use crate::users::{ use crate::users::{queries::create_user, types::UserCreateData};
queries::create_user, use serial_test::serial;
types::UserCreateData,
};
use super::*;
fn create_test_actor(actor_id: &str) -> DbActor { fn create_test_actor(actor_id: &str) -> DbActor {
DbActor { id: actor_id.to_string(), ..Default::default() } DbActor {
id: actor_id.to_string(),
..Default::default()
}
} }
#[tokio::test] #[tokio::test]
@ -883,10 +919,7 @@ mod tests {
assert_eq!(profile.username, "test"); assert_eq!(profile.username, "test");
assert_eq!(profile.hostname.unwrap(), "example.com"); assert_eq!(profile.hostname.unwrap(), "example.com");
assert_eq!(profile.acct, "test@example.com"); assert_eq!(profile.acct, "test@example.com");
assert_eq!( assert_eq!(profile.actor_id.unwrap(), "https://example.com/users/test",);
profile.actor_id.unwrap(),
"https://example.com/users/test",
);
} }
#[tokio::test] #[tokio::test]
@ -894,14 +927,9 @@ mod tests {
async fn test_create_profile_with_emoji() { async fn test_create_profile_with_emoji() {
let db_client = &mut create_test_database().await; let db_client = &mut create_test_database().await;
let image = EmojiImage::default(); let image = EmojiImage::default();
let emoji = create_emoji( let emoji = create_emoji(db_client, "testemoji", None, image, None, &Utc::now())
db_client, .await
"testemoji", .unwrap();
None,
image,
None,
&Utc::now(),
).await.unwrap();
let profile_data = ProfileCreateData { let profile_data = ProfileCreateData {
username: "test".to_string(), username: "test".to_string(),
emojis: vec![emoji.id.clone()], emojis: vec![emoji.id.clone()],
@ -931,7 +959,10 @@ mod tests {
actor_json: Some(create_test_actor(actor_id)), actor_json: Some(create_test_actor(actor_id)),
..Default::default() ..Default::default()
}; };
let error = create_profile(db_client, profile_data_2).await.err().unwrap(); let error = create_profile(db_client, profile_data_2)
.await
.err()
.unwrap();
assert_eq!(error.to_string(), "profile already exists"); assert_eq!(error.to_string(), "profile already exists");
} }
@ -947,11 +978,9 @@ mod tests {
let mut profile_data = ProfileUpdateData::from(&profile); let mut profile_data = ProfileUpdateData::from(&profile);
let bio = "test bio"; let bio = "test bio";
profile_data.bio = Some(bio.to_string()); profile_data.bio = Some(bio.to_string());
let profile_updated = update_profile( let profile_updated = update_profile(db_client, &profile.id, profile_data)
db_client, .await
&profile.id, .unwrap();
profile_data,
).await.unwrap();
assert_eq!(profile_updated.username, profile.username); assert_eq!(profile_updated.username, profile.username);
assert_eq!(profile_updated.bio.unwrap(), bio); assert_eq!(profile_updated.bio.unwrap(), bio);
assert!(profile_updated.updated_at != profile.updated_at); assert!(profile_updated.updated_at != profile.updated_at);
@ -980,8 +1009,10 @@ mod tests {
..Default::default() ..Default::default()
}; };
let _user = create_user(db_client, user_data).await.unwrap(); let _user = create_user(db_client, user_data).await.unwrap();
let profiles = search_profiles_by_wallet_address( let profiles =
db_client, &ETHEREUM, wallet_address, false).await.unwrap(); search_profiles_by_wallet_address(db_client, &ETHEREUM, wallet_address, false)
.await
.unwrap();
// Login address must not be exposed // Login address must not be exposed
assert_eq!(profiles.len(), 0); assert_eq!(profiles.len(), 0);
@ -1001,8 +1032,9 @@ mod tests {
..Default::default() ..Default::default()
}; };
let profile = create_profile(db_client, profile_data).await.unwrap(); let profile = create_profile(db_client, profile_data).await.unwrap();
let profiles = search_profiles_by_wallet_address( let profiles = search_profiles_by_wallet_address(db_client, &ETHEREUM, "0x1234abcd", false)
db_client, &ETHEREUM, "0x1234abcd", false).await.unwrap(); .await
.unwrap();
assert_eq!(profiles.len(), 1); assert_eq!(profiles.len(), 1);
assert_eq!(profiles[0].id, profile.id); assert_eq!(profiles[0].id, profile.id);
@ -1022,8 +1054,9 @@ mod tests {
..Default::default() ..Default::default()
}; };
let profile = create_profile(db_client, profile_data).await.unwrap(); let profile = create_profile(db_client, profile_data).await.unwrap();
let profiles = search_profiles_by_wallet_address( let profiles = search_profiles_by_wallet_address(db_client, &ETHEREUM, "0x1234abcd", false)
db_client, &ETHEREUM, "0x1234abcd", false).await.unwrap(); .await
.unwrap();
assert_eq!(profiles.len(), 1); assert_eq!(profiles.len(), 1);
assert_eq!(profiles[0].id, profile.id); assert_eq!(profiles[0].id, profile.id);
@ -1041,7 +1074,9 @@ mod tests {
..Default::default() ..Default::default()
}; };
let profile = create_profile(db_client, profile_data).await.unwrap(); let profile = create_profile(db_client, profile_data).await.unwrap();
set_reachability_status(db_client, actor_id, false).await.unwrap(); set_reachability_status(db_client, actor_id, false)
.await
.unwrap();
let profile = get_profile_by_id(db_client, &profile.id).await.unwrap(); let profile = get_profile_by_id(db_client, &profile.id).await.unwrap();
assert_eq!(profile.unreachable_since.is_some(), true); assert_eq!(profile.unreachable_since.is_some(), true);
} }
@ -1051,7 +1086,9 @@ mod tests {
async fn test_find_empty_profiles() { async fn test_find_empty_profiles() {
let db_client = &mut create_test_database().await; let db_client = &mut create_test_database().await;
let updated_before = Utc::now(); let updated_before = Utc::now();
let profiles = find_empty_profiles(db_client, &updated_before).await.unwrap(); let profiles = find_empty_profiles(db_client, &updated_before)
.await
.unwrap();
assert_eq!(profiles.is_empty(), true); assert_eq!(profiles.is_empty(), true);
} }
} }

View file

@ -1,17 +1,12 @@
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use postgres_types::FromSql; use postgres_types::FromSql;
use serde::{ use serde::{
Deserialize, Deserializer, Serialize, Serializer, Deserialize, Deserializer, Serialize, Serializer, __private::ser::FlatMapSerializer,
de::Error as DeserializerError, de::Error as DeserializerError, ser::SerializeMap,
ser::SerializeMap,
__private::ser::FlatMapSerializer,
}; };
use uuid::Uuid; use uuid::Uuid;
use mitra_utils::{ use mitra_utils::{caip2::ChainId, did::Did};
caip2::ChainId,
did::Did,
};
use crate::database::{ use crate::database::{
json_macro::{json_from_sql, json_to_sql}, json_macro::{json_from_sql, json_to_sql},
@ -27,11 +22,7 @@ pub struct ProfileImage {
} }
impl ProfileImage { impl ProfileImage {
pub fn new( pub fn new(file_name: String, file_size: usize, media_type: Option<String>) -> Self {
file_name: String,
file_size: usize,
media_type: Option<String>,
) -> Self {
Self { Self {
file_name, file_name,
file_size: Some(file_size), file_size: Some(file_size),
@ -73,16 +64,19 @@ impl TryFrom<i16> for IdentityProofType {
impl<'de> Deserialize<'de> for IdentityProofType { impl<'de> Deserialize<'de> for IdentityProofType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> where
D: Deserializer<'de>,
{ {
i16::deserialize(deserializer)? i16::deserialize(deserializer)?
.try_into().map_err(DeserializerError::custom) .try_into()
.map_err(DeserializerError::custom)
} }
} }
impl Serialize for IdentityProofType { impl Serialize for IdentityProofType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer where
S: Serializer,
{ {
serializer.serialize_i16(self.into()) serializer.serialize_i16(self.into())
} }
@ -194,30 +188,31 @@ impl PaymentOption {
// Workaround: https://stackoverflow.com/a/65576570 // Workaround: https://stackoverflow.com/a/65576570
impl<'de> Deserialize<'de> for PaymentOption { impl<'de> Deserialize<'de> for PaymentOption {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> where
D: Deserializer<'de>,
{ {
let value = serde_json::Value::deserialize(deserializer)?; let value = serde_json::Value::deserialize(deserializer)?;
let payment_type = value.get("payment_type") let payment_type = value
.get("payment_type")
.and_then(serde_json::Value::as_u64) .and_then(serde_json::Value::as_u64)
.and_then(|val| i16::try_from(val).ok()) .and_then(|val| i16::try_from(val).ok())
.and_then(|val| PaymentType::try_from(val).ok()) .and_then(|val| PaymentType::try_from(val).ok())
.ok_or(DeserializerError::custom("invalid payment type"))?; .ok_or(DeserializerError::custom("invalid payment type"))?;
let payment_option = match payment_type { let payment_option = match payment_type {
PaymentType::Link => { PaymentType::Link => {
let link = PaymentLink::deserialize(value) let link = PaymentLink::deserialize(value).map_err(DeserializerError::custom)?;
.map_err(DeserializerError::custom)?;
Self::Link(link) Self::Link(link)
}, }
PaymentType::EthereumSubscription => { PaymentType::EthereumSubscription => {
let payment_info = EthereumSubscription::deserialize(value) let payment_info =
.map_err(DeserializerError::custom)?; EthereumSubscription::deserialize(value).map_err(DeserializerError::custom)?;
Self::EthereumSubscription(payment_info) Self::EthereumSubscription(payment_info)
}, }
PaymentType::MoneroSubscription => { PaymentType::MoneroSubscription => {
let payment_info = MoneroSubscription::deserialize(value) let payment_info =
.map_err(DeserializerError::custom)?; MoneroSubscription::deserialize(value).map_err(DeserializerError::custom)?;
Self::MoneroSubscription(payment_info) Self::MoneroSubscription(payment_info)
}, }
}; };
Ok(payment_option) Ok(payment_option)
} }
@ -225,7 +220,8 @@ impl<'de> Deserialize<'de> for PaymentOption {
impl Serialize for PaymentOption { impl Serialize for PaymentOption {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer, where
S: Serializer,
{ {
let mut map = serializer.serialize_map(None)?; let mut map = serializer.serialize_map(None)?;
let payment_type = self.payment_type(); let payment_type = self.payment_type();
@ -235,10 +231,10 @@ impl Serialize for PaymentOption {
Self::Link(link) => link.serialize(FlatMapSerializer(&mut map))?, Self::Link(link) => link.serialize(FlatMapSerializer(&mut map))?,
Self::EthereumSubscription(payment_info) => { Self::EthereumSubscription(payment_info) => {
payment_info.serialize(FlatMapSerializer(&mut map))? payment_info.serialize(FlatMapSerializer(&mut map))?
}, }
Self::MoneroSubscription(payment_info) => { Self::MoneroSubscription(payment_info) => {
payment_info.serialize(FlatMapSerializer(&mut map))? payment_info.serialize(FlatMapSerializer(&mut map))?
}, }
}; };
map.end() map.end()
} }
@ -267,7 +263,8 @@ impl PaymentOptions {
/// of the given type. /// of the given type.
pub fn any(&self, payment_type: PaymentType) -> bool { pub fn any(&self, payment_type: PaymentType) -> bool {
let Self(payment_options) = self; let Self(payment_options) = self;
payment_options.iter() payment_options
.iter()
.any(|option| option.payment_type() == payment_type) .any(|option| option.payment_type() == payment_type)
} }
} }
@ -306,7 +303,8 @@ pub struct Aliases(Vec<Alias>);
impl Aliases { impl Aliases {
pub fn new(actor_ids: Vec<String>) -> Self { pub fn new(actor_ids: Vec<String>) -> Self {
// Not signed // Not signed
let aliases = actor_ids.into_iter() let aliases = actor_ids
.into_iter()
.map(|actor_id| Alias { id: actor_id }) .map(|actor_id| Alias { id: actor_id })
.collect(); .collect();
Self(aliases) Self(aliases)
@ -369,7 +367,7 @@ pub struct DbActorProfile {
pub username: String, pub username: String,
pub hostname: Option<String>, pub hostname: Option<String>,
pub display_name: Option<String>, pub display_name: Option<String>,
pub bio: Option<String>, // html pub bio: Option<String>, // html
pub bio_source: Option<String>, // plaintext or markdown pub bio_source: Option<String>, // plaintext or markdown
pub avatar: Option<ProfileImage>, pub avatar: Option<ProfileImage>,
pub banner: Option<ProfileImage>, pub banner: Option<ProfileImage>,
@ -491,16 +489,16 @@ impl ProfileUpdateData {
/// Adds new identity proof /// Adds new identity proof
/// or replaces the existing one if it has the same issuer. /// or replaces the existing one if it has the same issuer.
pub fn add_identity_proof(&mut self, proof: IdentityProof) -> () { pub fn add_identity_proof(&mut self, proof: IdentityProof) -> () {
self.identity_proofs.retain(|item| item.issuer != proof.issuer); self.identity_proofs
.retain(|item| item.issuer != proof.issuer);
self.identity_proofs.push(proof); self.identity_proofs.push(proof);
} }
/// Adds new payment option /// Adds new payment option
/// or replaces the existing one if it has the same type. /// or replaces the existing one if it has the same type.
pub fn add_payment_option(&mut self, option: PaymentOption) -> () { pub fn add_payment_option(&mut self, option: PaymentOption) -> () {
self.payment_options.retain(|item| { self.payment_options
item.payment_type() != option.payment_type() .retain(|item| item.payment_type() != option.payment_type());
});
self.payment_options.push(option); self.payment_options.push(option);
} }
} }
@ -519,7 +517,10 @@ impl From<&DbActorProfile> for ProfileUpdateData {
payment_options: profile.payment_options.into_inner(), payment_options: profile.payment_options.into_inner(),
extra_fields: profile.extra_fields.into_inner(), extra_fields: profile.extra_fields.into_inner(),
aliases: profile.aliases.into_actor_ids(), aliases: profile.aliases.into_actor_ids(),
emojis: profile.emojis.into_inner().into_iter() emojis: profile
.emojis
.into_inner()
.into_iter()
.map(|emoji| emoji.id) .map(|emoji| emoji.id)
.collect(), .collect(),
actor_json: profile.actor_json, actor_json: profile.actor_json,
@ -539,7 +540,10 @@ mod tests {
Did::Pkh(ref did_pkh) => did_pkh, Did::Pkh(ref did_pkh) => did_pkh,
_ => panic!("unexpected did method"), _ => panic!("unexpected did method"),
}; };
assert_eq!(did_pkh.address, "0xb9c5714089478a327f09197987f16f9e5d936e8a"); assert_eq!(
did_pkh.address,
"0xb9c5714089478a327f09197987f16f9e5d936e8a"
);
let serialized = serde_json::to_string(&proof).unwrap(); let serialized = serde_json::to_string(&proof).unwrap();
assert_eq!(serialized, json_data); assert_eq!(serialized, json_data);
} }

View file

@ -1,31 +1,25 @@
use serde::{ use serde::{de::DeserializeOwned, Serialize};
de::DeserializeOwned,
Serialize,
};
use serde_json::Value as JsonValue; use serde_json::Value as JsonValue;
use crate::database::{ use crate::database::{DatabaseClient, DatabaseError, DatabaseTypeError};
DatabaseClient,
DatabaseError,
DatabaseTypeError,
};
pub async fn set_internal_property( pub async fn set_internal_property(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
name: &str, name: &str,
value: &impl Serialize, value: &impl Serialize,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let value_json = serde_json::to_value(value) let value_json = serde_json::to_value(value).map_err(|_| DatabaseTypeError)?;
.map_err(|_| DatabaseTypeError)?; db_client
db_client.execute( .execute(
" "
INSERT INTO internal_property (property_name, property_value) INSERT INTO internal_property (property_name, property_value)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (property_name) DO UPDATE ON CONFLICT (property_name) DO UPDATE
SET property_value = $2 SET property_value = $2
", ",
&[&name, &value_json], &[&name, &value_json],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -33,21 +27,22 @@ pub async fn get_internal_property<T: DeserializeOwned>(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
name: &str, name: &str,
) -> Result<Option<T>, DatabaseError> { ) -> Result<Option<T>, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT property_value SELECT property_value
FROM internal_property FROM internal_property
WHERE property_name = $1 WHERE property_name = $1
", ",
&[&name], &[&name],
).await?; )
.await?;
let maybe_value = match maybe_row { let maybe_value = match maybe_row {
Some(row) => { Some(row) => {
let value_json: JsonValue = row.try_get("property_value")?; let value_json: JsonValue = row.try_get("property_value")?;
let value: T = serde_json::from_value(value_json) let value: T = serde_json::from_value(value_json).map_err(|_| DatabaseTypeError)?;
.map_err(|_| DatabaseTypeError)?;
Some(value) Some(value)
}, }
None => None, None => None,
}; };
Ok(maybe_value) Ok(maybe_value)
@ -55,9 +50,9 @@ pub async fn get_internal_property<T: DeserializeOwned>(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use super::*; use super::*;
use crate::database::test_utils::create_test_database;
use serial_test::serial;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -65,9 +60,13 @@ mod tests {
let db_client = &create_test_database().await; let db_client = &create_test_database().await;
let name = "myproperty"; let name = "myproperty";
let value = 100; let value = 100;
set_internal_property(db_client, name, &value).await.unwrap(); set_internal_property(db_client, name, &value)
let db_value: u32 = get_internal_property(db_client, name).await .await
.unwrap().unwrap_or_default(); .unwrap();
let db_value: u32 = get_internal_property(db_client, name)
.await
.unwrap()
.unwrap_or_default();
assert_eq!(db_value, value); assert_eq!(db_value, value);
} }
} }

View file

@ -2,16 +2,9 @@ use uuid::Uuid;
use mitra_utils::id::generate_ulid; use mitra_utils::id::generate_ulid;
use crate::database::{ use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::notifications::queries::create_reaction_notification; use crate::notifications::queries::create_reaction_notification;
use crate::posts::queries::{ use crate::posts::queries::{get_post_author, update_reaction_count};
update_reaction_count,
get_post_author,
};
use super::types::DbReaction; use super::types::DbReaction;
@ -24,8 +17,9 @@ pub async fn create_reaction(
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
let reaction_id = generate_ulid(); let reaction_id = generate_ulid();
// Reactions to reposts are not allowed // Reactions to reposts are not allowed
let maybe_row = transaction.query_opt( let maybe_row = transaction
" .query_opt(
"
INSERT INTO post_reaction (id, author_id, post_id, activity_id) INSERT INTO post_reaction (id, author_id, post_id, activity_id)
SELECT $1, $2, $3, $4 SELECT $1, $2, $3, $4
WHERE NOT EXISTS ( WHERE NOT EXISTS (
@ -34,19 +28,16 @@ pub async fn create_reaction(
) )
RETURNING post_reaction RETURNING post_reaction
", ",
&[&reaction_id, &author_id, &post_id, &activity_id], &[&reaction_id, &author_id, &post_id, &activity_id],
).await.map_err(catch_unique_violation("reaction"))?; )
.await
.map_err(catch_unique_violation("reaction"))?;
let row = maybe_row.ok_or(DatabaseError::NotFound("post"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("post"))?;
let reaction: DbReaction = row.try_get("post_reaction")?; let reaction: DbReaction = row.try_get("post_reaction")?;
update_reaction_count(&transaction, post_id, 1).await?; update_reaction_count(&transaction, post_id, 1).await?;
let post_author = get_post_author(&transaction, post_id).await?; let post_author = get_post_author(&transaction, post_id).await?;
if post_author.is_local() && post_author.id != *author_id { if post_author.is_local() && post_author.id != *author_id {
create_reaction_notification( create_reaction_notification(&transaction, author_id, &post_author.id, post_id).await?;
&transaction,
author_id,
&post_author.id,
post_id,
).await?;
}; };
transaction.commit().await?; transaction.commit().await?;
Ok(reaction) Ok(reaction)
@ -56,14 +47,16 @@ pub async fn get_reaction_by_remote_activity_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
activity_id: &str, activity_id: &str,
) -> Result<DbReaction, DatabaseError> { ) -> Result<DbReaction, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT post_reaction SELECT post_reaction
FROM post_reaction FROM post_reaction
WHERE activity_id = $1 WHERE activity_id = $1
", ",
&[&activity_id], &[&activity_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("reaction"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("reaction"))?;
let reaction = row.try_get("post_reaction")?; let reaction = row.try_get("post_reaction")?;
Ok(reaction) Ok(reaction)
@ -75,14 +68,16 @@ pub async fn delete_reaction(
post_id: &Uuid, post_id: &Uuid,
) -> Result<Uuid, DatabaseError> { ) -> Result<Uuid, DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
let maybe_row = transaction.query_opt( let maybe_row = transaction
" .query_opt(
"
DELETE FROM post_reaction DELETE FROM post_reaction
WHERE author_id = $1 AND post_id = $2 WHERE author_id = $1 AND post_id = $2
RETURNING post_reaction.id RETURNING post_reaction.id
", ",
&[&author_id, &post_id], &[&author_id, &post_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("reaction"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("reaction"))?;
let reaction_id = row.try_get("id")?; let reaction_id = row.try_get("id")?;
update_reaction_count(&transaction, post_id, -1).await?; update_reaction_count(&transaction, post_id, -1).await?;
@ -96,15 +91,18 @@ pub async fn find_favourited_by_user(
user_id: &Uuid, user_id: &Uuid,
posts_ids: &[Uuid], posts_ids: &[Uuid],
) -> Result<Vec<Uuid>, DatabaseError> { ) -> Result<Vec<Uuid>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT post_id SELECT post_id
FROM post_reaction FROM post_reaction
WHERE author_id = $1 AND post_id = ANY($2) WHERE author_id = $1 AND post_id = ANY($2)
", ",
&[&user_id, &posts_ids], &[&user_id, &posts_ids],
).await?; )
let favourites: Vec<Uuid> = rows.iter() .await?;
let favourites: Vec<Uuid> = rows
.iter()
.map(|row| row.try_get("post_id")) .map(|row| row.try_get("post_id"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(favourites) Ok(favourites)

View file

@ -2,27 +2,15 @@ use uuid::Uuid;
use mitra_utils::id::generate_ulid; use mitra_utils::id::generate_ulid;
use crate::database::{ use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::notifications::queries::create_follow_notification; use crate::notifications::queries::create_follow_notification;
use crate::profiles::{ use crate::profiles::{
queries::{ queries::{update_follower_count, update_following_count, update_subscriber_count},
update_follower_count,
update_following_count,
update_subscriber_count,
},
types::DbActorProfile, types::DbActorProfile,
}; };
use super::types::{ use super::types::{
DbFollowRequest, DbFollowRequest, DbRelationship, FollowRequestStatus, RelatedActorProfile, RelationshipType,
DbRelationship,
FollowRequestStatus,
RelatedActorProfile,
RelationshipType,
}; };
pub async fn get_relationships( pub async fn get_relationships(
@ -30,8 +18,9 @@ pub async fn get_relationships(
source_id: &Uuid, source_id: &Uuid,
target_id: &Uuid, target_id: &Uuid,
) -> Result<Vec<DbRelationship>, DatabaseError> { ) -> Result<Vec<DbRelationship>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT source_id, target_id, relationship_type SELECT source_id, target_id, relationship_type
FROM relationship FROM relationship
WHERE WHERE
@ -45,14 +34,16 @@ pub async fn get_relationships(
source_id = $1 AND target_id = $2 source_id = $1 AND target_id = $2
AND request_status = $3 AND request_status = $3
", ",
&[ &[
&source_id, &source_id,
&target_id, &target_id,
&FollowRequestStatus::Pending, &FollowRequestStatus::Pending,
&RelationshipType::FollowRequest, &RelationshipType::FollowRequest,
], ],
).await?; )
let relationships = rows.iter() .await?;
let relationships = rows
.iter()
.map(DbRelationship::try_from) .map(DbRelationship::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(relationships) Ok(relationships)
@ -64,20 +55,18 @@ pub async fn has_relationship(
target_id: &Uuid, target_id: &Uuid,
relationship_type: RelationshipType, relationship_type: RelationshipType,
) -> Result<bool, DatabaseError> { ) -> Result<bool, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT 1 SELECT 1
FROM relationship FROM relationship
WHERE WHERE
source_id = $1 AND target_id = $2 source_id = $1 AND target_id = $2
AND relationship_type = $3 AND relationship_type = $3
", ",
&[ &[&source_id, &target_id, &relationship_type],
&source_id, )
&target_id, .await?;
&relationship_type,
],
).await?;
Ok(maybe_row.is_some()) Ok(maybe_row.is_some())
} }
@ -87,13 +76,16 @@ pub async fn follow(
target_id: &Uuid, target_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
transaction.execute( transaction
" .execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type) INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
", ",
&[&source_id, &target_id, &RelationshipType::Follow], &[&source_id, &target_id, &RelationshipType::Follow],
).await.map_err(catch_unique_violation("relationship"))?; )
.await
.map_err(catch_unique_violation("relationship"))?;
let target_profile = update_follower_count(&transaction, target_id, 1).await?; let target_profile = update_follower_count(&transaction, target_id, 1).await?;
update_following_count(&transaction, source_id, 1).await?; update_following_count(&transaction, source_id, 1).await?;
if target_profile.is_local() { if target_profile.is_local() {
@ -109,22 +101,21 @@ pub async fn unfollow(
target_id: &Uuid, target_id: &Uuid,
) -> Result<Option<Uuid>, DatabaseError> { ) -> Result<Option<Uuid>, DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
let deleted_count = transaction.execute( let deleted_count = transaction
" .execute(
"
DELETE FROM relationship DELETE FROM relationship
WHERE WHERE
source_id = $1 AND target_id = $2 source_id = $1 AND target_id = $2
AND relationship_type = $3 AND relationship_type = $3
", ",
&[&source_id, &target_id, &RelationshipType::Follow], &[&source_id, &target_id, &RelationshipType::Follow],
).await?; )
.await?;
let relationship_deleted = deleted_count > 0; let relationship_deleted = deleted_count > 0;
// Delete follow request (for remote follows) // Delete follow request (for remote follows)
let follow_request_deleted = delete_follow_request_opt( let follow_request_deleted =
&transaction, delete_follow_request_opt(&transaction, source_id, target_id).await?;
source_id,
target_id,
).await?;
if !relationship_deleted && follow_request_deleted.is_none() { if !relationship_deleted && follow_request_deleted.is_none() {
return Err(DatabaseError::NotFound("relationship")); return Err(DatabaseError::NotFound("relationship"));
}; };
@ -147,21 +138,24 @@ pub async fn create_follow_request(
target_id: &Uuid, target_id: &Uuid,
) -> Result<DbFollowRequest, DatabaseError> { ) -> Result<DbFollowRequest, DatabaseError> {
let request_id = generate_ulid(); let request_id = generate_ulid();
let row = db_client.query_one( let row = db_client
" .query_one(
"
INSERT INTO follow_request ( INSERT INTO follow_request (
id, source_id, target_id, request_status id, source_id, target_id, request_status
) )
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
RETURNING follow_request RETURNING follow_request
", ",
&[ &[
&request_id, &request_id,
&source_id, &source_id,
&target_id, &target_id,
&FollowRequestStatus::Pending, &FollowRequestStatus::Pending,
], ],
).await.map_err(catch_unique_violation("follow request"))?; )
.await
.map_err(catch_unique_violation("follow request"))?;
let request = row.try_get("follow_request")?; let request = row.try_get("follow_request")?;
Ok(request) Ok(request)
} }
@ -174,8 +168,9 @@ pub async fn create_remote_follow_request_opt(
activity_id: &str, activity_id: &str,
) -> Result<DbFollowRequest, DatabaseError> { ) -> Result<DbFollowRequest, DatabaseError> {
let request_id = generate_ulid(); let request_id = generate_ulid();
let row = db_client.query_one( let row = db_client
" .query_one(
"
INSERT INTO follow_request ( INSERT INTO follow_request (
id, id,
source_id, source_id,
@ -188,14 +183,15 @@ pub async fn create_remote_follow_request_opt(
DO UPDATE SET activity_id = $4 DO UPDATE SET activity_id = $4
RETURNING follow_request RETURNING follow_request
", ",
&[ &[
&request_id, &request_id,
&source_id, &source_id,
&target_id, &target_id,
&activity_id, &activity_id,
&FollowRequestStatus::Pending, &FollowRequestStatus::Pending,
], ],
).await?; )
.await?;
let request = row.try_get("follow_request")?; let request = row.try_get("follow_request")?;
Ok(request) Ok(request)
} }
@ -205,15 +201,17 @@ pub async fn follow_request_accepted(
request_id: &Uuid, request_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let mut transaction = db_client.transaction().await?; let mut transaction = db_client.transaction().await?;
let maybe_row = transaction.query_opt( let maybe_row = transaction
" .query_opt(
"
UPDATE follow_request UPDATE follow_request
SET request_status = $1 SET request_status = $1
WHERE id = $2 WHERE id = $2
RETURNING source_id, target_id RETURNING source_id, target_id
", ",
&[&FollowRequestStatus::Accepted, &request_id], &[&FollowRequestStatus::Accepted, &request_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?;
let source_id: Uuid = row.try_get("source_id")?; let source_id: Uuid = row.try_get("source_id")?;
let target_id: Uuid = row.try_get("target_id")?; let target_id: Uuid = row.try_get("target_id")?;
@ -226,14 +224,16 @@ pub async fn follow_request_rejected(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
request_id: &Uuid, request_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let updated_count = db_client.execute( let updated_count = db_client
" .execute(
"
UPDATE follow_request UPDATE follow_request
SET request_status = $1 SET request_status = $1
WHERE id = $2 WHERE id = $2
", ",
&[&FollowRequestStatus::Rejected, &request_id], &[&FollowRequestStatus::Rejected, &request_id],
).await?; )
.await?;
if updated_count == 0 { if updated_count == 0 {
return Err(DatabaseError::NotFound("follow request")); return Err(DatabaseError::NotFound("follow request"));
} }
@ -245,33 +245,39 @@ async fn delete_follow_request_opt(
source_id: &Uuid, source_id: &Uuid,
target_id: &Uuid, target_id: &Uuid,
) -> Result<Option<Uuid>, DatabaseError> { ) -> Result<Option<Uuid>, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
DELETE FROM follow_request DELETE FROM follow_request
WHERE source_id = $1 AND target_id = $2 WHERE source_id = $1 AND target_id = $2
RETURNING id RETURNING id
", ",
&[&source_id, &target_id], &[&source_id, &target_id],
).await?; )
.await?;
let maybe_request_id = if let Some(row) = maybe_row { let maybe_request_id = if let Some(row) = maybe_row {
let request_id: Uuid = row.try_get("id")?; let request_id: Uuid = row.try_get("id")?;
Some(request_id) Some(request_id)
} else { None }; } else {
None
};
Ok(maybe_request_id) Ok(maybe_request_id)
} }
pub async fn get_follow_request_by_id( pub async fn get_follow_request_by_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
request_id: &Uuid, request_id: &Uuid,
) -> Result<DbFollowRequest, DatabaseError> { ) -> Result<DbFollowRequest, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT follow_request SELECT follow_request
FROM follow_request FROM follow_request
WHERE id = $1 WHERE id = $1
", ",
&[&request_id], &[&request_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?;
let request = row.try_get("follow_request")?; let request = row.try_get("follow_request")?;
Ok(request) Ok(request)
@ -281,14 +287,16 @@ pub async fn get_follow_request_by_activity_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
activity_id: &str, activity_id: &str,
) -> Result<DbFollowRequest, DatabaseError> { ) -> Result<DbFollowRequest, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT follow_request SELECT follow_request
FROM follow_request FROM follow_request
WHERE activity_id = $1 WHERE activity_id = $1
", ",
&[&activity_id], &[&activity_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?;
let request = row.try_get("follow_request")?; let request = row.try_get("follow_request")?;
Ok(request) Ok(request)
@ -298,8 +306,9 @@ pub async fn get_followers(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
profile_id: &Uuid, profile_id: &Uuid,
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
JOIN relationship JOIN relationship
@ -308,9 +317,11 @@ pub async fn get_followers(
relationship.target_id = $1 relationship.target_id = $1
AND relationship.relationship_type = $2 AND relationship.relationship_type = $2
", ",
&[&profile_id, &RelationshipType::Follow], &[&profile_id, &RelationshipType::Follow],
).await?; )
let profiles = rows.iter() .await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(profiles) Ok(profiles)
@ -322,8 +333,9 @@ pub async fn get_followers_paginated(
max_relationship_id: Option<i32>, max_relationship_id: Option<i32>,
limit: u16, limit: u16,
) -> Result<Vec<RelatedActorProfile>, DatabaseError> { ) -> Result<Vec<RelatedActorProfile>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT relationship.id, actor_profile SELECT relationship.id, actor_profile
FROM actor_profile FROM actor_profile
JOIN relationship JOIN relationship
@ -335,14 +347,16 @@ pub async fn get_followers_paginated(
ORDER BY relationship.id DESC ORDER BY relationship.id DESC
LIMIT $4 LIMIT $4
", ",
&[ &[
&profile_id, &profile_id,
&RelationshipType::Follow, &RelationshipType::Follow,
&max_relationship_id, &max_relationship_id,
&i64::from(limit), &i64::from(limit),
], ],
).await?; )
let related_profiles = rows.iter() .await?;
let related_profiles = rows
.iter()
.map(RelatedActorProfile::try_from) .map(RelatedActorProfile::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(related_profiles) Ok(related_profiles)
@ -352,8 +366,9 @@ pub async fn has_local_followers(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
actor_id: &str, actor_id: &str,
) -> Result<bool, DatabaseError> { ) -> Result<bool, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT 1 SELECT 1
FROM relationship FROM relationship
JOIN actor_profile ON (relationship.target_id = actor_profile.id) JOIN actor_profile ON (relationship.target_id = actor_profile.id)
@ -362,8 +377,9 @@ pub async fn has_local_followers(
AND relationship_type = $2 AND relationship_type = $2
LIMIT 1 LIMIT 1
", ",
&[&actor_id, &RelationshipType::Follow] &[&actor_id, &RelationshipType::Follow],
).await?; )
.await?;
Ok(maybe_row.is_some()) Ok(maybe_row.is_some())
} }
@ -371,8 +387,9 @@ pub async fn get_following(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
profile_id: &Uuid, profile_id: &Uuid,
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
JOIN relationship JOIN relationship
@ -381,9 +398,11 @@ pub async fn get_following(
relationship.source_id = $1 relationship.source_id = $1
AND relationship.relationship_type = $2 AND relationship.relationship_type = $2
", ",
&[&profile_id, &RelationshipType::Follow], &[&profile_id, &RelationshipType::Follow],
).await?; )
let profiles = rows.iter() .await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(profiles) Ok(profiles)
@ -395,8 +414,9 @@ pub async fn get_following_paginated(
max_relationship_id: Option<i32>, max_relationship_id: Option<i32>,
limit: u16, limit: u16,
) -> Result<Vec<RelatedActorProfile>, DatabaseError> { ) -> Result<Vec<RelatedActorProfile>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT relationship.id, actor_profile SELECT relationship.id, actor_profile
FROM actor_profile FROM actor_profile
JOIN relationship JOIN relationship
@ -408,14 +428,16 @@ pub async fn get_following_paginated(
ORDER BY relationship.id DESC ORDER BY relationship.id DESC
LIMIT $4 LIMIT $4
", ",
&[ &[
&profile_id, &profile_id,
&RelationshipType::Follow, &RelationshipType::Follow,
&max_relationship_id, &max_relationship_id,
&i64::from(limit), &i64::from(limit),
], ],
).await?; )
let related_profiles = rows.iter() .await?;
let related_profiles = rows
.iter()
.map(RelatedActorProfile::try_from) .map(RelatedActorProfile::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(related_profiles) Ok(related_profiles)
@ -427,13 +449,16 @@ pub async fn subscribe(
target_id: &Uuid, target_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
transaction.execute( transaction
" .execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type) INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
", ",
&[&source_id, &target_id, &RelationshipType::Subscription], &[&source_id, &target_id, &RelationshipType::Subscription],
).await.map_err(catch_unique_violation("relationship"))?; )
.await
.map_err(catch_unique_violation("relationship"))?;
update_subscriber_count(&transaction, target_id, 1).await?; update_subscriber_count(&transaction, target_id, 1).await?;
transaction.commit().await?; transaction.commit().await?;
Ok(()) Ok(())
@ -445,14 +470,16 @@ pub async fn subscribe_opt(
target_id: &Uuid, target_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
let inserted_count = transaction.execute( let inserted_count = transaction
" .execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type) INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING
", ",
&[&source_id, &target_id, &RelationshipType::Subscription], &[&source_id, &target_id, &RelationshipType::Subscription],
).await?; )
.await?;
if inserted_count > 0 { if inserted_count > 0 {
update_subscriber_count(&transaction, target_id, 1).await?; update_subscriber_count(&transaction, target_id, 1).await?;
}; };
@ -466,15 +493,17 @@ pub async fn unsubscribe(
target_id: &Uuid, target_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
let deleted_count = transaction.execute( let deleted_count = transaction
" .execute(
"
DELETE FROM relationship DELETE FROM relationship
WHERE WHERE
source_id = $1 AND target_id = $2 source_id = $1 AND target_id = $2
AND relationship_type = $3 AND relationship_type = $3
", ",
&[&source_id, &target_id, &RelationshipType::Subscription], &[&source_id, &target_id, &RelationshipType::Subscription],
).await?; )
.await?;
if deleted_count == 0 { if deleted_count == 0 {
return Err(DatabaseError::NotFound("relationship")); return Err(DatabaseError::NotFound("relationship"));
}; };
@ -487,8 +516,9 @@ pub async fn get_subscribers(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
profile_id: &Uuid, profile_id: &Uuid,
) -> Result<Vec<DbActorProfile>, DatabaseError> { ) -> Result<Vec<DbActorProfile>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT actor_profile SELECT actor_profile
FROM actor_profile FROM actor_profile
JOIN relationship JOIN relationship
@ -498,9 +528,11 @@ pub async fn get_subscribers(
AND relationship.relationship_type = $2 AND relationship.relationship_type = $2
ORDER BY relationship.id DESC ORDER BY relationship.id DESC
", ",
&[&profile_id, &RelationshipType::Subscription], &[&profile_id, &RelationshipType::Subscription],
).await?; )
let profiles = rows.iter() .await?;
let profiles = rows
.iter()
.map(|row| row.try_get("actor_profile")) .map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(profiles) Ok(profiles)
@ -511,14 +543,16 @@ pub async fn hide_reposts(
source_id: &Uuid, source_id: &Uuid,
target_id: &Uuid, target_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
db_client.execute( db_client
" .execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type) INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING
", ",
&[&source_id, &target_id, &RelationshipType::HideReposts], &[&source_id, &target_id, &RelationshipType::HideReposts],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -528,15 +562,17 @@ pub async fn show_reposts(
target_id: &Uuid, target_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
// Does not return NotFound error // Does not return NotFound error
db_client.execute( db_client
" .execute(
"
DELETE FROM relationship DELETE FROM relationship
WHERE WHERE
source_id = $1 AND target_id = $2 source_id = $1 AND target_id = $2
AND relationship_type = $3 AND relationship_type = $3
", ",
&[&source_id, &target_id, &RelationshipType::HideReposts], &[&source_id, &target_id, &RelationshipType::HideReposts],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -545,14 +581,16 @@ pub async fn hide_replies(
source_id: &Uuid, source_id: &Uuid,
target_id: &Uuid, target_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
db_client.execute( db_client
" .execute(
"
INSERT INTO relationship (source_id, target_id, relationship_type) INSERT INTO relationship (source_id, target_id, relationship_type)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING ON CONFLICT (source_id, target_id, relationship_type) DO NOTHING
", ",
&[&source_id, &target_id, &RelationshipType::HideReplies], &[&source_id, &target_id, &RelationshipType::HideReplies],
).await?; )
.await?;
Ok(()) Ok(())
} }
@ -562,34 +600,30 @@ pub async fn show_replies(
target_id: &Uuid, target_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
// Does not return NotFound error // Does not return NotFound error
db_client.execute( db_client
" .execute(
"
DELETE FROM relationship DELETE FROM relationship
WHERE WHERE
source_id = $1 AND target_id = $2 source_id = $1 AND target_id = $2
AND relationship_type = $3 AND relationship_type = $3
", ",
&[&source_id, &target_id, &RelationshipType::HideReplies], &[&source_id, &target_id, &RelationshipType::HideReplies],
).await?; )
.await?;
Ok(()) Ok(())
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial; use super::*;
use crate::database::{ use crate::database::{test_utils::create_test_database, DatabaseError};
test_utils::create_test_database,
DatabaseError,
};
use crate::profiles::{ use crate::profiles::{
queries::create_profile, queries::create_profile,
types::{DbActor, ProfileCreateData}, types::{DbActor, ProfileCreateData},
}; };
use crate::users::{ use crate::users::{queries::create_user, types::UserCreateData};
queries::create_user, use serial_test::serial;
types::UserCreateData,
};
use super::*;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -614,7 +648,8 @@ mod tests {
let target = create_profile(db_client, target_data).await.unwrap(); let target = create_profile(db_client, target_data).await.unwrap();
// Create follow request // Create follow request
let follow_request = create_follow_request(db_client, &source.id, &target.id) let follow_request = create_follow_request(db_client, &source.id, &target.id)
.await.unwrap(); .await
.unwrap();
assert_eq!(follow_request.source_id, source.id); assert_eq!(follow_request.source_id, source.id);
assert_eq!(follow_request.target_id, target.id); assert_eq!(follow_request.target_id, target.id);
assert_eq!(follow_request.activity_id, None); assert_eq!(follow_request.activity_id, None);
@ -622,22 +657,27 @@ mod tests {
let following = get_following(db_client, &source.id).await.unwrap(); let following = get_following(db_client, &source.id).await.unwrap();
assert!(following.is_empty()); assert!(following.is_empty());
// Accept follow request // Accept follow request
follow_request_accepted(db_client, &follow_request.id).await.unwrap(); follow_request_accepted(db_client, &follow_request.id)
.await
.unwrap();
let follow_request = get_follow_request_by_id(db_client, &follow_request.id) let follow_request = get_follow_request_by_id(db_client, &follow_request.id)
.await.unwrap(); .await
.unwrap();
assert_eq!(follow_request.request_status, FollowRequestStatus::Accepted); assert_eq!(follow_request.request_status, FollowRequestStatus::Accepted);
let following = get_following(db_client, &source.id).await.unwrap(); let following = get_following(db_client, &source.id).await.unwrap();
assert_eq!(following[0].id, target.id); assert_eq!(following[0].id, target.id);
let target_has_followers = let target_has_followers = has_local_followers(db_client, target_actor_id)
has_local_followers(db_client, target_actor_id).await.unwrap(); .await
.unwrap();
assert_eq!(target_has_followers, true); assert_eq!(target_has_followers, true);
// Unfollow // Unfollow
let follow_request_id = unfollow(db_client, &source.id, &target.id) let follow_request_id = unfollow(db_client, &source.id, &target.id)
.await.unwrap().unwrap(); .await
.unwrap()
.unwrap();
assert_eq!(follow_request_id, follow_request.id); assert_eq!(follow_request_id, follow_request.id);
let follow_request_result = let follow_request_result = get_follow_request_by_id(db_client, &follow_request_id).await;
get_follow_request_by_id(db_client, &follow_request_id).await;
assert!(matches!( assert!(matches!(
follow_request_result, follow_request_result,
Err(DatabaseError::NotFound("follow request")), Err(DatabaseError::NotFound("follow request")),
@ -665,21 +705,26 @@ mod tests {
let target = create_user(db_client, target_data).await.unwrap(); let target = create_user(db_client, target_data).await.unwrap();
// Create follow request // Create follow request
let activity_id = "https://example.org/objects/123"; let activity_id = "https://example.org/objects/123";
let _follow_request = create_remote_follow_request_opt( let _follow_request =
db_client, &source.id, &target.id, activity_id, create_remote_follow_request_opt(db_client, &source.id, &target.id, activity_id)
).await.unwrap(); .await
.unwrap();
// Repeat // Repeat
let follow_request = create_remote_follow_request_opt( let follow_request =
db_client, &source.id, &target.id, activity_id, create_remote_follow_request_opt(db_client, &source.id, &target.id, activity_id)
).await.unwrap(); .await
.unwrap();
assert_eq!(follow_request.source_id, source.id); assert_eq!(follow_request.source_id, source.id);
assert_eq!(follow_request.target_id, target.id); assert_eq!(follow_request.target_id, target.id);
assert_eq!(follow_request.activity_id, Some(activity_id.to_string())); assert_eq!(follow_request.activity_id, Some(activity_id.to_string()));
assert_eq!(follow_request.request_status, FollowRequestStatus::Pending); assert_eq!(follow_request.request_status, FollowRequestStatus::Pending);
// Accept follow request // Accept follow request
follow_request_accepted(db_client, &follow_request.id).await.unwrap(); follow_request_accepted(db_client, &follow_request.id)
.await
.unwrap();
let follow_request = get_follow_request_by_id(db_client, &follow_request.id) let follow_request = get_follow_request_by_id(db_client, &follow_request.id)
.await.unwrap(); .await
.unwrap();
assert_eq!(follow_request.request_status, FollowRequestStatus::Accepted); assert_eq!(follow_request.request_status, FollowRequestStatus::Accepted);
} }
} }

View file

@ -5,8 +5,7 @@ use uuid::Uuid;
use crate::{ use crate::{
database::{ database::{
int_enum::{int_enum_from_sql, int_enum_to_sql}, int_enum::{int_enum_from_sql, int_enum_to_sql},
DatabaseError, DatabaseError, DatabaseTypeError,
DatabaseTypeError,
}, },
profiles::types::DbActorProfile, profiles::types::DbActorProfile,
}; };
@ -58,11 +57,7 @@ pub struct DbRelationship {
} }
impl DbRelationship { impl DbRelationship {
pub fn is_direct( pub fn is_direct(&self, source_id: &Uuid, target_id: &Uuid) -> Result<bool, DatabaseTypeError> {
&self,
source_id: &Uuid,
target_id: &Uuid,
) -> Result<bool, DatabaseTypeError> {
if &self.source_id == source_id && &self.target_id == target_id { if &self.source_id == source_id && &self.target_id == target_id {
Ok(true) Ok(true)
} else if &self.source_id == target_id && &self.target_id == source_id { } else if &self.source_id == target_id && &self.target_id == source_id {
@ -74,7 +69,6 @@ impl DbRelationship {
} }
impl TryFrom<&Row> for DbRelationship { impl TryFrom<&Row> for DbRelationship {
type Error = tokio_postgres::Error; type Error = tokio_postgres::Error;
fn try_from(row: &Row) -> Result<Self, Self::Error> { fn try_from(row: &Row) -> Result<Self, Self::Error> {
@ -97,7 +91,7 @@ pub enum FollowRequestStatus {
impl From<&FollowRequestStatus> for i16 { impl From<&FollowRequestStatus> for i16 {
fn from(value: &FollowRequestStatus) -> i16 { fn from(value: &FollowRequestStatus) -> i16 {
match value { match value {
FollowRequestStatus::Pending => 1, FollowRequestStatus::Pending => 1,
FollowRequestStatus::Accepted => 2, FollowRequestStatus::Accepted => 2,
FollowRequestStatus::Rejected => 3, FollowRequestStatus::Rejected => 3,
} }
@ -137,12 +131,14 @@ pub struct RelatedActorProfile {
} }
impl TryFrom<&Row> for RelatedActorProfile { impl TryFrom<&Row> for RelatedActorProfile {
type Error = DatabaseError; type Error = DatabaseError;
fn try_from(row: &Row) -> Result<Self, Self::Error> { fn try_from(row: &Row) -> Result<Self, Self::Error> {
let relationship_id = row.try_get("id")?; let relationship_id = row.try_get("id")?;
let profile = row.try_get("actor_profile")?; let profile = row.try_get("actor_profile")?;
Ok(Self { relationship_id, profile }) Ok(Self {
relationship_id,
profile,
})
} }
} }

View file

@ -3,11 +3,7 @@ use uuid::Uuid;
use mitra_utils::caip2::ChainId; use mitra_utils::caip2::ChainId;
use crate::database::{ use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::invoices::types::DbChainId; use crate::invoices::types::DbChainId;
use crate::profiles::types::PaymentType; use crate::profiles::types::PaymentType;
use crate::relationships::{ use crate::relationships::{
@ -28,8 +24,9 @@ pub async fn create_subscription(
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
assert!(chain_id.is_ethereum() == sender_address.is_some()); assert!(chain_id.is_ethereum() == sender_address.is_some());
let mut transaction = db_client.transaction().await?; let mut transaction = db_client.transaction().await?;
transaction.execute( transaction
" .execute(
"
INSERT INTO subscription ( INSERT INTO subscription (
sender_id, sender_id,
sender_address, sender_address,
@ -40,15 +37,17 @@ pub async fn create_subscription(
) )
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
", ",
&[ &[
&sender_id, &sender_id,
&sender_address, &sender_address,
&recipient_id, &recipient_id,
&DbChainId::new(chain_id), &DbChainId::new(chain_id),
&expires_at, &expires_at,
&updated_at, &updated_at,
], ],
).await.map_err(catch_unique_violation("subscription"))?; )
.await
.map_err(catch_unique_violation("subscription"))?;
subscribe(&mut transaction, sender_id, recipient_id).await?; subscribe(&mut transaction, sender_id, recipient_id).await?;
transaction.commit().await?; transaction.commit().await?;
Ok(()) Ok(())
@ -61,8 +60,9 @@ pub async fn update_subscription(
updated_at: &DateTime<Utc>, updated_at: &DateTime<Utc>,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let mut transaction = db_client.transaction().await?; let mut transaction = db_client.transaction().await?;
let maybe_row = transaction.query_opt( let maybe_row = transaction
" .query_opt(
"
UPDATE subscription UPDATE subscription
SET SET
expires_at = $2, expires_at = $2,
@ -70,12 +70,9 @@ pub async fn update_subscription(
WHERE id = $1 WHERE id = $1
RETURNING sender_id, recipient_id RETURNING sender_id, recipient_id
", ",
&[ &[&subscription_id, &expires_at, &updated_at],
&subscription_id, )
&expires_at, .await?;
&updated_at,
],
).await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("subscription"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("subscription"))?;
let sender_id: Uuid = row.try_get("sender_id")?; let sender_id: Uuid = row.try_get("sender_id")?;
let recipient_id: Uuid = row.try_get("recipient_id")?; let recipient_id: Uuid = row.try_get("recipient_id")?;
@ -91,14 +88,16 @@ pub async fn get_subscription_by_participants(
sender_id: &Uuid, sender_id: &Uuid,
recipient_id: &Uuid, recipient_id: &Uuid,
) -> Result<DbSubscription, DatabaseError> { ) -> Result<DbSubscription, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT subscription SELECT subscription
FROM subscription FROM subscription
WHERE sender_id = $1 AND recipient_id = $2 WHERE sender_id = $1 AND recipient_id = $2
", ",
&[sender_id, recipient_id], &[sender_id, recipient_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("subscription"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("subscription"))?;
let subscription: DbSubscription = row.try_get("subscription")?; let subscription: DbSubscription = row.try_get("subscription")?;
Ok(subscription) Ok(subscription)
@ -107,8 +106,9 @@ pub async fn get_subscription_by_participants(
pub async fn get_expired_subscriptions( pub async fn get_expired_subscriptions(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
) -> Result<Vec<DbSubscription>, DatabaseError> { ) -> Result<Vec<DbSubscription>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT subscription SELECT subscription
FROM subscription FROM subscription
JOIN relationship JOIN relationship
@ -119,9 +119,11 @@ pub async fn get_expired_subscriptions(
) )
WHERE subscription.expires_at <= CURRENT_TIMESTAMP WHERE subscription.expires_at <= CURRENT_TIMESTAMP
", ",
&[&RelationshipType::Subscription], &[&RelationshipType::Subscription],
).await?; )
let subscriptions = rows.iter() .await?;
let subscriptions = rows
.iter()
.map(|row| row.try_get("subscription")) .map(|row| row.try_get("subscription"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(subscriptions) Ok(subscriptions)
@ -133,8 +135,9 @@ pub async fn get_incoming_subscriptions(
max_subscription_id: Option<i32>, max_subscription_id: Option<i32>,
limit: u16, limit: u16,
) -> Result<Vec<Subscription>, DatabaseError> { ) -> Result<Vec<Subscription>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT subscription, actor_profile AS sender SELECT subscription, actor_profile AS sender
FROM actor_profile FROM actor_profile
JOIN subscription JOIN subscription
@ -145,9 +148,11 @@ pub async fn get_incoming_subscriptions(
ORDER BY subscription.id DESC ORDER BY subscription.id DESC
LIMIT $3 LIMIT $3
", ",
&[&recipient_id, &max_subscription_id, &i64::from(limit)], &[&recipient_id, &max_subscription_id, &i64::from(limit)],
).await?; )
let subscriptions = rows.iter() .await?;
let subscriptions = rows
.iter()
.map(Subscription::try_from) .map(Subscription::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(subscriptions) Ok(subscriptions)
@ -161,8 +166,9 @@ pub async fn reset_subscriptions(
if ethereum_contract_replaced { if ethereum_contract_replaced {
// Ethereum subscription configuration is stored in contract. // Ethereum subscription configuration is stored in contract.
// If contract is replaced, payment option needs to be deleted. // If contract is replaced, payment option needs to be deleted.
transaction.execute( transaction
" .execute(
"
UPDATE actor_profile UPDATE actor_profile
SET payment_options = '[]' SET payment_options = '[]'
WHERE WHERE
@ -174,19 +180,22 @@ pub async fn reset_subscriptions(
WHERE CAST(option ->> 'payment_type' AS SMALLINT) = $1 WHERE CAST(option ->> 'payment_type' AS SMALLINT) = $1
) )
", ",
&[&i16::from(&PaymentType::EthereumSubscription)], &[&i16::from(&PaymentType::EthereumSubscription)],
).await?; )
.await?;
}; };
transaction.execute( transaction
" .execute(
"
DELETE FROM relationship DELETE FROM relationship
WHERE relationship_type = $1 WHERE relationship_type = $1
", ",
&[&RelationshipType::Subscription], &[&RelationshipType::Subscription],
).await?; )
transaction.execute( .await?;
"UPDATE actor_profile SET subscriber_count = 0", &[], transaction
).await?; .execute("UPDATE actor_profile SET subscriber_count = 0", &[])
.await?;
transaction.execute("DELETE FROM subscription", &[]).await?; transaction.execute("DELETE FROM subscription", &[]).await?;
transaction.commit().await?; transaction.commit().await?;
Ok(()) Ok(())
@ -194,21 +203,12 @@ pub async fn reset_subscriptions(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serial_test::serial;
use crate::database::test_utils::create_test_database;
use crate::profiles::{
queries::create_profile,
types::ProfileCreateData,
};
use crate::relationships::{
queries::has_relationship,
types::RelationshipType,
};
use crate::users::{
queries::create_user,
types::UserCreateData,
};
use super::*; use super::*;
use crate::database::test_utils::create_test_database;
use crate::profiles::{queries::create_profile, types::ProfileCreateData};
use crate::relationships::{queries::has_relationship, types::RelationshipType};
use crate::users::{queries::create_user, types::UserCreateData};
use serial_test::serial;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -237,14 +237,18 @@ mod tests {
&chain_id, &chain_id,
&expires_at, &expires_at,
&updated_at, &updated_at,
).await.unwrap(); )
.await
.unwrap();
let is_subscribed = has_relationship( let is_subscribed = has_relationship(
db_client, db_client,
&sender.id, &sender.id,
&recipient.id, &recipient.id,
RelationshipType::Subscription, RelationshipType::Subscription,
).await.unwrap(); )
.await
.unwrap();
assert_eq!(is_subscribed, true); assert_eq!(is_subscribed, true);
} }
} }

View file

@ -27,7 +27,6 @@ pub struct Subscription {
} }
impl TryFrom<&Row> for Subscription { impl TryFrom<&Row> for Subscription {
type Error = DatabaseError; type Error = DatabaseError;
fn try_from(row: &Row) -> Result<Self, Self::Error> { fn try_from(row: &Row) -> Result<Self, Self::Error> {

View file

@ -6,16 +6,19 @@ pub async fn search_tags(
limit: u16, limit: u16,
) -> Result<Vec<String>, DatabaseError> { ) -> Result<Vec<String>, DatabaseError> {
let db_search_query = format!("%{}%", search_query); let db_search_query = format!("%{}%", search_query);
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT tag_name SELECT tag_name
FROM tag FROM tag
WHERE tag_name ILIKE $1 WHERE tag_name ILIKE $1
LIMIT $2 LIMIT $2
", ",
&[&db_search_query, &i64::from(limit)], &[&db_search_query, &i64::from(limit)],
).await?; )
let tags: Vec<String> = rows.iter() .await?;
let tags: Vec<String> = rows
.iter()
.map(|row| row.try_get("tag_name")) .map(|row| row.try_get("tag_name"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(tags) Ok(tags)

View file

@ -1,30 +1,16 @@
use serde_json::{Value as JsonValue}; use serde_json::Value as JsonValue;
use uuid::Uuid; use uuid::Uuid;
use mitra_utils::{ use mitra_utils::{currencies::Currency, did::Did, did_pkh::DidPkh};
currencies::Currency,
did::Did,
did_pkh::DidPkh,
};
use crate::database::{ use crate::database::{catch_unique_violation, DatabaseClient, DatabaseError};
catch_unique_violation,
DatabaseClient,
DatabaseError,
};
use crate::profiles::{ use crate::profiles::{
queries::create_profile, queries::create_profile,
types::{DbActorProfile, ProfileCreateData}, types::{DbActorProfile, ProfileCreateData},
}; };
use super::types::{ use super::types::{
ClientConfig, ClientConfig, DbClientConfig, DbInviteCode, DbUser, Role, User, UserCreateData,
DbClientConfig,
DbInviteCode,
DbUser,
Role,
User,
UserCreateData,
}; };
use super::utils::generate_invite_code; use super::utils::generate_invite_code;
@ -33,28 +19,33 @@ pub async fn create_invite_code(
note: Option<&str>, note: Option<&str>,
) -> Result<String, DatabaseError> { ) -> Result<String, DatabaseError> {
let invite_code = generate_invite_code(); let invite_code = generate_invite_code();
db_client.execute( db_client
" .execute(
"
INSERT INTO user_invite_code (code, note) INSERT INTO user_invite_code (code, note)
VALUES ($1, $2) VALUES ($1, $2)
", ",
&[&invite_code, &note], &[&invite_code, &note],
).await?; )
.await?;
Ok(invite_code) Ok(invite_code)
} }
pub async fn get_invite_codes( pub async fn get_invite_codes(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
) -> Result<Vec<DbInviteCode>, DatabaseError> { ) -> Result<Vec<DbInviteCode>, DatabaseError> {
let rows = db_client.query( let rows = db_client
" .query(
"
SELECT user_invite_code SELECT user_invite_code
FROM user_invite_code FROM user_invite_code
WHERE used = FALSE WHERE used = FALSE
", ",
&[], &[],
).await?; )
let codes = rows.iter() .await?;
let codes = rows
.iter()
.map(|row| row.try_get("user_invite_code")) .map(|row| row.try_get("user_invite_code"))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
Ok(codes) Ok(codes)
@ -64,13 +55,15 @@ pub async fn is_valid_invite_code(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
invite_code: &str, invite_code: &str,
) -> Result<bool, DatabaseError> { ) -> Result<bool, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT 1 FROM user_invite_code SELECT 1 FROM user_invite_code
WHERE code = $1 AND used = FALSE WHERE code = $1 AND used = FALSE
", ",
&[&invite_code], &[&invite_code],
).await?; )
.await?;
Ok(maybe_row.is_some()) Ok(maybe_row.is_some())
} }
@ -78,37 +71,39 @@ pub async fn create_user(
db_client: &mut impl DatabaseClient, db_client: &mut impl DatabaseClient,
user_data: UserCreateData, user_data: UserCreateData,
) -> Result<User, DatabaseError> { ) -> Result<User, DatabaseError> {
assert!(user_data.password_hash.is_some() || assert!(user_data.password_hash.is_some() || user_data.wallet_address.is_some());
user_data.wallet_address.is_some());
let mut transaction = db_client.transaction().await?; let mut transaction = db_client.transaction().await?;
// Prevent changes to actor_profile table // Prevent changes to actor_profile table
transaction.execute( transaction
"LOCK TABLE actor_profile IN EXCLUSIVE MODE", .execute("LOCK TABLE actor_profile IN EXCLUSIVE MODE", &[])
&[], .await?;
).await?;
// Ensure there are no local accounts with a similar name // Ensure there are no local accounts with a similar name
let maybe_row = transaction.query_opt( let maybe_row = transaction
" .query_opt(
"
SELECT 1 SELECT 1
FROM user_account JOIN actor_profile USING (id) FROM user_account JOIN actor_profile USING (id)
WHERE actor_profile.username ILIKE $1 WHERE actor_profile.username ILIKE $1
LIMIT 1 LIMIT 1
", ",
&[&user_data.username], &[&user_data.username],
).await?; )
.await?;
if maybe_row.is_some() { if maybe_row.is_some() {
return Err(DatabaseError::AlreadyExists("user")); return Err(DatabaseError::AlreadyExists("user"));
}; };
// Use invite code // Use invite code
if let Some(ref invite_code) = user_data.invite_code { if let Some(ref invite_code) = user_data.invite_code {
let updated_count = transaction.execute( let updated_count = transaction
" .execute(
"
UPDATE user_invite_code UPDATE user_invite_code
SET used = TRUE SET used = TRUE
WHERE code = $1 AND used = FALSE WHERE code = $1 AND used = FALSE
", ",
&[&invite_code], &[&invite_code],
).await?; )
.await?;
if updated_count == 0 { if updated_count == 0 {
return Err(DatabaseError::NotFound("invite code")); return Err(DatabaseError::NotFound("invite code"));
}; };
@ -131,8 +126,9 @@ pub async fn create_user(
}; };
let profile = create_profile(&mut transaction, profile_data).await?; let profile = create_profile(&mut transaction, profile_data).await?;
// Create user // Create user
let row = transaction.query_one( let row = transaction
" .query_one(
"
INSERT INTO user_account ( INSERT INTO user_account (
id, id,
wallet_address, wallet_address,
@ -144,15 +140,17 @@ pub async fn create_user(
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING user_account RETURNING user_account
", ",
&[ &[
&profile.id, &profile.id,
&user_data.wallet_address, &user_data.wallet_address,
&user_data.password_hash, &user_data.password_hash,
&user_data.private_key_pem, &user_data.private_key_pem,
&user_data.invite_code, &user_data.invite_code,
&user_data.role, &user_data.role,
], ],
).await.map_err(catch_unique_violation("user"))?; )
.await
.map_err(catch_unique_violation("user"))?;
let db_user: DbUser = row.try_get("user_account")?; let db_user: DbUser = row.try_get("user_account")?;
let user = User::new(db_user, profile); let user = User::new(db_user, profile);
transaction.commit().await?; transaction.commit().await?;
@ -164,13 +162,15 @@ pub async fn set_user_password(
user_id: &Uuid, user_id: &Uuid,
password_hash: String, password_hash: String,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let updated_count = db_client.execute( let updated_count = db_client
" .execute(
"
UPDATE user_account SET password_hash = $1 UPDATE user_account SET password_hash = $1
WHERE id = $2 WHERE id = $2
", ",
&[&password_hash, &user_id], &[&password_hash, &user_id],
).await?; )
.await?;
if updated_count == 0 { if updated_count == 0 {
return Err(DatabaseError::NotFound("user")); return Err(DatabaseError::NotFound("user"));
}; };
@ -182,13 +182,15 @@ pub async fn set_user_role(
user_id: &Uuid, user_id: &Uuid,
role: Role, role: Role,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let updated_count = db_client.execute( let updated_count = db_client
" .execute(
"
UPDATE user_account SET user_role = $1 UPDATE user_account SET user_role = $1
WHERE id = $2 WHERE id = $2
", ",
&[&role, &user_id], &[&role, &user_id],
).await?; )
.await?;
if updated_count == 0 { if updated_count == 0 {
return Err(DatabaseError::NotFound("user")); return Err(DatabaseError::NotFound("user"));
}; };
@ -201,15 +203,17 @@ pub async fn update_client_config(
client_name: &str, client_name: &str,
client_config_value: &JsonValue, client_config_value: &JsonValue,
) -> Result<ClientConfig, DatabaseError> { ) -> Result<ClientConfig, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
UPDATE user_account UPDATE user_account
SET client_config = jsonb_set(client_config, ARRAY[$1], $2, true) SET client_config = jsonb_set(client_config, ARRAY[$1], $2, true)
WHERE id = $3 WHERE id = $3
RETURNING client_config RETURNING client_config
", ",
&[&client_name, &client_config_value, &user_id], &[&client_name, &client_config_value, &user_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let client_config: DbClientConfig = row.try_get("client_config")?; let client_config: DbClientConfig = row.try_get("client_config")?;
Ok(client_config.into_inner()) Ok(client_config.into_inner())
@ -219,14 +223,16 @@ pub async fn get_user_by_id(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
user_id: &Uuid, user_id: &Uuid,
) -> Result<User, DatabaseError> { ) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT user_account, actor_profile SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id) FROM user_account JOIN actor_profile USING (id)
WHERE id = $1 WHERE id = $1
", ",
&[&user_id], &[&user_id],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?; let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?; let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -238,14 +244,16 @@ pub async fn get_user_by_name(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
username: &str, username: &str,
) -> Result<User, DatabaseError> { ) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT user_account, actor_profile SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id) FROM user_account JOIN actor_profile USING (id)
WHERE actor_profile.username = $1 WHERE actor_profile.username = $1
", ",
&[&username], &[&username],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?; let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?; let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -257,13 +265,15 @@ pub async fn is_registered_user(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
username: &str, username: &str,
) -> Result<bool, DatabaseError> { ) -> Result<bool, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT 1 FROM user_account JOIN actor_profile USING (id) SELECT 1 FROM user_account JOIN actor_profile USING (id)
WHERE actor_profile.username = $1 WHERE actor_profile.username = $1
", ",
&[&username], &[&username],
).await?; )
.await?;
Ok(maybe_row.is_some()) Ok(maybe_row.is_some())
} }
@ -271,14 +281,16 @@ pub async fn get_user_by_login_address(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
wallet_address: &str, wallet_address: &str,
) -> Result<User, DatabaseError> { ) -> Result<User, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT user_account, actor_profile SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id) FROM user_account JOIN actor_profile USING (id)
WHERE wallet_address = $1 WHERE wallet_address = $1
", ",
&[&wallet_address], &[&wallet_address],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?; let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?; let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -291,8 +303,9 @@ pub async fn get_user_by_did(
did: &Did, did: &Did,
) -> Result<User, DatabaseError> { ) -> Result<User, DatabaseError> {
// DIDs must be locally unique // DIDs must be locally unique
let maybe_row = db_client.query_opt( let maybe_row = db_client
" .query_opt(
"
SELECT user_account, actor_profile SELECT user_account, actor_profile
FROM user_account JOIN actor_profile USING (id) FROM user_account JOIN actor_profile USING (id)
WHERE WHERE
@ -302,8 +315,9 @@ pub async fn get_user_by_did(
WHERE proof ->> 'issuer' = $1 WHERE proof ->> 'issuer' = $1
) )
", ",
&[&did.to_string()], &[&did.to_string()],
).await?; )
.await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("user"))?;
let db_user: DbUser = row.try_get("user_account")?; let db_user: DbUser = row.try_get("user_account")?;
let db_profile: DbActorProfile = row.try_get("actor_profile")?; let db_profile: DbActorProfile = row.try_get("actor_profile")?;
@ -321,24 +335,21 @@ pub async fn get_user_by_public_wallet_address(
get_user_by_did(db_client, &did).await get_user_by_did(db_client, &did).await
} }
pub async fn get_user_count( pub async fn get_user_count(db_client: &impl DatabaseClient) -> Result<i64, DatabaseError> {
db_client: &impl DatabaseClient, let row = db_client
) -> Result<i64, DatabaseError> { .query_one("SELECT count(user_account) FROM user_account", &[])
let row = db_client.query_one( .await?;
"SELECT count(user_account) FROM user_account",
&[],
).await?;
let count = row.try_get("count")?; let count = row.try_get("count")?;
Ok(count) Ok(count)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json::json; use super::*;
use serial_test::serial;
use crate::database::test_utils::create_test_database; use crate::database::test_utils::create_test_database;
use crate::users::types::Role; use crate::users::types::Role;
use super::*; use serde_json::json;
use serial_test::serial;
#[tokio::test] #[tokio::test]
#[serial] #[serial]
@ -392,7 +403,9 @@ mod tests {
}; };
let user = create_user(db_client, user_data).await.unwrap(); let user = create_user(db_client, user_data).await.unwrap();
assert_eq!(user.role, Role::NormalUser); assert_eq!(user.role, Role::NormalUser);
set_user_role(db_client, &user.id, Role::ReadOnlyUser).await.unwrap(); set_user_role(db_client, &user.id, Role::ReadOnlyUser)
.await
.unwrap();
let user = get_user_by_id(db_client, &user.id).await.unwrap(); let user = get_user_by_id(db_client, &user.id).await.unwrap();
assert_eq!(user.role, Role::ReadOnlyUser); assert_eq!(user.role, Role::ReadOnlyUser);
} }
@ -410,12 +423,10 @@ mod tests {
assert_eq!(user.client_config.is_empty(), true); assert_eq!(user.client_config.is_empty(), true);
let client_name = "test"; let client_name = "test";
let client_config_value = json!({"a": 1}); let client_config_value = json!({"a": 1});
let client_config = update_client_config( let client_config =
db_client, update_client_config(db_client, &user.id, client_name, &client_config_value)
&user.id, .await
client_name, .unwrap();
&client_config_value,
).await.unwrap();
assert_eq!( assert_eq!(
client_config.get(client_name).unwrap(), client_config.get(client_name).unwrap(),
&client_config_value, &client_config_value,

View file

@ -3,13 +3,10 @@ use std::collections::HashMap;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use postgres_types::FromSql; use postgres_types::FromSql;
use serde::Deserialize; use serde::Deserialize;
use serde_json::{Value as JsonValue}; use serde_json::Value as JsonValue;
use uuid::Uuid; use uuid::Uuid;
use mitra_utils::{ use mitra_utils::{currencies::Currency, did::Did};
currencies::Currency,
did::Did,
};
use crate::database::{ use crate::database::{
int_enum::{int_enum_from_sql, int_enum_to_sql}, int_enum::{int_enum_from_sql, int_enum_to_sql},
@ -46,7 +43,9 @@ pub enum Role {
} }
impl Default for Role { impl Default for Role {
fn default() -> Self { Self::NormalUser } fn default() -> Self {
Self::NormalUser
}
} }
impl Role { impl Role {
@ -65,9 +64,7 @@ impl Role {
Permission::DeleteAnyProfile, Permission::DeleteAnyProfile,
Permission::ManageSubscriptionOptions, Permission::ManageSubscriptionOptions,
], ],
Self::ReadOnlyUser => vec![ Self::ReadOnlyUser => vec![Permission::CreateFollowRequest],
Permission::CreateFollowRequest,
],
} }
} }
@ -147,10 +144,7 @@ pub struct User {
} }
impl User { impl User {
pub fn new( pub fn new(db_user: DbUser, db_profile: DbActorProfile) -> Self {
db_user: DbUser,
db_profile: DbActorProfile,
) -> Self {
assert_eq!(db_user.id, db_profile.id); assert_eq!(db_user.id, db_profile.id);
Self { Self {
id: db_user.id, id: db_user.id,
@ -177,7 +171,7 @@ impl User {
return Some(did_pkh.address); return Some(did_pkh.address);
}; };
}; };
}; }
None None
} }
} }

View file

@ -3,13 +3,7 @@ use std::fmt;
use std::str::FromStr; use std::str::FromStr;
use regex::Regex; use regex::Regex;
use serde::{ use serde::{de::Error as DeserializerError, Deserialize, Deserializer, Serialize, Serializer};
Deserialize,
Deserializer,
Serialize,
Serializer,
de::Error as DeserializerError,
};
use super::currencies::Currency; use super::currencies::Currency;
@ -52,7 +46,9 @@ impl ChainId {
if !self.is_ethereum() { if !self.is_ethereum() {
return Err(ChainIdError("namespace is not eip155")); return Err(ChainIdError("namespace is not eip155"));
}; };
let chain_id: u32 = self.reference.parse() let chain_id: u32 = self
.reference
.parse()
.map_err(|_| ChainIdError("invalid EIP-155 chain ID"))?; .map_err(|_| ChainIdError("invalid EIP-155 chain ID"))?;
Ok(chain_id) Ok(chain_id)
} }
@ -76,7 +72,8 @@ impl FromStr for ChainId {
fn from_str(value: &str) -> Result<Self, Self::Err> { fn from_str(value: &str) -> Result<Self, Self::Err> {
let caip2_re = Regex::new(CAIP2_RE).unwrap(); let caip2_re = Regex::new(CAIP2_RE).unwrap();
let caps = caip2_re.captures(value) let caps = caip2_re
.captures(value)
.ok_or(ChainIdError("invalid chain ID"))?; .ok_or(ChainIdError("invalid chain ID"))?;
let chain_id = Self { let chain_id = Self {
namespace: caps["namespace"].to_string(), namespace: caps["namespace"].to_string(),
@ -94,7 +91,8 @@ impl fmt::Display for ChainId {
impl Serialize for ChainId { impl Serialize for ChainId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer where
S: Serializer,
{ {
serializer.serialize_str(&self.to_string()) serializer.serialize_str(&self.to_string())
} }
@ -102,10 +100,12 @@ impl Serialize for ChainId {
impl<'de> Deserialize<'de> for ChainId { impl<'de> Deserialize<'de> for ChainId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> where
D: Deserializer<'de>,
{ {
String::deserialize(deserializer)? String::deserialize(deserializer)?
.parse().map_err(DeserializerError::custom) .parse()
.map_err(DeserializerError::custom)
} }
} }

View file

@ -5,9 +5,7 @@ use serde::Serialize;
pub struct CanonicalizationError(#[from] serde_json::Error); pub struct CanonicalizationError(#[from] serde_json::Error);
/// JCS: https://www.rfc-editor.org/rfc/rfc8785 /// JCS: https://www.rfc-editor.org/rfc/rfc8785
pub fn canonicalize_object( pub fn canonicalize_object(object: &impl Serialize) -> Result<String, CanonicalizationError> {
object: &impl Serialize,
) -> Result<String, CanonicalizationError> {
let object_str = serde_jcs::to_string(object)?; let object_str = serde_jcs::to_string(object)?;
Ok(object_str) Ok(object_str)
} }

View file

@ -1,5 +1,5 @@
use rsa::{Hash, PaddingScheme, PublicKey, RsaPrivateKey, RsaPublicKey};
use rsa::pkcs8::{FromPrivateKey, FromPublicKey, ToPrivateKey, ToPublicKey}; use rsa::pkcs8::{FromPrivateKey, FromPublicKey, ToPrivateKey, ToPublicKey};
use rsa::{Hash, PaddingScheme, PublicKey, RsaPrivateKey, RsaPublicKey};
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
pub fn generate_rsa_key() -> Result<RsaPrivateKey, rsa::errors::Error> { pub fn generate_rsa_key() -> Result<RsaPrivateKey, rsa::errors::Error> {
@ -16,32 +16,24 @@ pub fn generate_weak_rsa_key() -> Result<RsaPrivateKey, rsa::errors::Error> {
RsaPrivateKey::new(&mut rng, bits) RsaPrivateKey::new(&mut rng, bits)
} }
pub fn serialize_private_key( pub fn serialize_private_key(private_key: &RsaPrivateKey) -> Result<String, rsa::pkcs8::Error> {
private_key: &RsaPrivateKey,
) -> Result<String, rsa::pkcs8::Error> {
private_key.to_pkcs8_pem().map(|val| val.to_string()) private_key.to_pkcs8_pem().map(|val| val.to_string())
} }
pub fn deserialize_private_key( pub fn deserialize_private_key(private_key_pem: &str) -> Result<RsaPrivateKey, rsa::pkcs8::Error> {
private_key_pem: &str,
) -> Result<RsaPrivateKey, rsa::pkcs8::Error> {
RsaPrivateKey::from_pkcs8_pem(private_key_pem) RsaPrivateKey::from_pkcs8_pem(private_key_pem)
} }
pub fn get_public_key_pem( pub fn get_public_key_pem(private_key: &RsaPrivateKey) -> Result<String, rsa::pkcs8::Error> {
private_key: &RsaPrivateKey,
) -> Result<String, rsa::pkcs8::Error> {
let public_key = RsaPublicKey::from(private_key); let public_key = RsaPublicKey::from(private_key);
public_key.to_public_key_pem() public_key.to_public_key_pem()
} }
pub fn deserialize_public_key( pub fn deserialize_public_key(public_key_pem: &str) -> Result<RsaPublicKey, rsa::pkcs8::Error> {
public_key_pem: &str,
) -> Result<RsaPublicKey, rsa::pkcs8::Error> {
// rsa package can't decode PEM string with non-standard wrap width, // rsa package can't decode PEM string with non-standard wrap width,
// so the input should be normalized first // so the input should be normalized first
let parsed_pem = pem::parse(public_key_pem.trim().as_bytes()) let parsed_pem =
.map_err(|_| rsa::pkcs8::Error::Pem)?; pem::parse(public_key_pem.trim().as_bytes()).map_err(|_| rsa::pkcs8::Error::Pem)?;
let normalized_pem = pem::encode(&parsed_pem); let normalized_pem = pem::encode(&parsed_pem);
RsaPublicKey::from_public_key_pem(&normalized_pem) RsaPublicKey::from_public_key_pem(&normalized_pem)
} }
@ -70,11 +62,7 @@ pub fn verify_rsa_sha256_signature(
) -> bool { ) -> bool {
let digest = Sha256::digest(message.as_bytes()); let digest = Sha256::digest(message.as_bytes());
let padding = PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA2_256)); let padding = PaddingScheme::new_pkcs1v15_sign(Some(Hash::SHA2_256));
let is_valid = public_key.verify( let is_valid = public_key.verify(padding, &digest, signature).is_ok();
padding,
&digest,
signature,
).is_ok();
is_valid is_valid
} }
@ -102,17 +90,10 @@ YsFtrgWDQ/s8k86sNBU+Ce2GOL7seh46kyAWgJeohh4Rcrr23rftHbvxOcRM8VzYuCeb1DgVhPGtA0xU
fn test_verify_rsa_signature() { fn test_verify_rsa_signature() {
let private_key = generate_weak_rsa_key().unwrap(); let private_key = generate_weak_rsa_key().unwrap();
let message = "test".to_string(); let message = "test".to_string();
let signature = create_rsa_sha256_signature( let signature = create_rsa_sha256_signature(&private_key, &message).unwrap();
&private_key,
&message,
).unwrap();
let public_key = RsaPublicKey::from(&private_key); let public_key = RsaPublicKey::from(&private_key);
let is_valid = verify_rsa_sha256_signature( let is_valid = verify_rsa_sha256_signature(&public_key, &message, &signature);
&public_key,
&message,
&signature,
);
assert_eq!(is_valid, true); assert_eq!(is_valid, true);
} }
} }

View file

@ -9,7 +9,8 @@ impl Currency {
match self { match self {
Self::Ethereum => "ETH", Self::Ethereum => "ETH",
Self::Monero => "XMR", Self::Monero => "XMR",
}.to_string() }
.to_string()
} }
pub fn field_name(&self) -> String { pub fn field_name(&self) -> String {

View file

@ -1,8 +1,7 @@
use chrono::{DateTime, Duration, NaiveDateTime, Utc}; use chrono::{DateTime, Duration, NaiveDateTime, Utc};
pub fn get_min_datetime() -> DateTime<Utc> { pub fn get_min_datetime() -> DateTime<Utc> {
let native = NaiveDateTime::from_timestamp_opt(0, 0) let native = NaiveDateTime::from_timestamp_opt(0, 0).expect("0 should be a valid argument");
.expect("0 should be a valid argument");
DateTime::from_utc(native, Utc) DateTime::from_utc(native, Utc)
} }

View file

@ -3,10 +3,7 @@ use std::fmt;
use std::str::FromStr; use std::str::FromStr;
use regex::Regex; use regex::Regex;
use serde::{ use serde::{de::Error as DeserializerError, Deserialize, Deserializer, Serialize, Serializer};
Deserialize, Deserializer, Serialize, Serializer,
de::Error as DeserializerError,
};
use super::did_key::DidKey; use super::did_key::DidKey;
use super::did_pkh::DidPkh; use super::did_pkh::DidPkh;
@ -33,11 +30,11 @@ impl FromStr for Did {
"key" => { "key" => {
let did_key = DidKey::from_str(value)?; let did_key = DidKey::from_str(value)?;
Self::Key(did_key) Self::Key(did_key)
}, }
"pkh" => { "pkh" => {
let did_pkh = DidPkh::from_str(value)?; let did_pkh = DidPkh::from_str(value)?;
Self::Pkh(did_pkh) Self::Pkh(did_pkh)
}, }
_ => return Err(DidParseError), _ => return Err(DidParseError),
}; };
Ok(did) Ok(did)
@ -56,7 +53,8 @@ impl fmt::Display for Did {
impl<'de> Deserialize<'de> for Did { impl<'de> Deserialize<'de> for Did {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> where
D: Deserializer<'de>,
{ {
let did_str: String = Deserialize::deserialize(deserializer)?; let did_str: String = Deserialize::deserialize(deserializer)?;
did_str.parse().map_err(DeserializerError::custom) did_str.parse().map_err(DeserializerError::custom)
@ -65,7 +63,8 @@ impl<'de> Deserialize<'de> for Did {
impl Serialize for Did { impl Serialize for Did {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: Serializer where
S: Serializer,
{ {
let did_str = self.to_string(); let did_str = self.to_string();
serializer.serialize_str(&did_str) serializer.serialize_str(&did_str)

View file

@ -6,10 +6,7 @@ use regex::Regex;
use super::{ use super::{
did::DidParseError, did::DidParseError,
multibase::{ multibase::{decode_multibase_base58btc, encode_multibase_base58btc},
decode_multibase_base58btc,
encode_multibase_base58btc,
},
}; };
const DID_KEY_RE: &str = r"did:key:(?P<key>z[a-km-zA-HJ-NP-Z1-9]+)"; const DID_KEY_RE: &str = r"did:key:(?P<key>z[a-km-zA-HJ-NP-Z1-9]+)";
@ -34,10 +31,7 @@ impl DidKey {
} }
pub fn from_ed25519_key(key: [u8; 32]) -> Self { pub fn from_ed25519_key(key: [u8; 32]) -> Self {
let prefixed_key = [ let prefixed_key = [MULTICODEC_ED25519_PREFIX.to_vec(), key.to_vec()].concat();
MULTICODEC_ED25519_PREFIX.to_vec(),
key.to_vec(),
].concat();
Self { key: prefixed_key } Self { key: prefixed_key }
} }
@ -62,8 +56,7 @@ impl FromStr for DidKey {
fn from_str(value: &str) -> Result<Self, Self::Err> { fn from_str(value: &str) -> Result<Self, Self::Err> {
let did_key_re = Regex::new(DID_KEY_RE).unwrap(); let did_key_re = Regex::new(DID_KEY_RE).unwrap();
let caps = did_key_re.captures(value).ok_or(DidParseError)?; let caps = did_key_re.captures(value).ok_or(DidParseError)?;
let key = decode_multibase_base58btc(&caps["key"]) let key = decode_multibase_base58btc(&caps["key"]).map_err(|_| DidParseError)?;
.map_err(|_| DidParseError)?;
let did_key = Self { key }; let did_key = Self { key };
Ok(did_key) Ok(did_key)
} }

View file

@ -4,11 +4,7 @@ use std::str::FromStr;
use regex::Regex; use regex::Regex;
use super::{ use super::{caip2::ChainId, currencies::Currency, did::DidParseError};
caip2::ChainId,
currencies::Currency,
did::DidParseError,
};
// https://github.com/ChainAgnostic/CAIPs/blob/master/CAIPs/caip-10.md#syntax // https://github.com/ChainAgnostic/CAIPs/blob/master/CAIPs/caip-10.md#syntax
const DID_PKH_RE: &str = r"did:pkh:(?P<network>[-a-z0-9]{3,8}):(?P<chain>[-a-zA-Z0-9]{1,32}):(?P<address>[a-zA-Z0-9]{1,64})"; const DID_PKH_RE: &str = r"did:pkh:(?P<network>[-a-z0-9]{3,8}):(?P<chain>[-a-zA-Z0-9]{1,32}):(?P<address>[a-zA-Z0-9]{1,64})";
@ -38,9 +34,7 @@ impl fmt::Display for DidPkh {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
let did_str = format!( let did_str = format!(
"did:pkh:{}:{}:{}", "did:pkh:{}:{}:{}",
self.chain_id.namespace, self.chain_id.namespace, self.chain_id.reference, self.address,
self.chain_id.reference,
self.address,
); );
write!(formatter, "{}", did_str) write!(formatter, "{}", did_str)
} }

View file

@ -1,10 +1,6 @@
use std::fs::{ use std::fs::{set_permissions, File, Permissions};
set_permissions,
File,
Permissions,
};
use std::io::Error;
use std::io::prelude::*; use std::io::prelude::*;
use std::io::Error;
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
use std::path::Path; use std::path::Path;
@ -19,11 +15,9 @@ pub fn get_media_type_extension(media_type: &str) -> Option<&'static str> {
match media_type { match media_type {
// Override extension provided by mime_guess // Override extension provided by mime_guess
"image/jpeg" => Some("jpg"), "image/jpeg" => Some("jpg"),
_ => { _ => get_mime_extensions_str(media_type)
get_mime_extensions_str(media_type) .and_then(|extensions| extensions.first())
.and_then(|extensions| extensions.first()) .copied(),
.copied()
}
} }
} }
@ -45,13 +39,7 @@ mod tests {
#[test] #[test]
fn test_get_media_type_extension() { fn test_get_media_type_extension() {
assert_eq!( assert_eq!(get_media_type_extension("image/png"), Some("png"),);
get_media_type_extension("image/png"), assert_eq!(get_media_type_extension("image/jpeg"), Some("jpg"),);
Some("png"),
);
assert_eq!(
get_media_type_extension("image/jpeg"),
Some("jpg"),
);
} }
} }

View file

@ -3,7 +3,7 @@ use std::iter::FromIterator;
use ammonia::Builder; use ammonia::Builder;
pub use ammonia::{clean_text as escape_html}; pub use ammonia::clean_text as escape_html;
pub fn clean_html( pub fn clean_html(
unsafe_html: &str, unsafe_html: &str,
@ -12,7 +12,7 @@ pub fn clean_html(
let mut builder = Builder::default(); let mut builder = Builder::default();
for (tag, classes) in allowed_classes.iter() { for (tag, classes) in allowed_classes.iter() {
builder.add_allowed_classes(tag, classes); builder.add_allowed_classes(tag, classes);
}; }
let safe_html = builder let safe_html = builder
// Remove src from external images to prevent tracking // Remove src from external images to prevent tracking
.set_tag_attribute_value("img", "src", "") .set_tag_attribute_value("img", "src", "")
@ -28,15 +28,11 @@ pub fn clean_html_strict(
allowed_tags: &[&str], allowed_tags: &[&str],
allowed_classes: Vec<(&'static str, Vec<&'static str>)>, allowed_classes: Vec<(&'static str, Vec<&'static str>)>,
) -> String { ) -> String {
let allowed_tags = let allowed_tags = HashSet::from_iter(allowed_tags.iter().copied());
HashSet::from_iter(allowed_tags.iter().copied());
let mut allowed_classes_map = HashMap::new(); let mut allowed_classes_map = HashMap::new();
for (tag, classes) in allowed_classes { for (tag, classes) in allowed_classes {
allowed_classes_map.insert( allowed_classes_map.insert(tag, HashSet::from_iter(classes.into_iter()));
tag, }
HashSet::from_iter(classes.into_iter()),
);
};
let safe_html = Builder::default() let safe_html = Builder::default()
.tags(allowed_tags) .tags(allowed_tags)
.allowed_classes(allowed_classes_map) .allowed_classes(allowed_classes_map)
@ -47,9 +43,7 @@ pub fn clean_html_strict(
} }
pub fn clean_html_all(html: &str) -> String { pub fn clean_html_all(html: &str) -> String {
let text = Builder::empty() let text = Builder::empty().clean(html).to_string();
.clean(html)
.to_string();
text text
} }
@ -69,10 +63,7 @@ mod tests {
); );
let safe_html = clean_html( let safe_html = clean_html(
unsafe_html, unsafe_html,
vec![ vec![("a", vec!["mention", "u-url"]), ("span", vec!["h-card"])],
("a", vec!["mention", "u-url"]),
("span", vec!["h-card"]),
],
); );
assert_eq!(safe_html, expected_safe_html); assert_eq!(safe_html, expected_safe_html);
} }
@ -83,12 +74,12 @@ mod tests {
let safe_html = clean_html_strict( let safe_html = clean_html_strict(
unsafe_html, unsafe_html,
&["a", "br", "code", "p", "span"], &["a", "br", "code", "p", "span"],
vec![ vec![("a", vec!["mention", "u-url"]), ("span", vec!["h-card"])],
("a", vec!["mention", "u-url"]), );
("span", vec!["h-card"]), assert_eq!(
], safe_html,
r#"<p><span class="h-card"><a href="https://example.com/user" class="u-url mention" rel="noopener">@<span>user</span></a></span> test bold with <a href="https://example.com" rel="noopener">link</a> and <code>code</code></p>"#
); );
assert_eq!(safe_html, r#"<p><span class="h-card"><a href="https://example.com/user" class="u-url mention" rel="noopener">@<span>user</span></a></span> test bold with <a href="https://example.com" rel="noopener">link</a> and <code>code</code></p>"#);
} }
#[test] #[test]

View file

@ -2,14 +2,9 @@ use std::cell::RefCell;
use comrak::{ use comrak::{
arena_tree::Node, arena_tree::Node,
format_commonmark, format_commonmark, format_html,
format_html,
nodes::{Ast, AstNode, ListType, NodeValue}, nodes::{Ast, AstNode, ListType, NodeValue},
parse_document, parse_document, Arena, ComrakExtensionOptions, ComrakOptions, ComrakParseOptions,
Arena,
ComrakOptions,
ComrakExtensionOptions,
ComrakParseOptions,
ComrakRenderOptions, ComrakRenderOptions,
}; };
@ -37,16 +32,14 @@ fn build_comrak_options() -> ComrakOptions {
} }
} }
fn iter_nodes<'a, F>( fn iter_nodes<'a, F>(node: &'a AstNode<'a>, func: &F) -> Result<(), MarkdownError>
node: &'a AstNode<'a>, where
func: &F, F: Fn(&'a AstNode<'a>) -> Result<(), MarkdownError>,
) -> Result<(), MarkdownError>
where F: Fn(&'a AstNode<'a>) -> Result<(), MarkdownError>
{ {
func(node)?; func(node)?;
for child in node.children() { for child in node.children() {
iter_nodes(child, func)?; iter_nodes(child, func)?;
}; }
Ok(()) Ok(())
} }
@ -80,15 +73,13 @@ fn replace_with_markdown<'a>(
let markdown = node_to_markdown(node, options)?; let markdown = node_to_markdown(node, options)?;
for child in node.children() { for child in node.children() {
child.detach(); child.detach();
}; }
let text = NodeValue::Text(markdown); let text = NodeValue::Text(markdown);
replace_node_value(node, text); replace_node_value(node, text);
Ok(()) Ok(())
} }
fn fix_microsyntaxes<'a>( fn fix_microsyntaxes<'a>(node: &'a AstNode<'a>) -> Result<(), MarkdownError> {
node: &'a AstNode<'a>,
) -> Result<(), MarkdownError> {
if let Some(prev) = node.previous_sibling() { if let Some(prev) = node.previous_sibling() {
if let NodeValue::Text(ref prev_text) = prev.data.borrow().value { if let NodeValue::Text(ref prev_text) = prev.data.borrow().value {
// Remove autolink if mention or object link syntax is found // Remove autolink if mention or object link syntax is found
@ -100,7 +91,7 @@ fn fix_microsyntaxes<'a>(
if let NodeValue::Text(child_text) = child_value { if let NodeValue::Text(child_text) = child_value {
link_text.push_str(child_text); link_text.push_str(child_text);
}; };
}; }
let text = NodeValue::Text(link_text); let text = NodeValue::Text(link_text);
replace_node_value(node, text); replace_node_value(node, text);
}; };
@ -138,11 +129,7 @@ fn fix_linebreaks(html: &str) -> String {
pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> { pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
let options = build_comrak_options(); let options = build_comrak_options();
let arena = Arena::new(); let arena = Arena::new();
let root = parse_document( let root = parse_document(&arena, text, &options);
&arena,
text,
&options,
);
// Re-render blockquotes, headings, HRs, images and lists // Re-render blockquotes, headings, HRs, images and lists
// Headings: poorly degrade on Pleroma // Headings: poorly degrade on Pleroma
@ -160,12 +147,12 @@ pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
}; };
for child in node.children() { for child in node.children() {
child.detach(); child.detach();
}; }
let text = NodeValue::Text(markdown); let text = NodeValue::Text(markdown);
let text_node = arena.alloc(create_node(text)); let text_node = arena.alloc(create_node(text));
node.append(text_node); node.append(text_node);
replace_node_value(node, NodeValue::Paragraph); replace_node_value(node, NodeValue::Paragraph);
}, }
NodeValue::Image(_) => replace_with_markdown(node, &options)?, NodeValue::Image(_) => replace_with_markdown(node, &options)?,
NodeValue::List(_) => { NodeValue::List(_) => {
// Replace list and list item nodes // Replace list and list item nodes
@ -176,11 +163,10 @@ pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
for paragraph in list_item.children() { for paragraph in list_item.children() {
for content_node in paragraph.children() { for content_node in paragraph.children() {
contents.push(content_node); contents.push(content_node);
}; }
paragraph.detach(); paragraph.detach();
}; }
let mut list_prefix_markdown = let mut list_prefix_markdown = node_to_markdown(list_item, &options)?;
node_to_markdown(list_item, &options)?;
if let NodeValue::Item(item) = list_item.data.borrow().value { if let NodeValue::Item(item) = list_item.data.borrow().value {
if item.list_type == ListType::Ordered { if item.list_type == ListType::Ordered {
// Preserve numbering in ordered lists // Preserve numbering in ordered lists
@ -200,14 +186,14 @@ pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
replacements.push(list_prefix_node); replacements.push(list_prefix_node);
for content_node in contents { for content_node in contents {
replacements.push(content_node); replacements.push(content_node);
}; }
list_item.detach(); list_item.detach();
}; }
for child_node in replacements { for child_node in replacements {
node.append(child_node); node.append(child_node);
}; }
replace_node_value(node, NodeValue::Paragraph); replace_node_value(node, NodeValue::Paragraph);
}, }
NodeValue::Link(_) => fix_microsyntaxes(node)?, NodeValue::Link(_) => fix_microsyntaxes(node)?,
_ => (), _ => (),
}; };
@ -224,20 +210,15 @@ pub fn markdown_lite_to_html(text: &str) -> Result<String, MarkdownError> {
pub fn markdown_basic_to_html(text: &str) -> Result<String, MarkdownError> { pub fn markdown_basic_to_html(text: &str) -> Result<String, MarkdownError> {
let options = build_comrak_options(); let options = build_comrak_options();
let arena = Arena::new(); let arena = Arena::new();
let root = parse_document( let root = parse_document(&arena, text, &options);
&arena,
text,
&options,
);
iter_nodes(root, &|node| { iter_nodes(root, &|node| {
let node_value = node.data.borrow().value.clone(); let node_value = node.data.borrow().value.clone();
match node_value { match node_value {
NodeValue::Document | NodeValue::Document
NodeValue::Text(_) | | NodeValue::Text(_)
NodeValue::SoftBreak | | NodeValue::SoftBreak
NodeValue::LineBreak | NodeValue::LineBreak => (),
=> (),
NodeValue::Link(_) => fix_microsyntaxes(node)?, NodeValue::Link(_) => fix_microsyntaxes(node)?,
NodeValue::Paragraph => { NodeValue::Paragraph => {
if node.next_sibling().is_some() { if node.next_sibling().is_some() {
@ -248,13 +229,12 @@ pub fn markdown_basic_to_html(text: &str) -> Result<String, MarkdownError> {
let last_child_value = &last_child.data.borrow().value; let last_child_value = &last_child.data.borrow().value;
if !matches!(last_child_value, NodeValue::LineBreak) { if !matches!(last_child_value, NodeValue::LineBreak) {
let line_break = NodeValue::LineBreak; let line_break = NodeValue::LineBreak;
let line_break_node = let line_break_node = arena.alloc(create_node(line_break));
arena.alloc(create_node(line_break));
node.append(line_break_node); node.append(line_break_node);
}; };
}; };
}; };
}, }
_ => replace_with_markdown(node, &options)?, _ => replace_with_markdown(node, &options)?,
}; };
Ok(()) Ok(())
@ -340,9 +320,6 @@ mod tests {
fn test_markdown_to_html() { fn test_markdown_to_html() {
let text = "# heading\n\ntest"; let text = "# heading\n\ntest";
let html = markdown_to_html(text); let html = markdown_to_html(text);
assert_eq!( assert_eq!(html, "<h1>heading</h1>\n<p>test</p>\n",);
html,
"<h1>heading</h1>\n<p>test</p>\n",
);
} }
} }

View file

@ -12,10 +12,10 @@ pub enum MultibaseError {
/// Decodes multibase base58 (bitcoin) value /// Decodes multibase base58 (bitcoin) value
/// https://github.com/multiformats/multibase /// https://github.com/multiformats/multibase
pub fn decode_multibase_base58btc(value: &str) pub fn decode_multibase_base58btc(value: &str) -> Result<Vec<u8>, MultibaseError> {
-> Result<Vec<u8>, MultibaseError> let base = value
{ .chars()
let base = value.chars().next() .next()
.ok_or(MultibaseError::InvalidBaseString)?; .ok_or(MultibaseError::InvalidBaseString)?;
// z == base58btc // z == base58btc
// https://github.com/multiformats/multibase#multibase-table // https://github.com/multiformats/multibase#multibase-table

View file

@ -7,10 +7,7 @@ pub fn hash_password(password: &str) -> Result<String, argon2::Error> {
argon2::hash_encoded(password.as_bytes(), &salt, &config) argon2::hash_encoded(password.as_bytes(), &salt, &config)
} }
pub fn verify_password( pub fn verify_password(password_hash: &str, password: &str) -> Result<bool, argon2::Error> {
password_hash: &str,
password: &str,
) -> Result<bool, argon2::Error> {
argon2::verify_encoded(password_hash, password.as_bytes()) argon2::verify_encoded(password_hash, password.as_bytes())
} }

View file

@ -2,10 +2,7 @@ use std::net::{Ipv4Addr, Ipv6Addr};
use url::{Host, ParseError, Url}; use url::{Host, ParseError, Url};
pub fn get_hostname(url: &str) -> Result<String, ParseError> { pub fn get_hostname(url: &str) -> Result<String, ParseError> {
let hostname = match Url::parse(url)? let hostname = match Url::parse(url)?.host().ok_or(ParseError::EmptyHost)? {
.host()
.ok_or(ParseError::EmptyHost)?
{
Host::Domain(domain) => domain.to_string(), Host::Domain(domain) => domain.to_string(),
Host::Ipv4(addr) => addr.to_string(), Host::Ipv4(addr) => addr.to_string(),
Host::Ipv6(addr) => addr.to_string(), Host::Ipv6(addr) => addr.to_string(),
@ -35,10 +32,7 @@ pub fn guess_protocol(hostname: &str) -> &'static str {
} }
pub fn normalize_url(url: &str) -> Result<Url, url::ParseError> { pub fn normalize_url(url: &str) -> Result<Url, url::ParseError> {
let normalized_url = if let normalized_url = if url.starts_with("http://") || url.starts_with("https://") {
url.starts_with("http://") ||
url.starts_with("https://")
{
url.to_string() url.to_string()
} else { } else {
// Add scheme // Add scheme
@ -49,11 +43,7 @@ pub fn normalize_url(url: &str) -> Result<Url, url::ParseError> {
url url
}; };
let url_scheme = guess_protocol(hostname); let url_scheme = guess_protocol(hostname);
format!( format!("{}://{}", url_scheme, url,)
"{}://{}",
url_scheme,
url,
)
}; };
let url = Url::parse(&normalized_url)?; let url = Url::parse(&normalized_url)?;
url.host().ok_or(ParseError::EmptyHost)?; // validates URL url.host().ok_or(ParseError::EmptyHost)?; // validates URL
@ -82,7 +72,10 @@ mod tests {
fn test_get_hostname_tor() { fn test_get_hostname_tor() {
let url = "http://2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion/objects/1"; let url = "http://2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion/objects/1";
let hostname = get_hostname(url).unwrap(); let hostname = get_hostname(url).unwrap();
assert_eq!(hostname, "2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion"); assert_eq!(
hostname,
"2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion"
);
} }
#[test] #[test]
@ -101,28 +94,19 @@ mod tests {
#[test] #[test]
fn test_guess_protocol() { fn test_guess_protocol() {
assert_eq!( assert_eq!(guess_protocol("example.org"), "https",);
guess_protocol("example.org"),
"https",
);
assert_eq!( assert_eq!(
guess_protocol("2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion"), guess_protocol("2gzyxa5ihm7nsggfxnu52rck2vv4rvmdlkiu3zzui5du4xyclen53wid.onion"),
"http", "http",
); );
assert_eq!( assert_eq!(guess_protocol("zzz.i2p"), "http",);
guess_protocol("zzz.i2p"),
"http",
);
// Yggdrasil // Yggdrasil
assert_eq!( assert_eq!(
guess_protocol("319:3cf0:dd1d:47b9:20c:29ff:fe2c:39be"), guess_protocol("319:3cf0:dd1d:47b9:20c:29ff:fe2c:39be"),
"http", "http",
); );
// localhost // localhost
assert_eq!( assert_eq!(guess_protocol("127.0.0.1"), "http",);
guess_protocol("127.0.0.1"),
"http",
);
} }
#[test] #[test]

View file

@ -1,36 +1,20 @@
use mitra_models::profiles::types::{ use mitra_models::profiles::types::{
ExtraField, ExtraField, IdentityProof, IdentityProofType, PaymentLink, PaymentOption,
IdentityProof,
IdentityProofType,
PaymentLink,
PaymentOption,
}; };
use mitra_utils::did::Did; use mitra_utils::did::Did;
use crate::activitypub::vocabulary::{ use crate::activitypub::vocabulary::{IDENTITY_PROOF, LINK, PROPERTY_VALUE};
IDENTITY_PROOF,
LINK,
PROPERTY_VALUE,
};
use crate::errors::ValidationError; use crate::errors::ValidationError;
use crate::identity::{ use crate::identity::{
claims::create_identity_claim, claims::create_identity_claim,
minisign::{ minisign::{parse_minisign_signature, verify_minisign_signature},
parse_minisign_signature,
verify_minisign_signature,
},
};
use crate::json_signatures::proofs::{
PROOF_TYPE_ID_EIP191,
PROOF_TYPE_ID_MINISIGN,
}; };
use crate::json_signatures::proofs::{PROOF_TYPE_ID_EIP191, PROOF_TYPE_ID_MINISIGN};
use crate::web_client::urls::get_subscription_page_url; use crate::web_client::urls::get_subscription_page_url;
use super::types::ActorAttachment; use super::types::ActorAttachment;
pub fn attach_identity_proof( pub fn attach_identity_proof(proof: IdentityProof) -> ActorAttachment {
proof: IdentityProof,
) -> ActorAttachment {
let proof_type_str = match proof.proof_type { let proof_type_str = match proof.proof_type {
IdentityProofType::LegacyEip191IdentityProof => PROOF_TYPE_ID_EIP191, IdentityProofType::LegacyEip191IdentityProof => PROOF_TYPE_ID_EIP191,
IdentityProofType::LegacyMinisignIdentityProof => PROOF_TYPE_ID_MINISIGN, IdentityProofType::LegacyMinisignIdentityProof => PROOF_TYPE_ID_MINISIGN,
@ -52,18 +36,24 @@ pub fn parse_identity_proof(
if attachment.object_type != IDENTITY_PROOF { if attachment.object_type != IDENTITY_PROOF {
return Err(ValidationError("invalid attachment type")); return Err(ValidationError("invalid attachment type"));
}; };
let proof_type_str = attachment.signature_algorithm.as_ref() let proof_type_str = attachment
.signature_algorithm
.as_ref()
.ok_or(ValidationError("missing proof type"))?; .ok_or(ValidationError("missing proof type"))?;
let proof_type = match proof_type_str.as_str() { let proof_type = match proof_type_str.as_str() {
PROOF_TYPE_ID_EIP191 => IdentityProofType::LegacyEip191IdentityProof, PROOF_TYPE_ID_EIP191 => IdentityProofType::LegacyEip191IdentityProof,
PROOF_TYPE_ID_MINISIGN => IdentityProofType::LegacyMinisignIdentityProof, PROOF_TYPE_ID_MINISIGN => IdentityProofType::LegacyMinisignIdentityProof,
_ => return Err(ValidationError("unsupported proof type")), _ => return Err(ValidationError("unsupported proof type")),
}; };
let did = attachment.name.parse::<Did>() let did = attachment
.name
.parse::<Did>()
.map_err(|_| ValidationError("invalid DID"))?; .map_err(|_| ValidationError("invalid DID"))?;
let message = create_identity_claim(actor_id, &did) let message =
.map_err(|_| ValidationError("invalid claim"))?; create_identity_claim(actor_id, &did).map_err(|_| ValidationError("invalid claim"))?;
let signature = attachment.signature_value.as_ref() let signature = attachment
.signature_value
.as_ref()
.ok_or(ValidationError("missing signature"))?; .ok_or(ValidationError("missing signature"))?;
match did { match did {
Did::Key(ref did_key) => { Did::Key(ref did_key) => {
@ -72,15 +62,12 @@ pub fn parse_identity_proof(
}; };
let signature_bin = parse_minisign_signature(signature) let signature_bin = parse_minisign_signature(signature)
.map_err(|_| ValidationError("invalid signature encoding"))?; .map_err(|_| ValidationError("invalid signature encoding"))?;
verify_minisign_signature( verify_minisign_signature(did_key, &message, &signature_bin)
did_key, .map_err(|_| ValidationError("invalid identity proof"))?;
&message, }
&signature_bin,
).map_err(|_| ValidationError("invalid identity proof"))?;
},
Did::Pkh(ref _did_pkh) => { Did::Pkh(ref _did_pkh) => {
return Err(ValidationError("incorrect proof type")); return Err(ValidationError("incorrect proof type"));
}, }
}; };
let proof = IdentityProof { let proof = IdentityProof {
issuer: did, issuer: did,
@ -102,12 +89,12 @@ pub fn attach_payment_option(
let name = "EthereumSubscription".to_string(); let name = "EthereumSubscription".to_string();
let href = get_subscription_page_url(instance_url, username); let href = get_subscription_page_url(instance_url, username);
(name, href) (name, href)
}, }
PaymentOption::MoneroSubscription(_) => { PaymentOption::MoneroSubscription(_) => {
let name = "MoneroSubscription".to_string(); let name = "MoneroSubscription".to_string();
let href = get_subscription_page_url(instance_url, username); let href = get_subscription_page_url(instance_url, username);
(name, href) (name, href)
}, }
}; };
ActorAttachment { ActorAttachment {
object_type: LINK.to_string(), object_type: LINK.to_string(),
@ -125,7 +112,9 @@ pub fn parse_payment_option(
if attachment.object_type != LINK { if attachment.object_type != LINK {
return Err(ValidationError("invalid attachment type")); return Err(ValidationError("invalid attachment type"));
}; };
let href = attachment.href.as_ref() let href = attachment
.href
.as_ref()
.ok_or(ValidationError("href attribute is required"))? .ok_or(ValidationError("href attribute is required"))?
.to_string(); .to_string();
let payment_option = PaymentOption::Link(PaymentLink { let payment_option = PaymentOption::Link(PaymentLink {
@ -135,9 +124,7 @@ pub fn parse_payment_option(
Ok(payment_option) Ok(payment_option)
} }
pub fn attach_extra_field( pub fn attach_extra_field(field: ExtraField) -> ActorAttachment {
field: ExtraField,
) -> ActorAttachment {
ActorAttachment { ActorAttachment {
object_type: PROPERTY_VALUE.to_string(), object_type: PROPERTY_VALUE.to_string(),
name: field.name, name: field.name,
@ -148,13 +135,13 @@ pub fn attach_extra_field(
} }
} }
pub fn parse_extra_field( pub fn parse_extra_field(attachment: &ActorAttachment) -> Result<ExtraField, ValidationError> {
attachment: &ActorAttachment,
) -> Result<ExtraField, ValidationError> {
if attachment.object_type != PROPERTY_VALUE { if attachment.object_type != PROPERTY_VALUE {
return Err(ValidationError("invalid attachment type")); return Err(ValidationError("invalid attachment type"));
}; };
let property_value = attachment.value.as_ref() let property_value = attachment
.value
.as_ref()
.ok_or(ValidationError("missing property value"))?; .ok_or(ValidationError("missing property value"))?;
let field = ExtraField { let field = ExtraField {
name: attachment.name.clone(), name: attachment.name.clone(),
@ -166,10 +153,8 @@ pub fn parse_extra_field(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::{
caip2::ChainId,
};
use super::*; use super::*;
use mitra_utils::caip2::ChainId;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";
@ -191,15 +176,9 @@ mod tests {
#[test] #[test]
fn test_payment_option() { fn test_payment_option() {
let username = "testuser"; let username = "testuser";
let payment_option = let payment_option = PaymentOption::ethereum_subscription(ChainId::ethereum_mainnet());
PaymentOption::ethereum_subscription(ChainId::ethereum_mainnet()); let subscription_page_url = "https://example.com/@testuser/subscription";
let subscription_page_url = let attachment = attach_payment_option(INSTANCE_URL, username, payment_option);
"https://example.com/@testuser/subscription";
let attachment = attach_payment_option(
INSTANCE_URL,
username,
payment_option,
);
assert_eq!(attachment.object_type, LINK); assert_eq!(attachment.object_type, LINK);
assert_eq!(attachment.name, "EthereumSubscription"); assert_eq!(attachment.name, "EthereumSubscription");
assert_eq!(attachment.href.as_deref().unwrap(), subscription_page_url); assert_eq!(attachment.href.as_deref().unwrap(), subscription_page_url);

View file

@ -6,12 +6,7 @@ use mitra_config::Instance;
use mitra_models::{ use mitra_models::{
database::DatabaseClient, database::DatabaseClient,
profiles::queries::{create_profile, update_profile}, profiles::queries::{create_profile, update_profile},
profiles::types::{ profiles::types::{DbActorProfile, ProfileCreateData, ProfileImage, ProfileUpdateData},
DbActorProfile,
ProfileImage,
ProfileCreateData,
ProfileUpdateData,
},
}; };
use crate::activitypub::{ use crate::activitypub::{
@ -44,19 +39,17 @@ async fn fetch_actor_images(
icon.media_type.as_deref(), icon.media_type.as_deref(),
ACTOR_IMAGE_MAX_SIZE, ACTOR_IMAGE_MAX_SIZE,
media_dir, media_dir,
).await { )
.await
{
Ok((file_name, file_size, maybe_media_type)) => { Ok((file_name, file_size, maybe_media_type)) => {
let image = ProfileImage::new( let image = ProfileImage::new(file_name, file_size, maybe_media_type);
file_name,
file_size,
maybe_media_type,
);
Some(image) Some(image)
}, }
Err(error) => { Err(error) => {
log::warn!("failed to fetch avatar ({})", error); log::warn!("failed to fetch avatar ({})", error);
default_avatar default_avatar
}, }
} }
} else { } else {
None None
@ -68,19 +61,17 @@ async fn fetch_actor_images(
image.media_type.as_deref(), image.media_type.as_deref(),
ACTOR_IMAGE_MAX_SIZE, ACTOR_IMAGE_MAX_SIZE,
media_dir, media_dir,
).await { )
.await
{
Ok((file_name, file_size, maybe_media_type)) => { Ok((file_name, file_size, maybe_media_type)) => {
let image = ProfileImage::new( let image = ProfileImage::new(file_name, file_size, maybe_media_type);
file_name,
file_size,
maybe_media_type,
);
Some(image) Some(image)
}, }
Err(error) => { Err(error) => {
log::warn!("failed to fetch banner ({})", error); log::warn!("failed to fetch banner ({})", error);
default_banner default_banner
}, }
} }
} else { } else {
None None
@ -90,24 +81,24 @@ async fn fetch_actor_images(
fn parse_aliases(actor: &Actor) -> Vec<String> { fn parse_aliases(actor: &Actor) -> Vec<String> {
// Aliases reported by server (not signed) // Aliases reported by server (not signed)
actor.also_known_as.as_ref() actor
.and_then(|value| { .also_known_as
match parse_array(value) { .as_ref()
Ok(array) => { .and_then(|value| match parse_array(value) {
let mut aliases = vec![]; Ok(array) => {
for actor_id in array { let mut aliases = vec![];
if validate_object_id(&actor_id).is_err() { for actor_id in array {
log::warn!("invalid alias: {}", actor_id); if validate_object_id(&actor_id).is_err() {
continue; log::warn!("invalid alias: {}", actor_id);
}; continue;
aliases.push(actor_id);
}; };
Some(aliases) aliases.push(actor_id);
}, }
Err(_) => { Some(aliases)
log::warn!("invalid alias list: {}", value); }
None Err(_) => {
}, log::warn!("invalid alias list: {}", value);
None
} }
}) })
.unwrap_or_default() .unwrap_or_default()
@ -127,23 +118,18 @@ async fn parse_tags(
log::warn!("too many emojis"); log::warn!("too many emojis");
continue; continue;
}; };
match handle_emoji( match handle_emoji(db_client, instance, storage, tag_value).await? {
db_client,
instance,
storage,
tag_value,
).await? {
Some(emoji) => { Some(emoji) => {
if !emojis.contains(&emoji.id) { if !emojis.contains(&emoji.id) {
emojis.push(emoji.id); emojis.push(emoji.id);
}; };
}, }
None => continue, None => continue,
}; };
} else { } else {
log::warn!("skipping actor tag of type {}", tag_type); log::warn!("skipping actor tag of type {}", tag_type);
}; };
}; }
Ok(emojis) Ok(emojis)
} }
@ -157,22 +143,11 @@ pub async fn create_remote_profile(
if actor_address.hostname == instance.hostname() { if actor_address.hostname == instance.hostname() {
return Err(HandlerError::LocalObject); return Err(HandlerError::LocalObject);
}; };
let (maybe_avatar, maybe_banner) = fetch_actor_images( let (maybe_avatar, maybe_banner) =
instance, fetch_actor_images(instance, &actor, &storage.media_dir, None, None).await;
&actor, let (identity_proofs, payment_options, extra_fields) = actor.parse_attachments();
&storage.media_dir,
None,
None,
).await;
let (identity_proofs, payment_options, extra_fields) =
actor.parse_attachments();
let aliases = parse_aliases(&actor); let aliases = parse_aliases(&actor);
let emojis = parse_tags( let emojis = parse_tags(db_client, instance, storage, &actor).await?;
db_client,
instance,
storage,
&actor,
).await?;
let mut profile_data = ProfileCreateData { let mut profile_data = ProfileCreateData {
username: actor.preferred_username.clone(), username: actor.preferred_username.clone(),
hostname: Some(actor_address.hostname), hostname: Some(actor_address.hostname),
@ -203,11 +178,7 @@ pub async fn update_remote_profile(
) -> Result<DbActorProfile, HandlerError> { ) -> Result<DbActorProfile, HandlerError> {
let actor_old = profile.actor_json.ok_or(HandlerError::LocalObject)?; let actor_old = profile.actor_json.ok_or(HandlerError::LocalObject)?;
if actor_old.id != actor.id { if actor_old.id != actor.id {
log::warn!( log::warn!("actor ID changed from {} to {}", actor_old.id, actor.id,);
"actor ID changed from {} to {}",
actor_old.id,
actor.id,
);
}; };
if actor_old.public_key.public_key_pem != actor.public_key.public_key_pem { if actor_old.public_key.public_key_pem != actor.public_key.public_key_pem {
log::warn!( log::warn!(
@ -222,16 +193,11 @@ pub async fn update_remote_profile(
&storage.media_dir, &storage.media_dir,
profile.avatar, profile.avatar,
profile.banner, profile.banner,
).await; )
let (identity_proofs, payment_options, extra_fields) = .await;
actor.parse_attachments(); let (identity_proofs, payment_options, extra_fields) = actor.parse_attachments();
let aliases = parse_aliases(&actor); let aliases = parse_aliases(&actor);
let emojis = parse_tags( let emojis = parse_tags(db_client, instance, storage, &actor).await?;
db_client,
instance,
storage,
&actor,
).await?;
let mut profile_data = ProfileUpdateData { let mut profile_data = ProfileUpdateData {
display_name: actor.name.clone(), display_name: actor.name.clone(),
bio: actor.summary.clone(), bio: actor.summary.clone(),

View file

@ -1,22 +1,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use serde::{ use serde::{de::Error as DeserializerError, Deserialize, Deserializer, Serialize};
Deserialize,
Deserializer,
Serialize,
de::{Error as DeserializerError},
};
use serde_json::{json, Value}; use serde_json::{json, Value};
use mitra_config::Instance; use mitra_config::Instance;
use mitra_models::{ use mitra_models::{
profiles::types::{ profiles::types::{DbActor, DbActorPublicKey, ExtraField, IdentityProof, PaymentOption},
DbActor,
DbActorPublicKey,
ExtraField,
IdentityProof,
PaymentOption,
},
users::types::User, users::types::User,
}; };
use mitra_utils::{ use mitra_utils::{
@ -26,17 +15,10 @@ use mitra_utils::{
use crate::activitypub::{ use crate::activitypub::{
constants::{ constants::{
AP_CONTEXT, AP_CONTEXT, MASTODON_CONTEXT, MITRA_CONTEXT, SCHEMA_ORG_CONTEXT, W3ID_SECURITY_CONTEXT,
MASTODON_CONTEXT,
MITRA_CONTEXT,
SCHEMA_ORG_CONTEXT,
W3ID_SECURITY_CONTEXT,
}, },
identifiers::{ identifiers::{
local_actor_id, local_actor_id, local_actor_key_id, local_instance_actor_id, LocalActorCollection,
local_actor_key_id,
local_instance_actor_id,
LocalActorCollection,
}, },
types::deserialize_value_array, types::deserialize_value_array,
vocabulary::{IDENTITY_PROOF, IMAGE, LINK, PERSON, PROPERTY_VALUE, SERVICE}, vocabulary::{IDENTITY_PROOF, IMAGE, LINK, PERSON, PROPERTY_VALUE, SERVICE},
@ -46,12 +28,8 @@ use crate::media::get_file_url;
use crate::webfinger::types::ActorAddress; use crate::webfinger::types::ActorAddress;
use super::attachments::{ use super::attachments::{
attach_extra_field, attach_extra_field, attach_identity_proof, attach_payment_option, parse_extra_field,
attach_identity_proof, parse_identity_proof, parse_payment_option,
attach_payment_option,
parse_extra_field,
parse_identity_proof,
parse_payment_option,
}; };
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
@ -94,21 +72,17 @@ pub struct ActorAttachment {
} }
// Some implementations use empty object instead of null // Some implementations use empty object instead of null
fn deserialize_image_opt<'de, D>( fn deserialize_image_opt<'de, D>(deserializer: D) -> Result<Option<ActorImage>, D::Error>
deserializer: D, where
) -> Result<Option<ActorImage>, D::Error> D: Deserializer<'de>,
where D: Deserializer<'de>
{ {
let maybe_value: Option<Value> = Option::deserialize(deserializer)?; let maybe_value: Option<Value> = Option::deserialize(deserializer)?;
let maybe_image = if let Some(value) = maybe_value { let maybe_image = if let Some(value) = maybe_value {
let is_empty_object = value.as_object() let is_empty_object = value.as_object().map(|map| map.is_empty()).unwrap_or(false);
.map(|map| map.is_empty())
.unwrap_or(false);
if is_empty_object { if is_empty_object {
None None
} else { } else {
let image = ActorImage::deserialize(value) let image = ActorImage::deserialize(value).map_err(DeserializerError::custom)?;
.map_err(DeserializerError::custom)?;
Some(image) Some(image)
} }
} else { } else {
@ -149,7 +123,7 @@ pub struct Actor {
#[serde( #[serde(
default, default,
deserialize_with = "deserialize_image_opt", deserialize_with = "deserialize_image_opt",
skip_serializing_if = "Option::is_none", skip_serializing_if = "Option::is_none"
)] )]
pub icon: Option<ActorImage>, pub icon: Option<ActorImage>,
@ -165,7 +139,7 @@ pub struct Actor {
#[serde( #[serde(
default, default,
deserialize_with = "deserialize_value_array", deserialize_with = "deserialize_value_array",
skip_serializing_if = "Vec::is_empty", skip_serializing_if = "Vec::is_empty"
)] )]
pub attachment: Vec<Value>, pub attachment: Vec<Value>,
@ -175,7 +149,7 @@ pub struct Actor {
#[serde( #[serde(
default, default,
deserialize_with = "deserialize_value_array", deserialize_with = "deserialize_value_array",
skip_serializing_if = "Vec::is_empty", skip_serializing_if = "Vec::is_empty"
)] )]
pub tag: Vec<Value>, pub tag: Vec<Value>,
@ -184,11 +158,8 @@ pub struct Actor {
} }
impl Actor { impl Actor {
pub fn address( pub fn address(&self) -> Result<ActorAddress, ValidationError> {
&self, let hostname = get_hostname(&self.id).map_err(|_| ValidationError("invalid actor ID"))?;
) -> Result<ActorAddress, ValidationError> {
let hostname = get_hostname(&self.id)
.map_err(|_| ValidationError("invalid actor ID"))?;
let actor_address = ActorAddress { let actor_address = ActorAddress {
username: self.preferred_username.clone(), username: self.preferred_username.clone(),
hostname: hostname, hostname: hostname,
@ -213,11 +184,7 @@ impl Actor {
} }
} }
pub fn parse_attachments(&self) -> ( pub fn parse_attachments(&self) -> (Vec<IdentityProof>, Vec<PaymentOption>, Vec<ExtraField>) {
Vec<IdentityProof>,
Vec<PaymentOption>,
Vec<ExtraField>,
) {
let mut identity_proofs = vec![]; let mut identity_proofs = vec![];
let mut payment_options = vec![]; let mut payment_options = vec![];
let mut extra_fields = vec![]; let mut extra_fields = vec![];
@ -229,17 +196,13 @@ impl Actor {
); );
}; };
for attachment_value in self.attachment.iter() { for attachment_value in self.attachment.iter() {
let attachment_type = let attachment_type = attachment_value["type"].as_str().unwrap_or("Unknown");
attachment_value["type"].as_str().unwrap_or("Unknown");
let attachment = match serde_json::from_value(attachment_value.clone()) { let attachment = match serde_json::from_value(attachment_value.clone()) {
Ok(attachment) => attachment, Ok(attachment) => attachment,
Err(_) => { Err(_) => {
log_error( log_error(attachment_type, ValidationError("invalid attachment"));
attachment_type,
ValidationError("invalid attachment"),
);
continue; continue;
}, }
}; };
match attachment_type { match attachment_type {
IDENTITY_PROOF => { IDENTITY_PROOF => {
@ -247,27 +210,27 @@ impl Actor {
Ok(proof) => identity_proofs.push(proof), Ok(proof) => identity_proofs.push(proof),
Err(error) => log_error(attachment_type, error), Err(error) => log_error(attachment_type, error),
}; };
}, }
LINK => { LINK => {
match parse_payment_option(&attachment) { match parse_payment_option(&attachment) {
Ok(option) => payment_options.push(option), Ok(option) => payment_options.push(option),
Err(error) => log_error(attachment_type, error), Err(error) => log_error(attachment_type, error),
}; };
}, }
PROPERTY_VALUE => { PROPERTY_VALUE => {
match parse_extra_field(&attachment) { match parse_extra_field(&attachment) {
Ok(field) => extra_fields.push(field), Ok(field) => extra_fields.push(field),
Err(error) => log_error(attachment_type, error), Err(error) => log_error(attachment_type, error),
}; };
}, }
_ => { _ => {
log_error( log_error(
attachment_type, attachment_type,
ValidationError("unsupported attachment type"), ValidationError("unsupported attachment type"),
); );
}, }
}; };
}; }
(identity_proofs, payment_options, extra_fields) (identity_proofs, payment_options, extra_fields)
} }
} }
@ -295,10 +258,7 @@ fn build_actor_context() -> (
) )
} }
pub fn get_local_actor( pub fn get_local_actor(user: &User, instance_url: &str) -> Result<Actor, ActorKeyError> {
user: &User,
instance_url: &str,
) -> Result<Actor, ActorKeyError> {
let username = &user.profile.username; let username = &user.profile.username;
let actor_id = local_actor_id(instance_url, username); let actor_id = local_actor_id(instance_url, username);
let inbox = LocalActorCollection::Inbox.of(&actor_id); let inbox = LocalActorCollection::Inbox.of(&actor_id);
@ -322,7 +282,7 @@ pub fn get_local_actor(
media_type: image.media_type.clone(), media_type: image.media_type.clone(),
}; };
Some(actor_image) Some(actor_image)
}, }
None => None, None => None,
}; };
let banner = match &user.profile.banner { let banner = match &user.profile.banner {
@ -333,32 +293,29 @@ pub fn get_local_actor(
media_type: image.media_type.clone(), media_type: image.media_type.clone(),
}; };
Some(actor_image) Some(actor_image)
}, }
None => None, None => None,
}; };
let mut attachments = vec![]; let mut attachments = vec![];
for proof in user.profile.identity_proofs.clone().into_inner() { for proof in user.profile.identity_proofs.clone().into_inner() {
let attachment = attach_identity_proof(proof); let attachment = attach_identity_proof(proof);
let attachment_value = serde_json::to_value(attachment) let attachment_value =
.expect("attachment should be serializable"); serde_json::to_value(attachment).expect("attachment should be serializable");
attachments.push(attachment_value); attachments.push(attachment_value);
}; }
for payment_option in user.profile.payment_options.clone().into_inner() { for payment_option in user.profile.payment_options.clone().into_inner() {
let attachment = attach_payment_option( let attachment =
instance_url, attach_payment_option(instance_url, &user.profile.username, payment_option);
&user.profile.username, let attachment_value =
payment_option, serde_json::to_value(attachment).expect("attachment should be serializable");
);
let attachment_value = serde_json::to_value(attachment)
.expect("attachment should be serializable");
attachments.push(attachment_value); attachments.push(attachment_value);
}; }
for field in user.profile.extra_fields.clone().into_inner() { for field in user.profile.extra_fields.clone().into_inner() {
let attachment = attach_extra_field(field); let attachment = attach_extra_field(field);
let attachment_value = serde_json::to_value(attachment) let attachment_value =
.expect("attachment should be serializable"); serde_json::to_value(attachment).expect("attachment should be serializable");
attachments.push(attachment_value); attachments.push(attachment_value);
}; }
let aliases = user.profile.aliases.clone().into_actor_ids(); let aliases = user.profile.aliases.clone().into_actor_ids();
let actor = Actor { let actor = Actor {
context: Some(json!(build_actor_context())), context: Some(json!(build_actor_context())),
@ -384,9 +341,7 @@ pub fn get_local_actor(
Ok(actor) Ok(actor)
} }
pub fn get_instance_actor( pub fn get_instance_actor(instance: &Instance) -> Result<Actor, ActorKeyError> {
instance: &Instance,
) -> Result<Actor, ActorKeyError> {
let actor_id = local_instance_actor_id(&instance.url()); let actor_id = local_instance_actor_id(&instance.url());
let actor_inbox = LocalActorCollection::Inbox.of(&actor_id); let actor_inbox = LocalActorCollection::Inbox.of(&actor_id);
let actor_outbox = LocalActorCollection::Outbox.of(&actor_id); let actor_outbox = LocalActorCollection::Outbox.of(&actor_id);
@ -422,12 +377,9 @@ pub fn get_instance_actor(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::crypto_rsa::{
generate_weak_rsa_key,
serialize_private_key,
};
use super::*; use super::*;
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::crypto_rsa::{generate_weak_rsa_key, serialize_private_key};
const INSTANCE_HOSTNAME: &str = "example.com"; const INSTANCE_HOSTNAME: &str = "example.com";
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";

View file

@ -7,24 +7,17 @@ use mitra_models::{
profiles::queries::get_profile_by_remote_actor_id, profiles::queries::get_profile_by_remote_actor_id,
profiles::types::DbActorProfile, profiles::types::DbActorProfile,
}; };
use mitra_utils::{ use mitra_utils::{crypto_rsa::deserialize_public_key, did::Did};
crypto_rsa::deserialize_public_key,
did::Did,
};
use crate::http_signatures::verify::{ use crate::http_signatures::verify::{
parse_http_signature, parse_http_signature, verify_http_signature,
verify_http_signature,
HttpSignatureVerificationError as HttpSignatureError, HttpSignatureVerificationError as HttpSignatureError,
}; };
use crate::json_signatures::{ use crate::json_signatures::{
proofs::ProofType, proofs::ProofType,
verify::{ verify::{
get_json_signature, get_json_signature, verify_ed25519_json_signature, verify_eip191_json_signature,
verify_ed25519_json_signature, verify_rsa_json_signature, JsonSignatureVerificationError as JsonSignatureError,
verify_eip191_json_signature,
verify_rsa_json_signature,
JsonSignatureVerificationError as JsonSignatureError,
JsonSigner, JsonSigner,
}, },
}; };
@ -93,12 +86,14 @@ async fn get_signer(
&config.instance(), &config.instance(),
&MediaStorage::from(config), &MediaStorage::from(config),
signer_id, signer_id,
).await { )
.await
{
Ok(profile) => profile, Ok(profile) => profile,
Err(HandlerError::DatabaseError(error)) => return Err(error.into()), Err(HandlerError::DatabaseError(error)) => return Err(error.into()),
Err(other_error) => { Err(other_error) => {
return Err(AuthenticationError::ImportError(other_error.to_string())); return Err(AuthenticationError::ImportError(other_error.to_string()));
}, }
} }
}; };
Ok(signer) Ok(signer)
@ -111,24 +106,22 @@ pub async fn verify_signed_request(
request: &HttpRequest, request: &HttpRequest,
no_fetch: bool, no_fetch: bool,
) -> Result<DbActorProfile, AuthenticationError> { ) -> Result<DbActorProfile, AuthenticationError> {
let signature_data = match parse_http_signature( let signature_data =
request.method(), match parse_http_signature(request.method(), request.uri(), request.headers()) {
request.uri(), Ok(signature_data) => signature_data,
request.headers(), Err(HttpSignatureError::NoSignature) => {
) { return Err(AuthenticationError::NoHttpSignature);
Ok(signature_data) => signature_data, }
Err(HttpSignatureError::NoSignature) => { Err(other_error) => return Err(other_error.into()),
return Err(AuthenticationError::NoHttpSignature); };
},
Err(other_error) => return Err(other_error.into()),
};
let signer_id = key_id_to_actor_id(&signature_data.key_id)?; let signer_id = key_id_to_actor_id(&signature_data.key_id)?;
let signer = get_signer(config, db_client, &signer_id, no_fetch).await?; let signer = get_signer(config, db_client, &signer_id, no_fetch).await?;
let signer_actor = signer.actor_json.as_ref() let signer_actor = signer
.actor_json
.as_ref()
.expect("request should be signed by remote actor"); .expect("request should be signed by remote actor");
let signer_key = let signer_key = deserialize_public_key(&signer_actor.public_key.public_key_pem)?;
deserialize_public_key(&signer_actor.public_key.public_key_pem)?;
verify_http_signature(&signature_data, &signer_key)?; verify_http_signature(&signature_data, &signer_key)?;
@ -146,13 +139,14 @@ pub async fn verify_signed_activity(
Ok(signature_data) => signature_data, Ok(signature_data) => signature_data,
Err(JsonSignatureError::NoProof) => { Err(JsonSignatureError::NoProof) => {
return Err(AuthenticationError::NoJsonSignature); return Err(AuthenticationError::NoJsonSignature);
}, }
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
}; };
// Signed activities must have `actor` property, to avoid situations // Signed activities must have `actor` property, to avoid situations
// where signer is identified by DID but there is no matching // where signer is identified by DID but there is no matching
// identity proof in the local database. // identity proof in the local database.
let actor_id = activity["actor"].as_str() let actor_id = activity["actor"]
.as_str()
.ok_or(AuthenticationError::ActorError("unknown actor"))?; .ok_or(AuthenticationError::ActorError("unknown actor"))?;
let actor_profile = get_signer(config, db_client, actor_id, no_fetch).await?; let actor_profile = get_signer(config, db_client, actor_id, no_fetch).await?;
@ -165,12 +159,13 @@ pub async fn verify_signed_activity(
if signer_id != actor_id { if signer_id != actor_id {
return Err(AuthenticationError::UnexpectedSigner); return Err(AuthenticationError::UnexpectedSigner);
}; };
let signer_actor = actor_profile.actor_json.as_ref() let signer_actor = actor_profile
.actor_json
.as_ref()
.expect("activity should be signed by remote actor"); .expect("activity should be signed by remote actor");
let signer_key = let signer_key = deserialize_public_key(&signer_actor.public_key.public_key_pem)?;
deserialize_public_key(&signer_actor.public_key.public_key_pem)?;
verify_rsa_json_signature(&signature_data, &signer_key)?; verify_rsa_json_signature(&signature_data, &signer_key)?;
}, }
JsonSigner::Did(did) => { JsonSigner::Did(did) => {
if !actor_profile.identity_proofs.any(&did) { if !actor_profile.identity_proofs.any(&did) {
return Err(AuthenticationError::UnexpectedSigner); return Err(AuthenticationError::UnexpectedSigner);
@ -186,7 +181,7 @@ pub async fn verify_signed_activity(
&signature_data.message, &signature_data.message,
&signature_data.signature, &signature_data.signature,
)?; )?;
}, }
ProofType::JcsEip191Signature => { ProofType::JcsEip191Signature => {
let did_pkh = match did { let did_pkh = match did {
Did::Pkh(did_pkh) => did_pkh, Did::Pkh(did_pkh) => did_pkh,
@ -197,10 +192,10 @@ pub async fn verify_signed_activity(
&signature_data.message, &signature_data.message,
&signature_data.signature, &signature_data.signature,
)?; )?;
}, }
_ => return Err(AuthenticationError::InvalidJsonSignatureType), _ => return Err(AuthenticationError::InvalidJsonSignatureType),
}; };
}, }
}; };
// Signer is actor // Signer is actor
Ok(actor_profile) Ok(actor_profile)

View file

@ -61,12 +61,7 @@ pub fn prepare_accept_follow(
follow_activity_id, follow_activity_id,
); );
let recipients = vec![source_actor.clone()]; let recipients = vec![source_actor.clone()];
OutgoingActivity::new( OutgoingActivity::new(instance, sender, activity, recipients)
instance,
sender,
activity,
recipients,
)
} }
#[cfg(test)] #[cfg(test)]
@ -83,12 +78,7 @@ mod tests {
}; };
let follow_activity_id = "https://test.remote/objects/999"; let follow_activity_id = "https://test.remote/objects/999";
let follower_id = "https://test.remote/users/123"; let follower_id = "https://test.remote/users/123";
let activity = build_accept_follow( let activity = build_accept_follow(INSTANCE_URL, &target, follower_id, follow_activity_id);
INSTANCE_URL,
&target,
follower_id,
follow_activity_id,
);
assert_eq!(activity.id.starts_with(INSTANCE_URL), true); assert_eq!(activity.id.starts_with(INSTANCE_URL), true);
assert_eq!(activity.activity_type, "Accept"); assert_eq!(activity.activity_type, "Accept");

View file

@ -1,10 +1,7 @@
use serde::Serialize; use serde::Serialize;
use mitra_config::Instance; use mitra_config::Instance;
use mitra_models::{ use mitra_models::{profiles::types::DbActor, users::types::User};
profiles::types::DbActor,
users::types::User,
};
use mitra_utils::id::generate_ulid; use mitra_utils::id::generate_ulid;
use crate::activitypub::{ use crate::activitypub::{
@ -67,12 +64,7 @@ pub fn prepare_update_collection(
remove, remove,
); );
let recipients = vec![person.clone()]; let recipients = vec![person.clone()];
OutgoingActivity::new( OutgoingActivity::new(instance, sender, activity, recipients)
instance,
sender,
activity,
recipients,
)
} }
pub fn prepare_add_person( pub fn prepare_add_person(
@ -95,13 +87,8 @@ mod tests {
let sender_username = "local"; let sender_username = "local";
let person_id = "https://test.remote/actor/test"; let person_id = "https://test.remote/actor/test";
let collection = LocalActorCollection::Subscribers; let collection = LocalActorCollection::Subscribers;
let activity = build_update_collection( let activity =
INSTANCE_URL, build_update_collection(INSTANCE_URL, sender_username, person_id, collection, false);
sender_username,
person_id,
collection,
false,
);
assert_eq!(activity.activity_type, "Add"); assert_eq!(activity.activity_type, "Add");
assert_eq!( assert_eq!(

View file

@ -14,11 +14,7 @@ use crate::activitypub::{
constants::AP_PUBLIC, constants::AP_PUBLIC,
deliverer::OutgoingActivity, deliverer::OutgoingActivity,
identifiers::{ identifiers::{
local_actor_followers, local_actor_followers, local_actor_id, local_object_id, post_object_id, profile_actor_id,
local_actor_id,
local_object_id,
post_object_id,
profile_actor_id,
}, },
types::{build_default_context, Context}, types::{build_default_context, Context},
vocabulary::ANNOUNCE, vocabulary::ANNOUNCE,
@ -41,12 +37,11 @@ pub struct Announce {
cc: Vec<String>, cc: Vec<String>,
} }
pub fn build_announce( pub fn build_announce(instance_url: &str, repost: &Post) -> Announce {
instance_url: &str,
repost: &Post,
) -> Announce {
let actor_id = local_actor_id(instance_url, &repost.author.username); let actor_id = local_actor_id(instance_url, &repost.author.username);
let post = repost.repost_of.as_ref() let post = repost
.repost_of
.as_ref()
.expect("repost_of field should be populated"); .expect("repost_of field should be populated");
let object_id = post_object_id(instance_url, post); let object_id = post_object_id(instance_url, post);
let activity_id = local_object_id(instance_url, &repost.id); let activity_id = local_object_id(instance_url, &repost.id);
@ -76,7 +71,7 @@ pub async fn get_announce_recipients(
if let Some(remote_actor) = profile.actor_json { if let Some(remote_actor) = profile.actor_json {
recipients.push(remote_actor); recipients.push(remote_actor);
}; };
}; }
let primary_recipient = profile_actor_id(instance_url, &post.author); let primary_recipient = profile_actor_id(instance_url, &post.author);
if let Some(remote_actor) = post.author.actor_json.as_ref() { if let Some(remote_actor) = post.author.actor_json.as_ref() {
recipients.push(remote_actor.clone()); recipients.push(remote_actor.clone());
@ -91,30 +86,21 @@ pub async fn prepare_announce(
repost: &Post, repost: &Post,
) -> Result<OutgoingActivity, DatabaseError> { ) -> Result<OutgoingActivity, DatabaseError> {
assert_eq!(sender.id, repost.author.id); assert_eq!(sender.id, repost.author.id);
let post = repost.repost_of.as_ref() let post = repost
.repost_of
.as_ref()
.expect("repost_of field should be populated"); .expect("repost_of field should be populated");
let (recipients, _) = get_announce_recipients( let (recipients, _) = get_announce_recipients(db_client, &instance.url(), sender, post).await?;
db_client, let activity = build_announce(&instance.url(), repost);
&instance.url(),
sender,
post,
).await?;
let activity = build_announce(
&instance.url(),
repost,
);
Ok(OutgoingActivity::new( Ok(OutgoingActivity::new(
instance, instance, sender, activity, recipients,
sender,
activity,
recipients,
)) ))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_models::profiles::types::DbActorProfile;
use super::*; use super::*;
use mitra_models::profiles::types::DbActorProfile;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";
@ -145,18 +131,12 @@ mod tests {
repost_of: Some(Box::new(post)), repost_of: Some(Box::new(post)),
..Default::default() ..Default::default()
}; };
let activity = build_announce( let activity = build_announce(INSTANCE_URL, &repost);
INSTANCE_URL,
&repost,
);
assert_eq!( assert_eq!(
activity.id, activity.id,
format!("{}/objects/{}", INSTANCE_URL, repost.id), format!("{}/objects/{}", INSTANCE_URL, repost.id),
); );
assert_eq!( assert_eq!(activity.actor, format!("{}/users/announcer", INSTANCE_URL),);
activity.actor,
format!("{}/users/announcer", INSTANCE_URL),
);
assert_eq!(activity.object, post_id); assert_eq!(activity.object, post_id);
assert_eq!(activity.to, vec![AP_PUBLIC, post_author_id]); assert_eq!(activity.to, vec![AP_PUBLIC, post_author_id]);
} }

View file

@ -16,23 +16,11 @@ use crate::activitypub::{
constants::{AP_MEDIA_TYPE, AP_PUBLIC}, constants::{AP_MEDIA_TYPE, AP_PUBLIC},
deliverer::OutgoingActivity, deliverer::OutgoingActivity,
identifiers::{ identifiers::{
local_actor_id, local_actor_followers, local_actor_id, local_actor_subscribers, local_emoji_id,
local_actor_followers, local_object_id, local_tag_collection, post_object_id, profile_actor_id,
local_actor_subscribers,
local_emoji_id,
local_object_id,
local_tag_collection,
post_object_id,
profile_actor_id,
}, },
types::{ types::{
build_default_context, build_default_context, Attachment, Context, EmojiTag, EmojiTagImage, LinkTag, SimpleTag,
Attachment,
Context,
EmojiTag,
EmojiTagImage,
LinkTag,
SimpleTag,
}, },
vocabulary::*, vocabulary::*,
}; };
@ -96,50 +84,45 @@ pub fn build_emoji_tag(instance_url: &str, emoji: &DbEmoji) -> EmojiTag {
} }
} }
pub fn build_note( pub fn build_note(instance_hostname: &str, instance_url: &str, post: &Post) -> Note {
instance_hostname: &str,
instance_url: &str,
post: &Post,
) -> Note {
let object_id = local_object_id(instance_url, &post.id); let object_id = local_object_id(instance_url, &post.id);
let actor_id = local_actor_id(instance_url, &post.author.username); let actor_id = local_actor_id(instance_url, &post.author.username);
let attachments: Vec<Attachment> = post.attachments.iter().map(|db_item| { let attachments: Vec<Attachment> = post
let url = get_file_url(instance_url, &db_item.file_name); .attachments
let media_type = db_item.media_type.clone(); .iter()
Attachment { .map(|db_item| {
name: None, let url = get_file_url(instance_url, &db_item.file_name);
attachment_type: DOCUMENT.to_string(), let media_type = db_item.media_type.clone();
media_type, Attachment {
url: Some(url), name: None,
} attachment_type: DOCUMENT.to_string(),
}).collect(); media_type,
url: Some(url),
}
})
.collect();
let mut primary_audience = vec![]; let mut primary_audience = vec![];
let mut secondary_audience = vec![]; let mut secondary_audience = vec![];
let followers_collection_id = let followers_collection_id = local_actor_followers(instance_url, &post.author.username);
local_actor_followers(instance_url, &post.author.username); let subscribers_collection_id = local_actor_subscribers(instance_url, &post.author.username);
let subscribers_collection_id =
local_actor_subscribers(instance_url, &post.author.username);
match post.visibility { match post.visibility {
Visibility::Public => { Visibility::Public => {
primary_audience.push(AP_PUBLIC.to_string()); primary_audience.push(AP_PUBLIC.to_string());
secondary_audience.push(followers_collection_id); secondary_audience.push(followers_collection_id);
}, }
Visibility::Followers => { Visibility::Followers => {
primary_audience.push(followers_collection_id); primary_audience.push(followers_collection_id);
}, }
Visibility::Subscribers => { Visibility::Subscribers => {
primary_audience.push(subscribers_collection_id); primary_audience.push(subscribers_collection_id);
}, }
Visibility::Direct => (), Visibility::Direct => (),
}; };
let mut tags = vec![]; let mut tags = vec![];
for profile in &post.mentions { for profile in &post.mentions {
let actor_address = ActorAddress::from_profile( let actor_address = ActorAddress::from_profile(instance_hostname, profile);
instance_hostname,
profile,
);
let tag_name = format!("@{}", actor_address); let tag_name = format!("@{}", actor_address);
let actor_id = profile_actor_id(instance_url, profile); let actor_id = profile_actor_id(instance_url, profile);
if !primary_audience.contains(&actor_id) { if !primary_audience.contains(&actor_id) {
@ -151,7 +134,7 @@ pub fn build_note(
href: actor_id, href: actor_id,
}; };
tags.push(Tag::SimpleTag(tag)); tags.push(Tag::SimpleTag(tag));
}; }
for tag_name in &post.tags { for tag_name in &post.tags {
let tag_href = local_tag_collection(instance_url, tag_name); let tag_href = local_tag_collection(instance_url, tag_name);
let tag = SimpleTag { let tag = SimpleTag {
@ -160,7 +143,7 @@ pub fn build_note(
href: tag_href, href: tag_href,
}; };
tags.push(Tag::SimpleTag(tag)); tags.push(Tag::SimpleTag(tag));
}; }
assert_eq!(post.links.len(), post.linked.len()); assert_eq!(post.links.len(), post.linked.len());
for linked in &post.linked { for linked in &post.linked {
@ -168,7 +151,7 @@ pub fn build_note(
// https://codeberg.org/fediverse/fep/src/branch/main/feps/fep-e232.md // https://codeberg.org/fediverse/fep/src/branch/main/feps/fep-e232.md
let link_href = post_object_id(instance_url, linked); let link_href = post_object_id(instance_url, linked);
let tag = LinkTag { let tag = LinkTag {
name: None, // no microsyntax name: None, // no microsyntax
tag_type: LINK.to_string(), tag_type: LINK.to_string(),
href: link_href, href: link_href,
media_type: AP_MEDIA_TYPE.to_string(), media_type: AP_MEDIA_TYPE.to_string(),
@ -176,29 +159,30 @@ pub fn build_note(
if cfg!(feature = "fep-e232") { if cfg!(feature = "fep-e232") {
tags.push(Tag::LinkTag(tag)); tags.push(Tag::LinkTag(tag));
}; };
}; }
let maybe_quote_url = post.linked.get(0) let maybe_quote_url = post
.linked
.get(0)
.map(|linked| post_object_id(instance_url, linked)); .map(|linked| post_object_id(instance_url, linked));
for emoji in &post.emojis { for emoji in &post.emojis {
let tag = build_emoji_tag(instance_url, emoji); let tag = build_emoji_tag(instance_url, emoji);
tags.push(Tag::EmojiTag(tag)); tags.push(Tag::EmojiTag(tag));
}; }
let in_reply_to_object_id = match post.in_reply_to_id { let in_reply_to_object_id = match post.in_reply_to_id {
Some(in_reply_to_id) => { Some(in_reply_to_id) => {
let in_reply_to = post.in_reply_to.as_ref() let in_reply_to = post
.in_reply_to
.as_ref()
.expect("in_reply_to should be populated"); .expect("in_reply_to should be populated");
assert_eq!(in_reply_to.id, in_reply_to_id); assert_eq!(in_reply_to.id, in_reply_to_id);
let in_reply_to_actor_id = profile_actor_id( let in_reply_to_actor_id = profile_actor_id(instance_url, &in_reply_to.author);
instance_url,
&in_reply_to.author,
);
if !primary_audience.contains(&in_reply_to_actor_id) { if !primary_audience.contains(&in_reply_to_actor_id) {
primary_audience.push(in_reply_to_actor_id); primary_audience.push(in_reply_to_actor_id);
}; };
Some(post_object_id(instance_url, in_reply_to)) Some(post_object_id(instance_url, in_reply_to))
}, }
None => None, None => None,
}; };
Note { Note {
@ -234,11 +218,7 @@ pub struct CreateNote {
cc: Vec<String>, cc: Vec<String>,
} }
pub fn build_create_note( pub fn build_create_note(instance_hostname: &str, instance_url: &str, post: &Post) -> CreateNote {
instance_hostname: &str,
instance_url: &str,
post: &Post,
) -> CreateNote {
let object = build_note(instance_hostname, instance_url, post); let object = build_note(instance_hostname, instance_url, post);
let primary_audience = object.to.clone(); let primary_audience = object.to.clone();
let secondary_audience = object.cc.clone(); let secondary_audience = object.cc.clone();
@ -264,11 +244,11 @@ pub async fn get_note_recipients(
Visibility::Public | Visibility::Followers => { Visibility::Public | Visibility::Followers => {
let followers = get_followers(db_client, &current_user.id).await?; let followers = get_followers(db_client, &current_user.id).await?;
audience.extend(followers); audience.extend(followers);
}, }
Visibility::Subscribers => { Visibility::Subscribers => {
let subscribers = get_subscribers(db_client, &current_user.id).await?; let subscribers = get_subscribers(db_client, &current_user.id).await?;
audience.extend(subscribers); audience.extend(subscribers);
}, }
Visibility::Direct => (), Visibility::Direct => (),
}; };
if let Some(in_reply_to_id) = post.in_reply_to_id { if let Some(in_reply_to_id) = post.in_reply_to_id {
@ -283,7 +263,7 @@ pub async fn get_note_recipients(
if let Some(remote_actor) = profile.actor_json { if let Some(remote_actor) = profile.actor_json {
recipients.push(remote_actor); recipients.push(remote_actor);
}; };
}; }
Ok(recipients) Ok(recipients)
} }
@ -294,25 +274,18 @@ pub async fn prepare_create_note(
post: &Post, post: &Post,
) -> Result<OutgoingActivity, DatabaseError> { ) -> Result<OutgoingActivity, DatabaseError> {
assert_eq!(author.id, post.author.id); assert_eq!(author.id, post.author.id);
let activity = build_create_note( let activity = build_create_note(&instance.hostname(), &instance.url(), post);
&instance.hostname(),
&instance.url(),
post,
);
let recipients = get_note_recipients(db_client, author, post).await?; let recipients = get_note_recipients(db_client, author, post).await?;
Ok(OutgoingActivity::new( Ok(OutgoingActivity::new(
instance, instance, author, activity, recipients,
author,
activity,
recipients,
)) ))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json::json;
use mitra_models::profiles::types::DbActorProfile;
use super::*; use super::*;
use mitra_models::profiles::types::DbActorProfile;
use serde_json::json;
const INSTANCE_HOSTNAME: &str = "example.com"; const INSTANCE_HOSTNAME: &str = "example.com";
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";
@ -326,11 +299,14 @@ mod tests {
}; };
let tag = Tag::SimpleTag(simple_tag); let tag = Tag::SimpleTag(simple_tag);
let value = serde_json::to_value(tag).unwrap(); let value = serde_json::to_value(tag).unwrap();
assert_eq!(value, json!({ assert_eq!(
"type": "Hashtag", value,
"href": "https://example.org/tags/test", json!({
"name": "#test", "type": "Hashtag",
})); "href": "https://example.org/tags/test",
"name": "#test",
})
);
} }
#[test] #[test]
@ -357,10 +333,7 @@ mod tests {
}; };
let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post); let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!( assert_eq!(note.id, format!("{}/objects/{}", INSTANCE_URL, post.id),);
note.id,
format!("{}/objects/{}", INSTANCE_URL, post.id),
);
assert_eq!(note.attachment.len(), 0); assert_eq!(note.attachment.len(), 0);
assert_eq!( assert_eq!(
note.attributed_to, note.attributed_to,
@ -369,9 +342,10 @@ mod tests {
assert_eq!(note.in_reply_to.is_none(), true); assert_eq!(note.in_reply_to.is_none(), true);
assert_eq!(note.content, post.content); assert_eq!(note.content, post.content);
assert_eq!(note.to, vec![AP_PUBLIC]); assert_eq!(note.to, vec![AP_PUBLIC]);
assert_eq!(note.cc, vec![ assert_eq!(
local_actor_followers(INSTANCE_URL, "author"), note.cc,
]); vec![local_actor_followers(INSTANCE_URL, "author"),]
);
assert_eq!(note.tag.len(), 1); assert_eq!(note.tag.len(), 1);
let tag = match note.tag[0] { let tag = match note.tag[0] {
Tag::SimpleTag(ref tag) => tag, Tag::SimpleTag(ref tag) => tag,
@ -389,9 +363,10 @@ mod tests {
}; };
let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post); let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!(note.to, vec![ assert_eq!(
local_actor_followers(INSTANCE_URL, &post.author.username), note.to,
]); vec![local_actor_followers(INSTANCE_URL, &post.author.username),]
);
assert_eq!(note.cc.is_empty(), true); assert_eq!(note.cc.is_empty(), true);
} }
@ -415,10 +390,13 @@ mod tests {
}; };
let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post); let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!(note.to, vec![ assert_eq!(
local_actor_subscribers(INSTANCE_URL, &post.author.username), note.to,
subscriber_id.to_string(), vec![
]); local_actor_subscribers(INSTANCE_URL, &post.author.username),
subscriber_id.to_string(),
]
);
assert_eq!(note.cc.is_empty(), true); assert_eq!(note.cc.is_empty(), true);
} }
@ -460,10 +438,13 @@ mod tests {
note.in_reply_to.unwrap(), note.in_reply_to.unwrap(),
format!("{}/objects/{}", INSTANCE_URL, parent.id), format!("{}/objects/{}", INSTANCE_URL, parent.id),
); );
assert_eq!(note.to, vec![ assert_eq!(
AP_PUBLIC.to_string(), note.to,
local_actor_id(INSTANCE_URL, &parent.author.username), vec![
]); AP_PUBLIC.to_string(),
local_actor_id(INSTANCE_URL, &parent.author.username),
]
);
} }
#[test] #[test]
@ -496,10 +477,7 @@ mod tests {
}; };
let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post); let note = build_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
assert_eq!( assert_eq!(note.in_reply_to.unwrap(), parent.object_id.unwrap(),);
note.in_reply_to.unwrap(),
parent.object_id.unwrap(),
);
let tags = note.tag; let tags = note.tag;
assert_eq!(tags.len(), 1); assert_eq!(tags.len(), 1);
let tag = match tags[0] { let tag = match tags[0] {
@ -518,12 +496,11 @@ mod tests {
username: author_username.to_string(), username: author_username.to_string(),
..Default::default() ..Default::default()
}; };
let post = Post { author, ..Default::default() }; let post = Post {
let activity = build_create_note( author,
INSTANCE_HOSTNAME, ..Default::default()
INSTANCE_URL, };
&post, let activity = build_create_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
);
assert_eq!( assert_eq!(
activity.id, activity.id,

View file

@ -15,11 +15,7 @@ use crate::activitypub::{
vocabulary::{DELETE, NOTE, TOMBSTONE}, vocabulary::{DELETE, NOTE, TOMBSTONE},
}; };
use super::create_note::{ use super::create_note::{build_note, get_note_recipients, Note};
build_note,
get_note_recipients,
Note,
};
#[derive(Serialize)] #[derive(Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -48,20 +44,12 @@ struct DeleteNote {
cc: Vec<String>, cc: Vec<String>,
} }
fn build_delete_note( fn build_delete_note(instance_hostname: &str, instance_url: &str, post: &Post) -> DeleteNote {
instance_hostname: &str,
instance_url: &str,
post: &Post,
) -> DeleteNote {
assert!(post.is_local()); assert!(post.is_local());
let object_id = local_object_id(instance_url, &post.id); let object_id = local_object_id(instance_url, &post.id);
let activity_id = format!("{}/delete", object_id); let activity_id = format!("{}/delete", object_id);
let actor_id = local_actor_id(instance_url, &post.author.username); let actor_id = local_actor_id(instance_url, &post.author.username);
let Note { to, cc, .. } = build_note( let Note { to, cc, .. } = build_note(instance_hostname, instance_url, post);
instance_hostname,
instance_url,
post,
);
DeleteNote { DeleteNote {
context: build_default_context(), context: build_default_context(),
activity_type: DELETE.to_string(), activity_type: DELETE.to_string(),
@ -86,28 +74,18 @@ pub async fn prepare_delete_note(
assert_eq!(author.id, post.author.id); assert_eq!(author.id, post.author.id);
let mut post = post.clone(); let mut post = post.clone();
add_related_posts(db_client, vec![&mut post]).await?; add_related_posts(db_client, vec![&mut post]).await?;
let activity = build_delete_note( let activity = build_delete_note(&instance.hostname(), &instance.url(), &post);
&instance.hostname(),
&instance.url(),
&post,
);
let recipients = get_note_recipients(db_client, author, &post).await?; let recipients = get_note_recipients(db_client, author, &post).await?;
Ok(OutgoingActivity::new( Ok(OutgoingActivity::new(
instance, instance, author, activity, recipients,
author,
activity,
recipients,
)) ))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_models::profiles::types::DbActorProfile;
use crate::activitypub::{
constants::AP_PUBLIC,
identifiers::local_actor_followers,
};
use super::*; use super::*;
use crate::activitypub::{constants::AP_PUBLIC, identifiers::local_actor_followers};
use mitra_models::profiles::types::DbActorProfile;
const INSTANCE_HOSTNAME: &str = "example.com"; const INSTANCE_HOSTNAME: &str = "example.com";
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";
@ -118,12 +96,11 @@ mod tests {
username: "author".to_string(), username: "author".to_string(),
..Default::default() ..Default::default()
}; };
let post = Post { author, ..Default::default() }; let post = Post {
let activity = build_delete_note( author,
INSTANCE_HOSTNAME, ..Default::default()
INSTANCE_URL, };
&post, let activity = build_delete_note(INSTANCE_HOSTNAME, INSTANCE_URL, &post);
);
assert_eq!( assert_eq!(
activity.id, activity.id,

View file

@ -32,10 +32,7 @@ struct DeletePerson {
to: Vec<String>, to: Vec<String>,
} }
fn build_delete_person( fn build_delete_person(instance_url: &str, user: &User) -> DeletePerson {
instance_url: &str,
user: &User,
) -> DeletePerson {
let actor_id = local_actor_id(instance_url, &user.profile.username); let actor_id = local_actor_id(instance_url, &user.profile.username);
let activity_id = format!("{}/delete", actor_id); let activity_id = format!("{}/delete", actor_id);
DeletePerson { DeletePerson {
@ -59,7 +56,7 @@ async fn get_delete_person_recipients(
if let Some(remote_actor) = profile.actor_json { if let Some(remote_actor) = profile.actor_json {
recipients.push(remote_actor); recipients.push(remote_actor);
}; };
}; }
Ok(recipients) Ok(recipients)
} }
@ -70,18 +67,13 @@ pub async fn prepare_delete_person(
) -> Result<OutgoingActivity, DatabaseError> { ) -> Result<OutgoingActivity, DatabaseError> {
let activity = build_delete_person(&instance.url(), user); let activity = build_delete_person(&instance.url(), user);
let recipients = get_delete_person_recipients(db_client, &user.id).await?; let recipients = get_delete_person_recipients(db_client, &user.id).await?;
Ok(OutgoingActivity::new( Ok(OutgoingActivity::new(instance, user, activity, recipients))
instance,
user,
activity,
recipients,
))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_models::profiles::types::DbActorProfile;
use super::*; use super::*;
use mitra_models::profiles::types::DbActorProfile;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";
@ -100,10 +92,7 @@ mod tests {
format!("{}/users/testuser/delete", INSTANCE_URL), format!("{}/users/testuser/delete", INSTANCE_URL),
); );
assert_eq!(activity.actor, activity.object); assert_eq!(activity.actor, activity.object);
assert_eq!( assert_eq!(activity.object, format!("{}/users/testuser", INSTANCE_URL),);
activity.object,
format!("{}/users/testuser", INSTANCE_URL),
);
assert_eq!(activity.to, vec![AP_PUBLIC]); assert_eq!(activity.to, vec![AP_PUBLIC]);
} }
} }

View file

@ -62,12 +62,7 @@ pub fn prepare_follow(
follow_request_id, follow_request_id,
); );
let recipients = vec![target_actor.clone()]; let recipients = vec![target_actor.clone()];
OutgoingActivity::new( OutgoingActivity::new(instance, sender, activity, recipients)
instance,
sender,
activity,
recipients,
)
} }
pub async fn follow_or_create_request( pub async fn follow_or_create_request(
@ -78,19 +73,12 @@ pub async fn follow_or_create_request(
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
if let Some(ref remote_actor) = target_profile.actor_json { if let Some(ref remote_actor) = target_profile.actor_json {
// Create follow request if target is remote // Create follow request if target is remote
match create_follow_request( match create_follow_request(db_client, &current_user.id, &target_profile.id).await {
db_client,
&current_user.id,
&target_profile.id,
).await {
Ok(follow_request) => { Ok(follow_request) => {
prepare_follow( prepare_follow(instance, current_user, remote_actor, &follow_request.id)
instance, .enqueue(db_client)
current_user, .await?;
remote_actor, }
&follow_request.id,
).enqueue(db_client).await?;
},
Err(DatabaseError::AlreadyExists(_)) => (), // already following Err(DatabaseError::AlreadyExists(_)) => (), // already following
Err(other_error) => return Err(other_error), Err(other_error) => return Err(other_error),
}; };
@ -106,8 +94,8 @@ pub async fn follow_or_create_request(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::id::generate_ulid;
use super::*; use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";
@ -119,12 +107,7 @@ mod tests {
}; };
let follow_request_id = generate_ulid(); let follow_request_id = generate_ulid();
let target_actor_id = "https://test.remote/actor/test"; let target_actor_id = "https://test.remote/actor/test";
let activity = build_follow( let activity = build_follow(INSTANCE_URL, &follower, target_actor_id, &follow_request_id);
INSTANCE_URL,
&follower,
target_actor_id,
&follow_request_id,
);
assert_eq!( assert_eq!(
activity.id, activity.id,

View file

@ -12,12 +12,7 @@ use mitra_models::{
use crate::activitypub::{ use crate::activitypub::{
constants::AP_PUBLIC, constants::AP_PUBLIC,
deliverer::OutgoingActivity, deliverer::OutgoingActivity,
identifiers::{ identifiers::{local_actor_id, local_object_id, post_object_id, profile_actor_id},
local_actor_id,
local_object_id,
post_object_id,
profile_actor_id,
},
types::{build_default_context, Context}, types::{build_default_context, Context},
vocabulary::LIKE, vocabulary::LIKE,
}; };
@ -60,8 +55,7 @@ fn build_like(
) -> Like { ) -> Like {
let activity_id = local_object_id(instance_url, reaction_id); let activity_id = local_object_id(instance_url, reaction_id);
let actor_id = local_actor_id(instance_url, &actor_profile.username); let actor_id = local_actor_id(instance_url, &actor_profile.username);
let (primary_audience, secondary_audience) = let (primary_audience, secondary_audience) = get_like_audience(post_author_id, post_visibility);
get_like_audience(post_author_id, post_visibility);
Like { Like {
context: build_default_context(), context: build_default_context(),
activity_type: LIKE.to_string(), activity_type: LIKE.to_string(),
@ -92,11 +86,7 @@ pub async fn prepare_like(
post: &Post, post: &Post,
reaction_id: &Uuid, reaction_id: &Uuid,
) -> Result<OutgoingActivity, DatabaseError> { ) -> Result<OutgoingActivity, DatabaseError> {
let recipients = get_like_recipients( let recipients = get_like_recipients(db_client, &instance.url(), post).await?;
db_client,
&instance.url(),
post,
).await?;
let object_id = post_object_id(&instance.url(), post); let object_id = post_object_id(&instance.url(), post);
let post_author_id = profile_actor_id(&instance.url(), &post.author); let post_author_id = profile_actor_id(&instance.url(), &post.author);
let activity = build_like( let activity = build_like(
@ -108,17 +98,14 @@ pub async fn prepare_like(
&post.visibility, &post.visibility,
); );
Ok(OutgoingActivity::new( Ok(OutgoingActivity::new(
instance, instance, sender, activity, recipients,
sender,
activity,
recipients,
)) ))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::id::generate_ulid;
use super::*; use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";

View file

@ -2,10 +2,7 @@ use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
use mitra_config::Instance; use mitra_config::Instance;
use mitra_models::{ use mitra_models::{profiles::types::DbActor, users::types::User};
profiles::types::DbActor,
users::types::User,
};
use mitra_utils::id::generate_ulid; use mitra_utils::id::generate_ulid;
use crate::activitypub::{ use crate::activitypub::{
@ -38,7 +35,8 @@ pub fn build_move_person(
followers: &[String], followers: &[String],
maybe_internal_activity_id: Option<&Uuid>, maybe_internal_activity_id: Option<&Uuid>,
) -> MovePerson { ) -> MovePerson {
let internal_activity_id = maybe_internal_activity_id.copied() let internal_activity_id = maybe_internal_activity_id
.copied()
.unwrap_or(generate_ulid()); .unwrap_or(generate_ulid());
let activity_id = local_object_id(instance_url, &internal_activity_id); let activity_id = local_object_id(instance_url, &internal_activity_id);
let actor_id = local_actor_id(instance_url, &sender.profile.username); let actor_id = local_actor_id(instance_url, &sender.profile.username);
@ -60,9 +58,7 @@ pub fn prepare_move_person(
followers: Vec<DbActor>, followers: Vec<DbActor>,
maybe_internal_activity_id: Option<&Uuid>, maybe_internal_activity_id: Option<&Uuid>,
) -> OutgoingActivity { ) -> OutgoingActivity {
let followers_ids: Vec<String> = followers.iter() let followers_ids: Vec<String> = followers.iter().map(|actor| actor.id.clone()).collect();
.map(|actor| actor.id.clone())
.collect();
let activity = build_move_person( let activity = build_move_person(
&instance.url(), &instance.url(),
sender, sender,
@ -70,19 +66,14 @@ pub fn prepare_move_person(
&followers_ids, &followers_ids,
maybe_internal_activity_id, maybe_internal_activity_id,
); );
OutgoingActivity::new( OutgoingActivity::new(instance, sender, activity, followers)
instance,
sender,
activity,
followers,
)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::id::generate_ulid;
use mitra_models::profiles::types::DbActorProfile;
use super::*; use super::*;
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";

View file

@ -1,13 +1,7 @@
use mitra_config::Instance; use mitra_config::Instance;
use mitra_models::{ use mitra_models::{profiles::types::DbActor, users::types::User};
profiles::types::DbActor,
users::types::User,
};
use crate::activitypub::{ use crate::activitypub::{deliverer::OutgoingActivity, identifiers::LocalActorCollection};
deliverer::OutgoingActivity,
identifiers::LocalActorCollection,
};
use super::add_person::prepare_update_collection; use super::add_person::prepare_update_collection;

View file

@ -9,14 +9,14 @@ use mitra_models::{
users::types::User, users::types::User,
}; };
use super::announce::get_announce_recipients;
use crate::activitypub::{ use crate::activitypub::{
constants::AP_PUBLIC, constants::AP_PUBLIC,
deliverer::OutgoingActivity, deliverer::OutgoingActivity,
identifiers::{local_actor_id, local_actor_followers, local_object_id}, identifiers::{local_actor_followers, local_actor_id, local_object_id},
types::{build_default_context, Context}, types::{build_default_context, Context},
vocabulary::UNDO, vocabulary::UNDO,
}; };
use super::announce::get_announce_recipients;
#[derive(Serialize)] #[derive(Serialize)]
struct UndoAnnounce { struct UndoAnnounce {
@ -43,13 +43,8 @@ fn build_undo_announce(
let object_id = local_object_id(instance_url, repost_id); let object_id = local_object_id(instance_url, repost_id);
let activity_id = format!("{}/undo", object_id); let activity_id = format!("{}/undo", object_id);
let actor_id = local_actor_id(instance_url, &actor_profile.username); let actor_id = local_actor_id(instance_url, &actor_profile.username);
let primary_audience = vec![ let primary_audience = vec![AP_PUBLIC.to_string(), recipient_id.to_string()];
AP_PUBLIC.to_string(), let secondary_audience = vec![local_actor_followers(instance_url, &actor_profile.username)];
recipient_id.to_string(),
];
let secondary_audience = vec![
local_actor_followers(instance_url, &actor_profile.username),
];
UndoAnnounce { UndoAnnounce {
context: build_default_context(), context: build_default_context(),
activity_type: UNDO.to_string(), activity_type: UNDO.to_string(),
@ -69,12 +64,8 @@ pub async fn prepare_undo_announce(
repost_id: &Uuid, repost_id: &Uuid,
) -> Result<OutgoingActivity, DatabaseError> { ) -> Result<OutgoingActivity, DatabaseError> {
assert_ne!(&post.id, repost_id); assert_ne!(&post.id, repost_id);
let (recipients, primary_recipient) = get_announce_recipients( let (recipients, primary_recipient) =
db_client, get_announce_recipients(db_client, &instance.url(), sender, post).await?;
&instance.url(),
sender,
post,
).await?;
let activity = build_undo_announce( let activity = build_undo_announce(
&instance.url(), &instance.url(),
&sender.profile, &sender.profile,
@ -82,17 +73,14 @@ pub async fn prepare_undo_announce(
&primary_recipient, &primary_recipient,
); );
Ok(OutgoingActivity::new( Ok(OutgoingActivity::new(
instance, instance, sender, activity, recipients,
sender,
activity,
recipients,
)) ))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::id::generate_ulid;
use super::*; use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";
@ -101,12 +89,7 @@ mod tests {
let announcer = DbActorProfile::default(); let announcer = DbActorProfile::default();
let post_author_id = "https://example.com/users/test"; let post_author_id = "https://example.com/users/test";
let repost_id = generate_ulid(); let repost_id = generate_ulid();
let activity = build_undo_announce( let activity = build_undo_announce(INSTANCE_URL, &announcer, &repost_id, post_author_id);
INSTANCE_URL,
&announcer,
&repost_id,
post_author_id,
);
assert_eq!( assert_eq!(
activity.id, activity.id,
format!("{}/objects/{}/undo", INSTANCE_URL, repost_id), format!("{}/objects/{}/undo", INSTANCE_URL, repost_id),

View file

@ -37,14 +37,8 @@ fn build_undo_follow(
target_actor_id: &str, target_actor_id: &str,
follow_request_id: &Uuid, follow_request_id: &Uuid,
) -> UndoFollow { ) -> UndoFollow {
let follow_activity_id = local_object_id( let follow_activity_id = local_object_id(instance_url, follow_request_id);
instance_url, let follow_actor_id = local_actor_id(instance_url, &actor_profile.username);
follow_request_id,
);
let follow_actor_id = local_actor_id(
instance_url,
&actor_profile.username,
);
let object = Follow { let object = Follow {
context: build_default_context(), context: build_default_context(),
activity_type: FOLLOW.to_string(), activity_type: FOLLOW.to_string(),
@ -78,18 +72,13 @@ pub fn prepare_undo_follow(
follow_request_id, follow_request_id,
); );
let recipients = vec![target_actor.clone()]; let recipients = vec![target_actor.clone()];
OutgoingActivity::new( OutgoingActivity::new(instance, sender, activity, recipients)
instance,
sender,
activity,
recipients,
)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::id::generate_ulid;
use super::*; use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";

View file

@ -16,10 +16,7 @@ use crate::activitypub::{
vocabulary::UNDO, vocabulary::UNDO,
}; };
use super::like::{ use super::like::{get_like_audience, get_like_recipients};
get_like_audience,
get_like_recipients,
};
#[derive(Serialize)] #[derive(Serialize)]
struct UndoLike { struct UndoLike {
@ -47,8 +44,7 @@ fn build_undo_like(
let object_id = local_object_id(instance_url, reaction_id); let object_id = local_object_id(instance_url, reaction_id);
let activity_id = format!("{}/undo", object_id); let activity_id = format!("{}/undo", object_id);
let actor_id = local_actor_id(instance_url, &actor_profile.username); let actor_id = local_actor_id(instance_url, &actor_profile.username);
let (primary_audience, secondary_audience) = let (primary_audience, secondary_audience) = get_like_audience(post_author_id, post_visibility);
get_like_audience(post_author_id, post_visibility);
UndoLike { UndoLike {
context: build_default_context(), context: build_default_context(),
activity_type: UNDO.to_string(), activity_type: UNDO.to_string(),
@ -67,11 +63,7 @@ pub async fn prepare_undo_like(
post: &Post, post: &Post,
reaction_id: &Uuid, reaction_id: &Uuid,
) -> Result<OutgoingActivity, DatabaseError> { ) -> Result<OutgoingActivity, DatabaseError> {
let recipients = get_like_recipients( let recipients = get_like_recipients(db_client, &instance.url(), post).await?;
db_client,
&instance.url(),
post,
).await?;
let post_author_id = profile_actor_id(&instance.url(), &post.author); let post_author_id = profile_actor_id(&instance.url(), &post.author);
let activity = build_undo_like( let activity = build_undo_like(
&instance.url(), &instance.url(),
@ -81,18 +73,15 @@ pub async fn prepare_undo_like(
&post.visibility, &post.visibility,
); );
Ok(OutgoingActivity::new( Ok(OutgoingActivity::new(
instance, instance, sender, activity, recipients,
sender,
activity,
recipients,
)) ))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::id::generate_ulid;
use crate::activitypub::constants::AP_PUBLIC;
use super::*; use super::*;
use crate::activitypub::constants::AP_PUBLIC;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";

View file

@ -41,8 +41,7 @@ pub fn build_update_person(
) -> Result<UpdatePerson, ActorKeyError> { ) -> Result<UpdatePerson, ActorKeyError> {
let actor = get_local_actor(user, instance_url)?; let actor = get_local_actor(user, instance_url)?;
// Update(Person) is idempotent so its ID can be random // Update(Person) is idempotent so its ID can be random
let internal_activity_id = let internal_activity_id = maybe_internal_activity_id.unwrap_or(generate_ulid());
maybe_internal_activity_id.unwrap_or(generate_ulid());
let activity_id = local_object_id(instance_url, &internal_activity_id); let activity_id = local_object_id(instance_url, &internal_activity_id);
let activity = UpdatePerson { let activity = UpdatePerson {
context: build_default_context(), context: build_default_context(),
@ -68,7 +67,7 @@ async fn get_update_person_recipients(
if let Some(remote_actor) = profile.actor_json { if let Some(remote_actor) = profile.actor_json {
recipients.push(remote_actor); recipients.push(remote_actor);
}; };
}; }
Ok(recipients) Ok(recipients)
} }
@ -78,28 +77,17 @@ pub async fn prepare_update_person(
user: &User, user: &User,
maybe_internal_activity_id: Option<Uuid>, maybe_internal_activity_id: Option<Uuid>,
) -> Result<OutgoingActivity, DatabaseError> { ) -> Result<OutgoingActivity, DatabaseError> {
let activity = build_update_person( let activity = build_update_person(&instance.url(), user, maybe_internal_activity_id)
&instance.url(), .map_err(|_| DatabaseTypeError)?;
user,
maybe_internal_activity_id,
).map_err(|_| DatabaseTypeError)?;
let recipients = get_update_person_recipients(db_client, &user.id).await?; let recipients = get_update_person_recipients(db_client, &user.id).await?;
Ok(OutgoingActivity::new( Ok(OutgoingActivity::new(instance, user, activity, recipients))
instance,
user,
activity,
recipients,
))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::crypto_rsa::{
generate_weak_rsa_key,
serialize_private_key,
};
use super::*; use super::*;
use mitra_models::profiles::types::DbActorProfile;
use mitra_utils::crypto_rsa::{generate_weak_rsa_key, serialize_private_key};
const INSTANCE_URL: &str = "https://example.com"; const INSTANCE_URL: &str = "https://example.com";
@ -116,11 +104,7 @@ mod tests {
..Default::default() ..Default::default()
}; };
let internal_id = generate_ulid(); let internal_id = generate_ulid();
let activity = build_update_person( let activity = build_update_person(INSTANCE_URL, &user, Some(internal_id)).unwrap();
INSTANCE_URL,
&user,
Some(internal_id),
).unwrap();
assert_eq!( assert_eq!(
activity.id, activity.id,
format!("{}/objects/{}", INSTANCE_URL, internal_id), format!("{}/objects/{}", INSTANCE_URL, internal_id),
@ -129,9 +113,12 @@ mod tests {
activity.object.id, activity.object.id,
format!("{}/users/testuser", INSTANCE_URL), format!("{}/users/testuser", INSTANCE_URL),
); );
assert_eq!(activity.to, vec![ assert_eq!(
AP_PUBLIC.to_string(), activity.to,
format!("{}/users/testuser/followers", INSTANCE_URL), vec![
]); AP_PUBLIC.to_string(),
format!("{}/users/testuser/followers", INSTANCE_URL),
]
);
} }
} }

View file

@ -1,5 +1,5 @@
use serde::Serialize; use serde::Serialize;
use serde_json::{Value as JsonValue}; use serde_json::Value as JsonValue;
use super::types::{build_default_context, Context}; use super::types::{build_default_context, Context};
use super::vocabulary::{ORDERED_COLLECTION, ORDERED_COLLECTION_PAGE}; use super::vocabulary::{ORDERED_COLLECTION, ORDERED_COLLECTION_PAGE};
@ -53,10 +53,7 @@ pub struct OrderedCollectionPage {
} }
impl OrderedCollectionPage { impl OrderedCollectionPage {
pub fn new( pub fn new(collection_page_id: String, items: Vec<JsonValue>) -> Self {
collection_page_id: String,
items: Vec<JsonValue>,
) -> Self {
Self { Self {
context: build_default_context(), context: build_default_context(),
id: collection_page_id, id: collection_page_id,

View file

@ -1,5 +1,6 @@
// https://www.w3.org/TR/activitypub/#server-to-server-interactions // https://www.w3.org/TR/activitypub/#server-to-server-interactions
pub const AP_MEDIA_TYPE: &str = r#"application/ld+json; profile="https://www.w3.org/ns/activitystreams""#; pub const AP_MEDIA_TYPE: &str =
r#"application/ld+json; profile="https://www.w3.org/ns/activitystreams""#;
pub const AS_MEDIA_TYPE: &str = "application/activity+json"; pub const AS_MEDIA_TYPE: &str = "application/activity+json";
// Contexts // Contexts

View file

@ -8,24 +8,14 @@ use serde_json::Value;
use mitra_config::Instance; use mitra_config::Instance;
use mitra_models::{ use mitra_models::{
database::{ database::{DatabaseClient, DatabaseError},
DatabaseClient,
DatabaseError,
},
profiles::types::DbActor, profiles::types::DbActor,
users::types::User, users::types::User,
}; };
use mitra_utils::crypto_rsa::deserialize_private_key; use mitra_utils::crypto_rsa::deserialize_private_key;
use crate::http_signatures::create::{ use crate::http_signatures::create::{create_http_signature, HttpSignatureError};
create_http_signature, use crate::json_signatures::create::{is_object_signed, sign_object, JsonSignatureError};
HttpSignatureError,
};
use crate::json_signatures::create::{
is_object_signed,
sign_object,
JsonSignatureError,
};
use super::{ use super::{
constants::AP_MEDIA_TYPE, constants::AP_MEDIA_TYPE,
@ -58,16 +48,9 @@ pub enum DelivererError {
HttpError(reqwest::StatusCode), HttpError(reqwest::StatusCode),
} }
fn build_client( fn build_client(instance: &Instance, request_url: &str) -> Result<Client, DelivererError> {
instance: &Instance,
request_url: &str,
) -> Result<Client, DelivererError> {
let network = get_network_type(request_url)?; let network = get_network_type(request_url)?;
let client = build_federation_client( let client = build_federation_client(instance, network, instance.deliverer_timeout)?;
instance,
network,
instance.deliverer_timeout,
)?;
Ok(client) Ok(client)
} }
@ -87,7 +70,8 @@ async fn send_activity(
)?; )?;
let client = build_client(instance, inbox_url)?; let client = build_client(instance, inbox_url)?;
let request = client.post(inbox_url) let request = client
.post(inbox_url)
.header("Host", headers.host) .header("Host", headers.host)
.header("Date", headers.date) .header("Date", headers.date)
.header("Digest", headers.digest.unwrap()) .header("Digest", headers.digest.unwrap())
@ -97,15 +81,16 @@ async fn send_activity(
.body(activity_json.to_owned()); .body(activity_json.to_owned());
if instance.is_private { if instance.is_private {
log::info!( log::info!("private mode: not sending activity to {}", inbox_url,);
"private mode: not sending activity to {}",
inbox_url,
);
} else { } else {
let response = request.send().await?; let response = request.send().await?;
let response_status = response.status(); let response_status = response.status();
let response_text: String = response.text().await? let response_text: String = response
.chars().filter(|chr| *chr != '\n' && *chr != '\r').take(75) .text()
.await?
.chars()
.filter(|chr| *chr != '\n' && *chr != '\r')
.take(75)
.collect(); .collect();
log::info!( log::info!(
"response from {}: [{}] {}", "response from {}: [{}] {}",
@ -135,10 +120,7 @@ async fn deliver_activity_worker(
recipients: &mut [Recipient], recipients: &mut [Recipient],
) -> Result<(), DelivererError> { ) -> Result<(), DelivererError> {
let actor_key = deserialize_private_key(&sender.private_key)?; let actor_key = deserialize_private_key(&sender.private_key)?;
let actor_id = local_actor_id( let actor_id = local_actor_id(&instance.url(), &sender.profile.username);
&instance.url(),
&sender.profile.username,
);
let actor_key_id = local_actor_key_id(&actor_id); let actor_key_id = local_actor_key_id(&actor_id);
let activity_signed = if is_object_signed(&activity) { let activity_signed = if is_object_signed(&activity) {
log::warn!("activity is already signed"); log::warn!("activity is already signed");
@ -158,7 +140,9 @@ async fn deliver_activity_worker(
&actor_key_id, &actor_key_id,
&activity_json, &activity_json,
&recipient.inbox, &recipient.inbox,
).await { )
.await
{
log::warn!( log::warn!(
"failed to deliver activity to {}: {}", "failed to deliver activity to {}: {}",
recipient.inbox, recipient.inbox,
@ -167,7 +151,7 @@ async fn deliver_activity_worker(
} else { } else {
recipient.is_delivered = true; recipient.is_delivered = true;
}; };
}; }
Ok(()) Ok(())
} }
@ -196,32 +180,27 @@ impl OutgoingActivity {
}; };
recipient_map.insert(actor.id, recipient); recipient_map.insert(actor.id, recipient);
}; };
}; }
Self { Self {
instance: instance.clone(), instance: instance.clone(),
sender: sender.clone(), sender: sender.clone(),
activity: serde_json::to_value(activity) activity: serde_json::to_value(activity).expect("activity should be serializable"),
.expect("activity should be serializable"),
recipients: recipient_map.into_values().collect(), recipients: recipient_map.into_values().collect(),
} }
} }
pub(super) async fn deliver( pub(super) async fn deliver(mut self) -> Result<Vec<Recipient>, DelivererError> {
mut self,
) -> Result<Vec<Recipient>, DelivererError> {
deliver_activity_worker( deliver_activity_worker(
self.instance, self.instance,
self.sender, self.sender,
self.activity, self.activity,
&mut self.recipients, &mut self.recipients,
).await?; )
.await?;
Ok(self.recipients) Ok(self.recipients)
} }
pub async fn enqueue( pub async fn enqueue(self, db_client: &impl DatabaseClient) -> Result<(), DatabaseError> {
self,
db_client: &impl DatabaseClient,
) -> Result<(), DatabaseError> {
if self.recipients.is_empty() { if self.recipients.is_empty() {
return Ok(()); return Ok(());
}; };

View file

@ -2,13 +2,10 @@ use std::path::Path;
use reqwest::{Client, Method, RequestBuilder}; use reqwest::{Client, Method, RequestBuilder};
use serde::Deserialize; use serde::Deserialize;
use serde_json::{Value as JsonValue}; use serde_json::Value as JsonValue;
use mitra_config::Instance; use mitra_config::Instance;
use mitra_utils::{ use mitra_utils::{files::sniff_media_type, urls::guess_protocol};
files::sniff_media_type,
urls::guess_protocol,
};
use crate::activitypub::{ use crate::activitypub::{
actors::types::Actor, actors::types::Actor,
@ -18,10 +15,7 @@ use crate::activitypub::{
types::Object, types::Object,
vocabulary::GROUP, vocabulary::GROUP,
}; };
use crate::http_signatures::create::{ use crate::http_signatures::create::{create_http_signature, HttpSignatureError};
create_http_signature,
HttpSignatureError,
};
use crate::media::{save_file, SUPPORTED_MEDIA_TYPES}; use crate::media::{save_file, SUPPORTED_MEDIA_TYPES};
use crate::webfinger::types::{ActorAddress, JsonResourceDescriptor}; use crate::webfinger::types::{ActorAddress, JsonResourceDescriptor};
@ -52,39 +46,23 @@ pub enum FetchError {
OtherError(&'static str), OtherError(&'static str),
} }
fn build_client( fn build_client(instance: &Instance, request_url: &str) -> Result<Client, FetchError> {
instance: &Instance,
request_url: &str,
) -> Result<Client, FetchError> {
let network = get_network_type(request_url)?; let network = get_network_type(request_url)?;
let client = build_federation_client( let client = build_federation_client(instance, network, instance.fetcher_timeout)?;
instance,
network,
instance.fetcher_timeout,
)?;
Ok(client) Ok(client)
} }
fn build_request( fn build_request(instance: &Instance, client: Client, method: Method, url: &str) -> RequestBuilder {
instance: &Instance,
client: Client,
method: Method,
url: &str,
) -> RequestBuilder {
let mut request_builder = client.request(method, url); let mut request_builder = client.request(method, url);
if !instance.is_private { if !instance.is_private {
// Public instances should set User-Agent header // Public instances should set User-Agent header
request_builder = request_builder request_builder = request_builder.header(reqwest::header::USER_AGENT, instance.agent());
.header(reqwest::header::USER_AGENT, instance.agent());
}; };
request_builder request_builder
} }
/// Sends GET request to fetch AP object /// Sends GET request to fetch AP object
async fn send_request( async fn send_request(instance: &Instance, url: &str) -> Result<String, FetchError> {
instance: &Instance,
url: &str,
) -> Result<String, FetchError> {
let client = build_client(instance, url)?; let client = build_client(instance, url)?;
let mut request_builder = build_request(instance, client, Method::GET, url) let mut request_builder = build_request(instance, client, Method::GET, url)
.header(reqwest::header::ACCEPT, AP_MEDIA_TYPE); .header(reqwest::header::ACCEPT, AP_MEDIA_TYPE);
@ -107,9 +85,11 @@ async fn send_request(
}; };
let data = request_builder let data = request_builder
.send().await? .send()
.await?
.error_for_status()? .error_for_status()?
.text().await?; .text()
.await?;
Ok(data) Ok(data)
} }
@ -121,12 +101,10 @@ pub async fn fetch_file(
output_dir: &Path, output_dir: &Path,
) -> Result<(String, usize, Option<String>), FetchError> { ) -> Result<(String, usize, Option<String>), FetchError> {
let client = build_client(instance, url)?; let client = build_client(instance, url)?;
let request_builder = let request_builder = build_request(instance, client, Method::GET, url);
build_request(instance, client, Method::GET, url);
let response = request_builder.send().await?.error_for_status()?; let response = request_builder.send().await?.error_for_status()?;
if let Some(file_size) = response.content_length() { if let Some(file_size) = response.content_length() {
let file_size: usize = file_size.try_into() let file_size: usize = file_size.try_into().expect("value should be within bounds");
.expect("value should be within bounds");
if file_size > file_max_size { if file_size > file_max_size {
return Err(FetchError::FileTooLarge); return Err(FetchError::FileTooLarge);
}; };
@ -145,19 +123,11 @@ pub async fn fetch_file(
if SUPPORTED_MEDIA_TYPES.contains(&media_type.as_str()) { if SUPPORTED_MEDIA_TYPES.contains(&media_type.as_str()) {
true true
} else { } else {
log::info!( log::info!("unsupported media type {}: {}", media_type, url,);
"unsupported media type {}: {}",
media_type,
url,
);
false false
} }
}); });
let file_name = save_file( let file_name = save_file(file_data.to_vec(), output_dir, maybe_media_type.as_deref())?;
file_data.to_vec(),
output_dir,
maybe_media_type.as_deref(),
)?;
Ok((file_name, file_size, maybe_media_type)) Ok((file_name, file_size, maybe_media_type))
} }
@ -172,43 +142,45 @@ pub async fn perform_webfinger_query(
actor_address.hostname, actor_address.hostname,
); );
let client = build_client(instance, &webfinger_url)?; let client = build_client(instance, &webfinger_url)?;
let request_builder = let request_builder = build_request(instance, client, Method::GET, &webfinger_url);
build_request(instance, client, Method::GET, &webfinger_url);
let webfinger_data = request_builder let webfinger_data = request_builder
.query(&[("resource", webfinger_account_uri)]) .query(&[("resource", webfinger_account_uri)])
.send().await? .send()
.await?
.error_for_status()? .error_for_status()?
.text().await?; .text()
.await?;
let jrd: JsonResourceDescriptor = serde_json::from_str(&webfinger_data)?; let jrd: JsonResourceDescriptor = serde_json::from_str(&webfinger_data)?;
// Lemmy servers can have Group and Person actors with the same name // Lemmy servers can have Group and Person actors with the same name
// https://github.com/LemmyNet/lemmy/issues/2037 // https://github.com/LemmyNet/lemmy/issues/2037
let ap_type_property = format!("{}#type", AP_CONTEXT); let ap_type_property = format!("{}#type", AP_CONTEXT);
let group_link = jrd.links.iter() let group_link = jrd.links.iter().find(|link| {
.find(|link| { link.rel == "self"
link.rel == "self" && && link
link.properties .properties
.get(&ap_type_property) .get(&ap_type_property)
.map(|val| val.as_str()) == Some(GROUP) .map(|val| val.as_str())
}); == Some(GROUP)
});
let link = if let Some(link) = group_link { let link = if let Some(link) = group_link {
// Prefer Group if the actor type is provided // Prefer Group if the actor type is provided
link link
} else { } else {
// Otherwise take first "self" link // Otherwise take first "self" link
jrd.links.iter() jrd.links
.iter()
.find(|link| link.rel == "self") .find(|link| link.rel == "self")
.ok_or(FetchError::OtherError("self link not found"))? .ok_or(FetchError::OtherError("self link not found"))?
}; };
let actor_url = link.href.as_ref() let actor_url = link
.href
.as_ref()
.ok_or(FetchError::OtherError("account href not found"))? .ok_or(FetchError::OtherError("account href not found"))?
.to_string(); .to_string();
Ok(actor_url) Ok(actor_url)
} }
pub async fn fetch_actor( pub async fn fetch_actor(instance: &Instance, actor_url: &str) -> Result<Actor, FetchError> {
instance: &Instance,
actor_url: &str,
) -> Result<Actor, FetchError> {
let actor_json = send_request(instance, actor_url).await?; let actor_json = send_request(instance, actor_url).await?;
let actor: Actor = serde_json::from_str(&actor_json)?; let actor: Actor = serde_json::from_str(&actor_json)?;
if actor.id != actor_url { if actor.id != actor_url {
@ -217,10 +189,7 @@ pub async fn fetch_actor(
Ok(actor) Ok(actor)
} }
pub async fn fetch_object( pub async fn fetch_object(instance: &Instance, object_url: &str) -> Result<Object, FetchError> {
instance: &Instance,
object_url: &str,
) -> Result<Object, FetchError> {
let object_json = send_request(instance, object_url).await?; let object_json = send_request(instance, object_url).await?;
let object_value: JsonValue = serde_json::from_str(&object_json)?; let object_value: JsonValue = serde_json::from_str(&object_json)?;
let object: Object = serde_json::from_value(object_value)?; let object: Object = serde_json::from_value(object_value)?;
@ -245,7 +214,6 @@ pub async fn fetch_outbox(
let collection: Collection = serde_json::from_str(&collection_json)?; let collection: Collection = serde_json::from_str(&collection_json)?;
let page_json = send_request(instance, &collection.first).await?; let page_json = send_request(instance, &collection.first).await?;
let page: CollectionPage = serde_json::from_str(&page_json)?; let page: CollectionPage = serde_json::from_str(&page_json)?;
let activities = page.ordered_items.into_iter() let activities = page.ordered_items.into_iter().take(limit).collect();
.take(limit).collect();
Ok(activities) Ok(activities)
} }

View file

@ -6,13 +6,13 @@ use mitra_models::{
posts::helpers::get_local_post_by_id, posts::helpers::get_local_post_by_id,
posts::queries::get_post_by_remote_object_id, posts::queries::get_post_by_remote_object_id,
posts::types::Post, posts::types::Post,
profiles::queries::{ profiles::queries::{get_profile_by_acct, get_profile_by_remote_actor_id},
get_profile_by_acct,
get_profile_by_remote_actor_id,
},
profiles::types::DbActorProfile, profiles::types::DbActorProfile,
}; };
use super::fetchers::{
fetch_actor, fetch_object, fetch_outbox, perform_webfinger_query, FetchError,
};
use crate::activitypub::{ use crate::activitypub::{
actors::helpers::{create_remote_profile, update_remote_profile}, actors::helpers::{create_remote_profile, update_remote_profile},
handlers::create::{get_object_links, handle_note}, handlers::create::{get_object_links, handle_note},
@ -23,13 +23,6 @@ use crate::activitypub::{
use crate::errors::ValidationError; use crate::errors::ValidationError;
use crate::media::MediaStorage; use crate::media::MediaStorage;
use crate::webfinger::types::ActorAddress; use crate::webfinger::types::ActorAddress;
use super::fetchers::{
fetch_actor,
fetch_object,
fetch_outbox,
perform_webfinger_query,
FetchError,
};
pub async fn get_or_import_profile_by_actor_id( pub async fn get_or_import_profile_by_actor_id(
db_client: &mut impl DatabaseClient, db_client: &mut impl DatabaseClient,
@ -40,37 +33,28 @@ pub async fn get_or_import_profile_by_actor_id(
if actor_id.starts_with(&instance.url()) { if actor_id.starts_with(&instance.url()) {
return Err(HandlerError::LocalObject); return Err(HandlerError::LocalObject);
}; };
let profile = match get_profile_by_remote_actor_id( let profile = match get_profile_by_remote_actor_id(db_client, actor_id).await {
db_client,
actor_id,
).await {
Ok(profile) => { Ok(profile) => {
if profile.possibly_outdated() { if profile.possibly_outdated() {
// Try to re-fetch actor profile // Try to re-fetch actor profile
match fetch_actor(instance, actor_id).await { match fetch_actor(instance, actor_id).await {
Ok(actor) => { Ok(actor) => {
log::info!("re-fetched profile {}", profile.acct); log::info!("re-fetched profile {}", profile.acct);
let profile_updated = update_remote_profile( let profile_updated =
db_client, update_remote_profile(db_client, instance, storage, profile, actor)
instance, .await?;
storage,
profile,
actor,
).await?;
profile_updated profile_updated
}, }
Err(err) => { Err(err) => {
// Ignore error and return stored profile // Ignore error and return stored profile
log::warn!( log::warn!("failed to re-fetch {} ({})", profile.acct, err,);
"failed to re-fetch {} ({})", profile.acct, err,
);
profile profile
}, }
} }
} else { } else {
profile profile
} }
}, }
Err(DatabaseError::NotFound(_)) => { Err(DatabaseError::NotFound(_)) => {
let actor = fetch_actor(instance, actor_id).await?; let actor = fetch_actor(instance, actor_id).await?;
let actor_address = actor.address()?; let actor_address = actor.address()?;
@ -79,28 +63,19 @@ pub async fn get_or_import_profile_by_actor_id(
Ok(profile) => { Ok(profile) => {
// WARNING: Possible actor ID change // WARNING: Possible actor ID change
log::info!("re-fetched profile {}", profile.acct); log::info!("re-fetched profile {}", profile.acct);
let profile_updated = update_remote_profile( let profile_updated =
db_client, update_remote_profile(db_client, instance, storage, profile, actor).await?;
instance,
storage,
profile,
actor,
).await?;
profile_updated profile_updated
}, }
Err(DatabaseError::NotFound(_)) => { Err(DatabaseError::NotFound(_)) => {
log::info!("fetched profile {}", acct); log::info!("fetched profile {}", acct);
let profile = create_remote_profile( let profile =
db_client, create_remote_profile(db_client, instance, storage, actor).await?;
instance,
storage,
actor,
).await?;
profile profile
}, }
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
} }
}, }
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
}; };
Ok(profile) Ok(profile)
@ -128,12 +103,7 @@ pub async fn import_profile_by_actor_address(
}; };
}; };
log::info!("fetched profile {}", profile_acct); log::info!("fetched profile {}", profile_acct);
let profile = create_remote_profile( let profile = create_remote_profile(db_client, instance, storage, actor).await?;
db_client,
instance,
storage,
actor,
).await?;
Ok(profile) Ok(profile)
} }
@ -145,22 +115,14 @@ pub async fn get_or_import_profile_by_actor_address(
actor_address: &ActorAddress, actor_address: &ActorAddress,
) -> Result<DbActorProfile, HandlerError> { ) -> Result<DbActorProfile, HandlerError> {
let acct = actor_address.acct(&instance.hostname()); let acct = actor_address.acct(&instance.hostname());
let profile = match get_profile_by_acct( let profile = match get_profile_by_acct(db_client, &acct).await {
db_client,
&acct,
).await {
Ok(profile) => profile, Ok(profile) => profile,
Err(db_error @ DatabaseError::NotFound(_)) => { Err(db_error @ DatabaseError::NotFound(_)) => {
if actor_address.hostname == instance.hostname() { if actor_address.hostname == instance.hostname() {
return Err(db_error.into()); return Err(db_error.into());
}; };
import_profile_by_actor_address( import_profile_by_actor_address(db_client, instance, storage, actor_address).await?
db_client, }
instance,
storage,
actor_address,
).await?
},
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
}; };
Ok(profile) Ok(profile)
@ -176,12 +138,12 @@ pub async fn get_post_by_object_id(
// Local post // Local post
let post = get_local_post_by_id(db_client, &post_id).await?; let post = get_local_post_by_id(db_client, &post_id).await?;
Ok(post) Ok(post)
}, }
Err(_) => { Err(_) => {
// Remote post // Remote post
let post = get_post_by_remote_object_id(db_client, object_id).await?; let post = get_post_by_remote_object_id(db_client, object_id).await?;
Ok(post) Ok(post)
}, }
} }
} }
@ -222,10 +184,7 @@ pub async fn import_post(
get_local_post_by_id(db_client, &post_id).await?; get_local_post_by_id(db_client, &post_id).await?;
continue; continue;
}; };
match get_post_by_remote_object_id( match get_post_by_remote_object_id(db_client, &object_id).await {
db_client,
&object_id,
).await {
Ok(post) => { Ok(post) => {
// Object already fetched // Object already fetched
if objects.len() == 0 { if objects.len() == 0 {
@ -233,16 +192,16 @@ pub async fn import_post(
return Ok(post); return Ok(post);
}; };
continue; continue;
}, }
Err(DatabaseError::NotFound(_)) => (), Err(DatabaseError::NotFound(_)) => (),
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
}; };
object_id object_id
}, }
None => { None => {
// No object to fetch // No object to fetch
break; break;
}, }
}; };
let object = match maybe_object { let object = match maybe_object {
Some(object) => object, Some(object) => object,
@ -251,15 +210,14 @@ pub async fn import_post(
// TODO: create tombstone // TODO: create tombstone
return Err(FetchError::RecursionError.into()); return Err(FetchError::RecursionError.into());
}; };
let object = fetch_object(instance, &object_id).await let object = fetch_object(instance, &object_id).await.map_err(|err| {
.map_err(|err| { log::warn!("{}", err);
log::warn!("{}", err); ValidationError("failed to fetch object")
ValidationError("failed to fetch object") })?;
})?;
log::info!("fetched object {}", object.id); log::info!("fetched object {}", object.id);
fetch_count += 1; fetch_count += 1;
object object
}, }
}; };
if object.id != object_id { if object.id != object_id {
// ID of fetched object doesn't match requested ID // ID of fetched object doesn't match requested ID
@ -277,27 +235,22 @@ pub async fn import_post(
for object_id in get_object_links(&object) { for object_id in get_object_links(&object) {
// Fetch linked objects after fetching current thread // Fetch linked objects after fetching current thread
queue.insert(0, object_id); queue.insert(0, object_id);
}; }
maybe_object = None; maybe_object = None;
objects.push(object); objects.push(object);
}; }
let initial_object_id = objects[0].id.clone(); let initial_object_id = objects[0].id.clone();
// Objects are ordered according to their place in reply tree, // Objects are ordered according to their place in reply tree,
// starting with the root // starting with the root
objects.reverse(); objects.reverse();
for object in objects { for object in objects {
let post = handle_note( let post = handle_note(db_client, instance, storage, object, &redirects).await?;
db_client,
instance,
storage,
object,
&redirects,
).await?;
posts.push(post); posts.push(post);
}; }
let initial_post = posts.into_iter() let initial_post = posts
.into_iter()
.find(|post| post.object_id.as_ref() == Some(&initial_object_id)) .find(|post| post.object_id.as_ref() == Some(&initial_object_id))
.unwrap(); .unwrap();
Ok(initial_post) Ok(initial_post)
@ -314,24 +267,20 @@ pub async fn import_from_outbox(
let activities = fetch_outbox(&instance, &actor.outbox, limit).await?; let activities = fetch_outbox(&instance, &actor.outbox, limit).await?;
log::info!("fetched {} activities", activities.len()); log::info!("fetched {} activities", activities.len());
for activity in activities { for activity in activities {
let activity_actor = activity["actor"].as_str() let activity_actor = activity["actor"]
.as_str()
.ok_or(ValidationError("actor property is missing"))?; .ok_or(ValidationError("actor property is missing"))?;
if activity_actor != actor.id { if activity_actor != actor.id {
log::warn!("activity doesn't belong to outbox owner"); log::warn!("activity doesn't belong to outbox owner");
continue; continue;
}; };
handle_activity( handle_activity(
config, config, db_client, &activity, true, // is authenticated
db_client, )
&activity, .await
true, // is authenticated .unwrap_or_else(|error| {
).await.unwrap_or_else(|error| { log::warn!("failed to process activity ({}): {}", error, activity,);
log::warn!(
"failed to process activity ({}): {}",
error,
activity,
);
}); });
}; }
Ok(()) Ok(())
} }

View file

@ -5,17 +5,12 @@ use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::DatabaseClient, database::DatabaseClient,
profiles::queries::get_profile_by_remote_actor_id, profiles::queries::get_profile_by_remote_actor_id,
relationships::queries::{ relationships::queries::{follow_request_accepted, get_follow_request_by_id},
follow_request_accepted,
get_follow_request_by_id,
},
relationships::types::FollowRequestStatus, relationships::types::FollowRequestStatus,
}; };
use crate::activitypub::{ use crate::activitypub::{
identifiers::parse_local_object_id, identifiers::parse_local_object_id, receiver::deserialize_into_object_id, vocabulary::FOLLOW,
receiver::deserialize_into_object_id,
vocabulary::FOLLOW,
}; };
use crate::errors::ValidationError; use crate::errors::ValidationError;
@ -36,14 +31,8 @@ pub async fn handle_accept(
// Accept(Follow) // Accept(Follow)
let activity: Accept = serde_json::from_value(activity) let activity: Accept = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?; .map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id( let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
db_client, let follow_request_id = parse_local_object_id(&config.instance_url(), &activity.object)?;
&activity.actor,
).await?;
let follow_request_id = parse_local_object_id(
&config.instance_url(),
&activity.object,
)?;
let follow_request = get_follow_request_by_id(db_client, &follow_request_id).await?; let follow_request = get_follow_request_by_id(db_client, &follow_request_id).await?;
if follow_request.target_id != actor_profile.id { if follow_request.target_id != actor_profile.id {
return Err(ValidationError("actor is not a target").into()); return Err(ValidationError("actor is not a target").into());

View file

@ -3,18 +3,13 @@ use serde_json::Value;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::DatabaseClient, database::DatabaseClient, profiles::queries::get_profile_by_remote_actor_id,
profiles::queries::get_profile_by_remote_actor_id, relationships::queries::subscribe_opt, users::queries::get_user_by_name,
relationships::queries::subscribe_opt,
users::queries::get_user_by_name,
}; };
use crate::activitypub::{
identifiers::parse_local_actor_id,
vocabulary::PERSON,
};
use crate::errors::ValidationError;
use super::{HandlerError, HandlerResult}; use super::{HandlerError, HandlerResult};
use crate::activitypub::{identifiers::parse_local_actor_id, vocabulary::PERSON};
use crate::errors::ValidationError;
#[derive(Deserialize)] #[derive(Deserialize)]
struct Add { struct Add {
@ -30,17 +25,11 @@ pub async fn handle_add(
) -> HandlerResult { ) -> HandlerResult {
let activity: Add = serde_json::from_value(activity) let activity: Add = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?; .map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id( let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
db_client,
&activity.actor,
).await?;
let actor = actor_profile.actor_json.ok_or(HandlerError::LocalObject)?; let actor = actor_profile.actor_json.ok_or(HandlerError::LocalObject)?;
if Some(activity.target) == actor.subscribers { if Some(activity.target) == actor.subscribers {
// Adding to subscribers // Adding to subscribers
let username = parse_local_actor_id( let username = parse_local_actor_id(&config.instance_url(), &activity.object)?;
&config.instance_url(),
&activity.object,
)?;
let user = get_user_by_name(db_client, &username).await?; let user = get_user_by_name(db_client, &username).await?;
subscribe_opt(db_client, &user.id, &actor_profile.id).await?; subscribe_opt(db_client, &user.id, &actor_profile.id).await?;
return Ok(Some(PERSON)); return Ok(Some(PERSON));

View file

@ -4,13 +4,11 @@ use serde_json::Value;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::{DatabaseClient, DatabaseError}, database::{DatabaseClient, DatabaseError},
posts::queries::{ posts::queries::{create_post, get_post_by_remote_object_id},
create_post,
get_post_by_remote_object_id,
},
posts::types::PostCreateData, posts::types::PostCreateData,
}; };
use super::HandlerResult;
use crate::activitypub::{ use crate::activitypub::{
fetcher::helpers::{get_or_import_profile_by_actor_id, import_post}, fetcher::helpers::{get_or_import_profile_by_actor_id, import_post},
identifiers::parse_local_object_id, identifiers::parse_local_object_id,
@ -19,7 +17,6 @@ use crate::activitypub::{
}; };
use crate::errors::ValidationError; use crate::errors::ValidationError;
use crate::media::MediaStorage; use crate::media::MediaStorage;
use super::HandlerResult;
#[derive(Deserialize)] #[derive(Deserialize)]
struct Announce { struct Announce {
@ -44,43 +41,24 @@ pub async fn handle_announce(
let activity: Announce = serde_json::from_value(activity) let activity: Announce = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?; .map_err(|_| ValidationError("unexpected activity structure"))?;
let repost_object_id = activity.id; let repost_object_id = activity.id;
match get_post_by_remote_object_id( match get_post_by_remote_object_id(db_client, &repost_object_id).await {
db_client,
&repost_object_id,
).await {
Ok(_) => return Ok(None), // Ignore if repost already exists Ok(_) => return Ok(None), // Ignore if repost already exists
Err(DatabaseError::NotFound(_)) => (), Err(DatabaseError::NotFound(_)) => (),
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
}; };
let instance = config.instance(); let instance = config.instance();
let storage = MediaStorage::from(config); let storage = MediaStorage::from(config);
let author = get_or_import_profile_by_actor_id( let author =
db_client, get_or_import_profile_by_actor_id(db_client, &instance, &storage, &activity.actor).await?;
&instance, let post_id = match parse_local_object_id(&instance.url(), &activity.object) {
&storage,
&activity.actor,
).await?;
let post_id = match parse_local_object_id(
&instance.url(),
&activity.object,
) {
Ok(post_id) => post_id, Ok(post_id) => post_id,
Err(_) => { Err(_) => {
// Try to get remote post // Try to get remote post
let post = import_post( let post = import_post(db_client, &instance, &storage, activity.object, None).await?;
db_client,
&instance,
&storage,
activity.object,
None,
).await?;
post.id post.id
}, }
}; };
let repost_data = PostCreateData::repost( let repost_data = PostCreateData::repost(post_id, Some(repost_object_id.clone()));
post_id,
Some(repost_object_id.clone()),
);
match create_post(db_client, &author.id, repost_data).await { match create_post(db_client, &author.id, repost_data).await {
Ok(_) => Ok(Some(NOTE)), Ok(_) => Ok(Some(NOTE)),
Err(DatabaseError::AlreadyExists("post")) => { Err(DatabaseError::AlreadyExists("post")) => {
@ -88,7 +66,7 @@ pub async fn handle_announce(
// object ID, or due to race condition in a handler). // object ID, or due to race condition in a handler).
log::warn!("repost already exists: {}", repost_object_id); log::warn!("repost already exists: {}", repost_object_id);
Ok(None) Ok(None)
}, }
// May return "post not found" error if post if not public // May return "post not found" error if post if not public
Err(other_error) => Err(other_error.into()), Err(other_error) => Err(other_error.into()),
} }
@ -96,8 +74,8 @@ pub async fn handle_announce(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json::json;
use super::*; use super::*;
use serde_json::json;
#[test] #[test]
fn test_deserialize_announce() { fn test_deserialize_announce() {

View file

@ -2,18 +2,14 @@ use std::collections::HashMap;
use chrono::Utc; use chrono::Utc;
use serde::Deserialize; use serde::Deserialize;
use serde_json::{Value as JsonValue}; use serde_json::Value as JsonValue;
use uuid::Uuid; use uuid::Uuid;
use mitra_config::{Config, Instance}; use mitra_config::{Config, Instance};
use mitra_models::{ use mitra_models::{
attachments::queries::create_attachment, attachments::queries::create_attachment,
database::{DatabaseClient, DatabaseError}, database::{DatabaseClient, DatabaseError},
emojis::queries::{ emojis::queries::{create_emoji, get_emoji_by_remote_object_id, update_emoji},
create_emoji,
get_emoji_by_remote_object_id,
update_emoji,
},
emojis::types::{DbEmoji, EmojiImage}, emojis::types::{DbEmoji, EmojiImage},
posts::{ posts::{
queries::create_post, queries::create_post,
@ -23,19 +19,15 @@ use mitra_models::{
relationships::queries::has_local_followers, relationships::queries::has_local_followers,
users::queries::get_user_by_name, users::queries::get_user_by_name,
}; };
use mitra_utils::{ use mitra_utils::{html::clean_html, urls::get_hostname};
html::clean_html,
urls::get_hostname,
};
use super::HandlerResult;
use crate::activitypub::{ use crate::activitypub::{
constants::{AP_MEDIA_TYPE, AP_PUBLIC, AS_MEDIA_TYPE}, constants::{AP_MEDIA_TYPE, AP_PUBLIC, AS_MEDIA_TYPE},
fetcher::fetchers::{fetch_file, FetchError}, fetcher::fetchers::{fetch_file, FetchError},
fetcher::helpers::{ fetcher::helpers::{
get_or_import_profile_by_actor_address, get_or_import_profile_by_actor_address, get_or_import_profile_by_actor_id,
get_or_import_profile_by_actor_id, get_post_by_object_id, import_post,
get_post_by_object_id,
import_post,
}, },
identifiers::{parse_local_actor_id, profile_actor_id}, identifiers::{parse_local_actor_id, profile_actor_id},
receiver::{parse_array, parse_property_value, HandlerError}, receiver::{parse_array, parse_property_value, HandlerError},
@ -45,28 +37,19 @@ use crate::activitypub::{
use crate::errors::ValidationError; use crate::errors::ValidationError;
use crate::media::MediaStorage; use crate::media::MediaStorage;
use crate::validators::{ use crate::validators::{
emojis::{ emojis::{validate_emoji_name, EMOJI_MEDIA_TYPES},
validate_emoji_name,
EMOJI_MEDIA_TYPES,
},
posts::{ posts::{
content_allowed_classes, content_allowed_classes, ATTACHMENT_LIMIT, CONTENT_MAX_SIZE, EMOJI_LIMIT, LINK_LIMIT,
ATTACHMENT_LIMIT, MENTION_LIMIT, OBJECT_ID_SIZE_MAX,
CONTENT_MAX_SIZE,
EMOJI_LIMIT,
LINK_LIMIT,
MENTION_LIMIT,
OBJECT_ID_SIZE_MAX,
}, },
tags::validate_hashtag, tags::validate_hashtag,
}; };
use crate::webfinger::types::ActorAddress; use crate::webfinger::types::ActorAddress;
use super::HandlerResult;
fn get_object_attributed_to(object: &Object) fn get_object_attributed_to(object: &Object) -> Result<String, ValidationError> {
-> Result<String, ValidationError> let attributed_to = object
{ .attributed_to
let attributed_to = object.attributed_to.as_ref() .as_ref()
.ok_or(ValidationError("unattributed note"))?; .ok_or(ValidationError("unattributed note"))?;
let author_id = parse_array(attributed_to) let author_id = parse_array(attributed_to)
.map_err(|_| ValidationError("invalid attributedTo property"))? .map_err(|_| ValidationError("invalid attributedTo property"))?
@ -83,7 +66,7 @@ pub fn get_object_url(object: &Object) -> Result<String, ValidationError> {
let links: Vec<Link> = parse_property_value(other_value) let links: Vec<Link> = parse_property_value(other_value)
.map_err(|_| ValidationError("invalid object URL"))?; .map_err(|_| ValidationError("invalid object URL"))?;
links.get(0).map(|link| link.href.clone()) links.get(0).map(|link| link.href.clone())
}, }
None => None, None => None,
}; };
let object_url = maybe_object_url.unwrap_or(object.id.clone()); let object_url = maybe_object_url.unwrap_or(object.id.clone());
@ -110,10 +93,7 @@ pub fn get_object_content(object: &Object) -> Result<String, ValidationError> {
} }
pub fn create_content_link(url: String) -> String { pub fn create_content_link(url: String) -> String {
format!( format!(r#"<p><a href="{0}" rel="noopener">{0}</a></p>"#, url,)
r#"<p><a href="{0}" rel="noopener">{0}</a></p>"#,
url,
)
} }
fn is_gnu_social_link(author_id: &str, attachment: &Attachment) -> bool { fn is_gnu_social_link(author_id: &str, attachment: &Attachment) -> bool {
@ -147,21 +127,16 @@ pub async fn get_object_attachments(
match attachment.attachment_type.as_str() { match attachment.attachment_type.as_str() {
DOCUMENT | IMAGE | VIDEO => (), DOCUMENT | IMAGE | VIDEO => (),
_ => { _ => {
log::warn!( log::warn!("skipping attachment of type {}", attachment.attachment_type,);
"skipping attachment of type {}",
attachment.attachment_type,
);
continue; continue;
}, }
}; };
if is_gnu_social_link( if is_gnu_social_link(&profile_actor_id(&instance.url(), author), &attachment) {
&profile_actor_id(&instance.url(), author),
&attachment,
) {
// Don't fetch HTML pages attached by GNU Social // Don't fetch HTML pages attached by GNU Social
continue; continue;
}; };
let attachment_url = attachment.url let attachment_url = attachment
.url
.ok_or(ValidationError("attachment URL is missing"))?; .ok_or(ValidationError("attachment URL is missing"))?;
let (file_name, file_size, maybe_media_type) = match fetch_file( let (file_name, file_size, maybe_media_type) = match fetch_file(
instance, instance,
@ -169,17 +144,19 @@ pub async fn get_object_attachments(
attachment.media_type.as_deref(), attachment.media_type.as_deref(),
storage.file_size_limit, storage.file_size_limit,
&storage.media_dir, &storage.media_dir,
).await { )
.await
{
Ok(file) => file, Ok(file) => file,
Err(FetchError::FileTooLarge) => { Err(FetchError::FileTooLarge) => {
log::warn!("attachment is too large: {}", attachment_url); log::warn!("attachment is too large: {}", attachment_url);
unprocessed.push(attachment_url); unprocessed.push(attachment_url);
continue; continue;
}, }
Err(other_error) => { Err(other_error) => {
log::warn!("{}", other_error); log::warn!("{}", other_error);
return Err(ValidationError("failed to fetch attachment").into()); return Err(ValidationError("failed to fetch attachment").into());
}, }
}; };
log::info!("downloaded attachment {}", attachment_url); log::info!("downloaded attachment {}", attachment_url);
downloaded.push((file_name, file_size, maybe_media_type)); downloaded.push((file_name, file_size, maybe_media_type));
@ -188,7 +165,7 @@ pub async fn get_object_attachments(
log::warn!("too many attachments"); log::warn!("too many attachments");
break; break;
}; };
}; }
for (file_name, file_size, maybe_media_type) in downloaded { for (file_name, file_size, maybe_media_type) in downloaded {
let db_attachment = create_attachment( let db_attachment = create_attachment(
db_client, db_client,
@ -196,9 +173,10 @@ pub async fn get_object_attachments(
file_name, file_name,
file_size, file_size,
maybe_media_type, maybe_media_type,
).await?; )
.await?;
attachments.push(db_attachment.id); attachments.push(db_attachment.id);
}; }
}; };
Ok((attachments, unprocessed)) Ok((attachments, unprocessed))
} }
@ -209,9 +187,7 @@ fn normalize_hashtag(tag: &str) -> Result<String, ValidationError> {
Ok(tag_name.to_lowercase()) Ok(tag_name.to_lowercase())
} }
pub fn get_object_links( pub fn get_object_links(object: &Object) -> Vec<String> {
object: &Object,
) -> Vec<String> {
let mut links = vec![]; let mut links = vec![];
for tag_value in object.tag.clone() { for tag_value in object.tag.clone() {
let tag_type = tag_value["type"].as_str().unwrap_or(HASHTAG); let tag_type = tag_value["type"].as_str().unwrap_or(HASHTAG);
@ -221,11 +197,9 @@ pub fn get_object_links(
Err(_) => { Err(_) => {
log::warn!("invalid link tag"); log::warn!("invalid link tag");
continue; continue;
}, }
}; };
if tag.media_type != AP_MEDIA_TYPE && if tag.media_type != AP_MEDIA_TYPE && tag.media_type != AS_MEDIA_TYPE {
tag.media_type != AS_MEDIA_TYPE
{
// Unknown media type // Unknown media type
continue; continue;
}; };
@ -233,7 +207,7 @@ pub fn get_object_links(
links.push(tag.href); links.push(tag.href);
}; };
}; };
}; }
if let Some(ref object_id) = object.quote_url { if let Some(ref object_id) = object.quote_url {
if !links.contains(object_id) { if !links.contains(object_id) {
links.push(object_id.to_owned()); links.push(object_id.to_owned());
@ -253,17 +227,14 @@ pub async fn handle_emoji(
Err(error) => { Err(error) => {
log::warn!("invalid emoji tag: {}", error); log::warn!("invalid emoji tag: {}", error);
return Ok(None); return Ok(None);
}, }
}; };
let emoji_name = tag.name.trim_matches(':'); let emoji_name = tag.name.trim_matches(':');
if validate_emoji_name(emoji_name).is_err() { if validate_emoji_name(emoji_name).is_err() {
log::warn!("invalid emoji name: {}", emoji_name); log::warn!("invalid emoji name: {}", emoji_name);
return Ok(None); return Ok(None);
}; };
let maybe_emoji_id = match get_emoji_by_remote_object_id( let maybe_emoji_id = match get_emoji_by_remote_object_id(db_client, &tag.id).await {
db_client,
&tag.id,
).await {
Ok(emoji) => { Ok(emoji) => {
if emoji.updated_at >= tag.updated { if emoji.updated_at >= tag.updated {
// Emoji already exists and is up to date // Emoji already exists and is up to date
@ -274,7 +245,7 @@ pub async fn handle_emoji(
return Ok(None); return Ok(None);
}; };
Some(emoji.id) Some(emoji.id)
}, }
Err(DatabaseError::NotFound("emoji")) => None, Err(DatabaseError::NotFound("emoji")) => None,
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
}; };
@ -284,37 +255,32 @@ pub async fn handle_emoji(
tag.icon.media_type.as_deref(), tag.icon.media_type.as_deref(),
storage.emoji_size_limit, storage.emoji_size_limit,
&storage.media_dir, &storage.media_dir,
).await { )
.await
{
Ok(file) => file, Ok(file) => file,
Err(error) => { Err(error) => {
log::warn!("failed to fetch emoji: {}", error); log::warn!("failed to fetch emoji: {}", error);
return Ok(None); return Ok(None);
}, }
}; };
let media_type = match maybe_media_type { let media_type = match maybe_media_type {
Some(media_type) if EMOJI_MEDIA_TYPES.contains(&media_type.as_str()) => { Some(media_type) if EMOJI_MEDIA_TYPES.contains(&media_type.as_str()) => media_type,
media_type
},
_ => { _ => {
log::warn!( log::warn!("unexpected emoji media type: {:?}", maybe_media_type,);
"unexpected emoji media type: {:?}",
maybe_media_type,
);
return Ok(None); return Ok(None);
}, }
}; };
log::info!("downloaded emoji {}", tag.icon.url); log::info!("downloaded emoji {}", tag.icon.url);
let image = EmojiImage { file_name, file_size, media_type }; let image = EmojiImage {
file_name,
file_size,
media_type,
};
let emoji = if let Some(emoji_id) = maybe_emoji_id { let emoji = if let Some(emoji_id) = maybe_emoji_id {
update_emoji( update_emoji(db_client, &emoji_id, image, &tag.updated).await?
db_client,
&emoji_id,
image,
&tag.updated,
).await?
} else { } else {
let hostname = get_hostname(&tag.id) let hostname = get_hostname(&tag.id).map_err(|_| ValidationError("invalid emoji ID"))?;
.map_err(|_| ValidationError("invalid emoji ID"))?;
match create_emoji( match create_emoji(
db_client, db_client,
emoji_name, emoji_name,
@ -322,12 +288,14 @@ pub async fn handle_emoji(
image, image,
Some(&tag.id), Some(&tag.id),
&tag.updated, &tag.updated,
).await { )
.await
{
Ok(emoji) => emoji, Ok(emoji) => emoji,
Err(DatabaseError::AlreadyExists(_)) => { Err(DatabaseError::AlreadyExists(_)) => {
log::warn!("emoji name is not unique: {}", emoji_name); log::warn!("emoji name is not unique: {}", emoji_name);
return Ok(None); return Ok(None);
}, }
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
} }
}; };
@ -353,7 +321,7 @@ pub async fn get_object_tags(
Err(_) => { Err(_) => {
log::warn!("invalid hashtag"); log::warn!("invalid hashtag");
continue; continue;
}, }
}; };
if let Some(tag_name) = tag.name { if let Some(tag_name) = tag.name {
// Ignore invalid tags // Ignore invalid tags
@ -373,7 +341,7 @@ pub async fn get_object_tags(
Err(_) => { Err(_) => {
log::warn!("invalid mention"); log::warn!("invalid mention");
continue; continue;
}, }
}; };
// Try to find profile by actor ID. // Try to find profile by actor ID.
if let Some(href) = tag.href { if let Some(href) = tag.href {
@ -386,25 +354,16 @@ pub async fn get_object_tags(
}; };
// NOTE: `href` attribute is usually actor ID // NOTE: `href` attribute is usually actor ID
// but also can be actor URL (profile link). // but also can be actor URL (profile link).
match get_or_import_profile_by_actor_id( match get_or_import_profile_by_actor_id(db_client, instance, storage, &href).await {
db_client,
instance,
storage,
&href,
).await {
Ok(profile) => { Ok(profile) => {
if !mentions.contains(&profile.id) { if !mentions.contains(&profile.id) {
mentions.push(profile.id); mentions.push(profile.id);
}; };
continue; continue;
}, }
Err(error) => { Err(error) => {
log::warn!( log::warn!("failed to find mentioned profile by ID {}: {}", href, error,);
"failed to find mentioned profile by ID {}: {}", }
href,
error,
);
},
}; };
}; };
// Try to find profile by actor address // Try to find profile by actor address
@ -413,7 +372,7 @@ pub async fn get_object_tags(
None => { None => {
log::warn!("failed to parse mention"); log::warn!("failed to parse mention");
continue; continue;
}, }
}; };
if let Ok(actor_address) = ActorAddress::from_mention(&tag_name) { if let Ok(actor_address) = ActorAddress::from_mention(&tag_name) {
let profile = match get_or_import_profile_by_actor_address( let profile = match get_or_import_profile_by_actor_address(
@ -421,12 +380,14 @@ pub async fn get_object_tags(
instance, instance,
storage, storage,
&actor_address, &actor_address,
).await { )
.await
{
Ok(profile) => profile, Ok(profile) => profile,
Err(error @ ( Err(
HandlerError::FetchError(_) | error @ (HandlerError::FetchError(_)
HandlerError::DatabaseError(DatabaseError::NotFound(_)) | HandlerError::DatabaseError(DatabaseError::NotFound(_))),
)) => { ) => {
// Ignore mention if fetcher fails // Ignore mention if fetcher fails
// Ignore mention if local address is not valid // Ignore mention if local address is not valid
log::warn!( log::warn!(
@ -435,7 +396,7 @@ pub async fn get_object_tags(
error, error,
); );
continue; continue;
}, }
Err(other_error) => return Err(other_error), Err(other_error) => return Err(other_error),
}; };
if !mentions.contains(&profile.id) { if !mentions.contains(&profile.id) {
@ -454,20 +415,14 @@ pub async fn get_object_tags(
Err(_) => { Err(_) => {
log::warn!("invalid link tag"); log::warn!("invalid link tag");
continue; continue;
}, }
}; };
if tag.media_type != AP_MEDIA_TYPE && if tag.media_type != AP_MEDIA_TYPE && tag.media_type != AS_MEDIA_TYPE {
tag.media_type != AS_MEDIA_TYPE
{
// Unknown media type // Unknown media type
continue; continue;
}; };
let href = redirects.get(&tag.href).unwrap_or(&tag.href); let href = redirects.get(&tag.href).unwrap_or(&tag.href);
let linked = get_post_by_object_id( let linked = get_post_by_object_id(db_client, &instance.url(), href).await?;
db_client,
&instance.url(),
href,
).await?;
if !links.contains(&linked.id) { if !links.contains(&linked.id) {
links.push(linked.id); links.push(linked.id);
}; };
@ -476,30 +431,21 @@ pub async fn get_object_tags(
log::warn!("too many emojis"); log::warn!("too many emojis");
continue; continue;
}; };
match handle_emoji( match handle_emoji(db_client, instance, storage, tag_value).await? {
db_client,
instance,
storage,
tag_value,
).await? {
Some(emoji) => { Some(emoji) => {
if !emojis.contains(&emoji.id) { if !emojis.contains(&emoji.id) {
emojis.push(emoji.id); emojis.push(emoji.id);
}; };
}, }
None => continue, None => continue,
}; };
} else { } else {
log::warn!("skipping tag of type {}", tag_type); log::warn!("skipping tag of type {}", tag_type);
}; };
}; }
if let Some(ref object_id) = object.quote_url { if let Some(ref object_id) = object.quote_url {
let object_id = redirects.get(object_id).unwrap_or(object_id); let object_id = redirects.get(object_id).unwrap_or(object_id);
let linked = get_post_by_object_id( let linked = get_post_by_object_id(db_client, &instance.url(), object_id).await?;
db_client,
&instance.url(),
object_id,
).await?;
if !links.contains(&linked.id) { if !links.contains(&linked.id) {
links.push(linked.id); links.push(linked.id);
}; };
@ -510,16 +456,14 @@ pub async fn get_object_tags(
fn get_audience(object: &Object) -> Result<Vec<String>, ValidationError> { fn get_audience(object: &Object) -> Result<Vec<String>, ValidationError> {
let primary_audience = match object.to { let primary_audience = match object.to {
Some(ref value) => { Some(ref value) => {
parse_array(value) parse_array(value).map_err(|_| ValidationError("invalid 'to' property value"))?
.map_err(|_| ValidationError("invalid 'to' property value"))? }
},
None => vec![], None => vec![],
}; };
let secondary_audience = match object.cc { let secondary_audience = match object.cc {
Some(ref value) => { Some(ref value) => {
parse_array(value) parse_array(value).map_err(|_| ValidationError("invalid 'cc' property value"))?
.map_err(|_| ValidationError("invalid 'cc' property value"))? }
},
None => vec![], None => vec![],
}; };
let audience = [primary_audience, secondary_audience].concat(); let audience = [primary_audience, secondary_audience].concat();
@ -528,22 +472,19 @@ fn get_audience(object: &Object) -> Result<Vec<String>, ValidationError> {
fn is_public_object(audience: &[String]) -> bool { fn is_public_object(audience: &[String]) -> bool {
// Some servers (e.g. Takahe) use "as" namespace // Some servers (e.g. Takahe) use "as" namespace
const PUBLIC_VARIANTS: [&str; 3] = [ const PUBLIC_VARIANTS: [&str; 3] = [AP_PUBLIC, "as:Public", "Public"];
AP_PUBLIC, audience
"as:Public", .iter()
"Public", .any(|item| PUBLIC_VARIANTS.contains(&item.as_str()))
];
audience.iter().any(|item| PUBLIC_VARIANTS.contains(&item.as_str()))
} }
fn get_object_visibility( fn get_object_visibility(author: &DbActorProfile, audience: &[String]) -> Visibility {
author: &DbActorProfile,
audience: &[String],
) -> Visibility {
if is_public_object(audience) { if is_public_object(audience) {
return Visibility::Public; return Visibility::Public;
}; };
let actor = author.actor_json.as_ref() let actor = author
.actor_json
.as_ref()
.expect("actor data should be present"); .expect("actor data should be present");
if let Some(ref followers) = actor.followers { if let Some(ref followers) = actor.followers {
if audience.contains(followers) { if audience.contains(followers) {
@ -569,26 +510,23 @@ pub async fn handle_note(
NOTE => (), NOTE => (),
ARTICLE | EVENT | QUESTION | PAGE | VIDEO => { ARTICLE | EVENT | QUESTION | PAGE | VIDEO => {
log::info!("processing object of type {}", object.object_type); log::info!("processing object of type {}", object.object_type);
}, }
other_type => { other_type => {
log::warn!("discarding object of type {}", other_type); log::warn!("discarding object of type {}", other_type);
return Err(ValidationError("unsupported object type").into()); return Err(ValidationError("unsupported object type").into());
}, }
}; };
if object.id.len() > OBJECT_ID_SIZE_MAX { if object.id.len() > OBJECT_ID_SIZE_MAX {
return Err(ValidationError("object ID is too long").into()); return Err(ValidationError("object ID is too long").into());
}; };
let author_id = get_object_attributed_to(&object)?; let author_id = get_object_attributed_to(&object)?;
let author = get_or_import_profile_by_actor_id( let author = get_or_import_profile_by_actor_id(db_client, instance, storage, &author_id)
db_client, .await
instance, .map_err(|err| {
storage, log::warn!("failed to import {} ({})", author_id, err);
&author_id, err
).await.map_err(|err| { })?;
log::warn!("failed to import {} ({})", author_id, err);
err
})?;
let mut content = get_object_content(&object)?; let mut content = get_object_content(&object)?;
if object.object_type != NOTE { if object.object_type != NOTE {
@ -596,38 +534,24 @@ pub async fn handle_note(
let object_url = get_object_url(&object)?; let object_url = get_object_url(&object)?;
content += &create_content_link(object_url); content += &create_content_link(object_url);
}; };
let (attachments, unprocessed) = get_object_attachments( let (attachments, unprocessed) =
db_client, get_object_attachments(db_client, instance, storage, &object, &author).await?;
instance,
storage,
&object,
&author,
).await?;
for attachment_url in unprocessed { for attachment_url in unprocessed {
content += &create_content_link(attachment_url); content += &create_content_link(attachment_url);
}; }
if content.is_empty() && attachments.is_empty() { if content.is_empty() && attachments.is_empty() {
return Err(ValidationError("post is empty").into()); return Err(ValidationError("post is empty").into());
}; };
let (mentions, hashtags, links, emojis) = get_object_tags( let (mentions, hashtags, links, emojis) =
db_client, get_object_tags(db_client, instance, storage, &object, redirects).await?;
instance,
storage,
&object,
redirects,
).await?;
let in_reply_to_id = match object.in_reply_to { let in_reply_to_id = match object.in_reply_to {
Some(ref object_id) => { Some(ref object_id) => {
let object_id = redirects.get(object_id).unwrap_or(object_id); let object_id = redirects.get(object_id).unwrap_or(object_id);
let in_reply_to = get_post_by_object_id( let in_reply_to = get_post_by_object_id(db_client, &instance.url(), object_id).await?;
db_client,
&instance.url(),
object_id,
).await?;
Some(in_reply_to.id) Some(in_reply_to.id)
}, }
None => None, None => None,
}; };
let audience = get_audience(&object)?; let audience = get_audience(&object)?;
@ -665,17 +589,15 @@ pub async fn is_unsolicited_message(
object: &Object, object: &Object,
) -> Result<bool, HandlerError> { ) -> Result<bool, HandlerError> {
let author_id = get_object_attributed_to(object)?; let author_id = get_object_attributed_to(object)?;
let author_has_followers = let author_has_followers = has_local_followers(db_client, &author_id).await?;
has_local_followers(db_client, &author_id).await?;
let audience = get_audience(object)?; let audience = get_audience(object)?;
let has_local_recipients = audience.iter().any(|actor_id| { let has_local_recipients = audience
parse_local_actor_id(instance_url, actor_id).is_ok() .iter()
}); .any(|actor_id| parse_local_actor_id(instance_url, actor_id).is_ok());
let result = let result = object.in_reply_to.is_none()
object.in_reply_to.is_none() && && is_public_object(&audience)
is_public_object(&audience) && && !has_local_recipients
!has_local_recipients && && !author_has_followers;
!author_has_followers;
Ok(result) Ok(result)
} }
@ -691,8 +613,8 @@ pub async fn handle_create(
activity: JsonValue, activity: JsonValue,
mut is_authenticated: bool, mut is_authenticated: bool,
) -> HandlerResult { ) -> HandlerResult {
let activity: CreateNote = serde_json::from_value(activity) let activity: CreateNote =
.map_err(|_| ValidationError("invalid object"))?; serde_json::from_value(activity).map_err(|_| ValidationError("invalid object"))?;
let object = activity.object; let object = activity.object;
// Verify attribution // Verify attribution
@ -716,23 +638,21 @@ pub async fn handle_create(
&MediaStorage::from(config), &MediaStorage::from(config),
object_id, object_id,
object_received, object_received,
).await?; )
.await?;
Ok(Some(NOTE)) Ok(Some(NOTE))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json::json;
use mitra_models::profiles::types::DbActor;
use crate::activitypub::{
types::Object,
vocabulary::NOTE,
};
use super::*; use super::*;
use crate::activitypub::{types::Object, vocabulary::NOTE};
use mitra_models::profiles::types::DbActor;
use serde_json::json;
#[test] #[test]
fn test_get_object_attributed_to() { fn test_get_object_attributed_to() {
let object = Object { let object = Object {
object_type: NOTE.to_string(), object_type: NOTE.to_string(),
attributed_to: Some(json!(["https://example.org/1"])), attributed_to: Some(json!(["https://example.org/1"])),
..Default::default() ..Default::default()

View file

@ -4,14 +4,8 @@ use serde_json::Value;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::{DatabaseClient, DatabaseError}, database::{DatabaseClient, DatabaseError},
posts::queries::{ posts::queries::{delete_post, get_post_by_remote_object_id},
delete_post, profiles::queries::{delete_profile, get_profile_by_remote_actor_id},
get_post_by_remote_object_id,
},
profiles::queries::{
delete_profile,
get_profile_by_remote_actor_id,
},
}; };
use crate::activitypub::{ use crate::activitypub::{
@ -39,10 +33,7 @@ pub async fn handle_delete(
.map_err(|_| ValidationError("unexpected activity structure"))?; .map_err(|_| ValidationError("unexpected activity structure"))?;
if activity.object == activity.actor { if activity.object == activity.actor {
// Self-delete // Self-delete
let profile = match get_profile_by_remote_actor_id( let profile = match get_profile_by_remote_actor_id(db_client, &activity.object).await {
db_client,
&activity.object,
).await {
Ok(profile) => profile, Ok(profile) => profile,
// Ignore Delete(Person) if profile is not found // Ignore Delete(Person) if profile is not found
Err(DatabaseError::NotFound(_)) => return Ok(None), Err(DatabaseError::NotFound(_)) => return Ok(None),
@ -56,19 +47,13 @@ pub async fn handle_delete(
log::info!("deleted profile {}", profile.acct); log::info!("deleted profile {}", profile.acct);
return Ok(Some(PERSON)); return Ok(Some(PERSON));
}; };
let post = match get_post_by_remote_object_id( let post = match get_post_by_remote_object_id(db_client, &activity.object).await {
db_client,
&activity.object,
).await {
Ok(post) => post, Ok(post) => post,
// Ignore Delete(Note) if post is not found // Ignore Delete(Note) if post is not found
Err(DatabaseError::NotFound(_)) => return Ok(None), Err(DatabaseError::NotFound(_)) => return Ok(None),
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
}; };
let actor_profile = get_profile_by_remote_actor_id( let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
db_client,
&activity.actor,
).await?;
if post.author.id != actor_profile.id { if post.author.id != actor_profile.id {
return Err(ValidationError("actor is not an author").into()); return Err(ValidationError("actor is not an author").into());
}; };

View file

@ -4,23 +4,18 @@ use serde_json::Value;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::{DatabaseClient, DatabaseError}, database::{DatabaseClient, DatabaseError},
relationships::queries::{ relationships::queries::{create_remote_follow_request_opt, follow_request_accepted},
create_remote_follow_request_opt,
follow_request_accepted,
},
users::queries::get_user_by_name, users::queries::get_user_by_name,
}; };
use super::{HandlerError, HandlerResult};
use crate::activitypub::{ use crate::activitypub::{
builders::accept_follow::prepare_accept_follow, builders::accept_follow::prepare_accept_follow,
fetcher::helpers::get_or_import_profile_by_actor_id, fetcher::helpers::get_or_import_profile_by_actor_id, identifiers::parse_local_actor_id,
identifiers::parse_local_actor_id, receiver::deserialize_into_object_id, vocabulary::PERSON,
receiver::deserialize_into_object_id,
vocabulary::PERSON,
}; };
use crate::errors::ValidationError; use crate::errors::ValidationError;
use crate::media::MediaStorage; use crate::media::MediaStorage;
use super::{HandlerError, HandlerResult};
#[derive(Deserialize)] #[derive(Deserialize)]
struct Follow { struct Follow {
@ -43,20 +38,18 @@ pub async fn handle_follow(
&config.instance(), &config.instance(),
&MediaStorage::from(config), &MediaStorage::from(config),
&activity.actor, &activity.actor,
).await?; )
let source_actor = source_profile.actor_json .await?;
.ok_or(HandlerError::LocalObject)?; let source_actor = source_profile.actor_json.ok_or(HandlerError::LocalObject)?;
let target_username = parse_local_actor_id( let target_username = parse_local_actor_id(&config.instance_url(), &activity.object)?;
&config.instance_url(),
&activity.object,
)?;
let target_user = get_user_by_name(db_client, &target_username).await?; let target_user = get_user_by_name(db_client, &target_username).await?;
let follow_request = create_remote_follow_request_opt( let follow_request = create_remote_follow_request_opt(
db_client, db_client,
&source_profile.id, &source_profile.id,
&target_user.id, &target_user.id,
&activity.id, &activity.id,
).await?; )
.await?;
match follow_request_accepted(db_client, &follow_request.id).await { match follow_request_accepted(db_client, &follow_request.id).await {
Ok(_) => (), Ok(_) => (),
// Proceed even if relationship already exists // Proceed even if relationship already exists
@ -70,7 +63,9 @@ pub async fn handle_follow(
&target_user, &target_user,
&source_actor, &source_actor,
&activity.id, &activity.id,
).enqueue(db_client).await?; )
.enqueue(db_client)
.await?;
Ok(Some(PERSON)) Ok(Some(PERSON))
} }

View file

@ -8,10 +8,7 @@ use mitra_models::{
}; };
use crate::activitypub::{ use crate::activitypub::{
fetcher::helpers::{ fetcher::helpers::{get_or_import_profile_by_actor_id, get_post_by_object_id},
get_or_import_profile_by_actor_id,
get_post_by_object_id,
},
receiver::deserialize_into_object_id, receiver::deserialize_into_object_id,
vocabulary::NOTE, vocabulary::NOTE,
}; };
@ -40,23 +37,16 @@ pub async fn handle_like(
&config.instance(), &config.instance(),
&MediaStorage::from(config), &MediaStorage::from(config),
&activity.actor, &activity.actor,
).await?; )
let post_id = match get_post_by_object_id( .await?;
db_client, let post_id =
&config.instance_url(), match get_post_by_object_id(db_client, &config.instance_url(), &activity.object).await {
&activity.object, Ok(post) => post.id,
).await { // Ignore like if post is not found locally
Ok(post) => post.id, Err(DatabaseError::NotFound(_)) => return Ok(None),
// Ignore like if post is not found locally Err(other_error) => return Err(other_error.into()),
Err(DatabaseError::NotFound(_)) => return Ok(None), };
Err(other_error) => return Err(other_error.into()), match create_reaction(db_client, &author.id, &post_id, Some(&activity.id)).await {
};
match create_reaction(
db_client,
&author.id,
&post_id,
Some(&activity.id),
).await {
Ok(_) => (), Ok(_) => (),
// Ignore activity if reaction is already saved // Ignore activity if reaction is already saved
Err(DatabaseError::AlreadyExists(_)) => return Ok(None), Err(DatabaseError::AlreadyExists(_)) => return Ok(None),

View file

@ -6,18 +6,12 @@ use mitra_models::{
database::DatabaseClient, database::DatabaseClient,
notifications::queries::create_move_notification, notifications::queries::create_move_notification,
profiles::helpers::find_verified_aliases, profiles::helpers::find_verified_aliases,
relationships::queries::{ relationships::queries::{get_followers, unfollow},
get_followers,
unfollow,
},
users::queries::{get_user_by_id, get_user_by_name}, users::queries::{get_user_by_id, get_user_by_name},
}; };
use crate::activitypub::{ use crate::activitypub::{
builders::{ builders::{follow::follow_or_create_request, undo_follow::prepare_undo_follow},
follow::follow_or_create_request,
undo_follow::prepare_undo_follow,
},
fetcher::helpers::get_or_import_profile_by_actor_id, fetcher::helpers::get_or_import_profile_by_actor_id,
identifiers::{parse_local_actor_id, profile_actor_id}, identifiers::{parse_local_actor_id, profile_actor_id},
vocabulary::PERSON, vocabulary::PERSON,
@ -50,39 +44,26 @@ pub async fn handle_move(
let instance = config.instance(); let instance = config.instance();
let storage = MediaStorage::from(config); let storage = MediaStorage::from(config);
let old_profile = if let Ok(username) = parse_local_actor_id( let old_profile = if let Ok(username) = parse_local_actor_id(&instance.url(), &activity.object)
&instance.url(), {
&activity.object,
) {
let old_user = get_user_by_name(db_client, &username).await?; let old_user = get_user_by_name(db_client, &username).await?;
old_user.profile old_user.profile
} else { } else {
get_or_import_profile_by_actor_id( get_or_import_profile_by_actor_id(db_client, &instance, &storage, &activity.object).await?
db_client,
&instance,
&storage,
&activity.object,
).await?
}; };
let old_actor_id = profile_actor_id(&instance.url(), &old_profile); let old_actor_id = profile_actor_id(&instance.url(), &old_profile);
let new_profile = if let Ok(username) = parse_local_actor_id( let new_profile = if let Ok(username) = parse_local_actor_id(&instance.url(), &activity.target)
&instance.url(), {
&activity.target,
) {
let new_user = get_user_by_name(db_client, &username).await?; let new_user = get_user_by_name(db_client, &username).await?;
new_user.profile new_user.profile
} else { } else {
get_or_import_profile_by_actor_id( get_or_import_profile_by_actor_id(db_client, &instance, &storage, &activity.target).await?
db_client,
&instance,
&storage,
&activity.target,
).await?
}; };
// Find aliases by DIDs (verified) // Find aliases by DIDs (verified)
let mut aliases = find_verified_aliases(db_client, &new_profile).await? let mut aliases = find_verified_aliases(db_client, &new_profile)
.await?
.into_iter() .into_iter()
.map(|profile| profile_actor_id(&instance.url(), &profile)) .map(|profile| profile_actor_id(&instance.url(), &profile))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -96,39 +77,22 @@ pub async fn handle_move(
for follower in followers { for follower in followers {
let follower = get_user_by_id(db_client, &follower.id).await?; let follower = get_user_by_id(db_client, &follower.id).await?;
// Unfollow old profile // Unfollow old profile
let maybe_follow_request_id = unfollow( let maybe_follow_request_id = unfollow(db_client, &follower.id, &old_profile.id).await?;
db_client,
&follower.id,
&old_profile.id,
).await?;
// Send Undo(Follow) if old actor is not local // Send Undo(Follow) if old actor is not local
if let Some(ref old_actor) = old_profile.actor_json { if let Some(ref old_actor) = old_profile.actor_json {
let follow_request_id = maybe_follow_request_id let follow_request_id = maybe_follow_request_id.expect("follow request must exist");
.expect("follow request must exist"); prepare_undo_follow(&instance, &follower, old_actor, &follow_request_id)
prepare_undo_follow( .enqueue(db_client)
&instance, .await?;
&follower,
old_actor,
&follow_request_id,
).enqueue(db_client).await?;
}; };
if follower.id == new_profile.id { if follower.id == new_profile.id {
// Don't self-follow // Don't self-follow
continue; continue;
}; };
// Follow new profile // Follow new profile
follow_or_create_request( follow_or_create_request(db_client, &instance, &follower, &new_profile).await?;
db_client, create_move_notification(db_client, &new_profile.id, &follower.id).await?;
&instance, }
&follower,
&new_profile,
).await?;
create_move_notification(
db_client,
&new_profile.id,
&follower.id,
).await?;
};
Ok(Some(PERSON)) Ok(Some(PERSON))
} }

View file

@ -5,17 +5,12 @@ use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::DatabaseClient, database::DatabaseClient,
profiles::queries::get_profile_by_remote_actor_id, profiles::queries::get_profile_by_remote_actor_id,
relationships::queries::{ relationships::queries::{follow_request_rejected, get_follow_request_by_id},
follow_request_rejected,
get_follow_request_by_id,
},
relationships::types::FollowRequestStatus, relationships::types::FollowRequestStatus,
}; };
use crate::activitypub::{ use crate::activitypub::{
identifiers::parse_local_object_id, identifiers::parse_local_object_id, receiver::deserialize_into_object_id, vocabulary::FOLLOW,
receiver::deserialize_into_object_id,
vocabulary::FOLLOW,
}; };
use crate::errors::ValidationError; use crate::errors::ValidationError;
@ -36,14 +31,8 @@ pub async fn handle_reject(
// Reject(Follow) // Reject(Follow)
let activity: Reject = serde_json::from_value(activity) let activity: Reject = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?; .map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id( let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
db_client, let follow_request_id = parse_local_object_id(&config.instance_url(), &activity.object)?;
&activity.actor,
).await?;
let follow_request_id = parse_local_object_id(
&config.instance_url(),
&activity.object,
)?;
let follow_request = get_follow_request_by_id(db_client, &follow_request_id).await?; let follow_request = get_follow_request_by_id(db_client, &follow_request_id).await?;
if follow_request.target_id != actor_profile.id { if follow_request.target_id != actor_profile.id {
return Err(ValidationError("actor is not a target").into()); return Err(ValidationError("actor is not a target").into());

View file

@ -4,18 +4,13 @@ use serde_json::Value;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::{DatabaseClient, DatabaseError}, database::{DatabaseClient, DatabaseError},
notifications::queries::{ notifications::queries::create_subscription_expiration_notification,
create_subscription_expiration_notification,
},
profiles::queries::get_profile_by_remote_actor_id, profiles::queries::get_profile_by_remote_actor_id,
relationships::queries::unsubscribe, relationships::queries::unsubscribe,
users::queries::get_user_by_name, users::queries::get_user_by_name,
}; };
use crate::activitypub::{ use crate::activitypub::{identifiers::parse_local_actor_id, vocabulary::PERSON};
identifiers::parse_local_actor_id,
vocabulary::PERSON,
};
use crate::errors::ValidationError; use crate::errors::ValidationError;
use super::{HandlerError, HandlerResult}; use super::{HandlerError, HandlerResult};
@ -34,28 +29,19 @@ pub async fn handle_remove(
) -> HandlerResult { ) -> HandlerResult {
let activity: Remove = serde_json::from_value(activity) let activity: Remove = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?; .map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id( let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
db_client,
&activity.actor,
).await?;
let actor = actor_profile.actor_json.ok_or(HandlerError::LocalObject)?; let actor = actor_profile.actor_json.ok_or(HandlerError::LocalObject)?;
if Some(activity.target) == actor.subscribers { if Some(activity.target) == actor.subscribers {
// Removing from subscribers // Removing from subscribers
let username = parse_local_actor_id( let username = parse_local_actor_id(&config.instance_url(), &activity.object)?;
&config.instance_url(),
&activity.object,
)?;
let user = get_user_by_name(db_client, &username).await?; let user = get_user_by_name(db_client, &username).await?;
// actor is recipient, user is sender // actor is recipient, user is sender
match unsubscribe(db_client, &user.id, &actor_profile.id).await { match unsubscribe(db_client, &user.id, &actor_profile.id).await {
Ok(_) => { Ok(_) => {
create_subscription_expiration_notification( create_subscription_expiration_notification(db_client, &actor_profile.id, &user.id)
db_client, .await?;
&actor_profile.id,
&user.id,
).await?;
return Ok(Some(PERSON)); return Ok(Some(PERSON));
}, }
// Ignore removal if relationship does not exist // Ignore removal if relationship does not exist
Err(DatabaseError::NotFound(_)) => return Ok(None), Err(DatabaseError::NotFound(_)) => return Ok(None),
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),

View file

@ -4,22 +4,10 @@ use serde_json::Value;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::{DatabaseClient, DatabaseError}, database::{DatabaseClient, DatabaseError},
posts::queries::{ posts::queries::{delete_post, get_post_by_remote_object_id},
delete_post, profiles::queries::{get_profile_by_acct, get_profile_by_remote_actor_id},
get_post_by_remote_object_id, reactions::queries::{delete_reaction, get_reaction_by_remote_activity_id},
}, relationships::queries::{get_follow_request_by_activity_id, unfollow},
profiles::queries::{
get_profile_by_acct,
get_profile_by_remote_actor_id,
},
reactions::queries::{
delete_reaction,
get_reaction_by_remote_activity_id,
},
relationships::queries::{
get_follow_request_by_activity_id,
unfollow,
},
}; };
use crate::activitypub::{ use crate::activitypub::{
@ -44,15 +32,9 @@ async fn handle_undo_follow(
) -> HandlerResult { ) -> HandlerResult {
let activity: UndoFollow = serde_json::from_value(activity) let activity: UndoFollow = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?; .map_err(|_| ValidationError("unexpected activity structure"))?;
let source_profile = get_profile_by_remote_actor_id( let source_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
db_client,
&activity.actor,
).await?;
let target_actor_id = find_object_id(&activity.object["object"])?; let target_actor_id = find_object_id(&activity.object["object"])?;
let target_username = parse_local_actor_id( let target_username = parse_local_actor_id(&config.instance_url(), &target_actor_id)?;
&config.instance_url(),
&target_actor_id,
)?;
// acct equals username if profile is local // acct equals username if profile is local
let target_profile = get_profile_by_acct(db_client, &target_username).await?; let target_profile = get_profile_by_acct(db_client, &target_username).await?;
match unfollow(db_client, &source_profile.id, &target_profile.id).await { match unfollow(db_client, &source_profile.id, &target_profile.id).await {
@ -83,10 +65,7 @@ pub async fn handle_undo(
let activity: Undo = serde_json::from_value(activity) let activity: Undo = serde_json::from_value(activity)
.map_err(|_| ValidationError("unexpected activity structure"))?; .map_err(|_| ValidationError("unexpected activity structure"))?;
let actor_profile = get_profile_by_remote_actor_id( let actor_profile = get_profile_by_remote_actor_id(db_client, &activity.actor).await?;
db_client,
&activity.actor,
).await?;
match get_follow_request_by_activity_id(db_client, &activity.object).await { match get_follow_request_by_activity_id(db_client, &activity.object).await {
Ok(follow_request) => { Ok(follow_request) => {
@ -98,9 +77,10 @@ pub async fn handle_undo(
db_client, db_client,
&follow_request.source_id, &follow_request.source_id,
&follow_request.target_id, &follow_request.target_id,
).await?; )
.await?;
return Ok(Some(FOLLOW)); return Ok(Some(FOLLOW));
}, }
Err(DatabaseError::NotFound(_)) => (), // try other object types Err(DatabaseError::NotFound(_)) => (), // try other object types
Err(other_error) => return Err(other_error.into()), Err(other_error) => return Err(other_error.into()),
}; };
@ -111,19 +91,12 @@ pub async fn handle_undo(
if reaction.author_id != actor_profile.id { if reaction.author_id != actor_profile.id {
return Err(ValidationError("actor is not an author").into()); return Err(ValidationError("actor is not an author").into());
}; };
delete_reaction( delete_reaction(db_client, &reaction.author_id, &reaction.post_id).await?;
db_client,
&reaction.author_id,
&reaction.post_id,
).await?;
Ok(Some(LIKE)) Ok(Some(LIKE))
}, }
Err(DatabaseError::NotFound(_)) => { Err(DatabaseError::NotFound(_)) => {
// Undo(Announce) // Undo(Announce)
let post = match get_post_by_remote_object_id( let post = match get_post_by_remote_object_id(db_client, &activity.object).await {
db_client,
&activity.object,
).await {
Ok(post) => post, Ok(post) => post,
// Ignore undo if neither reaction nor repost is found // Ignore undo if neither reaction nor repost is found
Err(DatabaseError::NotFound(_)) => return Ok(None), Err(DatabaseError::NotFound(_)) => return Ok(None),
@ -139,7 +112,7 @@ pub async fn handle_undo(
None => return Err(ValidationError("object is not a repost").into()), None => return Err(ValidationError("object is not a repost").into()),
}; };
Ok(Some(ANNOUNCE)) Ok(Some(ANNOUNCE))
}, }
Err(other_error) => Err(other_error.into()), Err(other_error) => Err(other_error.into()),
} }
} }

View file

@ -7,24 +7,15 @@ use serde_json::Value;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::{ use mitra_models::{
database::{DatabaseClient, DatabaseError}, database::{DatabaseClient, DatabaseError},
posts::queries::{ posts::queries::{get_post_by_remote_object_id, update_post},
get_post_by_remote_object_id,
update_post,
},
posts::types::PostUpdateData, posts::types::PostUpdateData,
profiles::queries::get_profile_by_remote_actor_id, profiles::queries::get_profile_by_remote_actor_id,
}; };
use crate::activitypub::{ use crate::activitypub::{
actors::{ actors::{helpers::update_remote_profile, types::Actor},
helpers::update_remote_profile,
types::Actor,
},
handlers::create::{ handlers::create::{
create_content_link, create_content_link, get_object_attachments, get_object_content, get_object_tags,
get_object_attachments,
get_object_content,
get_object_tags,
get_object_url, get_object_url,
}, },
identifiers::profile_actor_id, identifiers::profile_actor_id,
@ -47,13 +38,10 @@ async fn handle_update_note(
db_client: &mut impl DatabaseClient, db_client: &mut impl DatabaseClient,
activity: Value, activity: Value,
) -> HandlerResult { ) -> HandlerResult {
let activity: UpdateNote = serde_json::from_value(activity) let activity: UpdateNote =
.map_err(|_| ValidationError("invalid object"))?; serde_json::from_value(activity).map_err(|_| ValidationError("invalid object"))?;
let object = activity.object; let object = activity.object;
let post = match get_post_by_remote_object_id( let post = match get_post_by_remote_object_id(db_client, &object.id).await {
db_client,
&object.id,
).await {
Ok(post) => post, Ok(post) => post,
// Ignore Update if post is not found locally // Ignore Update if post is not found locally
Err(DatabaseError::NotFound(_)) => return Ok(None), Err(DatabaseError::NotFound(_)) => return Ok(None),
@ -71,26 +59,16 @@ async fn handle_update_note(
}; };
let is_sensitive = object.sensitive.unwrap_or(false); let is_sensitive = object.sensitive.unwrap_or(false);
let storage = MediaStorage::from(config); let storage = MediaStorage::from(config);
let (attachments, unprocessed) = get_object_attachments( let (attachments, unprocessed) =
db_client, get_object_attachments(db_client, &instance, &storage, &object, &post.author).await?;
&instance,
&storage,
&object,
&post.author,
).await?;
for attachment_url in unprocessed { for attachment_url in unprocessed {
content += &create_content_link(attachment_url); content += &create_content_link(attachment_url);
}; }
if content.is_empty() && attachments.is_empty() { if content.is_empty() && attachments.is_empty() {
return Err(ValidationError("post is empty").into()); return Err(ValidationError("post is empty").into());
}; };
let (mentions, hashtags, links, emojis) = get_object_tags( let (mentions, hashtags, links, emojis) =
db_client, get_object_tags(db_client, &instance, &storage, &object, &HashMap::new()).await?;
&instance,
&storage,
&object,
&HashMap::new(),
).await?;
let updated_at = object.updated.unwrap_or(Utc::now()); let updated_at = object.updated.unwrap_or(Utc::now());
let post_data = PostUpdateData { let post_data = PostUpdateData {
content, content,
@ -117,22 +95,20 @@ async fn handle_update_person(
db_client: &mut impl DatabaseClient, db_client: &mut impl DatabaseClient,
activity: Value, activity: Value,
) -> HandlerResult { ) -> HandlerResult {
let activity: UpdatePerson = serde_json::from_value(activity) let activity: UpdatePerson =
.map_err(|_| ValidationError("invalid actor data"))?; serde_json::from_value(activity).map_err(|_| ValidationError("invalid actor data"))?;
if activity.object.id != activity.actor { if activity.object.id != activity.actor {
return Err(ValidationError("actor ID mismatch").into()); return Err(ValidationError("actor ID mismatch").into());
}; };
let profile = get_profile_by_remote_actor_id( let profile = get_profile_by_remote_actor_id(db_client, &activity.object.id).await?;
db_client,
&activity.object.id,
).await?;
update_remote_profile( update_remote_profile(
db_client, db_client,
&config.instance(), &config.instance(),
&MediaStorage::from(config), &MediaStorage::from(config),
profile, profile,
activity.object, activity.object,
).await?; )
.await?;
Ok(Some(PERSON)) Ok(Some(PERSON))
} }
@ -141,18 +117,15 @@ pub async fn handle_update(
db_client: &mut impl DatabaseClient, db_client: &mut impl DatabaseClient,
activity: Value, activity: Value,
) -> HandlerResult { ) -> HandlerResult {
let object_type = activity["object"]["type"].as_str() let object_type = activity["object"]["type"]
.as_str()
.ok_or(ValidationError("unknown object type"))?; .ok_or(ValidationError("unknown object type"))?;
match object_type { match object_type {
NOTE => { NOTE => handle_update_note(config, db_client, activity).await,
handle_update_note(config, db_client, activity).await PERSON => handle_update_person(config, db_client, activity).await,
},
PERSON => {
handle_update_person(config, db_client, activity).await
},
_ => { _ => {
log::warn!("unexpected object type {}", object_type); log::warn!("unexpected object type {}", object_type);
Ok(None) Ok(None)
}, }
} }
} }

View file

@ -14,9 +14,7 @@ pub enum Network {
I2p, I2p,
} }
pub fn get_network_type(request_url: &str) -> pub fn get_network_type(request_url: &str) -> Result<Network, url::ParseError> {
Result<Network, url::ParseError>
{
let hostname = get_hostname(request_url)?; let hostname = get_hostname(request_url)?;
let network = if hostname.ends_with(".onion") { let network = if hostname.ends_with(".onion") {
Network::Tor Network::Tor
@ -38,23 +36,18 @@ pub fn build_federation_client(
match network { match network {
Network::Default => (), Network::Default => (),
Network::Tor => { Network::Tor => {
maybe_proxy_url = instance.onion_proxy_url.as_ref() maybe_proxy_url = instance.onion_proxy_url.as_ref().or(maybe_proxy_url);
.or(maybe_proxy_url); }
},
Network::I2p => { Network::I2p => {
maybe_proxy_url = instance.i2p_proxy_url.as_ref() maybe_proxy_url = instance.i2p_proxy_url.as_ref().or(maybe_proxy_url);
.or(maybe_proxy_url); }
},
}; };
if let Some(proxy_url) = maybe_proxy_url { if let Some(proxy_url) = maybe_proxy_url {
let proxy = Proxy::all(proxy_url)?; let proxy = Proxy::all(proxy_url)?;
client_builder = client_builder.proxy(proxy); client_builder = client_builder.proxy(proxy);
}; };
let request_timeout = Duration::from_secs(timeout); let request_timeout = Duration::from_secs(timeout);
let connect_timeout = Duration::from_secs(max( let connect_timeout = Duration::from_secs(max(timeout, CONNECTION_TIMEOUT));
timeout,
CONNECTION_TIMEOUT,
));
client_builder client_builder
.timeout(request_timeout) .timeout(request_timeout)
.connect_timeout(connect_timeout) .connect_timeout(connect_timeout)

View file

@ -1,10 +1,7 @@
use regex::Regex; use regex::Regex;
use uuid::Uuid; use uuid::Uuid;
use mitra_models::{ use mitra_models::{posts::types::Post, profiles::types::DbActorProfile};
posts::types::Post,
profiles::types::DbActorProfile,
};
use mitra_utils::urls::get_hostname; use mitra_utils::urls::get_hostname;
use crate::errors::ValidationError; use crate::errors::ValidationError;
@ -83,45 +80,41 @@ pub fn local_tag_collection(instance_url: &str, tag_name: &str) -> String {
} }
pub fn validate_object_id(object_id: &str) -> Result<(), ValidationError> { pub fn validate_object_id(object_id: &str) -> Result<(), ValidationError> {
get_hostname(object_id) get_hostname(object_id).map_err(|_| ValidationError("invalid object ID"))?;
.map_err(|_| ValidationError("invalid object ID"))?;
Ok(()) Ok(())
} }
pub fn parse_local_actor_id( pub fn parse_local_actor_id(instance_url: &str, actor_id: &str) -> Result<String, ValidationError> {
instance_url: &str,
actor_id: &str,
) -> Result<String, ValidationError> {
let url_regexp_str = format!( let url_regexp_str = format!(
"^{}/users/(?P<username>[0-9a-z_]+)$", "^{}/users/(?P<username>[0-9a-z_]+)$",
instance_url.replace('.', r"\."), instance_url.replace('.', r"\."),
); );
let url_regexp = Regex::new(&url_regexp_str) let url_regexp = Regex::new(&url_regexp_str).map_err(|_| ValidationError("error"))?;
.map_err(|_| ValidationError("error"))?; let url_caps = url_regexp
let url_caps = url_regexp.captures(actor_id) .captures(actor_id)
.ok_or(ValidationError("invalid actor ID"))?; .ok_or(ValidationError("invalid actor ID"))?;
let username = url_caps.name("username") let username = url_caps
.name("username")
.ok_or(ValidationError("invalid actor ID"))? .ok_or(ValidationError("invalid actor ID"))?
.as_str() .as_str()
.to_owned(); .to_owned();
Ok(username) Ok(username)
} }
pub fn parse_local_object_id( pub fn parse_local_object_id(instance_url: &str, object_id: &str) -> Result<Uuid, ValidationError> {
instance_url: &str,
object_id: &str,
) -> Result<Uuid, ValidationError> {
let url_regexp_str = format!( let url_regexp_str = format!(
"^{}/objects/(?P<uuid>[0-9a-f-]+)$", "^{}/objects/(?P<uuid>[0-9a-f-]+)$",
instance_url.replace('.', r"\."), instance_url.replace('.', r"\."),
); );
let url_regexp = Regex::new(&url_regexp_str) let url_regexp = Regex::new(&url_regexp_str).map_err(|_| ValidationError("error"))?;
.map_err(|_| ValidationError("error"))?; let url_caps = url_regexp
let url_caps = url_regexp.captures(object_id) .captures(object_id)
.ok_or(ValidationError("invalid object ID"))?; .ok_or(ValidationError("invalid object ID"))?;
let internal_object_id: Uuid = url_caps.name("uuid") let internal_object_id: Uuid = url_caps
.name("uuid")
.ok_or(ValidationError("invalid object ID"))? .ok_or(ValidationError("invalid object ID"))?
.as_str().parse() .as_str()
.parse()
.map_err(|_| ValidationError("invalid object ID"))?; .map_err(|_| ValidationError("invalid object ID"))?;
Ok(internal_object_id) Ok(internal_object_id)
} }
@ -151,68 +144,51 @@ pub fn profile_actor_url(instance_url: &str, profile: &DbActorProfile) -> String
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use mitra_utils::id::generate_ulid;
use super::*; use super::*;
use mitra_utils::id::generate_ulid;
const INSTANCE_URL: &str = "https://example.org"; const INSTANCE_URL: &str = "https://example.org";
#[test] #[test]
fn test_parse_local_actor_id() { fn test_parse_local_actor_id() {
let username = parse_local_actor_id( let username =
INSTANCE_URL, parse_local_actor_id(INSTANCE_URL, "https://example.org/users/test").unwrap();
"https://example.org/users/test",
).unwrap();
assert_eq!(username, "test".to_string()); assert_eq!(username, "test".to_string());
} }
#[test] #[test]
fn test_parse_local_actor_id_wrong_path() { fn test_parse_local_actor_id_wrong_path() {
let error = parse_local_actor_id( let error =
INSTANCE_URL, parse_local_actor_id(INSTANCE_URL, "https://example.org/user/test").unwrap_err();
"https://example.org/user/test",
).unwrap_err();
assert_eq!(error.to_string(), "invalid actor ID"); assert_eq!(error.to_string(), "invalid actor ID");
} }
#[test] #[test]
fn test_parse_local_actor_id_invalid_username() { fn test_parse_local_actor_id_invalid_username() {
let error = parse_local_actor_id( let error =
INSTANCE_URL, parse_local_actor_id(INSTANCE_URL, "https://example.org/users/tes-t").unwrap_err();
"https://example.org/users/tes-t",
).unwrap_err();
assert_eq!(error.to_string(), "invalid actor ID"); assert_eq!(error.to_string(), "invalid actor ID");
} }
#[test] #[test]
fn test_parse_local_actor_id_invalid_instance_url() { fn test_parse_local_actor_id_invalid_instance_url() {
let error = parse_local_actor_id( let error =
INSTANCE_URL, parse_local_actor_id(INSTANCE_URL, "https://example.gov/users/test").unwrap_err();
"https://example.gov/users/test",
).unwrap_err();
assert_eq!(error.to_string(), "invalid actor ID"); assert_eq!(error.to_string(), "invalid actor ID");
} }
#[test] #[test]
fn test_parse_local_object_id() { fn test_parse_local_object_id() {
let expected_uuid = generate_ulid(); let expected_uuid = generate_ulid();
let object_id = format!( let object_id = format!("https://example.org/objects/{}", expected_uuid,);
"https://example.org/objects/{}", let internal_object_id = parse_local_object_id(INSTANCE_URL, &object_id).unwrap();
expected_uuid,
);
let internal_object_id = parse_local_object_id(
INSTANCE_URL,
&object_id,
).unwrap();
assert_eq!(internal_object_id, expected_uuid); assert_eq!(internal_object_id, expected_uuid);
} }
#[test] #[test]
fn test_parse_local_object_id_invalid_uuid() { fn test_parse_local_object_id_invalid_uuid() {
let object_id = "https://example.org/objects/1234"; let object_id = "https://example.org/objects/1234";
let error = parse_local_object_id( let error = parse_local_object_id(INSTANCE_URL, object_id).unwrap_err();
INSTANCE_URL,
object_id,
).unwrap_err();
assert_eq!(error.to_string(), "invalid object ID"); assert_eq!(error.to_string(), "invalid object ID");
} }
@ -223,9 +199,6 @@ mod tests {
..Default::default() ..Default::default()
}; };
let profile_url = profile_actor_url(INSTANCE_URL, &profile); let profile_url = profile_actor_url(INSTANCE_URL, &profile);
assert_eq!( assert_eq!(profile_url, "https://example.org/users/test",);
profile_url,
"https://example.org/users/test",
);
} }
} }

View file

@ -5,19 +5,9 @@ use uuid::Uuid;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::{ use mitra_models::{
background_jobs::queries::{ background_jobs::queries::{delete_job_from_queue, enqueue_job, get_job_batch},
enqueue_job,
get_job_batch,
delete_job_from_queue,
},
background_jobs::types::JobType, background_jobs::types::JobType,
database::{ database::{get_database_client, DatabaseClient, DatabaseError, DatabaseTypeError, DbPool},
get_database_client,
DatabaseClient,
DatabaseError,
DatabaseTypeError,
DbPool,
},
profiles::queries::set_reachability_status, profiles::queries::set_reachability_status,
users::queries::get_user_by_id, users::queries::get_user_by_id,
}; };
@ -49,15 +39,15 @@ impl IncomingActivityJobData {
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
delay: u32, delay: u32,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let job_data = serde_json::to_value(self) let job_data = serde_json::to_value(self).expect("activity should be serializable");
.expect("activity should be serializable");
let scheduled_for = Utc::now() + Duration::seconds(delay.into()); let scheduled_for = Utc::now() + Duration::seconds(delay.into());
enqueue_job( enqueue_job(
db_client, db_client,
&JobType::IncomingActivity, &JobType::IncomingActivity,
&job_data, &job_data,
&scheduled_for, &scheduled_for,
).await )
.await
} }
} }
@ -78,11 +68,11 @@ pub async fn process_queued_incoming_activities(
&JobType::IncomingActivity, &JobType::IncomingActivity,
INCOMING_QUEUE_BATCH_SIZE, INCOMING_QUEUE_BATCH_SIZE,
JOB_TIMEOUT, JOB_TIMEOUT,
).await?; )
.await?;
for job in batch { for job in batch {
let mut job_data: IncomingActivityJobData = let mut job_data: IncomingActivityJobData =
serde_json::from_value(job.job_data) serde_json::from_value(job.job_data).map_err(|_| DatabaseTypeError)?;
.map_err(|_| DatabaseTypeError)?;
// See also: activitypub::queues::JOB_TIMEOUT // See also: activitypub::queues::JOB_TIMEOUT
let duration_max = std::time::Duration::from_secs(600); let duration_max = std::time::Duration::from_secs(600);
let handler_future = handle_activity( let handler_future = handle_activity(
@ -91,10 +81,7 @@ pub async fn process_queued_incoming_activities(
&job_data.activity, &job_data.activity,
job_data.is_authenticated, job_data.is_authenticated,
); );
let handler_result = match tokio::time::timeout( let handler_result = match tokio::time::timeout(duration_max, handler_future).await {
duration_max,
handler_future,
).await {
Ok(result) => result, Ok(result) => result,
Err(_) => { Err(_) => {
log::error!( log::error!(
@ -103,12 +90,12 @@ pub async fn process_queued_incoming_activities(
); );
delete_job_from_queue(db_client, &job.id).await?; delete_job_from_queue(db_client, &job.id).await?;
continue; continue;
}, }
}; };
if let Err(error) = handler_result { if let Err(error) = handler_result {
job_data.failure_count += 1; job_data.failure_count += 1;
if let HandlerError::DatabaseError( if let HandlerError::DatabaseError(DatabaseError::DatabaseClientError(ref pg_error)) =
DatabaseError::DatabaseClientError(ref pg_error)) = error error
{ {
log::error!("database client error: {}", pg_error); log::error!("database client error: {}", pg_error);
}; };
@ -129,7 +116,7 @@ pub async fn process_queued_incoming_activities(
}; };
}; };
delete_job_from_queue(db_client, &job.id).await?; delete_job_from_queue(db_client, &job.id).await?;
}; }
Ok(()) Ok(())
} }
@ -147,15 +134,15 @@ impl OutgoingActivityJobData {
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
delay: u32, delay: u32,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let job_data = serde_json::to_value(self) let job_data = serde_json::to_value(self).expect("activity should be serializable");
.expect("activity should be serializable");
let scheduled_for = Utc::now() + Duration::seconds(delay.into()); let scheduled_for = Utc::now() + Duration::seconds(delay.into());
enqueue_job( enqueue_job(
db_client, db_client,
&JobType::OutgoingActivity, &JobType::OutgoingActivity,
&job_data, &job_data,
&scheduled_for, &scheduled_for,
).await )
.await
} }
} }
@ -178,11 +165,11 @@ pub async fn process_queued_outgoing_activities(
&JobType::OutgoingActivity, &JobType::OutgoingActivity,
OUTGOING_QUEUE_BATCH_SIZE, OUTGOING_QUEUE_BATCH_SIZE,
JOB_TIMEOUT, JOB_TIMEOUT,
).await?; )
.await?;
for job in batch { for job in batch {
let mut job_data: OutgoingActivityJobData = let mut job_data: OutgoingActivityJobData =
serde_json::from_value(job.job_data) serde_json::from_value(job.job_data).map_err(|_| DatabaseTypeError)?;
.map_err(|_| DatabaseTypeError)?;
let sender = get_user_by_id(db_client, &job_data.sender_id).await?; let sender = get_user_by_id(db_client, &job_data.sender_id).await?;
let outgoing_activity = OutgoingActivity { let outgoing_activity = OutgoingActivity {
instance: config.instance(), instance: config.instance(),
@ -198,7 +185,7 @@ pub async fn process_queued_outgoing_activities(
log::error!("{}", error); log::error!("{}", error);
delete_job_from_queue(db_client, &job.id).await?; delete_job_from_queue(db_client, &job.id).await?;
return Ok(()); return Ok(());
}, }
}; };
log::info!( log::info!(
"delivery job: {} delivered, {} errors (attempt #{})", "delivery job: {} delivered, {} errors (attempt #{})",
@ -206,8 +193,8 @@ pub async fn process_queued_outgoing_activities(
recipients.iter().filter(|item| !item.is_delivered).count(), recipients.iter().filter(|item| !item.is_delivered).count(),
job_data.failure_count + 1, job_data.failure_count + 1,
); );
if recipients.iter().any(|recipient| !recipient.is_delivered) && if recipients.iter().any(|recipient| !recipient.is_delivered)
job_data.failure_count < OUTGOING_QUEUE_RETRIES_MAX && job_data.failure_count < OUTGOING_QUEUE_RETRIES_MAX
{ {
job_data.failure_count += 1; job_data.failure_count += 1;
// Re-queue if some deliveries are not successful // Re-queue if some deliveries are not successful
@ -219,15 +206,11 @@ pub async fn process_queued_outgoing_activities(
// Update inbox status if all deliveries are successful // Update inbox status if all deliveries are successful
// or if retry limit is reached // or if retry limit is reached
for recipient in recipients { for recipient in recipients {
set_reachability_status( set_reachability_status(db_client, &recipient.id, recipient.is_delivered).await?;
db_client, }
&recipient.id,
recipient.is_delivered,
).await?;
};
}; };
delete_job_from_queue(db_client, &job.id).await?; delete_job_from_queue(db_client, &job.id).await?;
}; }
Ok(()) Ok(())
} }

View file

@ -1,35 +1,17 @@
use actix_web::HttpRequest; use actix_web::HttpRequest;
use serde::{ use serde::{de::DeserializeOwned, de::Error as DeserializerError, Deserialize, Deserializer};
Deserialize,
Deserializer,
de::DeserializeOwned,
de::Error as DeserializerError,
};
use serde_json::Value; use serde_json::Value;
use mitra_config::Config; use mitra_config::Config;
use mitra_models::database::{DatabaseClient, DatabaseError}; use mitra_models::database::{DatabaseClient, DatabaseError};
use crate::errors::{ use super::authentication::{verify_signed_activity, verify_signed_request, AuthenticationError};
ConversionError,
HttpError,
ValidationError,
};
use super::authentication::{
verify_signed_activity,
verify_signed_request,
AuthenticationError,
};
use super::fetcher::fetchers::FetchError; use super::fetcher::fetchers::FetchError;
use super::handlers::{ use super::handlers::{
accept::handle_accept, accept::handle_accept,
add::handle_add, add::handle_add,
announce::handle_announce, announce::handle_announce,
create::{ create::{handle_create, is_unsolicited_message, CreateNote},
handle_create,
is_unsolicited_message,
CreateNote,
},
delete::handle_delete, delete::handle_delete,
follow::handle_follow, follow::handle_follow,
like::handle_like, like::handle_like,
@ -42,6 +24,7 @@ use super::handlers::{
use super::identifiers::profile_actor_id; use super::identifiers::profile_actor_id;
use super::queues::IncomingActivityJobData; use super::queues::IncomingActivityJobData;
use super::vocabulary::*; use super::vocabulary::*;
use crate::errors::{ConversionError, HttpError, ValidationError};
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum HandlerError { pub enum HandlerError {
@ -65,14 +48,10 @@ impl From<HandlerError> for HttpError {
fn from(error: HandlerError) -> Self { fn from(error: HandlerError) -> Self {
match error { match error {
HandlerError::LocalObject => HttpError::InternalError, HandlerError::LocalObject => HttpError::InternalError,
HandlerError::FetchError(error) => { HandlerError::FetchError(error) => HttpError::ValidationError(error.to_string()),
HttpError::ValidationError(error.to_string())
},
HandlerError::ValidationError(error) => error.into(), HandlerError::ValidationError(error) => error.into(),
HandlerError::DatabaseError(error) => error.into(), HandlerError::DatabaseError(error) => error.into(),
HandlerError::AuthError(_) => { HandlerError::AuthError(_) => HttpError::AuthError("invalid signature"),
HttpError::AuthError("invalid signature")
},
} }
} }
} }
@ -93,13 +72,13 @@ pub fn parse_array(value: &Value) -> Result<Vec<String>, ConversionError> {
// id property is missing // id property is missing
return Err(ConversionError); return Err(ConversionError);
}; };
}, }
// Unexpected array item type // Unexpected array item type
_ => return Err(ConversionError), _ => return Err(ConversionError),
}; };
}; }
results results
}, }
// Unexpected value type // Unexpected value type
_ => return Err(ConversionError), _ => return Err(ConversionError),
}; };
@ -116,10 +95,9 @@ pub fn parse_property_value<T: DeserializeOwned>(value: &Value) -> Result<Vec<T>
}; };
let mut items = vec![]; let mut items = vec![];
for object in objects { for object in objects {
let item: T = serde_json::from_value(object) let item: T = serde_json::from_value(object).map_err(|_| ConversionError)?;
.map_err(|_| ConversionError)?;
items.push(item); items.push(item);
}; }
Ok(items) Ok(items)
} }
@ -128,23 +106,22 @@ pub fn find_object_id(object: &Value) -> Result<String, ValidationError> {
let object_id = match object.as_str() { let object_id = match object.as_str() {
Some(object_id) => object_id.to_owned(), Some(object_id) => object_id.to_owned(),
None => { None => {
let object_id = object["id"].as_str() let object_id = object["id"]
.as_str()
.ok_or(ValidationError("missing object ID"))? .ok_or(ValidationError("missing object ID"))?
.to_string(); .to_string();
object_id object_id
}, }
}; };
Ok(object_id) Ok(object_id)
} }
pub fn deserialize_into_object_id<'de, D>( pub fn deserialize_into_object_id<'de, D>(deserializer: D) -> Result<String, D::Error>
deserializer: D, where
) -> Result<String, D::Error> D: Deserializer<'de>,
where D: Deserializer<'de>
{ {
let value = Value::deserialize(deserializer)?; let value = Value::deserialize(deserializer)?;
let object_id = find_object_id(&value) let object_id = find_object_id(&value).map_err(DeserializerError::custom)?;
.map_err(DeserializerError::custom)?;
Ok(object_id) Ok(object_id)
} }
@ -154,54 +131,32 @@ pub async fn handle_activity(
activity: &Value, activity: &Value,
is_authenticated: bool, is_authenticated: bool,
) -> Result<(), HandlerError> { ) -> Result<(), HandlerError> {
let activity_type = activity["type"].as_str() let activity_type = activity["type"]
.as_str()
.ok_or(ValidationError("type property is missing"))? .ok_or(ValidationError("type property is missing"))?
.to_owned(); .to_owned();
let activity_actor = activity["actor"].as_str() let activity_actor = activity["actor"]
.as_str()
.ok_or(ValidationError("actor property is missing"))? .ok_or(ValidationError("actor property is missing"))?
.to_owned(); .to_owned();
let activity = activity.clone(); let activity = activity.clone();
let maybe_object_type = match activity_type.as_str() { let maybe_object_type = match activity_type.as_str() {
ACCEPT => { ACCEPT => handle_accept(config, db_client, activity).await?,
handle_accept(config, db_client, activity).await? ADD => handle_add(config, db_client, activity).await?,
}, ANNOUNCE => handle_announce(config, db_client, activity).await?,
ADD => { CREATE => handle_create(config, db_client, activity, is_authenticated).await?,
handle_add(config, db_client, activity).await? DELETE => handle_delete(config, db_client, activity).await?,
}, FOLLOW => handle_follow(config, db_client, activity).await?,
ANNOUNCE => { LIKE | EMOJI_REACT => handle_like(config, db_client, activity).await?,
handle_announce(config, db_client, activity).await? MOVE => handle_move(config, db_client, activity).await?,
}, REJECT => handle_reject(config, db_client, activity).await?,
CREATE => { REMOVE => handle_remove(config, db_client, activity).await?,
handle_create(config, db_client, activity, is_authenticated).await? UNDO => handle_undo(config, db_client, activity).await?,
}, UPDATE => handle_update(config, db_client, activity).await?,
DELETE => {
handle_delete(config, db_client, activity).await?
},
FOLLOW => {
handle_follow(config, db_client, activity).await?
},
LIKE | EMOJI_REACT => {
handle_like(config, db_client, activity).await?
},
MOVE => {
handle_move(config, db_client, activity).await?
},
REJECT => {
handle_reject(config, db_client, activity).await?
},
REMOVE => {
handle_remove(config, db_client, activity).await?
},
UNDO => {
handle_undo(config, db_client, activity).await?
},
UPDATE => {
handle_update(config, db_client, activity).await?
},
_ => { _ => {
log::warn!("activity type is not supported: {}", activity); log::warn!("activity type is not supported: {}", activity);
None None
}, }
}; };
if let Some(object_type) = maybe_object_type { if let Some(object_type) = maybe_object_type {
log::info!( log::info!(
@ -220,9 +175,11 @@ pub async fn receive_activity(
request: &HttpRequest, request: &HttpRequest,
activity: &Value, activity: &Value,
) -> Result<(), HandlerError> { ) -> Result<(), HandlerError> {
let activity_type = activity["type"].as_str() let activity_type = activity["type"]
.as_str()
.ok_or(ValidationError("type property is missing"))?; .ok_or(ValidationError("type property is missing"))?;
let activity_actor = activity["actor"].as_str() let activity_actor = activity["actor"]
.as_str()
.ok_or(ValidationError("actor property is missing"))?; .ok_or(ValidationError("actor property is missing"))?;
let actor_hostname = url::Url::parse(activity_actor) let actor_hostname = url::Url::parse(activity_actor)
@ -230,7 +187,9 @@ pub async fn receive_activity(
.host_str() .host_str()
.ok_or(ValidationError("invalid actor ID"))? .ok_or(ValidationError("invalid actor ID"))?
.to_string(); .to_string();
if config.blocked_instances.iter() if config
.blocked_instances
.iter()
.any(|instance_hostname| &actor_hostname == instance_hostname) .any(|instance_hostname| &actor_hostname == instance_hostname)
{ {
log::warn!("ignoring activity from blocked instance: {}", activity); log::warn!("ignoring activity from blocked instance: {}", activity);
@ -240,7 +199,9 @@ pub async fn receive_activity(
let is_self_delete = if activity_type == DELETE { let is_self_delete = if activity_type == DELETE {
let object_id = find_object_id(&activity["object"])?; let object_id = find_object_id(&activity["object"])?;
object_id == activity_actor object_id == activity_actor
} else { false }; } else {
false
};
// HTTP signature is required // HTTP signature is required
let mut signer = match verify_signed_request( let mut signer = match verify_signed_request(
@ -249,24 +210,28 @@ pub async fn receive_activity(
request, request,
// Don't fetch signer if this is Delete(Person) activity // Don't fetch signer if this is Delete(Person) activity
is_self_delete, is_self_delete,
).await { )
.await
{
Ok(request_signer) => { Ok(request_signer) => {
log::debug!("request signed by {}", request_signer.acct); log::debug!("request signed by {}", request_signer.acct);
request_signer request_signer
}, }
Err(error) => { Err(error) => {
if is_self_delete && matches!( if is_self_delete
error, && matches!(
AuthenticationError::NoHttpSignature | error,
AuthenticationError::DatabaseError(DatabaseError::NotFound(_)) AuthenticationError::NoHttpSignature
) { | AuthenticationError::DatabaseError(DatabaseError::NotFound(_))
)
{
// Ignore Delete(Person) activities without HTTP signatures // Ignore Delete(Person) activities without HTTP signatures
// or if signer is not found in local database // or if signer is not found in local database
return Ok(()); return Ok(());
}; };
log::warn!("invalid HTTP signature: {}", error); log::warn!("invalid HTTP signature: {}", error);
return Err(error.into()); return Err(error.into());
}, }
}; };
// JSON signature is optional // JSON signature is optional
@ -276,7 +241,9 @@ pub async fn receive_activity(
activity, activity,
// Don't fetch actor if this is Delete(Person) activity // Don't fetch actor if this is Delete(Person) activity
is_self_delete, is_self_delete,
).await { )
.await
{
Ok(activity_signer) => { Ok(activity_signer) => {
if activity_signer.acct != signer.acct { if activity_signer.acct != signer.acct {
log::warn!( log::warn!(
@ -289,14 +256,16 @@ pub async fn receive_activity(
}; };
// Activity signature has higher priority // Activity signature has higher priority
signer = activity_signer; signer = activity_signer;
}, }
Err(AuthenticationError::NoJsonSignature) => (), // ignore Err(AuthenticationError::NoJsonSignature) => (), // ignore
Err(other_error) => { Err(other_error) => {
log::warn!("invalid JSON signature: {}", other_error); log::warn!("invalid JSON signature: {}", other_error);
}, }
}; };
if config.blocked_instances.iter() if config
.blocked_instances
.iter()
.any(|instance| signer.hostname.as_ref() == Some(instance)) .any(|instance| signer.hostname.as_ref() == Some(instance))
{ {
log::warn!("ignoring activity from blocked instance: {}", activity); log::warn!("ignoring activity from blocked instance: {}", activity);
@ -311,7 +280,7 @@ pub async fn receive_activity(
DELETE | LIKE => { DELETE | LIKE => {
// Ignore forwarded Delete and Like activities // Ignore forwarded Delete and Like activities
return Ok(()); return Ok(());
}, }
_ => { _ => {
// Reject other types // Reject other types
log::warn!( log::warn!(
@ -320,7 +289,7 @@ pub async fn receive_activity(
activity_actor, activity_actor,
); );
return Err(AuthenticationError::UnexpectedSigner.into()); return Err(AuthenticationError::UnexpectedSigner.into());
}, }
}; };
}; };
@ -336,31 +305,24 @@ pub async fn receive_activity(
if let ANNOUNCE | CREATE | DELETE | MOVE | UNDO | UPDATE = activity_type { if let ANNOUNCE | CREATE | DELETE | MOVE | UNDO | UPDATE = activity_type {
// Add activity to job queue and release lock // Add activity to job queue and release lock
IncomingActivityJobData::new(activity, is_authenticated) IncomingActivityJobData::new(activity, is_authenticated)
.into_job(db_client, 0).await?; .into_job(db_client, 0)
.await?;
log::debug!("activity added to the queue: {}", activity_type); log::debug!("activity added to the queue: {}", activity_type);
return Ok(()); return Ok(());
}; };
handle_activity( handle_activity(config, db_client, activity, is_authenticated).await
config,
db_client,
activity,
is_authenticated,
).await
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use serde_json::json;
use super::*; use super::*;
use serde_json::json;
#[test] #[test]
fn test_parse_array_with_string() { fn test_parse_array_with_string() {
let value = json!("test"); let value = json!("test");
assert_eq!( assert_eq!(parse_array(&value).unwrap(), vec!["test".to_string()],);
parse_array(&value).unwrap(),
vec!["test".to_string()],
);
} }
#[test] #[test]

View file

@ -1,19 +1,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{ use serde::{de::Error as DeserializerError, Deserialize, Deserializer, Serialize};
Deserialize,
Deserializer,
Serialize,
de::{Error as DeserializerError},
};
use serde_json::Value; use serde_json::Value;
use super::constants::{ use super::constants::{
AP_CONTEXT, AP_CONTEXT, MITRA_CONTEXT, W3ID_DATA_INTEGRITY_CONTEXT, W3ID_SECURITY_CONTEXT,
MITRA_CONTEXT,
W3ID_DATA_INTEGRITY_CONTEXT,
W3ID_SECURITY_CONTEXT,
}; };
use super::receiver::parse_property_value; use super::receiver::parse_property_value;
use super::vocabulary::HASHTAG; use super::vocabulary::HASHTAG;
@ -35,7 +27,9 @@ pub struct Link {
pub href: String, pub href: String,
} }
fn default_tag_type() -> String { HASHTAG.to_string() } fn default_tag_type() -> String {
HASHTAG.to_string()
}
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -88,10 +82,9 @@ pub struct EmojiTag {
pub updated: DateTime<Utc>, pub updated: DateTime<Utc>,
} }
pub fn deserialize_value_array<'de, D>( pub fn deserialize_value_array<'de, D>(deserializer: D) -> Result<Vec<Value>, D::Error>
deserializer: D, where
) -> Result<Vec<Value>, D::Error> D: Deserializer<'de>,
where D: Deserializer<'de>
{ {
let maybe_value: Option<Value> = Option::deserialize(deserializer)?; let maybe_value: Option<Value> = Option::deserialize(deserializer)?;
let values = if let Some(value) = maybe_value { let values = if let Some(value) = maybe_value {
@ -124,10 +117,7 @@ pub struct Object {
pub quote_url: Option<String>, pub quote_url: Option<String>,
pub sensitive: Option<bool>, pub sensitive: Option<bool>,
#[serde( #[serde(default, deserialize_with = "deserialize_value_array")]
default,
deserialize_with = "deserialize_value_array",
)]
pub tag: Vec<Value>, pub tag: Vec<Value>,
pub to: Option<Value>, pub to: Option<Value>,

View file

@ -1,14 +1,8 @@
use std::time::Instant; use std::time::Instant;
use actix_web::{ use actix_web::{
get, get, http::header as http_header, http::header::HeaderMap, post, web, HttpRequest,
post, HttpResponse, Scope,
web,
http::header as http_header,
http::header::HeaderMap,
HttpRequest,
HttpResponse,
Scope,
}; };
use serde::Deserialize; use serde::Deserialize;
use tokio::sync::Mutex; use tokio::sync::Mutex;
@ -23,36 +17,23 @@ use mitra_models::{
users::queries::get_user_by_name, users::queries::get_user_by_name,
}; };
use crate::errors::HttpError; use super::actors::types::{get_instance_actor, get_local_actor};
use crate::web_client::urls::{
get_post_page_url,
get_profile_page_url,
get_tag_page_url,
};
use super::actors::types::{get_local_actor, get_instance_actor};
use super::builders::{ use super::builders::{
announce::build_announce, announce::build_announce,
create_note::{ create_note::{build_create_note, build_emoji_tag, build_note},
build_emoji_tag,
build_note,
build_create_note,
},
};
use super::collections::{
OrderedCollection,
OrderedCollectionPage,
}; };
use super::collections::{OrderedCollection, OrderedCollectionPage};
use super::constants::{AP_MEDIA_TYPE, AS_MEDIA_TYPE}; use super::constants::{AP_MEDIA_TYPE, AS_MEDIA_TYPE};
use super::identifiers::{ use super::identifiers::{
local_actor_followers, local_actor_followers, local_actor_following, local_actor_outbox, local_actor_subscribers,
local_actor_following,
local_actor_subscribers,
local_actor_outbox,
}; };
use super::receiver::{receive_activity, HandlerError}; use super::receiver::{receive_activity, HandlerError};
use crate::errors::HttpError;
use crate::web_client::urls::{get_post_page_url, get_profile_page_url, get_tag_page_url};
pub fn is_activitypub_request(headers: &HeaderMap) -> bool { pub fn is_activitypub_request(headers: &HeaderMap) -> bool {
let maybe_user_agent = headers.get(http_header::USER_AGENT) let maybe_user_agent = headers
.get(http_header::USER_AGENT)
.and_then(|value| value.to_str().ok()); .and_then(|value| value.to_str().ok());
if let Some(user_agent) = maybe_user_agent { if let Some(user_agent) = maybe_user_agent {
if user_agent.contains("THIS. IS. GNU social!!!!") { if user_agent.contains("THIS. IS. GNU social!!!!") {
@ -67,7 +48,9 @@ pub fn is_activitypub_request(headers: &HeaderMap) -> bool {
"application/json", "application/json",
]; ];
if let Some(content_type) = headers.get(http_header::ACCEPT) { if let Some(content_type) = headers.get(http_header::ACCEPT) {
let content_type_str = content_type.to_str().ok() let content_type_str = content_type
.to_str()
.ok()
// Take first content type if there are many // Take first content type if there are many
.and_then(|value| value.split(',').next()) .and_then(|value| value.split(',').next())
.unwrap_or(""); .unwrap_or("");
@ -86,20 +69,15 @@ async fn actor_view(
let db_client = &**get_database_client(&db_pool).await?; let db_client = &**get_database_client(&db_pool).await?;
let user = get_user_by_name(db_client, &username).await?; let user = get_user_by_name(db_client, &username).await?;
if !is_activitypub_request(request.headers()) { if !is_activitypub_request(request.headers()) {
let page_url = get_profile_page_url( let page_url = get_profile_page_url(&config.instance_url(), &user.profile.username);
&config.instance_url(),
&user.profile.username,
);
let response = HttpResponse::Found() let response = HttpResponse::Found()
.append_header((http_header::LOCATION, page_url)) .append_header((http_header::LOCATION, page_url))
.finish(); .finish();
return Ok(response); return Ok(response);
}; };
let actor = get_local_actor(&user, &config.instance_url()) let actor =
.map_err(|_| HttpError::InternalError)?; get_local_actor(&user, &config.instance_url()).map_err(|_| HttpError::InternalError)?;
let response = HttpResponse::Ok() let response = HttpResponse::Ok().content_type(AP_MEDIA_TYPE).json(actor);
.content_type(AP_MEDIA_TYPE)
.json(actor);
Ok(response) Ok(response)
} }
@ -126,19 +104,16 @@ async fn inbox(
activity["id"].as_str().unwrap_or_default(), activity["id"].as_str().unwrap_or_default(),
); );
let db_client = &mut **get_database_client(&db_pool).await?; let db_client = &mut **get_database_client(&db_pool).await?;
receive_activity(&config, db_client, &request, &activity).await receive_activity(&config, db_client, &request, &activity)
.await
.map_err(|error| { .map_err(|error| {
// TODO: preserve original error text in DatabaseError // TODO: preserve original error text in DatabaseError
if let HandlerError::DatabaseError( if let HandlerError::DatabaseError(DatabaseError::DatabaseClientError(ref pg_error)) =
DatabaseError::DatabaseClientError(ref pg_error)) = error error
{ {
log::error!("database client error: {}", pg_error); log::error!("database client error: {}", pg_error);
}; };
log::warn!( log::warn!("failed to process activity ({}): {}", error, activity,);
"failed to process activity ({}): {}",
error,
activity,
);
error error
})?; })?;
Ok(HttpResponse::Accepted().finish()) Ok(HttpResponse::Accepted().finish())
@ -160,11 +135,7 @@ async fn outbox(
let collection_id = local_actor_outbox(&instance.url(), &username); let collection_id = local_actor_outbox(&instance.url(), &username);
let first_page_id = format!("{}?page=true", collection_id); let first_page_id = format!("{}?page=true", collection_id);
if query_params.page.is_none() { if query_params.page.is_none() {
let collection = OrderedCollection::new( let collection = OrderedCollection::new(collection_id, Some(first_page_id), None);
collection_id,
Some(first_page_id),
None,
);
let response = HttpResponse::Ok() let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE) .content_type(AP_MEDIA_TYPE)
.json(collection); .json(collection);
@ -182,27 +153,22 @@ async fn outbox(
true, // include reposts true, // include reposts
None, None,
COLLECTION_PAGE_SIZE, COLLECTION_PAGE_SIZE,
).await?; )
.await?;
add_related_posts(db_client, posts.iter_mut().collect()).await?; add_related_posts(db_client, posts.iter_mut().collect()).await?;
let activities = posts.iter().map(|post| { let activities = posts
if post.repost_of_id.is_some() { .iter()
let activity = build_announce(&instance.url(), post); .map(|post| {
serde_json::to_value(activity) if post.repost_of_id.is_some() {
.expect("activity should be serializable") let activity = build_announce(&instance.url(), post);
} else { serde_json::to_value(activity).expect("activity should be serializable")
let activity = build_create_note( } else {
&instance.hostname(), let activity = build_create_note(&instance.hostname(), &instance.url(), post);
&instance.url(), serde_json::to_value(activity).expect("activity should be serializable")
post, }
); })
serde_json::to_value(activity) .collect();
.expect("activity should be serializable") let collection_page = OrderedCollectionPage::new(first_page_id, activities);
}
}).collect();
let collection_page = OrderedCollectionPage::new(
first_page_id,
activities,
);
let response = HttpResponse::Ok() let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE) .content_type(AP_MEDIA_TYPE)
.json(collection_page); .json(collection_page);
@ -227,15 +193,8 @@ async fn followers_collection(
}; };
let db_client = &**get_database_client(&db_pool).await?; let db_client = &**get_database_client(&db_pool).await?;
let user = get_user_by_name(db_client, &username).await?; let user = get_user_by_name(db_client, &username).await?;
let collection_id = local_actor_followers( let collection_id = local_actor_followers(&config.instance_url(), &username);
&config.instance_url(), let collection = OrderedCollection::new(collection_id, None, Some(user.profile.follower_count));
&username,
);
let collection = OrderedCollection::new(
collection_id,
None,
Some(user.profile.follower_count),
);
let response = HttpResponse::Ok() let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE) .content_type(AP_MEDIA_TYPE)
.json(collection); .json(collection);
@ -255,15 +214,9 @@ async fn following_collection(
}; };
let db_client = &**get_database_client(&db_pool).await?; let db_client = &**get_database_client(&db_pool).await?;
let user = get_user_by_name(db_client, &username).await?; let user = get_user_by_name(db_client, &username).await?;
let collection_id = local_actor_following( let collection_id = local_actor_following(&config.instance_url(), &username);
&config.instance_url(), let collection =
&username, OrderedCollection::new(collection_id, None, Some(user.profile.following_count));
);
let collection = OrderedCollection::new(
collection_id,
None,
Some(user.profile.following_count),
);
let response = HttpResponse::Ok() let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE) .content_type(AP_MEDIA_TYPE)
.json(collection); .json(collection);
@ -283,15 +236,9 @@ async fn subscribers_collection(
}; };
let db_client = &**get_database_client(&db_pool).await?; let db_client = &**get_database_client(&db_pool).await?;
let user = get_user_by_name(db_client, &username).await?; let user = get_user_by_name(db_client, &username).await?;
let collection_id = local_actor_subscribers( let collection_id = local_actor_subscribers(&config.instance_url(), &username);
&config.instance_url(), let collection =
&username, OrderedCollection::new(collection_id, None, Some(user.profile.subscriber_count));
);
let collection = OrderedCollection::new(
collection_id,
None,
Some(user.profile.subscriber_count),
);
let response = HttpResponse::Ok() let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE) .content_type(AP_MEDIA_TYPE)
.json(collection); .json(collection);
@ -310,14 +257,9 @@ pub fn actor_scope() -> Scope {
} }
#[get("")] #[get("")]
async fn instance_actor_view( async fn instance_actor_view(config: web::Data<Config>) -> Result<HttpResponse, HttpError> {
config: web::Data<Config>, let actor = get_instance_actor(&config.instance()).map_err(|_| HttpError::InternalError)?;
) -> Result<HttpResponse, HttpError> { let response = HttpResponse::Ok().content_type(AP_MEDIA_TYPE).json(actor);
let actor = get_instance_actor(&config.instance())
.map_err(|_| HttpError::InternalError)?;
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(actor);
Ok(response) Ok(response)
} }
@ -370,9 +312,7 @@ pub async fn object_view(
&config.instance().url(), &config.instance().url(),
&post, &post,
); );
let response = HttpResponse::Ok() let response = HttpResponse::Ok().content_type(AP_MEDIA_TYPE).json(object);
.content_type(AP_MEDIA_TYPE)
.json(object);
Ok(response) Ok(response)
} }
@ -383,17 +323,9 @@ pub async fn emoji_view(
emoji_name: web::Path<String>, emoji_name: web::Path<String>,
) -> Result<HttpResponse, HttpError> { ) -> Result<HttpResponse, HttpError> {
let db_client = &**get_database_client(&db_pool).await?; let db_client = &**get_database_client(&db_pool).await?;
let emoji = get_local_emoji_by_name( let emoji = get_local_emoji_by_name(db_client, &emoji_name).await?;
db_client, let object = build_emoji_tag(&config.instance().url(), &emoji);
&emoji_name, let response = HttpResponse::Ok().content_type(AP_MEDIA_TYPE).json(object);
).await?;
let object = build_emoji_tag(
&config.instance().url(),
&emoji,
);
let response = HttpResponse::Ok()
.content_type(AP_MEDIA_TYPE)
.json(object);
Ok(response) Ok(response)
} }
@ -411,11 +343,11 @@ pub async fn tag_view(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use actix_web::http::{ use actix_web::http::{
header, header,
header::{HeaderMap, HeaderValue}, header::{HeaderMap, HeaderValue},
}; };
use super::*;
#[test] #[test]
fn test_is_activitypub_request_mastodon() { fn test_is_activitypub_request_mastodon() {
@ -442,10 +374,7 @@ mod tests {
#[test] #[test]
fn test_is_activitypub_request_browser() { fn test_is_activitypub_request_browser() {
let mut request_headers = HeaderMap::new(); let mut request_headers = HeaderMap::new();
request_headers.insert( request_headers.insert(header::ACCEPT, HeaderValue::from_static("text/html"));
header::ACCEPT,
HeaderValue::from_static("text/html"),
);
let result = is_activitypub_request(&request_headers); let result = is_activitypub_request(&request_headers);
assert_eq!(result, false); assert_eq!(result, false);
} }

View file

@ -1,8 +1,5 @@
use mitra_config::Instance; use mitra_config::Instance;
use mitra_models::{ use mitra_models::{posts::types::Post, profiles::types::DbActorProfile};
posts::types::Post,
profiles::types::DbActorProfile,
};
use mitra_utils::{ use mitra_utils::{
datetime::get_min_datetime, datetime::get_min_datetime,
html::{clean_html_all, escape_html}, html::{clean_html_all, escape_html},
@ -13,19 +10,16 @@ use crate::webfinger::types::ActorAddress;
const ENTRY_TITLE_MAX_LENGTH: usize = 75; const ENTRY_TITLE_MAX_LENGTH: usize = 75;
fn make_entry( fn make_entry(instance_url: &str, post: &Post) -> String {
instance_url: &str,
post: &Post,
) -> String {
let object_id = local_object_id(instance_url, &post.id); let object_id = local_object_id(instance_url, &post.id);
let content_escaped = escape_html(&post.content); let content_escaped = escape_html(&post.content);
let content_cleaned = clean_html_all(&post.content); let content_cleaned = clean_html_all(&post.content);
// Use trimmed content for title // Use trimmed content for title
let mut title: String = content_cleaned.chars() let mut title: String = content_cleaned
.chars()
.take(ENTRY_TITLE_MAX_LENGTH) .take(ENTRY_TITLE_MAX_LENGTH)
.collect(); .collect();
if title.len() == ENTRY_TITLE_MAX_LENGTH && if title.len() == ENTRY_TITLE_MAX_LENGTH && content_cleaned.len() != ENTRY_TITLE_MAX_LENGTH {
content_cleaned.len() != ENTRY_TITLE_MAX_LENGTH {
title += "..."; title += "...";
}; };
format!( format!(
@ -37,11 +31,11 @@ fn make_entry(
<content type="html">{content}</content> <content type="html">{content}</content>
<link rel="alternate" href="{url}"/> <link rel="alternate" href="{url}"/>
</entry>"#, </entry>"#,
url=object_id, url = object_id,
title=title, title = title,
updated_at=post.created_at.to_rfc3339(), updated_at = post.created_at.to_rfc3339(),
author=post.author.username, author = post.author.username,
content=content_escaped, content = content_escaped,
) )
} }
@ -49,18 +43,10 @@ fn get_feed_url(instance_url: &str, username: &str) -> String {
format!("{}/feeds/users/{}", instance_url, username) format!("{}/feeds/users/{}", instance_url, username)
} }
pub fn make_feed( pub fn make_feed(instance: &Instance, profile: &DbActorProfile, posts: Vec<Post>) -> String {
instance: &Instance,
profile: &DbActorProfile,
posts: Vec<Post>,
) -> String {
let actor_id = local_actor_id(&instance.url(), &profile.username); let actor_id = local_actor_id(&instance.url(), &profile.username);
let actor_name = profile.display_name.as_ref() let actor_name = profile.display_name.as_ref().unwrap_or(&profile.username);
.unwrap_or(&profile.username); let actor_address = ActorAddress::from_profile(&instance.hostname(), profile);
let actor_address = ActorAddress::from_profile(
&instance.hostname(),
profile,
);
let feed_url = get_feed_url(&instance.url(), &profile.username); let feed_url = get_feed_url(&instance.url(), &profile.username);
let feed_title = format!("{} (@{})", actor_name, actor_address); let feed_title = format!("{} (@{})", actor_name, actor_address);
let mut entries = vec![]; let mut entries = vec![];
@ -71,7 +57,7 @@ pub fn make_feed(
if post.created_at > feed_updated_at { if post.created_at > feed_updated_at {
feed_updated_at = post.created_at; feed_updated_at = post.created_at;
}; };
}; }
format!( format!(
r#"<?xml version="1.0" encoding="utf-8"?> r#"<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom"> <feed xmlns="http://www.w3.org/2005/Atom">
@ -81,19 +67,19 @@ pub fn make_feed(
<updated>{updated_at}</updated> <updated>{updated_at}</updated>
{entries} {entries}
</feed>"#, </feed>"#,
id=actor_id, id = actor_id,
url=feed_url, url = feed_url,
title=feed_title, title = feed_title,
updated_at=feed_updated_at.to_rfc3339(), updated_at = feed_updated_at.to_rfc3339(),
entries=entries.join("\n"), entries = entries.join("\n"),
) )
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use chrono::{TimeZone, Utc}; use chrono::{TimeZone, Utc};
use uuid::uuid; use uuid::uuid;
use super::*;
#[test] #[test]
fn test_make_entry() { fn test_make_entry() {
@ -118,8 +104,10 @@ mod tests {
" <title>titletext text text</title>\n", " <title>titletext text text</title>\n",
" <updated>2020-03-03T03:03:03+00:00</updated>\n", " <updated>2020-03-03T03:03:03+00:00</updated>\n",
" <author><name>username</name></author>\n", " <author><name>username</name></author>\n",
r#" <content type="html">&lt;p&gt;title&lt;&#47;p&gt;&lt;p&gt;text&#32;text&#32;text&lt;&#47;p&gt;</content>"#, "\n", r#" <content type="html">&lt;p&gt;title&lt;&#47;p&gt;&lt;p&gt;text&#32;text&#32;text&lt;&#47;p&gt;</content>"#,
r#" <link rel="alternate" href="https://example.org/objects/67e55044-10b1-426f-9247-bb680e5fe0c8"/>"#, "\n", "\n",
r#" <link rel="alternate" href="https://example.org/objects/67e55044-10b1-426f-9247-bb680e5fe0c8"/>"#,
"\n",
"</entry>", "</entry>",
); );
assert_eq!(entry, expected_entry); assert_eq!(entry, expected_entry);

Some files were not shown because too many files have changed in this diff Show more