Add macro for building SQL queries

This commit is contained in:
silverpill 2022-01-07 16:33:23 +00:00
parent 5bdee5585a
commit 3ff4c79f0d
6 changed files with 97 additions and 38 deletions

30
Cargo.lock generated
View file

@ -1708,6 +1708,8 @@ dependencies = [
"num_cpus", "num_cpus",
"postgres-protocol", "postgres-protocol",
"postgres-types", "postgres-types",
"postgres_query",
"postgres_query_macro",
"rand 0.8.3", "rand 0.8.3",
"refinery", "refinery",
"regex", "regex",
@ -2124,6 +2126,34 @@ dependencies = [
"uuid", "uuid",
] ]
[[package]]
name = "postgres_query"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfc8ccaefa786c5b8b07f907a3e920bfcbe28612d845b86f1e86715036b3939d"
dependencies = [
"async-trait",
"futures",
"postgres-types",
"postgres_query_macro",
"proc-macro-hack",
"serde",
"thiserror",
"tokio-postgres",
]
[[package]]
name = "postgres_query_macro"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f0dde0ed0b582c2b7a72c3c5a28fbc03d2fc75c0622f150c8628e9d9cbb5bad"
dependencies = [
"proc-macro-hack",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.10" version = "0.2.10"

View file

@ -71,6 +71,9 @@ tokio = { version = "0.2.25", features = ["macros"] }
tokio-postgres = { version = "0.5.5", features = ["with-chrono-0_4", "with-uuid-0_8", "with-serde_json-1"] } tokio-postgres = { version = "0.5.5", features = ["with-chrono-0_4", "with-uuid-0_8", "with-serde_json-1"] }
postgres-types = { version = "0.1.2", features = ["derive", "with-chrono-0_4", "with-uuid-0_8", "with-serde_json-1"] } postgres-types = { version = "0.1.2", features = ["derive", "with-chrono-0_4", "with-uuid-0_8", "with-serde_json-1"] }
postgres-protocol = "0.5.3" postgres-protocol = "0.5.3"
# Used to construct PostgreSQL queries
postgres_query = "0.3.3"
postgres_query_macro = "0.3.1"
# Used to work with URLs # Used to work with URLs
url = "2.2.2" url = "2.2.2"
# Used to generate lexicographically sortable IDs # Used to generate lexicographically sortable IDs

View file

@ -2,6 +2,7 @@ use tokio_postgres::config::{Config as DbConfig};
pub mod int_enum; pub mod int_enum;
pub mod migrate; pub mod migrate;
pub mod query_macro;
#[cfg(test)] #[cfg(test)]
pub mod test_utils; pub mod test_utils;

View file

@ -0,0 +1,7 @@
macro_rules! query {
($($tt:tt)*) => {
postgres_query_macro::proc_macro_hack_query_dynamic!($($tt)*)
};
}
pub(crate) use query;

View file

@ -19,6 +19,9 @@ pub enum DatabaseError {
#[error("database pool error")] #[error("database pool error")]
DatabasePoolError(#[from] deadpool_postgres::PoolError), DatabasePoolError(#[from] deadpool_postgres::PoolError),
#[error("database query error")]
DatabaseQueryError(#[from] postgres_query::Error),
#[error("database client error")] #[error("database client error")]
DatabaseClientError(#[from] tokio_postgres::Error), DatabaseClientError(#[from] tokio_postgres::Error),

View file

@ -5,6 +5,7 @@ use tokio_postgres::GenericClient;
use uuid::Uuid; use uuid::Uuid;
use crate::database::catch_unique_violation; use crate::database::catch_unique_violation;
use crate::database::query_macro::query;
use crate::errors::DatabaseError; use crate::errors::DatabaseError;
use crate::models::attachments::types::DbMediaAttachment; use crate::models::attachments::types::DbMediaAttachment;
use crate::models::cleanup::{ use crate::models::cleanup::{
@ -223,33 +224,36 @@ pub async fn get_home_timeline(
JOIN actor_profile ON post.author_id = actor_profile.id JOIN actor_profile ON post.author_id = actor_profile.id
WHERE WHERE
( (
post.author_id = $1 post.author_id = $current_user_id
OR EXISTS ( OR EXISTS (
SELECT 1 FROM relationship SELECT 1 FROM relationship
WHERE source_id = $1 AND target_id = post.author_id WHERE source_id = $current_user_id AND target_id = post.author_id
) )
) )
AND ( AND (
post.visibility = {visibility_public} post.visibility = {visibility_public}
OR post.author_id = $1 OR post.author_id = $current_user_id
OR EXISTS ( OR EXISTS (
SELECT 1 FROM mention SELECT 1 FROM mention
WHERE post_id = post.id AND profile_id = $1 WHERE post_id = post.id AND profile_id = $current_user_id
) )
) )
AND ($2::uuid IS NULL OR post.id < $2) AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id)
ORDER BY post.id DESC ORDER BY post.id DESC
LIMIT $3 LIMIT $limit
", ",
related_attachments=RELATED_ATTACHMENTS, related_attachments=RELATED_ATTACHMENTS,
related_mentions=RELATED_MENTIONS, related_mentions=RELATED_MENTIONS,
related_tags=RELATED_TAGS, related_tags=RELATED_TAGS,
visibility_public=i16::from(&Visibility::Public), visibility_public=i16::from(&Visibility::Public),
); );
let rows = db_client.query( let query = query!(
statement.as_str(), &statement,
&[&current_user_id, &max_post_id, &limit], current_user_id=current_user_id,
).await?; max_post_id=max_post_id,
limit=limit,
)?;
let rows = db_client.query(query.sql(), query.parameters()).await?;
let posts: Vec<Post> = rows.iter() let posts: Vec<Post> = rows.iter()
.map(Post::try_from) .map(Post::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
@ -293,19 +297,19 @@ pub async fn get_posts_by_author(
max_post_id: Option<Uuid>, max_post_id: Option<Uuid>,
limit: i64, limit: i64,
) -> Result<Vec<Post>, DatabaseError> { ) -> Result<Vec<Post>, DatabaseError> {
let mut condition = "post.author_id = $1 let mut condition = "post.author_id = $profile_id
AND ($2::uuid IS NULL OR post.id < $2)".to_string(); AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id)".to_string();
if !include_replies { if !include_replies {
condition.push_str(" AND post.in_reply_to_id IS NULL"); condition.push_str(" AND post.in_reply_to_id IS NULL");
}; };
let visibility_filter = format!( let visibility_filter = format!(
" AND ( " AND (
post.visibility = {visibility_public} post.visibility = {visibility_public}
OR $4::uuid IS NULL OR $current_user_id::uuid IS NULL
OR post.author_id = $4 OR post.author_id = $current_user_id
OR EXISTS ( OR EXISTS (
SELECT 1 FROM mention SELECT 1 FROM mention
WHERE post_id = post.id AND profile_id = $4 WHERE post_id = post.id AND profile_id = $current_user_id
) )
)", )",
visibility_public=i16::from(&Visibility::Public), visibility_public=i16::from(&Visibility::Public),
@ -322,17 +326,21 @@ pub async fn get_posts_by_author(
JOIN actor_profile ON post.author_id = actor_profile.id JOIN actor_profile ON post.author_id = actor_profile.id
WHERE {condition} WHERE {condition}
ORDER BY post.created_at DESC ORDER BY post.created_at DESC
LIMIT $3 LIMIT $limit
", ",
related_attachments=RELATED_ATTACHMENTS, related_attachments=RELATED_ATTACHMENTS,
related_mentions=RELATED_MENTIONS, related_mentions=RELATED_MENTIONS,
related_tags=RELATED_TAGS, related_tags=RELATED_TAGS,
condition=condition, condition=condition,
); );
let rows = db_client.query( let query = query!(
statement.as_str(), &statement,
&[&profile_id, &max_post_id, &limit, &current_user_id], profile_id=profile_id,
).await?; current_user_id=current_user_id,
max_post_id=max_post_id,
limit=limit,
)?;
let rows = db_client.query(query.sql(), query.parameters()).await?;
let posts: Vec<Post> = rows.iter() let posts: Vec<Post> = rows.iter()
.map(Post::try_from) .map(Post::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
@ -346,6 +354,7 @@ pub async fn get_posts_by_tag(
max_post_id: Option<Uuid>, max_post_id: Option<Uuid>,
limit: i64, limit: i64,
) -> Result<Vec<Post>, DatabaseError> { ) -> Result<Vec<Post>, DatabaseError> {
let tag_name = tag_name.to_lowercase();
let statement = format!( let statement = format!(
" "
SELECT SELECT
@ -358,30 +367,34 @@ pub async fn get_posts_by_tag(
WHERE WHERE
( (
post.visibility = {visibility_public} post.visibility = {visibility_public}
OR $4::uuid IS NULL OR $current_user_id::uuid IS NULL
OR post.author_id = $4 OR post.author_id = $current_user_id
OR EXISTS ( OR EXISTS (
SELECT 1 FROM mention SELECT 1 FROM mention
WHERE post_id = post.id AND profile_id = $4 WHERE post_id = post.id AND profile_id = $current_user_id
) )
) )
AND EXISTS ( AND EXISTS (
SELECT 1 FROM post_tag JOIN tag ON post_tag.tag_id = tag.id SELECT 1 FROM post_tag JOIN tag ON post_tag.tag_id = tag.id
WHERE post_tag.post_id = post.id AND tag.tag_name = $1 WHERE post_tag.post_id = post.id AND tag.tag_name = $tag_name
) )
AND ($2::uuid IS NULL OR post.id < $2) AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id)
ORDER BY post.id DESC ORDER BY post.id DESC
LIMIT $3 LIMIT $limit
", ",
related_attachments=RELATED_ATTACHMENTS, related_attachments=RELATED_ATTACHMENTS,
related_mentions=RELATED_MENTIONS, related_mentions=RELATED_MENTIONS,
related_tags=RELATED_TAGS, related_tags=RELATED_TAGS,
visibility_public=i16::from(&Visibility::Public), visibility_public=i16::from(&Visibility::Public),
); );
let rows = db_client.query( let query = query!(
statement.as_str(), &statement,
&[&tag_name.to_lowercase(), &max_post_id, &limit, &current_user_id], tag_name=tag_name,
).await?; current_user_id=current_user_id,
max_post_id=max_post_id,
limit=limit,
)?;
let rows = db_client.query(query.sql(), query.parameters()).await?;
let posts: Vec<Post> = rows.iter() let posts: Vec<Post> = rows.iter()
.map(Post::try_from) .map(Post::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
@ -428,11 +441,11 @@ pub async fn get_thread(
let condition = format!( let condition = format!(
" "
post.visibility = {visibility_public} post.visibility = {visibility_public}
OR $2::uuid IS NULL OR $current_user_id::uuid IS NULL
OR post.author_id = $2 OR post.author_id = $current_user_id
OR EXISTS ( OR EXISTS (
SELECT 1 FROM mention SELECT 1 FROM mention
WHERE post_id = post.id AND profile_id = $2 WHERE post_id = post.id AND profile_id = $current_user_id
) )
", ",
visibility_public=i16::from(&Visibility::Public), visibility_public=i16::from(&Visibility::Public),
@ -443,7 +456,7 @@ pub async fn get_thread(
WITH RECURSIVE WITH RECURSIVE
ancestors (id, in_reply_to_id) AS ( ancestors (id, in_reply_to_id) AS (
SELECT post.id, post.in_reply_to_id FROM post SELECT post.id, post.in_reply_to_id FROM post
WHERE post.id = $1 AND ({condition}) WHERE post.id = $post_id AND ({condition})
UNION ALL UNION ALL
SELECT post.id, post.in_reply_to_id FROM post SELECT post.id, post.in_reply_to_id FROM post
JOIN ancestors ON post.id = ancestors.in_reply_to_id JOIN ancestors ON post.id = ancestors.in_reply_to_id
@ -471,10 +484,12 @@ pub async fn get_thread(
related_tags=RELATED_TAGS, related_tags=RELATED_TAGS,
condition=condition, condition=condition,
); );
let rows = db_client.query( let query = query!(
statement.as_str(), &statement,
&[&post_id, &current_user_id], post_id=post_id,
).await?; current_user_id=current_user_id,
)?;
let rows = db_client.query(query.sql(), query.parameters()).await?;
let posts: Vec<Post> = rows.iter() let posts: Vec<Post> = rows.iter()
.map(Post::try_from) .map(Post::try_from)
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;