Refactor create_post() function

This commit is contained in:
silverpill 2023-01-31 12:47:40 +00:00
parent 70455e5eeb
commit f142bee72b

View file

@ -32,10 +32,138 @@ use super::types::{
Visibility,
};
async fn create_post_attachments(
db_client: &impl DatabaseClient,
post_id: &Uuid,
author_id: &Uuid,
attachments: Vec<Uuid>,
) -> Result<Vec<DbMediaAttachment>, DatabaseError> {
let attachments_rows = db_client.query(
"
UPDATE media_attachment
SET post_id = $1
WHERE owner_id = $2 AND id = ANY($3)
RETURNING media_attachment
",
&[&post_id, &author_id, &attachments],
).await?;
if attachments_rows.len() != attachments.len() {
// Some attachments were not found
return Err(DatabaseError::NotFound("attachment"));
};
let attachments = attachments_rows.iter()
.map(|row| row.try_get("media_attachment"))
.collect::<Result<_, _>>()?;
Ok(attachments)
}
async fn create_post_mentions(
db_client: &impl DatabaseClient,
post_id: &Uuid,
mentions: Vec<Uuid>,
) -> Result<Vec<DbActorProfile>, DatabaseError> {
let mentions_rows = db_client.query(
"
INSERT INTO mention (post_id, profile_id)
SELECT $1, actor_profile.id FROM actor_profile WHERE id = ANY($2)
RETURNING (
SELECT actor_profile FROM actor_profile
WHERE actor_profile.id = profile_id
) AS actor_profile
",
&[&post_id, &mentions],
).await?;
if mentions_rows.len() != mentions.len() {
// Some profiles were not found
return Err(DatabaseError::NotFound("profile"));
};
let profiles = mentions_rows.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
Ok(profiles)
}
async fn create_post_tags(
db_client: &impl DatabaseClient,
post_id: &Uuid,
tags: Vec<String>,
) -> Result<Vec<String>, DatabaseError> {
db_client.execute(
"
INSERT INTO tag (tag_name)
SELECT unnest($1::text[])
ON CONFLICT (tag_name) DO NOTHING
",
&[&tags],
).await?;
let tags_rows = db_client.query(
"
INSERT INTO post_tag (post_id, tag_id)
SELECT $1, tag.id FROM tag WHERE tag_name = ANY($2)
RETURNING (SELECT tag_name FROM tag WHERE tag.id = tag_id)
",
&[&post_id, &tags],
).await?;
if tags_rows.len() != tags.len() {
return Err(DatabaseError::NotFound("tag"));
};
let tags = tags_rows.iter()
.map(|row| row.try_get("tag_name"))
.collect::<Result<_, _>>()?;
Ok(tags)
}
async fn create_post_links(
db_client: &impl DatabaseClient,
post_id: &Uuid,
links: Vec<Uuid>,
) -> Result<Vec<Uuid>, DatabaseError> {
let links_rows = db_client.query(
"
INSERT INTO post_link (source_id, target_id)
SELECT $1, post.id FROM post WHERE id = ANY($2)
RETURNING target_id
",
&[&post_id, &links],
).await?;
if links_rows.len() != links.len() {
return Err(DatabaseError::NotFound("post"));
};
let links = links_rows.iter()
.map(|row| row.try_get("target_id"))
.collect::<Result<_, _>>()?;
Ok(links)
}
async fn create_post_emojis(
db_client: &impl DatabaseClient,
post_id: &Uuid,
emojis: Vec<Uuid>,
) -> Result<Vec<DbEmoji>, DatabaseError> {
let emojis_rows = db_client.query(
"
INSERT INTO post_emoji (post_id, emoji_id)
SELECT $1, emoji.id FROM emoji WHERE id = ANY($2)
RETURNING (
SELECT emoji FROM emoji
WHERE emoji.id = emoji_id
)
",
&[&post_id, &emojis],
).await?;
if emojis_rows.len() != emojis.len() {
return Err(DatabaseError::NotFound("emoji"));
};
let emojis = emojis_rows.iter()
.map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?;
Ok(emojis)
}
pub async fn create_post(
db_client: &mut impl DatabaseClient,
author_id: &Uuid,
data: PostCreateData,
post_data: PostCreateData,
) -> Result<Post, DatabaseError> {
let transaction = db_client.transaction().await?;
let post_id = new_uuid();
@ -73,109 +201,46 @@ pub async fn create_post(
&[
&post_id,
&author_id,
&data.content,
&data.in_reply_to_id,
&data.repost_of_id,
&data.visibility,
&data.object_id,
&data.created_at,
&post_data.content,
&post_data.in_reply_to_id,
&post_data.repost_of_id,
&post_data.visibility,
&post_data.object_id,
&post_data.created_at,
],
).await.map_err(catch_unique_violation("post"))?;
// Return NotFound error if reply/repost is not allowed
let post_row = maybe_post_row.ok_or(DatabaseError::NotFound("post"))?;
let db_post: DbPost = post_row.try_get("post")?;
// Create links to attachments
let attachments_rows = transaction.query(
"
UPDATE media_attachment
SET post_id = $1
WHERE owner_id = $2 AND id = ANY($3)
RETURNING media_attachment
",
&[&post_id, &author_id, &data.attachments],
// Create related objects
let db_attachments = create_post_attachments(
&transaction,
&db_post.id,
&db_post.author_id,
post_data.attachments,
).await?;
if attachments_rows.len() != data.attachments.len() {
// Some attachments were not found
return Err(DatabaseError::NotFound("attachment"));
};
let db_attachments: Vec<DbMediaAttachment> = attachments_rows.iter()
.map(|row| row.try_get("media_attachment"))
.collect::<Result<_, _>>()?;
// Create mentions
let mentions_rows = transaction.query(
"
INSERT INTO mention (post_id, profile_id)
SELECT $1, actor_profile.id FROM actor_profile WHERE id = ANY($2)
RETURNING (
SELECT actor_profile FROM actor_profile
WHERE actor_profile.id = profile_id
) AS actor_profile
",
&[&db_post.id, &data.mentions],
let db_mentions = create_post_mentions(
&transaction,
&db_post.id,
post_data.mentions,
).await?;
if mentions_rows.len() != data.mentions.len() {
// Some profiles were not found
return Err(DatabaseError::NotFound("profile"));
};
let db_mentions: Vec<DbActorProfile> = mentions_rows.iter()
.map(|row| row.try_get("actor_profile"))
.collect::<Result<_, _>>()?;
// Create tags
transaction.execute(
"
INSERT INTO tag (tag_name)
SELECT unnest($1::text[])
ON CONFLICT (tag_name) DO NOTHING
",
&[&data.tags],
let db_tags = create_post_tags(
&transaction,
&db_post.id,
post_data.tags,
).await?;
let tags_rows = transaction.query(
"
INSERT INTO post_tag (post_id, tag_id)
SELECT $1, tag.id FROM tag WHERE tag_name = ANY($2)
RETURNING (SELECT tag_name FROM tag WHERE tag.id = tag_id)
",
&[&db_post.id, &data.tags],
let db_links = create_post_links(
&transaction,
&db_post.id,
post_data.links,
).await?;
if tags_rows.len() != data.tags.len() {
return Err(DatabaseError::NotFound("tag"));
};
let db_tags: Vec<String> = tags_rows.iter()
.map(|row| row.try_get("tag_name"))
.collect::<Result<_, _>>()?;
// Create links
let links_rows = transaction.query(
"
INSERT INTO post_link (source_id, target_id)
SELECT $1, post.id FROM post WHERE id = ANY($2)
RETURNING target_id
",
&[&db_post.id, &data.links],
let db_emojis = create_post_emojis(
&transaction,
&db_post.id,
post_data.emojis,
).await?;
if links_rows.len() != data.links.len() {
return Err(DatabaseError::NotFound("post"));
};
let db_links: Vec<Uuid> = links_rows.iter()
.map(|row| row.try_get("target_id"))
.collect::<Result<_, _>>()?;
// Create emojis
let emojis_rows = transaction.query(
"
INSERT INTO post_emoji (post_id, emoji_id)
SELECT $1, emoji.id FROM emoji WHERE id = ANY($2)
RETURNING (
SELECT emoji FROM emoji
WHERE emoji.id = emoji_id
)
",
&[&db_post.id, &data.emojis],
).await?;
if emojis_rows.len() != data.emojis.len() {
return Err(DatabaseError::NotFound("emoji"));
};
let db_emojis: Vec<DbEmoji> = emojis_rows.iter()
.map(|row| row.try_get("emoji"))
.collect::<Result<_, _>>()?;
// Update counters
let author = update_post_count(&transaction, &db_post.author_id, 1).await?;
let mut notified_users = vec![];
@ -226,8 +291,7 @@ pub async fn create_post(
).await?;
};
};
transaction.commit().await?;
// Construct post object
let post = Post::new(
db_post,
author,
@ -237,9 +301,38 @@ pub async fn create_post(
db_links,
db_emojis,
)?;
transaction.commit().await?;
Ok(post)
}
pub async fn update_post(
db_client: &impl DatabaseClient,
post_id: &Uuid,
post_data: PostUpdateData,
) -> Result<(), DatabaseError> {
// Reposts and immutable posts can't be updated
let updated_count = db_client.execute(
"
UPDATE post
SET
content = $1,
updated_at = $2
WHERE id = $3
AND repost_of_id IS NULL
AND ipfs_cid IS NULL
",
&[
&post_data.content,
&post_data.updated_at,
&post_id,
],
).await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("post"));
};
Ok(())
}
pub const RELATED_ATTACHMENTS: &str =
"ARRAY(
SELECT media_attachment
@ -765,34 +858,6 @@ pub async fn get_post_by_ipfs_cid(
Ok(post)
}
pub async fn update_post(
db_client: &impl DatabaseClient,
post_id: &Uuid,
post_data: PostUpdateData,
) -> Result<(), DatabaseError> {
// Reposts and immutable posts can't be updated
let updated_count = db_client.execute(
"
UPDATE post
SET
content = $1,
updated_at = $2
WHERE id = $3
AND repost_of_id IS NULL
AND ipfs_cid IS NULL
",
&[
&post_data.content,
&post_data.updated_at,
&post_id,
],
).await?;
if updated_count == 0 {
return Err(DatabaseError::NotFound("post"));
}
Ok(())
}
pub async fn update_reply_count(
db_client: &impl DatabaseClient,
post_id: &Uuid,