This commit is contained in:
Astro 2022-12-20 00:15:00 +01:00
parent 2efa522c8a
commit 0c13db943c
7 changed files with 30 additions and 43 deletions

View file

@ -25,12 +25,7 @@
cargoTestCommands = x: cargoTestCommands = x:
x ++ [ x ++ [
''cargo clippy --all --all-features --tests -- \ ''cargo clippy --all --all-features --tests -- \
-D clippy::pedantic \
-D warnings \ -D warnings \
-A clippy::module-name-repetitions \
-A clippy::too-many-lines \
-A clippy::cast-possible-wrap \
-A clippy::cast-possible-truncation \
-A clippy::nonminimal_bool'' -A clippy::nonminimal_bool''
]; ];
meta.description = "Send Prometheus alerts to XMPP Multi-User Chatrooms"; meta.description = "Send Prometheus alerts to XMPP Multi-User Chatrooms";

View file

@ -1,5 +1,5 @@
use axum::{response::IntoResponse, Json}; use axum::{response::IntoResponse, Json};
use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Actor { pub struct Actor {
@ -23,7 +23,7 @@ pub struct ActorPublicKey {
pub pem: String, pub pem: String,
} }
/// ActivityPub "activity" /// `ActivityPub` "activity"
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Action<O> { pub struct Action<O> {
#[serde(rename = "@context")] #[serde(rename = "@context")]

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use tokio_postgres::{Client, Error, NoTls, Statement}; use tokio_postgres::{Client, Error, NoTls, Statement};
use crate::actor;
const CREATE_SCHEMA_COMMANDS: &[&str] = &[ const CREATE_SCHEMA_COMMANDS: &[&str] = &[
"CREATE TABLE IF NOT EXISTS follows (id TEXT, inbox TEXT, actor TEXT, UNIQUE (inbox, actor))", "CREATE TABLE IF NOT EXISTS follows (id TEXT, inbox TEXT, actor TEXT, UNIQUE (inbox, actor))",

View file

@ -4,16 +4,13 @@ use axum::{
async_trait, async_trait,
body::{Bytes, HttpBody}, body::{Bytes, HttpBody},
extract::{FromRef, FromRequest}, extract::{FromRef, FromRequest},
http::{header::CONTENT_TYPE, Request, StatusCode}, http::{header::CONTENT_TYPE, Request, StatusCode}, BoxError,
response::{IntoResponse, Response},
routing::post,
Form, RequestExt, Router, BoxError,
}; };
use http_digest_headers::{DigestHeader, DigestMethod, Error as DigestError}; use http_digest_headers::{DigestHeader};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use sigh::{Signature, PublicKey, Key}; use sigh::{Signature, PublicKey, Key};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use crate::fetch::fetch; use crate::fetch::fetch;
use crate::activitypub::Actor; use crate::activitypub::Actor;
@ -67,7 +64,7 @@ where
let signature_headers = signature.headers() let signature_headers = signature.headers()
.ok_or((StatusCode::BAD_REQUEST, "No signed headers".to_string()))?; .ok_or((StatusCode::BAD_REQUEST, "No signed headers".to_string()))?;
for header in SIGNATURE_HEADERS_REQUIRED { for header in SIGNATURE_HEADERS_REQUIRED {
if signature_headers.iter().find(|h| *h == header) == None { if !signature_headers.iter().any(|h| h == header) {
return Err((StatusCode::BAD_REQUEST, format!("Header {:?} not signed", header))); return Err((StatusCode::BAD_REQUEST, format!("Header {:?} not signed", header)));
} }
} }
@ -83,8 +80,8 @@ where
digest_header.replace_range(..4, "sha-"); digest_header.replace_range(..4, "sha-");
} }
// mastodon uses base64::alphabet::STANDARD, not base64::alphabet::URL_SAFE // mastodon uses base64::alphabet::STANDARD, not base64::alphabet::URL_SAFE
digest_header = digest_header.replace("+", "-") digest_header = digest_header.replace('+', "-")
.replace("/", "_"); .replace('/', "_");
let digest: DigestHeader = digest_header.parse() let digest: DigestHeader = digest_header.parse()
.map_err(|e| (StatusCode::BAD_REQUEST, format!("Cannot parse Digest: header: {}", e)))?; .map_err(|e| (StatusCode::BAD_REQUEST, format!("Cannot parse Digest: header: {}", e)))?;
// read body // read body
@ -96,7 +93,7 @@ where
} }
// parse body // parse body
let payload: serde_json::Value = serde_json::from_slice(&bytes) let payload: serde_json::Value = serde_json::from_slice(&bytes)
.map_err(|_| (StatusCode::BAD_REQUEST, format!("Error parsing JSON")))?; .map_err(|_| (StatusCode::BAD_REQUEST, "Error parsing JSON".to_string()))?;
let actor_uri = if let Some(serde_json::Value::String(actor_uri)) = payload.get("actor") { let actor_uri = if let Some(serde_json::Value::String(actor_uri)) = payload.get("actor") {
actor_uri actor_uri
} else { } else {
@ -104,20 +101,20 @@ where
}; };
// validate actor // validate actor
let client = Arc::from_ref(&state); let client = Arc::from_ref(state);
let actor: Actor = let actor: Actor =
serde_json::from_value( serde_json::from_value(
fetch(&client, &actor_uri).await fetch(&client, actor_uri).await
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("{}", e)))? .map_err(|e| (StatusCode::BAD_GATEWAY, format!("{}", e)))?
).map_err(|e| (StatusCode::BAD_GATEWAY, format!("Invalid actor: {}", e)))?; ).map_err(|e| (StatusCode::BAD_GATEWAY, format!("Invalid actor: {}", e)))?;
let public_key = PublicKey::from_pem(actor.public_key.pem.as_bytes()) let public_key = PublicKey::from_pem(actor.public_key.pem.as_bytes())
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?;
if signature.verify(&public_key) if !(signature.verify(&public_key)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))? != true .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", e)))?)
{ {
return Err((StatusCode::BAD_REQUEST, "Signature verification failed".to_string())); return Err((StatusCode::BAD_REQUEST, "Signature verification failed".to_string()));
} }
return Ok(Endpoint { actor, payload }); return Ok(Endpoint { payload, actor });
} }
} }

View file

@ -1,14 +1,12 @@
use axum::{ use axum::{
async_trait, extract::{FromRef, Path, Query},
extract::{FromRequest, FromRef, Path, Query}, http::{StatusCode},
http::{header::CONTENT_TYPE, Request, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
routing::{get, post}, routing::{get}, Json, Router,
Form, Json, RequestExt, Router,
}; };
use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use sigh::{PrivateKey, PublicKey, alg::{RsaSha256, Algorithm}, Key}; use sigh::{PrivateKey, PublicKey};
use std::{net::SocketAddr, sync::Arc, time::Duration, collections::HashMap}; use std::{net::SocketAddr, sync::Arc, time::Duration, collections::HashMap};
use std::{panic, process}; use std::{panic, process};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@ -143,7 +141,7 @@ async fn post_relay(
}; };
let object_type = action.object let object_type = action.object
.and_then(|object| object.get("type").cloned()) .and_then(|object| object.get("type").cloned())
.and_then(|object_type| object_type.as_str().map(|s| s.to_string())); .and_then(|object_type| object_type.as_str().map(std::string::ToString::to_string));
if action.action_type == "Follow" { if action.action_type == "Follow" {
let priv_key = state.priv_key.clone(); let priv_key = state.priv_key.clone();
@ -222,9 +220,7 @@ async fn main() {
.init(); .init();
let config = config::Config::load( let config = config::Config::load(
&std::env::args() &std::env::args().nth(1)
.skip(1)
.next()
.expect("Call with config.yaml") .expect("Call with config.yaml")
); );
let database = db::Database::connect(&config.db).await; let database = db::Database::connect(&config.db).await;

View file

@ -20,7 +20,7 @@ impl Post<'_> {
reqwest::Url::parse(self.url?) reqwest::Url::parse(self.url?)
.ok() .ok()
.and_then(|url| url.domain() .and_then(|url| url.domain()
.map(|s| s.to_lowercase()) .map(str::to_lowercase)
) )
} }
@ -38,11 +38,11 @@ impl Post<'_> {
fn relay_target_kinds(&self) -> impl Iterator<Item = actor::ActorKind> { fn relay_target_kinds(&self) -> impl Iterator<Item = actor::ActorKind> {
self.host() self.host()
.into_iter() .into_iter()
.map(|host| actor::ActorKind::InstanceRelay(host.clone())) .map(actor::ActorKind::InstanceRelay)
.chain( .chain(
self.tags() self.tags()
.into_iter() .into_iter()
.map(|tag| actor::ActorKind::TagRelay(tag)) .map(actor::ActorKind::TagRelay)
) )
} }

View file

@ -1,9 +1,8 @@
use std::{sync::Arc, ops::Deref}; use std::{sync::Arc};
use futures::StreamExt;
use http::StatusCode; use http::StatusCode;
use http_digest_headers::{DigestHeader, DigestMethod}; use http_digest_headers::{DigestHeader, DigestMethod};
use reqwest::Body;
use serde::Serialize; use serde::Serialize;
use sigh::{PrivateKey, SigningConfig, alg::RsaSha256}; use sigh::{PrivateKey, SigningConfig, alg::RsaSha256};
@ -56,7 +55,7 @@ pub async fn send_raw(
// mastodon uses base64::alphabet::STANDARD, not base64::alphabet::URL_SAFE // mastodon uses base64::alphabet::STANDARD, not base64::alphabet::URL_SAFE
digest_header.replace_range( digest_header.replace_range(
7.., 7..,
&digest_header[7..].replace("-", "+").replace("_", "/") &digest_header[7..].replace('-', "+").replace('_', "/")
); );
let url = reqwest::Url::parse(uri) let url = reqwest::Url::parse(uri)
@ -71,7 +70,7 @@ pub async fn send_raw(
.header("digest", digest_header) .header("digest", digest_header)
.body(body.as_ref().clone()) .body(body.as_ref().clone())
.map_err(SendError::HttpReq)?; .map_err(SendError::HttpReq)?;
SigningConfig::new(RsaSha256, &private_key, key_id) SigningConfig::new(RsaSha256, private_key, key_id)
.sign(&mut req)?; .sign(&mut req)?;
let req: reqwest::Request = req.try_into()?; let req: reqwest::Request = req.try_into()?;
let res = client.execute(req) let res = client.execute(req)