mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-11-25 04:51:26 +00:00
net/aws/transcriber: use two queues for sending transcript items
* A queue dedicated to transcript items not intended for translation. * A queue dedicated to transcript items intended for translation. The items are enqueued after a separator is detected or translate-lookahead was reached. Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1137>
This commit is contained in:
parent
5a5ca76d9d
commit
2b32d00589
3 changed files with 373 additions and 323 deletions
|
@ -29,12 +29,12 @@ use futures::prelude::*;
|
||||||
use tokio::{runtime, sync::broadcast, task};
|
use tokio::{runtime, sync::broadcast, task};
|
||||||
|
|
||||||
use std::collections::{BTreeSet, VecDeque};
|
use std::collections::{BTreeSet, VecDeque};
|
||||||
use std::sync::Mutex;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
|
||||||
use super::transcribe::{TranscriberLoop, TranscriptEvent, TranscriptItem, TranscriptionSettings};
|
use super::transcribe::{TranscriberSettings, TranscriberStream, TranscriptEvent, TranscriptItem};
|
||||||
use super::translate::{TranslateLoop, TranslateQueue, TranslatedItem};
|
use super::translate::{TranslateLoop, TranslatedItem};
|
||||||
use super::{
|
use super::{
|
||||||
AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod,
|
AwsTranscriberResultStability, AwsTranscriberVocabularyFilterMethod,
|
||||||
TranslationTokenizationMethod, CAT,
|
TranslationTokenizationMethod, CAT,
|
||||||
|
@ -148,6 +148,7 @@ struct State {
|
||||||
srcpads: BTreeSet<super::TranslateSrcPad>,
|
srcpads: BTreeSet<super::TranslateSrcPad>,
|
||||||
pad_serial: u32,
|
pad_serial: u32,
|
||||||
seqnum: gst::Seqnum,
|
seqnum: gst::Seqnum,
|
||||||
|
start_time: Option<gst::ClockTime>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for State {
|
impl Default for State {
|
||||||
|
@ -158,6 +159,7 @@ impl Default for State {
|
||||||
srcpads: Default::default(),
|
srcpads: Default::default(),
|
||||||
pad_serial: 0,
|
pad_serial: 0,
|
||||||
seqnum: gst::Seqnum::next(),
|
seqnum: gst::Seqnum::next(),
|
||||||
|
start_time: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -168,7 +170,9 @@ pub struct Transcriber {
|
||||||
settings: Mutex<Settings>,
|
settings: Mutex<Settings>,
|
||||||
state: Mutex<State>,
|
state: Mutex<State>,
|
||||||
pub(super) aws_config: Mutex<Option<aws_config::SdkConfig>>,
|
pub(super) aws_config: Mutex<Option<aws_config::SdkConfig>>,
|
||||||
// sender to broadcast transcript items to the translate src pads.
|
// sender to broadcast transcript items to the src pads for translation.
|
||||||
|
transcript_event_for_translate_tx: broadcast::Sender<TranscriptEvent>,
|
||||||
|
// sender to broadcast transcript items to the src pads, not intended for translation.
|
||||||
transcript_event_tx: broadcast::Sender<TranscriptEvent>,
|
transcript_event_tx: broadcast::Sender<TranscriptEvent>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,7 +280,10 @@ impl Transcriber {
|
||||||
) -> Result<gst::FlowSuccess, gst::FlowError> {
|
) -> Result<gst::FlowSuccess, gst::FlowError> {
|
||||||
gst::log!(CAT, obj: pad, "Handling {buffer:?}");
|
gst::log!(CAT, obj: pad, "Handling {buffer:?}");
|
||||||
|
|
||||||
self.ensure_connection();
|
self.ensure_connection().map_err(|err| {
|
||||||
|
gst::element_imp_error!(self, gst::StreamError::Failed, ["Streaming failed: {err}"]);
|
||||||
|
gst::FlowError::Error
|
||||||
|
})?;
|
||||||
|
|
||||||
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
|
let Some(mut buffer_tx) = self.state.lock().unwrap().buffer_tx.take() else {
|
||||||
gst::log!(CAT, obj: pad, "Flushing");
|
gst::log!(CAT, obj: pad, "Flushing");
|
||||||
|
@ -292,12 +299,82 @@ impl Transcriber {
|
||||||
|
|
||||||
Ok(gst::FlowSuccess::Ok)
|
Ok(gst::FlowSuccess::Ok)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn ensure_connection(&self) {
|
#[derive(Default)]
|
||||||
|
struct TranslateQueue {
|
||||||
|
items: VecDeque<TranscriptItem>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TranslateQueue {
|
||||||
|
fn is_empty(&self) -> bool {
|
||||||
|
self.items.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pushes the provided item.
|
||||||
|
///
|
||||||
|
/// Returns `Some(..)` if items are ready for translation.
|
||||||
|
fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
|
||||||
|
// Keep track of the item individually so we can schedule translation precisely.
|
||||||
|
self.items.push_back(transcript_item.clone());
|
||||||
|
|
||||||
|
if transcript_item.is_punctuation {
|
||||||
|
// This makes it a good chunk for translation.
|
||||||
|
// Concatenate as a single item for translation
|
||||||
|
|
||||||
|
return Some(self.items.drain(..).collect());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular case: no separator detected, don't push transcript items
|
||||||
|
// to translation now. They will be pushed either if a punctuation
|
||||||
|
// is found or of a `dequeue()` is requested.
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dequeues items from the specified `deadline` up to `lookahead`.
|
||||||
|
///
|
||||||
|
/// Returns `Some(..)` if some items match the criteria.
|
||||||
|
fn dequeue(
|
||||||
|
&mut self,
|
||||||
|
latency: gst::ClockTime,
|
||||||
|
threshold: gst::ClockTime,
|
||||||
|
lookahead: gst::ClockTime,
|
||||||
|
) -> Option<Vec<TranscriptItem>> {
|
||||||
|
let first_pts = self.items.front()?.pts;
|
||||||
|
if first_pts + latency > threshold {
|
||||||
|
// First item is too early to be sent to translation now
|
||||||
|
// we can wait for more items to accumulate.
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can't wait any longer to send the first item to translation
|
||||||
|
// Try to get up to lookahead worth of items to improve translation accuracy
|
||||||
|
let limit = first_pts + lookahead;
|
||||||
|
|
||||||
|
let mut items_acc = vec![self.items.pop_front().unwrap()];
|
||||||
|
while let Some(item) = self.items.front() {
|
||||||
|
if item.pts > limit {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
items_acc.push(self.items.pop_front().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(items_acc)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drain(&mut self) -> impl Iterator<Item = TranscriptItem> + '_ {
|
||||||
|
self.items.drain(..)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Transcriber {
|
||||||
|
fn ensure_connection(&self) -> Result<(), gst::ErrorMessage> {
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock().unwrap();
|
||||||
|
|
||||||
if state.buffer_tx.is_some() {
|
if state.buffer_tx.is_some() {
|
||||||
return;
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
let settings = self.settings.lock().unwrap();
|
let settings = self.settings.lock().unwrap();
|
||||||
|
@ -306,21 +383,116 @@ impl Transcriber {
|
||||||
let s = in_caps.structure(0).unwrap();
|
let s = in_caps.structure(0).unwrap();
|
||||||
let sample_rate = s.get::<i32>("rate").unwrap();
|
let sample_rate = s.get::<i32>("rate").unwrap();
|
||||||
|
|
||||||
let transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
|
let transcription_settings = TranscriberSettings::from(&settings, sample_rate);
|
||||||
|
|
||||||
let (buffer_tx, buffer_rx) = mpsc::channel(1);
|
let (buffer_tx, buffer_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
let transcriber_loop = TranscriberLoop::new(
|
let _enter = RUNTIME.enter();
|
||||||
|
let mut transcriber_stream = futures::executor::block_on(TranscriberStream::try_new(
|
||||||
self,
|
self,
|
||||||
transcription_settings,
|
transcription_settings,
|
||||||
settings.lateness,
|
settings.lateness,
|
||||||
buffer_rx,
|
buffer_rx,
|
||||||
self.transcript_event_tx.clone(),
|
))?;
|
||||||
);
|
|
||||||
let transcriber_loop_handle = RUNTIME.spawn(transcriber_loop.run());
|
// Latency budget for an item to be pushed to stream on time
|
||||||
|
// Margin:
|
||||||
|
// - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
|
||||||
|
// - 1 * GRANULARITY: extra margin to account for additional overheads.
|
||||||
|
let latency = settings.transcribe_latency.saturating_sub(3 * GRANULARITY);
|
||||||
|
let translate_lookahead = settings.translate_lookahead;
|
||||||
|
let mut translate_queue = TranslateQueue::default();
|
||||||
|
let imp = self.ref_counted();
|
||||||
|
let transcriber_loop_handle = RUNTIME.spawn(async move {
|
||||||
|
loop {
|
||||||
|
// This is to make sure we send items on a timely basis or at least Gap events.
|
||||||
|
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
|
||||||
|
futures::pin_mut!(timeout);
|
||||||
|
|
||||||
|
let transcriber_next = transcriber_stream.next().fuse();
|
||||||
|
futures::pin_mut!(transcriber_next);
|
||||||
|
|
||||||
|
// `transcriber_next` takes precedence over `timeout`
|
||||||
|
// because we don't want to loose any incoming items.
|
||||||
|
let res = futures::select_biased! {
|
||||||
|
event = transcriber_next => Some(event?),
|
||||||
|
_ = timeout => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
use TranscriptEvent::*;
|
||||||
|
match res {
|
||||||
|
None => (),
|
||||||
|
Some(Items(items)) => {
|
||||||
|
if imp.transcript_event_tx.receiver_count() > 0 {
|
||||||
|
let _ = imp.transcript_event_tx.send(Items(items.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
|
||||||
|
for item in items.iter() {
|
||||||
|
if let Some(items_to_translate) = translate_queue.push(item) {
|
||||||
|
let _ = imp
|
||||||
|
.transcript_event_for_translate_tx
|
||||||
|
.send(Items(items_to_translate.into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(Eos) => {
|
||||||
|
gst::debug!(CAT, imp: imp, "Transcriber loop sending EOS");
|
||||||
|
|
||||||
|
if imp.transcript_event_tx.receiver_count() > 0 {
|
||||||
|
let _ = imp.transcript_event_tx.send(Eos);
|
||||||
|
}
|
||||||
|
|
||||||
|
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
|
||||||
|
let items_to_translate: Vec<TranscriptItem> =
|
||||||
|
translate_queue.drain().collect();
|
||||||
|
let _ = imp
|
||||||
|
.transcript_event_for_translate_tx
|
||||||
|
.send(Items(items_to_translate.into()));
|
||||||
|
|
||||||
|
let _ = imp.transcript_event_for_translate_tx.send(Eos);
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if imp.transcript_event_for_translate_tx.receiver_count() > 0 {
|
||||||
|
// Check if we need to push items for translation
|
||||||
|
|
||||||
|
let Some((start_time, now)) = imp.get_start_time_and_now() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if !translate_queue.is_empty() {
|
||||||
|
let threshold = now - start_time;
|
||||||
|
|
||||||
|
if let Some(items_to_translate) =
|
||||||
|
translate_queue.dequeue(latency, threshold, translate_lookahead)
|
||||||
|
{
|
||||||
|
gst::debug!(
|
||||||
|
CAT,
|
||||||
|
imp: imp,
|
||||||
|
"Forcing to translation (threshold {threshold}): {items_to_translate:?}"
|
||||||
|
);
|
||||||
|
let _ = imp
|
||||||
|
.transcript_event_for_translate_tx
|
||||||
|
.send(Items(items_to_translate.into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
gst::debug!(CAT, imp: imp, "Exiting transcriber loop");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
state.transcriber_loop_handle = Some(transcriber_loop_handle);
|
state.transcriber_loop_handle = Some(transcriber_loop_handle);
|
||||||
state.buffer_tx = Some(buffer_tx);
|
state.buffer_tx = Some(buffer_tx);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare(&self) -> Result<(), gst::ErrorMessage> {
|
fn prepare(&self) -> Result<(), gst::ErrorMessage> {
|
||||||
|
@ -382,6 +554,18 @@ impl Transcriber {
|
||||||
}
|
}
|
||||||
gst::info!(CAT, imp: self, "Unprepared");
|
gst::info!(CAT, imp: self, "Unprepared");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_start_time_and_now(&self) -> Option<(gst::ClockTime, gst::ClockTime)> {
|
||||||
|
let now = self.obj().current_running_time()?;
|
||||||
|
|
||||||
|
let mut state = self.state.lock().unwrap();
|
||||||
|
|
||||||
|
if state.start_time.is_none() {
|
||||||
|
state.start_time = Some(now);
|
||||||
|
}
|
||||||
|
|
||||||
|
Some((state.start_time.unwrap(), now))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[glib::object_subclass]
|
#[glib::object_subclass]
|
||||||
|
@ -438,6 +622,7 @@ impl ObjectSubclass for Transcriber {
|
||||||
// Setting the channel capacity so that a TranslateSrcPad that would lag
|
// Setting the channel capacity so that a TranslateSrcPad that would lag
|
||||||
// behind for some reasons get a chance to catch-up without loosing items.
|
// behind for some reasons get a chance to catch-up without loosing items.
|
||||||
// Receiver will be created by subscribing to sender later.
|
// Receiver will be created by subscribing to sender later.
|
||||||
|
let (transcript_event_for_translate_tx, _) = broadcast::channel(128);
|
||||||
let (transcript_event_tx, _) = broadcast::channel(128);
|
let (transcript_event_tx, _) = broadcast::channel(128);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
@ -446,6 +631,7 @@ impl ObjectSubclass for Transcriber {
|
||||||
settings: Default::default(),
|
settings: Default::default(),
|
||||||
state: Default::default(),
|
state: Default::default(),
|
||||||
aws_config: Default::default(),
|
aws_config: Default::default(),
|
||||||
|
transcript_event_for_translate_tx,
|
||||||
transcript_event_tx,
|
transcript_event_tx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -876,51 +1062,93 @@ struct TranslationPadTask {
|
||||||
elem: super::Transcriber,
|
elem: super::Transcriber,
|
||||||
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
|
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
|
||||||
needs_translate: bool,
|
needs_translate: bool,
|
||||||
translate_queue: TranslateQueue,
|
|
||||||
translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
|
translate_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
|
||||||
to_translate_tx: Option<mpsc::Sender<Vec<TranscriptItem>>>,
|
to_translate_tx: Option<mpsc::Sender<Arc<Vec<TranscriptItem>>>>,
|
||||||
from_translate_rx: Option<mpsc::Receiver<Vec<TranslatedItem>>>,
|
from_translate_rx: Option<mpsc::Receiver<Vec<TranslatedItem>>>,
|
||||||
translate_latency: gst::ClockTime,
|
|
||||||
translate_lookahead: gst::ClockTime,
|
|
||||||
send_events: bool,
|
send_events: bool,
|
||||||
output_items: VecDeque<OutputItem>,
|
output_items: VecDeque<OutputItem>,
|
||||||
our_latency: gst::ClockTime,
|
our_latency: gst::ClockTime,
|
||||||
seqnum: gst::Seqnum,
|
seqnum: gst::Seqnum,
|
||||||
send_eos: bool,
|
send_eos: bool,
|
||||||
pending_translations: usize,
|
pending_translations: usize,
|
||||||
start_time: Option<gst::ClockTime>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranslationPadTask {
|
impl TranslationPadTask {
|
||||||
fn try_new(
|
async fn try_new(
|
||||||
pad: &TranslateSrcPad,
|
pad: &TranslateSrcPad,
|
||||||
elem: super::Transcriber,
|
elem: super::Transcriber,
|
||||||
transcript_event_rx: broadcast::Receiver<TranscriptEvent>,
|
|
||||||
) -> Result<TranslationPadTask, gst::ErrorMessage> {
|
) -> Result<TranslationPadTask, gst::ErrorMessage> {
|
||||||
let mut this = TranslationPadTask {
|
let mut translation_loop = None;
|
||||||
|
let mut translate_loop_handle = None;
|
||||||
|
let mut to_translate_tx = None;
|
||||||
|
let mut from_translate_rx = None;
|
||||||
|
|
||||||
|
let (our_latency, transcript_event_rx, needs_translate);
|
||||||
|
|
||||||
|
{
|
||||||
|
let elem_imp = elem.imp();
|
||||||
|
let elem_settings = elem_imp.settings.lock().unwrap();
|
||||||
|
|
||||||
|
let pad_settings = pad.settings.lock().unwrap();
|
||||||
|
|
||||||
|
our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
|
||||||
|
if our_latency + elem_settings.lateness <= 2 * GRANULARITY {
|
||||||
|
let err = format!(
|
||||||
|
"total latency + lateness must be greater than {}",
|
||||||
|
2 * GRANULARITY
|
||||||
|
);
|
||||||
|
gst::error!(CAT, imp: pad, "{err}");
|
||||||
|
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
|
||||||
|
}
|
||||||
|
|
||||||
|
needs_translate = TranslateSrcPad::needs_translation(
|
||||||
|
&elem_settings.language_code,
|
||||||
|
pad_settings.language_code.as_deref(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if needs_translate {
|
||||||
|
let (to_loop_tx, to_loop_rx) = mpsc::channel(64);
|
||||||
|
let (from_loop_tx, from_loop_rx) = mpsc::channel(64);
|
||||||
|
|
||||||
|
translation_loop = Some(TranslateLoop::new(
|
||||||
|
elem_imp,
|
||||||
|
pad,
|
||||||
|
&elem_settings.language_code,
|
||||||
|
pad_settings.language_code.as_deref().unwrap(),
|
||||||
|
pad_settings.tokenization_method,
|
||||||
|
to_loop_rx,
|
||||||
|
from_loop_tx,
|
||||||
|
));
|
||||||
|
|
||||||
|
to_translate_tx = Some(to_loop_tx);
|
||||||
|
from_translate_rx = Some(from_loop_rx);
|
||||||
|
|
||||||
|
transcript_event_rx = elem_imp.transcript_event_for_translate_tx.subscribe();
|
||||||
|
} else {
|
||||||
|
transcript_event_rx = elem_imp.transcript_event_tx.subscribe();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(translation_loop) = translation_loop {
|
||||||
|
translation_loop.check_language().await?;
|
||||||
|
translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(TranslationPadTask {
|
||||||
pad: pad.ref_counted(),
|
pad: pad.ref_counted(),
|
||||||
elem,
|
elem,
|
||||||
transcript_event_rx,
|
transcript_event_rx,
|
||||||
needs_translate: false,
|
needs_translate,
|
||||||
translate_queue: TranslateQueue::default(),
|
translate_loop_handle,
|
||||||
translate_loop_handle: None,
|
to_translate_tx,
|
||||||
to_translate_tx: None,
|
from_translate_rx,
|
||||||
from_translate_rx: None,
|
|
||||||
translate_latency: DEFAULT_TRANSLATE_LATENCY,
|
|
||||||
translate_lookahead: DEFAULT_TRANSLATE_LOOKAHEAD,
|
|
||||||
send_events: true,
|
send_events: true,
|
||||||
output_items: VecDeque::new(),
|
output_items: VecDeque::new(),
|
||||||
our_latency: DEFAULT_TRANSCRIBE_LATENCY,
|
our_latency,
|
||||||
seqnum: gst::Seqnum::next(),
|
seqnum: gst::Seqnum::next(),
|
||||||
send_eos: false,
|
send_eos: false,
|
||||||
pending_translations: 0,
|
pending_translations: 0,
|
||||||
start_time: None,
|
})
|
||||||
};
|
|
||||||
|
|
||||||
let _enter_guard = RUNTIME.enter();
|
|
||||||
futures::executor::block_on(this.init_translate())?;
|
|
||||||
|
|
||||||
Ok(this)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -958,11 +1186,9 @@ impl TranslationPadTask {
|
||||||
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
|
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
|
||||||
futures::pin_mut!(transcript_event_rx);
|
futures::pin_mut!(transcript_event_rx);
|
||||||
|
|
||||||
// `timeout` takes precedence over `transcript_events` reception
|
// `transcript_event_rx` takes precedence over `timeout`
|
||||||
// because we may need to `dequeue` `items` or push a `Gap` event
|
// because we don't want to loose any incoming items.
|
||||||
// before current latency budget is exhausted.
|
|
||||||
futures::select_biased! {
|
futures::select_biased! {
|
||||||
_ = timeout => (),
|
|
||||||
items_res = transcript_event_rx => {
|
items_res = transcript_event_rx => {
|
||||||
use TranscriptEvent::*;
|
use TranscriptEvent::*;
|
||||||
use broadcast::error::RecvError;
|
use broadcast::error::RecvError;
|
||||||
|
@ -983,6 +1209,7 @@ impl TranslationPadTask {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ = timeout => (),
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -999,121 +1226,100 @@ impl TranslationPadTask {
|
||||||
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
||||||
}
|
}
|
||||||
|
|
||||||
let transcript_items = {
|
let items_to_translate = {
|
||||||
// This is to make sure we send items on a timely basis or at least Gap events.
|
// This is to make sure we send items on a timely basis or at least Gap events.
|
||||||
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
|
let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
|
||||||
futures::pin_mut!(timeout);
|
futures::pin_mut!(timeout);
|
||||||
|
|
||||||
let from_translate_rx = self
|
|
||||||
.from_translate_rx
|
|
||||||
.as_mut()
|
|
||||||
.expect("from_translation chan must be available in translation mode");
|
|
||||||
|
|
||||||
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
|
let transcript_event_rx = self.transcript_event_rx.recv().fuse();
|
||||||
futures::pin_mut!(transcript_event_rx);
|
futures::pin_mut!(transcript_event_rx);
|
||||||
|
|
||||||
// `timeout` takes precedence over `transcript_events` reception
|
// `transcript_event_rx` takes precedence over `timeout`
|
||||||
// because we may need to `dequeue` `items` or push a `Gap` event
|
// because we don't want to loose any incoming items.
|
||||||
// before current latency budget is exhausted.
|
|
||||||
futures::select_biased! {
|
futures::select_biased! {
|
||||||
_ = timeout => return Ok(()),
|
|
||||||
translated_items = from_translate_rx.next() => {
|
|
||||||
let Some(translated_items) = translated_items else {
|
|
||||||
const ERR: &str = "translation chan terminated";
|
|
||||||
gst::debug!(CAT, imp: self.pad, "{ERR}");
|
|
||||||
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
|
||||||
};
|
|
||||||
|
|
||||||
self.output_items.extend(translated_items.into_iter().map(Into::into));
|
|
||||||
self.pending_translations = self.pending_translations.saturating_sub(1);
|
|
||||||
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
items_res = transcript_event_rx => {
|
items_res = transcript_event_rx => {
|
||||||
use TranscriptEvent::*;
|
use TranscriptEvent::*;
|
||||||
use broadcast::error::RecvError;
|
use broadcast::error::RecvError;
|
||||||
match items_res {
|
match items_res {
|
||||||
Ok(Items(transcript_items)) => transcript_items,
|
Ok(Items(items_to_translate)) => Some(items_to_translate),
|
||||||
Ok(Eos) => {
|
Ok(Eos) => {
|
||||||
gst::debug!(CAT, imp: self.pad, "Got eos");
|
gst::debug!(CAT, imp: self.pad, "Got eos");
|
||||||
self.send_eos = true;
|
self.send_eos = true;
|
||||||
return Ok(());
|
None
|
||||||
}
|
}
|
||||||
Err(RecvError::Lagged(nb_msg)) => {
|
Err(RecvError::Lagged(nb_msg)) => {
|
||||||
gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
|
gst::warning!(CAT, imp: self.pad, "Missed {nb_msg} transcript sets");
|
||||||
return Ok(());
|
None
|
||||||
}
|
}
|
||||||
Err(RecvError::Closed) => {
|
Err(RecvError::Closed) => {
|
||||||
gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
|
gst::debug!(CAT, imp: self.pad, "Transcript chan terminated: setting eos");
|
||||||
self.send_eos = true;
|
self.send_eos = true;
|
||||||
return Ok(());
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
_ = timeout => None,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for items in transcript_items.iter() {
|
if let Some(items_to_translate) = items_to_translate {
|
||||||
if let Some(items_to_translate) = self.translate_queue.push(items) {
|
if !items_to_translate.is_empty() {
|
||||||
self.send_for_translation(items_to_translate).await?;
|
let res = self
|
||||||
|
.to_translate_tx
|
||||||
|
.as_mut()
|
||||||
|
.expect("to_translation chan must be available in translation mode")
|
||||||
|
.send(items_to_translate)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if res.is_err() {
|
||||||
|
const ERR: &str = "to_translation chan terminated";
|
||||||
|
gst::debug!(CAT, imp: self.pad, "{ERR}");
|
||||||
|
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.pending_translations += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
// Check pending translated items
|
||||||
}
|
let from_translate_rx = self
|
||||||
|
.from_translate_rx
|
||||||
|
.as_mut()
|
||||||
|
.expect("from_translation chan must be available in translation mode");
|
||||||
|
|
||||||
async fn dequeue_for_translation(
|
while let Ok(translated_items) = from_translate_rx.try_next() {
|
||||||
&mut self,
|
let Some(translated_items) = translated_items else {
|
||||||
start_time: gst::ClockTime,
|
const ERR: &str = "translation chan terminated";
|
||||||
now: gst::ClockTime,
|
gst::debug!(CAT, imp: self.pad, "{ERR}");
|
||||||
) -> Result<(), gst::ErrorMessage> {
|
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
||||||
if !self.translate_queue.is_empty() {
|
};
|
||||||
// Latency budget for an item to be pushed to stream on time
|
|
||||||
// Margin:
|
|
||||||
// - 2 * GRANULARITY: to make sure we don't push items up to GRANULARITY late.
|
|
||||||
// - 1 * GRANULARITY: extra margin to account for additional overheads.
|
|
||||||
let latency = self.our_latency.saturating_sub(3 * GRANULARITY);
|
|
||||||
|
|
||||||
// Estimated time of arrival for an item sent to translation now.
|
self.output_items
|
||||||
// (in transcript item ts base)
|
.extend(translated_items.into_iter().map(Into::into));
|
||||||
let translation_eta = now + self.translate_latency - start_time;
|
self.pending_translations = self.pending_translations.saturating_sub(1);
|
||||||
|
|
||||||
if let Some(items_to_translate) =
|
|
||||||
self.translate_queue
|
|
||||||
.dequeue(latency, translation_eta, self.translate_lookahead)
|
|
||||||
{
|
|
||||||
gst::debug!(CAT, imp: self.pad, "Forcing to translation: {items_to_translate:?}");
|
|
||||||
self.send_for_translation(items_to_translate).await?;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn dequeue(&mut self) -> bool {
|
async fn dequeue(&mut self) -> bool {
|
||||||
let (now, start_time, mut last_position, mut discont_pending);
|
let Some((start_time, now)) = self.elem.imp().get_start_time_and_now() else {
|
||||||
{
|
// Wait for the clock to be available
|
||||||
let mut pad_state = self.pad.state.lock().unwrap();
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
let Some(cur_rt) = self.elem.current_running_time() else {
|
let (mut last_position, mut discont_pending) = {
|
||||||
// Wait for the clock to be available
|
let mut state = self.pad.state.lock().unwrap();
|
||||||
return true;
|
|
||||||
|
let last_position = if let Some(pos) = state.out_segment.position() {
|
||||||
|
pos
|
||||||
|
} else {
|
||||||
|
state.out_segment.set_position(start_time);
|
||||||
|
start_time
|
||||||
};
|
};
|
||||||
now = cur_rt;
|
|
||||||
|
|
||||||
if self.start_time.is_none() {
|
(last_position, state.discont_pending)
|
||||||
self.start_time = Some(now);
|
};
|
||||||
pad_state.out_segment.set_position(now);
|
|
||||||
}
|
|
||||||
|
|
||||||
start_time = self.start_time.unwrap();
|
|
||||||
last_position = pad_state.out_segment.position().unwrap();
|
|
||||||
discont_pending = pad_state.discont_pending;
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.needs_translate && self.dequeue_for_translation(start_time, now).await.is_err() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* First, check our pending buffers */
|
/* First, check our pending buffers */
|
||||||
while let Some(item) = self.output_items.front() {
|
while let Some(item) = self.output_items.front() {
|
||||||
|
@ -1206,11 +1412,7 @@ impl TranslationPadTask {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.send_eos
|
if self.send_eos && self.pending_translations == 0 && self.output_items.is_empty() {
|
||||||
&& self.pending_translations == 0
|
|
||||||
&& self.output_items.is_empty()
|
|
||||||
&& self.translate_queue.is_empty()
|
|
||||||
{
|
|
||||||
/* We're EOS, we can pause and exit early */
|
/* We're EOS, we can pause and exit early */
|
||||||
let _ = self.pad.obj().pause_task();
|
let _ = self.pad.obj().pause_task();
|
||||||
|
|
||||||
|
@ -1261,28 +1463,6 @@ impl TranslationPadTask {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_for_translation(
|
|
||||||
&mut self,
|
|
||||||
transcript_items: Vec<TranscriptItem>,
|
|
||||||
) -> Result<(), gst::ErrorMessage> {
|
|
||||||
let res = self
|
|
||||||
.to_translate_tx
|
|
||||||
.as_mut()
|
|
||||||
.expect("to_translation chan must be available in translation mode")
|
|
||||||
.send(transcript_items)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if res.is_err() {
|
|
||||||
const ERR: &str = "to_translation chan terminated";
|
|
||||||
gst::debug!(CAT, imp: self.pad, "{ERR}");
|
|
||||||
return Err(gst::error_msg!(gst::StreamError::Failed, ["{ERR}"]));
|
|
||||||
}
|
|
||||||
|
|
||||||
self.pending_translations += 1;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> {
|
fn ensure_init_events(&mut self) -> Result<(), gst::ErrorMessage> {
|
||||||
if !self.send_events {
|
if !self.send_events {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
@ -1332,62 +1512,6 @@ impl TranslationPadTask {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranslationPadTask {
|
|
||||||
async fn init_translate(&mut self) -> Result<(), gst::ErrorMessage> {
|
|
||||||
let mut translation_loop = None;
|
|
||||||
|
|
||||||
{
|
|
||||||
let elem_imp = self.elem.imp();
|
|
||||||
let elem_settings = elem_imp.settings.lock().unwrap();
|
|
||||||
|
|
||||||
let pad_settings = self.pad.settings.lock().unwrap();
|
|
||||||
|
|
||||||
self.our_latency = TranslateSrcPad::our_latency(&elem_settings, &pad_settings);
|
|
||||||
if self.our_latency + elem_settings.lateness <= 2 * GRANULARITY {
|
|
||||||
let err = format!(
|
|
||||||
"total latency + lateness must be greater than {}",
|
|
||||||
2 * GRANULARITY
|
|
||||||
);
|
|
||||||
gst::error!(CAT, imp: self.pad, "{err}");
|
|
||||||
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{err}"]));
|
|
||||||
}
|
|
||||||
|
|
||||||
self.translate_latency = elem_settings.translate_latency;
|
|
||||||
self.translate_lookahead = elem_settings.translate_lookahead;
|
|
||||||
|
|
||||||
self.needs_translate = TranslateSrcPad::needs_translation(
|
|
||||||
&elem_settings.language_code,
|
|
||||||
pad_settings.language_code.as_deref(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if self.needs_translate {
|
|
||||||
let (to_translate_tx, to_translate_rx) = mpsc::channel(64);
|
|
||||||
let (from_translate_tx, from_translate_rx) = mpsc::channel(64);
|
|
||||||
|
|
||||||
translation_loop = Some(TranslateLoop::new(
|
|
||||||
elem_imp,
|
|
||||||
&self.pad,
|
|
||||||
&elem_settings.language_code,
|
|
||||||
pad_settings.language_code.as_deref().unwrap(),
|
|
||||||
pad_settings.tokenization_method,
|
|
||||||
to_translate_rx,
|
|
||||||
from_translate_tx,
|
|
||||||
));
|
|
||||||
|
|
||||||
self.to_translate_tx = Some(to_translate_tx);
|
|
||||||
self.from_translate_rx = Some(from_translate_rx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(translation_loop) = translation_loop {
|
|
||||||
translation_loop.check_language().await?;
|
|
||||||
self.translate_loop_handle = Some(RUNTIME.spawn(translation_loop.run()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct TranslationPadState {
|
struct TranslationPadState {
|
||||||
discont_pending: bool,
|
discont_pending: bool,
|
||||||
|
@ -1422,8 +1546,8 @@ impl TranslateSrcPad {
|
||||||
gst::debug!(CAT, imp: self, "Starting task");
|
gst::debug!(CAT, imp: self, "Starting task");
|
||||||
|
|
||||||
let elem = self.parent();
|
let elem = self.parent();
|
||||||
let transcript_event_rx = elem.imp().transcript_event_tx.subscribe();
|
let _enter = RUNTIME.enter();
|
||||||
let mut pad_task = TranslationPadTask::try_new(self, elem, transcript_event_rx)
|
let mut pad_task = futures::executor::block_on(TranslationPadTask::try_new(self, elem))
|
||||||
.map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?;
|
.map_err(|err| gst::loggable_error!(CAT, format!("Failed to start pad task {err}")))?;
|
||||||
|
|
||||||
let imp = self.ref_counted();
|
let imp = self.ref_counted();
|
||||||
|
|
|
@ -15,7 +15,6 @@ use aws_sdk_transcribestreaming::model;
|
||||||
|
|
||||||
use futures::channel::mpsc;
|
use futures::channel::mpsc;
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use tokio::sync::broadcast;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
@ -23,7 +22,7 @@ use super::imp::{Settings, Transcriber};
|
||||||
use super::CAT;
|
use super::CAT;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct TranscriptionSettings {
|
pub struct TranscriberSettings {
|
||||||
lang_code: model::LanguageCode,
|
lang_code: model::LanguageCode,
|
||||||
sample_rate: i32,
|
sample_rate: i32,
|
||||||
vocabulary: Option<String>,
|
vocabulary: Option<String>,
|
||||||
|
@ -33,9 +32,9 @@ pub struct TranscriptionSettings {
|
||||||
results_stability: model::PartialResultsStability,
|
results_stability: model::PartialResultsStability,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranscriptionSettings {
|
impl TranscriberSettings {
|
||||||
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
|
pub(super) fn from(settings: &Settings, sample_rate: i32) -> Self {
|
||||||
TranscriptionSettings {
|
TranscriberSettings {
|
||||||
lang_code: settings.language_code.as_str().into(),
|
lang_code: settings.language_code.as_str().into(),
|
||||||
sample_rate,
|
sample_rate,
|
||||||
vocabulary: settings.vocabulary.clone(),
|
vocabulary: settings.vocabulary.clone(),
|
||||||
|
@ -83,43 +82,30 @@ impl From<Vec<TranscriptItem>> for TranscriptEvent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TranscriberLoop {
|
pub struct TranscriberStream {
|
||||||
imp: glib::subclass::ObjectImplRef<Transcriber>,
|
imp: glib::subclass::ObjectImplRef<Transcriber>,
|
||||||
client: aws_transcribe::Client,
|
output: aws_transcribe::output::StartStreamTranscriptionOutput,
|
||||||
settings: Option<TranscriptionSettings>,
|
|
||||||
lateness: gst::ClockTime,
|
lateness: gst::ClockTime,
|
||||||
buffer_rx: Option<mpsc::Receiver<gst::Buffer>>,
|
|
||||||
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
|
|
||||||
partial_index: usize,
|
partial_index: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TranscriberLoop {
|
impl TranscriberStream {
|
||||||
pub fn new(
|
pub async fn try_new(
|
||||||
imp: &Transcriber,
|
imp: &Transcriber,
|
||||||
settings: TranscriptionSettings,
|
settings: TranscriberSettings,
|
||||||
lateness: gst::ClockTime,
|
lateness: gst::ClockTime,
|
||||||
buffer_rx: mpsc::Receiver<gst::Buffer>,
|
buffer_rx: mpsc::Receiver<gst::Buffer>,
|
||||||
transcript_items_tx: broadcast::Sender<TranscriptEvent>,
|
) -> Result<Self, gst::ErrorMessage> {
|
||||||
) -> Self {
|
let client = {
|
||||||
let aws_config = imp.aws_config.lock().unwrap();
|
let aws_config = imp.aws_config.lock().unwrap();
|
||||||
let aws_config = aws_config
|
let aws_config = aws_config
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.expect("aws_config must be initialized at this stage");
|
.expect("aws_config must be initialized at this stage");
|
||||||
|
aws_transcribe::Client::new(aws_config)
|
||||||
|
};
|
||||||
|
|
||||||
TranscriberLoop {
|
|
||||||
imp: imp.ref_counted(),
|
|
||||||
client: aws_transcribe::Client::new(aws_config),
|
|
||||||
settings: Some(settings),
|
|
||||||
lateness,
|
|
||||||
buffer_rx: Some(buffer_rx),
|
|
||||||
transcript_items_tx,
|
|
||||||
partial_index: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run(mut self) -> Result<(), gst::ErrorMessage> {
|
|
||||||
// Stream the incoming buffers chunked
|
// Stream the incoming buffers chunked
|
||||||
let chunk_stream = self.buffer_rx.take().unwrap().flat_map(move |buffer: gst::Buffer| {
|
let chunk_stream = buffer_rx.flat_map(move |buffer: gst::Buffer| {
|
||||||
async_stream::stream! {
|
async_stream::stream! {
|
||||||
let data = buffer.map_readable().unwrap();
|
let data = buffer.map_readable().unwrap();
|
||||||
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
|
use aws_transcribe::{model::{AudioEvent, AudioStream}, types::Blob};
|
||||||
|
@ -129,9 +115,7 @@ impl TranscriberLoop {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let settings = self.settings.take().unwrap();
|
let mut transcribe_builder = client
|
||||||
let mut transcribe_builder = self
|
|
||||||
.client
|
|
||||||
.start_stream_transcription()
|
.start_stream_transcription()
|
||||||
.language_code(settings.lang_code)
|
.language_code(settings.lang_code)
|
||||||
.media_sample_rate_hertz(settings.sample_rate)
|
.media_sample_rate_hertz(settings.sample_rate)
|
||||||
|
@ -147,26 +131,42 @@ impl TranscriberLoop {
|
||||||
.vocabulary_filter_method(settings.vocabulary_filter_method);
|
.vocabulary_filter_method(settings.vocabulary_filter_method);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut output = transcribe_builder
|
let output = transcribe_builder
|
||||||
.audio_stream(chunk_stream.into())
|
.audio_stream(chunk_stream.into())
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
let err = format!("Transcribe ws init error: {err}");
|
let err = format!("Transcribe ws init error: {err}");
|
||||||
gst::error!(CAT, imp: self.imp, "{err}");
|
gst::error!(CAT, imp: imp, "{err}");
|
||||||
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
|
gst::error_msg!(gst::LibraryError::Init, ["{err}"])
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
while let Some(event) = output
|
Ok(TranscriberStream {
|
||||||
.transcript_result_stream
|
imp: imp.ref_counted(),
|
||||||
.recv()
|
output,
|
||||||
.await
|
lateness,
|
||||||
.map_err(|err| {
|
partial_index: 0,
|
||||||
let err = format!("Transcribe ws stream error: {err}");
|
})
|
||||||
gst::error!(CAT, imp: self.imp, "{err}");
|
}
|
||||||
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
|
|
||||||
})?
|
pub async fn next(&mut self) -> Result<TranscriptEvent, gst::ErrorMessage> {
|
||||||
{
|
loop {
|
||||||
|
let event = self
|
||||||
|
.output
|
||||||
|
.transcript_result_stream
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
let err = format!("Transcribe ws stream error: {err}");
|
||||||
|
gst::error!(CAT, imp: self.imp, "{err}");
|
||||||
|
gst::error_msg!(gst::LibraryError::Failed, ["{err}"])
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let Some(event) = event else {
|
||||||
|
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
|
||||||
|
return Ok(TranscriptEvent::Eos);
|
||||||
|
};
|
||||||
|
|
||||||
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
|
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
|
||||||
let mut ready_items = None;
|
let mut ready_items = None;
|
||||||
|
|
||||||
|
@ -188,10 +188,7 @@ impl TranscriberLoop {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ready_items) = ready_items {
|
if let Some(ready_items) = ready_items {
|
||||||
if self.transcript_items_tx.send(ready_items.into()).is_err() {
|
return Ok(ready_items.into());
|
||||||
gst::debug!(CAT, imp: self.imp, "No transcript items receivers");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
gst::warning!(
|
gst::warning!(
|
||||||
|
@ -201,13 +198,6 @@ impl TranscriberLoop {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gst::debug!(CAT, imp: self.imp, "Transcriber loop sending EOS");
|
|
||||||
let _ = self.transcript_items_tx.send(TranscriptEvent::Eos);
|
|
||||||
|
|
||||||
gst::debug!(CAT, imp: self.imp, "Exiting transcriber loop");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Builds a list from the provided stable items.
|
/// Builds a list from the provided stable items.
|
||||||
|
|
|
@ -14,7 +14,7 @@ use aws_sdk_translate as aws_translate;
|
||||||
use futures::channel::mpsc;
|
use futures::channel::mpsc;
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
|
|
||||||
use std::collections::VecDeque;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use super::imp::TranslateSrcPad;
|
use super::imp::TranslateSrcPad;
|
||||||
use super::transcribe::TranscriptItem;
|
use super::transcribe::TranscriptItem;
|
||||||
|
@ -40,77 +40,13 @@ impl From<&TranscriptItem> for TranslatedItem {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
|
||||||
pub struct TranslateQueue {
|
|
||||||
items: VecDeque<TranscriptItem>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TranslateQueue {
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.items.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Pushes the provided item.
|
|
||||||
///
|
|
||||||
/// Returns `Some(..)` if items are ready for translation.
|
|
||||||
pub fn push(&mut self, transcript_item: &TranscriptItem) -> Option<Vec<TranscriptItem>> {
|
|
||||||
// Keep track of the item individually so we can schedule translation precisely.
|
|
||||||
self.items.push_back(transcript_item.clone());
|
|
||||||
|
|
||||||
if transcript_item.is_punctuation {
|
|
||||||
// This makes it a good chunk for translation.
|
|
||||||
// Concatenate as a single item for translation
|
|
||||||
|
|
||||||
return Some(self.items.drain(..).collect());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regular case: no separator detected, don't push transcript items
|
|
||||||
// to translation now. They will be pushed either if a punctuation
|
|
||||||
// is found or of a `dequeue()` is requested.
|
|
||||||
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Dequeues items from the specified `deadline` up to `lookahead`.
|
|
||||||
///
|
|
||||||
/// Returns `Some(..)` if some items match the criteria.
|
|
||||||
pub fn dequeue(
|
|
||||||
&mut self,
|
|
||||||
latency: gst::ClockTime,
|
|
||||||
threshold: gst::ClockTime,
|
|
||||||
lookahead: gst::ClockTime,
|
|
||||||
) -> Option<Vec<TranscriptItem>> {
|
|
||||||
let first_pts = self.items.front()?.pts;
|
|
||||||
if first_pts + latency > threshold {
|
|
||||||
// First item is too early to be sent to translation now
|
|
||||||
// we can wait for more items to accumulate.
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Can't wait any longer to send the first item to translation
|
|
||||||
// Try to get up to lookahead worth of items to improve translation accuracy
|
|
||||||
let limit = first_pts + lookahead;
|
|
||||||
|
|
||||||
let mut items_acc = vec![self.items.pop_front().unwrap()];
|
|
||||||
while let Some(item) = self.items.front() {
|
|
||||||
if item.pts > limit {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
items_acc.push(self.items.pop_front().unwrap());
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(items_acc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct TranslateLoop {
|
pub struct TranslateLoop {
|
||||||
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
|
pad: glib::subclass::ObjectImplRef<TranslateSrcPad>,
|
||||||
client: aws_translate::Client,
|
client: aws_translate::Client,
|
||||||
input_lang: String,
|
input_lang: String,
|
||||||
output_lang: String,
|
output_lang: String,
|
||||||
tokenization_method: TranslationTokenizationMethod,
|
tokenization_method: TranslationTokenizationMethod,
|
||||||
transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>,
|
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
|
||||||
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
|
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -121,7 +57,7 @@ impl TranslateLoop {
|
||||||
input_lang: &str,
|
input_lang: &str,
|
||||||
output_lang: &str,
|
output_lang: &str,
|
||||||
tokenization_method: TranslationTokenizationMethod,
|
tokenization_method: TranslationTokenizationMethod,
|
||||||
transcript_rx: mpsc::Receiver<Vec<TranscriptItem>>,
|
transcript_rx: mpsc::Receiver<Arc<Vec<TranscriptItem>>>,
|
||||||
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
|
translate_tx: mpsc::Sender<Vec<TranslatedItem>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let aws_config = imp.aws_config.lock().unwrap();
|
let aws_config = imp.aws_config.lock().unwrap();
|
||||||
|
@ -175,12 +111,12 @@ impl TranslateLoop {
|
||||||
|
|
||||||
let (ts_duration_list, content): (Vec<(gst::ClockTime, gst::ClockTime)>, String) =
|
let (ts_duration_list, content): (Vec<(gst::ClockTime, gst::ClockTime)>, String) =
|
||||||
transcript_items
|
transcript_items
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(|item| {
|
.map(|item| {
|
||||||
(
|
(
|
||||||
(item.pts, item.duration),
|
(item.pts, item.duration),
|
||||||
match self.tokenization_method {
|
match self.tokenization_method {
|
||||||
Tokenization::None => item.content,
|
Tokenization::None => item.content.clone(),
|
||||||
Tokenization::SpanBased => {
|
Tokenization::SpanBased => {
|
||||||
format!("{SPAN_START}{}{SPAN_END}", item.content)
|
format!("{SPAN_START}{}{SPAN_END}", item.content)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue