endpoint: use authorized_fetch for remote_actor validation

This commit is contained in:
Astro 2023-05-14 23:14:55 +02:00
parent f782197b93
commit 01f8cc2f5f
4 changed files with 50 additions and 50 deletions

View file

@ -9,11 +9,12 @@ use axum::{
use http_digest_headers::{DigestHeader}; use http_digest_headers::{DigestHeader};
use sigh::{Signature, PublicKey, Key}; use sigh::{Signature, PublicKey, Key, PrivateKey};
use crate::fetch::fetch; use crate::fetch::authorized_fetch;
use crate::activitypub::Actor; use crate::activitypub::Actor;
use crate::error::Error;
const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[ const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[
"(request-target)", "(request-target)",
@ -21,20 +22,14 @@ const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[
"digest", "digest",
]; ];
#[derive(Clone, Debug)] pub struct Endpoint<'a> {
pub struct Endpoint {
pub payload: serde_json::Value, pub payload: serde_json::Value,
pub actor: Actor, signature: Signature<'a>,
remote_actor_uri: String,
} }
// impl Endpoint {
// pub fn parse<T: DeserializeOwned>(self) -> Result<T, serde_json::Error> {
// serde_json::from_value(self.payload)
// }
// }
#[async_trait] #[async_trait]
impl<S, B> FromRequest<S, B> for Endpoint impl<'a, S, B> FromRequest<S, B> for Endpoint<'a>
where where
B: HttpBody + Send + 'static, B: HttpBody + Send + 'static,
B::Data: Send, B::Data: Send,
@ -95,27 +90,32 @@ where
// parse body // parse body
let payload: serde_json::Value = serde_json::from_slice(&bytes) let payload: serde_json::Value = serde_json::from_slice(&bytes)
.map_err(|_| (StatusCode::BAD_REQUEST, "Error parsing JSON".to_string()))?; .map_err(|_| (StatusCode::BAD_REQUEST, "Error parsing JSON".to_string()))?;
let actor_uri = if let Some(serde_json::Value::String(actor_uri)) = payload.get("actor") { let remote_actor_uri = if let Some(serde_json::Value::String(actor_uri)) = payload.get("actor") {
actor_uri actor_uri.to_string()
} else { } else {
return Err((StatusCode::BAD_REQUEST, "Actor missing".to_string())); return Err((StatusCode::BAD_REQUEST, "Actor missing".to_string()));
}; };
// validate actor return Ok(Endpoint { payload, signature, remote_actor_uri });
let client = Arc::from_ref(state); }
let actor: Actor =
serde_json::from_value(
fetch(&client, actor_uri).await
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("{}", e)))?
).map_err(|e| (StatusCode::BAD_GATEWAY, format!("Invalid actor: {}", e)))?;
let public_key = PublicKey::from_pem(actor.public_key.pem.as_bytes())
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?;
if !(signature.verify(&public_key)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?)
{
return Err((StatusCode::BAD_REQUEST, "Signature verification failed".to_string()));
} }
return Ok(Endpoint { payload, actor }); impl<'a> Endpoint<'a> {
/// Validates the requesting actor
pub async fn remote_actor(
&self,
client: &reqwest::Client,
key_id: &str,
private_key: &PrivateKey,
) -> Result<Actor, Error> {
let remote_actor: Actor = serde_json::from_value(
authorized_fetch(&client, &self.remote_actor_uri, key_id, private_key).await?
)?;
let public_key = PublicKey::from_pem(remote_actor.public_key.pem.as_bytes())?;
if ! (self.signature.verify(&public_key)?) {
return Err(Error::SignatureFail);
}
Ok(remote_actor)
} }
} }

View file

@ -6,6 +6,8 @@ pub enum Error {
Json(#[from] serde_json::Error), Json(#[from] serde_json::Error),
#[error("Signature error")] #[error("Signature error")]
Signature(#[from] sigh::Error), Signature(#[from] sigh::Error),
#[error("Signature verification failure")]
SignatureFail,
#[error("HTTP request error")] #[error("HTTP request error")]
HttpReq(#[from] http::Error), HttpReq(#[from] http::Error),
#[error("HTTP client error")] #[error("HTTP client error")]

View file

@ -3,18 +3,6 @@ use serde::de::DeserializeOwned;
use sigh::{PrivateKey, SigningConfig, alg::RsaSha256}; use sigh::{PrivateKey, SigningConfig, alg::RsaSha256};
use crate::{digest, error::Error}; use crate::{digest, error::Error};
pub async fn fetch<T>(client: &reqwest::Client, url: &str) -> Result<T, reqwest::Error>
where
T: DeserializeOwned,
{
client.get(url)
.header("accept", "application/activity+json")
.send()
.await?
.json()
.await
}
pub async fn authorized_fetch<T>( pub async fn authorized_fetch<T>(
client: &reqwest::Client, client: &reqwest::Client,
uri: &str, uri: &str,

View file

@ -120,7 +120,7 @@ async fn get_instance_actor(
async fn post_tag_relay( async fn post_tag_relay(
axum::extract::State(state): axum::extract::State<State>, axum::extract::State(state): axum::extract::State<State>,
Path(tag): Path<String>, Path(tag): Path<String>,
endpoint: endpoint::Endpoint endpoint: endpoint::Endpoint<'_>
) -> Response { ) -> Response {
let target = actor::Actor { let target = actor::Actor {
host: state.hostname.clone(), host: state.hostname.clone(),
@ -132,7 +132,7 @@ async fn post_tag_relay(
async fn post_instance_relay( async fn post_instance_relay(
axum::extract::State(state): axum::extract::State<State>, axum::extract::State(state): axum::extract::State<State>,
Path(instance): Path<String>, Path(instance): Path<String>,
endpoint: endpoint::Endpoint endpoint: endpoint::Endpoint<'_>
) -> Response { ) -> Response {
let target = actor::Actor { let target = actor::Actor {
host: state.hostname.clone(), host: state.hostname.clone(),
@ -143,9 +143,19 @@ async fn post_instance_relay(
async fn post_relay( async fn post_relay(
state: State, state: State,
endpoint: endpoint::Endpoint, endpoint: endpoint::Endpoint<'_>,
target: actor::Actor target: actor::Actor
) -> Response { ) -> Response {
let remote_actor = match endpoint.remote_actor(&state.client, &target.key_id(), &state.priv_key).await {
Ok(remote_actor) => remote_actor,
Err(e) => {
track_request("POST", "relay", "bad_actor");
return (
StatusCode::BAD_REQUEST,
format!("Bad actor: {:?}", e)
).into_response();
}
};
let action = match serde_json::from_value::<activitypub::Action<serde_json::Value>>(endpoint.payload.clone()) { let action = match serde_json::from_value::<activitypub::Action<serde_json::Value>>(endpoint.payload.clone()) {
Ok(action) => action, Ok(action) => action,
Err(e) => { Err(e) => {
@ -168,12 +178,12 @@ async fn post_relay(
jsonld_context: serde_json::Value::String("https://www.w3.org/ns/activitystreams".to_string()), jsonld_context: serde_json::Value::String("https://www.w3.org/ns/activitystreams".to_string()),
action_type: "Accept".to_string(), action_type: "Accept".to_string(),
actor: target.uri(), actor: target.uri(),
to: Some(json!(endpoint.actor.id.clone())), to: Some(json!(remote_actor.id.clone())),
id: action.id, id: action.id,
object: Some(endpoint.payload), object: Some(endpoint.payload),
}; };
let result = send::send( let result = send::send(
client.as_ref(), &endpoint.actor.inbox, client.as_ref(), &remote_actor.inbox,
&target.key_id(), &target.key_id(),
&priv_key, &priv_key,
&accept, &accept,
@ -181,8 +191,8 @@ async fn post_relay(
match result { match result {
Ok(()) => { Ok(()) => {
match state.database.add_follow( match state.database.add_follow(
&endpoint.actor.id, &remote_actor.id,
&endpoint.actor.inbox, &remote_actor.inbox,
&target.uri(), &target.uri(),
).await { ).await {
Ok(()) => { Ok(()) => {
@ -208,7 +218,7 @@ async fn post_relay(
).into_response() ).into_response()
} else if action.action_type == "Undo" && object_type == Some("Follow".to_string()) { } else if action.action_type == "Undo" && object_type == Some("Follow".to_string()) {
match state.database.del_follow( match state.database.del_follow(
&endpoint.actor.id, &remote_actor.id,
&target.uri(), &target.uri(),
).await { ).await {
Ok(()) => { Ok(()) => {