Add diesel feature, add ObjectId::dereference_forced (#88)

* Add diesel feature

This can simplify Lemmy code and avoid converting back and forth
to DbUrl type all the time.

* Also add diesel derives for CollectionId

* Add ObjectId::dereference_forced

* no deprecated code

* fmt
This commit is contained in:
Nutomic 2023-12-11 15:04:18 +01:00 committed by GitHub
parent 1f7de85a53
commit 24830070f6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 206 additions and 8 deletions

View file

@ -12,6 +12,7 @@ documentation = "https://docs.rs/activitypub_federation/"
default = ["actix-web", "axum"] default = ["actix-web", "axum"]
actix-web = ["dep:actix-web"] actix-web = ["dep:actix-web"]
axum = ["dep:axum", "dep:tower", "dep:hyper", "dep:http-body-util"] axum = ["dep:axum", "dep:tower", "dep:hyper", "dep:http-body-util"]
diesel = ["dep:diesel"]
[dependencies] [dependencies]
chrono = { version = "0.4.31", features = ["clock"], default-features = false } chrono = { version = "0.4.31", features = ["clock"], default-features = false }
@ -50,6 +51,7 @@ tokio = { version = "1.34.0", features = [
"rt-multi-thread", "rt-multi-thread",
"time", "time",
] } ] }
diesel = { version = "2.1.4", features = ["postgres"], default-features = false, optional = true }
futures = "0.3.29" futures = "0.3.29"
moka = { version = "0.12.1", features = ["future"] } moka = { version = "0.12.1", features = ["future"] }

View file

@ -102,3 +102,92 @@ where
self.0.eq(&other.0) && self.1 == other.1 self.0.eq(&other.0) && self.1 == other.1
} }
} }
#[cfg(feature = "diesel")]
const _IMPL_DIESEL_NEW_TYPE_FOR_COLLECTION_ID: () = {
use diesel::{
backend::Backend,
deserialize::{FromSql, FromStaticSqlRow},
expression::AsExpression,
internal::derives::as_expression::Bound,
pg::Pg,
query_builder::QueryId,
serialize,
serialize::{Output, ToSql},
sql_types::{HasSqlType, SingleValue, Text},
Expression,
Queryable,
};
// TODO: this impl only works for Postgres db because of to_string() call which requires reborrow
impl<Kind, ST> ToSql<ST, Pg> for CollectionId<Kind>
where
Kind: Collection,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
String: ToSql<ST, Pg>,
{
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
let v = self.0.to_string();
<String as ToSql<Text, Pg>>::to_sql(&v, &mut out.reborrow())
}
}
impl<'expr, Kind, ST> AsExpression<ST> for &'expr CollectionId<Kind>
where
Kind: Collection,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, &'expr str>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.as_str())
}
}
impl<Kind, ST> AsExpression<ST> for CollectionId<Kind>
where
Kind: Collection,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, String>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.to_string())
}
}
impl<Kind, ST, DB> FromSql<ST, DB> for CollectionId<Kind>
where
Kind: Collection + Send + 'static,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
String: FromSql<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
fn from_sql(
raw: DB::RawValue<'_>,
) -> Result<Self, Box<dyn ::std::error::Error + Send + Sync>> {
let string: String = FromSql::<ST, DB>::from_sql(raw)?;
Ok(CollectionId::parse(&string)?)
}
}
impl<Kind, ST, DB> Queryable<ST, DB> for CollectionId<Kind>
where
Kind: Collection + Send + 'static,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
String: FromStaticSqlRow<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
type Row = String;
fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
Ok(CollectionId::parse(&row)?)
}
}
impl<Kind> QueryId for CollectionId<Kind>
where
Kind: Collection + 'static,
for<'de2> <Kind as Collection>::Kind: Deserialize<'de2>,
{
type QueryId = Self;
}
};

View file

@ -57,12 +57,12 @@ where
pub struct ObjectId<Kind>(Box<Url>, PhantomData<Kind>) pub struct ObjectId<Kind>(Box<Url>, PhantomData<Kind>)
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>; for<'de2> <Kind as Object>::Kind: Deserialize<'de2>;
impl<Kind> ObjectId<Kind> impl<Kind> ObjectId<Kind>
where where
Kind: Object + Send + Debug + 'static, Kind: Object + Send + Debug + 'static,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
/// Construct a new objectid instance /// Construct a new objectid instance
pub fn parse(url: &str) -> Result<Self, url::ParseError> { pub fn parse(url: &str) -> Result<Self, url::ParseError> {
@ -112,6 +112,24 @@ where
} }
} }
/// If this is a remote object, fetch it from origin instance unconditionally to get the
/// latest version, regardless of refresh interval.
pub async fn dereference_forced(
&self,
data: &Data<<Kind as Object>::DataType>,
) -> Result<Kind, <Kind as Object>::Error>
where
<Kind as Object>::Error: From<Error>,
{
if data.config.is_local_url(&self.0) {
self.dereference_from_db(data)
.await
.map(|o| o.ok_or(Error::NotFound.into()))?
} else {
self.dereference_from_http(data, None).await
}
}
/// Fetch an object from the local db. Instead of falling back to http, this throws an error if /// Fetch an object from the local db. Instead of falling back to http, this throws an error if
/// the object is not found in the database. /// the object is not found in the database.
pub async fn dereference_local( pub async fn dereference_local(
@ -163,7 +181,7 @@ where
impl<Kind> Clone for ObjectId<Kind> impl<Kind> Clone for ObjectId<Kind>
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
ObjectId(self.0.clone(), self.1) ObjectId(self.0.clone(), self.1)
@ -190,7 +208,7 @@ fn should_refetch_object(last_refreshed: DateTime<Utc>) -> bool {
impl<Kind> Display for ObjectId<Kind> impl<Kind> Display for ObjectId<Kind>
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_str()) write!(f, "{}", self.0.as_str())
@ -200,7 +218,7 @@ where
impl<Kind> Debug for ObjectId<Kind> impl<Kind> Debug for ObjectId<Kind>
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.as_str()) write!(f, "{}", self.0.as_str())
@ -210,7 +228,7 @@ where
impl<Kind> From<ObjectId<Kind>> for Url impl<Kind> From<ObjectId<Kind>> for Url
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn from(id: ObjectId<Kind>) -> Self { fn from(id: ObjectId<Kind>) -> Self {
*id.0 *id.0
@ -220,7 +238,7 @@ where
impl<Kind> From<Url> for ObjectId<Kind> impl<Kind> From<Url> for ObjectId<Kind>
where where
Kind: Object + Send + 'static, Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn from(url: Url) -> Self { fn from(url: Url) -> Self {
ObjectId(Box::new(url), PhantomData::<Kind>) ObjectId(Box::new(url), PhantomData::<Kind>)
@ -230,13 +248,102 @@ where
impl<Kind> PartialEq for ObjectId<Kind> impl<Kind> PartialEq for ObjectId<Kind>
where where
Kind: Object, Kind: Object,
for<'de2> <Kind as Object>::Kind: serde::Deserialize<'de2>, for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0) && self.1 == other.1 self.0.eq(&other.0) && self.1 == other.1
} }
} }
#[cfg(feature = "diesel")]
const _IMPL_DIESEL_NEW_TYPE_FOR_OBJECT_ID: () = {
use diesel::{
backend::Backend,
deserialize::{FromSql, FromStaticSqlRow},
expression::AsExpression,
internal::derives::as_expression::Bound,
pg::Pg,
query_builder::QueryId,
serialize,
serialize::{Output, ToSql},
sql_types::{HasSqlType, SingleValue, Text},
Expression,
Queryable,
};
// TODO: this impl only works for Postgres db because of to_string() call which requires reborrow
impl<Kind, ST> ToSql<ST, Pg> for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: ToSql<ST, Pg>,
{
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
let v = self.0.to_string();
<String as ToSql<Text, Pg>>::to_sql(&v, &mut out.reborrow())
}
}
impl<'expr, Kind, ST> AsExpression<ST> for &'expr ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, &'expr str>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.as_str())
}
}
impl<Kind, ST> AsExpression<ST> for ObjectId<Kind>
where
Kind: Object,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
Bound<ST, String>: Expression<SqlType = ST>,
ST: SingleValue,
{
type Expression = Bound<ST, String>;
fn as_expression(self) -> Self::Expression {
Bound::new(self.0.to_string())
}
}
impl<Kind, ST, DB> FromSql<ST, DB> for ObjectId<Kind>
where
Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: FromSql<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
fn from_sql(
raw: DB::RawValue<'_>,
) -> Result<Self, Box<dyn ::std::error::Error + Send + Sync>> {
let string: String = FromSql::<ST, DB>::from_sql(raw)?;
Ok(ObjectId::parse(&string)?)
}
}
impl<Kind, ST, DB> Queryable<ST, DB> for ObjectId<Kind>
where
Kind: Object + Send + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
String: FromStaticSqlRow<ST, DB>,
DB: Backend,
DB: HasSqlType<ST>,
{
type Row = String;
fn build(row: Self::Row) -> diesel::deserialize::Result<Self> {
Ok(ObjectId::parse(&row)?)
}
}
impl<Kind> QueryId for ObjectId<Kind>
where
Kind: Object + 'static,
for<'de2> <Kind as Object>::Kind: Deserialize<'de2>,
{
type QueryId = Self;
}
};
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::*; use super::*;