From 743e97738fe44bec8843f467828a1ec2aa710d91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laignel?= Date: Fri, 10 Mar 2023 14:47:38 +0100 Subject: [PATCH] net/aws/transcriber: add translation request src pads This commit adds an optional transcript translation feature implemented as request src Pads. When requesting a src Pad, the user can specify the translation language code using Pad properties 'language-code'. The following properties are defined on the Element: - 'transcribe-latency': formerly 'latency', defines the expected latency for the Transcribe webservice. - 'translate-latency': defines the expected latency for the Translate webservice. - 'transcript-lookahead': maximum transcript duration to send to translation when a transcript is hitting its deadline and no punctuation was found. When the input and output languages are the same, only the 'transcribe-latency' is used for the Pad. Otherwise, the resulting latency is the addition of 'transcribe-latency' and 'translate-latency'. Part-of: --- docs/plugins/gst_plugins_cache.json | 80 +- net/aws/Cargo.toml | 1 + net/aws/src/transcriber/imp.rs | 1845 +++++++++++++++---------- net/aws/src/transcriber/mod.rs | 19 +- net/aws/src/transcriber/transcribe.rs | 277 ++++ net/aws/src/transcriber/translate.rs | 215 +++ 6 files changed, 1707 insertions(+), 730 deletions(-) create mode 100644 net/aws/src/transcriber/transcribe.rs create mode 100644 net/aws/src/transcriber/translate.rs diff --git a/docs/plugins/gst_plugins_cache.json b/docs/plugins/gst_plugins_cache.json index e21ba9282..c6e6e16bb 100644 --- a/docs/plugins/gst_plugins_cache.json +++ b/docs/plugins/gst_plugins_cache.json @@ -628,6 +628,9 @@ "GInitiallyUnowned", "GObject" ], + "interfaces": [ + "GstChildProxy" + ], "klass": "Audio/Text/Filter", "long-name": "Transcriber", "pad-templates": { @@ -639,7 +642,14 @@ "src": { "caps": "text/x-raw:\n format: utf8\n", "direction": "src", - "presence": "always" + "presence": "always", + "type": "GstTranslationSrcPad" + }, + "src_%%u": { + "caps": "text/x-raw:\n format: utf8\n", + "direction": "src", + "presence": "request", + "type": "GstTranslationSrcPad" } }, "properties": { @@ -668,7 +678,7 @@ "writable": true }, "latency": { - "blurb": "Amount of milliseconds to allow AWS transcribe", + "blurb": "Amount of milliseconds to allow AWS transcribe (Deprecated. Use transcribe-latency)", "conditionally-available": false, "construct": false, "construct-only": false, @@ -743,6 +753,48 @@ "type": "gchararray", "writable": true }, + "transcribe-latency": { + "blurb": "Amount of milliseconds to allow AWS transcribe", + "conditionally-available": false, + "construct": false, + "construct-only": false, + "controllable": false, + "default": "8000", + "max": "-1", + "min": "0", + "mutable": "ready", + "readable": true, + "type": "guint", + "writable": true + }, + "transcript-lookahead": { + "blurb": "Maximum duration in milliseconds of transcript to lookahead before sending to translation when no separator was encountered", + "conditionally-available": false, + "construct": false, + "construct-only": false, + "controllable": false, + "default": "3000", + "max": "-1", + "min": "0", + "mutable": "ready", + "readable": true, + "type": "guint", + "writable": true + }, + "translate-latency": { + "blurb": "Amount of milliseconds to allow AWS translate (ignored if the input and output languages are the same)", + "conditionally-available": false, + "construct": false, + "construct-only": false, + "controllable": false, + "default": "500", + "max": "-1", + "min": "0", + "mutable": "ready", + "readable": true, + "type": "guint", + "writable": true + }, "vocabulary-filter-method": { "blurb": "Defines how filtered words will be edited, has no effect when vocabulary-filter-name isn't set", "conditionally-available": false, @@ -845,6 +897,30 @@ "value": "2" } ] + }, + "GstTranslationSrcPad": { + "hierarchy": [ + "GstTranslationSrcPad", + "GstPad", + "GstObject", + "GInitiallyUnowned", + "GObject" + ], + "kind": "object", + "properties": { + "language-code": { + "blurb": "The Language the Stream must be translated to", + "conditionally-available": false, + "construct": false, + "construct-only": false, + "controllable": false, + "default": "NULL", + "mutable": "ready", + "readable": true, + "type": "gchararray", + "writable": true + } + } } }, "package": "gst-plugin-aws", diff --git a/net/aws/Cargo.toml b/net/aws/Cargo.toml index c329009f7..663263f92 100644 --- a/net/aws/Cargo.toml +++ b/net/aws/Cargo.toml @@ -16,6 +16,7 @@ base32 = "0.4" aws-config = "0.54.0" aws-sdk-s3 = "0.24.0" aws-sdk-transcribestreaming = "0.24.0" +aws-sdk-translate = "0.24.0" aws-types = "0.54.0" aws-credential-types = "0.54.0" aws-sig-auth = "0.54.0" diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index e4116d324..d856e6d4a 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -7,72 +7,97 @@ // // SPDX-License-Identifier: MPL-2.0 -use gst::glib; -use gst::prelude::*; +//! AWS Transcriber element. +//! +//! This element calls AWS Transcribe to extract transcripts from an audio stream. +//! The element can optionally translate the resulting transcripts to one or +//! multiple languages. +//! +//! This module contains the element implementation as well as the `TranslationSrcPad` +//! sublcass and its `TranslationPadTask`. +//! +//! Web service specific code can be found in the `transcribe` and `translate` modules. + use gst::subclass::prelude::*; +use gst::{glib, prelude::*}; use aws_sdk_transcribestreaming as aws_transcribe; -use aws_sdk_transcribestreaming::model; use futures::channel::mpsc; use futures::future::AbortHandle; use futures::prelude::*; -use tokio::{runtime, task}; +use tokio::{runtime, sync::broadcast, task}; -use std::cmp::Ordering; -use std::collections::VecDeque; +use std::collections::{BTreeSet, VecDeque}; use std::sync::Mutex; use once_cell::sync::Lazy; -use super::{AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod}; - -static CAT: Lazy = Lazy::new(|| { - gst::DebugCategory::new( - "awstranscribe", - gst::DebugColorFlags::empty(), - Some("AWS Transcribe element"), - ) -}); +use super::transcribe::{TranscriberLoop, TranscriptEvent, TranscriptItem, TranscriptionSettings}; +use super::translate::{TranslatedItem, TranslationLoop, TranslationQueue}; +use super::{AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod, CAT}; static RUNTIME: Lazy = Lazy::new(|| { runtime::Builder::new_multi_thread() .enable_all() - .worker_threads(1) .build() .unwrap() }); const DEFAULT_TRANSCRIBER_REGION: &str = "us-east-1"; -const DEFAULT_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8); + +// Deprecated in 0.11.0: due to evolutions of the transcriber element, +// this property has been replaced by `TRANSCRIBE_LATENCY_PROPERTY`. +const DEPRECATED_LATENCY_PROPERTY: &str = "latency"; + +const TRANSCRIBE_LATENCY_PROPERTY: &str = "transcribe-latency"; +pub const DEFAULT_TRANSCRIBE_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8); + +const TRANSLATE_LATENCY_PROPERTY: &str = "translate-latency"; +pub const DEFAULT_TRANSLATE_LATENCY: gst::ClockTime = gst::ClockTime::from_mseconds(500); + +const TRANSCRIPT_LOOKAHEAD_PROPERTY: &str = "transcript-lookahead"; +pub const DEFAULT_TRANSCRIPT_LOOKAHEAD: gst::ClockTime = gst::ClockTime::from_seconds(5); + const DEFAULT_LATENESS: gst::ClockTime = gst::ClockTime::ZERO; -const DEFAULT_LANGUAGE_CODE: &str = "en-US"; +pub const DEFAULT_INPUT_LANG_CODE: &str = "en-US"; + const DEFAULT_STABILITY: AwsTranscriberResultStability = AwsTranscriberResultStability::Low; const DEFAULT_VOCABULARY_FILTER_METHOD: AwsTranscriberVocabularyFilterMethod = AwsTranscriberVocabularyFilterMethod::Mask; -const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100); + +// The period at which the event loops will check if they need to push +// anything downstream when no other events show up. +pub const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100); + +const OUTPUT_LANG_CODE_PROPERTY: &str = "language-code"; +const DEFAULT_OUTPUT_LANG_CODE: Option<&str> = None; #[derive(Debug, Clone)] -struct Settings { - latency: gst::ClockTime, +pub(super) struct Settings { + transcribe_latency: gst::ClockTime, + translate_latency: gst::ClockTime, + transcript_lookahead: gst::ClockTime, lateness: gst::ClockTime, - language_code: String, - vocabulary: Option, - session_id: Option, - results_stability: AwsTranscriberResultStability, + pub language_code: String, + pub vocabulary: Option, + pub session_id: Option, + pub results_stability: AwsTranscriberResultStability, access_key: Option, secret_access_key: Option, session_token: Option, - vocabulary_filter: Option, - vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod, + pub vocabulary_filter: Option, + pub vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod, } impl Default for Settings { fn default() -> Self { Self { - latency: DEFAULT_LATENCY, + transcribe_latency: DEFAULT_TRANSCRIBE_LATENCY, + translate_latency: DEFAULT_TRANSLATE_LATENCY, + transcript_lookahead: DEFAULT_TRANSCRIPT_LOOKAHEAD, lateness: DEFAULT_LATENESS, - language_code: DEFAULT_LANGUAGE_CODE.to_string(), + language_code: DEFAULT_INPUT_LANG_CODE.to_string(), vocabulary: None, session_id: None, results_stability: DEFAULT_STABILITY, @@ -85,578 +110,72 @@ impl Default for Settings { } } -#[derive(Debug)] -struct TranscriptionSettings { - lang_code: model::LanguageCode, - sample_rate: i32, - vocabulary: Option, - vocabulary_filter: Option, - vocabulary_filter_method: model::VocabularyFilterMethod, - session_id: Option, - results_stability: model::PartialResultsStability, -} - -impl TranscriptionSettings { - fn from(settings: &Settings, sample_rate: i32) -> Self { - TranscriptionSettings { - lang_code: settings.language_code.as_str().into(), - sample_rate, - vocabulary: settings.vocabulary.clone(), - vocabulary_filter: settings.vocabulary_filter.clone(), - vocabulary_filter_method: settings.vocabulary_filter_method.into(), - session_id: settings.session_id.clone(), - results_stability: settings.results_stability.into(), - } - } -} - -struct TranscriberLoop { - imp: glib::subclass::ObjectImplRef, - client: aws_transcribe::Client, - settings: TranscriptionSettings, - lateness: gst::ClockTime, - buffer_rx: mpsc::Receiver, - transcript_notif_tx: mpsc::Sender<()>, -} - -impl TranscriberLoop { - fn new( - imp: &Transcriber, - aws_config: &aws_config::SdkConfig, - settings: TranscriptionSettings, - lateness: gst::ClockTime, - buffer_rx: mpsc::Receiver, - transcript_notif_tx: mpsc::Sender<()>, - ) -> Self { - TranscriberLoop { - imp: imp.ref_counted(), - client: aws_transcribe::Client::new(aws_config), - settings, - lateness, - buffer_rx, - transcript_notif_tx, - } - } - - async fn run(mut self) -> Result<(), gst::ErrorMessage> { - // Stream the incoming buffers chunked - let chunk_stream = self.buffer_rx.flat_map(move |buffer: gst::Buffer| { - async_stream::stream! { - let data = buffer.map_readable().unwrap(); - use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; - for chunk in data.chunks(8192) { - yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); - } - } - }); - - let mut transcribe_builder = self - .client - .start_stream_transcription() - .language_code(self.settings.lang_code) - .media_sample_rate_hertz(self.settings.sample_rate) - .media_encoding(model::MediaEncoding::Pcm) - .enable_partial_results_stabilization(true) - .partial_results_stability(self.settings.results_stability) - .set_vocabulary_name(self.settings.vocabulary) - .set_session_id(self.settings.session_id); - - if let Some(vocabulary_filter) = self.settings.vocabulary_filter { - transcribe_builder = transcribe_builder - .vocabulary_filter_name(vocabulary_filter) - .vocabulary_filter_method(self.settings.vocabulary_filter_method); - } - - let mut output = transcribe_builder - .audio_stream(chunk_stream.into()) - .send() - .await - .map_err(|err| { - let err = format!("Transcribe ws init error: {err}"); - gst::error!(CAT, imp: self.imp, "{err}"); - gst::error_msg!(gst::LibraryError::Init, ["{err}"]) - })?; - - while let Some(event) = output - .transcript_result_stream - .recv() - .await - .map_err(|err| { - let err = format!("Transcribe ws stream error: {err}"); - gst::error!(CAT, imp: self.imp, "{err}"); - gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) - })? - { - if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { - let mut enqueued = false; - - if let Some(result) = transcript_evt - .transcript - .and_then(|transcript| transcript.results) - .and_then(|mut results| results.drain(..).next()) - { - gst::trace!(CAT, imp: self.imp, "Received: {result:?}"); - - if let Some(alternative) = result - .alternatives - .and_then(|mut alternatives| alternatives.drain(..).next()) - { - if let Some(items) = alternative.items { - enqueued = self.imp.enqueue(items, result.is_partial, self.lateness); - } - } - } - - if enqueued && self.transcript_notif_tx.send(()).await.is_err() { - gst::debug!(CAT, imp: self.imp, "Terminated transcript_notif_tx channel"); - break; - } - } else { - gst::warning!( - CAT, - imp: self.imp, - "Transcribe ws returned unknown event: consider upgrading the SDK" - ) - } - } - - gst::debug!(CAT, imp: self.imp, "Exiting ws loop"); - - Ok(()) - } -} - -struct State { - aws_config: Option, +pub(super) struct State { buffer_tx: Option>, - transcript_notif_tx: Option>, - ws_loop_handle: Option>>, - task_abort_handle: Option, - in_segment: gst::FormattedSegment, - out_segment: gst::FormattedSegment, - seqnum: gst::Seqnum, - buffers: VecDeque, - send_eos: bool, - discont: bool, - partial_index: usize, - send_events: bool, - start_time: Option, + transcriber_loop_handle: Option>>, + srcpads: BTreeSet, + pad_serial: u32, + pub seqnum: gst::Seqnum, } impl Default for State { fn default() -> Self { Self { - aws_config: None, buffer_tx: None, - transcript_notif_tx: None, - ws_loop_handle: None, - task_abort_handle: None, - in_segment: gst::FormattedSegment::new(), - out_segment: gst::FormattedSegment::new(), + transcriber_loop_handle: None, + srcpads: Default::default(), + pad_serial: 0, seqnum: gst::Seqnum::next(), - buffers: VecDeque::new(), - send_eos: false, - discont: true, - partial_index: 0, - send_events: true, - start_time: None, } } } pub struct Transcriber { - srcpad: gst::Pad, + static_srcpad: super::TranslationSrcPad, sinkpad: gst::Pad, settings: Mutex, state: Mutex, + pub(super) aws_config: Mutex>, + // sender to broadcast transcript items to the translation src pads. + transcript_event_tx: broadcast::Sender, } impl Transcriber { - fn dequeue(&self) -> bool { - /* First, check our pending buffers */ - let mut items = vec![]; + fn start_srcpad_tasks(&self, state: &State) -> Result<(), gst::LoggableError> { + gst::debug!(CAT, imp: self, "Starting tasks"); - let Some(now) = self.obj().current_running_time() else { return true }; - - let latency = self.settings.lock().unwrap().latency; - - let mut state = self.state.lock().unwrap(); - - if state.start_time.is_none() { - state.start_time = Some(now); - state.out_segment.set_position(now); + if self.static_srcpad.is_linked() { + self.static_srcpad.imp().start_task()?; } - let start_time = state.start_time.unwrap(); - let mut last_position = state.out_segment.position().unwrap(); - - let send_eos = state.send_eos && state.buffers.is_empty(); - - while let Some(buf) = state.buffers.front() { - let pts = buf.pts().unwrap(); - gst::trace!( - CAT, - imp: self, - "Checking now {now} if item is ready for dequeuing, PTS {pts}, threshold {} vs {}", - pts + latency.saturating_sub(3 * GRANULARITY), - now - start_time - ); - - if pts + latency.saturating_sub(3 * GRANULARITY) < now - start_time { - /* Safe unwrap, we know we have an item */ - let mut buf = state.buffers.pop_front().unwrap(); - - { - let buf_mut = buf.get_mut().unwrap(); - - buf_mut.set_pts(start_time + pts); - } - - items.push(buf); - } else { - break; - } + for pad in state.srcpads.iter() { + pad.imp().start_task()?; } - let seqnum = state.seqnum; - - drop(state); - - /* We're EOS, we can pause and exit early */ - if send_eos { - let _ = self.srcpad.pause_task(); - - return self - .srcpad - .push_event(gst::event::Eos::builder().seqnum(seqnum).build()); - } - - for mut buf in items.drain(..) { - let mut pts = buf.pts().unwrap(); - let mut duration = buf.duration().unwrap(); - - match pts.cmp(&last_position) { - Ordering::Greater => { - let gap_event = gst::event::Gap::builder(last_position) - .duration(pts - last_position) - .seqnum(seqnum) - .build(); - gst::log!(CAT, "Pushing gap: {last_position} -> {pts}"); - if !self.srcpad.push_event(gap_event) { - return false; - } - } - Ordering::Less => { - let delta = last_position - pts; - - gst::warning!( - CAT, - imp: self, - "Updating item PTS ({pts} < {last_position}), consider increasing latency", - ); - - pts = last_position; - duration = duration.saturating_sub(delta); - - { - let buf_mut = buf.get_mut().unwrap(); - - buf_mut.set_pts(pts); - buf_mut.set_duration(duration); - } - } - _ => (), - } - - last_position = pts + duration; - - gst::debug!(CAT, "Pushing buffer: {pts} -> {}", pts + duration); - - if self.srcpad.push(buf).is_err() { - return false; - } - } - - /* next, push a gap if we're lagging behind the target position */ - gst::trace!( - CAT, - imp: self, - "Checking now: {now} if we need to push a gap, last_position: {last_position}, threshold: {}", - last_position + latency.saturating_sub(GRANULARITY) - ); - - if now > last_position + latency.saturating_sub(GRANULARITY) { - let duration = now - last_position - latency.saturating_sub(GRANULARITY); - - let gap_event = gst::event::Gap::builder(last_position) - .duration(duration) - .seqnum(seqnum) - .build(); - - gst::log!( - CAT, - "Pushing gap: {last_position} -> {}", - last_position + duration - ); - - last_position += duration; - - if !self.srcpad.push_event(gap_event) { - return false; - } - } - - self.state - .lock() - .unwrap() - .out_segment - .set_position(last_position); - - true - } - - /// Enqueues a buffer for each of the provided stable items. - /// - /// Returns `true` if at least one buffer was enqueued. - fn enqueue( - &self, - mut items: Vec, - partial: bool, - lateness: gst::ClockTime, - ) -> bool { - let mut state = self.state.lock().unwrap(); - - if items.len() <= state.partial_index { - gst::error!( - CAT, - imp: self, - "sanity check failed, alternative length {} < partial_index {}", - items.len(), - state.partial_index - ); - - if !partial { - state.partial_index = 0; - } - - return false; - } - - let mut enqueued = false; - - for item in items.drain(state.partial_index..) { - if !item.stable().unwrap_or(false) { - break; - } - - let Some(content) = item.content else { continue }; - - let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness; - let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness; - - /* Should be sent now */ - gst::debug!( - CAT, - imp: self, - "Item is ready for queuing: {content}, PTS {start_time}", - ); - - let mut buf = gst::Buffer::from_mut_slice(content.into_bytes()); - { - let buf = buf.get_mut().unwrap(); - - if state.discont { - buf.set_flags(gst::BufferFlags::DISCONT); - state.discont = false; - } - - buf.set_pts(start_time); - buf.set_duration(end_time - start_time); - } - - state.partial_index += 1; - - state.buffers.push_back(buf); - enqueued = true; - } - - if !partial { - state.partial_index = 0; - } - - enqueued - } - - fn pad_loop_fn(&self, transcript_notif_rx: &mut mpsc::Receiver<()>) { - let mut events = { - let mut events = vec![]; - - let state = self.state.lock().unwrap(); - if state.send_events { - events.push( - gst::event::StreamStart::builder("transcription") - .seqnum(state.seqnum) - .build(), - ); - - let caps = gst::Caps::builder("text/x-raw") - .field("format", "utf8") - .build(); - events.push( - gst::event::Caps::builder(&caps) - .seqnum(state.seqnum) - .build(), - ); - - events.push( - gst::event::Segment::builder(&state.out_segment) - .seqnum(state.seqnum) - .build(), - ); - } - - events - }; - - if !events.is_empty() { - for event in events.drain(..) { - gst::info!(CAT, imp: self, "Sending {event:?}"); - self.srcpad.push_event(event); - } - - self.state.lock().unwrap().send_events = false; - } - - let future = async move { - let timeout = tokio::time::sleep(GRANULARITY.into()).fuse(); - futures::pin_mut!(timeout); - - futures::select! { - notif = transcript_notif_rx.next() => { - if notif.is_none() { - // Transcriber loop terminated - self.state.lock().unwrap().send_eos = true; - }; - } - _ = timeout => (), - }; - - if !self.dequeue() { - gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing"); - let _ = self.srcpad.pause_task(); - } - }; - - let (abortable_future, abort_handle) = future::abortable(future); - self.state.lock().unwrap().task_abort_handle = Some(abort_handle); - - let _enter = RUNTIME.enter(); - if futures::executor::block_on(abortable_future).is_err() { - gst::debug!(CAT, imp: self, "task iter aborted"); - } - } - - fn start_task(&self) -> Result<(), gst::LoggableError> { - gst::debug!(CAT, imp: self, "Starting task"); - let mut state = self.state.lock().unwrap(); - - let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1); - - let imp = self.ref_counted(); - let res = self - .srcpad - .start_task(move || imp.pad_loop_fn(&mut transcript_notif_rx)); - - if res.is_err() { - state.transcript_notif_tx = None; - return Err(gst::loggable_error!(CAT, "Failed to start pad task")); - } - - state.transcript_notif_tx = Some(transcript_notif_tx); - - gst::debug!(CAT, imp: self, "Task started"); + gst::debug!(CAT, imp: self, "Tasks Started"); Ok(()) } - fn stop_task(&self) { - gst::debug!(CAT, imp: self, "Stopping task"); + fn stop_tasks(&self, state: &mut State) { + gst::debug!(CAT, imp: self, "Stopping tasks"); - let _ = self.srcpad.stop_task(); - - let mut state = self.state.lock().unwrap(); - - if let Some(task_abort_handle) = state.task_abort_handle.take() { - task_abort_handle.abort(); + if self.static_srcpad.is_linked() { + self.static_srcpad.imp().stop_task(); } - if let Some(ws_loop_handle) = state.ws_loop_handle.take() { - ws_loop_handle.abort(); + for pad in state.srcpads.iter() { + pad.imp().stop_task(); } - state.transcript_notif_tx = None; + // Terminate the audio buffer stream state.buffer_tx = None; - gst::debug!(CAT, imp: self, "Task Stopped"); - } - - fn stop_ws_loop(&self) { - let mut state = self.state.lock().unwrap(); - - if let Some(ws_loop_handle) = state.ws_loop_handle.take() { - ws_loop_handle.abort(); + if let Some(transcriber_loop_handle) = state.transcriber_loop_handle.take() { + transcriber_loop_handle.abort(); } - state.buffer_tx = None; - } - - fn src_activatemode( - &self, - _pad: &gst::Pad, - _mode: gst::PadMode, - active: bool, - ) -> Result<(), gst::LoggableError> { - if active { - self.start_task()?; - } else { - self.stop_task(); - } - - Ok(()) - } - - fn src_query(&self, pad: &gst::Pad, query: &mut gst::QueryRef) -> bool { - gst::log!(CAT, obj: pad, "Handling query {query:?}"); - - use gst::QueryViewMut::*; - match query.view_mut() { - Latency(q) => { - let mut peer_query = gst::query::Latency::new(); - - let ret = self.sinkpad.peer_query(&mut peer_query); - - if ret { - let (_, min, _) = peer_query.result(); - let our_latency = self.settings.lock().unwrap().latency; - q.set(true, our_latency + min, gst::ClockTime::NONE); - } - ret - } - Position(q) => { - if q.format() == gst::Format::Time { - let state = self.state.lock().unwrap(); - q.set( - state - .out_segment - .to_stream_time(state.out_segment.position()), - ); - true - } else { - false - } - } - _ => gst::Pad::query_default(pad, Some(&*self.obj()), query), - } + gst::debug!(CAT, imp: self, "Tasks Stopped"); } fn sink_event(&self, pad: &gst::Pad, event: gst::Event) -> bool { @@ -665,14 +184,15 @@ impl Transcriber { use gst::EventView::*; match event.view() { Eos(_) => { - self.stop_ws_loop(); + // Terminate the audio buffer stream + self.state.lock().unwrap().buffer_tx = None; true } FlushStart(_) => { gst::info!(CAT, imp: self, "Received flush start, disconnecting"); let ret = gst::Pad::event_default(pad, Some(&*self.obj()), event); - self.stop_task(); + self.stop_tasks(&mut self.state.lock().unwrap()); ret } @@ -680,9 +200,10 @@ impl Transcriber { gst::info!(CAT, imp: self, "Received flush stop, restarting task"); if gst::Pad::event_default(pad, Some(&*self.obj()), event) { - match self.start_task() { + let state = self.state.lock().unwrap(); + match self.start_srcpad_tasks(&state) { Err(err) => { - gst::error!(CAT, imp: self, "Failed to start srcpad task: {err}"); + gst::error!(CAT, imp: self, "Failed to start srcpad tasks: {err}"); false } Ok(_) => true, @@ -692,22 +213,17 @@ impl Transcriber { } } Segment(e) => { - let segment = match e.segment().clone().downcast::() { - Err(segment) => { - gst::element_imp_error!( - self, - gst::StreamError::Format, - ["Only Time segments supported, got {:?}", segment.format()] - ); - return false; - } - Ok(segment) => segment, + let format = e.segment().format(); + if format != gst::Format::Time { + gst::element_imp_error!( + self, + gst::StreamError::Format, + ["Only Time segments supported, got {format:?}"] + ); + return false; }; - let mut state = self.state.lock().unwrap(); - - state.in_segment = segment; - state.seqnum = e.seqnum(); + self.state.lock().unwrap().seqnum = e.seqnum(); true } @@ -728,10 +244,7 @@ impl Transcriber { ) -> Result { gst::log!(CAT, obj: pad, "Handling {buffer:?}"); - self.ensure_connection().map_err(|err| { - gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]); - gst::FlowError::Error - })?; + self.ensure_connection(); let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else { gst::log!(CAT, obj: pad, "Flushing"); @@ -748,132 +261,93 @@ impl Transcriber { Ok(gst::FlowSuccess::Ok) } - fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> { - enum ConfigStatus { - Ready(aws_config::SdkConfig), - NotReady { - access_key: Option, - secret_access_key: Option, - session_token: Option, - }, + fn ensure_connection(&self) { + let mut state = self.state.lock().unwrap(); + + if state.buffer_tx.is_some() { + return; } - let (config_status, transcription_settings, lateness, transcript_notif_tx); - { - let mut state = self.state.lock().unwrap(); + let settings = self.settings.lock().unwrap(); - if let Some(ref ws_loop_handle) = state.ws_loop_handle { - if ws_loop_handle.is_finished() { - state.ws_loop_handle = None; + let in_caps = self.sinkpad.current_caps().unwrap(); + let s = in_caps.structure(0).unwrap(); + let sample_rate = s.get::("rate").unwrap(); - const ERR: &str = "ws loop terminated unexpectedly"; - gst::error!(CAT, imp: self, "{ERR}"); - return Err(gst::error_msg!(gst::LibraryError::Failed, ["{ERR}"])); - } - - return Ok(()); - } - - transcript_notif_tx = state - .transcript_notif_tx - .take() - .expect("attempting to spawn the ws loop, but the srcpad task hasn't been started"); - - let settings = self.settings.lock().unwrap(); - - lateness = settings.lateness; - if settings.latency + lateness <= 2 * GRANULARITY { - const ERR: &str = "latency + lateness must be greater than 200 milliseconds"; - gst::error!(CAT, imp: self, "{ERR}"); - return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"])); - } - - let in_caps = self.sinkpad.current_caps().unwrap(); - let s = in_caps.structure(0).unwrap(); - let sample_rate = s.get::("rate").unwrap(); - - transcription_settings = TranscriptionSettings::from(&settings, sample_rate); - - config_status = if let Some(aws_config) = state.aws_config.take() { - ConfigStatus::Ready(aws_config) - } else { - ConfigStatus::NotReady { - access_key: settings.access_key.to_owned(), - secret_access_key: settings.secret_access_key.to_owned(), - session_token: settings.session_token.to_owned(), - } - }; - }; - - let aws_config = match config_status { - ConfigStatus::Ready(aws_config) => aws_config, - ConfigStatus::NotReady { - access_key, - secret_access_key, - session_token, - } => { - gst::info!(CAT, imp: self, "Loading aws config..."); - let _enter_guard = RUNTIME.enter(); - - let config_loader = match (access_key, secret_access_key) { - (Some(key), Some(secret_key)) => { - gst::debug!(CAT, imp: self, "Using settings credentials"); - aws_config::ConfigLoader::default().credentials_provider( - aws_transcribe::Credentials::new( - key, - secret_key, - session_token, - None, - "translate", - ), - ) - } - _ => { - gst::debug!(CAT, imp: self, "Attempting to get credentials from env..."); - aws_config::from_env() - } - }; - - let config_loader = config_loader.region( - aws_config::meta::region::RegionProviderChain::default_provider() - .or_else(DEFAULT_TRANSCRIBER_REGION), - ); - let config = futures::executor::block_on(config_loader.load()); - gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap()); - - config - } - }; - - let mut state = self.state.lock().unwrap(); + let transcription_settings = TranscriptionSettings::from(&settings, sample_rate); let (buffer_tx, buffer_rx) = mpsc::channel(1); - let ws_loop_ctx = TranscriberLoop::new( + let transcriber_loop = TranscriberLoop::new( self, - &aws_config, transcription_settings, - lateness, + settings.lateness, buffer_rx, - transcript_notif_tx, + self.transcript_event_tx.clone(), ); - let ws_loop_handle = RUNTIME.spawn(ws_loop_ctx.run()); + let transcriber_loop_handle = RUNTIME.spawn(transcriber_loop.run()); - state.aws_config = Some(aws_config); - state.ws_loop_handle = Some(ws_loop_handle); + state.transcriber_loop_handle = Some(transcriber_loop_handle); state.buffer_tx = Some(buffer_tx); + } + + fn prepare(&self) -> Result<(), gst::ErrorMessage> { + gst::debug!(CAT, imp: self, "Preparing"); + + let (access_key, secret_access_key, session_token); + { + let settings = self.settings.lock().unwrap(); + access_key = settings.access_key.to_owned(); + secret_access_key = settings.secret_access_key.to_owned(); + session_token = settings.session_token.to_owned(); + } + + gst::info!(CAT, imp: self, "Loading aws config..."); + let _enter_guard = RUNTIME.enter(); + + let config_loader = match (access_key, secret_access_key) { + (Some(key), Some(secret_key)) => { + gst::debug!(CAT, imp: self, "Using settings credentials"); + aws_config::ConfigLoader::default().credentials_provider( + aws_transcribe::Credentials::new( + key, + secret_key, + session_token, + None, + "translate", + ), + ) + } + _ => { + gst::debug!(CAT, imp: self, "Attempting to get credentials from env..."); + aws_config::from_env() + } + }; + + let config_loader = config_loader.region( + aws_config::meta::region::RegionProviderChain::default_provider() + .or_else(DEFAULT_TRANSCRIBER_REGION), + ); + + let config = futures::executor::block_on(config_loader.load()); + gst::debug!(CAT, imp: self, "Using region {}", config.region().unwrap()); + + *self.aws_config.lock().unwrap() = Some(config); + + gst::debug!(CAT, imp: self, "Prepared"); Ok(()) } fn disconnect(&self) { gst::info!(CAT, imp: self, "Unpreparing"); + let mut state = self.state.lock().unwrap(); - self.stop_task(); - - // Also resets discont to true - *self.state.lock().unwrap() = State::default(); + self.stop_tasks(&mut state); + for pad in state.srcpads.iter() { + pad.imp().set_discont(); + } gst::info!(CAT, imp: self, "Unprepared"); } } @@ -883,6 +357,7 @@ impl ObjectSubclass for Transcriber { const NAME: &'static str = "GstAwsTranscriber"; type Type = super::Transcriber; type ParentType = gst::Element; + type Interfaces = (gst::ChildProxy,); fn with_class(klass: &Self::Class) -> Self { let templ = klass.pad_template("sink").unwrap(); @@ -904,36 +379,42 @@ impl ObjectSubclass for Transcriber { .build(); let templ = klass.pad_template("src").unwrap(); - let srcpad = gst::Pad::builder_with_template(&templ, Some("src")) - .activatemode_function(|pad, parent, mode, active| { - Transcriber::catch_panic_pad_function( - parent, - || { - Err(gst::loggable_error!( - CAT, - "Panic activating src pad with mode" - )) - }, - |transcriber| transcriber.src_activatemode(pad, mode, active), - ) - }) - .query_function(|pad, parent, query| { - Transcriber::catch_panic_pad_function( - parent, - || false, - |transcriber| transcriber.src_query(pad, query), - ) - }) - .flags(gst::PadFlags::FIXED_CAPS) - .build(); + let static_srcpad = + gst::PadBuilder::::from_template(&templ, Some("src")) + .activatemode_function(|pad, parent, mode, active| { + Transcriber::catch_panic_pad_function( + parent, + || { + Err(gst::loggable_error!( + CAT, + "Panic activating TranslationSrcPad" + )) + }, + |elem| TranslationSrcPad::activatemode(elem, pad, mode, active), + ) + }) + .query_function(|pad, parent, query| { + Transcriber::catch_panic_pad_function( + parent, + || false, + |elem| TranslationSrcPad::src_query(elem, pad, query), + ) + }) + .flags(gst::PadFlags::FIXED_CAPS) + .build(); - let settings = Mutex::new(Settings::default()); + // Setting the channel capacity so that a TranslationSrcPad that would lag + // behind for some reasons get a chance to catch-up without loosing items. + // Receiver will be created by subscribing to sender later. + let (transcript_event_tx, _) = broadcast::channel(128); Self { - srcpad, + static_srcpad, sinkpad, - settings, - state: Mutex::new(State::default()), + settings: Default::default(), + state: Default::default(), + aws_config: Default::default(), + transcript_event_tx, } } } @@ -947,13 +428,38 @@ impl ObjectImpl for Transcriber { .blurb("The Language of the Stream, see \ \ for an up to date list of allowed languages") - .default_value(Some(DEFAULT_LANGUAGE_CODE)) + .default_value(Some(DEFAULT_INPUT_LANG_CODE)) .mutable_ready() .build(), - glib::ParamSpecUInt::builder("latency") + glib::ParamSpecUInt::builder(DEPRECATED_LATENCY_PROPERTY) .nick("Latency") + .blurb("Amount of milliseconds to allow AWS transcribe (Deprecated. Use transcribe-latency)") + .default_value(DEFAULT_TRANSCRIBE_LATENCY.mseconds() as u32) + .mutable_ready() + .deprecated() + .build(), + glib::ParamSpecUInt::builder(TRANSCRIBE_LATENCY_PROPERTY) + .nick("AWS Transcribe Latency") .blurb("Amount of milliseconds to allow AWS transcribe") - .default_value(DEFAULT_LATENCY.mseconds() as u32) + .default_value(DEFAULT_TRANSCRIBE_LATENCY.mseconds() as u32) + .mutable_ready() + .build(), + glib::ParamSpecUInt::builder(TRANSLATE_LATENCY_PROPERTY) + .nick("AWS Translate Latency") + .blurb(concat!( + "Amount of milliseconds to allow AWS translate ", + "(ignored if the input and output languages are the same)", + )) + .default_value(DEFAULT_TRANSLATE_LATENCY.mseconds() as u32) + .mutable_ready() + .build(), + glib::ParamSpecUInt::builder(TRANSCRIPT_LOOKAHEAD_PROPERTY) + .nick("Transcript chunk") + .blurb(concat!( + "Maximum duration in milliseconds of transcript to lookahead ", + "before sending to translation when no separator was encountered", + )) + .default_value(DEFAULT_TRANSCRIPT_LOOKAHEAD.mseconds() as u32) .mutable_ready() .build(), glib::ParamSpecUInt::builder("lateness") @@ -1017,7 +523,7 @@ impl ObjectImpl for Transcriber { let obj = self.obj(); obj.add_pad(&self.sinkpad).unwrap(); - obj.add_pad(&self.srcpad).unwrap(); + obj.add_pad(&self.static_srcpad).unwrap(); obj.set_element_flags(gst::ElementFlags::PROVIDE_CLOCK | gst::ElementFlags::REQUIRE_CLOCK); } @@ -1027,12 +533,26 @@ impl ObjectImpl for Transcriber { let mut settings = self.settings.lock().unwrap(); settings.language_code = value.get().expect("type checked upstream"); } - "latency" => { + DEPRECATED_LATENCY_PROPERTY => { let mut settings = self.settings.lock().unwrap(); - settings.latency = gst::ClockTime::from_mseconds( + settings.transcribe_latency = gst::ClockTime::from_mseconds( value.get::().expect("type checked upstream").into(), ); } + TRANSCRIBE_LATENCY_PROPERTY => { + let mut settings = self.settings.lock().unwrap(); + settings.transcribe_latency = gst::ClockTime::from_mseconds( + value.get::().expect("type checked upstream").into(), + ); + } + TRANSLATE_LATENCY_PROPERTY => { + self.settings.lock().unwrap().translate_latency = + gst::ClockTime::from_mseconds(value.get::().unwrap().into()); + } + TRANSCRIPT_LOOKAHEAD_PROPERTY => { + self.settings.lock().unwrap().transcript_lookahead = + gst::ClockTime::from_mseconds(value.get::().unwrap().into()); + } "lateness" => { let mut settings = self.settings.lock().unwrap(); settings.lateness = gst::ClockTime::from_mseconds( @@ -1085,10 +605,24 @@ impl ObjectImpl for Transcriber { let settings = self.settings.lock().unwrap(); settings.language_code.to_value() } - "latency" => { + DEPRECATED_LATENCY_PROPERTY => { let settings = self.settings.lock().unwrap(); - (settings.latency.mseconds() as u32).to_value() + (settings.transcribe_latency.mseconds() as u32).to_value() } + TRANSCRIBE_LATENCY_PROPERTY => { + let settings = self.settings.lock().unwrap(); + (settings.transcribe_latency.mseconds() as u32).to_value() + } + TRANSLATE_LATENCY_PROPERTY => { + (self.settings.lock().unwrap().translate_latency.mseconds() as u32).to_value() + } + TRANSCRIPT_LOOKAHEAD_PROPERTY => (self + .settings + .lock() + .unwrap() + .transcript_lookahead + .mseconds() as u32) + .to_value(), "lateness" => { let settings = self.settings.lock().unwrap(); (settings.lateness.mseconds() as u32).to_value() @@ -1151,11 +685,20 @@ impl ElementImpl for Transcriber { let src_caps = gst::Caps::builder("text/x-raw") .field("format", "utf8") .build(); - let src_pad_template = gst::PadTemplate::new( + let src_pad_template = gst::PadTemplate::with_gtype( "src", gst::PadDirection::Src, gst::PadPresence::Always, &src_caps, + super::TranslationSrcPad::static_type(), + ) + .unwrap(); + let req_src_pad_template = gst::PadTemplate::with_gtype( + "translation_src_%u", + gst::PadDirection::Src, + gst::PadPresence::Request, + &src_caps, + super::TranslationSrcPad::static_type(), ) .unwrap(); @@ -1172,7 +715,7 @@ impl ElementImpl for Transcriber { ) .unwrap(); - vec![src_pad_template, sink_pad_template] + vec![src_pad_template, req_src_pad_template, sink_pad_template] }); PAD_TEMPLATES.as_ref() @@ -1184,6 +727,13 @@ impl ElementImpl for Transcriber { ) -> Result { gst::info!(CAT, imp: self, "Changing state {transition:?}"); + if let gst::StateChange::NullToReady = transition { + self.prepare().map_err(|err| { + self.post_error_message(err); + gst::StateChangeError + })?; + } + let mut success = self.parent_change_state(transition)?; match transition { @@ -1202,7 +752,848 @@ impl ElementImpl for Transcriber { Ok(success) } + fn request_new_pad( + &self, + templ: &gst::PadTemplate, + _name: Option<&str>, + _caps: Option<&gst::Caps>, + ) -> Option { + let mut state = self.state.lock().unwrap(); + + let pad = gst::PadBuilder::::from_template( + templ, + Some(format!("translation_src_{}", state.pad_serial).as_str()), + ) + .activatemode_function(|pad, parent, mode, active| { + Transcriber::catch_panic_pad_function( + parent, + || { + Err(gst::loggable_error!( + CAT, + "Panic activating TranslationSrcPad" + )) + }, + |elem| TranslationSrcPad::activatemode(elem, pad, mode, active), + ) + }) + .query_function(|pad, parent, query| { + Transcriber::catch_panic_pad_function( + parent, + || false, + |elem| TranslationSrcPad::src_query(elem, pad, query), + ) + }) + .flags(gst::PadFlags::FIXED_CAPS) + .build(); + + state.srcpads.insert(pad.clone()); + + state.pad_serial += 1; + drop(state); + + self.obj().add_pad(&pad).unwrap(); + + let _ = self + .obj() + .post_message(gst::message::Latency::builder().src(&*self.obj()).build()); + + self.obj().child_added(&pad, &pad.name()); + Some(pad.upcast()) + } + + fn release_pad(&self, pad: &gst::Pad) { + pad.set_active(false).unwrap(); + self.obj().remove_pad(pad).unwrap(); + + self.obj().child_removed(pad, &pad.name()); + let _ = self + .obj() + .post_message(gst::message::Latency::builder().src(&*self.obj()).build()); + } + fn provide_clock(&self) -> Option { Some(gst::SystemClock::obtain()) } } + +// Implementation of gst::ChildProxy virtual methods. +// +// This allows accessing the pads and their properties from e.g. gst-launch. +impl ChildProxyImpl for Transcriber { + fn children_count(&self) -> u32 { + let object = self.obj(); + object.num_pads() as u32 + } + + fn child_by_name(&self, name: &str) -> Option { + let object = self.obj(); + object + .pads() + .into_iter() + .find(|p| p.name() == name) + .map(|p| p.upcast()) + } + + fn child_by_index(&self, index: u32) -> Option { + let object = self.obj(); + object + .pads() + .into_iter() + .nth(index as usize) + .map(|p| p.upcast()) + } +} +struct TranslationPadTask { + pad: glib::subclass::ObjectImplRef, + elem: super::Transcriber, + transcript_event_rx: broadcast::Receiver, + needs_translate: bool, + translation_queue: TranslationQueue, + translation_loop_handle: Option>>, + to_translation_tx: Option>, + from_translation_rx: Option>, + translate_latency: gst::ClockTime, + transcript_lookahead: gst::ClockTime, + send_events: bool, + translated_items: VecDeque, + our_latency: gst::ClockTime, + seqnum: gst::Seqnum, + send_eos: bool, + pending_translations: usize, + start_time: Option, +} + +impl TranslationPadTask { + fn try_new( + pad: &TranslationSrcPad, + elem: super::Transcriber, + transcript_event_rx: broadcast::Receiver, + ) -> Result { + let mut this = TranslationPadTask { + pad: pad.ref_counted(), + elem, + transcript_event_rx, + needs_translate: false, + translation_queue: TranslationQueue::default(), + translation_loop_handle: None, + to_translation_tx: None, + from_translation_rx: None, + translate_latency: DEFAULT_TRANSLATE_LATENCY, + transcript_lookahead: DEFAULT_TRANSCRIPT_LOOKAHEAD, + send_events: true, + translated_items: VecDeque::new(), + our_latency: DEFAULT_TRANSCRIBE_LATENCY, + seqnum: gst::Seqnum::next(), + send_eos: false, + pending_translations: 0, + start_time: None, + }; + + let _enter_guard = RUNTIME.enter(); + futures::executor::block_on(this.init_translate())?; + + Ok(this) + } +} + +impl Drop for TranslationPadTask { + fn drop(&mut self) { + if let Some(translation_loop_handle) = self.translation_loop_handle.take() { + translation_loop_handle.abort(); + } + } +} + +impl TranslationPadTask { + async fn run_iter(&mut self) -> Result<(), gst::ErrorMessage> { + self.ensure_init_events()?; + + if self.needs_translate { + self.translate_iter().await?; + } else { + self.passthrough_iter().await?; + } + + if !self.dequeue().await { + gst::info!(CAT, imp: self.pad, "Failed to dequeue buffer, pausing"); + let _ = self.pad.obj().pause_task(); + } + + Ok(()) + } + + async fn passthrough_iter(&mut self) -> Result<(), gst::ErrorMessage> { + // This is to make sure we send items on a timely basis or at least Gap events. + let timeout = tokio::time::sleep(GRANULARITY.into()).fuse(); + futures::pin_mut!(timeout); + + let transcript_event_rx = self.transcript_event_rx.recv().fuse(); + futures::pin_mut!(transcript_event_rx); + + // `timeout` takes precedence over `transcript_events` reception + // because we may need to `dequeue` `items` or push a `Gap` event + // before current latency budget is exhausted. + futures::select_biased! { + _ = timeout => (), + items_res = transcript_event_rx => { + use TranscriptEvent::*; + use broadcast::error::RecvError; + match items_res { + Ok(Items(transcript_items)) => { + for transcript_item in transcript_items.iter() { + self.translated_items.push_back(transcript_item.into()); + } + } + Ok(Eos) => { + gst::debug!(CAT, imp: self.pad, "Got eos"); + self.send_eos = true; + } + Err(RecvError::Lagged(nb_msg)) => { + gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets"); + } + Err(RecvError::Closed) => { + gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos"); + self.send_eos = true; + } + } + } + } + + Ok(()) + } + + async fn translate_iter(&mut self) -> Result<(), gst::ErrorMessage> { + if self + .translation_loop_handle + .as_ref() + .map_or(true, task::JoinHandle::is_finished) + { + const ERR: &str = "Translation loop is not running"; + gst::error!(CAT, imp: self.pad, "{ERR}"); + return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); + } + + let transcript_items = { + // This is to make sure we send items on a timely basis or at least Gap events. + let timeout = tokio::time::sleep(GRANULARITY.into()).fuse(); + futures::pin_mut!(timeout); + + let from_translation_rx = self + .from_translation_rx + .as_mut() + .expect("from_translation chan must be available in translation mode"); + + let transcript_event_rx = self.transcript_event_rx.recv().fuse(); + futures::pin_mut!(transcript_event_rx); + + // `timeout` takes precedence over `transcript_events` reception + // because we may need to `dequeue` `items` or push a `Gap` event + // before current latency budget is exhausted. + futures::select_biased! { + _ = timeout => return Ok(()), + translated_item = from_translation_rx.next() => { + let Some(translated_item) = translated_item else { + const ERR: &str = "translation chan terminated"; + gst::debug!(CAT, imp: self.pad, "{ERR}"); + return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); + }; + + self.translated_items.push_back(translated_item); + self.pending_translations = self.pending_translations.saturating_sub(1); + + return Ok(()); + } + items_res = transcript_event_rx => { + use TranscriptEvent::*; + use broadcast::error::RecvError; + match items_res { + Ok(Items(transcript_items)) => transcript_items, + Ok(Eos) => { + gst::debug!(CAT, imp: self.pad, "Got eos"); + self.send_eos = true; + return Ok(()); + } + Err(RecvError::Lagged(nb_msg)) => { + gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets"); + return Ok(()); + } + Err(RecvError::Closed) => { + gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos"); + self.send_eos = true; + return Ok(()); + } + } + } + } + }; + + for item in transcript_items.iter() { + if let Some(ready_item) = self.translation_queue.push(item) { + self.send_for_translation(ready_item).await?; + } + } + + Ok(()) + } + + async fn dequeue(&mut self) -> bool { + let (now, start_time, mut last_position, mut discont_pending); + { + let mut pad_state = self.pad.state.lock().unwrap(); + + let Some(cur_rt) = self.elem.current_running_time() else { + // Wait for the clock to be available + return true; + }; + now = cur_rt; + + if self.start_time.is_none() { + self.start_time = Some(now); + pad_state.out_segment.set_position(now); + } + + start_time = self.start_time.unwrap(); + last_position = pad_state.out_segment.position().unwrap(); + discont_pending = pad_state.discont_pending; + } + + if self.needs_translate && !self.translation_queue.is_empty() { + // Maximum delay for an item to be pushed to stream on time + // Margin: + // - 1 * GRANULARITY: the time it will take before we can check this again, + // without running late, in the case of a timeout. + // - 2 * GRANULARITY: extra margin to account for additional overheads. + // FIXME explaing which ones. + let max_delay = self.our_latency.saturating_sub(3 * GRANULARITY); + + // Estimated time of arrival for an item sent to translation now. + // (in transcript item ts base) + let translation_eta = now + self.translate_latency - start_time; + + let deadline = translation_eta.saturating_sub(max_delay); + + if let Some(ready_item) = self + .translation_queue + .dequeue(deadline, self.transcript_lookahead) + { + gst::debug!( + CAT, + imp: self.pad, + "Forcing transcript at pts {} with duration {} to translation", + ready_item.pts, + ready_item.duration, + ); + + if self.send_for_translation(ready_item).await.is_err() { + return false; + } + } + } + + /* First, check our pending buffers */ + while let Some(item) = self.translated_items.front() { + // Note: items pts start from 0 + lateness + gst::trace!( + CAT, + imp: self.pad, + "Checking now {now} if item is ready for dequeuing, PTS {}, threshold {} vs {}", + item.pts, + item.pts + self.our_latency.saturating_sub(3 * GRANULARITY), + now - start_time + ); + + // Margin: + // - 1 * GRANULARITY: the time it will take before we can check this again, + // without running late, in the case of a timeout. + // - 2 * GRANULARITY: extra margin to account for additional overheads. + // FIXME explaing which ones. + if item.pts + self.our_latency.saturating_sub(3 * GRANULARITY) < now - start_time { + /* Safe unwrap, we know we have an item */ + let TranslatedItem { + pts: item_pts, + mut duration, + content, + } = self.translated_items.pop_front().unwrap(); + + let mut pts = start_time + item_pts; + + let mut buf = gst::Buffer::from_mut_slice(content.into_bytes()); + { + let buf = buf.get_mut().unwrap(); + + if discont_pending { + buf.set_flags(gst::BufferFlags::DISCONT); + discont_pending = false; + } + + buf.set_pts(pts); + buf.set_duration(duration); + } + + use std::cmp::Ordering::*; + match pts.cmp(&last_position) { + Greater => { + // The buffer we are about to push starts after the end of + // last item previously pushed to the stream. + let gap_event = gst::event::Gap::builder(last_position) + .duration(pts - last_position) + .seqnum(self.seqnum) + .build(); + gst::log!(CAT, imp: self.pad, "Pushing gap: {last_position} -> {pts}"); + if !self.pad.obj().push_event(gap_event) { + return false; + } + } + Less => { + // The buffer we are about to push was expected to start + // before the end of last item previously pushed to the stream. + // => update it to fit in stream. + let delta = last_position - pts; + + gst::warning!( + CAT, + imp: self.pad, + "Updating item PTS ({pts} < {last_position}), consider increasing latency", + ); + + pts = last_position; + // FIXME if the resulting duration is zero, we might as well not push it. + duration = duration.saturating_sub(delta); + + { + let buf_mut = buf.get_mut().unwrap(); + + buf_mut.set_pts(pts); + buf_mut.set_duration(duration); + } + } + _ => (), + } + + last_position = pts + duration; + + gst::debug!(CAT, imp: self.pad, "Pushing buffer: {pts} -> {}", pts + duration); + + if self.pad.obj().push(buf).is_err() { + return false; + } + } else { + // Current and subsequent items are not ready to be pushed + break; + } + } + + if self.send_eos + && self.pending_translations == 0 + && self.translated_items.is_empty() + && self.translation_queue.is_empty() + { + /* We're EOS, we can pause and exit early */ + let _ = self.pad.obj().pause_task(); + + gst::info!(CAT, imp: self.pad, "Sending eos"); + return self + .pad + .obj() + .push_event(gst::event::Eos::builder().seqnum(self.seqnum).build()); + } + + /* next, push a gap if we're lagging behind the target position */ + gst::trace!( + CAT, + imp: self.pad, + "Checking now: {now} if we need to push a gap, last_position: {last_position}, threshold: {}", + last_position + self.our_latency.saturating_sub(GRANULARITY) + ); + + if now > last_position + self.our_latency.saturating_sub(GRANULARITY) { + // We are running out of latency budget since last time we pushed downstream, + // so push a Gap long enough to keep continuity before we dequeue again: + // worse case scenario, this is GRANULARITY ms from now. + let duration = now - last_position - self.our_latency.saturating_sub(GRANULARITY); + + let gap_event = gst::event::Gap::builder(last_position) + .duration(duration) + .seqnum(self.seqnum) + .build(); + + gst::log!( + CAT, + imp: self.pad, + "Pushing gap: {last_position} -> {}", + last_position + duration + ); + + last_position += duration; + + if !self.pad.obj().push_event(gap_event) { + return false; + } + } + + let mut pad_state = self.pad.state.lock().unwrap(); + pad_state.out_segment.set_position(last_position); + pad_state.discont_pending = discont_pending; + + true + } + + async fn send_for_translation( + &mut self, + transcript_item: TranscriptItem, + ) -> Result<(), gst::ErrorMessage> { + let res = self + .to_translation_tx + .as_mut() + .expect("to_translation chan must be available in translation mode") + .send(transcript_item) + .await; + + if res.is_err() { + const ERR: &str = "to_translation chan terminated"; + gst::debug!(CAT, imp: self.pad, "{ERR}"); + return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); + } + + self.pending_translations += 1; + + Ok(()) + } + + fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> { + if !self.send_events { + return Ok(()); + } + + let mut events = vec![]; + + { + let elem_imp = self.elem.imp(); + let elem_state = elem_imp.state.lock().unwrap(); + + let mut pad_state = self.pad.state.lock().unwrap(); + + self.seqnum = elem_state.seqnum; + pad_state.out_segment = Default::default(); + + events.push( + gst::event::StreamStart::builder("transcription") + .seqnum(self.seqnum) + .build(), + ); + + let caps = gst::Caps::builder("text/x-raw") + .field("format", "utf8") + .build(); + events.push(gst::event::Caps::builder(&caps).seqnum(self.seqnum).build()); + + events.push( + gst::event::Segment::builder(&pad_state.out_segment) + .seqnum(self.seqnum) + .build(), + ); + } + + for event in events.drain(..) { + gst::info!(CAT, imp: self.pad, "Sending {event:?}"); + if !self.pad.obj().push_event(event) { + const ERR: &str = "Failed to send initial"; + gst::error!(CAT, imp: self.pad, "{ERR}"); + return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); + } + } + + self.send_events = false; + + Ok(()) + } +} + +impl TranslationPadTask { + async fn init_translate(&mut self) -> Result<(), gst::ErrorMessage> { + let mut translation_loop = None; + + { + let elem_imp = self.elem.imp(); + let elem_settings = elem_imp.settings.lock().unwrap(); + + let pad_settings = self.pad.settings.lock().unwrap(); + + self.our_latency = TranslationSrcPad::our_latency(&elem_settings, &pad_settings); + if self.our_latency + elem_settings.lateness <= 2 * GRANULARITY { + let err = format!( + "total latency + lateness must be greater than {}", + 2 * GRANULARITY + ); + gst::error!(CAT, imp: self.pad, "{err}"); + return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"])); + } + + self.translate_latency = elem_settings.translate_latency; + self.transcript_lookahead = elem_settings.transcript_lookahead; + + self.needs_translate = TranslationSrcPad::needs_translation( + &elem_settings.language_code, + pad_settings.language_code.as_deref(), + ); + + if self.needs_translate { + let (to_translation_tx, to_translation_rx) = mpsc::channel(64); + let (from_translation_tx, from_translation_rx) = mpsc::channel(64); + + translation_loop = Some(TranslationLoop::new( + elem_imp, + &self.pad, + &elem_settings.language_code, + pad_settings.language_code.as_deref().unwrap(), + to_translation_rx, + from_translation_tx, + )); + + self.to_translation_tx = Some(to_translation_tx); + self.from_translation_rx = Some(from_translation_rx); + } + } + + if let Some(translation_loop) = translation_loop { + translation_loop.check_language().await?; + self.translation_loop_handle = Some(RUNTIME.spawn(translation_loop.run())); + } + + Ok(()) + } +} + +#[derive(Debug)] +struct TranslationPadState { + discont_pending: bool, + out_segment: gst::FormattedSegment, + task_abort_handle: Option, +} + +impl Default for TranslationPadState { + fn default() -> TranslationPadState { + TranslationPadState { + discont_pending: true, + out_segment: Default::default(), + task_abort_handle: None, + } + } +} + +#[derive(Debug, Default, Clone)] +struct TranslationPadSettings { + language_code: Option, +} + +#[derive(Debug, Default)] +pub struct TranslationSrcPad { + state: Mutex, + settings: Mutex, +} + +impl TranslationSrcPad { + fn start_task(&self) -> Result<(), gst::LoggableError> { + gst::debug!(CAT, imp: self, "Starting task"); + + let elem = self.parent(); + let transcript_event_rx = elem.imp().transcript_event_tx.subscribe(); + let mut pad_task = TranslationPadTask::try_new(self, elem, transcript_event_rx) + .map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?; + + let imp = self.ref_counted(); + let res = self.obj().start_task(move || { + let (abortable_task_iter, abort_handle) = future::abortable(pad_task.run_iter()); + imp.state.lock().unwrap().task_abort_handle = Some(abort_handle); + + let _enter = RUNTIME.enter(); + match futures::executor::block_on(abortable_task_iter) { + Ok(Ok(())) => (), + Ok(Err(err)) => { + // Don't bring down the whole element if this Pad fails + // FIXME is there a way to mark the Pad in error though? + gst::info!(CAT, imp: imp, "Pausing task due to: {err}"); + let _ = imp.obj().pause_task(); + } + Err(_) => gst::debug!(CAT, imp: imp, "task iter aborted"), + } + }); + + if res.is_err() { + return Err(gst::loggable_error!(CAT, "Failed to start pad task")); + } + + gst::debug!(CAT, imp: self, "Task started"); + + Ok(()) + } + + fn stop_task(&self) { + gst::debug!(CAT, imp: self, "Stopping task"); + + // See also the note in `start_task()`: + // 1. Mark the task as stopped so no further iteration is executed. + let _ = self.obj().stop_task(); + + // 2. Abort the task iteration if the Future is pending. + if let Some(task_abort_handle) = self.state.lock().unwrap().task_abort_handle.take() { + task_abort_handle.abort(); + } + + gst::debug!(CAT, imp: self, "Task stopped"); + } + + fn set_discont(&self) { + self.state.lock().unwrap().discont_pending = true; + } + + #[inline] + fn needs_translation(input_lang: &str, output_lang: Option<&str>) -> bool { + output_lang.map_or(false, |other| { + !input_lang.eq_ignore_ascii_case(other.as_ref()) + }) + } + + #[inline] + fn our_latency( + elem_settings: &Settings, + pad_settings: &TranslationPadSettings, + ) -> gst::ClockTime { + if Self::needs_translation( + &elem_settings.language_code, + pad_settings.language_code.as_deref(), + ) { + elem_settings.transcribe_latency + + elem_settings.transcript_lookahead + + elem_settings.translate_latency + } else { + elem_settings.transcribe_latency + } + } + + #[track_caller] + fn parent(&self) -> super::Transcriber { + self.obj() + .parent() + .map(|elem_obj| { + elem_obj + .downcast::() + .expect("Wrong Element type") + }) + .expect("Pad should have a parent at this stage") + } +} + +impl TranslationSrcPad { + #[track_caller] + pub fn activatemode( + _elem: &Transcriber, + pad: &super::TranslationSrcPad, + _mode: gst::PadMode, + active: bool, + ) -> Result<(), gst::LoggableError> { + if active { + pad.imp().start_task()?; + } else { + pad.imp().stop_task(); + } + + Ok(()) + } + + pub fn src_query( + elem: &Transcriber, + pad: &super::TranslationSrcPad, + query: &mut gst::QueryRef, + ) -> bool { + gst::log!(CAT, obj: pad, "Handling query {query:?}"); + + use gst::QueryViewMut::*; + match query.view_mut() { + Latency(q) => { + let mut peer_query = gst::query::Latency::new(); + + let ret = elem.sinkpad.peer_query(&mut peer_query); + + if ret { + let (_, min, _) = peer_query.result(); + + let our_latency = { + let elem_settings = elem.settings.lock().unwrap(); + let pad_settings = pad.imp().settings.lock().unwrap(); + + Self::our_latency(&elem_settings, &pad_settings) + }; + + gst::info!(CAT, obj: pad, "Our latency {our_latency}"); + q.set(true, our_latency + min, gst::ClockTime::NONE); + } + ret + } + Position(q) => { + if q.format() == gst::Format::Time { + let stream_time = { + let state = pad.imp().state.lock().unwrap(); + state + .out_segment + .to_stream_time(state.out_segment.position()) + }; + + let Some(stream_time) = stream_time else { return false }; + q.set(stream_time); + + true + } else { + false + } + } + _ => gst::Pad::query_default(pad, Some(pad), query), + } + } +} + +#[glib::object_subclass] +impl ObjectSubclass for TranslationSrcPad { + const NAME: &'static str = "GstTranslationSrcPad"; + type Type = super::TranslationSrcPad; + type ParentType = gst::Pad; + + fn new() -> Self { + Default::default() + } +} + +impl ObjectImpl for TranslationSrcPad { + fn properties() -> &'static [glib::ParamSpec] { + static PROPERTIES: Lazy> = Lazy::new(|| { + vec![glib::ParamSpecString::builder(OUTPUT_LANG_CODE_PROPERTY) + .nick("Language Code") + .blurb("The Language the Stream must be translated to") + .default_value(DEFAULT_OUTPUT_LANG_CODE) + .mutable_ready() + .build()] + }); + + PROPERTIES.as_ref() + } + + fn set_property(&self, _id: usize, value: &glib::Value, pspec: &glib::ParamSpec) { + match pspec.name() { + OUTPUT_LANG_CODE_PROPERTY => { + self.settings.lock().unwrap().language_code = value.get().unwrap() + } + _ => unimplemented!(), + } + } + + fn property(&self, _id: usize, pspec: &glib::ParamSpec) -> glib::Value { + match pspec.name() { + OUTPUT_LANG_CODE_PROPERTY => self.settings.lock().unwrap().language_code.to_value(), + _ => unimplemented!(), + } + } +} + +impl GstObjectImpl for TranslationSrcPad {} + +impl PadImpl for TranslationSrcPad {} diff --git a/net/aws/src/transcriber/mod.rs b/net/aws/src/transcriber/mod.rs index 69ac60597..eb2a28f7c 100644 --- a/net/aws/src/transcriber/mod.rs +++ b/net/aws/src/transcriber/mod.rs @@ -10,6 +10,18 @@ use gst::glib; use gst::prelude::*; mod imp; +mod transcribe; +mod translate; + +use once_cell::sync::Lazy; + +static CAT: Lazy = Lazy::new(|| { + gst::DebugCategory::new( + "awstranscribe", + gst::DebugColorFlags::empty(), + Some("AWS Transcribe element"), + ) +}); use aws_sdk_transcribestreaming::model::{PartialResultsStability, VocabularyFilterMethod}; @@ -68,7 +80,11 @@ impl From for VocabularyFilterMethod { } glib::wrapper! { - pub struct Transcriber(ObjectSubclass) @extends gst::Element, gst::Object; + pub struct Transcriber(ObjectSubclass) @extends gst::Element, gst::Object, @implements gst::ChildProxy; +} + +glib::wrapper! { + pub struct TranslationSrcPad(ObjectSubclass) @extends gst::Pad, gst::Object; } pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { @@ -78,6 +94,7 @@ pub fn register(plugin: &gst::Plugin) -> Result<(), glib::BoolError> { .mark_as_plugin_api(gst::PluginAPIFlags::empty()); AwsTranscriberVocabularyFilterMethod::static_type() .mark_as_plugin_api(gst::PluginAPIFlags::empty()); + TranslationSrcPad::static_type().mark_as_plugin_api(gst::PluginAPIFlags::empty()); } gst::Element::register( Some(plugin), diff --git a/net/aws/src/transcriber/transcribe.rs b/net/aws/src/transcriber/transcribe.rs new file mode 100644 index 000000000..7b683f3bf --- /dev/null +++ b/net/aws/src/transcriber/transcribe.rs @@ -0,0 +1,277 @@ +// Copyright (C) 2020 Mathieu Duponchelle +// Copyright (C) 2023 François Laignel +// +// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. +// If a copy of the MPL was not distributed with this file, You can obtain one at +// . +// +// SPDX-License-Identifier: MPL-2.0 + +use gst::subclass::prelude::*; +use gst::{glib, prelude::*}; + +use aws_sdk_transcribestreaming as aws_transcribe; +use aws_sdk_transcribestreaming::model; + +use futures::channel::mpsc; +use futures::prelude::*; +use tokio::sync::broadcast; + +use std::sync::Arc; + +use super::imp::{Settings, Transcriber}; +use super::CAT; + +#[derive(Debug)] +pub struct TranscriptionSettings { + lang_code: model::LanguageCode, + sample_rate: i32, + vocabulary: Option, + vocabulary_filter: Option, + vocabulary_filter_method: model::VocabularyFilterMethod, + session_id: Option, + results_stability: model::PartialResultsStability, +} + +impl TranscriptionSettings { + pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self { + TranscriptionSettings { + lang_code: settings.language_code.as_str().into(), + sample_rate, + vocabulary: settings.vocabulary.clone(), + vocabulary_filter: settings.vocabulary_filter.clone(), + vocabulary_filter_method: settings.vocabulary_filter_method.into(), + session_id: settings.session_id.clone(), + results_stability: settings.results_stability.into(), + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct TranscriptItem { + pub pts: gst::ClockTime, + pub duration: gst::ClockTime, + pub content: String, + pub is_punctuation: bool, +} + +impl TranscriptItem { + pub fn from(item: model::Item, lateness: gst::ClockTime) -> Option { + let content = item.content?; + + let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness; + let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness; + + Some(TranscriptItem { + pts: start_time, + duration: end_time - start_time, + content, + is_punctuation: matches!(item.r#type, Some(model::ItemType::Punctuation)), + }) + } + + #[inline] + pub fn push(&mut self, item: &TranscriptItem) { + self.duration += item.duration; + + self.is_punctuation &= item.is_punctuation; + if !item.is_punctuation { + self.content.push(' '); + } + + self.content.push_str(&item.content); + } +} + +#[derive(Clone)] +pub enum TranscriptEvent { + Items(Arc>), + Eos, +} + +impl From> for TranscriptEvent { + fn from(transcript_items: Vec) -> Self { + TranscriptEvent::Items(transcript_items.into()) + } +} + +pub struct TranscriberLoop { + imp: glib::subclass::ObjectImplRef, + client: aws_transcribe::Client, + settings: Option, + lateness: gst::ClockTime, + buffer_rx: Option>, + transcript_items_tx: broadcast::Sender, + partial_index: usize, +} + +impl TranscriberLoop { + pub fn new( + imp: &Transcriber, + settings: TranscriptionSettings, + lateness: gst::ClockTime, + buffer_rx: mpsc::Receiver, + transcript_items_tx: broadcast::Sender, + ) -> Self { + let aws_config = imp.aws_config.lock().unwrap(); + let aws_config = aws_config + .as_ref() + .expect("aws_config must be initialized at this stage"); + + TranscriberLoop { + imp: imp.ref_counted(), + client: aws_transcribe::Client::new(aws_config), + settings: Some(settings), + lateness, + buffer_rx: Some(buffer_rx), + transcript_items_tx, + partial_index: 0, + } + } + + pub async fn run(mut self) -> Result<(), gst::ErrorMessage> { + // Stream the incoming buffers chunked + let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| { + async_stream::stream! { + let data = buffer.map_readable().unwrap(); + use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; + for chunk in data.chunks(8192) { + yield Ok(AudioStream::AudioEvent(AudioEvent::builder().audio_chunk(Blob::new(chunk)).build())); + } + } + }); + + let settings = self.settings.take().unwrap(); + let mut transcribe_builder = self + .client + .start_stream_transcription() + .language_code(settings.lang_code) + .media_sample_rate_hertz(settings.sample_rate) + .media_encoding(model::MediaEncoding::Pcm) + .enable_partial_results_stabilization(true) + .partial_results_stability(settings.results_stability) + .set_vocabulary_name(settings.vocabulary) + .set_session_id(settings.session_id); + + if let Some(vocabulary_filter) = settings.vocabulary_filter { + transcribe_builder = transcribe_builder + .vocabulary_filter_name(vocabulary_filter) + .vocabulary_filter_method(settings.vocabulary_filter_method); + } + + let mut output = transcribe_builder + .audio_stream(chunk_stream.into()) + .send() + .await + .map_err(|err| { + let err = format!("Transcribe ws init error: {err}"); + gst::error!(CAT, imp: self.imp, "{err}"); + gst::error_msg!(gst::LibraryError::Init, ["{err}"]) + })?; + + while let Some(event) = output + .transcript_result_stream + .recv() + .await + .map_err(|err| { + let err = format!("Transcribe ws stream error: {err}"); + gst::error!(CAT, imp: self.imp, "{err}"); + gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) + })? + { + if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { + let mut ready_items = None; + + if let Some(result) = transcript_evt + .transcript + .and_then(|transcript| transcript.results) + .and_then(|mut results| results.drain(..).next()) + { + gst::trace!(CAT, imp: self.imp, "Received: {result:?}"); + + if let Some(alternative) = result + .alternatives + .and_then(|mut alternatives| alternatives.drain(..).next()) + { + ready_items = alternative.items.and_then(|items| { + self.get_ready_transcript_items(items, result.is_partial) + }); + } + } + + if let Some(ready_items) = ready_items { + if self.transcript_items_tx.send(ready_items.into()).is_err() { + gst::debug!(CAT, imp: self.imp, "No transcript items receivers"); + break; + } + } + } else { + gst::warning!( + CAT, + imp: self.imp, + "Transcribe ws returned unknown event: consider upgrading the SDK" + ) + } + } + + gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS"); + let _ = self.transcript_items_tx.send(TranscriptEvent::Eos); + + gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop"); + + Ok(()) + } + + /// Builds a list from the provided stable items. + fn get_ready_transcript_items( + &mut self, + mut items: Vec, + partial: bool, + ) -> Option> { + if items.len() <= self.partial_index { + gst::error!( + CAT, + imp: self.imp, + "sanity check failed, alternative length {} < partial_index {}", + items.len(), + self.partial_index + ); + + if !partial { + self.partial_index = 0; + } + + return None; + } + + let mut output = vec![]; + + for item in items.drain(self.partial_index..) { + if !item.stable().unwrap_or(false) { + break; + } + + let Some(item) = TranscriptItem::from(item, self.lateness) else { continue }; + gst::debug!( + CAT, + imp: self.imp, + "Item is ready for queuing: {}, PTS {}", + item.content, + item.pts, + ); + + self.partial_index += 1; + output.push(item); + } + + if !partial { + self.partial_index = 0; + } + + if output.is_empty() { + return None; + } + + Some(output) + } +} diff --git a/net/aws/src/transcriber/translate.rs b/net/aws/src/transcriber/translate.rs new file mode 100644 index 000000000..b689bd637 --- /dev/null +++ b/net/aws/src/transcriber/translate.rs @@ -0,0 +1,215 @@ +// Copyright (C) 2023 François Laignel +// +// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0. +// If a copy of the MPL was not distributed with this file, You can obtain one at +// . +// +// SPDX-License-Identifier: MPL-2.0 + +use gst::glib; +use gst::subclass::prelude::*; + +use aws_sdk_translate as aws_translate; + +use futures::channel::mpsc; +use futures::prelude::*; + +use std::collections::VecDeque; + +use super::imp::TranslationSrcPad; +use super::transcribe::TranscriptItem; +use super::CAT; + +pub struct TranslatedItem { + pub pts: gst::ClockTime, + pub duration: gst::ClockTime, + pub content: String, +} + +impl From<&TranscriptItem> for TranslatedItem { + fn from(transcript_item: &TranscriptItem) -> Self { + TranslatedItem { + pts: transcript_item.pts, + duration: transcript_item.duration, + content: transcript_item.content.clone(), + } + } +} + +#[derive(Default)] +pub struct TranslationQueue { + items: VecDeque, +} + +impl TranslationQueue { + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Pushes the provided item. + /// + /// Returns `Some(..)` if items are ready for translation. + pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option { + // Keep track of the item individually so we can schedule translation precisely. + self.items.push_back(transcript_item.clone()); + + if transcript_item.is_punctuation { + // This makes it a good chunk for translation. + // Concatenate as a single item for translation + + let mut items = self.items.drain(..); + + let mut item_acc = items.next()?; + for item in items { + item_acc.push(&item); + } + + item_acc.push(transcript_item); + + return Some(item_acc); + } + + // Regular case: no separator detected, don't push transcript items + // to translation now. They will be pushed either if a punctuation + // is found or of a `dequeue()` is requested. + + None + } + + /// Dequeues items from the specified `deadline` up to `lookahead`. + /// + /// Returns `Some(..)` with the accumulated items matching the criteria. + pub fn dequeue( + &mut self, + deadline: gst::ClockTime, + lookahead: gst::ClockTime, + ) -> Option { + if self.items.front()?.pts < deadline { + // First item is too early to be sent to translation now + // we can wait for more items to accumulate. + return None; + } + + // Can't wait any longer to send the first item to translation + // Try to get up to lookahead more items to improve translation accuracy + let limit = deadline + lookahead; + + let mut item_acc = self.items.pop_front().unwrap(); + while let Some(item) = self.items.front() { + if item.pts > limit { + break; + } + + let item = self.items.pop_front().unwrap(); + item_acc.push(&item); + } + + Some(item_acc) + } +} + +pub struct TranslationLoop { + pad: glib::subclass::ObjectImplRef, + client: aws_translate::Client, + input_lang: String, + output_lang: String, + transcript_rx: mpsc::Receiver, + translation_tx: mpsc::Sender, +} + +impl TranslationLoop { + pub fn new( + imp: &super::imp::Transcriber, + pad: &TranslationSrcPad, + input_lang: &str, + output_lang: &str, + transcript_rx: mpsc::Receiver, + translation_tx: mpsc::Sender, + ) -> Self { + let aws_config = imp.aws_config.lock().unwrap(); + let aws_config = aws_config + .as_ref() + .expect("aws_config must be initialized at this stage"); + + TranslationLoop { + pad: pad.ref_counted(), + client: aws_sdk_translate::Client::new(aws_config), + input_lang: input_lang.to_string(), + output_lang: output_lang.to_string(), + transcript_rx, + translation_tx, + } + } + + pub async fn check_language(&self) -> Result<(), gst::ErrorMessage> { + let language_list = self.client.list_languages().send().await.map_err(|err| { + let err = format!("Failed to call list_languages service: {err}"); + gst::info!(CAT, imp: self.pad, "{err}"); + gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) + })?; + + let found_output_lang = language_list + .languages() + .and_then(|langs| { + langs + .iter() + .find(|lang| lang.language_code() == Some(&self.output_lang)) + }) + .is_some(); + + if !found_output_lang { + let err = format!("Unknown output languages: {}", self.output_lang); + gst::info!(CAT, imp: self.pad, "{err}"); + return Err(gst::error_msg!(gst::LibraryError::Failed, ["{err}"])); + } + + Ok(()) + } + + pub async fn run(mut self) -> Result<(), gst::ErrorMessage> { + while let Some(transcript_item) = self.transcript_rx.next().await { + let TranscriptItem { + pts, + duration, + content, + .. + } = transcript_item; + + let translated_text = if content.is_empty() { + content + } else { + self.client + .translate_text() + .set_source_language_code(Some(self.input_lang.clone())) + .set_target_language_code(Some(self.output_lang.clone())) + .set_text(Some(content)) + .send() + .await + .map_err(|err| { + let err = format!("Failed to call translation service: {err}"); + gst::info!(CAT, imp: self.pad, "{err}"); + gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) + })? + .translated_text + .unwrap_or_default() + }; + + let translated_item = TranslatedItem { + pts, + duration, + content: translated_text, + }; + + if self.translation_tx.send(translated_item).await.is_err() { + gst::info!( + CAT, + imp: self.pad, + "translation chan terminated, exiting translation loop" + ); + break; + } + } + + Ok(()) + } +}