diff --git a/src/actor.rs b/src/actor.rs index 38596d6..8b4c9ed 100644 --- a/src/actor.rs +++ b/src/actor.rs @@ -27,26 +27,39 @@ pub struct Actor { impl Actor { pub fn from_uri(mut uri: &str) -> Option { - if ! uri.starts_with("https://") { + let kind; + let host; + if uri.starts_with("acct:tag-") { + let off = "acct:tag-".len(); + let Some(at) = uri.find('@') else { return None; }; + kind = ActorKind::TagRelay(uri[off..at].to_string()); + host = Arc::new(uri[at + 1..].to_string()); + } else if uri.starts_with("acct:instance-") { + let off = "acct:instance-".len(); + let Some(at) = uri.find('@') else { return None; }; + kind = ActorKind::InstanceRelay(uri[off..at].to_string()); + host = Arc::new(uri[at + 1..].to_string()); + } else if uri.starts_with("https://") { + uri = &uri[8..]; + + let parts = uri.split('/').collect::>(); + if parts.len() != 3 { + return None; + } + + let Ok(topic) = urlencoding::decode(parts[2]) else { return None; }; + kind = match parts[1] { + "tag" => + ActorKind::TagRelay(topic.to_string()), + "instance" => + ActorKind::InstanceRelay(topic.to_string()), + _ => + return None, + }; + host = Arc::new(parts[0].to_string()); + } else { return None; } - uri = &uri[8..]; - - let parts = uri.split("/").collect::>(); - if parts.len() != 3 { - return None; - } - - let Ok(topic) = urlencoding::decode(parts[2]) else { return None; }; - let kind = match parts[1] { - "tag" => - ActorKind::TagRelay(topic.to_string()), - "instance" => - ActorKind::InstanceRelay(topic.to_string()), - _ => - return None, - }; - let host = Arc::new(parts[0].to_string()); Some(Actor { host, kind }) } diff --git a/src/main.rs b/src/main.rs index 1e18e42..c111cd1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use metrics::increment_counter; use metrics_util::MetricKindMask; use metrics_exporter_prometheus::PrometheusBuilder; use serde_json::json; -use std::{net::SocketAddr, sync::Arc, time::Duration, collections::HashMap}; +use std::{net::SocketAddr, time::Duration, collections::HashMap}; use std::{panic, process}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use reqwest::Url; @@ -37,7 +37,6 @@ fn track_request(method: &'static str, controller: &'static str, result: &'stati } async fn webfinger( - axum::extract::State(state): axum::extract::State, Query(params): Query>, ) -> Response { let resource = match params.get("resource") { @@ -47,29 +46,7 @@ async fn webfinger( return StatusCode::NOT_FOUND.into_response(); }, }; - let target = if resource.starts_with("acct:") { - let (target_kind, target_host) = - if resource.starts_with("acct:tag-") { - let off = "acct:tag-".len(); - let at = resource.find('@'); - (actor::ActorKind::TagRelay(resource[off..at.unwrap_or(resource.len())].to_string()), - at.map_or_else(|| state.hostname.clone(), |at| Arc::new(resource[at + 1..].to_string()))) - } else if resource.starts_with("acct:instance-") { - let off = "acct:instance-".len(); - let at = resource.find('@'); - (actor::ActorKind::InstanceRelay(resource[off..at.unwrap_or(resource.len())].to_string()), - at.map_or_else(|| state.hostname.clone(), |at| Arc::new(resource[at + 1..].to_string()))) - } else { - track_request("GET", "webfinger", "not_found"); - return StatusCode::NOT_FOUND.into_response(); - }; - actor::Actor { - host: target_host, - kind: target_kind, - } - } else if let Some(target) = Actor::from_uri(resource) { - target - } else { + let Some(target) = Actor::from_uri(resource) else { track_request("GET", "webfinger", "not_found"); return StatusCode::NOT_FOUND.into_response(); };