diff --git a/Cargo.lock b/Cargo.lock index cbcf0c4..baab7c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1708,6 +1708,8 @@ dependencies = [ "num_cpus", "postgres-protocol", "postgres-types", + "postgres_query", + "postgres_query_macro", "rand 0.8.3", "refinery", "regex", @@ -2124,6 +2126,34 @@ dependencies = [ "uuid", ] +[[package]] +name = "postgres_query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfc8ccaefa786c5b8b07f907a3e920bfcbe28612d845b86f1e86715036b3939d" +dependencies = [ + "async-trait", + "futures", + "postgres-types", + "postgres_query_macro", + "proc-macro-hack", + "serde", + "thiserror", + "tokio-postgres", +] + +[[package]] +name = "postgres_query_macro" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0dde0ed0b582c2b7a72c3c5a28fbc03d2fc75c0622f150c8628e9d9cbb5bad" +dependencies = [ + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "ppv-lite86" version = "0.2.10" diff --git a/Cargo.toml b/Cargo.toml index 6307052..da41b9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,6 +71,9 @@ tokio = { version = "0.2.25", features = ["macros"] } tokio-postgres = { version = "0.5.5", features = ["with-chrono-0_4", "with-uuid-0_8", "with-serde_json-1"] } postgres-types = { version = "0.1.2", features = ["derive", "with-chrono-0_4", "with-uuid-0_8", "with-serde_json-1"] } postgres-protocol = "0.5.3" +# Used to construct PostgreSQL queries +postgres_query = "0.3.3" +postgres_query_macro = "0.3.1" # Used to work with URLs url = "2.2.2" # Used to generate lexicographically sortable IDs diff --git a/src/database/mod.rs b/src/database/mod.rs index ec7ec08..8cd089e 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -2,6 +2,7 @@ use tokio_postgres::config::{Config as DbConfig}; pub mod int_enum; pub mod migrate; +pub mod query_macro; #[cfg(test)] pub mod test_utils; diff --git a/src/database/query_macro.rs b/src/database/query_macro.rs new file mode 100644 index 0000000..c980b46 --- /dev/null +++ b/src/database/query_macro.rs @@ -0,0 +1,7 @@ +macro_rules! query { + ($($tt:tt)*) => { + postgres_query_macro::proc_macro_hack_query_dynamic!($($tt)*) + }; +} + +pub(crate) use query; diff --git a/src/errors.rs b/src/errors.rs index 89d56fd..8f547d3 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -19,6 +19,9 @@ pub enum DatabaseError { #[error("database pool error")] DatabasePoolError(#[from] deadpool_postgres::PoolError), + #[error("database query error")] + DatabaseQueryError(#[from] postgres_query::Error), + #[error("database client error")] DatabaseClientError(#[from] tokio_postgres::Error), diff --git a/src/models/posts/queries.rs b/src/models/posts/queries.rs index 47ccd97..40a2f64 100644 --- a/src/models/posts/queries.rs +++ b/src/models/posts/queries.rs @@ -5,6 +5,7 @@ use tokio_postgres::GenericClient; use uuid::Uuid; use crate::database::catch_unique_violation; +use crate::database::query_macro::query; use crate::errors::DatabaseError; use crate::models::attachments::types::DbMediaAttachment; use crate::models::cleanup::{ @@ -223,33 +224,36 @@ pub async fn get_home_timeline( JOIN actor_profile ON post.author_id = actor_profile.id WHERE ( - post.author_id = $1 + post.author_id = $current_user_id OR EXISTS ( SELECT 1 FROM relationship - WHERE source_id = $1 AND target_id = post.author_id + WHERE source_id = $current_user_id AND target_id = post.author_id ) ) AND ( post.visibility = {visibility_public} - OR post.author_id = $1 + OR post.author_id = $current_user_id OR EXISTS ( SELECT 1 FROM mention - WHERE post_id = post.id AND profile_id = $1 + WHERE post_id = post.id AND profile_id = $current_user_id ) ) - AND ($2::uuid IS NULL OR post.id < $2) + AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id) ORDER BY post.id DESC - LIMIT $3 + LIMIT $limit ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, visibility_public=i16::from(&Visibility::Public), ); - let rows = db_client.query( - statement.as_str(), - &[¤t_user_id, &max_post_id, &limit], - ).await?; + let query = query!( + &statement, + current_user_id=current_user_id, + max_post_id=max_post_id, + limit=limit, + )?; + let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; @@ -293,19 +297,19 @@ pub async fn get_posts_by_author( max_post_id: Option, limit: i64, ) -> Result, DatabaseError> { - let mut condition = "post.author_id = $1 - AND ($2::uuid IS NULL OR post.id < $2)".to_string(); + let mut condition = "post.author_id = $profile_id + AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id)".to_string(); if !include_replies { condition.push_str(" AND post.in_reply_to_id IS NULL"); }; let visibility_filter = format!( " AND ( post.visibility = {visibility_public} - OR $4::uuid IS NULL - OR post.author_id = $4 + OR $current_user_id::uuid IS NULL + OR post.author_id = $current_user_id OR EXISTS ( SELECT 1 FROM mention - WHERE post_id = post.id AND profile_id = $4 + WHERE post_id = post.id AND profile_id = $current_user_id ) )", visibility_public=i16::from(&Visibility::Public), @@ -322,17 +326,21 @@ pub async fn get_posts_by_author( JOIN actor_profile ON post.author_id = actor_profile.id WHERE {condition} ORDER BY post.created_at DESC - LIMIT $3 + LIMIT $limit ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, condition=condition, ); - let rows = db_client.query( - statement.as_str(), - &[&profile_id, &max_post_id, &limit, ¤t_user_id], - ).await?; + let query = query!( + &statement, + profile_id=profile_id, + current_user_id=current_user_id, + max_post_id=max_post_id, + limit=limit, + )?; + let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; @@ -346,6 +354,7 @@ pub async fn get_posts_by_tag( max_post_id: Option, limit: i64, ) -> Result, DatabaseError> { + let tag_name = tag_name.to_lowercase(); let statement = format!( " SELECT @@ -358,30 +367,34 @@ pub async fn get_posts_by_tag( WHERE ( post.visibility = {visibility_public} - OR $4::uuid IS NULL - OR post.author_id = $4 + OR $current_user_id::uuid IS NULL + OR post.author_id = $current_user_id OR EXISTS ( SELECT 1 FROM mention - WHERE post_id = post.id AND profile_id = $4 + WHERE post_id = post.id AND profile_id = $current_user_id ) ) AND EXISTS ( SELECT 1 FROM post_tag JOIN tag ON post_tag.tag_id = tag.id - WHERE post_tag.post_id = post.id AND tag.tag_name = $1 + WHERE post_tag.post_id = post.id AND tag.tag_name = $tag_name ) - AND ($2::uuid IS NULL OR post.id < $2) + AND ($max_post_id::uuid IS NULL OR post.id < $max_post_id) ORDER BY post.id DESC - LIMIT $3 + LIMIT $limit ", related_attachments=RELATED_ATTACHMENTS, related_mentions=RELATED_MENTIONS, related_tags=RELATED_TAGS, visibility_public=i16::from(&Visibility::Public), ); - let rows = db_client.query( - statement.as_str(), - &[&tag_name.to_lowercase(), &max_post_id, &limit, ¤t_user_id], - ).await?; + let query = query!( + &statement, + tag_name=tag_name, + current_user_id=current_user_id, + max_post_id=max_post_id, + limit=limit, + )?; + let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?; @@ -428,11 +441,11 @@ pub async fn get_thread( let condition = format!( " post.visibility = {visibility_public} - OR $2::uuid IS NULL - OR post.author_id = $2 + OR $current_user_id::uuid IS NULL + OR post.author_id = $current_user_id OR EXISTS ( SELECT 1 FROM mention - WHERE post_id = post.id AND profile_id = $2 + WHERE post_id = post.id AND profile_id = $current_user_id ) ", visibility_public=i16::from(&Visibility::Public), @@ -443,7 +456,7 @@ pub async fn get_thread( WITH RECURSIVE ancestors (id, in_reply_to_id) AS ( SELECT post.id, post.in_reply_to_id FROM post - WHERE post.id = $1 AND ({condition}) + WHERE post.id = $post_id AND ({condition}) UNION ALL SELECT post.id, post.in_reply_to_id FROM post JOIN ancestors ON post.id = ancestors.in_reply_to_id @@ -471,10 +484,12 @@ pub async fn get_thread( related_tags=RELATED_TAGS, condition=condition, ); - let rows = db_client.query( - statement.as_str(), - &[&post_id, ¤t_user_id], - ).await?; + let query = query!( + &statement, + post_id=post_id, + current_user_id=current_user_id, + )?; + let rows = db_client.query(query.sql(), query.parameters()).await?; let posts: Vec = rows.iter() .map(Post::try_from) .collect::>()?;