Merge branch 'aws-transcribe-auto-lang' into 'main'

aws: transcriber: add support for language identification

See merge request gstreamer/gst-plugins-rs!1518
This commit is contained in:
François Laignel 2024-05-03 06:16:48 +00:00
commit c3bf546cfa
4 changed files with 427 additions and 100 deletions

View file

@ -955,8 +955,32 @@
"type": "gchararray",
"writable": true
},
"identify-language": {
"blurb": "Enables automatic language identification, see <https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.identify_language>",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "false",
"mutable": "ready",
"readable": true,
"type": "gboolean",
"writable": true
},
"identify-multiple-languages": {
"blurb": "Enables automatic multi-language identification, see <https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.identify_multiple_languages>",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "false",
"mutable": "ready",
"readable": true,
"type": "gboolean",
"writable": true
},
"language-code": {
"blurb": "The Language of the Stream, see <https://docs.aws.amazon.com/transcribe/latest/dg/how-streaming-transcription.html> for an up to date list of allowed languages",
"blurb": "The Language of the Stream, see <https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html> for an up to date list of allowed languages",
"conditionally-available": false,
"construct": false,
"construct-only": false,
@ -967,6 +991,18 @@
"type": "gchararray",
"writable": true
},
"language-options": {
"blurb": "Two or more language codes that represent the languages which may be present in the media, see <https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.language_options>",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "NULL",
"mutable": "ready",
"readable": true,
"type": "gchararray",
"writable": true
},
"latency": {
"blurb": "Amount of milliseconds to allow AWS transcribe (Deprecated. Use transcribe-latency)",
"conditionally-available": false,
@ -995,6 +1031,18 @@
"type": "guint",
"writable": true
},
"preferred-language": {
"blurb": "Preferred language from the subset of languages codes specified in `language-options`, see <https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.preferred_language>",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "NULL",
"mutable": "ready",
"readable": true,
"type": "gchararray",
"writable": true
},
"results-stability": {
"blurb": "Defines how fast results should stabilize",
"conditionally-available": false,
@ -1109,6 +1157,18 @@
"type": "gchararray",
"writable": true
},
"vocabulary-filter-names": {
"blurb": "The names of a custom filter vocabularies to be used with identify-language or identify-multiple-languages, see <https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.vocabulary_filter_names>",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "NULL",
"mutable": "ready",
"readable": true,
"type": "gchararray",
"writable": true
},
"vocabulary-name": {
"blurb": "The name of a custom vocabulary, see <https://docs.aws.amazon.com/transcribe/latest/dg/how-vocabulary.html> for more information",
"conditionally-available": false,
@ -1120,6 +1180,18 @@
"readable": true,
"type": "gchararray",
"writable": true
},
"vocabulary-names": {
"blurb": "The names of a custom vocabularies to be used with identify-language or identify-multiple-languages, see <https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.vocabulary_names>",
"conditionally-available": false,
"construct": false,
"construct-only": false,
"controllable": false,
"default": "NULL",
"mutable": "ready",
"readable": true,
"type": "gchararray",
"writable": true
}
},
"rank": "none"

View file

@ -88,13 +88,19 @@ pub(super) struct Settings {
translate_lookahead: gst::ClockTime,
lateness: gst::ClockTime,
pub language_code: String,
pub identify_language: bool,
pub language_options: Option<String>,
pub preferred_language: Option<String>,
pub identify_multiple_languages: bool,
pub vocabulary: Option<String>,
pub vocabularies: Option<String>,
pub session_id: Option<String>,
pub results_stability: AwsTranscriberResultStability,
access_key: Option<String>,
secret_access_key: Option<String>,
session_token: Option<String>,
pub vocabulary_filter: Option<String>,
pub vocabulary_filters: Option<String>,
pub vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod,
}
@ -106,13 +112,19 @@ impl Default for Settings {
translate_lookahead: DEFAULT_TRANSLATE_LOOKAHEAD,
lateness: DEFAULT_LATENESS,
language_code: DEFAULT_INPUT_LANG_CODE.to_string(),
identify_language: false,
language_options: None,
preferred_language: None,
identify_multiple_languages: false,
vocabulary: None,
vocabularies: None,
session_id: None,
results_stability: DEFAULT_STABILITY,
access_key: None,
secret_access_key: None,
session_token: None,
vocabulary_filter: None,
vocabulary_filters: None,
vocabulary_filter_method: DEFAULT_VOCABULARY_FILTER_METHOD,
}
}
@ -680,11 +692,35 @@ impl ObjectImpl for Transcriber {
glib::ParamSpecString::builder("language-code")
.nick("Language Code")
.blurb("The Language of the Stream, see \
<https://docs.aws.amazon.com/transcribe/latest/dg/how-streaming-transcription.html> \
<https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html> \
for an up to date list of allowed languages")
.default_value(Some(DEFAULT_INPUT_LANG_CODE))
.mutable_ready()
.build(),
glib::ParamSpecBoolean::builder("identify-language")
.nick("Identify Language")
.blurb("Enables automatic language identification, see \
<https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.identify_language>")
.mutable_ready()
.build(),
glib::ParamSpecString::builder("language-options")
.nick("Language Options")
.blurb("Two or more comma-separated language codes that represent the languages which may be present in the media, see \
<https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.language_options>")
.mutable_ready()
.build(),
glib::ParamSpecString::builder("preferred-language")
.nick("Preferred Language")
.blurb("Preferred language from the subset of languages codes specified in `language-options`, see \
<https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.preferred_language>")
.mutable_ready()
.build(),
glib::ParamSpecBoolean::builder("identify-multiple-languages")
.nick("Identify Multiple Languages")
.blurb("Enables automatic multi-language identification, see \
<https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.identify_multiple_languages>")
.mutable_ready()
.build(),
glib::ParamSpecUInt::builder(DEPRECATED_LATENCY_PROPERTY)
.nick("Latency")
.blurb("Amount of milliseconds to allow AWS transcribe (Deprecated. Use transcribe-latency)")
@ -729,6 +765,12 @@ impl ObjectImpl for Transcriber {
for more information")
.mutable_ready()
.build(),
glib::ParamSpecString::builder("vocabulary-names")
.nick("Vocabulary Names")
.blurb("The names of comma-separated custom vocabularies to be used with identify-language or identify-multiple-languages, see \
<https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.vocabulary_names>")
.mutable_ready()
.build(),
glib::ParamSpecString::builder("session-id")
.nick("Session ID")
.blurb("The ID of the transcription session, must be length 36")
@ -761,6 +803,12 @@ impl ObjectImpl for Transcriber {
for more information")
.mutable_ready()
.build(),
glib::ParamSpecString::builder("vocabulary-filter-names")
.nick("Vocabulary Filter Names")
.blurb("The names of comma-separated custom filter vocabularies to be used with identify-language or identify-multiple-languages, see \
<https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.vocabulary_filter_names>")
.mutable_ready()
.build(),
glib::ParamSpecEnum::builder_with_default("vocabulary-filter-method", DEFAULT_VOCABULARY_FILTER_METHOD)
.nick("Vocabulary Filter Method")
.blurb("Defines how filtered words will be edited, has no effect when vocabulary-filter-name isn't set")
@ -787,6 +835,22 @@ impl ObjectImpl for Transcriber {
let mut settings = self.settings.lock().unwrap();
settings.language_code = value.get().expect("type checked upstream");
}
"identify-language" => {
let mut settings = self.settings.lock().unwrap();
settings.identify_language = value.get().unwrap();
}
"language-options" => {
let mut settings = self.settings.lock().unwrap();
settings.language_options = value.get().unwrap();
}
"preferred-language" => {
let mut settings = self.settings.lock().unwrap();
settings.preferred_language = value.get().unwrap();
}
"identify-multiple-languages" => {
let mut settings = self.settings.lock().unwrap();
settings.identify_multiple_languages = value.get().unwrap();
}
DEPRECATED_LATENCY_PROPERTY => {
let mut settings = self.settings.lock().unwrap();
settings.transcribe_latency = gst::ClockTime::from_mseconds(
@ -817,6 +881,10 @@ impl ObjectImpl for Transcriber {
let mut settings = self.settings.lock().unwrap();
settings.vocabulary = value.get().expect("type checked upstream");
}
"vocabulary-names" => {
let mut settings = self.settings.lock().unwrap();
settings.vocabularies = value.get().expect("type checked upstream");
}
"session-id" => {
let mut settings = self.settings.lock().unwrap();
settings.session_id = value.get().expect("type checked upstream");
@ -843,6 +911,10 @@ impl ObjectImpl for Transcriber {
let mut settings = self.settings.lock().unwrap();
settings.vocabulary_filter = value.get().expect("type checked upstream");
}
"vocabulary-filter-names" => {
let mut settings = self.settings.lock().unwrap();
settings.vocabulary_filters = value.get().expect("type checked upstream");
}
"vocabulary-filter-method" => {
let mut settings = self.settings.lock().unwrap();
settings.vocabulary_filter_method = value
@ -859,6 +931,22 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
settings.language_code.to_value()
}
"identify-language" => {
let settings = self.settings.lock().unwrap();
settings.identify_language.to_value()
}
"language-options" => {
let settings = self.settings.lock().unwrap();
settings.language_options.to_value()
}
"preferred-language" => {
let settings = self.settings.lock().unwrap();
settings.preferred_language.to_value()
}
"identify-multiple-languages" => {
let settings = self.settings.lock().unwrap();
settings.identify_multiple_languages.to_value()
}
DEPRECATED_LATENCY_PROPERTY => {
let settings = self.settings.lock().unwrap();
(settings.transcribe_latency.mseconds() as u32).to_value()
@ -881,6 +969,10 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
settings.vocabulary.to_value()
}
"vocabulary-names" => {
let settings = self.settings.lock().unwrap();
settings.vocabularies.to_value()
}
"session-id" => {
let settings = self.settings.lock().unwrap();
settings.session_id.to_value()
@ -905,6 +997,10 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
settings.vocabulary_filter.to_value()
}
"vocabulary-filter-names" => {
let settings = self.settings.lock().unwrap();
settings.vocabulary_filters.to_value()
}
"vocabulary-filter-method" => {
let settings = self.settings.lock().unwrap();
settings.vocabulary_filter_method.to_value()
@ -1091,6 +1187,7 @@ impl ChildProxyImpl for Transcriber {
.map(|p| p.upcast())
}
}
struct TranslationPadTask {
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
elem: super::Transcriber,
@ -1136,7 +1233,7 @@ impl TranslationPadTask {
}
needs_translate = TranslateSrcPad::needs_translation(
&elem_settings.language_code,
&elem_settings,
pad_settings.language_code.as_deref(),
);
@ -1147,7 +1244,6 @@ impl TranslationPadTask {
translation_loop = Some(TranslateLoop::new(
elem_imp,
pad,
&elem_settings.language_code,
pad_settings.language_code.as_deref().unwrap(),
pad_settings.tokenization_method,
to_loop_rx,
@ -1634,10 +1730,24 @@ impl TranslateSrcPad {
}
#[inline]
fn needs_translation(input_lang: &str, output_lang: Option<&str>) -> bool {
output_lang.map_or(false, |other| {
!input_lang.eq_ignore_ascii_case(other.as_ref())
})
fn needs_translation(elem_settings: &Settings, output_lang: Option<&str>) -> bool {
let Some(output_lang) = output_lang else {
return false;
};
if elem_settings.identify_language || elem_settings.identify_multiple_languages {
// TranslateLoop will determine on a case by case basis whether
// the Translate service must be called depending on the language
// detected by Transcribe.
return true;
}
// Transcript language is a 5 character localized language code: e.g. en-US
// Translate output language can be 2 (en) or 5 characters (en-US).
!elem_settings
.language_code
.to_ascii_lowercase()
.starts_with(&output_lang.to_ascii_lowercase())
}
#[inline]
@ -1645,10 +1755,7 @@ impl TranslateSrcPad {
elem_settings: &Settings,
pad_settings: &TranslatePadSettings,
) -> gst::ClockTime {
if Self::needs_translation(
&elem_settings.language_code,
pad_settings.language_code.as_deref(),
) {
if Self::needs_translation(elem_settings, pad_settings.language_code.as_deref()) {
elem_settings.transcribe_latency + elem_settings.translate_latency
} else {
elem_settings.transcribe_latency

View file

@ -25,9 +25,15 @@ use super::CAT;
#[derive(Debug)]
pub struct TranscriberSettings {
lang_code: types::LanguageCode,
identify_lang: bool,
lang_options: Option<String>,
preferred_lang: Option<types::LanguageCode>,
identify_multi_lang: bool,
sample_rate: i32,
vocabulary: Option<String>,
vocabularies: Option<String>,
vocabulary_filter: Option<String>,
vocabulary_filters: Option<String>,
vocabulary_filter_method: types::VocabularyFilterMethod,
session_id: Option<String>,
results_stability: types::PartialResultsStability,
@ -35,11 +41,22 @@ pub struct TranscriberSettings {
impl TranscriberSettings {
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
let preferred_lang = settings
.preferred_language
.as_ref()
.map(|pref_lang| pref_lang.as_str().into());
TranscriberSettings {
lang_code: settings.language_code.as_str().into(),
identify_lang: settings.identify_language,
lang_options: settings.language_options.clone(),
preferred_lang,
identify_multi_lang: settings.identify_multiple_languages,
sample_rate,
vocabulary: settings.vocabulary.clone(),
vocabularies: settings.vocabularies.clone(),
vocabulary_filter: settings.vocabulary_filter.clone(),
vocabulary_filters: settings.vocabulary_filters.clone(),
vocabulary_filter_method: settings.vocabulary_filter_method.into(),
session_id: settings.session_id.clone(),
results_stability: settings.results_stability.into(),
@ -47,10 +64,11 @@ impl TranscriberSettings {
}
}
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug)]
pub struct TranscriptItem {
pub pts: gst::ClockTime,
pub duration: gst::ClockTime,
pub lang_code: Option<types::LanguageCode>,
pub content: String,
pub is_punctuation: bool,
}
@ -58,6 +76,7 @@ pub struct TranscriptItem {
impl TranscriptItem {
pub fn from(
item: types::Item,
lang_code: Option<types::LanguageCode>,
lateness: gst::ClockTime,
discont_offset: gst::ClockTime,
) -> Option<TranscriptItem> {
@ -73,6 +92,7 @@ impl TranscriptItem {
Some(TranscriptItem {
pts: start_time,
duration: end_time - start_time,
lang_code,
content,
is_punctuation: matches!(item.r#type, Some(types::ItemType::Punctuation)),
})
@ -100,7 +120,8 @@ pub struct TranscriberStream {
imp: glib::subclass::ObjectImplRef<Transcriber>,
output: aws_transcribe::operation::start_stream_transcription::StartStreamTranscriptionOutput,
lateness: gst::ClockTime,
partial_index: usize,
last_stable_end_time: f64,
last_stable_is_punctuation: bool,
discont_offset_tracker: Arc<Mutex<DiscontOffsetTracker>>,
}
@ -147,18 +168,42 @@ impl TranscriberStream {
let mut transcribe_builder = client
.start_stream_transcription()
.language_code(settings.lang_code)
.media_sample_rate_hertz(settings.sample_rate)
.media_encoding(types::MediaEncoding::Pcm)
.enable_partial_results_stabilization(true)
.partial_results_stability(settings.results_stability)
.set_vocabulary_name(settings.vocabulary)
.set_session_id(settings.session_id);
if let Some(vocabulary_filter) = settings.vocabulary_filter {
// From the doc:
//
// > Note that you must include either LanguageCode or IdentifyLanguage or
// > IdentifyMultipleLanguages in your request. If you include more than one
// > of these parameters, your transcription job fails.
//
// https://docs.rs/aws-sdk-transcribestreaming/1.17.0/aws_sdk_transcribestreaming/operation/start_stream_transcription/builders/struct.StartStreamTranscriptionFluentBuilder.html#method.identify_language
if settings.identify_lang || settings.identify_multi_lang {
transcribe_builder = transcribe_builder
.vocabulary_filter_name(vocabulary_filter)
.vocabulary_filter_method(settings.vocabulary_filter_method);
.set_language_options(settings.lang_options)
.set_preferred_language(settings.preferred_lang)
.identify_language(!settings.identify_multi_lang)
.identify_multiple_languages(settings.identify_multi_lang)
.set_vocabulary_names(settings.vocabularies);
if let Some(vocabulary_filters) = settings.vocabulary_filters {
transcribe_builder = transcribe_builder
.vocabulary_filter_names(vocabulary_filters)
.vocabulary_filter_method(settings.vocabulary_filter_method);
}
} else {
transcribe_builder = transcribe_builder
.language_code(settings.lang_code)
.set_vocabulary_name(settings.vocabulary);
if let Some(vocabulary_filter) = settings.vocabulary_filter {
transcribe_builder = transcribe_builder
.vocabulary_filter_name(vocabulary_filter)
.vocabulary_filter_method(settings.vocabulary_filter_method);
}
}
let output = transcribe_builder
@ -175,12 +220,15 @@ impl TranscriberStream {
imp: imp.ref_counted(),
output,
lateness,
partial_index: 0,
last_stable_end_time: 0.0f64,
last_stable_is_punctuation: false,
discont_offset_tracker,
})
}
pub async fn next(&mut self) -> Result<TranscriptEvent, gst::ErrorMessage> {
let lang_code = self.output.language_code().cloned();
loop {
let event = self
.output
@ -199,8 +247,6 @@ impl TranscriberStream {
};
if let types::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
let mut ready_items = None;
if let Some(result) = transcript_evt
.transcript
.and_then(|transcript| transcript.results)
@ -212,14 +258,17 @@ impl TranscriberStream {
.alternatives
.and_then(|mut alternatives| alternatives.drain(..).next())
{
ready_items = alternative.items.and_then(|items| {
self.get_ready_transcript_items(items, result.is_partial)
});
}
}
if let Some(items) = alternative.items {
let ready_items = self.get_ready_transcript_items(
items,
result.language_code.or_else(|| lang_code.clone()),
);
if let Some(ready_items) = ready_items {
return Ok(ready_items.into());
if !ready_items.is_empty() {
return Ok(ready_items.into());
}
}
}
}
} else {
gst::warning!(
@ -234,57 +283,71 @@ impl TranscriberStream {
/// Builds a list from the provided stable items.
fn get_ready_transcript_items(
&mut self,
mut items: Vec<types::Item>,
partial: bool,
) -> Option<Vec<TranscriptItem>> {
if items.len() <= self.partial_index {
gst::error!(
CAT,
imp: self.imp,
"sanity check failed, alternative length {} < partial_index {}",
items.len(),
self.partial_index
);
if !partial {
self.partial_index = 0;
}
return None;
}
items: Vec<types::Item>,
lang_code: Option<types::LanguageCode>,
) -> Vec<TranscriptItem> {
let mut output = vec![];
for item in items.drain(self.partial_index..) {
// With language identification, we can receive several non-partial sub-segment
// results for individual sentences. E.g. starting from a segment with 3 sentences:
//
// - ... several partial results with some stabilized items.
// - partial result with all stable items for the segment (3 sentences).
// - non-partial sub-segment result with items & timestamps from 1st sentence.
// - non-partial sub-segment result with items & timestamps from 2nd sentence.
// - non-partial sub-segment result with items & timestamps from 3nd sentence.
// end_time matches the end_time of the last stable item.
// - partial result for next segment...
//
// Also had the case of a non-partial segment followed by the same
// segment flagged as partial.
//
// We can't expect the items sequence to be stable anymore and skip already
// processed items based on the partial_index. The approach here consists in
// using the item timestamp to determine which item should be skipped.
for item in items {
if !item.stable().unwrap_or(false) {
break;
}
let discont_offset = self.discont_offset_tracker.lock().unwrap().discont_offset;
if item.start_time < self.last_stable_end_time {
gst::trace!(CAT, imp: self.imp, "Skipping earlier item starting @ {}", item.start_time);
continue;
}
let Some(item) = TranscriptItem::from(item, self.lateness, discont_offset) else {
let is_punctuation = item
.r#type()
.map_or(false, |typ| *typ == types::ItemType::Punctuation);
if is_punctuation && self.last_stable_is_punctuation {
gst::trace!(CAT, imp: self.imp, "Skipping punctuation {:?} because last item is a punctuation too", item.content);
continue;
}
let discont_offset = self.discont_offset_tracker.lock().unwrap().discont_offset;
let end_time = item.end_time;
let Some(item) =
TranscriptItem::from(item, lang_code.clone(), self.lateness, discont_offset)
else {
continue;
};
gst::debug!(
CAT,
imp: self.imp,
"Item is ready for queuing: {}, PTS {}",
"Item is ready for queuing: {}, PTS {}, lang {:?}",
item.content,
item.pts,
item.lang_code,
);
self.partial_index += 1;
self.last_stable_end_time = end_time;
self.last_stable_is_punctuation = is_punctuation;
output.push(item);
}
if !partial {
self.partial_index = 0;
}
if output.is_empty() {
return None;
}
Some(output)
output
}
}

View file

@ -15,6 +15,7 @@ use aws_sdk_translate::error::ProvideErrorMetadata;
use futures::channel::mpsc;
use futures::prelude::*;
use std::ops::ControlFlow;
use std::sync::Arc;
use super::imp::TranslateSrcPad;
@ -44,7 +45,6 @@ impl From<&TranscriptItem> for TranslatedItem {
pub struct TranslateLoop {
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
client: aws_translate::Client,
input_lang: String,
output_lang: String,
tokenization_method: TranslationTokenizationMethod,
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
@ -55,7 +55,6 @@ impl TranslateLoop {
pub fn new(
imp: &super::imp::Transcriber,
pad: &TranslateSrcPad,
input_lang: &str,
output_lang: &str,
tokenization_method: TranslationTokenizationMethod,
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
@ -69,7 +68,6 @@ impl TranslateLoop {
TranslateLoop {
pad: pad.ref_counted(),
client: aws_sdk_translate::Client::new(aws_config),
input_lang: input_lang.to_string(),
output_lang: output_lang.to_string(),
tokenization_method,
transcript_rx,
@ -110,11 +108,54 @@ impl TranslateLoop {
}
let mut ts_duration_list: Vec<(gst::ClockTime, gst::ClockTime)> = vec![];
let mut content: Vec<String> = vec![];
let mut content = String::new();
let mut content_lang = Option::<String>::None;
let mut needs_translation = false;
let mut it = transcript_items.iter().peekable();
while let Some(item) = it.next() {
let lang_changed = if !content.is_empty() {
// Some items already buffered
match (content_lang.as_ref(), item.lang_code.as_ref()) {
(Some(clang), Some(ilang)) => {
// Content and new item langs are defined
!clang.eq_ignore_ascii_case(ilang.as_str())
}
(None, Some(_)) | (Some(_), None) => {
// Content uses an undefined lang
// but new item's lang is defined
// or Content uses a defined lang
// but incoming item's lang is undefined
true
}
(None, None) => false,
}
} else {
false
};
if lang_changed
&& self
.handle_transcript_items(
&mut ts_duration_list,
&mut content,
needs_translation,
content_lang.take(),
)
.await?
.is_break()
{
gst::info!(CAT, imp: self.pad, "exiting translation loop");
break;
}
if content.is_empty() {
// Either first item or content drained above
content_lang = item.lang_code.as_ref().map(|lang| lang.to_string());
needs_translation = self.needs_translation(content_lang.as_deref());
}
let suffix = match it.peek() {
Some(next_item) => {
if next_item.is_punctuation {
@ -125,25 +166,65 @@ impl TranslateLoop {
}
None => "",
};
ts_duration_list.push((item.pts, item.duration));
content.push(match self.tokenization_method {
Tokenization::None => format!("{}{}", item.content, suffix),
Tokenization::SpanBased => {
let item_content =
if needs_translation && self.tokenization_method == Tokenization::SpanBased {
format!("{SPAN_START}{}{SPAN_END}{}", item.content, suffix)
}
});
} else {
format!("{}{}", item.content, suffix)
};
content.push_str(&item_content);
}
let content: String = content.join("");
if !content.is_empty()
&& self
.handle_transcript_items(
&mut ts_duration_list,
&mut content,
needs_translation,
content_lang.take(),
)
.await?
.is_break()
{
gst::info!(CAT, imp: self.pad, "exiting translation loop");
break;
}
}
gst::debug!(CAT, imp: self.pad, "Translating {content} with {ts_duration_list:?}");
Ok(())
}
#[inline]
fn needs_translation(&self, lang: Option<&str>) -> bool {
let Some(lang) = lang else { return false };
!lang.to_ascii_lowercase().starts_with(&self.output_lang)
}
async fn handle_transcript_items(
&mut self,
ts_duration_list: &mut Vec<(gst::ClockTime, gst::ClockTime)>,
content: &mut String,
needs_translation: bool,
content_lang: Option<String>,
) -> Result<ControlFlow<()>, gst::ErrorMessage> {
use std::mem;
use TranslationTokenizationMethod as Tokenization;
let output_text = if needs_translation {
gst::debug!(CAT, imp: self.pad,
"Translating: '{content}' from {content_lang:?} to {} with {ts_duration_list:?}",
self.output_lang,
);
let translated_text = self
.client
.translate_text()
.set_source_language_code(Some(self.input_lang.clone()))
.set_source_language_code(content_lang)
.set_target_language_code(Some(self.output_lang.clone()))
.set_text(Some(content))
.set_text(Some(mem::take(content)))
.send()
.await
.map_err(|err| {
@ -153,39 +234,43 @@ impl TranslateLoop {
})?
.translated_text;
gst::debug!(CAT, imp: self.pad, "Got translation {translated_text}");
gst::debug!(CAT, imp: self.pad, "Got translation: '{translated_text}'");
let translated_items = match self.tokenization_method {
Tokenization::None => {
// Push translation as a single item
let mut ts_duration_iter = ts_duration_list.into_iter().peekable();
translated_text
} else {
gst::debug!(CAT, imp: self.pad,
"Not translating: '{content}' from {content_lang:?} to {} with {ts_duration_list:?}",
self.output_lang,
);
let &(first_pts, _) = ts_duration_iter.peek().expect("at least one item");
let (last_pts, last_duration) =
ts_duration_iter.last().expect("at least one item");
mem::take(content)
};
vec![TranslatedItem {
pts: first_pts,
duration: last_pts.saturating_sub(first_pts) + last_duration,
content: translated_text,
}]
}
Tokenization::SpanBased => span_tokenize_items(&translated_text, ts_duration_list),
let translated_items =
if needs_translation && self.tokenization_method == Tokenization::SpanBased {
span_tokenize_items(&output_text, ts_duration_list.drain(..))
} else {
// Push translation as a single item
let mut ts_duration_iter = ts_duration_list.drain(..).peekable();
let &(first_pts, _) = ts_duration_iter.peek().expect("at least one item");
let (last_pts, last_duration) = ts_duration_iter.last().expect("at least one item");
vec![TranslatedItem {
pts: first_pts,
duration: last_pts.saturating_sub(first_pts) + last_duration,
content: output_text,
}]
};
gst::trace!(CAT, imp: self.pad, "Sending {translated_items:?}");
gst::trace!(CAT, imp: self.pad, "Sending {translated_items:?}");
if self.translate_tx.send(translated_items).await.is_err() {
gst::info!(
CAT,
imp: self.pad,
"translation chan terminated, exiting translation loop"
);
break;
}
if self.translate_tx.send(translated_items).await.is_err() {
gst::info!(CAT, imp: self.pad, "translation chan terminated");
return Ok(ControlFlow::Break(()));
}
Ok(())
Ok(ControlFlow::Continue(()))
}
}