diff --git a/audio/transcribe/src/aws_transcribe_parse.rs b/audio/transcribe/src/aws_transcribe_parse.rs index 0656c37b..bd76108c 100644 --- a/audio/transcribe/src/aws_transcribe_parse.rs +++ b/audio/transcribe/src/aws_transcribe_parse.rs @@ -108,10 +108,11 @@ static RUNTIME: Lazy = Lazy::new(|| { .unwrap() }); -const DEFAULT_LATENCY_MS: u32 = 30000; +const DEFAULT_LATENCY_MS: u32 = 8000; +const DEFAULT_USE_PARTIAL_RESULTS: bool = true; const GRANULARITY_MS: u32 = 100; -static PROPERTIES: [subclass::Property; 2] = [ +static PROPERTIES: [subclass::Property; 3] = [ subclass::Property("language-code", |name| { glib::ParamSpec::string( name, @@ -123,12 +124,21 @@ static PROPERTIES: [subclass::Property; 2] = [ glib::ParamFlags::READWRITE, ) }), + subclass::Property("use-partial-results", |name| { + glib::ParamSpec::boolean( + name, + "Latency", + "Whether partial results from AWS should be used", + DEFAULT_USE_PARTIAL_RESULTS, + glib::ParamFlags::READWRITE, + ) + }), subclass::Property("latency", |name| { glib::ParamSpec::uint( name, "Latency", "Amount of milliseconds to allow AWS transcribe", - GRANULARITY_MS, + 2 * GRANULARITY_MS, std::u32::MAX, DEFAULT_LATENCY_MS, glib::ParamFlags::READWRITE, @@ -140,6 +150,7 @@ static PROPERTIES: [subclass::Property; 2] = [ struct Settings { latency_ms: u32, language_code: Option, + use_partial_results: bool, } impl Default for Settings { @@ -147,6 +158,7 @@ impl Default for Settings { Self { latency_ms: DEFAULT_LATENCY_MS, language_code: Some("en-US".to_string()), + use_partial_results: DEFAULT_USE_PARTIAL_RESULTS, } } } @@ -162,6 +174,8 @@ struct State { buffers: VecDeque, send_eos: bool, discont: bool, + last_partial_end_time: gst::ClockTime, + partial_alternative: Option, } impl Default for State { @@ -177,6 +191,8 @@ impl Default for State { buffers: VecDeque::new(), send_eos: false, discont: true, + last_partial_end_time: gst::CLOCK_TIME_NONE, + partial_alternative: None, } } } @@ -257,13 +273,19 @@ impl Transcriber { let (latency, now, mut last_position, send_eos, seqnum) = { let mut state = self.state.lock().unwrap(); - let send_eos = state.send_eos && state.buffers.is_empty(); - + // Multiply GRANULARITY by 2 in order to not send buffers that + // are less than GRANULARITY milliseconds away too late let latency: gst::ClockTime = (self.settings.lock().unwrap().latency_ms as u64 - - GRANULARITY_MS as u64) + - (2 * GRANULARITY_MS) as u64) * gst::MSECOND; let now = element.get_current_running_time(); + if let Some(alternative) = state.partial_alternative.take() { + self.enqueue(element, &mut state, &alternative, true, latency, now); + state.partial_alternative = Some(alternative); + } + let send_eos = state.send_eos && state.buffers.is_empty(); + while let Some(buf) = state.buffers.front() { if now - buf.get_pts() > latency { /* Safe unwrap, we know we have an item */ @@ -352,6 +374,64 @@ impl Transcriber { true } + fn enqueue( + &self, + element: &gst::Element, + state: &mut State, + alternative: &TranscriptAlternative, + partial: bool, + latency: gst::ClockTime, + now: gst::ClockTime, + ) { + for item in &alternative.items { + let mut start_time: gst::ClockTime = + ((item.start_time as f64 * 1_000_000_000.0) as u64).into(); + let mut end_time: gst::ClockTime = + ((item.end_time as f64 * 1_000_000_000.0) as u64).into(); + + if start_time <= state.last_partial_end_time { + /* Already sent (hopefully) */ + continue; + } else if !partial || start_time + latency < now { + /* Should be sent now */ + gst_debug!(CAT, obj: element, "Item is ready: {}", item.content); + let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes()); + state.last_partial_end_time = end_time; + + { + let buf = buf.get_mut().unwrap(); + + if state.discont { + buf.set_flags(gst::BufferFlags::DISCONT); + state.discont = false; + } + + if start_time < state.out_segment.get_position() { + gst_debug!( + CAT, + obj: element, + "Adjusting item timing({:?} < {:?})", + start_time, + state.out_segment.get_position() + ); + start_time = state.out_segment.get_position(); + if end_time < start_time { + end_time = start_time; + } + } + + buf.set_pts(start_time); + buf.set_duration(end_time - start_time); + } + + state.buffers.push_back(buf); + } else { + /* Doesn't need to be sent yet */ + break; + } + } + } + fn loop_fn( &self, element: &gst::Element, @@ -417,50 +497,78 @@ impl Transcriber { if !transcript.transcript.results.is_empty() { let mut result = transcript.transcript.results.remove(0); + let use_partial_results = self.settings.lock().unwrap().use_partial_results; if !result.is_partial && !result.alternatives.is_empty() { - let alternative = result.alternatives.remove(0); - gst_info!(CAT, obj: element, "Transcript: {}", alternative.transcript); - - let mut start_time: gst::ClockTime = - ((result.start_time as f64 * 1_000_000_000.0) as u64).into(); - let end_time: gst::ClockTime = - ((result.end_time as f64 * 1_000_000_000.0) as u64).into(); - - let mut state = self.state.lock().unwrap(); - let position = state.out_segment.get_position(); - - if end_time < position { - gst_warning!(CAT, obj: element, - "Received transcript is too late by {:?}, dropping, consider increasing the latency", - position - start_time); - } else { - if start_time < position { - gst_warning!(CAT, obj: element, - "Received transcript is too late by {:?}, clipping, consider increasing the latency", - position - start_time); - start_time = position; - } - - let mut buf = gst::Buffer::from_mut_slice( - alternative.transcript.into_bytes(), + if !use_partial_results { + let alternative = result.alternatives.remove(0); + gst_info!( + CAT, + obj: element, + "Transcript: {}", + alternative.transcript ); - { - let buf = buf.get_mut().unwrap(); + let mut start_time: gst::ClockTime = + ((result.start_time as f64 * 1_000_000_000.0) as u64).into(); + let end_time: gst::ClockTime = + ((result.end_time as f64 * 1_000_000_000.0) as u64).into(); - if state.discont { - buf.set_flags(gst::BufferFlags::DISCONT); - state.discont = false; + let mut state = self.state.lock().unwrap(); + let position = state.out_segment.get_position(); + + if end_time < position { + gst_warning!(CAT, obj: element, + "Received transcript is too late by {:?}, dropping, consider increasing the latency", + position - start_time); + } else { + if start_time < position { + gst_warning!(CAT, obj: element, + "Received transcript is too late by {:?}, clipping, consider increasing the latency", + position - start_time); + start_time = position; } - buf.set_pts(start_time); - buf.set_duration(end_time - start_time); + let mut buf = gst::Buffer::from_mut_slice( + alternative.transcript.into_bytes(), + ); + + { + let buf = buf.get_mut().unwrap(); + + if state.discont { + buf.set_flags(gst::BufferFlags::DISCONT); + state.discont = false; + } + + buf.set_pts(start_time); + buf.set_duration(end_time - start_time); + } + + gst_debug!( + CAT, + obj: element, + "Adding pending buffer: {:?}", + buf + ); + + state.buffers.push_back(buf); } - - gst_debug!(CAT, obj: element, "Adding pending buffer: {:?}", buf); - - state.buffers.push_back(buf); + } else { + let alternative = result.alternatives.remove(0); + let mut state = self.state.lock().unwrap(); + self.enqueue( + element, + &mut state, + &alternative, + false, + 0.into(), + 0.into(), + ); + state.partial_alternative = None; } + } else if !result.alternatives.is_empty() && use_partial_results { + let mut state = self.state.lock().unwrap(); + state.partial_alternative = Some(result.alternatives.remove(0)); } } Ok(()) @@ -1001,6 +1109,10 @@ impl ObjectImpl for Transcriber { let mut settings = self.settings.lock().unwrap(); settings.latency_ms = value.get_some().expect("type checked upstream"); } + subclass::Property("use-partial-results", ..) => { + let mut settings = self.settings.lock().unwrap(); + settings.use_partial_results = value.get_some().expect("type checked upstream"); + } _ => unimplemented!(), } } @@ -1017,6 +1129,10 @@ impl ObjectImpl for Transcriber { let settings = self.settings.lock().unwrap(); Ok(settings.latency_ms.to_value()) } + subclass::Property("use-partial-results", ..) => { + let settings = self.settings.lock().unwrap(); + Ok(settings.use_partial_results.to_value()) + } _ => unimplemented!(), } }