endpoint: use authorized_fetch for remote_actor validation

This commit is contained in:
Astro 2023-05-14 23:14:55 +02:00
parent f782197b93
commit 01f8cc2f5f
4 changed files with 50 additions and 50 deletions

View file

@ -9,11 +9,12 @@ use axum::{
use http_digest_headers::{DigestHeader};
use sigh::{Signature, PublicKey, Key};
use sigh::{Signature, PublicKey, Key, PrivateKey};
use crate::fetch::fetch;
use crate::fetch::authorized_fetch;
use crate::activitypub::Actor;
use crate::error::Error;
const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[
"(request-target)",
@ -21,20 +22,14 @@ const SIGNATURE_HEADERS_REQUIRED: &[&str] = &[
"digest",
];
#[derive(Clone, Debug)]
pub struct Endpoint {
pub struct Endpoint<'a> {
pub payload: serde_json::Value,
pub actor: Actor,
signature: Signature<'a>,
remote_actor_uri: String,
}
// impl Endpoint {
// pub fn parse<T: DeserializeOwned>(self) -> Result<T, serde_json::Error> {
// serde_json::from_value(self.payload)
// }
// }
#[async_trait]
impl<S, B> FromRequest<S, B> for Endpoint
impl<'a, S, B> FromRequest<S, B> for Endpoint<'a>
where
B: HttpBody + Send + 'static,
B::Data: Send,
@ -95,27 +90,32 @@ where
// parse body
let payload: serde_json::Value = serde_json::from_slice(&bytes)
.map_err(|_| (StatusCode::BAD_REQUEST, "Error parsing JSON".to_string()))?;
let actor_uri = if let Some(serde_json::Value::String(actor_uri)) = payload.get("actor") {
actor_uri
let remote_actor_uri = if let Some(serde_json::Value::String(actor_uri)) = payload.get("actor") {
actor_uri.to_string()
} else {
return Err((StatusCode::BAD_REQUEST, "Actor missing".to_string()));
};
// validate actor
let client = Arc::from_ref(state);
let actor: Actor =
serde_json::from_value(
fetch(&client, actor_uri).await
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("{}", e)))?
).map_err(|e| (StatusCode::BAD_GATEWAY, format!("Invalid actor: {}", e)))?;
let public_key = PublicKey::from_pem(actor.public_key.pem.as_bytes())
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?;
if !(signature.verify(&public_key)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?)
{
return Err((StatusCode::BAD_REQUEST, "Signature verification failed".to_string()));
}
return Ok(Endpoint { payload, actor });
return Ok(Endpoint { payload, signature, remote_actor_uri });
}
}
impl<'a> Endpoint<'a> {
/// Validates the requesting actor
pub async fn remote_actor(
&self,
client: &reqwest::Client,
key_id: &str,
private_key: &PrivateKey,
) -> Result<Actor, Error> {
let remote_actor: Actor = serde_json::from_value(
authorized_fetch(&client, &self.remote_actor_uri, key_id, private_key).await?
)?;
let public_key = PublicKey::from_pem(remote_actor.public_key.pem.as_bytes())?;
if ! (self.signature.verify(&public_key)?) {
return Err(Error::SignatureFail);
}
Ok(remote_actor)
}
}

View file

@ -6,6 +6,8 @@ pub enum Error {
Json(#[from] serde_json::Error),
#[error("Signature error")]
Signature(#[from] sigh::Error),
#[error("Signature verification failure")]
SignatureFail,
#[error("HTTP request error")]
HttpReq(#[from] http::Error),
#[error("HTTP client error")]

View file

@ -3,18 +3,6 @@ use serde::de::DeserializeOwned;
use sigh::{PrivateKey, SigningConfig, alg::RsaSha256};
use crate::{digest, error::Error};
pub async fn fetch<T>(client: &reqwest::Client, url: &str) -> Result<T, reqwest::Error>
where
T: DeserializeOwned,
{
client.get(url)
.header("accept", "application/activity+json")
.send()
.await?
.json()
.await
}
pub async fn authorized_fetch<T>(
client: &reqwest::Client,
uri: &str,

View file

@ -120,7 +120,7 @@ async fn get_instance_actor(
async fn post_tag_relay(
axum::extract::State(state): axum::extract::State<State>,
Path(tag): Path<String>,
endpoint: endpoint::Endpoint
endpoint: endpoint::Endpoint<'_>
) -> Response {
let target = actor::Actor {
host: state.hostname.clone(),
@ -132,7 +132,7 @@ async fn post_tag_relay(
async fn post_instance_relay(
axum::extract::State(state): axum::extract::State<State>,
Path(instance): Path<String>,
endpoint: endpoint::Endpoint
endpoint: endpoint::Endpoint<'_>
) -> Response {
let target = actor::Actor {
host: state.hostname.clone(),
@ -143,9 +143,19 @@ async fn post_instance_relay(
async fn post_relay(
state: State,
endpoint: endpoint::Endpoint,
endpoint: endpoint::Endpoint<'_>,
target: actor::Actor
) -> Response {
let remote_actor = match endpoint.remote_actor(&state.client, &target.key_id(), &state.priv_key).await {
Ok(remote_actor) => remote_actor,
Err(e) => {
track_request("POST", "relay", "bad_actor");
return (
StatusCode::BAD_REQUEST,
format!("Bad actor: {:?}", e)
).into_response();
}
};
let action = match serde_json::from_value::<activitypub::Action<serde_json::Value>>(endpoint.payload.clone()) {
Ok(action) => action,
Err(e) => {
@ -168,12 +178,12 @@ async fn post_relay(
jsonld_context: serde_json::Value::String("https://www.w3.org/ns/activitystreams".to_string()),
action_type: "Accept".to_string(),
actor: target.uri(),
to: Some(json!(endpoint.actor.id.clone())),
to: Some(json!(remote_actor.id.clone())),
id: action.id,
object: Some(endpoint.payload),
};
let result = send::send(
client.as_ref(), &endpoint.actor.inbox,
client.as_ref(), &remote_actor.inbox,
&target.key_id(),
&priv_key,
&accept,
@ -181,8 +191,8 @@ async fn post_relay(
match result {
Ok(()) => {
match state.database.add_follow(
&endpoint.actor.id,
&endpoint.actor.inbox,
&remote_actor.id,
&remote_actor.inbox,
&target.uri(),
).await {
Ok(()) => {
@ -208,7 +218,7 @@ async fn post_relay(
).into_response()
} else if action.action_type == "Undo" && object_type == Some("Follow".to_string()) {
match state.database.del_follow(
&endpoint.actor.id,
&remote_actor.id,
&target.uri(),
).await {
Ok(()) => {