diff --git a/src/activitypub/handlers/create_note.rs b/src/activitypub/handlers/create_note.rs index 4b1dd9d..3d2ebd3 100644 --- a/src/activitypub/handlers/create_note.rs +++ b/src/activitypub/handlers/create_note.rs @@ -27,7 +27,7 @@ use crate::models::posts::{ mentions::mention_to_address, queries::create_post, types::{Post, PostCreateData, Visibility}, - validators::CONTENT_MAX_SIZE, + validators::{content_allowed_classes, CONTENT_MAX_SIZE}, }; use crate::models::profiles::queries::get_profile_by_acct; use crate::models::profiles::types::DbActorProfile; @@ -88,7 +88,7 @@ pub fn get_note_content(object: &Object) -> Result { if content.len() > CONTENT_MAX_SIZE { return Err(ValidationError("content is too long")); }; - let content_safe = clean_html(&content); + let content_safe = clean_html(&content, content_allowed_classes()); Ok(content_safe) } diff --git a/src/models/posts/validators.rs b/src/models/posts/validators.rs index a98be39..4d8ed89 100644 --- a/src/models/posts/validators.rs +++ b/src/models/posts/validators.rs @@ -13,7 +13,7 @@ const CONTENT_ALLOWED_TAGS: [&str; 8] = [ "span", ]; -fn content_allowed_classes() -> Vec<(&'static str, Vec<&'static str>)> { +pub fn content_allowed_classes() -> Vec<(&'static str, Vec<&'static str>)> { vec![ ("a", vec!["hashtag", "mention", "u-url"]), ("span", vec!["h-card"]), diff --git a/src/models/profiles/validators.rs b/src/models/profiles/validators.rs index 8190687..13a7412 100644 --- a/src/models/profiles/validators.rs +++ b/src/models/profiles/validators.rs @@ -37,7 +37,7 @@ pub fn clean_bio(bio: &str, is_remote: bool) -> Result let cleaned_bio = if is_remote { // Remote profile let truncated_bio: String = bio.chars().take(BIO_MAX_LENGTH).collect(); - clean_html(&truncated_bio) + clean_html(&truncated_bio, vec![]) } else { // Local profile if bio.chars().count() > BIO_MAX_LENGTH { diff --git a/src/utils/html.rs b/src/utils/html.rs index 14c9e50..8ccffd8 100644 --- a/src/utils/html.rs +++ b/src/utils/html.rs @@ -3,9 +3,15 @@ use std::iter::FromIterator; use ammonia::Builder; -pub fn clean_html(unsafe_html: &str) -> String { - let safe_html = Builder::default() - .add_generic_attributes(&["class"]) +pub fn clean_html( + unsafe_html: &str, + allowed_classes: Vec<(&'static str, Vec<&'static str>)>, +) -> String { + let mut builder = Builder::default(); + for (tag, classes) in allowed_classes.iter() { + builder.add_allowed_classes(tag, classes); + }; + let safe_html = builder // Remove src from external images to prevent tracking .set_tag_attribute_value("img", "src", "") // Always add rel="noopener" @@ -53,13 +59,19 @@ mod tests { fn test_clean_html() { let unsafe_html = concat!( r#"

@user test

"#, - r#"

"#, + r#"

"#, ); let expected_safe_html = concat!( r#"

@user test

"#, r#"

"#, ); - let safe_html = clean_html(unsafe_html); + let safe_html = clean_html( + unsafe_html, + vec![ + ("a", vec!["mention", "u-url"]), + ("span", vec!["h-card"]), + ], + ); assert_eq!(safe_html, expected_safe_html); }