diff --git a/net/rusoto/src/aws_transcriber/imp.rs b/net/rusoto/src/aws_transcriber/imp.rs index 08c6bddc..8dcefa5e 100644 --- a/net/rusoto/src/aws_transcriber/imp.rs +++ b/net/rusoto/src/aws_transcriber/imp.rs @@ -19,7 +19,7 @@ use gst::glib; use gst::prelude::*; use gst::subclass::prelude::*; use gst::{ - element_error, error_msg, gst_debug, gst_error, gst_info, gst_log, gst_warning, loggable_error, + element_error, error_msg, gst_debug, gst_error, gst_info, gst_log, gst_trace, loggable_error, }; use std::default::Default; @@ -44,11 +44,13 @@ use atomic_refcell::AtomicRefCell; use super::packet::*; -use serde_derive::Deserialize; +use serde_derive::{Deserialize, Serialize}; use once_cell::sync::Lazy; -#[derive(Deserialize, Debug)] +use super::AwsTranscriberResultStability; + +#[derive(Deserialize, Serialize, Debug)] #[serde(rename_all = "PascalCase")] struct TranscriptItem { content: String, @@ -56,16 +58,17 @@ struct TranscriptItem { start_time: f32, #[serde(rename = "Type")] type_: String, + stable: bool, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Serialize, Debug)] #[serde(rename_all = "PascalCase")] struct TranscriptAlternative { items: Vec, transcript: String, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Serialize, Debug)] #[serde(rename_all = "PascalCase")] struct TranscriptResult { alternatives: Vec, @@ -110,16 +113,16 @@ static RUNTIME: Lazy = Lazy::new(|| { }); const DEFAULT_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8); -const DEFAULT_USE_PARTIAL_RESULTS: bool = true; +const DEFAULT_STABILITY: AwsTranscriberResultStability = AwsTranscriberResultStability::Low; const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100); #[derive(Debug, Clone)] struct Settings { latency: gst::ClockTime, language_code: Option, - use_partial_results: bool, vocabulary: Option, session_id: Option, + results_stability: AwsTranscriberResultStability, } impl Default for Settings { @@ -127,9 +130,9 @@ impl Default for Settings { Self { latency: DEFAULT_LATENCY, language_code: Some("en-US".to_string()), - use_partial_results: DEFAULT_USE_PARTIAL_RESULTS, vocabulary: None, session_id: None, + results_stability: DEFAULT_STABILITY, } } } @@ -145,8 +148,7 @@ struct State { buffers: VecDeque, send_eos: bool, discont: bool, - last_partial_end_time: Option, - partial_alternative: Option, + partial_index: usize, } impl Default for State { @@ -162,8 +164,7 @@ impl Default for State { buffers: VecDeque::new(), send_eos: false, discont: true, - last_partial_end_time: None, - partial_alternative: None, + partial_index: 0, } } } @@ -207,15 +208,12 @@ impl Transcriber { let (latency, now, mut last_position, send_eos, seqnum) = { let mut state = self.state.lock().unwrap(); + // Multiply GRANULARITY by 2 in order to not send buffers that // are less than GRANULARITY away too late let latency = self.settings.lock().unwrap().latency - 2 * GRANULARITY; let now = element.current_running_time(); - if let Some(alternative) = state.partial_alternative.take() { - self.enqueue(element, &mut state, &alternative, true, latency, now); - state.partial_alternative = Some(alternative); - } let send_eos = state.send_eos && state.buffers.is_empty(); while let Some(buf) = state.buffers.front() { @@ -261,7 +259,7 @@ impl Transcriber { .duration(delta) .seqnum(seqnum) .build(); - gst_debug!( + gst_log!( CAT, "Pushing gap: {} -> {}", last_pos, @@ -279,7 +277,7 @@ impl Transcriber { let buf = buf.get_mut().unwrap(); buf.set_pts(buf.pts()); } - gst_debug!( + gst_log!( CAT, "Pushing buffer: {} -> {}", buf.pts().display(), @@ -306,7 +304,7 @@ impl Transcriber { .seqnum(seqnum) .build(); let next_position = last_pos + duration; - gst_debug!(CAT, "Pushing gap: {} -> {}", last_pos, next_position,); + gst_log!(CAT, "Pushing gap: {} -> {}", last_pos, next_position,); last_position = Some(next_position); if !self.srcpad.push_event(gap_event) { return false; @@ -328,68 +326,62 @@ impl Transcriber { state: &mut State, alternative: &TranscriptAlternative, partial: bool, - latency: gst::ClockTime, - now: impl Into> + Copy, ) { - for item in &alternative.items { + for item in &alternative.items[state.partial_index..] { let mut start_time = gst::ClockTime::from_nseconds((item.start_time as f64 * 1_000_000_000.0) as u64); let mut end_time = gst::ClockTime::from_nseconds((item.end_time as f64 * 1_000_000_000.0) as u64); - if state - .last_partial_end_time - .map_or(false, |last_partial_end_time| { - start_time <= last_partial_end_time - }) - { - /* Already sent (hopefully) */ - continue; - } else if !partial || now.into().map_or(false, |now| start_time + latency < now) { - /* Should be sent now */ - gst_debug!(CAT, obj: element, "Item is ready: {}", item.content); - let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes()); - state.last_partial_end_time = Some(end_time); - - { - let buf = buf.get_mut().unwrap(); - - if state.discont { - buf.set_flags(gst::BufferFlags::DISCONT); - state.discont = false; - } - - if state - .out_segment - .position() - .map_or(false, |pos| start_time < pos) - { - let pos = state - .out_segment - .position() - .expect("position checked above"); - gst_debug!( - CAT, - obj: element, - "Adjusting item timing({} < {})", - start_time, - pos, - ); - start_time = pos; - if end_time < start_time { - end_time = start_time; - } - } - - buf.set_pts(start_time); - buf.set_duration(end_time - start_time); - } - - state.buffers.push_back(buf); - } else { - /* Doesn't need to be sent yet */ + if !item.stable { break; } + + /* Should be sent now */ + gst_debug!(CAT, obj: element, "Item is ready: {}", item.content); + let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes()); + + { + let buf = buf.get_mut().unwrap(); + + if state.discont { + buf.set_flags(gst::BufferFlags::DISCONT); + state.discont = false; + } + + if state + .out_segment + .position() + .map_or(false, |pos| start_time < pos) + { + let pos = state + .out_segment + .position() + .expect("position checked above"); + gst_debug!( + CAT, + obj: element, + "Adjusting item timing({} < {})", + start_time, + pos, + ); + start_time = pos; + if end_time < start_time { + end_time = start_time; + } + } + + buf.set_pts(start_time); + buf.set_duration(end_time - start_time); + } + + state.partial_index += 1; + + state.buffers.push_back(buf); + } + + if !partial { + state.partial_index = 0; } } @@ -448,94 +440,28 @@ impl Transcriber { )); } - let mut transcript: Transcript = - serde_json::from_str(&payload).map_err(|err| { - error_msg!( - gst::StreamError::Failed, - ["Unexpected binary message: {} ({})", payload, err] - ) - })?; + let transcript: Transcript = serde_json::from_str(&payload).map_err(|err| { + error_msg!( + gst::StreamError::Failed, + ["Unexpected binary message: {} ({})", payload, err] + ) + })?; - if !transcript.transcript.results.is_empty() { - let mut result = transcript.transcript.results.remove(0); - let use_partial_results = self.settings.lock().unwrap().use_partial_results; - if !result.is_partial && !result.alternatives.is_empty() { - let alternative = result.alternatives.remove(0); - if !use_partial_results { - gst_info!( - CAT, - obj: element, - "Transcript: {}", - alternative.transcript - ); + if let Some(result) = transcript.transcript.results.get(0) { + gst_trace!( + CAT, + obj: element, + "result: {}", + serde_json::to_string_pretty(&result).unwrap(), + ); - let mut start_time = gst::ClockTime::from_nseconds( - (result.start_time as f64 * 1_000_000_000.0) as u64, - ); - let end_time = gst::ClockTime::from_nseconds( - (result.end_time as f64 * 1_000_000_000.0) as u64, - ); - - let mut state = self.state.lock().unwrap(); - let position = state.out_segment.position(); - - if position.map_or(false, |position| end_time < position) { - let pos = position.expect("position checked above"); - gst_warning!(CAT, obj: element, - "Received transcript is too late by {}, dropping, consider increasing the latency", - pos - start_time); - } else { - if let Some(delta) = - position.and_then(|pos| pos.checked_sub(start_time)) - { - gst_warning!(CAT, obj: element, - "Received transcript is too late by {}, clipping, consider increasing the latency", - delta); - start_time = position.expect("position checked above"); - } - - let mut buf = gst::Buffer::from_mut_slice( - alternative.transcript.into_bytes(), - ); - - { - let buf = buf.get_mut().unwrap(); - - if state.discont { - buf.set_flags(gst::BufferFlags::DISCONT); - state.discont = false; - } - - buf.set_pts(start_time); - buf.set_duration(end_time - start_time); - } - - gst_debug!( - CAT, - obj: element, - "Adding pending buffer: {:?}", - buf - ); - - state.buffers.push_back(buf); - } - } else { - let mut state = self.state.lock().unwrap(); - self.enqueue( - element, - &mut state, - &alternative, - false, - gst::ClockTime::ZERO, - gst::ClockTime::ZERO, - ); - state.partial_alternative = None; - } - } else if !result.alternatives.is_empty() && use_partial_results { + if let Some(alternative) = result.alternatives.get(0) { let mut state = self.state.lock().unwrap(); - state.partial_alternative = Some(result.alternatives.remove(0)); + + self.enqueue(element, &mut state, alternative, result.is_partial) } } + Ok(()) } @@ -673,7 +599,7 @@ impl Transcriber { fn sink_event(&self, pad: &gst::Pad, element: &super::Transcriber, event: gst::Event) -> bool { use gst::EventView; - gst_debug!(CAT, obj: pad, "Handling event {:?}", event); + gst_log!(CAT, obj: pad, "Handling event {:?}", event); match event.view() { EventView::Eos(_) => match self.handle_buffer(pad, element, None) { @@ -811,7 +737,7 @@ impl Transcriber { element: &super::Transcriber, buffer: Option, ) -> Result { - gst_debug!(CAT, obj: element, "Handling {:?}", buffer); + gst_log!(CAT, obj: element, "Handling {:?}", buffer); self.ensure_connection(element).map_err(|err| { element_error!( @@ -902,6 +828,16 @@ impl Transcriber { signed.add_param("session-id", session_id); } + signed.add_param("enable-partial-results-stabilization", "true"); + signed.add_param( + "partial-results-stability", + match settings.results_stability { + AwsTranscriberResultStability::High => "high", + AwsTranscriberResultStability::Medium => "medium", + AwsTranscriberResultStability::Low => "low", + }, + ); + let url = signed.generate_presigned_url(&creds, &std::time::Duration::from_secs(60), true); let (ws, _) = { @@ -1060,13 +996,6 @@ impl ObjectImpl for Transcriber { Some("en-US"), glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY, ), - glib::ParamSpec::new_boolean( - "use-partial-results", - "Latency", - "Whether partial results from AWS should be used", - DEFAULT_USE_PARTIAL_RESULTS, - glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_PLAYING, - ), glib::ParamSpec::new_uint( "latency", "Latency", @@ -1092,6 +1021,14 @@ impl ObjectImpl for Transcriber { None, glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY, ), + glib::ParamSpec::new_enum( + "results-stability", + "Results stability", + "Defines how fast results should stabilize", + AwsTranscriberResultStability::static_type(), + DEFAULT_STABILITY as i32, + glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY, + ), ] }); @@ -1124,10 +1061,6 @@ impl ObjectImpl for Transcriber { value.get::().expect("type checked upstream").into(), ); } - "use-partial-results" => { - let mut settings = self.settings.lock().unwrap(); - settings.use_partial_results = value.get().expect("type checked upstream"); - } "vocabulary-name" => { let mut settings = self.settings.lock().unwrap(); settings.vocabulary = value.get().expect("type checked upstream"); @@ -1136,6 +1069,12 @@ impl ObjectImpl for Transcriber { let mut settings = self.settings.lock().unwrap(); settings.session_id = value.get().expect("type checked upstream"); } + "results-stability" => { + let mut settings = self.settings.lock().unwrap(); + settings.results_stability = value + .get::() + .expect("type checked upstream"); + } _ => unimplemented!(), } } @@ -1150,10 +1089,6 @@ impl ObjectImpl for Transcriber { let settings = self.settings.lock().unwrap(); (settings.latency.mseconds() as u32).to_value() } - "use-partial-results" => { - let settings = self.settings.lock().unwrap(); - settings.use_partial_results.to_value() - } "vocabulary-name" => { let settings = self.settings.lock().unwrap(); settings.vocabulary.to_value() @@ -1162,6 +1097,10 @@ impl ObjectImpl for Transcriber { let settings = self.settings.lock().unwrap(); settings.session_id.to_value() } + "results-stability" => { + let settings = self.settings.lock().unwrap(); + settings.results_stability.to_value() + } _ => unimplemented!(), } } diff --git a/net/rusoto/src/aws_transcriber/mod.rs b/net/rusoto/src/aws_transcriber/mod.rs index b4fdf270..7907d29c 100644 --- a/net/rusoto/src/aws_transcriber/mod.rs +++ b/net/rusoto/src/aws_transcriber/mod.rs @@ -21,6 +21,24 @@ use gst::prelude::*; mod imp; mod packet; +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::GEnum)] +#[repr(u32)] +#[genum(type_name = "GstAwsTranscriberResultStability")] +pub enum AwsTranscriberResultStability { + #[genum(name = "High: stabilize results as fast as possible", nick = "high")] + High = 0, + #[genum( + name = "Medium: balance between stability and accuracy", + nick = "medium" + )] + Medium = 1, + #[genum( + name = "Low: relatively less stable partial transcription results with higher accuracy", + nick = "low" + )] + Low = 2, +} + glib::wrapper! { pub struct Transcriber(ObjectSubclass) @extends gst::Element, gst::Object; }