diff --git a/examples/local_federation/main.rs b/examples/local_federation/main.rs index d23a594..f0ea3c4 100644 --- a/examples/local_federation/main.rs +++ b/examples/local_federation/main.rs @@ -7,6 +7,7 @@ use crate::{ }; use error::Error; use std::{env::args, str::FromStr}; +use tokio::try_join; use tracing::log::{info, LevelFilter}; mod activities; @@ -34,8 +35,10 @@ async fn main() -> Result<(), Error> { .map(|arg| Webserver::from_str(&arg).unwrap()) .unwrap_or(Webserver::Axum); - let alpha = new_instance("localhost:8001", "alpha".to_string()).await?; - let beta = new_instance("localhost:8002", "beta".to_string()).await?; + let (alpha, beta) = try_join!( + new_instance("localhost:8001", "alpha".to_string()), + new_instance("localhost:8002", "beta".to_string()) + )?; listen(&alpha, &webserver)?; listen(&beta, &webserver)?; info!("Local instances started"); diff --git a/src/config.rs b/src/config.rs index 63af704..73f7688 100644 --- a/src/config.rs +++ b/src/config.rs @@ -26,11 +26,14 @@ use bytes::Bytes; use derive_builder::Builder; use dyn_clone::{clone_trait_object, DynClone}; use moka::future::Cache; +use once_cell::sync::Lazy; +use regex::Regex; use reqwest::Request; use reqwest_middleware::{ClientWithMiddleware, RequestBuilder}; use rsa::{pkcs8::DecodePrivateKey, RsaPrivateKey}; use serde::de::DeserializeOwned; use std::{ + net::IpAddr, ops::Deref, sync::{ atomic::{AtomicU32, Ordering}, @@ -38,6 +41,7 @@ use std::{ }, time::Duration, }; +use tokio::net::lookup_host; use url::Url; /// Configuration for this library, with various federation related settings @@ -159,14 +163,44 @@ impl FederationConfig { return Ok(()); } - if url.domain().is_none() { + let Some(domain) = url.domain() else { return Err(Error::UrlVerificationError("Url must have a domain")); + }; + + static DOMAIN_REGEX: Lazy = + Lazy::new(|| Regex::new(r"^[a-zA-Z0-9.-]*$").expect("compile regex")); + if !DOMAIN_REGEX.is_match(domain) { + return Err(Error::UrlVerificationError("Invalid characters in domain")); } - if url.domain() == Some("localhost") && !self.debug { - return Err(Error::UrlVerificationError( - "Localhost is only allowed in debug mode", - )); + // Extra checks only for production mode + if !self.debug { + if url.port().is_some() { + return Err(Error::UrlVerificationError("Explicit port is not allowed")); + } + + // Resolve domain and see if it points to private IP + // TODO: Use is_global() once stabilized + // https://doc.rust-lang.org/std/net/enum.IpAddr.html#method.is_global + let invalid_ip = lookup_host(domain).await?.any(|addr| match addr.ip() { + IpAddr::V4(addr) => { + addr.is_private() + || addr.is_link_local() + || addr.is_loopback() + || addr.is_multicast() + } + IpAddr::V6(addr) => { + addr.is_loopback() + || addr.is_multicast() + || ((addr.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local + || ((addr.segments()[0] & 0xffc0) == 0xfe80) // is_unicast_link_local + } + }); + if invalid_ip { + return Err(Error::UrlVerificationError( + "Localhost is only allowed in debug mode", + )); + } } self.url_verifier.verify(url).await?; diff --git a/src/error.rs b/src/error.rs index 1866e48..4a53fd8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -78,6 +78,9 @@ pub enum Error { /// Attempted to fetch object but the response's id field doesn't match #[error("Attempted to fetch object from {0} but the response's id field doesn't match")] FetchWrongId(Url), + /// I/O error from OS + #[error(transparent)] + IoError(#[from] std::io::Error), /// Other generic errors #[error("{0}")] Other(String),