From 36ae29d7465e62b0c1f08b713597898074356726 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laignel?= Date: Tue, 28 Feb 2023 16:28:13 +0100 Subject: [PATCH] net/aws: enqueue transcribed buffers within the ws loop Instead of sending transcription events to the src pad loop, this commit enqueues the transcribed buffers immediately in the ws loop, then notifies the src pad loop. The src pad loop is only in charge of dequeuing the buffers. This should help with upcoming evolutions. Part-of: --- net/aws/src/transcriber/imp.rs | 207 ++++++++++++++++----------------- 1 file changed, 100 insertions(+), 107 deletions(-) diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index 18e45c08..ae766615 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -66,7 +66,7 @@ struct Settings { vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod, } -impl std::default::Default for Settings { +impl Default for Settings { fn default() -> Self { Self { latency: DEFAULT_LATENCY, @@ -112,26 +112,25 @@ impl TranscriptionSettings { struct State { client: Option, buffer_tx: Option>, - transcript_tx: Option>, + transcript_notif_tx: Option>, ws_loop_handle: Option>>, in_segment: gst::FormattedSegment, out_segment: gst::FormattedSegment, seqnum: gst::Seqnum, buffers: VecDeque, send_eos: bool, - // FIXME never set to true discont: bool, partial_index: usize, send_events: bool, start_time: Option, } -impl std::default::Default for State { +impl Default for State { fn default() -> Self { Self { client: None, buffer_tx: None, - transcript_tx: None, + transcript_notif_tx: None, ws_loop_handle: None, in_segment: gst::FormattedSegment::new(), out_segment: gst::FormattedSegment::new(), @@ -297,8 +296,11 @@ impl Transcriber { true } - fn enqueue(&self, state: &mut State, items: &[model::Item], partial: bool) { - let lateness = self.settings.lock().unwrap().lateness; + /// Enqueues a buffer for each of the provided stable items. + /// + /// Returns `true` if at least one buffer was enqueued. + fn enqueue(&self, items: &[model::Item], partial: bool, lateness: gst::ClockTime) -> bool { + let mut state = self.state.lock().unwrap(); if items.len() <= state.partial_index { gst::error!( @@ -313,53 +315,55 @@ impl Transcriber { state.partial_index = 0; } - return; + return false; } - for item in &items[state.partial_index..] { - let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness; - let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness; + let mut enqueued = false; + for item in &items[state.partial_index..] { if !item.stable().unwrap_or(false) { break; } - // FIXME could probably just unwrap - if let Some(content) = item.content() { - /* Should be sent now */ - gst::debug!( - CAT, - imp: self, - "Item is ready for queuing: {content}, PTS {start_time}", - ); + let Some(content) = item.content() else { continue }; - let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes()); - { - let buf = buf.get_mut().unwrap(); + let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness; + let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness; - if state.discont { - buf.set_flags(gst::BufferFlags::DISCONT); - state.discont = false; - } + /* Should be sent now */ + gst::debug!( + CAT, + imp: self, + "Item is ready for queuing: {content}, PTS {start_time}", + ); - buf.set_pts(start_time); - buf.set_duration(end_time - start_time); + let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes()); + { + let buf = buf.get_mut().unwrap(); + + if state.discont { + buf.set_flags(gst::BufferFlags::DISCONT); + state.discont = false; } - state.partial_index += 1; - - state.buffers.push_back(buf); - } else { - gst::debug!(CAT, imp: self, "None transcript item content"); + buf.set_pts(start_time); + buf.set_duration(end_time - start_time); } + + state.partial_index += 1; + + state.buffers.push_back(buf); + enqueued = true; } if !partial { state.partial_index = 0; } + + enqueued } - fn pad_loop_fn(&self, receiver: &mut mpsc::Receiver) -> Result<(), ()> { + fn pad_loop_fn(&self, transcript_notif_rx: &mut mpsc::Receiver<()>) { let mut events = { let mut events = vec![]; @@ -400,56 +404,24 @@ impl Transcriber { } let future = async move { - enum Winner { - TranscriptEvent(Option), - Timeout, - } + let timeout = tokio::time::sleep(GRANULARITY.into()).fuse(); + futures::pin_mut!(timeout); - let timer = tokio::time::sleep(GRANULARITY.into()).fuse(); - futures::pin_mut!(timer); - - let race_res = futures::select_biased! { - transcript_evt = receiver.next() => Winner::TranscriptEvent(transcript_evt), - _ = timer => Winner::Timeout, + futures::select! { + notif = transcript_notif_rx.next() => { + if notif.is_none() { + // Transcriber loop terminated + self.state.lock().unwrap().send_eos = true; + return; + }; + } + _ = timeout => (), }; - use Winner::*; - match race_res { - TranscriptEvent(Some(transcript_evt)) => { - if let Some(result) = transcript_evt - .transcript - .as_ref() - .and_then(|transcript| transcript.results()) - .and_then(|results| results.get(0)) - { - gst::trace!(CAT, imp: self, "Received: {result:?}"); - - if let Some(alternative) = result - .alternatives - .as_ref() - .and_then(|alternatives| alternatives.get(0)) - { - if let Some(items) = alternative.items() { - let mut state = self.state.lock().unwrap(); - self.enqueue(&mut state, items, result.is_partial) - } - } - } - } - TranscriptEvent(None) => { - gst::info!(CAT, imp: self, "Transcript evt channel disconnected"); - // Something bad happened elsewhere, let the other side report. - return Err(()); - } - Timeout => (), - } - if !self.dequeue() { gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing"); let _ = self.srcpad.pause_task(); } - - Ok(()) }; let _enter = RUNTIME.enter(); @@ -459,24 +431,19 @@ impl Transcriber { fn start_task(&self) -> Result<(), gst::LoggableError> { let mut state = self.state.lock().unwrap(); - let (transcript_tx, mut transcript_rx) = mpsc::channel(1); + let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1); let imp = self.ref_counted(); - let res = self.srcpad.start_task(move || { - if imp.pad_loop_fn(&mut transcript_rx).is_err() { - // Pad loop fn reported an unrecoverable error. - // FIXME we should probably stop the task as - // there's nothing we can do about it except restarting. - let _ = imp.srcpad.pause_task(); - } - }); + let res = self + .srcpad + .start_task(move || imp.pad_loop_fn(&mut transcript_notif_rx)); if res.is_err() { - state.transcript_tx = None; + state.transcript_notif_tx = None; return Err(gst::loggable_error!(CAT, "Failed to start pad task")); } - state.transcript_tx = Some(transcript_tx); + state.transcript_notif_tx = Some(transcript_notif_tx); Ok(()) } @@ -490,7 +457,7 @@ impl Transcriber { ws_loop_handle.abort(); } - state.transcript_tx = None; + state.transcript_notif_tx = None; state.buffer_tx = None; } @@ -652,7 +619,8 @@ impl Transcriber { }, } - let (client_stage, transcription_settings, transcript_tx) = { + let (client_stage, transcription_settings, lateness, transcript_notif_tx); + { let mut state = self.state.lock().unwrap(); if let Some(ref ws_loop_handle) = state.ws_loop_handle { @@ -667,14 +635,15 @@ impl Transcriber { return Ok(()); } - let transcript_tx = state - .transcript_tx + transcript_notif_tx = state + .transcript_notif_tx .take() .expect("attempting to spawn the ws loop, but the srcpad task hasn't been started"); let settings = self.settings.lock().unwrap(); - if settings.latency + settings.lateness <= 2 * GRANULARITY { + lateness = settings.lateness; + if settings.latency + lateness <= 2 * GRANULARITY { const ERR: &str = "latency + lateness must be greater than 200 milliseconds"; gst::error!(CAT, imp: self, "{ERR}"); return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"])); @@ -684,9 +653,9 @@ impl Transcriber { let s = in_caps.structure(0).unwrap(); let sample_rate = s.get::("rate").unwrap(); - let transcription_settings = TranscriptionSettings::from(&settings, sample_rate); + transcription_settings = TranscriptionSettings::from(&settings, sample_rate); - let client_stage = if let Some(client) = state.client.take() { + client_stage = if let Some(client) = state.client.take() { ClientStage::Ready(client) } else { ClientStage::NotReady { @@ -695,8 +664,6 @@ impl Transcriber { session_token: settings.session_token.to_owned(), } }; - - (client_stage, transcription_settings, transcript_tx) }; let client = match client_stage { @@ -745,8 +712,9 @@ impl Transcriber { let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut( client, transcription_settings, + lateness, buffer_rx, - transcript_tx, + transcript_notif_tx, )); state.ws_loop_handle = Some(ws_loop_handle); @@ -759,18 +727,19 @@ impl Transcriber { &self, client: aws_transcribe::Client, settings: TranscriptionSettings, + lateness: gst::ClockTime, buffer_rx: mpsc::Receiver, - transcript_tx: mpsc::Sender, + transcript_notif_tx: mpsc::Sender<()>, ) -> impl Future> { let imp_weak = self.downgrade(); async move { use gst::glib::subclass::ObjectImplWeakRef; - // Guard that restores client & transcript_tx when the ws loop is done + // Guard that restores client & transcript_notif_tx when the ws loop is done struct Guard { imp_weak: ObjectImplWeakRef, client: Option, - transcript_tx: Option>, + transcript_notif_tx: Option>, } impl Guard { @@ -778,8 +747,8 @@ impl Transcriber { self.client.as_ref().unwrap() } - fn transcript_tx(&mut self) -> &mut mpsc::Sender { - self.transcript_tx.as_mut().unwrap() + fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> { + self.transcript_notif_tx.as_mut().unwrap() } } @@ -788,7 +757,7 @@ impl Transcriber { if let Some(imp) = self.imp_weak.upgrade() { let mut state = imp.state.lock().unwrap(); state.client = self.client.take(); - state.transcript_tx = self.transcript_tx.take(); + state.transcript_notif_tx = self.transcript_notif_tx.take(); } } } @@ -796,7 +765,7 @@ impl Transcriber { let mut guard = Guard { imp_weak: imp_weak.clone(), client: Some(client), - transcript_tx: Some(transcript_tx), + transcript_notif_tx: Some(transcript_notif_tx), }; // Stream the incoming buffers chunked @@ -852,9 +821,32 @@ impl Transcriber { })? { if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { - if guard.transcript_tx().send(transcript_evt).await.is_err() { + let mut enqueued = false; + + if let Some(result) = transcript_evt + .transcript + .as_ref() + .and_then(|transcript| transcript.results()) + .and_then(|results| results.get(0)) + { + let Some(imp) = imp_weak.upgrade() else { break }; + + gst::trace!(CAT, imp: imp, "Received: {result:?}"); + + if let Some(alternative) = result + .alternatives + .as_ref() + .and_then(|alternatives| alternatives.get(0)) + { + if let Some(items) = alternative.items() { + enqueued = imp.enqueue(items, result.is_partial, lateness); + } + } + } + + if enqueued && guard.transcript_notif_tx().send(()).await.is_err() { if let Some(imp) = imp_weak.upgrade() { - gst::debug!(CAT, imp: imp, "Terminated transcript_evt channel"); + gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel"); } break; } @@ -882,6 +874,7 @@ impl Transcriber { let mut state = self.state.lock().unwrap(); gst::info!(CAT, imp: self, "Unpreparing"); self.stop_task(); + // Also resets discont to true *state = State::default(); gst::info!(CAT, imp: self, "Unprepared"); }