net/aws/transcriber: use two queues for sending transcript items

* A queue dedicated to transcript items not intended for translation.
* A queue dedicated to transcript items intended for translation. The items are
  enqueued after a separator is detected or translate-lookahead was reached.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1137>
This commit is contained in:
François Laignel 2023-03-16 18:20:08 +01:00
parent 5a5ca76d9d
commit 2b32d00589
3 changed files with 373 additions and 323 deletions

View file

@ -29,12 +29,12 @@ use futures::prelude::*;
use tokio::{runtime, sync::broadcast, task}; use tokio::{runtime, sync::broadcast, task};
use std::collections::{BTreeSet, VecDeque}; use std::collections::{BTreeSet, VecDeque};
use std::sync::Mutex; use std::sync::{Arc, Mutex};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use super::transcribe::{TranscriberLoop, TranscriptEvent, TranscriptItem, TranscriptionSettings}; use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem};
use super::translate::{TranslateLoop, TranslateQueue, TranslatedItem}; use super::translate::{TranslateLoop, TranslatedItem};
use super::{ use super::{
AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod, AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod,
TranslationTokenizationMethod, CAT, TranslationTokenizationMethod, CAT,
@ -148,6 +148,7 @@ struct State {
srcpads: BTreeSet<super::TranslateSrcPad>, srcpads: BTreeSet<super::TranslateSrcPad>,
pad_serial: u32, pad_serial: u32,
seqnum: gst::Seqnum, seqnum: gst::Seqnum,
start_time: Option<gst::ClockTime>,
} }
impl Default for State { impl Default for State {
@ -158,6 +159,7 @@ impl Default for State {
srcpads: Default::default(), srcpads: Default::default(),
pad_serial: 0, pad_serial: 0,
seqnum: gst::Seqnum::next(), seqnum: gst::Seqnum::next(),
start_time: None,
} }
} }
} }
@ -168,7 +170,9 @@ pub struct Transcriber {
settings: Mutex<Settings>, settings: Mutex<Settings>,
state: Mutex<State>, state: Mutex<State>,
pub(super) aws_config: Mutex<Option<aws_config::SdkConfig>>, pub(super) aws_config: Mutex<Option<aws_config::SdkConfig>>,
// sender to broadcast transcript items to the translate src pads. // sender to broadcast transcript items to the src pads for translation.
transcript_event_for_translate_tx: broadcast::Sender<TranscriptEvent>,
// sender to broadcast transcript items to the src pads, not intended for translation.
transcript_event_tx: broadcast::Sender<TranscriptEvent>, transcript_event_tx: broadcast::Sender<TranscriptEvent>,
} }
@ -276,7 +280,10 @@ impl Transcriber {
) -> Result<gst::FlowSuccess, gst::FlowError> { ) -> Result<gst::FlowSuccess, gst::FlowError> {
gst::log!(CAT, obj: pad, "Handling {buffer:?}"); gst::log!(CAT, obj: pad, "Handling {buffer:?}");
self.ensure_connection(); self.ensure_connection().map_err(|err| {
gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
gst::FlowError::Error
})?;
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else { let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
gst::log!(CAT, obj: pad, "Flushing"); gst::log!(CAT, obj: pad, "Flushing");
@ -292,12 +299,82 @@ impl Transcriber {
Ok(gst::FlowSuccess::Ok) Ok(gst::FlowSuccess::Ok)
} }
}
fn ensure_connection(&self) { #[derive(Default)]
struct TranslateQueue {
items: VecDeque<TranscriptItem>,
}
impl TranslateQueue {
fn is_empty(&self) -> bool {
self.items.is_empty()
}
/// Pushes the provided item.
///
/// Returns `Some(..)` if items are ready for translation.
fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
// Keep track of the item individually so we can schedule translation precisely.
self.items.push_back(transcript_item.clone());
if transcript_item.is_punctuation {
// This makes it a good chunk for translation.
// Concatenate as a single item for translation
return Some(self.items.drain(..).collect());
}
// Regular case: no separator detected, don't push transcript items
// to translation now. They will be pushed either if a punctuation
// is found or of a `dequeue()` is requested.
None
}
/// Dequeues items from the specified `deadline` up to `lookahead`.
///
/// Returns `Some(..)` if some items match the criteria.
fn dequeue(
&mut self,
latency: gst::ClockTime,
threshold: gst::ClockTime,
lookahead: gst::ClockTime,
) -> Option<Vec<TranscriptItem>> {
let first_pts = self.items.front()?.pts;
if first_pts + latency > threshold {
// First item is too early to be sent to translation now
// we can wait for more items to accumulate.
return None;
}
// Can't wait any longer to send the first item to translation
// Try to get up to lookahead worth of items to improve translation accuracy
let limit = first_pts + lookahead;
let mut items_acc = vec![self.items.pop_front().unwrap()];
while let Some(item) = self.items.front() {
if item.pts > limit {
break;
}
items_acc.push(self.items.pop_front().unwrap());
}
Some(items_acc)
}
fn drain(&mut self) -> impl Iterator<Item = TranscriptItem> + '_ {
self.items.drain(..)
}
}
impl Transcriber {
fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
if state.buffer_tx.is_some() { if state.buffer_tx.is_some() {
return; return Ok(());
} }
let settings = self.settings.lock().unwrap(); let settings = self.settings.lock().unwrap();
@ -306,21 +383,116 @@ impl Transcriber {
let s = in_caps.structure(0).unwrap(); let s = in_caps.structure(0).unwrap();
let sample_rate = s.get::<i32>("rate").unwrap(); let sample_rate = s.get::<i32>("rate").unwrap();
let transcription_settings = TranscriptionSettings::from(&settings, sample_rate); let transcription_settings = TranscriberSettings::from(&settings, sample_rate);
let (buffer_tx, buffer_rx) = mpsc::channel(1); let (buffer_tx, buffer_rx) = mpsc::channel(1);
let transcriber_loop = TranscriberLoop::new( let _enter = RUNTIME.enter();
let mut transcriber_stream = futures::executor::block_on(TranscriberStream::try_new(
self, self,
transcription_settings, transcription_settings,
settings.lateness, settings.lateness,
buffer_rx, buffer_rx,
self.transcript_event_tx.clone(), ))?;
// Latency budget for an item to be pushed to stream on time
// Margin:
// - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
// - 1 * GRANULARITY: extra margin to account for additional overheads.
let latency = settings.transcribe_latency.saturating_sub(3 * GRANULARITY);
let translate_lookahead = settings.translate_lookahead;
let mut translate_queue = TranslateQueue::default();
let imp = self.ref_counted();
let transcriber_loop_handle = RUNTIME.spawn(async move {
loop {
// This is to make sure we send items on a timely basis or at least Gap events.
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
futures::pin_mut!(timeout);
let transcriber_next = transcriber_stream.next().fuse();
futures::pin_mut!(transcriber_next);
// `transcriber_next` takes precedence over `timeout`
// because we don't want to loose any incoming items.
let res = futures::select_biased! {
event = transcriber_next => Some(event?),
_ = timeout => None,
};
use TranscriptEvent::*;
match res {
None => (),
Some(Items(items)) => {
if imp.transcript_event_tx.receiver_count() > 0 {
let _ = imp.transcript_event_tx.send(Items(items.clone()));
}
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
for item in items.iter() {
if let Some(items_to_translate) = translate_queue.push(item) {
let _ = imp
.transcript_event_for_translate_tx
.send(Items(items_to_translate.into()));
}
}
}
}
Some(Eos) => {
gst::debug!(CAT, imp: imp, "Transcriber loop sending EOS");
if imp.transcript_event_tx.receiver_count() > 0 {
let _ = imp.transcript_event_tx.send(Eos);
}
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
let items_to_translate: Vec<TranscriptItem> =
translate_queue.drain().collect();
let _ = imp
.transcript_event_for_translate_tx
.send(Items(items_to_translate.into()));
let _ = imp.transcript_event_for_translate_tx.send(Eos);
}
break;
}
}
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
// Check if we need to push items for translation
let Some((start_time, now)) = imp.get_start_time_and_now() else {
continue;
};
if !translate_queue.is_empty() {
let threshold = now - start_time;
if let Some(items_to_translate) =
translate_queue.dequeue(latency, threshold, translate_lookahead)
{
gst::debug!(
CAT,
imp: imp,
"Forcing to translation (threshold {threshold}): {items_to_translate:?}"
); );
let transcriber_loop_handle = RUNTIME.spawn(transcriber_loop.run()); let _ = imp
.transcript_event_for_translate_tx
.send(Items(items_to_translate.into()));
}
}
}
}
gst::debug!(CAT, imp: imp, "Exiting transcriber loop");
Ok(())
});
state.transcriber_loop_handle = Some(transcriber_loop_handle); state.transcriber_loop_handle = Some(transcriber_loop_handle);
state.buffer_tx = Some(buffer_tx); state.buffer_tx = Some(buffer_tx);
Ok(())
} }
fn prepare(&self) -> Result<(), gst::ErrorMessage> { fn prepare(&self) -> Result<(), gst::ErrorMessage> {
@ -382,6 +554,18 @@ impl Transcriber {
} }
gst::info!(CAT, imp: self, "Unprepared"); gst::info!(CAT, imp: self, "Unprepared");
} }
fn get_start_time_and_now(&self) -> Option<(gst::ClockTime, gst::ClockTime)> {
let now = self.obj().current_running_time()?;
let mut state = self.state.lock().unwrap();
if state.start_time.is_none() {
state.start_time = Some(now);
}
Some((state.start_time.unwrap(), now))
}
} }
#[glib::object_subclass] #[glib::object_subclass]
@ -438,6 +622,7 @@ impl ObjectSubclass for Transcriber {
// Setting the channel capacity so that a TranslateSrcPad that would lag // Setting the channel capacity so that a TranslateSrcPad that would lag
// behind for some reasons get a chance to catch-up without loosing items. // behind for some reasons get a chance to catch-up without loosing items.
// Receiver will be created by subscribing to sender later. // Receiver will be created by subscribing to sender later.
let (transcript_event_for_translate_tx, _) = broadcast::channel(128);
let (transcript_event_tx, _) = broadcast::channel(128); let (transcript_event_tx, _) = broadcast::channel(128);
Self { Self {
@ -446,6 +631,7 @@ impl ObjectSubclass for Transcriber {
settings: Default::default(), settings: Default::default(),
state: Default::default(), state: Default::default(),
aws_config: Default::default(), aws_config: Default::default(),
transcript_event_for_translate_tx,
transcript_event_tx, transcript_event_tx,
} }
} }
@ -876,51 +1062,93 @@ struct TranslationPadTask {
elem: super::Transcriber, elem: super::Transcriber,
transcript_event_rx: broadcast::Receiver<TranscriptEvent>, transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
needs_translate: bool, needs_translate: bool,
translate_queue: TranslateQueue,
translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>, translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
to_translate_tx: Option<mpsc::Sender<Vec<TranscriptItem>>>, to_translate_tx: Option<mpsc::Sender<Arc<Vec<TranscriptItem>>>>,
from_translate_rx: Option<mpsc::Receiver<Vec<TranslatedItem>>>, from_translate_rx: Option<mpsc::Receiver<Vec<TranslatedItem>>>,
translate_latency: gst::ClockTime,
translate_lookahead: gst::ClockTime,
send_events: bool, send_events: bool,
output_items: VecDeque<OutputItem>, output_items: VecDeque<OutputItem>,
our_latency: gst::ClockTime, our_latency: gst::ClockTime,
seqnum: gst::Seqnum, seqnum: gst::Seqnum,
send_eos: bool, send_eos: bool,
pending_translations: usize, pending_translations: usize,
start_time: Option<gst::ClockTime>,
} }
impl TranslationPadTask { impl TranslationPadTask {
fn try_new( async fn try_new(
pad: &TranslateSrcPad, pad: &TranslateSrcPad,
elem: super::Transcriber, elem: super::Transcriber,
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
) -> Result<TranslationPadTask, gst::ErrorMessage> { ) -> Result<TranslationPadTask, gst::ErrorMessage> {
let mut this = TranslationPadTask { let mut translation_loop = None;
let mut translate_loop_handle = None;
let mut to_translate_tx = None;
let mut from_translate_rx = None;
let (our_latency, transcript_event_rx, needs_translate);
{
let elem_imp = elem.imp();
let elem_settings = elem_imp.settings.lock().unwrap();
let pad_settings = pad.settings.lock().unwrap();
our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
if our_latency + elem_settings.lateness <= 2 * GRANULARITY {
let err = format!(
"total latency + lateness must be greater than {}",
2 * GRANULARITY
);
gst::error!(CAT, imp: pad, "{err}");
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
}
needs_translate = TranslateSrcPad::needs_translation(
&elem_settings.language_code,
pad_settings.language_code.as_deref(),
);
if needs_translate {
let (to_loop_tx, to_loop_rx) = mpsc::channel(64);
let (from_loop_tx, from_loop_rx) = mpsc::channel(64);
translation_loop = Some(TranslateLoop::new(
elem_imp,
pad,
&elem_settings.language_code,
pad_settings.language_code.as_deref().unwrap(),
pad_settings.tokenization_method,
to_loop_rx,
from_loop_tx,
));
to_translate_tx = Some(to_loop_tx);
from_translate_rx = Some(from_loop_rx);
transcript_event_rx = elem_imp.transcript_event_for_translate_tx.subscribe();
} else {
transcript_event_rx = elem_imp.transcript_event_tx.subscribe();
}
}
if let Some(translation_loop) = translation_loop {
translation_loop.check_language().await?;
translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
}
Ok(TranslationPadTask {
pad: pad.ref_counted(), pad: pad.ref_counted(),
elem, elem,
transcript_event_rx, transcript_event_rx,
needs_translate: false, needs_translate,
translate_queue: TranslateQueue::default(), translate_loop_handle,
translate_loop_handle: None, to_translate_tx,
to_translate_tx: None, from_translate_rx,
from_translate_rx: None,
translate_latency: DEFAULT_TRANSLATE_LATENCY,
translate_lookahead: DEFAULT_TRANSLATE_LOOKAHEAD,
send_events: true, send_events: true,
output_items: VecDeque::new(), output_items: VecDeque::new(),
our_latency: DEFAULT_TRANSCRIBE_LATENCY, our_latency,
seqnum: gst::Seqnum::next(), seqnum: gst::Seqnum::next(),
send_eos: false, send_eos: false,
pending_translations: 0, pending_translations: 0,
start_time: None, })
};
let _enter_guard = RUNTIME.enter();
futures::executor::block_on(this.init_translate())?;
Ok(this)
} }
} }
@ -958,11 +1186,9 @@ impl TranslationPadTask {
let transcript_event_rx = self.transcript_event_rx.recv().fuse(); let transcript_event_rx = self.transcript_event_rx.recv().fuse();
futures::pin_mut!(transcript_event_rx); futures::pin_mut!(transcript_event_rx);
// `timeout` takes precedence over `transcript_events` reception // `transcript_event_rx` takes precedence over `timeout`
// because we may need to `dequeue` `items` or push a `Gap` event // because we don't want to loose any incoming items.
// before current latency budget is exhausted.
futures::select_biased! { futures::select_biased! {
_ = timeout => (),
items_res = transcript_event_rx => { items_res = transcript_event_rx => {
use TranscriptEvent::*; use TranscriptEvent::*;
use broadcast::error::RecvError; use broadcast::error::RecvError;
@ -983,6 +1209,7 @@ impl TranslationPadTask {
} }
} }
} }
_ = timeout => (),
} }
Ok(()) Ok(())
@ -999,121 +1226,100 @@ impl TranslationPadTask {
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
} }
let transcript_items = { let items_to_translate = {
// This is to make sure we send items on a timely basis or at least Gap events. // This is to make sure we send items on a timely basis or at least Gap events.
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse(); let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
futures::pin_mut!(timeout); futures::pin_mut!(timeout);
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
futures::pin_mut!(transcript_event_rx);
// `transcript_event_rx` takes precedence over `timeout`
// because we don't want to loose any incoming items.
futures::select_biased! {
items_res = transcript_event_rx => {
use TranscriptEvent::*;
use broadcast::error::RecvError;
match items_res {
Ok(Items(items_to_translate)) => Some(items_to_translate),
Ok(Eos) => {
gst::debug!(CAT, imp: self.pad, "Got eos");
self.send_eos = true;
None
}
Err(RecvError::Lagged(nb_msg)) => {
gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
None
}
Err(RecvError::Closed) => {
gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
self.send_eos = true;
None
}
}
}
_ = timeout => None,
}
};
if let Some(items_to_translate) = items_to_translate {
if !items_to_translate.is_empty() {
let res = self
.to_translate_tx
.as_mut()
.expect("to_translation chan must be available in translation mode")
.send(items_to_translate)
.await;
if res.is_err() {
const ERR: &str = "to_translation chan terminated";
gst::debug!(CAT, imp: self.pad, "{ERR}");
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
}
self.pending_translations += 1;
}
}
// Check pending translated items
let from_translate_rx = self let from_translate_rx = self
.from_translate_rx .from_translate_rx
.as_mut() .as_mut()
.expect("from_translation chan must be available in translation mode"); .expect("from_translation chan must be available in translation mode");
let transcript_event_rx = self.transcript_event_rx.recv().fuse(); while let Ok(translated_items) = from_translate_rx.try_next() {
futures::pin_mut!(transcript_event_rx);
// `timeout` takes precedence over `transcript_events` reception
// because we may need to `dequeue` `items` or push a `Gap` event
// before current latency budget is exhausted.
futures::select_biased! {
_ = timeout => return Ok(()),
translated_items = from_translate_rx.next() => {
let Some(translated_items) = translated_items else { let Some(translated_items) = translated_items else {
const ERR: &str = "translation chan terminated"; const ERR: &str = "translation chan terminated";
gst::debug!(CAT, imp: self.pad, "{ERR}"); gst::debug!(CAT, imp: self.pad, "{ERR}");
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"])); return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
}; };
self.output_items.extend(translated_items.into_iter().map(Into::into)); self.output_items
.extend(translated_items.into_iter().map(Into::into));
self.pending_translations = self.pending_translations.saturating_sub(1); self.pending_translations = self.pending_translations.saturating_sub(1);
return Ok(());
}
items_res = transcript_event_rx => {
use TranscriptEvent::*;
use broadcast::error::RecvError;
match items_res {
Ok(Items(transcript_items)) => transcript_items,
Ok(Eos) => {
gst::debug!(CAT, imp: self.pad, "Got eos");
self.send_eos = true;
return Ok(());
}
Err(RecvError::Lagged(nb_msg)) => {
gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
return Ok(());
}
Err(RecvError::Closed) => {
gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
self.send_eos = true;
return Ok(());
}
}
}
}
};
for items in transcript_items.iter() {
if let Some(items_to_translate) = self.translate_queue.push(items) {
self.send_for_translation(items_to_translate).await?;
}
}
Ok(())
}
async fn dequeue_for_translation(
&mut self,
start_time: gst::ClockTime,
now: gst::ClockTime,
) -> Result<(), gst::ErrorMessage> {
if !self.translate_queue.is_empty() {
// Latency budget for an item to be pushed to stream on time
// Margin:
// - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
// - 1 * GRANULARITY: extra margin to account for additional overheads.
let latency = self.our_latency.saturating_sub(3 * GRANULARITY);
// Estimated time of arrival for an item sent to translation now.
// (in transcript item ts base)
let translation_eta = now + self.translate_latency - start_time;
if let Some(items_to_translate) =
self.translate_queue
.dequeue(latency, translation_eta, self.translate_lookahead)
{
gst::debug!(CAT, imp: self.pad, "Forcing to translation: {items_to_translate:?}");
self.send_for_translation(items_to_translate).await?;
}
} }
Ok(()) Ok(())
} }
async fn dequeue(&mut self) -> bool { async fn dequeue(&mut self) -> bool {
let (now, start_time, mut last_position, mut discont_pending); let Some((start_time, now)) = self.elem.imp().get_start_time_and_now() else {
{
let mut pad_state = self.pad.state.lock().unwrap();
let Some(cur_rt) = self.elem.current_running_time() else {
// Wait for the clock to be available // Wait for the clock to be available
return true; return true;
}; };
now = cur_rt;
if self.start_time.is_none() { let (mut last_position, mut discont_pending) = {
self.start_time = Some(now); let mut state = self.pad.state.lock().unwrap();
pad_state.out_segment.set_position(now);
}
start_time = self.start_time.unwrap(); let last_position = if let Some(pos) = state.out_segment.position() {
last_position = pad_state.out_segment.position().unwrap(); pos
discont_pending = pad_state.discont_pending; } else {
} state.out_segment.set_position(start_time);
start_time
};
if self.needs_translate && self.dequeue_for_translation(start_time, now).await.is_err() { (last_position, state.discont_pending)
return false; };
}
/* First, check our pending buffers */ /* First, check our pending buffers */
while let Some(item) = self.output_items.front() { while let Some(item) = self.output_items.front() {
@ -1206,11 +1412,7 @@ impl TranslationPadTask {
} }
} }
if self.send_eos if self.send_eos && self.pending_translations == 0 && self.output_items.is_empty() {
&& self.pending_translations == 0
&& self.output_items.is_empty()
&& self.translate_queue.is_empty()
{
/* We're EOS, we can pause and exit early */ /* We're EOS, we can pause and exit early */
let _ = self.pad.obj().pause_task(); let _ = self.pad.obj().pause_task();
@ -1261,28 +1463,6 @@ impl TranslationPadTask {
true true
} }
async fn send_for_translation(
&mut self,
transcript_items: Vec<TranscriptItem>,
) -> Result<(), gst::ErrorMessage> {
let res = self
.to_translate_tx
.as_mut()
.expect("to_translation chan must be available in translation mode")
.send(transcript_items)
.await;
if res.is_err() {
const ERR: &str = "to_translation chan terminated";
gst::debug!(CAT, imp: self.pad, "{ERR}");
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
}
self.pending_translations += 1;
Ok(())
}
fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> { fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> {
if !self.send_events { if !self.send_events {
return Ok(()); return Ok(());
@ -1332,62 +1512,6 @@ impl TranslationPadTask {
} }
} }
impl TranslationPadTask {
async fn init_translate(&mut self) -> Result<(), gst::ErrorMessage> {
let mut translation_loop = None;
{
let elem_imp = self.elem.imp();
let elem_settings = elem_imp.settings.lock().unwrap();
let pad_settings = self.pad.settings.lock().unwrap();
self.our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
if self.our_latency + elem_settings.lateness <= 2 * GRANULARITY {
let err = format!(
"total latency + lateness must be greater than {}",
2 * GRANULARITY
);
gst::error!(CAT, imp: self.pad, "{err}");
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
}
self.translate_latency = elem_settings.translate_latency;
self.translate_lookahead = elem_settings.translate_lookahead;
self.needs_translate = TranslateSrcPad::needs_translation(
&elem_settings.language_code,
pad_settings.language_code.as_deref(),
);
if self.needs_translate {
let (to_translate_tx, to_translate_rx) = mpsc::channel(64);
let (from_translate_tx, from_translate_rx) = mpsc::channel(64);
translation_loop = Some(TranslateLoop::new(
elem_imp,
&self.pad,
&elem_settings.language_code,
pad_settings.language_code.as_deref().unwrap(),
pad_settings.tokenization_method,
to_translate_rx,
from_translate_tx,
));
self.to_translate_tx = Some(to_translate_tx);
self.from_translate_rx = Some(from_translate_rx);
}
}
if let Some(translation_loop) = translation_loop {
translation_loop.check_language().await?;
self.translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
}
Ok(())
}
}
#[derive(Debug)] #[derive(Debug)]
struct TranslationPadState { struct TranslationPadState {
discont_pending: bool, discont_pending: bool,
@ -1422,8 +1546,8 @@ impl TranslateSrcPad {
gst::debug!(CAT, imp: self, "Starting task"); gst::debug!(CAT, imp: self, "Starting task");
let elem = self.parent(); let elem = self.parent();
let transcript_event_rx = elem.imp().transcript_event_tx.subscribe(); let _enter = RUNTIME.enter();
let mut pad_task = TranslationPadTask::try_new(self, elem, transcript_event_rx) let mut pad_task = futures::executor::block_on(TranslationPadTask::try_new(self, elem))
.map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?; .map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?;
let imp = self.ref_counted(); let imp = self.ref_counted();

View file

@ -15,7 +15,6 @@ use aws_sdk_transcribestreaming::model;
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::prelude::*; use futures::prelude::*;
use tokio::sync::broadcast;
use std::sync::Arc; use std::sync::Arc;
@ -23,7 +22,7 @@ use super::imp::{Settings, Transcriber};
use super::CAT; use super::CAT;
#[derive(Debug)] #[derive(Debug)]
pub struct TranscriptionSettings { pub struct TranscriberSettings {
lang_code: model::LanguageCode, lang_code: model::LanguageCode,
sample_rate: i32, sample_rate: i32,
vocabulary: Option<String>, vocabulary: Option<String>,
@ -33,9 +32,9 @@ pub struct TranscriptionSettings {
results_stability: model::PartialResultsStability, results_stability: model::PartialResultsStability,
} }
impl TranscriptionSettings { impl TranscriberSettings {
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self { pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
TranscriptionSettings { TranscriberSettings {
lang_code: settings.language_code.as_str().into(), lang_code: settings.language_code.as_str().into(),
sample_rate, sample_rate,
vocabulary: settings.vocabulary.clone(), vocabulary: settings.vocabulary.clone(),
@ -83,43 +82,30 @@ impl From<Vec<TranscriptItem>> for TranscriptEvent {
} }
} }
pub struct TranscriberLoop { pub struct TranscriberStream {
imp: glib::subclass::ObjectImplRef<Transcriber>, imp: glib::subclass::ObjectImplRef<Transcriber>,
client: aws_transcribe::Client, output: aws_transcribe::output::StartStreamTranscriptionOutput,
settings: Option<TranscriptionSettings>,
lateness: gst::ClockTime, lateness: gst::ClockTime,
buffer_rx: Option<mpsc::Receiver<gst::Buffer>>,
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
partial_index: usize, partial_index: usize,
} }
impl TranscriberLoop { impl TranscriberStream {
pub fn new( pub async fn try_new(
imp: &Transcriber, imp: &Transcriber,
settings: TranscriptionSettings, settings: TranscriberSettings,
lateness: gst::ClockTime, lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>, buffer_rx: mpsc::Receiver<gst::Buffer>,
transcript_items_tx: broadcast::Sender<TranscriptEvent>, ) -> Result<Self, gst::ErrorMessage> {
) -> Self { let client = {
let aws_config = imp.aws_config.lock().unwrap(); let aws_config = imp.aws_config.lock().unwrap();
let aws_config = aws_config let aws_config = aws_config
.as_ref() .as_ref()
.expect("aws_config must be initialized at this stage"); .expect("aws_config must be initialized at this stage");
aws_transcribe::Client::new(aws_config)
};
TranscriberLoop {
imp: imp.ref_counted(),
client: aws_transcribe::Client::new(aws_config),
settings: Some(settings),
lateness,
buffer_rx: Some(buffer_rx),
transcript_items_tx,
partial_index: 0,
}
}
pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
// Stream the incoming buffers chunked // Stream the incoming buffers chunked
let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| { let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
async_stream::stream! { async_stream::stream! {
let data = buffer.map_readable().unwrap(); let data = buffer.map_readable().unwrap();
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob}; use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
@ -129,9 +115,7 @@ impl TranscriberLoop {
} }
}); });
let settings = self.settings.take().unwrap(); let mut transcribe_builder = client
let mut transcribe_builder = self
.client
.start_stream_transcription() .start_stream_transcription()
.language_code(settings.lang_code) .language_code(settings.lang_code)
.media_sample_rate_hertz(settings.sample_rate) .media_sample_rate_hertz(settings.sample_rate)
@ -147,17 +131,28 @@ impl TranscriberLoop {
.vocabulary_filter_method(settings.vocabulary_filter_method); .vocabulary_filter_method(settings.vocabulary_filter_method);
} }
let mut output = transcribe_builder let output = transcribe_builder
.audio_stream(chunk_stream.into()) .audio_stream(chunk_stream.into())
.send() .send()
.await .await
.map_err(|err| { .map_err(|err| {
let err = format!("Transcribe ws init error: {err}"); let err = format!("Transcribe ws init error: {err}");
gst::error!(CAT, imp: self.imp, "{err}"); gst::error!(CAT, imp: imp, "{err}");
gst::error_msg!(gst::LibraryError::Init, ["{err}"]) gst::error_msg!(gst::LibraryError::Init, ["{err}"])
})?; })?;
while let Some(event) = output Ok(TranscriberStream {
imp: imp.ref_counted(),
output,
lateness,
partial_index: 0,
})
}
pub async fn next(&mut self) -> Result<TranscriptEvent, gst::ErrorMessage> {
loop {
let event = self
.output
.transcript_result_stream .transcript_result_stream
.recv() .recv()
.await .await
@ -165,8 +160,13 @@ impl TranscriberLoop {
let err = format!("Transcribe ws stream error: {err}"); let err = format!("Transcribe ws stream error: {err}");
gst::error!(CAT, imp: self.imp, "{err}"); gst::error!(CAT, imp: self.imp, "{err}");
gst::error_msg!(gst::LibraryError::Failed, ["{err}"]) gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
})? })?;
{
let Some(event) = event else {
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
return Ok(TranscriptEvent::Eos);
};
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
let mut ready_items = None; let mut ready_items = None;
@ -188,10 +188,7 @@ impl TranscriberLoop {
} }
if let Some(ready_items) = ready_items { if let Some(ready_items) = ready_items {
if self.transcript_items_tx.send(ready_items.into()).is_err() { return Ok(ready_items.into());
gst::debug!(CAT, imp: self.imp, "No transcript items receivers");
break;
}
} }
} else { } else {
gst::warning!( gst::warning!(
@ -201,13 +198,6 @@ impl TranscriberLoop {
) )
} }
} }
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
let _ = self.transcript_items_tx.send(TranscriptEvent::Eos);
gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop");
Ok(())
} }
/// Builds a list from the provided stable items. /// Builds a list from the provided stable items.

View file

@ -14,7 +14,7 @@ use aws_sdk_translate as aws_translate;
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::prelude::*; use futures::prelude::*;
use std::collections::VecDeque; use std::sync::Arc;
use super::imp::TranslateSrcPad; use super::imp::TranslateSrcPad;
use super::transcribe::TranscriptItem; use super::transcribe::TranscriptItem;
@ -40,77 +40,13 @@ impl From<&TranscriptItem> for TranslatedItem {
} }
} }
#[derive(Default)]
pub struct TranslateQueue {
items: VecDeque<TranscriptItem>,
}
impl TranslateQueue {
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
/// Pushes the provided item.
///
/// Returns `Some(..)` if items are ready for translation.
pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
// Keep track of the item individually so we can schedule translation precisely.
self.items.push_back(transcript_item.clone());
if transcript_item.is_punctuation {
// This makes it a good chunk for translation.
// Concatenate as a single item for translation
return Some(self.items.drain(..).collect());
}
// Regular case: no separator detected, don't push transcript items
// to translation now. They will be pushed either if a punctuation
// is found or of a `dequeue()` is requested.
None
}
/// Dequeues items from the specified `deadline` up to `lookahead`.
///
/// Returns `Some(..)` if some items match the criteria.
pub fn dequeue(
&mut self,
latency: gst::ClockTime,
threshold: gst::ClockTime,
lookahead: gst::ClockTime,
) -> Option<Vec<TranscriptItem>> {
let first_pts = self.items.front()?.pts;
if first_pts + latency > threshold {
// First item is too early to be sent to translation now
// we can wait for more items to accumulate.
return None;
}
// Can't wait any longer to send the first item to translation
// Try to get up to lookahead worth of items to improve translation accuracy
let limit = first_pts + lookahead;
let mut items_acc = vec![self.items.pop_front().unwrap()];
while let Some(item) = self.items.front() {
if item.pts > limit {
break;
}
items_acc.push(self.items.pop_front().unwrap());
}
Some(items_acc)
}
}
pub struct TranslateLoop { pub struct TranslateLoop {
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>, pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
client: aws_translate::Client, client: aws_translate::Client,
input_lang: String, input_lang: String,
output_lang: String, output_lang: String,
tokenization_method: TranslationTokenizationMethod, tokenization_method: TranslationTokenizationMethod,
transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>, transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
translate_tx: mpsc::Sender<Vec<TranslatedItem>>, translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
} }
@ -121,7 +57,7 @@ impl TranslateLoop {
input_lang: &str, input_lang: &str,
output_lang: &str, output_lang: &str,
tokenization_method: TranslationTokenizationMethod, tokenization_method: TranslationTokenizationMethod,
transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>, transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
translate_tx: mpsc::Sender<Vec<TranslatedItem>>, translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
) -> Self { ) -> Self {
let aws_config = imp.aws_config.lock().unwrap(); let aws_config = imp.aws_config.lock().unwrap();
@ -175,12 +111,12 @@ impl TranslateLoop {
let (ts_duration_list, content): (Vec<(gst::ClockTime, gst::ClockTime)>, String) = let (ts_duration_list, content): (Vec<(gst::ClockTime, gst::ClockTime)>, String) =
transcript_items transcript_items
.into_iter() .iter()
.map(|item| { .map(|item| {
( (
(item.pts, item.duration), (item.pts, item.duration),
match self.tokenization_method { match self.tokenization_method {
Tokenization::None => item.content, Tokenization::None => item.content.clone(),
Tokenization::SpanBased => { Tokenization::SpanBased => {
format!("{SPAN_START}{}{SPAN_END}", item.content) format!("{SPAN_START}{}{SPAN_END}", item.content)
} }