diff --git a/src/endpoint.rs b/src/endpoint.rs index 9e8a978..664834a 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -9,11 +9,12 @@ use axum::{ use http_digest_headers::{DigestHeader}; -use sigh::{Signature, PublicKey, Key}; +use sigh::{Signature, PublicKey, Key, PrivateKey}; -use crate::fetch::fetch; +use crate::fetch::authorized_fetch; use crate::activitypub::Actor; +use crate::error::Error; const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[ "(request-target)", @@ -21,20 +22,14 @@ const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[ "digest", ]; -#[derive(Clone, Debug)] -pub struct Endpoint { +pub struct Endpoint<'a> { pub payload: serde_json::Value, - pub actor: Actor, + signature: Signature<'a>, + remote_actor_uri: String, } -// impl Endpoint { -// pub fn parse(self) -> Result { -// serde_json::from_value(self.payload) -// } -// } - #[async_trait] -impl FromRequest for Endpoint +impl<'a, S, B> FromRequest for Endpoint<'a> where B: HttpBody + Send + 'static, B::Data: Send, @@ -95,27 +90,32 @@ where // parse body let payload: serde_json::Value = serde_json::from_slice(&bytes) .map_err(|_| (StatusCode::BAD_REQUEST, "Error parsing JSON".to_string()))?; - let actor_uri = if let Some(serde_json::Value::String(actor_uri)) = payload.get("actor") { - actor_uri + let remote_actor_uri = if let Some(serde_json::Value::String(actor_uri)) = payload.get("actor") { + actor_uri.to_string() } else { return Err((StatusCode::BAD_REQUEST, "Actor missing".to_string())); }; - // validate actor - let client = Arc::from_ref(state); - let actor: Actor = - serde_json::from_value( - fetch(&client, actor_uri).await - .map_err(|e| (StatusCode::BAD_GATEWAY, format!("{}", e)))? - ).map_err(|e| (StatusCode::BAD_GATEWAY, format!("Invalid actor: {}", e)))?; - let public_key = PublicKey::from_pem(actor.public_key.pem.as_bytes()) - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?; - if !(signature.verify(&public_key) - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?) - { - return Err((StatusCode::BAD_REQUEST, "Signature verification failed".to_string())); - } - - return Ok(Endpoint { payload, actor }); + return Ok(Endpoint { payload, signature, remote_actor_uri }); + } +} + +impl<'a> Endpoint<'a> { + /// Validates the requesting actor + pub async fn remote_actor( + &self, + client: &reqwest::Client, + key_id: &str, + private_key: &PrivateKey, + ) -> Result { + let remote_actor: Actor = serde_json::from_value( + authorized_fetch(&client, &self.remote_actor_uri, key_id, private_key).await? + )?; + let public_key = PublicKey::from_pem(remote_actor.public_key.pem.as_bytes())?; + if ! (self.signature.verify(&public_key)?) { + return Err(Error::SignatureFail); + } + + Ok(remote_actor) } } diff --git a/src/error.rs b/src/error.rs index 3cd811d..6980939 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,8 @@ pub enum Error { Json(#[from] serde_json::Error), #[error("Signature error")] Signature(#[from] sigh::Error), + #[error("Signature verification failure")] + SignatureFail, #[error("HTTP request error")] HttpReq(#[from] http::Error), #[error("HTTP client error")] diff --git a/src/fetch.rs b/src/fetch.rs index 432d2c3..198c2be 100644 --- a/src/fetch.rs +++ b/src/fetch.rs @@ -3,18 +3,6 @@ use serde::de::DeserializeOwned; use sigh::{PrivateKey, SigningConfig, alg::RsaSha256}; use crate::{digest, error::Error}; -pub async fn fetch(client: &reqwest::Client, url: &str) -> Result -where - T: DeserializeOwned, -{ - client.get(url) - .header("accept", "application/activity+json") - .send() - .await? - .json() - .await -} - pub async fn authorized_fetch( client: &reqwest::Client, uri: &str, diff --git a/src/main.rs b/src/main.rs index bf87c57..1a87153 100644 --- a/src/main.rs +++ b/src/main.rs @@ -120,7 +120,7 @@ async fn get_instance_actor( async fn post_tag_relay( axum::extract::State(state): axum::extract::State, Path(tag): Path, - endpoint: endpoint::Endpoint + endpoint: endpoint::Endpoint<'_> ) -> Response { let target = actor::Actor { host: state.hostname.clone(), @@ -132,7 +132,7 @@ async fn post_tag_relay( async fn post_instance_relay( axum::extract::State(state): axum::extract::State, Path(instance): Path, - endpoint: endpoint::Endpoint + endpoint: endpoint::Endpoint<'_> ) -> Response { let target = actor::Actor { host: state.hostname.clone(), @@ -143,9 +143,19 @@ async fn post_instance_relay( async fn post_relay( state: State, - endpoint: endpoint::Endpoint, + endpoint: endpoint::Endpoint<'_>, target: actor::Actor ) -> Response { + let remote_actor = match endpoint.remote_actor(&state.client, &target.key_id(), &state.priv_key).await { + Ok(remote_actor) => remote_actor, + Err(e) => { + track_request("POST", "relay", "bad_actor"); + return ( + StatusCode::BAD_REQUEST, + format!("Bad actor: {:?}", e) + ).into_response(); + } + }; let action = match serde_json::from_value::>(endpoint.payload.clone()) { Ok(action) => action, Err(e) => { @@ -168,12 +178,12 @@ async fn post_relay( jsonld_context: serde_json::Value::String("https://www.w3.org/ns/activitystreams".to_string()), action_type: "Accept".to_string(), actor: target.uri(), - to: Some(json!(endpoint.actor.id.clone())), + to: Some(json!(remote_actor.id.clone())), id: action.id, object: Some(endpoint.payload), }; let result = send::send( - client.as_ref(), &endpoint.actor.inbox, + client.as_ref(), &remote_actor.inbox, &target.key_id(), &priv_key, &accept, @@ -181,8 +191,8 @@ async fn post_relay( match result { Ok(()) => { match state.database.add_follow( - &endpoint.actor.id, - &endpoint.actor.inbox, + &remote_actor.id, + &remote_actor.inbox, &target.uri(), ).await { Ok(()) => { @@ -208,7 +218,7 @@ async fn post_relay( ).into_response() } else if action.action_type == "Undo" && object_type == Some("Follow".to_string()) { match state.database.del_follow( - &endpoint.actor.id, + &remote_actor.id, &target.uri(), ).await { Ok(()) => {