main: consolidate actor uri parsing into Actor::from_uri()

This commit is contained in:
Astro 2023-10-31 04:12:03 +01:00
parent f9a103de9a
commit 316b34c689
2 changed files with 33 additions and 43 deletions

View file

@ -27,18 +27,28 @@ pub struct Actor {
impl Actor { impl Actor {
pub fn from_uri(mut uri: &str) -> Option<Self> { pub fn from_uri(mut uri: &str) -> Option<Self> {
if ! uri.starts_with("https://") { let kind;
return None; let host;
} if uri.starts_with("acct:tag-") {
let off = "acct:tag-".len();
let Some(at) = uri.find('@') else { return None; };
kind = ActorKind::TagRelay(uri[off..at].to_string());
host = Arc::new(uri[at + 1..].to_string());
} else if uri.starts_with("acct:instance-") {
let off = "acct:instance-".len();
let Some(at) = uri.find('@') else { return None; };
kind = ActorKind::InstanceRelay(uri[off..at].to_string());
host = Arc::new(uri[at + 1..].to_string());
} else if uri.starts_with("https://") {
uri = &uri[8..]; uri = &uri[8..];
let parts = uri.split("/").collect::<Vec<_>>(); let parts = uri.split('/').collect::<Vec<_>>();
if parts.len() != 3 { if parts.len() != 3 {
return None; return None;
} }
let Ok(topic) = urlencoding::decode(parts[2]) else { return None; }; let Ok(topic) = urlencoding::decode(parts[2]) else { return None; };
let kind = match parts[1] { kind = match parts[1] {
"tag" => "tag" =>
ActorKind::TagRelay(topic.to_string()), ActorKind::TagRelay(topic.to_string()),
"instance" => "instance" =>
@ -46,7 +56,10 @@ impl Actor {
_ => _ =>
return None, return None,
}; };
let host = Arc::new(parts[0].to_string()); host = Arc::new(parts[0].to_string());
} else {
return None;
}
Some(Actor { host, kind }) Some(Actor { host, kind })
} }

View file

@ -9,7 +9,7 @@ use metrics::increment_counter;
use metrics_util::MetricKindMask; use metrics_util::MetricKindMask;
use metrics_exporter_prometheus::PrometheusBuilder; use metrics_exporter_prometheus::PrometheusBuilder;
use serde_json::json; use serde_json::json;
use std::{net::SocketAddr, sync::Arc, time::Duration, collections::HashMap}; use std::{net::SocketAddr, time::Duration, collections::HashMap};
use std::{panic, process}; use std::{panic, process};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use reqwest::Url; use reqwest::Url;
@ -37,7 +37,6 @@ fn track_request(method: &'static str, controller: &'static str, result: &'stati
} }
async fn webfinger( async fn webfinger(
axum::extract::State(state): axum::extract::State<State>,
Query(params): Query<HashMap<String, String>>, Query(params): Query<HashMap<String, String>>,
) -> Response { ) -> Response {
let resource = match params.get("resource") { let resource = match params.get("resource") {
@ -47,29 +46,7 @@ async fn webfinger(
return StatusCode::NOT_FOUND.into_response(); return StatusCode::NOT_FOUND.into_response();
}, },
}; };
let target = if resource.starts_with("acct:") { let Some(target) = Actor::from_uri(resource) else {
let (target_kind, target_host) =
if resource.starts_with("acct:tag-") {
let off = "acct:tag-".len();
let at = resource.find('@');
(actor::ActorKind::TagRelay(resource[off..at.unwrap_or(resource.len())].to_string()),
at.map_or_else(|| state.hostname.clone(), |at| Arc::new(resource[at + 1..].to_string())))
} else if resource.starts_with("acct:instance-") {
let off = "acct:instance-".len();
let at = resource.find('@');
(actor::ActorKind::InstanceRelay(resource[off..at.unwrap_or(resource.len())].to_string()),
at.map_or_else(|| state.hostname.clone(), |at| Arc::new(resource[at + 1..].to_string())))
} else {
track_request("GET", "webfinger", "not_found");
return StatusCode::NOT_FOUND.into_response();
};
actor::Actor {
host: target_host,
kind: target_kind,
}
} else if let Some(target) = Actor::from_uri(resource) {
target
} else {
track_request("GET", "webfinger", "not_found"); track_request("GET", "webfinger", "not_found");
return StatusCode::NOT_FOUND.into_response(); return StatusCode::NOT_FOUND.into_response();
}; };