mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-11-28 14:31:06 +00:00
net/aws/transcriber: use two queues for sending transcript items
* A queue dedicated to transcript items not intended for translation. * A queue dedicated to transcript items intended for translation. The items are enqueued after a separator is detected or translate-lookahead was reached. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1137>
This commit is contained in:
parent
5a5ca76d9d
commit
2b32d00589
3 changed files with 373 additions and 323 deletions
|
@ -29,12 +29,12 @@ use futures::prelude::*;
|
|||
use tokio::{runtime, sync::broadcast, task};
|
||||
|
||||
use std::collections::{BTreeSet, VecDeque};
|
||||
use std::sync::Mutex;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use super::transcribe::{TranscriberLoop, TranscriptEvent, TranscriptItem, TranscriptionSettings};
|
||||
use super::translate::{TranslateLoop, TranslateQueue, TranslatedItem};
|
||||
use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem};
|
||||
use super::translate::{TranslateLoop, TranslatedItem};
|
||||
use super::{
|
||||
AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod,
|
||||
TranslationTokenizationMethod, CAT,
|
||||
|
@ -148,6 +148,7 @@ struct State {
|
|||
srcpads: BTreeSet<super::TranslateSrcPad>,
|
||||
pad_serial: u32,
|
||||
seqnum: gst::Seqnum,
|
||||
start_time: Option<gst::ClockTime>,
|
||||
}
|
||||
|
||||
impl Default for State {
|
||||
|
@ -158,6 +159,7 @@ impl Default for State {
|
|||
srcpads: Default::default(),
|
||||
pad_serial: 0,
|
||||
seqnum: gst::Seqnum::next(),
|
||||
start_time: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -168,7 +170,9 @@ pub struct Transcriber {
|
|||
settings: Mutex<Settings>,
|
||||
state: Mutex<State>,
|
||||
pub(super) aws_config: Mutex<Option<aws_config::SdkConfig>>,
|
||||
// sender to broadcast transcript items to the translate src pads.
|
||||
// sender to broadcast transcript items to the src pads for translation.
|
||||
transcript_event_for_translate_tx: broadcast::Sender<TranscriptEvent>,
|
||||
// sender to broadcast transcript items to the src pads, not intended for translation.
|
||||
transcript_event_tx: broadcast::Sender<TranscriptEvent>,
|
||||
}
|
||||
|
||||
|
@ -276,7 +280,10 @@ impl Transcriber {
|
|||
) -> Result<gst::FlowSuccess, gst::FlowError> {
|
||||
gst::log!(CAT, obj: pad, "Handling {buffer:?}");
|
||||
|
||||
self.ensure_connection();
|
||||
self.ensure_connection().map_err(|err| {
|
||||
gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
|
||||
gst::FlowError::Error
|
||||
})?;
|
||||
|
||||
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
|
||||
gst::log!(CAT, obj: pad, "Flushing");
|
||||
|
@ -292,12 +299,82 @@ impl Transcriber {
|
|||
|
||||
Ok(gst::FlowSuccess::Ok)
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_connection(&self) {
|
||||
#[derive(Default)]
|
||||
struct TranslateQueue {
|
||||
items: VecDeque<TranscriptItem>,
|
||||
}
|
||||
|
||||
impl TranslateQueue {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.items.is_empty()
|
||||
}
|
||||
|
||||
/// Pushes the provided item.
|
||||
///
|
||||
/// Returns `Some(..)` if items are ready for translation.
|
||||
fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
|
||||
// Keep track of the item individually so we can schedule translation precisely.
|
||||
self.items.push_back(transcript_item.clone());
|
||||
|
||||
if transcript_item.is_punctuation {
|
||||
// This makes it a good chunk for translation.
|
||||
// Concatenate as a single item for translation
|
||||
|
||||
return Some(self.items.drain(..).collect());
|
||||
}
|
||||
|
||||
// Regular case: no separator detected, don't push transcript items
|
||||
// to translation now. They will be pushed either if a punctuation
|
||||
// is found or of a `dequeue()` is requested.
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Dequeues items from the specified `deadline` up to `lookahead`.
|
||||
///
|
||||
/// Returns `Some(..)` if some items match the criteria.
|
||||
fn dequeue(
|
||||
&mut self,
|
||||
latency: gst::ClockTime,
|
||||
threshold: gst::ClockTime,
|
||||
lookahead: gst::ClockTime,
|
||||
) -> Option<Vec<TranscriptItem>> {
|
||||
let first_pts = self.items.front()?.pts;
|
||||
if first_pts + latency > threshold {
|
||||
// First item is too early to be sent to translation now
|
||||
// we can wait for more items to accumulate.
|
||||
return None;
|
||||
}
|
||||
|
||||
// Can't wait any longer to send the first item to translation
|
||||
// Try to get up to lookahead worth of items to improve translation accuracy
|
||||
let limit = first_pts + lookahead;
|
||||
|
||||
let mut items_acc = vec![self.items.pop_front().unwrap()];
|
||||
while let Some(item) = self.items.front() {
|
||||
if item.pts > limit {
|
||||
break;
|
||||
}
|
||||
|
||||
items_acc.push(self.items.pop_front().unwrap());
|
||||
}
|
||||
|
||||
Some(items_acc)
|
||||
}
|
||||
|
||||
fn drain(&mut self) -> impl Iterator<Item = TranscriptItem> + '_ {
|
||||
self.items.drain(..)
|
||||
}
|
||||
}
|
||||
|
||||
impl Transcriber {
|
||||
fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
if state.buffer_tx.is_some() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let settings = self.settings.lock().unwrap();
|
||||
|
@ -306,21 +383,116 @@ impl Transcriber {
|
|||
let s = in_caps.structure(0).unwrap();
|
||||
let sample_rate = s.get::<i32>("rate").unwrap();
|
||||
|
||||
let transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
|
||||
let transcription_settings = TranscriberSettings::from(&settings, sample_rate);
|
||||
|
||||
let (buffer_tx, buffer_rx) = mpsc::channel(1);
|
||||
|
||||
let transcriber_loop = TranscriberLoop::new(
|
||||
let _enter = RUNTIME.enter();
|
||||
let mut transcriber_stream = futures::executor::block_on(TranscriberStream::try_new(
|
||||
self,
|
||||
transcription_settings,
|
||||
settings.lateness,
|
||||
buffer_rx,
|
||||
self.transcript_event_tx.clone(),
|
||||
))?;
|
||||
|
||||
// Latency budget for an item to be pushed to stream on time
|
||||
// Margin:
|
||||
// - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
|
||||
// - 1 * GRANULARITY: extra margin to account for additional overheads.
|
||||
let latency = settings.transcribe_latency.saturating_sub(3 * GRANULARITY);
|
||||
let translate_lookahead = settings.translate_lookahead;
|
||||
let mut translate_queue = TranslateQueue::default();
|
||||
let imp = self.ref_counted();
|
||||
let transcriber_loop_handle = RUNTIME.spawn(async move {
|
||||
loop {
|
||||
// This is to make sure we send items on a timely basis or at least Gap events.
|
||||
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
|
||||
futures::pin_mut!(timeout);
|
||||
|
||||
let transcriber_next = transcriber_stream.next().fuse();
|
||||
futures::pin_mut!(transcriber_next);
|
||||
|
||||
// `transcriber_next` takes precedence over `timeout`
|
||||
// because we don't want to loose any incoming items.
|
||||
let res = futures::select_biased! {
|
||||
event = transcriber_next => Some(event?),
|
||||
_ = timeout => None,
|
||||
};
|
||||
|
||||
use TranscriptEvent::*;
|
||||
match res {
|
||||
None => (),
|
||||
Some(Items(items)) => {
|
||||
if imp.transcript_event_tx.receiver_count() > 0 {
|
||||
let _ = imp.transcript_event_tx.send(Items(items.clone()));
|
||||
}
|
||||
|
||||
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
|
||||
for item in items.iter() {
|
||||
if let Some(items_to_translate) = translate_queue.push(item) {
|
||||
let _ = imp
|
||||
.transcript_event_for_translate_tx
|
||||
.send(Items(items_to_translate.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Eos) => {
|
||||
gst::debug!(CAT, imp: imp, "Transcriber loop sending EOS");
|
||||
|
||||
if imp.transcript_event_tx.receiver_count() > 0 {
|
||||
let _ = imp.transcript_event_tx.send(Eos);
|
||||
}
|
||||
|
||||
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
|
||||
let items_to_translate: Vec<TranscriptItem> =
|
||||
translate_queue.drain().collect();
|
||||
let _ = imp
|
||||
.transcript_event_for_translate_tx
|
||||
.send(Items(items_to_translate.into()));
|
||||
|
||||
let _ = imp.transcript_event_for_translate_tx.send(Eos);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
|
||||
// Check if we need to push items for translation
|
||||
|
||||
let Some((start_time, now)) = imp.get_start_time_and_now() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if !translate_queue.is_empty() {
|
||||
let threshold = now - start_time;
|
||||
|
||||
if let Some(items_to_translate) =
|
||||
translate_queue.dequeue(latency, threshold, translate_lookahead)
|
||||
{
|
||||
gst::debug!(
|
||||
CAT,
|
||||
imp: imp,
|
||||
"Forcing to translation (threshold {threshold}): {items_to_translate:?}"
|
||||
);
|
||||
let transcriber_loop_handle = RUNTIME.spawn(transcriber_loop.run());
|
||||
let _ = imp
|
||||
.transcript_event_for_translate_tx
|
||||
.send(Items(items_to_translate.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
gst::debug!(CAT, imp: imp, "Exiting transcriber loop");
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
state.transcriber_loop_handle = Some(transcriber_loop_handle);
|
||||
state.buffer_tx = Some(buffer_tx);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn prepare(&self) -> Result<(), gst::ErrorMessage> {
|
||||
|
@ -382,6 +554,18 @@ impl Transcriber {
|
|||
}
|
||||
gst::info!(CAT, imp: self, "Unprepared");
|
||||
}
|
||||
|
||||
fn get_start_time_and_now(&self) -> Option<(gst::ClockTime, gst::ClockTime)> {
|
||||
let now = self.obj().current_running_time()?;
|
||||
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
if state.start_time.is_none() {
|
||||
state.start_time = Some(now);
|
||||
}
|
||||
|
||||
Some((state.start_time.unwrap(), now))
|
||||
}
|
||||
}
|
||||
|
||||
#[glib::object_subclass]
|
||||
|
@ -438,6 +622,7 @@ impl ObjectSubclass for Transcriber {
|
|||
// Setting the channel capacity so that a TranslateSrcPad that would lag
|
||||
// behind for some reasons get a chance to catch-up without loosing items.
|
||||
// Receiver will be created by subscribing to sender later.
|
||||
let (transcript_event_for_translate_tx, _) = broadcast::channel(128);
|
||||
let (transcript_event_tx, _) = broadcast::channel(128);
|
||||
|
||||
Self {
|
||||
|
@ -446,6 +631,7 @@ impl ObjectSubclass for Transcriber {
|
|||
settings: Default::default(),
|
||||
state: Default::default(),
|
||||
aws_config: Default::default(),
|
||||
transcript_event_for_translate_tx,
|
||||
transcript_event_tx,
|
||||
}
|
||||
}
|
||||
|
@ -876,51 +1062,93 @@ struct TranslationPadTask {
|
|||
elem: super::Transcriber,
|
||||
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
|
||||
needs_translate: bool,
|
||||
translate_queue: TranslateQueue,
|
||||
translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
|
||||
to_translate_tx: Option<mpsc::Sender<Vec<TranscriptItem>>>,
|
||||
to_translate_tx: Option<mpsc::Sender<Arc<Vec<TranscriptItem>>>>,
|
||||
from_translate_rx: Option<mpsc::Receiver<Vec<TranslatedItem>>>,
|
||||
translate_latency: gst::ClockTime,
|
||||
translate_lookahead: gst::ClockTime,
|
||||
send_events: bool,
|
||||
output_items: VecDeque<OutputItem>,
|
||||
our_latency: gst::ClockTime,
|
||||
seqnum: gst::Seqnum,
|
||||
send_eos: bool,
|
||||
pending_translations: usize,
|
||||
start_time: Option<gst::ClockTime>,
|
||||
}
|
||||
|
||||
impl TranslationPadTask {
|
||||
fn try_new(
|
||||
async fn try_new(
|
||||
pad: &TranslateSrcPad,
|
||||
elem: super::Transcriber,
|
||||
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
|
||||
) -> Result<TranslationPadTask, gst::ErrorMessage> {
|
||||
let mut this = TranslationPadTask {
|
||||
let mut translation_loop = None;
|
||||
let mut translate_loop_handle = None;
|
||||
let mut to_translate_tx = None;
|
||||
let mut from_translate_rx = None;
|
||||
|
||||
let (our_latency, transcript_event_rx, needs_translate);
|
||||
|
||||
{
|
||||
let elem_imp = elem.imp();
|
||||
let elem_settings = elem_imp.settings.lock().unwrap();
|
||||
|
||||
let pad_settings = pad.settings.lock().unwrap();
|
||||
|
||||
our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
|
||||
if our_latency + elem_settings.lateness <= 2 * GRANULARITY {
|
||||
let err = format!(
|
||||
"total latency + lateness must be greater than {}",
|
||||
2 * GRANULARITY
|
||||
);
|
||||
gst::error!(CAT, imp: pad, "{err}");
|
||||
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
|
||||
}
|
||||
|
||||
needs_translate = TranslateSrcPad::needs_translation(
|
||||
&elem_settings.language_code,
|
||||
pad_settings.language_code.as_deref(),
|
||||
);
|
||||
|
||||
if needs_translate {
|
||||
let (to_loop_tx, to_loop_rx) = mpsc::channel(64);
|
||||
let (from_loop_tx, from_loop_rx) = mpsc::channel(64);
|
||||
|
||||
translation_loop = Some(TranslateLoop::new(
|
||||
elem_imp,
|
||||
pad,
|
||||
&elem_settings.language_code,
|
||||
pad_settings.language_code.as_deref().unwrap(),
|
||||
pad_settings.tokenization_method,
|
||||
to_loop_rx,
|
||||
from_loop_tx,
|
||||
));
|
||||
|
||||
to_translate_tx = Some(to_loop_tx);
|
||||
from_translate_rx = Some(from_loop_rx);
|
||||
|
||||
transcript_event_rx = elem_imp.transcript_event_for_translate_tx.subscribe();
|
||||
} else {
|
||||
transcript_event_rx = elem_imp.transcript_event_tx.subscribe();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(translation_loop) = translation_loop {
|
||||
translation_loop.check_language().await?;
|
||||
translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
|
||||
}
|
||||
|
||||
Ok(TranslationPadTask {
|
||||
pad: pad.ref_counted(),
|
||||
elem,
|
||||
transcript_event_rx,
|
||||
needs_translate: false,
|
||||
translate_queue: TranslateQueue::default(),
|
||||
translate_loop_handle: None,
|
||||
to_translate_tx: None,
|
||||
from_translate_rx: None,
|
||||
translate_latency: DEFAULT_TRANSLATE_LATENCY,
|
||||
translate_lookahead: DEFAULT_TRANSLATE_LOOKAHEAD,
|
||||
needs_translate,
|
||||
translate_loop_handle,
|
||||
to_translate_tx,
|
||||
from_translate_rx,
|
||||
send_events: true,
|
||||
output_items: VecDeque::new(),
|
||||
our_latency: DEFAULT_TRANSCRIBE_LATENCY,
|
||||
our_latency,
|
||||
seqnum: gst::Seqnum::next(),
|
||||
send_eos: false,
|
||||
pending_translations: 0,
|
||||
start_time: None,
|
||||
};
|
||||
|
||||
let _enter_guard = RUNTIME.enter();
|
||||
futures::executor::block_on(this.init_translate())?;
|
||||
|
||||
Ok(this)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -958,11 +1186,9 @@ impl TranslationPadTask {
|
|||
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
|
||||
futures::pin_mut!(transcript_event_rx);
|
||||
|
||||
// `timeout` takes precedence over `transcript_events` reception
|
||||
// because we may need to `dequeue` `items` or push a `Gap` event
|
||||
// before current latency budget is exhausted.
|
||||
// `transcript_event_rx` takes precedence over `timeout`
|
||||
// because we don't want to loose any incoming items.
|
||||
futures::select_biased! {
|
||||
_ = timeout => (),
|
||||
items_res = transcript_event_rx => {
|
||||
use TranscriptEvent::*;
|
||||
use broadcast::error::RecvError;
|
||||
|
@ -983,6 +1209,7 @@ impl TranslationPadTask {
|
|||
}
|
||||
}
|
||||
}
|
||||
_ = timeout => (),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -999,121 +1226,100 @@ impl TranslationPadTask {
|
|||
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
||||
}
|
||||
|
||||
let transcript_items = {
|
||||
let items_to_translate = {
|
||||
// This is to make sure we send items on a timely basis or at least Gap events.
|
||||
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
|
||||
futures::pin_mut!(timeout);
|
||||
|
||||
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
|
||||
futures::pin_mut!(transcript_event_rx);
|
||||
|
||||
// `transcript_event_rx` takes precedence over `timeout`
|
||||
// because we don't want to loose any incoming items.
|
||||
futures::select_biased! {
|
||||
items_res = transcript_event_rx => {
|
||||
use TranscriptEvent::*;
|
||||
use broadcast::error::RecvError;
|
||||
match items_res {
|
||||
Ok(Items(items_to_translate)) => Some(items_to_translate),
|
||||
Ok(Eos) => {
|
||||
gst::debug!(CAT, imp: self.pad, "Got eos");
|
||||
self.send_eos = true;
|
||||
None
|
||||
}
|
||||
Err(RecvError::Lagged(nb_msg)) => {
|
||||
gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
|
||||
None
|
||||
}
|
||||
Err(RecvError::Closed) => {
|
||||
gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
|
||||
self.send_eos = true;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = timeout => None,
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(items_to_translate) = items_to_translate {
|
||||
if !items_to_translate.is_empty() {
|
||||
let res = self
|
||||
.to_translate_tx
|
||||
.as_mut()
|
||||
.expect("to_translation chan must be available in translation mode")
|
||||
.send(items_to_translate)
|
||||
.await;
|
||||
|
||||
if res.is_err() {
|
||||
const ERR: &str = "to_translation chan terminated";
|
||||
gst::debug!(CAT, imp: self.pad, "{ERR}");
|
||||
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
||||
}
|
||||
|
||||
self.pending_translations += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Check pending translated items
|
||||
let from_translate_rx = self
|
||||
.from_translate_rx
|
||||
.as_mut()
|
||||
.expect("from_translation chan must be available in translation mode");
|
||||
|
||||
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
|
||||
futures::pin_mut!(transcript_event_rx);
|
||||
|
||||
// `timeout` takes precedence over `transcript_events` reception
|
||||
// because we may need to `dequeue` `items` or push a `Gap` event
|
||||
// before current latency budget is exhausted.
|
||||
futures::select_biased! {
|
||||
_ = timeout => return Ok(()),
|
||||
translated_items = from_translate_rx.next() => {
|
||||
while let Ok(translated_items) = from_translate_rx.try_next() {
|
||||
let Some(translated_items) = translated_items else {
|
||||
const ERR: &str = "translation chan terminated";
|
||||
gst::debug!(CAT, imp: self.pad, "{ERR}");
|
||||
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
||||
};
|
||||
|
||||
self.output_items.extend(translated_items.into_iter().map(Into::into));
|
||||
self.output_items
|
||||
.extend(translated_items.into_iter().map(Into::into));
|
||||
self.pending_translations = self.pending_translations.saturating_sub(1);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
items_res = transcript_event_rx => {
|
||||
use TranscriptEvent::*;
|
||||
use broadcast::error::RecvError;
|
||||
match items_res {
|
||||
Ok(Items(transcript_items)) => transcript_items,
|
||||
Ok(Eos) => {
|
||||
gst::debug!(CAT, imp: self.pad, "Got eos");
|
||||
self.send_eos = true;
|
||||
return Ok(());
|
||||
}
|
||||
Err(RecvError::Lagged(nb_msg)) => {
|
||||
gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
|
||||
return Ok(());
|
||||
}
|
||||
Err(RecvError::Closed) => {
|
||||
gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
|
||||
self.send_eos = true;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for items in transcript_items.iter() {
|
||||
if let Some(items_to_translate) = self.translate_queue.push(items) {
|
||||
self.send_for_translation(items_to_translate).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn dequeue_for_translation(
|
||||
&mut self,
|
||||
start_time: gst::ClockTime,
|
||||
now: gst::ClockTime,
|
||||
) -> Result<(), gst::ErrorMessage> {
|
||||
if !self.translate_queue.is_empty() {
|
||||
// Latency budget for an item to be pushed to stream on time
|
||||
// Margin:
|
||||
// - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
|
||||
// - 1 * GRANULARITY: extra margin to account for additional overheads.
|
||||
let latency = self.our_latency.saturating_sub(3 * GRANULARITY);
|
||||
|
||||
// Estimated time of arrival for an item sent to translation now.
|
||||
// (in transcript item ts base)
|
||||
let translation_eta = now + self.translate_latency - start_time;
|
||||
|
||||
if let Some(items_to_translate) =
|
||||
self.translate_queue
|
||||
.dequeue(latency, translation_eta, self.translate_lookahead)
|
||||
{
|
||||
gst::debug!(CAT, imp: self.pad, "Forcing to translation: {items_to_translate:?}");
|
||||
self.send_for_translation(items_to_translate).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn dequeue(&mut self) -> bool {
|
||||
let (now, start_time, mut last_position, mut discont_pending);
|
||||
{
|
||||
let mut pad_state = self.pad.state.lock().unwrap();
|
||||
|
||||
let Some(cur_rt) = self.elem.current_running_time() else {
|
||||
let Some((start_time, now)) = self.elem.imp().get_start_time_and_now() else {
|
||||
// Wait for the clock to be available
|
||||
return true;
|
||||
};
|
||||
now = cur_rt;
|
||||
|
||||
if self.start_time.is_none() {
|
||||
self.start_time = Some(now);
|
||||
pad_state.out_segment.set_position(now);
|
||||
}
|
||||
let (mut last_position, mut discont_pending) = {
|
||||
let mut state = self.pad.state.lock().unwrap();
|
||||
|
||||
start_time = self.start_time.unwrap();
|
||||
last_position = pad_state.out_segment.position().unwrap();
|
||||
discont_pending = pad_state.discont_pending;
|
||||
}
|
||||
let last_position = if let Some(pos) = state.out_segment.position() {
|
||||
pos
|
||||
} else {
|
||||
state.out_segment.set_position(start_time);
|
||||
start_time
|
||||
};
|
||||
|
||||
if self.needs_translate && self.dequeue_for_translation(start_time, now).await.is_err() {
|
||||
return false;
|
||||
}
|
||||
(last_position, state.discont_pending)
|
||||
};
|
||||
|
||||
/* First, check our pending buffers */
|
||||
while let Some(item) = self.output_items.front() {
|
||||
|
@ -1206,11 +1412,7 @@ impl TranslationPadTask {
|
|||
}
|
||||
}
|
||||
|
||||
if self.send_eos
|
||||
&& self.pending_translations == 0
|
||||
&& self.output_items.is_empty()
|
||||
&& self.translate_queue.is_empty()
|
||||
{
|
||||
if self.send_eos && self.pending_translations == 0 && self.output_items.is_empty() {
|
||||
/* We're EOS, we can pause and exit early */
|
||||
let _ = self.pad.obj().pause_task();
|
||||
|
||||
|
@ -1261,28 +1463,6 @@ impl TranslationPadTask {
|
|||
true
|
||||
}
|
||||
|
||||
async fn send_for_translation(
|
||||
&mut self,
|
||||
transcript_items: Vec<TranscriptItem>,
|
||||
) -> Result<(), gst::ErrorMessage> {
|
||||
let res = self
|
||||
.to_translate_tx
|
||||
.as_mut()
|
||||
.expect("to_translation chan must be available in translation mode")
|
||||
.send(transcript_items)
|
||||
.await;
|
||||
|
||||
if res.is_err() {
|
||||
const ERR: &str = "to_translation chan terminated";
|
||||
gst::debug!(CAT, imp: self.pad, "{ERR}");
|
||||
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
||||
}
|
||||
|
||||
self.pending_translations += 1;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> {
|
||||
if !self.send_events {
|
||||
return Ok(());
|
||||
|
@ -1332,62 +1512,6 @@ impl TranslationPadTask {
|
|||
}
|
||||
}
|
||||
|
||||
impl TranslationPadTask {
|
||||
async fn init_translate(&mut self) -> Result<(), gst::ErrorMessage> {
|
||||
let mut translation_loop = None;
|
||||
|
||||
{
|
||||
let elem_imp = self.elem.imp();
|
||||
let elem_settings = elem_imp.settings.lock().unwrap();
|
||||
|
||||
let pad_settings = self.pad.settings.lock().unwrap();
|
||||
|
||||
self.our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
|
||||
if self.our_latency + elem_settings.lateness <= 2 * GRANULARITY {
|
||||
let err = format!(
|
||||
"total latency + lateness must be greater than {}",
|
||||
2 * GRANULARITY
|
||||
);
|
||||
gst::error!(CAT, imp: self.pad, "{err}");
|
||||
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
|
||||
}
|
||||
|
||||
self.translate_latency = elem_settings.translate_latency;
|
||||
self.translate_lookahead = elem_settings.translate_lookahead;
|
||||
|
||||
self.needs_translate = TranslateSrcPad::needs_translation(
|
||||
&elem_settings.language_code,
|
||||
pad_settings.language_code.as_deref(),
|
||||
);
|
||||
|
||||
if self.needs_translate {
|
||||
let (to_translate_tx, to_translate_rx) = mpsc::channel(64);
|
||||
let (from_translate_tx, from_translate_rx) = mpsc::channel(64);
|
||||
|
||||
translation_loop = Some(TranslateLoop::new(
|
||||
elem_imp,
|
||||
&self.pad,
|
||||
&elem_settings.language_code,
|
||||
pad_settings.language_code.as_deref().unwrap(),
|
||||
pad_settings.tokenization_method,
|
||||
to_translate_rx,
|
||||
from_translate_tx,
|
||||
));
|
||||
|
||||
self.to_translate_tx = Some(to_translate_tx);
|
||||
self.from_translate_rx = Some(from_translate_rx);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(translation_loop) = translation_loop {
|
||||
translation_loop.check_language().await?;
|
||||
self.translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TranslationPadState {
|
||||
discont_pending: bool,
|
||||
|
@ -1422,8 +1546,8 @@ impl TranslateSrcPad {
|
|||
gst::debug!(CAT, imp: self, "Starting task");
|
||||
|
||||
let elem = self.parent();
|
||||
let transcript_event_rx = elem.imp().transcript_event_tx.subscribe();
|
||||
let mut pad_task = TranslationPadTask::try_new(self, elem, transcript_event_rx)
|
||||
let _enter = RUNTIME.enter();
|
||||
let mut pad_task = futures::executor::block_on(TranslationPadTask::try_new(self, elem))
|
||||
.map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?;
|
||||
|
||||
let imp = self.ref_counted();
|
||||
|
|
|
@ -15,7 +15,6 @@ use aws_sdk_transcribestreaming::model;
|
|||
|
||||
use futures::channel::mpsc;
|
||||
use futures::prelude::*;
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -23,7 +22,7 @@ use super::imp::{Settings, Transcriber};
|
|||
use super::CAT;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TranscriptionSettings {
|
||||
pub struct TranscriberSettings {
|
||||
lang_code: model::LanguageCode,
|
||||
sample_rate: i32,
|
||||
vocabulary: Option<String>,
|
||||
|
@ -33,9 +32,9 @@ pub struct TranscriptionSettings {
|
|||
results_stability: model::PartialResultsStability,
|
||||
}
|
||||
|
||||
impl TranscriptionSettings {
|
||||
impl TranscriberSettings {
|
||||
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
|
||||
TranscriptionSettings {
|
||||
TranscriberSettings {
|
||||
lang_code: settings.language_code.as_str().into(),
|
||||
sample_rate,
|
||||
vocabulary: settings.vocabulary.clone(),
|
||||
|
@ -83,43 +82,30 @@ impl From<Vec<TranscriptItem>> for TranscriptEvent {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct TranscriberLoop {
|
||||
pub struct TranscriberStream {
|
||||
imp: glib::subclass::ObjectImplRef<Transcriber>,
|
||||
client: aws_transcribe::Client,
|
||||
settings: Option<TranscriptionSettings>,
|
||||
output: aws_transcribe::output::StartStreamTranscriptionOutput,
|
||||
lateness: gst::ClockTime,
|
||||
buffer_rx: Option<mpsc::Receiver<gst::Buffer>>,
|
||||
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
|
||||
partial_index: usize,
|
||||
}
|
||||
|
||||
impl TranscriberLoop {
|
||||
pub fn new(
|
||||
impl TranscriberStream {
|
||||
pub async fn try_new(
|
||||
imp: &Transcriber,
|
||||
settings: TranscriptionSettings,
|
||||
settings: TranscriberSettings,
|
||||
lateness: gst::ClockTime,
|
||||
buffer_rx: mpsc::Receiver<gst::Buffer>,
|
||||
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
|
||||
) -> Self {
|
||||
) -> Result<Self, gst::ErrorMessage> {
|
||||
let client = {
|
||||
let aws_config = imp.aws_config.lock().unwrap();
|
||||
let aws_config = aws_config
|
||||
.as_ref()
|
||||
.expect("aws_config must be initialized at this stage");
|
||||
aws_transcribe::Client::new(aws_config)
|
||||
};
|
||||
|
||||
TranscriberLoop {
|
||||
imp: imp.ref_counted(),
|
||||
client: aws_transcribe::Client::new(aws_config),
|
||||
settings: Some(settings),
|
||||
lateness,
|
||||
buffer_rx: Some(buffer_rx),
|
||||
transcript_items_tx,
|
||||
partial_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
|
||||
// Stream the incoming buffers chunked
|
||||
let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| {
|
||||
let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
|
||||
async_stream::stream! {
|
||||
let data = buffer.map_readable().unwrap();
|
||||
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
|
||||
|
@ -129,9 +115,7 @@ impl TranscriberLoop {
|
|||
}
|
||||
});
|
||||
|
||||
let settings = self.settings.take().unwrap();
|
||||
let mut transcribe_builder = self
|
||||
.client
|
||||
let mut transcribe_builder = client
|
||||
.start_stream_transcription()
|
||||
.language_code(settings.lang_code)
|
||||
.media_sample_rate_hertz(settings.sample_rate)
|
||||
|
@ -147,17 +131,28 @@ impl TranscriberLoop {
|
|||
.vocabulary_filter_method(settings.vocabulary_filter_method);
|
||||
}
|
||||
|
||||
let mut output = transcribe_builder
|
||||
let output = transcribe_builder
|
||||
.audio_stream(chunk_stream.into())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
let err = format!("Transcribe ws init error: {err}");
|
||||
gst::error!(CAT, imp: self.imp, "{err}");
|
||||
gst::error!(CAT, imp: imp, "{err}");
|
||||
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
|
||||
})?;
|
||||
|
||||
while let Some(event) = output
|
||||
Ok(TranscriberStream {
|
||||
imp: imp.ref_counted(),
|
||||
output,
|
||||
lateness,
|
||||
partial_index: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn next(&mut self) -> Result<TranscriptEvent, gst::ErrorMessage> {
|
||||
loop {
|
||||
let event = self
|
||||
.output
|
||||
.transcript_result_stream
|
||||
.recv()
|
||||
.await
|
||||
|
@ -165,8 +160,13 @@ impl TranscriberLoop {
|
|||
let err = format!("Transcribe ws stream error: {err}");
|
||||
gst::error!(CAT, imp: self.imp, "{err}");
|
||||
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
|
||||
})?
|
||||
{
|
||||
})?;
|
||||
|
||||
let Some(event) = event else {
|
||||
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
|
||||
return Ok(TranscriptEvent::Eos);
|
||||
};
|
||||
|
||||
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
|
||||
let mut ready_items = None;
|
||||
|
||||
|
@ -188,10 +188,7 @@ impl TranscriberLoop {
|
|||
}
|
||||
|
||||
if let Some(ready_items) = ready_items {
|
||||
if self.transcript_items_tx.send(ready_items.into()).is_err() {
|
||||
gst::debug!(CAT, imp: self.imp, "No transcript items receivers");
|
||||
break;
|
||||
}
|
||||
return Ok(ready_items.into());
|
||||
}
|
||||
} else {
|
||||
gst::warning!(
|
||||
|
@ -201,13 +198,6 @@ impl TranscriberLoop {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
|
||||
let _ = self.transcript_items_tx.send(TranscriptEvent::Eos);
|
||||
|
||||
gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Builds a list from the provided stable items.
|
||||
|
|
|
@ -14,7 +14,7 @@ use aws_sdk_translate as aws_translate;
|
|||
use futures::channel::mpsc;
|
||||
use futures::prelude::*;
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::imp::TranslateSrcPad;
|
||||
use super::transcribe::TranscriptItem;
|
||||
|
@ -40,77 +40,13 @@ impl From<&TranscriptItem> for TranslatedItem {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct TranslateQueue {
|
||||
items: VecDeque<TranscriptItem>,
|
||||
}
|
||||
|
||||
impl TranslateQueue {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.items.is_empty()
|
||||
}
|
||||
|
||||
/// Pushes the provided item.
|
||||
///
|
||||
/// Returns `Some(..)` if items are ready for translation.
|
||||
pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
|
||||
// Keep track of the item individually so we can schedule translation precisely.
|
||||
self.items.push_back(transcript_item.clone());
|
||||
|
||||
if transcript_item.is_punctuation {
|
||||
// This makes it a good chunk for translation.
|
||||
// Concatenate as a single item for translation
|
||||
|
||||
return Some(self.items.drain(..).collect());
|
||||
}
|
||||
|
||||
// Regular case: no separator detected, don't push transcript items
|
||||
// to translation now. They will be pushed either if a punctuation
|
||||
// is found or of a `dequeue()` is requested.
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Dequeues items from the specified `deadline` up to `lookahead`.
|
||||
///
|
||||
/// Returns `Some(..)` if some items match the criteria.
|
||||
pub fn dequeue(
|
||||
&mut self,
|
||||
latency: gst::ClockTime,
|
||||
threshold: gst::ClockTime,
|
||||
lookahead: gst::ClockTime,
|
||||
) -> Option<Vec<TranscriptItem>> {
|
||||
let first_pts = self.items.front()?.pts;
|
||||
if first_pts + latency > threshold {
|
||||
// First item is too early to be sent to translation now
|
||||
// we can wait for more items to accumulate.
|
||||
return None;
|
||||
}
|
||||
|
||||
// Can't wait any longer to send the first item to translation
|
||||
// Try to get up to lookahead worth of items to improve translation accuracy
|
||||
let limit = first_pts + lookahead;
|
||||
|
||||
let mut items_acc = vec![self.items.pop_front().unwrap()];
|
||||
while let Some(item) = self.items.front() {
|
||||
if item.pts > limit {
|
||||
break;
|
||||
}
|
||||
|
||||
items_acc.push(self.items.pop_front().unwrap());
|
||||
}
|
||||
|
||||
Some(items_acc)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TranslateLoop {
|
||||
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
|
||||
client: aws_translate::Client,
|
||||
input_lang: String,
|
||||
output_lang: String,
|
||||
tokenization_method: TranslationTokenizationMethod,
|
||||
transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>,
|
||||
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
|
||||
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
|
||||
}
|
||||
|
||||
|
@ -121,7 +57,7 @@ impl TranslateLoop {
|
|||
input_lang: &str,
|
||||
output_lang: &str,
|
||||
tokenization_method: TranslationTokenizationMethod,
|
||||
transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>,
|
||||
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
|
||||
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
|
||||
) -> Self {
|
||||
let aws_config = imp.aws_config.lock().unwrap();
|
||||
|
@ -175,12 +111,12 @@ impl TranslateLoop {
|
|||
|
||||
let (ts_duration_list, content): (Vec<(gst::ClockTime, gst::ClockTime)>, String) =
|
||||
transcript_items
|
||||
.into_iter()
|
||||
.iter()
|
||||
.map(|item| {
|
||||
(
|
||||
(item.pts, item.duration),
|
||||
match self.tokenization_method {
|
||||
Tokenization::None => item.content,
|
||||
Tokenization::None => item.content.clone(),
|
||||
Tokenization::SpanBased => {
|
||||
format!("{SPAN_START}{}{SPAN_END}", item.content)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue