Use macros to create FromSql/ToSql implementations for int enums

This commit is contained in:
silverpill 2021-11-19 17:32:51 +00:00
parent cf5d4db031
commit 5547403200
9 changed files with 73 additions and 92 deletions

View file

@ -9,7 +9,7 @@ use crate::activitypub::views::get_instance_actor_url;
use crate::errors::ConversionError; use crate::errors::ConversionError;
use crate::utils::crypto::deserialize_private_key; use crate::utils::crypto::deserialize_private_key;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug)]
pub enum Environment { pub enum Environment {
Development, Development,
Production, Production,
@ -126,7 +126,7 @@ impl Config {
Instance { Instance {
_url: self.try_instance_url().unwrap(), _url: self.try_instance_url().unwrap(),
actor_key: self.try_instance_rsa_key().unwrap(), actor_key: self.try_instance_rsa_key().unwrap(),
is_private: self.environment == Environment::Development, is_private: matches!(self.environment, Environment::Development),
} }
} }

36
src/database/int_enum.rs Normal file
View file

@ -0,0 +1,36 @@
macro_rules! int_enum_from_sql {
($t:ty) => {
impl<'a> postgres_types::FromSql<'a> for $t {
fn from_sql(
_: &postgres_types::Type,
raw: &'a [u8],
) -> Result<$t, Box<dyn std::error::Error + Sync + Send>> {
let int_value = postgres_protocol::types::int2_from_sql(raw)?;
let value = <$t>::try_from(int_value)?;
Ok(value)
}
postgres_types::accepts!(INT2);
}
}
}
macro_rules! int_enum_to_sql {
($t:ty) => {
impl postgres_types::ToSql for $t {
fn to_sql(
&self, _: &postgres_types::Type,
out: &mut postgres_types::private::BytesMut,
) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
let int_value: i16 = self.into();
postgres_protocol::types::int2_to_sql(int_value, out);
Ok(postgres_types::IsNull::No)
}
postgres_types::accepts!(INT2);
postgres_types::to_sql_checked!();
}
}
}
pub(crate) use {int_enum_from_sql, int_enum_to_sql};

View file

@ -1,3 +1,4 @@
pub mod int_enum;
pub mod migrate; pub mod migrate;
pub type Pool = deadpool_postgres::Pool; pub type Pool = deadpool_postgres::Pool;

View file

@ -1,14 +1,10 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use postgres_protocol::types::{int2_from_sql, int2_to_sql}; use postgres_types::FromSql;
use postgres_types::{
FromSql, ToSql, IsNull, Type,
accepts, to_sql_checked,
private::BytesMut,
};
use uuid::Uuid; use uuid::Uuid;
use crate::database::int_enum::{int_enum_from_sql, int_enum_to_sql};
use crate::errors::ConversionError; use crate::errors::ConversionError;
#[derive(Debug)] #[derive(Debug)]
@ -39,28 +35,8 @@ impl TryFrom<i16> for Timeline {
} }
} }
type SqlError = Box<dyn std::error::Error + Sync + Send>; int_enum_from_sql!(Timeline);
int_enum_to_sql!(Timeline);
impl<'a> FromSql<'a> for Timeline {
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Timeline, SqlError> {
let int_value = int2_from_sql(raw)?;
let timeline = Timeline::try_from(int_value)?;
Ok(timeline)
}
accepts!(INT2);
}
impl ToSql for Timeline {
fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, SqlError> {
let int_value: i16 = self.into();
int2_to_sql(int_value, out);
Ok(IsNull::No)
}
accepts!(INT2);
to_sql_checked!();
}
#[allow(dead_code)] #[allow(dead_code)]
#[derive(FromSql)] #[derive(FromSql)]

View file

@ -25,7 +25,7 @@ async fn create_notification(
) )
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
", ",
&[&sender_id, &recipient_id, &post_id, &i16::from(event_type)], &[&sender_id, &recipient_id, &post_id, &event_type],
).await?; ).await?;
Ok(()) Ok(())
} }

View file

@ -5,6 +5,7 @@ use postgres_types::FromSql;
use tokio_postgres::Row; use tokio_postgres::Row;
use uuid::Uuid; use uuid::Uuid;
use crate::database::int_enum::{int_enum_from_sql, int_enum_to_sql};
use crate::errors::{ConversionError, DatabaseError}; use crate::errors::{ConversionError, DatabaseError};
use crate::models::attachments::types::DbMediaAttachment; use crate::models::attachments::types::DbMediaAttachment;
use crate::models::posts::types::{DbPost, Post}; use crate::models::posts::types::{DbPost, Post};
@ -18,10 +19,11 @@ struct DbNotification {
sender_id: Uuid, sender_id: Uuid,
recipient_id: Uuid, recipient_id: Uuid,
post_id: Option<Uuid>, post_id: Option<Uuid>,
event_type: i16, event_type: EventType,
created_at: DateTime<Utc>, created_at: DateTime<Utc>,
} }
#[derive(Debug)]
pub enum EventType { pub enum EventType {
Follow, Follow,
FollowRequest, FollowRequest,
@ -29,8 +31,8 @@ pub enum EventType {
Reaction, Reaction,
} }
impl From<EventType> for i16 { impl From<&EventType> for i16 {
fn from(value: EventType) -> i16 { fn from(value: &EventType) -> i16 {
match value { match value {
EventType::Follow => 1, EventType::Follow => 1,
EventType::FollowRequest => 2, EventType::FollowRequest => 2,
@ -55,6 +57,9 @@ impl TryFrom<i16> for EventType {
} }
} }
int_enum_from_sql!(EventType);
int_enum_to_sql!(EventType);
pub struct Notification { pub struct Notification {
pub id: i32, pub id: i32,
pub sender: DbActorProfile, pub sender: DbActorProfile,
@ -84,7 +89,7 @@ impl TryFrom<&Row> for Notification {
id: db_notification.id, id: db_notification.id,
sender: db_sender, sender: db_sender,
post: maybe_post, post: maybe_post,
event_type: EventType::try_from(db_notification.event_type)?, event_type: db_notification.event_type,
created_at: db_notification.created_at, created_at: db_notification.created_at,
}; };
Ok(notification) Ok(notification)

View file

@ -1,15 +1,11 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use postgres_protocol::types::{int2_from_sql, int2_to_sql}; use postgres_types::FromSql;
use postgres_types::{
FromSql, ToSql, IsNull, Type,
accepts, to_sql_checked,
private::BytesMut,
};
use tokio_postgres::Row; use tokio_postgres::Row;
use uuid::Uuid; use uuid::Uuid;
use crate::database::int_enum::{int_enum_from_sql, int_enum_to_sql};
use crate::errors::{ConversionError, ValidationError}; use crate::errors::{ConversionError, ValidationError};
use crate::models::attachments::types::DbMediaAttachment; use crate::models::attachments::types::DbMediaAttachment;
use crate::models::profiles::types::DbActorProfile; use crate::models::profiles::types::DbActorProfile;
@ -43,28 +39,8 @@ impl TryFrom<i16> for Visibility {
} }
} }
type SqlError = Box<dyn std::error::Error + Sync + Send>; int_enum_from_sql!(Visibility);
int_enum_to_sql!(Visibility);
impl<'a> FromSql<'a> for Visibility {
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<Visibility, SqlError> {
let int_value = int2_from_sql(raw)?;
let visibility = Visibility::try_from(int_value)?;
Ok(visibility)
}
accepts!(INT2);
}
impl ToSql for Visibility {
fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result<IsNull, SqlError> {
let int_value: i16 = self.into();
int2_to_sql(int_value, out);
Ok(IsNull::No)
}
accepts!(INT2);
to_sql_checked!();
}
#[derive(FromSql)] #[derive(FromSql)]
#[postgres(name = "post")] #[postgres(name = "post")]

View file

@ -12,7 +12,6 @@ use crate::models::profiles::queries::{
}; };
use super::types::{ use super::types::{
DbFollowRequest, DbFollowRequest,
FollowRequest,
FollowRequestStatus, FollowRequestStatus,
Relationship, Relationship,
}; };
@ -142,12 +141,12 @@ pub async fn create_follow_request(
db_client: &impl GenericClient, db_client: &impl GenericClient,
source_id: &Uuid, source_id: &Uuid,
target_id: &Uuid, target_id: &Uuid,
) -> Result<FollowRequest, DatabaseError> { ) -> Result<DbFollowRequest, DatabaseError> {
let request = FollowRequest { let request = DbFollowRequest {
id: Uuid::new_v4(), id: Uuid::new_v4(),
source_id: source_id.to_owned(), source_id: source_id.to_owned(),
target_id: target_id.to_owned(), target_id: target_id.to_owned(),
status: FollowRequestStatus::Pending, request_status: FollowRequestStatus::Pending,
}; };
db_client.execute( db_client.execute(
" "
@ -160,7 +159,7 @@ pub async fn create_follow_request(
&request.id, &request.id,
&request.source_id, &request.source_id,
&request.target_id, &request.target_id,
&i16::from(request.status.clone()), &request.request_status,
], ],
).await?; ).await?;
Ok(request) Ok(request)
@ -171,7 +170,6 @@ pub async fn follow_request_accepted(
request_id: &Uuid, request_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let mut transaction = db_client.transaction().await?; let mut transaction = db_client.transaction().await?;
let status_sql = i16::from(FollowRequestStatus::Accepted);
let maybe_row = transaction.query_opt( let maybe_row = transaction.query_opt(
" "
UPDATE follow_request UPDATE follow_request
@ -179,7 +177,7 @@ pub async fn follow_request_accepted(
WHERE id = $2 WHERE id = $2
RETURNING source_id, target_id RETURNING source_id, target_id
", ",
&[&status_sql, &request_id], &[&FollowRequestStatus::Accepted, &request_id],
).await?; ).await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?;
let source_id: Uuid = row.try_get("source_id")?; let source_id: Uuid = row.try_get("source_id")?;
@ -193,14 +191,13 @@ pub async fn follow_request_rejected(
db_client: &impl GenericClient, db_client: &impl GenericClient,
request_id: &Uuid, request_id: &Uuid,
) -> Result<(), DatabaseError> { ) -> Result<(), DatabaseError> {
let status_sql: i16 = FollowRequestStatus::Rejected.into();
let updated_count = db_client.execute( let updated_count = db_client.execute(
" "
UPDATE follow_request UPDATE follow_request
SET request_status = $1 SET request_status = $1
WHERE id = $2 WHERE id = $2
", ",
&[&status_sql, &request_id], &[&FollowRequestStatus::Rejected, &request_id],
).await?; ).await?;
if updated_count == 0 { if updated_count == 0 {
return Err(DatabaseError::NotFound("follow request")); return Err(DatabaseError::NotFound("follow request"));
@ -228,7 +225,7 @@ pub async fn get_follow_request_by_path(
db_client: &impl GenericClient, db_client: &impl GenericClient,
source_id: &Uuid, source_id: &Uuid,
target_id: &Uuid, target_id: &Uuid,
) -> Result<FollowRequest, DatabaseError> { ) -> Result<DbFollowRequest, DatabaseError> {
let maybe_row = db_client.query_opt( let maybe_row = db_client.query_opt(
" "
SELECT follow_request SELECT follow_request
@ -238,13 +235,6 @@ pub async fn get_follow_request_by_path(
&[&source_id, &target_id], &[&source_id, &target_id],
).await?; ).await?;
let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?; let row = maybe_row.ok_or(DatabaseError::NotFound("follow request"))?;
let db_request: DbFollowRequest = row.try_get("follow_request")?; let request: DbFollowRequest = row.try_get("follow_request")?;
let request_status = FollowRequestStatus::try_from(db_request.request_status)?;
let request = FollowRequest {
id: db_request.id,
source_id: db_request.source_id,
target_id: db_request.target_id,
status: request_status,
};
Ok(request) Ok(request)
} }

View file

@ -5,6 +5,7 @@ use serde::Serialize;
use tokio_postgres::Row; use tokio_postgres::Row;
use uuid::Uuid; use uuid::Uuid;
use crate::database::int_enum::{int_enum_from_sql, int_enum_to_sql};
use crate::errors::ConversionError; use crate::errors::ConversionError;
#[derive(Serialize)] #[derive(Serialize)]
@ -30,15 +31,15 @@ impl TryFrom<&Row> for Relationship {
} }
} }
#[derive(Clone, PartialEq)] #[derive(Debug)]
pub enum FollowRequestStatus { pub enum FollowRequestStatus {
Pending, Pending,
Accepted, Accepted,
Rejected, Rejected,
} }
impl From<FollowRequestStatus> for i16 { impl From<&FollowRequestStatus> for i16 {
fn from(value: FollowRequestStatus) -> i16 { fn from(value: &FollowRequestStatus) -> i16 {
match value { match value {
FollowRequestStatus::Pending => 1, FollowRequestStatus::Pending => 1,
FollowRequestStatus::Accepted => 2, FollowRequestStatus::Accepted => 2,
@ -61,18 +62,14 @@ impl TryFrom<i16> for FollowRequestStatus {
} }
} }
int_enum_from_sql!(FollowRequestStatus);
int_enum_to_sql!(FollowRequestStatus);
#[derive(FromSql)] #[derive(FromSql)]
#[postgres(name = "follow_request")] #[postgres(name = "follow_request")]
pub struct DbFollowRequest { pub struct DbFollowRequest {
pub id: Uuid, pub id: Uuid,
pub source_id: Uuid, pub source_id: Uuid,
pub target_id: Uuid, pub target_id: Uuid,
pub request_status: i16, pub request_status: FollowRequestStatus,
}
pub struct FollowRequest {
pub id: Uuid,
pub source_id: Uuid,
pub target_id: Uuid,
pub status: FollowRequestStatus,
} }