diff --git a/net/aws/src/aws_transcriber/imp.rs b/net/aws/src/aws_transcriber/imp.rs index 8fb8e57e..eb232847 100644 --- a/net/aws/src/aws_transcriber/imp.rs +++ b/net/aws/src/aws_transcriber/imp.rs @@ -44,7 +44,7 @@ use serde_derive::{Deserialize, Serialize}; use once_cell::sync::Lazy; -use super::AwsTranscriberResultStability; +use super::{AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod}; const DEFAULT_TRANSCRIBER_REGION: &str = "us-east-1"; @@ -113,6 +113,8 @@ static RUNTIME: Lazy = Lazy::new(|| { const DEFAULT_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8); const DEFAULT_LATENESS: gst::ClockTime = gst::ClockTime::from_seconds(0); const DEFAULT_STABILITY: AwsTranscriberResultStability = AwsTranscriberResultStability::Low; +const DEFAULT_VOCABULARY_FILTER_METHOD: AwsTranscriberVocabularyFilterMethod = + AwsTranscriberVocabularyFilterMethod::Mask; const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100); #[derive(Debug, Clone)] @@ -126,6 +128,8 @@ struct Settings { access_key: Option, secret_access_key: Option, session_token: Option, + vocabulary_filter: Option, + vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod, } impl Default for Settings { @@ -140,6 +144,8 @@ impl Default for Settings { access_key: None, secret_access_key: None, session_token: None, + vocabulary_filter: None, + vocabulary_filter_method: DEFAULT_VOCABULARY_FILTER_METHOD, } } } @@ -936,6 +942,23 @@ impl Transcriber { query_params.push_str(format!("&vocabulary-name={}", vocabulary).as_str()); } + if let Some(ref vocabulary_filter) = settings.vocabulary_filter { + query_params + .push_str(format!("&vocabulary-filter-name={}", vocabulary_filter).as_str()); + } + + query_params.push_str( + format!( + "&vocabulary-filter-method={}", + match settings.vocabulary_filter_method { + AwsTranscriberVocabularyFilterMethod::Mask => "mask", + AwsTranscriberVocabularyFilterMethod::Remove => "remove", + AwsTranscriberVocabularyFilterMethod::Tag => "tag", + } + ) + .as_str(), + ); + if let Some(ref session_id) = settings.session_id { gst::debug!(CAT, obj: element, "Using session ID: {}", session_id); query_params.push_str(format!("&session-id={}", session_id).as_str()); @@ -1227,6 +1250,23 @@ impl ObjectImpl for Transcriber { None, glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY, ), + glib::ParamSpecString::new( + "vocabulary-filter-name", + "Vocabulary Filter Name", + "The name of a custom filter vocabulary, see \ + \ + for more information", + None, + glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY, + ), + glib::ParamSpecEnum::new( + "vocabulary-filter-method", + "Vocabulary Filter Method", + "Defines how filtered words will be edited, has no effect when vocabulary-filter-name isn't set", + AwsTranscriberVocabularyFilterMethod::static_type(), + DEFAULT_VOCABULARY_FILTER_METHOD as i32, + glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY, + ), ] }); @@ -1291,6 +1331,16 @@ impl ObjectImpl for Transcriber { let mut settings = self.settings.lock().unwrap(); settings.session_token = value.get().expect("type checked upstream"); } + "vocabulary-filter-name" => { + let mut settings = self.settings.lock().unwrap(); + settings.vocabulary_filter = value.get().expect("type checked upstream"); + } + "vocabulary-filter-method" => { + let mut settings = self.settings.lock().unwrap(); + settings.vocabulary_filter_method = value + .get::() + .expect("type checked upstream"); + } _ => unimplemented!(), } } @@ -1333,6 +1383,14 @@ impl ObjectImpl for Transcriber { let settings = self.settings.lock().unwrap(); settings.session_token.to_value() } + "vocabulary-filter-name" => { + let settings = self.settings.lock().unwrap(); + settings.vocabulary_filter.to_value() + } + "vocabulary-filter-method" => { + let settings = self.settings.lock().unwrap(); + settings.vocabulary_filter_method.to_value() + } _ => unimplemented!(), } } diff --git a/net/aws/src/aws_transcriber/mod.rs b/net/aws/src/aws_transcriber/mod.rs index 02b50ba8..009123ca 100644 --- a/net/aws/src/aws_transcriber/mod.rs +++ b/net/aws/src/aws_transcriber/mod.rs @@ -31,6 +31,19 @@ pub enum AwsTranscriberResultStability { Low = 2, } +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::Enum)] +#[repr(u32)] +#[enum_type(name = "GstAwsTranscriberVocabularyFilterMethod")] +#[non_exhaustive] +pub enum AwsTranscriberVocabularyFilterMethod { + #[enum_value(name = "Mask: replace words with ***", nick = "mask")] + Mask = 0, + #[enum_value(name = "Remove: delete words", nick = "remove")] + Remove = 1, + #[enum_value(name = "Tag: flag words without changing them", nick = "tag")] + Tag = 2, +} + glib::wrapper! { pub struct Transcriber(ObjectSubclass) @extends gst::Element, gst::Object; }