diff --git a/src/activity_queue.rs b/src/activity_queue.rs index 20852bd..8785971 100644 --- a/src/activity_queue.rs +++ b/src/activity_queue.rs @@ -26,7 +26,7 @@ use tokio::{ task::{JoinHandle, JoinSet}, }; use tracing::{info, warn}; -use url::Url; +use crate::url::Url; /// Send a new activity to the given inboxes with automatic retry on failure. Alternatively you /// can implement your own queue and then send activities using [[crate::activity_sending::SendActivityTask]]. diff --git a/src/activity_sending.rs b/src/activity_sending.rs index f9023ce..7a35a34 100644 --- a/src/activity_sending.rs +++ b/src/activity_sending.rs @@ -27,7 +27,7 @@ use std::{ time::{Duration, SystemTime}, }; use tracing::debug; -use url::Url; +use crate::url::Url; #[derive(Clone, Debug)] /// All info needed to sign and send one activity to one inbox. You should generally use @@ -202,7 +202,7 @@ where } pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap { - let mut host = inbox_url.domain().expect("read inbox domain").to_string(); + let mut host = inbox_url.domain().to_string(); if let Some(port) = inbox_url.port() { host = format!("{}:{}", host, port); } diff --git a/src/actix_web/inbox.rs b/src/actix_web/inbox.rs index 7c10659..dbd1984 100644 --- a/src/actix_web/inbox.rs +++ b/src/actix_web/inbox.rs @@ -59,7 +59,7 @@ mod test { use reqwest::Client; use reqwest_middleware::ClientWithMiddleware; use serde_json::json; - use url::Url; + use crate::url::Url; #[tokio::test] async fn test_receive_activity() { @@ -108,7 +108,7 @@ mod test { async fn test_receive_unparseable_activity() { let (_, _, config) = setup_receive_test().await; - let actor = Url::parse("http://ds9.lemmy.ml/u/lemmy_alpha").unwrap(); + let actor = Url::from_str("http://ds9.lemmy.ml/u/lemmy_alpha").unwrap(); let id = "http://localhost:123/1"; let activity = json!({ "actor": actor.as_str(), @@ -140,7 +140,7 @@ mod test { async fn construct_request(body: &Bytes, actor: &Url) -> TestRequest { let inbox = "https://example.com/inbox"; - let headers = generate_request_headers(&Url::parse(inbox).unwrap()); + let headers = generate_request_headers(&Url::from_str(inbox).unwrap()); let request_builder = ClientWithMiddleware::from(Client::default()) .post(inbox) .headers(headers); diff --git a/src/config.rs b/src/config.rs index 2015750..d3c3608 100644 --- a/src/config.rs +++ b/src/config.rs @@ -35,7 +35,7 @@ use std::{ }, time::Duration, }; -use url::Url; +use crate::url::Url; /// Configuration for this library, with various federation related settings #[derive(Builder, Clone)] @@ -156,11 +156,7 @@ impl FederationConfig { return Ok(()); } - if url.domain().is_none() { - return Err(Error::UrlVerificationError("Url must have a domain")); - } - - if url.domain() == Some("localhost") && !self.debug { + if url.domain() == "localhost" && !self.debug { return Err(Error::UrlVerificationError( "Localhost is only allowed in debug mode", )); @@ -247,7 +243,7 @@ impl Deref for FederationConfig { /// /// ``` /// # use async_trait::async_trait; -/// # use url::Url; +/// # use crate::url::Url; /// # use activitypub_federation::config::UrlVerifier; /// # use activitypub_federation::error::Error; /// # #[derive(Clone)] @@ -351,6 +347,8 @@ impl FederationMiddleware { #[cfg(test)] #[allow(clippy::unwrap_used)] mod test { + use std::str::FromStr; + use super::*; async fn config() -> FederationConfig { @@ -365,10 +363,10 @@ mod test { #[tokio::test] async fn test_url_is_local() -> Result<(), Error> { let config = config().await; - assert!(config.is_local_url(&Url::parse("http://example.com")?)); - assert!(!config.is_local_url(&Url::parse("http://other.com")?)); + assert!(config.is_local_url(&Url::from_str("http://example.com")?)); + assert!(!config.is_local_url(&Url::from_str("http://other.com")?)); // ensure that missing domain doesnt cause crash - assert!(!config.is_local_url(&Url::parse("http://127.0.0.1")?)); + assert!(!config.is_local_url(&Url::from_str("http://127.0.0.1")?)); Ok(()) } diff --git a/src/error.rs b/src/error.rs index 1866e48..edaed0f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,7 +8,7 @@ use rsa::{ }; use std::string::FromUtf8Error; use tokio::task::JoinError; -use url::Url; +use crate::url::Url; /// Error messages returned by this library #[derive(thiserror::Error, Debug)] diff --git a/src/fetch/collection_id.rs b/src/fetch/collection_id.rs index 8c796f4..7a58278 100644 --- a/src/fetch/collection_id.rs +++ b/src/fetch/collection_id.rs @@ -2,9 +2,9 @@ use crate::{config::Data, error::Error, fetch::fetch_object_http, traits::Collec use serde::{Deserialize, Serialize}; use std::{ fmt::{Debug, Display, Formatter}, - marker::PhantomData, + marker::PhantomData, str::FromStr, }; -use url::Url; +use crate::url::Url; /// Typed wrapper for Activitypub Collection ID which helps with dereferencing. #[derive(Serialize, Deserialize)] @@ -21,7 +21,7 @@ where { /// Construct a new CollectionId instance pub fn parse(url: &str) -> Result { - Ok(Self(Box::new(Url::parse(url)?), PhantomData::)) + Ok(Self(Box::new(Url::from_str(url)?), PhantomData::)) } /// Fetches collection over HTTP diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index f078cf6..eee5899 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -15,7 +15,7 @@ use http::{HeaderValue, StatusCode}; use serde::de::DeserializeOwned; use std::sync::atomic::Ordering; use tracing::info; -use url::Url; +use crate::url::Url; /// Typed wrapper for collection IDs pub mod collection_id; @@ -135,13 +135,13 @@ async fn fetch_object_http_with_accept( match serde_json::from_slice(&text) { Ok(object) => Ok(FetchObjectResponse { object, - url, + url: url.into(), content_type, object_id, }), Err(e) => Err(ParseFetchedObject( e, - url, + url.into(), String::from_utf8(Vec::from(text))?, )), } diff --git a/src/fetch/object_id.rs b/src/fetch/object_id.rs index ce52c43..b05b59d 100644 --- a/src/fetch/object_id.rs +++ b/src/fetch/object_id.rs @@ -6,7 +6,7 @@ use std::{ marker::PhantomData, str::FromStr, }; -use url::Url; +use crate::url::Url; impl FromStr for ObjectId where @@ -66,7 +66,7 @@ where { /// Construct a new objectid instance pub fn parse(url: &str) -> Result { - Ok(Self(Box::new(Url::parse(url)?), PhantomData::)) + Ok(Self(Box::new(Url::from_str(url)?), PhantomData::)) } /// Returns a reference to the wrapped URL value diff --git a/src/fetch/webfinger.rs b/src/fetch/webfinger.rs index 8460245..80100bb 100644 --- a/src/fetch/webfinger.rs +++ b/src/fetch/webfinger.rs @@ -10,9 +10,9 @@ use itertools::Itertools; use once_cell::sync::Lazy; use regex::Regex; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fmt::Display}; +use std::{collections::HashMap, fmt::Display, str::FromStr}; use tracing::debug; -use url::Url; +use crate::url::Url; /// Errors relative to webfinger handling #[derive(thiserror::Error, Debug)] @@ -60,7 +60,7 @@ where debug!("Fetching webfinger url: {}", &fetch_url); let res: Webfinger = fetch_object_http_with_accept( - &Url::parse(&fetch_url).map_err(Error::UrlParse)?, + &Url::from_str(&fetch_url).map_err(Error::UrlParse)?, data, &WEBFINGER_CONTENT_TYPE, ) @@ -143,10 +143,10 @@ where /// of discovery. /// /// ``` -/// # use url::Url; +/// # use crate::url::Url; /// # use activitypub_federation::fetch::webfinger::build_webfinger_response; /// let subject = "acct:nutomic@lemmy.ml".to_string(); -/// let url = Url::parse("https://lemmy.ml/u/nutomic")?; +/// let url = Url::from_str("https://lemmy.ml/u/nutomic")?; /// build_webfinger_response(subject, url); /// # Ok::<(), anyhow::Error>(()) /// ``` @@ -162,11 +162,11 @@ pub fn build_webfinger_response(subject: String, url: Url) -> Webfinger { /// will be empty. /// /// ``` -/// # use url::Url; +/// # use crate::url::Url; /// # use activitypub_federation::fetch::webfinger::build_webfinger_response_with_type; /// let subject = "acct:nutomic@lemmy.ml".to_string(); -/// let user = Url::parse("https://lemmy.ml/u/nutomic")?; -/// let group = Url::parse("https://lemmy.ml/c/asklemmy")?; +/// let user = Url::from_str("https://lemmy.ml/u/nutomic")?; +/// let group = Url::from_str("https://lemmy.ml/c/asklemmy")?; /// build_webfinger_response_with_type(subject, vec![ /// (user, Some("Person")), /// (group, Some("Group"))]); diff --git a/src/http_signatures.rs b/src/http_signatures.rs index aa526f9..91a759f 100644 --- a/src/http_signatures.rs +++ b/src/http_signatures.rs @@ -30,9 +30,9 @@ use rsa::{ }; use serde::Deserialize; use sha2::{Digest, Sha256}; -use std::{collections::BTreeMap, fmt::Debug, time::Duration}; +use std::{collections::BTreeMap, fmt::Debug, str::FromStr, time::Duration}; use tracing::debug; -use url::Url; +use crate::url::Url; /// A private/public key pair used for HTTP signatures #[derive(Debug, Clone)] @@ -166,7 +166,7 @@ where None => return Err(Error::ActivitySignatureInvalid.into()), Some(caps) => caps.get(1).expect("regex error").as_str(), }; - let actor_url = Url::parse(actor_id).map_err(|_| Error::ActivitySignatureInvalid)?; + let actor_url = Url::from_str(actor_id).map_err(|_| Error::ActivitySignatureInvalid)?; let actor_id: ObjectId = actor_url.into(); let actor = actor_id.dereference(data).await?; @@ -287,9 +287,9 @@ pub mod test { use rsa::{pkcs1::DecodeRsaPrivateKey, pkcs8::DecodePrivateKey}; use std::str::FromStr; - static ACTOR_ID: Lazy = Lazy::new(|| Url::parse("https://example.com/u/alice").unwrap()); + static ACTOR_ID: Lazy = Lazy::new(|| Url::from_str("https://example.com/u/alice").unwrap()); static INBOX_URL: Lazy = - Lazy::new(|| Url::parse("https://example.com/u/alice/inbox").unwrap()); + Lazy::new(|| Url::from_str("https://example.com/u/alice/inbox").unwrap()); #[tokio::test] async fn test_sign() { diff --git a/src/lib.rs b/src/lib.rs index 0a44fc9..9a2abf7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ pub mod http_signatures; pub mod protocol; pub(crate) mod reqwest_shim; pub mod traits; +pub mod url; use crate::{ config::Data, @@ -33,7 +34,7 @@ use crate::{ pub use activitystreams_kinds as kinds; use serde::{de::DeserializeOwned, Deserialize}; -use url::Url; +use crate::url::Url; /// Mime type for Activitypub data, used for `Accept` and `Content-Type` HTTP headers pub const FEDERATION_CONTENT_TYPE: &str = "application/activity+json"; diff --git a/src/protocol/context.rs b/src/protocol/context.rs index 027ff15..06e429e 100644 --- a/src/protocol/context.rs +++ b/src/protocol/context.rs @@ -22,7 +22,7 @@ use crate::{config::Data, traits::ActivityHandler}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use url::Url; +use crate::url::Url; /// Default context used in Activitypub const DEFAULT_CONTEXT: &str = "https://www.w3.org/ns/activitystreams"; diff --git a/src/protocol/helpers.rs b/src/protocol/helpers.rs index 8c69f65..cb44680 100644 --- a/src/protocol/helpers.rs +++ b/src/protocol/helpers.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Deserializer}; /// /// ``` /// # use activitypub_federation::protocol::helpers::deserialize_one_or_many; -/// # use url::Url; +/// # use crate::url::Url; /// #[derive(serde::Deserialize)] /// struct Note { /// #[serde(deserialize_with = "deserialize_one_or_many")] @@ -52,7 +52,7 @@ where /// /// ``` /// # use activitypub_federation::protocol::helpers::deserialize_one; -/// # use url::Url; +/// # use crate::url::Url; /// #[derive(serde::Deserialize)] /// struct Note { /// #[serde(deserialize_with = "deserialize_one")] @@ -88,7 +88,7 @@ where /// /// ``` /// # use activitypub_federation::protocol::helpers::deserialize_skip_error; -/// # use url::Url; +/// # use crate::url::Url; /// #[derive(serde::Deserialize)] /// struct Note { /// content: String, @@ -121,7 +121,7 @@ mod tests { #[test] fn deserialize_one_multiple_values() { use crate::protocol::helpers::deserialize_one; - use url::Url; + use crate::url::Url; #[derive(serde::Deserialize)] struct Note { #[serde(deserialize_with = "deserialize_one")] diff --git a/src/protocol/public_key.rs b/src/protocol/public_key.rs index d36ee2b..fc7c31c 100644 --- a/src/protocol/public_key.rs +++ b/src/protocol/public_key.rs @@ -1,7 +1,7 @@ //! Struct which is used to federate actor key for HTTP signatures use serde::{Deserialize, Serialize}; -use url::Url; +use crate::url::Url; /// Public key of actors which is used for HTTP signatures. /// diff --git a/src/protocol/verification.rs b/src/protocol/verification.rs index 18595b9..a2392fd 100644 --- a/src/protocol/verification.rs +++ b/src/protocol/verification.rs @@ -1,17 +1,17 @@ //! Verify that received data is valid use crate::error::Error; -use url::Url; +use crate::url::Url; /// Check that both urls have the same domain. If not, return UrlVerificationError. /// /// ``` -/// # use url::Url; +/// # use crate::url::Url; /// # use activitypub_federation::protocol::verification::verify_domains_match; -/// let a = Url::parse("https://example.com/abc")?; -/// let b = Url::parse("https://sample.net/abc")?; +/// let a = Url::from_str("https://example.com/abc")?; +/// let b = Url::from_str("https://sample.net/abc")?; /// assert!(verify_domains_match(&a, &b).is_err()); -/// # Ok::<(), url::ParseError>(()) +/// # Ok::<(), Url::from_strError>(()) /// ``` pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> { if a.domain() != b.domain() { @@ -23,12 +23,12 @@ pub fn verify_domains_match(a: &Url, b: &Url) -> Result<(), Error> { /// Check that both urls are identical. If not, return UrlVerificationError. /// /// ``` -/// # use url::Url; +/// # use crate::url::Url; /// # use activitypub_federation::protocol::verification::verify_urls_match; -/// let a = Url::parse("https://example.com/abc")?; -/// let b = Url::parse("https://example.com/123")?; +/// let a = Url::from_str("https://example.com/abc")?; +/// let b = Url::from_str("https://example.com/123")?; /// assert!(verify_urls_match(&a, &b).is_err()); -/// # Ok::<(), url::ParseError>(()) +/// # Ok::<(), Url::from_strError>(()) /// ``` pub fn verify_urls_match(a: &Url, b: &Url) -> Result<(), Error> { if a != b { diff --git a/src/traits.rs b/src/traits.rs index 9976bda..d5613a9 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use serde::Deserialize; use std::{fmt::Debug, ops::Deref}; -use url::Url; +use crate::url::Url; /// Helper for converting between database structs and federated protocol structs. /// @@ -13,7 +13,7 @@ use url::Url; /// # use activitystreams_kinds::{object::NoteType, public}; /// # use chrono::{Local, DateTime, Utc}; /// # use serde::{Deserialize, Serialize}; -/// # use url::Url; +/// # use crate::url::Url; /// # use activitypub_federation::protocol::{public_key::PublicKey, helpers::deserialize_one_or_many}; /// # use activitypub_federation::config::Data; /// # use activitypub_federation::fetch::object_id::ObjectId; @@ -162,7 +162,7 @@ pub trait Object: Sized + Debug { /// /// ``` /// # use activitystreams_kinds::activity::FollowType; -/// # use url::Url; +/// # use crate::url::Url; /// # use activitypub_federation::fetch::object_id::ObjectId; /// # use activitypub_federation::config::Data; /// # use activitypub_federation::traits::ActivityHandler; diff --git a/src/url.rs b/src/url.rs new file mode 100644 index 0000000..ea6ea72 --- /dev/null +++ b/src/url.rs @@ -0,0 +1,46 @@ +//! Wrapper for `url::Url` type. + +use std::{fmt::{Display, Formatter}, ops::Deref, str::FromStr}; + +use serde::{Deserialize, Serialize}; + +/// Wrapper for `url::Url` type. Has `domain` as mandatory field, and prints plain +/// string for debugging. +#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)] +pub struct Url(url::Url); + +impl Deref for Url { + type Target = url::Url; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Display for Url { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0.to_string()) + } +} + +impl Url { + /// Returns domain of the url + pub fn domain(&self) -> &str { + // TODO: must have error handling, or ensure at creation that it has domain + self.0.domain().expect("has domain") + } +} + +impl From for Url { + fn from(value: url::Url) -> Self { + Url(value) + } +} + +impl FromStr for Url { + type Err = url::ParseError; + + fn from_str(s: &str) -> Result { + Ok(url::Url::from_str(s).map(Url).unwrap()) + } +} \ No newline at end of file