diff --git a/examples/local_federation/main.rs b/examples/local_federation/main.rs index d23a594..f0ea3c4 100644 --- a/examples/local_federation/main.rs +++ b/examples/local_federation/main.rs @@ -7,6 +7,7 @@ use crate::{ }; use error::Error; use std::{env::args, str::FromStr}; +use tokio::try_join; use tracing::log::{info, LevelFilter}; mod activities; @@ -34,8 +35,10 @@ async fn main() -> Result<(), Error> { .map(|arg| Webserver::from_str(&arg).unwrap()) .unwrap_or(Webserver::Axum); - let alpha = new_instance("localhost:8001", "alpha".to_string()).await?; - let beta = new_instance("localhost:8002", "beta".to_string()).await?; + let (alpha, beta) = try_join!( + new_instance("localhost:8001", "alpha".to_string()), + new_instance("localhost:8002", "beta".to_string()) + )?; listen(&alpha, &webserver)?; listen(&beta, &webserver)?; info!("Local instances started"); diff --git a/src/config.rs b/src/config.rs index 2015750..b76c484 100644 --- a/src/config.rs +++ b/src/config.rs @@ -24,10 +24,14 @@ use async_trait::async_trait; use derive_builder::Builder; use dyn_clone::{clone_trait_object, DynClone}; use moka::future::Cache; +use once_cell::sync::Lazy; +use regex::Regex; +use reqwest::{redirect::Policy, Client}; use reqwest_middleware::ClientWithMiddleware; use rsa::{pkcs8::DecodePrivateKey, RsaPrivateKey}; use serde::de::DeserializeOwned; use std::{ + net::IpAddr, ops::Deref, sync::{ atomic::{AtomicU32, Ordering}, @@ -35,6 +39,7 @@ use std::{ }, time::Duration, }; +use tokio::net::lookup_host; use url::Url; /// Configuration for this library, with various federation related settings @@ -51,9 +56,14 @@ pub struct FederationConfig { /// [crate::fetch::object_id::ObjectId] for more details. #[builder(default = "20")] pub(crate) http_fetch_limit: u32, - #[builder(default = "reqwest::Client::default().into()")] - /// HTTP client used for all outgoing requests. Middleware can be used to add functionality - /// like log tracing or retry of failed requests. + #[builder(default = "default_client()")] + /// HTTP client used for all outgoing requests. When passing a custom client here you should + /// also disable redirects and set timeouts. + /// + /// Middleware can be used to add functionality like log tracing or retry of failed requests. + /// Redirects are disabled by default, because automatic redirect URLs can't be validated. + /// Instead a single redirect is handled manually. The default client sets a timeout of 10s + /// to avoid excessive resource usage when connecting to dead servers. pub(crate) client: ClientWithMiddleware, /// Run library in debug mode. This allows usage of http and localhost urls. It also sends /// outgoing activities synchronously, not in background thread. This helps to make tests @@ -102,6 +112,9 @@ pub struct FederationConfig { pub(crate) queue_retry_count: usize, } +pub(crate) static DOMAIN_REGEX: Lazy = + Lazy::new(|| Regex::new(r"^[a-zA-Z0-9.-]*$").expect("compile regex")); + impl FederationConfig { /// Returns a new config builder with default values. pub fn builder() -> FederationConfigBuilder { @@ -156,17 +169,56 @@ impl FederationConfig { return Ok(()); } - if url.domain().is_none() { + let Some(domain) = url.domain() else { return Err(Error::UrlVerificationError("Url must have a domain")); + }; + if !DOMAIN_REGEX.is_match(domain) { + return Err(Error::UrlVerificationError("Invalid characters in domain")); } - if url.domain() == Some("localhost") && !self.debug { - return Err(Error::UrlVerificationError( - "Localhost is only allowed in debug mode", - )); + // Extra checks only for production mode + if !self.debug { + if url.port().is_some() { + return Err(Error::UrlVerificationError("Explicit port is not allowed")); + } + + // Resolve domain and see if it points to private IP + // TODO: Use is_global() once stabilized + // https://doc.rust-lang.org/std/net/enum.IpAddr.html#method.is_global + let invalid_ip = + lookup_host((domain.to_owned(), 80)) + .await? + .any(|addr| match addr.ip() { + IpAddr::V4(addr) => { + addr.is_private() + || addr.is_link_local() + || addr.is_loopback() + || addr.is_multicast() + } + IpAddr::V6(addr) => { + addr.is_loopback() + || addr.is_multicast() + || ((addr.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local + || ((addr.segments()[0] & 0xffc0) == 0xfe80) // is_unicast_link_local + } + }); + if invalid_ip { + return Err(Error::UrlVerificationError( + "Localhost is only allowed in debug mode", + )); + } } - self.url_verifier.verify(url).await?; + // It is valid but uncommon for domains to end with `.` char. Drop this so it cant be used + // to bypass domain blocklist. Avoid cloning url in common case. + if domain.ends_with('.') { + let mut url = url.clone(); + let domain = &domain[0..domain.len() - 1]; + url.set_host(Some(domain))?; + self.url_verifier.verify(&url).await?; + } else { + self.url_verifier.verify(url).await?; + } Ok(()) } @@ -348,6 +400,17 @@ impl FederationMiddleware { } } +fn default_client() -> ClientWithMiddleware { + let timeout = Duration::from_secs(10); + Client::builder() + .redirect(Policy::none()) + .timeout(timeout) + .connect_timeout(timeout) + .build() + .unwrap_or_else(|_| Client::default()) + .into() +} + #[cfg(test)] #[allow(clippy::unwrap_used)] mod test { diff --git a/src/error.rs b/src/error.rs index 1866e48..4a53fd8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -78,6 +78,9 @@ pub enum Error { /// Attempted to fetch object but the response's id field doesn't match #[error("Attempted to fetch object from {0} but the response's id field doesn't match")] FetchWrongId(Url), + /// I/O error from OS + #[error(transparent)] + IoError(#[from] std::io::Error), /// Other generic errors #[error("{0}")] Other(String), diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index e36e8b2..325e9aa 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -11,7 +11,7 @@ use crate::{ FEDERATION_CONTENT_TYPE, }; use bytes::Bytes; -use http::{HeaderValue, StatusCode}; +use http::{header::LOCATION, HeaderValue, StatusCode}; use serde::de::DeserializeOwned; use std::sync::atomic::Ordering; use tracing::info; @@ -59,7 +59,7 @@ pub async fn fetch_object_http( r#"application/ld+json; profile="https://www.w3.org/ns/activitystreams""#, // activitypub standard r#"application/activity+json; charset=utf-8"#, // mastodon ]; - let res = fetch_object_http_with_accept(url, data, &FETCH_CONTENT_TYPE).await?; + let res = fetch_object_http_with_accept(url, data, &FETCH_CONTENT_TYPE, false).await?; // Ensure correct content-type to prevent vulnerabilities, with case insensitive comparison. let content_type = res @@ -74,6 +74,7 @@ pub async fn fetch_object_http( // Ensure id field matches final url after redirect if res.object_id.as_ref() != Some(&res.url) { if let Some(res_object_id) = res.object_id { + data.config.verify_url_valid(&res_object_id).await?; // If id is different but still on the same domain, attempt to request object // again from url in id field. if res_object_id.domain() == res.url.domain() { @@ -99,6 +100,7 @@ async fn fetch_object_http_with_accept( url: &Url, data: &Data, content_type: &HeaderValue, + recursive: bool, ) -> Result, Error> { let config = &data.config; config.verify_url_valid(url).await?; @@ -131,6 +133,19 @@ async fn fetch_object_http_with_accept( req.send().await? }; + // Allow a single redirect using recursion. Further redirects are ignored. + let location = res.headers().get(LOCATION).and_then(|l| l.to_str().ok()); + if let (Some(location), false) = (location, recursive) { + let location = location.parse()?; + return Box::pin(fetch_object_http_with_accept( + &location, + data, + content_type, + true, + )) + .await; + } + if res.status() == StatusCode::GONE { return Err(Error::ObjectDeleted(url.clone())); } diff --git a/src/fetch/webfinger.rs b/src/fetch/webfinger.rs index 8460245..8d53078 100644 --- a/src/fetch/webfinger.rs +++ b/src/fetch/webfinger.rs @@ -1,5 +1,5 @@ use crate::{ - config::Data, + config::{Data, DOMAIN_REGEX}, error::Error, fetch::{fetch_object_http_with_accept, object_id::ObjectId}, traits::{Actor, Object}, @@ -54,21 +54,31 @@ where .splitn(2, '@') .collect_tuple() .ok_or(WebFingerError::WrongFormat.into_crate_error())?; + + // For production mode make sure that domain doesnt contain any port or path. + if !data.config.debug && !DOMAIN_REGEX.is_match(domain) { + return Err(Error::UrlVerificationError("Invalid characters in domain").into()); + } + let protocol = if data.config.debug { "http" } else { "https" }; let fetch_url = format!("{protocol}://{domain}/.well-known/webfinger?resource=acct:{identifier}"); debug!("Fetching webfinger url: {}", &fetch_url); - let res: Webfinger = fetch_object_http_with_accept( + let res = fetch_object_http_with_accept::<_, Webfinger>( &Url::parse(&fetch_url).map_err(Error::UrlParse)?, data, &WEBFINGER_CONTENT_TYPE, + false, ) - .await? - .object; + .await?; + if res.url.as_str() != fetch_url { + data.config.verify_url_valid(&res.url).await?; + } - debug_assert_eq!(res.subject, format!("acct:{identifier}")); + debug_assert_eq!(res.object.subject, format!("acct:{identifier}")); let links: Vec = res + .object .links .iter() .filter(|link| {