net/aws: enqueue transcribed buffers within the ws loop

Instead of sending transcription events to the src pad loop, this commit
enqueues the transcribed buffers immediately in the ws loop, then notifies
the src pad loop. The src pad loop is only in charge of dequeuing the buffers.

This should help with upcoming evolutions.

Part-of: <https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs/-/merge_requests/1104>
This commit is contained in:
François Laignel 2023-02-28 16:28:13 +01:00 committed by GStreamer Marge Bot
parent 00153754bb
commit 36ae29d746

View file

@ -66,7 +66,7 @@ struct Settings {
vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod, vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod,
} }
impl std::default::Default for Settings { impl Default for Settings {
fn default() -> Self { fn default() -> Self {
Self { Self {
latency: DEFAULT_LATENCY, latency: DEFAULT_LATENCY,
@ -112,26 +112,25 @@ impl TranscriptionSettings {
struct State { struct State {
client: Option<aws_transcribe::Client>, client: Option<aws_transcribe::Client>,
buffer_tx: Option<mpsc::Sender<gst::Buffer>>, buffer_tx: Option<mpsc::Sender<gst::Buffer>>,
transcript_tx: Option<mpsc::Sender<model::TranscriptEvent>>, transcript_notif_tx: Option<mpsc::Sender<()>>,
ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>, ws_loop_handle: Option<task::JoinHandle<Result<(), gst::ErrorMessage>>>,
in_segment: gst::FormattedSegment<gst::ClockTime>, in_segment: gst::FormattedSegment<gst::ClockTime>,
out_segment: gst::FormattedSegment<gst::ClockTime>, out_segment: gst::FormattedSegment<gst::ClockTime>,
seqnum: gst::Seqnum, seqnum: gst::Seqnum,
buffers: VecDeque<gst::Buffer>, buffers: VecDeque<gst::Buffer>,
send_eos: bool, send_eos: bool,
// FIXME never set to true
discont: bool, discont: bool,
partial_index: usize, partial_index: usize,
send_events: bool, send_events: bool,
start_time: Option<gst::ClockTime>, start_time: Option<gst::ClockTime>,
} }
impl std::default::Default for State { impl Default for State {
fn default() -> Self { fn default() -> Self {
Self { Self {
client: None, client: None,
buffer_tx: None, buffer_tx: None,
transcript_tx: None, transcript_notif_tx: None,
ws_loop_handle: None, ws_loop_handle: None,
in_segment: gst::FormattedSegment::new(), in_segment: gst::FormattedSegment::new(),
out_segment: gst::FormattedSegment::new(), out_segment: gst::FormattedSegment::new(),
@ -297,8 +296,11 @@ impl Transcriber {
true true
} }
fn enqueue(&self, state: &mut State, items: &[model::Item], partial: bool) { /// Enqueues a buffer for each of the provided stable items.
let lateness = self.settings.lock().unwrap().lateness; ///
/// Returns `true` if at least one buffer was enqueued.
fn enqueue(&self, items: &[model::Item], partial: bool, lateness: gst::ClockTime) -> bool {
let mut state = self.state.lock().unwrap();
if items.len() <= state.partial_index { if items.len() <= state.partial_index {
gst::error!( gst::error!(
@ -313,53 +315,55 @@ impl Transcriber {
state.partial_index = 0; state.partial_index = 0;
} }
return; return false;
} }
for item in &items[state.partial_index..] { let mut enqueued = false;
let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
for item in &items[state.partial_index..] {
if !item.stable().unwrap_or(false) { if !item.stable().unwrap_or(false) {
break; break;
} }
// FIXME could probably just unwrap let Some(content) = item.content() else { continue };
if let Some(content) = item.content() {
/* Should be sent now */
gst::debug!(
CAT,
imp: self,
"Item is ready for queuing: {content}, PTS {start_time}",
);
let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes()); let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness;
{ let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness;
let buf = buf.get_mut().unwrap();
if state.discont { /* Should be sent now */
buf.set_flags(gst::BufferFlags::DISCONT); gst::debug!(
state.discont = false; CAT,
} imp: self,
"Item is ready for queuing: {content}, PTS {start_time}",
);
buf.set_pts(start_time); let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes());
buf.set_duration(end_time - start_time); {
let buf = buf.get_mut().unwrap();
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
} }
state.partial_index += 1; buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
state.buffers.push_back(buf);
} else {
gst::debug!(CAT, imp: self, "None transcript item content");
} }
state.partial_index += 1;
state.buffers.push_back(buf);
enqueued = true;
} }
if !partial { if !partial {
state.partial_index = 0; state.partial_index = 0;
} }
enqueued
} }
fn pad_loop_fn(&self, receiver: &mut mpsc::Receiver<model::TranscriptEvent>) -> Result<(), ()> { fn pad_loop_fn(&self, transcript_notif_rx: &mut mpsc::Receiver<()>) {
let mut events = { let mut events = {
let mut events = vec![]; let mut events = vec![];
@ -400,56 +404,24 @@ impl Transcriber {
} }
let future = async move { let future = async move {
enum Winner { let timeout = tokio::time::sleep(GRANULARITY.into()).fuse();
TranscriptEvent(Option<model::TranscriptEvent>), futures::pin_mut!(timeout);
Timeout,
}
let timer = tokio::time::sleep(GRANULARITY.into()).fuse(); futures::select! {
futures::pin_mut!(timer); notif = transcript_notif_rx.next() => {
if notif.is_none() {
let race_res = futures::select_biased! { // Transcriber loop terminated
transcript_evt = receiver.next() => Winner::TranscriptEvent(transcript_evt), self.state.lock().unwrap().send_eos = true;
_ = timer => Winner::Timeout, return;
};
}
_ = timeout => (),
}; };
use Winner::*;
match race_res {
TranscriptEvent(Some(transcript_evt)) => {
if let Some(result) = transcript_evt
.transcript
.as_ref()
.and_then(|transcript| transcript.results())
.and_then(|results| results.get(0))
{
gst::trace!(CAT, imp: self, "Received: {result:?}");
if let Some(alternative) = result
.alternatives
.as_ref()
.and_then(|alternatives| alternatives.get(0))
{
if let Some(items) = alternative.items() {
let mut state = self.state.lock().unwrap();
self.enqueue(&mut state, items, result.is_partial)
}
}
}
}
TranscriptEvent(None) => {
gst::info!(CAT, imp: self, "Transcript evt channel disconnected");
// Something bad happened elsewhere, let the other side report.
return Err(());
}
Timeout => (),
}
if !self.dequeue() { if !self.dequeue() {
gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing"); gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing");
let _ = self.srcpad.pause_task(); let _ = self.srcpad.pause_task();
} }
Ok(())
}; };
let _enter = RUNTIME.enter(); let _enter = RUNTIME.enter();
@ -459,24 +431,19 @@ impl Transcriber {
fn start_task(&self) -> Result<(), gst::LoggableError> { fn start_task(&self) -> Result<(), gst::LoggableError> {
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
let (transcript_tx, mut transcript_rx) = mpsc::channel(1); let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1);
let imp = self.ref_counted(); let imp = self.ref_counted();
let res = self.srcpad.start_task(move || { let res = self
if imp.pad_loop_fn(&mut transcript_rx).is_err() { .srcpad
// Pad loop fn reported an unrecoverable error. .start_task(move || imp.pad_loop_fn(&mut transcript_notif_rx));
// FIXME we should probably stop the task as
// there's nothing we can do about it except restarting.
let _ = imp.srcpad.pause_task();
}
});
if res.is_err() { if res.is_err() {
state.transcript_tx = None; state.transcript_notif_tx = None;
return Err(gst::loggable_error!(CAT, "Failed to start pad task")); return Err(gst::loggable_error!(CAT, "Failed to start pad task"));
} }
state.transcript_tx = Some(transcript_tx); state.transcript_notif_tx = Some(transcript_notif_tx);
Ok(()) Ok(())
} }
@ -490,7 +457,7 @@ impl Transcriber {
ws_loop_handle.abort(); ws_loop_handle.abort();
} }
state.transcript_tx = None; state.transcript_notif_tx = None;
state.buffer_tx = None; state.buffer_tx = None;
} }
@ -652,7 +619,8 @@ impl Transcriber {
}, },
} }
let (client_stage, transcription_settings, transcript_tx) = { let (client_stage, transcription_settings, lateness, transcript_notif_tx);
{
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
if let Some(ref ws_loop_handle) = state.ws_loop_handle { if let Some(ref ws_loop_handle) = state.ws_loop_handle {
@ -667,14 +635,15 @@ impl Transcriber {
return Ok(()); return Ok(());
} }
let transcript_tx = state transcript_notif_tx = state
.transcript_tx .transcript_notif_tx
.take() .take()
.expect("attempting to spawn the ws loop, but the srcpad task hasn't been started"); .expect("attempting to spawn the ws loop, but the srcpad task hasn't been started");
let settings = self.settings.lock().unwrap(); let settings = self.settings.lock().unwrap();
if settings.latency + settings.lateness <= 2 * GRANULARITY { lateness = settings.lateness;
if settings.latency + lateness <= 2 * GRANULARITY {
const ERR: &str = "latency + lateness must be greater than 200 milliseconds"; const ERR: &str = "latency + lateness must be greater than 200 milliseconds";
gst::error!(CAT, imp: self, "{ERR}"); gst::error!(CAT, imp: self, "{ERR}");
return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"])); return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"]));
@ -684,9 +653,9 @@ impl Transcriber {
let s = in_caps.structure(0).unwrap(); let s = in_caps.structure(0).unwrap();
let sample_rate = s.get::<i32>("rate").unwrap(); let sample_rate = s.get::<i32>("rate").unwrap();
let transcription_settings = TranscriptionSettings::from(&settings, sample_rate); transcription_settings = TranscriptionSettings::from(&settings, sample_rate);
let client_stage = if let Some(client) = state.client.take() { client_stage = if let Some(client) = state.client.take() {
ClientStage::Ready(client) ClientStage::Ready(client)
} else { } else {
ClientStage::NotReady { ClientStage::NotReady {
@ -695,8 +664,6 @@ impl Transcriber {
session_token: settings.session_token.to_owned(), session_token: settings.session_token.to_owned(),
} }
}; };
(client_stage, transcription_settings, transcript_tx)
}; };
let client = match client_stage { let client = match client_stage {
@ -745,8 +712,9 @@ impl Transcriber {
let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut( let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut(
client, client,
transcription_settings, transcription_settings,
lateness,
buffer_rx, buffer_rx,
transcript_tx, transcript_notif_tx,
)); ));
state.ws_loop_handle = Some(ws_loop_handle); state.ws_loop_handle = Some(ws_loop_handle);
@ -759,18 +727,19 @@ impl Transcriber {
&self, &self,
client: aws_transcribe::Client, client: aws_transcribe::Client,
settings: TranscriptionSettings, settings: TranscriptionSettings,
lateness: gst::ClockTime,
buffer_rx: mpsc::Receiver<gst::Buffer>, buffer_rx: mpsc::Receiver<gst::Buffer>,
transcript_tx: mpsc::Sender<model::TranscriptEvent>, transcript_notif_tx: mpsc::Sender<()>,
) -> impl Future<Output = Result<(), gst::ErrorMessage>> { ) -> impl Future<Output = Result<(), gst::ErrorMessage>> {
let imp_weak = self.downgrade(); let imp_weak = self.downgrade();
async move { async move {
use gst::glib::subclass::ObjectImplWeakRef; use gst::glib::subclass::ObjectImplWeakRef;
// Guard that restores client & transcript_tx when the ws loop is done // Guard that restores client & transcript_notif_tx when the ws loop is done
struct Guard { struct Guard {
imp_weak: ObjectImplWeakRef<Transcriber>, imp_weak: ObjectImplWeakRef<Transcriber>,
client: Option<aws_transcribe::Client>, client: Option<aws_transcribe::Client>,
transcript_tx: Option<mpsc::Sender<model::TranscriptEvent>>, transcript_notif_tx: Option<mpsc::Sender<()>>,
} }
impl Guard { impl Guard {
@ -778,8 +747,8 @@ impl Transcriber {
self.client.as_ref().unwrap() self.client.as_ref().unwrap()
} }
fn transcript_tx(&mut self) -> &mut mpsc::Sender<model::TranscriptEvent> { fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> {
self.transcript_tx.as_mut().unwrap() self.transcript_notif_tx.as_mut().unwrap()
} }
} }
@ -788,7 +757,7 @@ impl Transcriber {
if let Some(imp) = self.imp_weak.upgrade() { if let Some(imp) = self.imp_weak.upgrade() {
let mut state = imp.state.lock().unwrap(); let mut state = imp.state.lock().unwrap();
state.client = self.client.take(); state.client = self.client.take();
state.transcript_tx = self.transcript_tx.take(); state.transcript_notif_tx = self.transcript_notif_tx.take();
} }
} }
} }
@ -796,7 +765,7 @@ impl Transcriber {
let mut guard = Guard { let mut guard = Guard {
imp_weak: imp_weak.clone(), imp_weak: imp_weak.clone(),
client: Some(client), client: Some(client),
transcript_tx: Some(transcript_tx), transcript_notif_tx: Some(transcript_notif_tx),
}; };
// Stream the incoming buffers chunked // Stream the incoming buffers chunked
@ -852,9 +821,32 @@ impl Transcriber {
})? })?
{ {
if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event {
if guard.transcript_tx().send(transcript_evt).await.is_err() { let mut enqueued = false;
if let Some(result) = transcript_evt
.transcript
.as_ref()
.and_then(|transcript| transcript.results())
.and_then(|results| results.get(0))
{
let Some(imp) = imp_weak.upgrade() else { break };
gst::trace!(CAT, imp: imp, "Received: {result:?}");
if let Some(alternative) = result
.alternatives
.as_ref()
.and_then(|alternatives| alternatives.get(0))
{
if let Some(items) = alternative.items() {
enqueued = imp.enqueue(items, result.is_partial, lateness);
}
}
}
if enqueued && guard.transcript_notif_tx().send(()).await.is_err() {
if let Some(imp) = imp_weak.upgrade() { if let Some(imp) = imp_weak.upgrade() {
gst::debug!(CAT, imp: imp, "Terminated transcript_evt channel"); gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel");
} }
break; break;
} }
@ -882,6 +874,7 @@ impl Transcriber {
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
gst::info!(CAT, imp: self, "Unpreparing"); gst::info!(CAT, imp: self, "Unpreparing");
self.stop_task(); self.stop_task();
// Also resets discont to true
*state = State::default(); *state = State::default();
gst::info!(CAT, imp: self, "Unprepared"); gst::info!(CAT, imp: self, "Unprepared");
} }