Refactor create_post() function

This commit is contained in:
silverpill 2023-01-31 12:47:40 +00:00
parent 70455e5eeb
commit f142bee72b

View file

@ -32,10 +32,138 @@ use super::types::{
Visibility, Visibility,
}; };
async fn create_post_attachments(
db_client: &impl DatabaseClient,
post_id: &Uuid,
author_id: &Uuid,
attachments: Vec<Uuid>,
) -> Result<Vec<DbMediaAttachment>, DatabaseError> {
let attachments_rows = db_client.query(
"
UPDATE media_attachment
SET post_id = $1
WHERE owner_id = $2 AND id = ANY($3)
RETURNING media_attachment
",
&[&post_id, &author_id, &attachments],
).await?;
if attachments_rows.len() != attachments.len() {
// Some attachments were not found
return Err(DatabaseError::NotFound("attachment"));
};
let attachments = attachments_rows.iter()
.map(|row| row.try_get("media_attachment"))
.collect::<Result<_, _>>()?;
Ok(attachments)
}
async fn create_post_mentions(
db_client: &impl DatabaseClient,
post_id: &Uuid,
mentions: Vec<Uuid>,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let mentions_rows = db_client.query(
"
INSERT INTO mention (post_id, profile_id)
SELECT $1, actor_profile.id FROM actor_profile WHERE id = ANY($2)
RETURNING (
SELECT actor_profile FROM actor_profile
WHERE actor_profile.id = profile_id
) AS actor_profile
",
&[&post_id, &mentions],
).await?;
if mentions_rows.len() != mentions.len() {
// Some profiles were not found
return Err(DatabaseError::NotFound("profile"));
};
let profiles = mentions_rows.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
}
async fn create_post_tags(
db_client: &impl DatabaseClient,
post_id: &Uuid,
tags: Vec<String>,
) -> Result<Vec<String>, DatabaseError> {
db_client.execute(
"
INSERT INTO tag (tag_name)
SELECT unnest($1::text[])
ON CONFLICT (tag_name) DO NOTHING
",
&[&tags],
).await?;
let tags_rows = db_client.query(
"
INSERT INTO post_tag (post_id, tag_id)
SELECT $1, tag.id FROM tag WHERE tag_name = ANY($2)
RETURNING (SELECT tag_name FROM tag WHERE tag.id = tag_id)
",
&[&post_id, &tags],
).await?;
if tags_rows.len() != tags.len() {
return Err(DatabaseError::NotFound("tag"));
};
let tags = tags_rows.iter()
.map(|row| row.try_get("tag_name"))
.collect::<Result<_, _>>()?;
Ok(tags)
}
async fn create_post_links(
db_client: &impl DatabaseClient,
post_id: &Uuid,
links: Vec<Uuid>,
) -> Result<Vec<Uuid>, DatabaseError> {
let links_rows = db_client.query(
"
INSERT INTO post_link (source_id, target_id)
SELECT $1, post.id FROM post WHERE id = ANY($2)
RETURNING target_id
",
&[&post_id, &links],
).await?;
if links_rows.len() != links.len() {
return Err(DatabaseError::NotFound("post"));
};
let links = links_rows.iter()
.map(|row| row.try_get("target_id"))
.collect::<Result<_, _>>()?;
Ok(links)
}
async fn create_post_emojis(
db_client: &impl DatabaseClient,
post_id: &Uuid,
emojis: Vec<Uuid>,
) -> Result<Vec<DbEmoji>, DatabaseError> {
let emojis_rows = db_client.query(
"
INSERT INTO post_emoji (post_id, emoji_id)
SELECT $1, emoji.id FROM emoji WHERE id = ANY($2)
RETURNING (
SELECT emoji FROM emoji
WHERE emoji.id = emoji_id
)
",
&[&post_id, &emojis],
).await?;
if emojis_rows.len() != emojis.len() {
return Err(DatabaseError::NotFound("emoji"));
};
let emojis = emojis_rows.iter()
.map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?;
Ok(emojis)
}
pub async fn create_post( pub async fn create_post(
db_client: &mut impl DatabaseClient, db_client: &mut impl DatabaseClient,
author_id: &Uuid, author_id: &Uuid,
data: PostCreateData, post_data: PostCreateData,
) -> Result<Post, DatabaseError> { ) -> Result<Post, DatabaseError> {
let transaction = db_client.transaction().await?; let transaction = db_client.transaction().await?;
let post_id = new_uuid(); let post_id = new_uuid();
@ -73,109 +201,46 @@ pub async fn create_post(
&[ &[
&post_id, &post_id,
&author_id, &author_id,
&data.content, &post_data.content,
&data.in_reply_to_id, &post_data.in_reply_to_id,
&data.repost_of_id, &post_data.repost_of_id,
&data.visibility, &post_data.visibility,
&data.object_id, &post_data.object_id,
&data.created_at, &post_data.created_at,
], ],
).await.map_err(catch_unique_violation("post"))?; ).await.map_err(catch_unique_violation("post"))?;
// Return NotFound error if reply/repost is not allowed // Return NotFound error if reply/repost is not allowed
let post_row = maybe_post_row.ok_or(DatabaseError::NotFound("post"))?; let post_row = maybe_post_row.ok_or(DatabaseError::NotFound("post"))?;
let db_post: DbPost = post_row.try_get("post")?; let db_post: DbPost = post_row.try_get("post")?;
// Create links to attachments
let attachments_rows = transaction.query( // Create related objects
" let db_attachments = create_post_attachments(
UPDATE media_attachment &transaction,
SET post_id = $1 &db_post.id,
WHERE owner_id = $2 AND id = ANY($3) &db_post.author_id,
RETURNING media_attachment post_data.attachments,
",
&[&post_id, &author_id, &data.attachments],
).await?; ).await?;
if attachments_rows.len() != data.attachments.len() { let db_mentions = create_post_mentions(
// Some attachments were not found &transaction,
return Err(DatabaseError::NotFound("attachment")); &db_post.id,
}; post_data.mentions,
let db_attachments: Vec<DbMediaAttachment> = attachments_rows.iter()
.map(|row| row.try_get("media_attachment"))
.collect::<Result<_, _>>()?;
// Create mentions
let mentions_rows = transaction.query(
"
INSERT INTO mention (post_id, profile_id)
SELECT $1, actor_profile.id FROM actor_profile WHERE id = ANY($2)
RETURNING (
SELECT actor_profile FROM actor_profile
WHERE actor_profile.id = profile_id
) AS actor_profile
",
&[&db_post.id, &data.mentions],
).await?; ).await?;
if mentions_rows.len() != data.mentions.len() { let db_tags = create_post_tags(
// Some profiles were not found &transaction,
return Err(DatabaseError::NotFound("profile")); &db_post.id,
}; post_data.tags,
let db_mentions: Vec<DbActorProfile> = mentions_rows.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
// Create tags
transaction.execute(
"
INSERT INTO tag (tag_name)
SELECT unnest($1::text[])
ON CONFLICT (tag_name) DO NOTHING
",
&[&data.tags],
).await?; ).await?;
let tags_rows = transaction.query( let db_links = create_post_links(
" &transaction,
INSERT INTO post_tag (post_id, tag_id) &db_post.id,
SELECT $1, tag.id FROM tag WHERE tag_name = ANY($2) post_data.links,
RETURNING (SELECT tag_name FROM tag WHERE tag.id = tag_id)
",
&[&db_post.id, &data.tags],
).await?; ).await?;
if tags_rows.len() != data.tags.len() { let db_emojis = create_post_emojis(
return Err(DatabaseError::NotFound("tag")); &transaction,
}; &db_post.id,
let db_tags: Vec<String> = tags_rows.iter() post_data.emojis,
.map(|row| row.try_get("tag_name"))
.collect::<Result<_, _>>()?;
// Create links
let links_rows = transaction.query(
"
INSERT INTO post_link (source_id, target_id)
SELECT $1, post.id FROM post WHERE id = ANY($2)
RETURNING target_id
",
&[&db_post.id, &data.links],
).await?; ).await?;
if links_rows.len() != data.links.len() {
return Err(DatabaseError::NotFound("post"));
};
let db_links: Vec<Uuid> = links_rows.iter()
.map(|row| row.try_get("target_id"))
.collect::<Result<_, _>>()?;
// Create emojis
let emojis_rows = transaction.query(
"
INSERT INTO post_emoji (post_id, emoji_id)
SELECT $1, emoji.id FROM emoji WHERE id = ANY($2)
RETURNING (
SELECT emoji FROM emoji
WHERE emoji.id = emoji_id
)
",
&[&db_post.id, &data.emojis],
).await?;
if emojis_rows.len() != data.emojis.len() {
return Err(DatabaseError::NotFound("emoji"));
};
let db_emojis: Vec<DbEmoji> = emojis_rows.iter()
.map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?;
// Update counters // Update counters
let author = update_post_count(&transaction, &db_post.author_id, 1).await?; let author = update_post_count(&transaction, &db_post.author_id, 1).await?;
let mut notified_users = vec![]; let mut notified_users = vec![];
@ -226,8 +291,7 @@ pub async fn create_post(
).await?; ).await?;
}; };
}; };
// Construct post object
transaction.commit().await?;
let post = Post::new( let post = Post::new(
db_post, db_post,
author, author,
@ -237,9 +301,38 @@ pub async fn create_post(
db_links, db_links,
db_emojis, db_emojis,
)?; )?;
transaction.commit().await?;
Ok(post) Ok(post)
} }
pub async fn update_post(
db_client: &impl DatabaseClient,
post_id: &Uuid,
post_data: PostUpdateData,
) -> Result<(), DatabaseError> {
// Reposts and immutable posts can't be updated
let updated_count = db_client.execute(
"
UPDATE post
SET
content = $1,
updated_at = $2
WHERE id = $3
AND repost_of_id IS NULL
AND ipfs_cid IS NULL
",
&[
&post_data.content,
&post_data.updated_at,
&post_id,
],
).await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("post"));
};
Ok(())
}
pub const RELATED_ATTACHMENTS: &str = pub const RELATED_ATTACHMENTS: &str =
"ARRAY( "ARRAY(
SELECT media_attachment SELECT media_attachment
@ -765,34 +858,6 @@ pub async fn get_post_by_ipfs_cid(
Ok(post) Ok(post)
} }
pub async fn update_post(
db_client: &impl DatabaseClient,
post_id: &Uuid,
post_data: PostUpdateData,
) -> Result<(), DatabaseError> {
// Reposts and immutable posts can't be updated
let updated_count = db_client.execute(
"
UPDATE post
SET
content = $1,
updated_at = $2
WHERE id = $3
AND repost_of_id IS NULL
AND ipfs_cid IS NULL
",
&[
&post_data.content,
&post_data.updated_at,
&post_id,
],
).await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("post"));
}
Ok(())
}
pub async fn update_reply_count( pub async fn update_reply_count(
db_client: &impl DatabaseClient, db_client: &impl DatabaseClient,
post_id: &Uuid, post_id: &Uuid,