diff --git a/net/aws/src/transcriber/imp.rs b/net/aws/src/transcriber/imp.rs index 18e45c08..ae766615 100644 --- a/net/aws/src/transcriber/imp.rs +++ b/net/aws/src/transcriber/imp.rs @@ -66,7 +66,7 @@ struct Settings { vocabulary_filter_method: AwsTranscriberVocabularyFilterMethod, } -impl std::default::Default for Settings { +impl Default for Settings { fn default() -> Self { Self { latency: DEFAULT_LATENCY, @@ -112,26 +112,25 @@ impl TranscriptionSettings { struct State { client: Option, buffer_tx: Option>, - transcript_tx: Option>, + transcript_notif_tx: Option>, ws_loop_handle: Option>>, in_segment: gst::FormattedSegment, out_segment: gst::FormattedSegment, seqnum: gst::Seqnum, buffers: VecDeque, send_eos: bool, - // FIXME never set to true discont: bool, partial_index: usize, send_events: bool, start_time: Option, } -impl std::default::Default for State { +impl Default for State { fn default() -> Self { Self { client: None, buffer_tx: None, - transcript_tx: None, + transcript_notif_tx: None, ws_loop_handle: None, in_segment: gst::FormattedSegment::new(), out_segment: gst::FormattedSegment::new(), @@ -297,8 +296,11 @@ impl Transcriber { true } - fn enqueue(&self, state: &mut State, items: &[model::Item], partial: bool) { - let lateness = self.settings.lock().unwrap().lateness; + /// Enqueues a buffer for each of the provided stable items. + /// + /// Returns `true` if at least one buffer was enqueued. + fn enqueue(&self, items: &[model::Item], partial: bool, lateness: gst::ClockTime) -> bool { + let mut state = self.state.lock().unwrap(); if items.len() <= state.partial_index { gst::error!( @@ -313,53 +315,55 @@ impl Transcriber { state.partial_index = 0; } - return; + return false; } - for item in &items[state.partial_index..] { - let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness; - let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness; + let mut enqueued = false; + for item in &items[state.partial_index..] { if !item.stable().unwrap_or(false) { break; } - // FIXME could probably just unwrap - if let Some(content) = item.content() { - /* Should be sent now */ - gst::debug!( - CAT, - imp: self, - "Item is ready for queuing: {content}, PTS {start_time}", - ); + let Some(content) = item.content() else { continue }; - let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes()); - { - let buf = buf.get_mut().unwrap(); + let start_time = ((item.start_time * 1_000_000_000.0) as u64).nseconds() + lateness; + let end_time = ((item.end_time * 1_000_000_000.0) as u64).nseconds() + lateness; - if state.discont { - buf.set_flags(gst::BufferFlags::DISCONT); - state.discont = false; - } + /* Should be sent now */ + gst::debug!( + CAT, + imp: self, + "Item is ready for queuing: {content}, PTS {start_time}", + ); - buf.set_pts(start_time); - buf.set_duration(end_time - start_time); + let mut buf = gst::Buffer::from_mut_slice(content.to_string().into_bytes()); + { + let buf = buf.get_mut().unwrap(); + + if state.discont { + buf.set_flags(gst::BufferFlags::DISCONT); + state.discont = false; } - state.partial_index += 1; - - state.buffers.push_back(buf); - } else { - gst::debug!(CAT, imp: self, "None transcript item content"); + buf.set_pts(start_time); + buf.set_duration(end_time - start_time); } + + state.partial_index += 1; + + state.buffers.push_back(buf); + enqueued = true; } if !partial { state.partial_index = 0; } + + enqueued } - fn pad_loop_fn(&self, receiver: &mut mpsc::Receiver) -> Result<(), ()> { + fn pad_loop_fn(&self, transcript_notif_rx: &mut mpsc::Receiver<()>) { let mut events = { let mut events = vec![]; @@ -400,56 +404,24 @@ impl Transcriber { } let future = async move { - enum Winner { - TranscriptEvent(Option), - Timeout, - } + let timeout = tokio::time::sleep(GRANULARITY.into()).fuse(); + futures::pin_mut!(timeout); - let timer = tokio::time::sleep(GRANULARITY.into()).fuse(); - futures::pin_mut!(timer); - - let race_res = futures::select_biased! { - transcript_evt = receiver.next() => Winner::TranscriptEvent(transcript_evt), - _ = timer => Winner::Timeout, + futures::select! { + notif = transcript_notif_rx.next() => { + if notif.is_none() { + // Transcriber loop terminated + self.state.lock().unwrap().send_eos = true; + return; + }; + } + _ = timeout => (), }; - use Winner::*; - match race_res { - TranscriptEvent(Some(transcript_evt)) => { - if let Some(result) = transcript_evt - .transcript - .as_ref() - .and_then(|transcript| transcript.results()) - .and_then(|results| results.get(0)) - { - gst::trace!(CAT, imp: self, "Received: {result:?}"); - - if let Some(alternative) = result - .alternatives - .as_ref() - .and_then(|alternatives| alternatives.get(0)) - { - if let Some(items) = alternative.items() { - let mut state = self.state.lock().unwrap(); - self.enqueue(&mut state, items, result.is_partial) - } - } - } - } - TranscriptEvent(None) => { - gst::info!(CAT, imp: self, "Transcript evt channel disconnected"); - // Something bad happened elsewhere, let the other side report. - return Err(()); - } - Timeout => (), - } - if !self.dequeue() { gst::info!(CAT, imp: self, "Failed to dequeue buffer, pausing"); let _ = self.srcpad.pause_task(); } - - Ok(()) }; let _enter = RUNTIME.enter(); @@ -459,24 +431,19 @@ impl Transcriber { fn start_task(&self) -> Result<(), gst::LoggableError> { let mut state = self.state.lock().unwrap(); - let (transcript_tx, mut transcript_rx) = mpsc::channel(1); + let (transcript_notif_tx, mut transcript_notif_rx) = mpsc::channel(1); let imp = self.ref_counted(); - let res = self.srcpad.start_task(move || { - if imp.pad_loop_fn(&mut transcript_rx).is_err() { - // Pad loop fn reported an unrecoverable error. - // FIXME we should probably stop the task as - // there's nothing we can do about it except restarting. - let _ = imp.srcpad.pause_task(); - } - }); + let res = self + .srcpad + .start_task(move || imp.pad_loop_fn(&mut transcript_notif_rx)); if res.is_err() { - state.transcript_tx = None; + state.transcript_notif_tx = None; return Err(gst::loggable_error!(CAT, "Failed to start pad task")); } - state.transcript_tx = Some(transcript_tx); + state.transcript_notif_tx = Some(transcript_notif_tx); Ok(()) } @@ -490,7 +457,7 @@ impl Transcriber { ws_loop_handle.abort(); } - state.transcript_tx = None; + state.transcript_notif_tx = None; state.buffer_tx = None; } @@ -652,7 +619,8 @@ impl Transcriber { }, } - let (client_stage, transcription_settings, transcript_tx) = { + let (client_stage, transcription_settings, lateness, transcript_notif_tx); + { let mut state = self.state.lock().unwrap(); if let Some(ref ws_loop_handle) = state.ws_loop_handle { @@ -667,14 +635,15 @@ impl Transcriber { return Ok(()); } - let transcript_tx = state - .transcript_tx + transcript_notif_tx = state + .transcript_notif_tx .take() .expect("attempting to spawn the ws loop, but the srcpad task hasn't been started"); let settings = self.settings.lock().unwrap(); - if settings.latency + settings.lateness <= 2 * GRANULARITY { + lateness = settings.lateness; + if settings.latency + lateness <= 2 * GRANULARITY { const ERR: &str = "latency + lateness must be greater than 200 milliseconds"; gst::error!(CAT, imp: self, "{ERR}"); return Err(gst::error_msg!(gst::LibraryError::Settings, ["{ERR}"])); @@ -684,9 +653,9 @@ impl Transcriber { let s = in_caps.structure(0).unwrap(); let sample_rate = s.get::("rate").unwrap(); - let transcription_settings = TranscriptionSettings::from(&settings, sample_rate); + transcription_settings = TranscriptionSettings::from(&settings, sample_rate); - let client_stage = if let Some(client) = state.client.take() { + client_stage = if let Some(client) = state.client.take() { ClientStage::Ready(client) } else { ClientStage::NotReady { @@ -695,8 +664,6 @@ impl Transcriber { session_token: settings.session_token.to_owned(), } }; - - (client_stage, transcription_settings, transcript_tx) }; let client = match client_stage { @@ -745,8 +712,9 @@ impl Transcriber { let ws_loop_handle = RUNTIME.spawn(self.build_ws_loop_fut( client, transcription_settings, + lateness, buffer_rx, - transcript_tx, + transcript_notif_tx, )); state.ws_loop_handle = Some(ws_loop_handle); @@ -759,18 +727,19 @@ impl Transcriber { &self, client: aws_transcribe::Client, settings: TranscriptionSettings, + lateness: gst::ClockTime, buffer_rx: mpsc::Receiver, - transcript_tx: mpsc::Sender, + transcript_notif_tx: mpsc::Sender<()>, ) -> impl Future> { let imp_weak = self.downgrade(); async move { use gst::glib::subclass::ObjectImplWeakRef; - // Guard that restores client & transcript_tx when the ws loop is done + // Guard that restores client & transcript_notif_tx when the ws loop is done struct Guard { imp_weak: ObjectImplWeakRef, client: Option, - transcript_tx: Option>, + transcript_notif_tx: Option>, } impl Guard { @@ -778,8 +747,8 @@ impl Transcriber { self.client.as_ref().unwrap() } - fn transcript_tx(&mut self) -> &mut mpsc::Sender { - self.transcript_tx.as_mut().unwrap() + fn transcript_notif_tx(&mut self) -> &mut mpsc::Sender<()> { + self.transcript_notif_tx.as_mut().unwrap() } } @@ -788,7 +757,7 @@ impl Transcriber { if let Some(imp) = self.imp_weak.upgrade() { let mut state = imp.state.lock().unwrap(); state.client = self.client.take(); - state.transcript_tx = self.transcript_tx.take(); + state.transcript_notif_tx = self.transcript_notif_tx.take(); } } } @@ -796,7 +765,7 @@ impl Transcriber { let mut guard = Guard { imp_weak: imp_weak.clone(), client: Some(client), - transcript_tx: Some(transcript_tx), + transcript_notif_tx: Some(transcript_notif_tx), }; // Stream the incoming buffers chunked @@ -852,9 +821,32 @@ impl Transcriber { })? { if let model::TranscriptResultStream::TranscriptEvent(transcript_evt) = event { - if guard.transcript_tx().send(transcript_evt).await.is_err() { + let mut enqueued = false; + + if let Some(result) = transcript_evt + .transcript + .as_ref() + .and_then(|transcript| transcript.results()) + .and_then(|results| results.get(0)) + { + let Some(imp) = imp_weak.upgrade() else { break }; + + gst::trace!(CAT, imp: imp, "Received: {result:?}"); + + if let Some(alternative) = result + .alternatives + .as_ref() + .and_then(|alternatives| alternatives.get(0)) + { + if let Some(items) = alternative.items() { + enqueued = imp.enqueue(items, result.is_partial, lateness); + } + } + } + + if enqueued && guard.transcript_notif_tx().send(()).await.is_err() { if let Some(imp) = imp_weak.upgrade() { - gst::debug!(CAT, imp: imp, "Terminated transcript_evt channel"); + gst::debug!(CAT, imp: imp, "Terminated transcript_notif_tx channel"); } break; } @@ -882,6 +874,7 @@ impl Transcriber { let mut state = self.state.lock().unwrap(); gst::info!(CAT, imp: self, "Unpreparing"); self.stop_task(); + // Also resets discont to true *state = State::default(); gst::info!(CAT, imp: self, "Unprepared"); }