awstranscriber: make use of new result stability AWS API option

<https://aws.amazon.com/blogs/machine-learning/amazon-transcribe-now-supports-partial-results-stabilization-for-streaming-audio/>

Amazon seem to have realized the previous iteration of their API
made it difficult to identify items from one result to the next,
which made the element much more complicated than it should have
been. With that new "stability" option, we can enqueue items as
soon as they stabilize, and simply rely on the current index in
the transcript to output them exactly once.

This also means the "use_partial_results" is now useless, as there
will be no difference in accuracy between a non-partial result and
and of its stable items that might have been pushed from previous
partial versions of the result.

The property is removed, instead a new option is exposed to let
users control how fast results should stabilize.

This greatly simplifies the code, and also improves the output as
punctuation doesn't need to be randomly discarded anymore.
This commit is contained in:
Mathieu Duponchelle 2021-06-19 03:27:42 +02:00
parent d6f6f1a777
commit 640ce43fee
2 changed files with 130 additions and 173 deletions

View file

@ -19,7 +19,7 @@ use gst::glib;
use gst::prelude::*;
use gst::subclass::prelude::*;
use gst::{
element_error, error_msg, gst_debug, gst_error, gst_info, gst_log, gst_warning, loggable_error,
element_error, error_msg, gst_debug, gst_error, gst_info, gst_log, gst_trace, loggable_error,
};
use std::default::Default;
@ -44,11 +44,13 @@ use atomic_refcell::AtomicRefCell;
use super::packet::*;
use serde_derive::Deserialize;
use serde_derive::{Deserialize, Serialize};
use once_cell::sync::Lazy;
#[derive(Deserialize, Debug)]
use super::AwsTranscriberResultStability;
#[derive(Deserialize, Serialize, Debug)]
#[serde(rename_all = "PascalCase")]
struct TranscriptItem {
content: String,
@ -56,16 +58,17 @@ struct TranscriptItem {
start_time: f32,
#[serde(rename = "Type")]
type_: String,
stable: bool,
}
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Serialize, Debug)]
#[serde(rename_all = "PascalCase")]
struct TranscriptAlternative {
items: Vec<TranscriptItem>,
transcript: String,
}
#[derive(Deserialize, Debug)]
#[derive(Deserialize, Serialize, Debug)]
#[serde(rename_all = "PascalCase")]
struct TranscriptResult {
alternatives: Vec<TranscriptAlternative>,
@ -110,16 +113,16 @@ static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
});
const DEFAULT_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8);
const DEFAULT_USE_PARTIAL_RESULTS: bool = true;
const DEFAULT_STABILITY: AwsTranscriberResultStability = AwsTranscriberResultStability::Low;
const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100);
#[derive(Debug, Clone)]
struct Settings {
latency: gst::ClockTime,
language_code: Option<String>,
use_partial_results: bool,
vocabulary: Option<String>,
session_id: Option<String>,
results_stability: AwsTranscriberResultStability,
}
impl Default for Settings {
@ -127,9 +130,9 @@ impl Default for Settings {
Self {
latency: DEFAULT_LATENCY,
language_code: Some("en-US".to_string()),
use_partial_results: DEFAULT_USE_PARTIAL_RESULTS,
vocabulary: None,
session_id: None,
results_stability: DEFAULT_STABILITY,
}
}
}
@ -145,8 +148,7 @@ struct State {
buffers: VecDeque<gst::Buffer>,
send_eos: bool,
discont: bool,
last_partial_end_time: Option<gst::ClockTime>,
partial_alternative: Option<TranscriptAlternative>,
partial_index: usize,
}
impl Default for State {
@ -162,8 +164,7 @@ impl Default for State {
buffers: VecDeque::new(),
send_eos: false,
discont: true,
last_partial_end_time: None,
partial_alternative: None,
partial_index: 0,
}
}
}
@ -207,15 +208,12 @@ impl Transcriber {
let (latency, now, mut last_position, send_eos, seqnum) = {
let mut state = self.state.lock().unwrap();
// Multiply GRANULARITY by 2 in order to not send buffers that
// are less than GRANULARITY away too late
let latency = self.settings.lock().unwrap().latency - 2 * GRANULARITY;
let now = element.current_running_time();
if let Some(alternative) = state.partial_alternative.take() {
self.enqueue(element, &mut state, &alternative, true, latency, now);
state.partial_alternative = Some(alternative);
}
let send_eos = state.send_eos && state.buffers.is_empty();
while let Some(buf) = state.buffers.front() {
@ -261,7 +259,7 @@ impl Transcriber {
.duration(delta)
.seqnum(seqnum)
.build();
gst_debug!(
gst_log!(
CAT,
"Pushing gap: {} -> {}",
last_pos,
@ -279,7 +277,7 @@ impl Transcriber {
let buf = buf.get_mut().unwrap();
buf.set_pts(buf.pts());
}
gst_debug!(
gst_log!(
CAT,
"Pushing buffer: {} -> {}",
buf.pts().display(),
@ -306,7 +304,7 @@ impl Transcriber {
.seqnum(seqnum)
.build();
let next_position = last_pos + duration;
gst_debug!(CAT, "Pushing gap: {} -> {}", last_pos, next_position,);
gst_log!(CAT, "Pushing gap: {} -> {}", last_pos, next_position,);
last_position = Some(next_position);
if !self.srcpad.push_event(gap_event) {
return false;
@ -328,68 +326,62 @@ impl Transcriber {
state: &mut State,
alternative: &TranscriptAlternative,
partial: bool,
latency: gst::ClockTime,
now: impl Into<Option<gst::ClockTime>> + Copy,
) {
for item in &alternative.items {
for item in &alternative.items[state.partial_index..] {
let mut start_time =
gst::ClockTime::from_nseconds((item.start_time as f64 * 1_000_000_000.0) as u64);
let mut end_time =
gst::ClockTime::from_nseconds((item.end_time as f64 * 1_000_000_000.0) as u64);
if state
.last_partial_end_time
.map_or(false, |last_partial_end_time| {
start_time <= last_partial_end_time
})
{
/* Already sent (hopefully) */
continue;
} else if !partial || now.into().map_or(false, |now| start_time + latency < now) {
/* Should be sent now */
gst_debug!(CAT, obj: element, "Item is ready: {}", item.content);
let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes());
state.last_partial_end_time = Some(end_time);
{
let buf = buf.get_mut().unwrap();
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
}
if state
.out_segment
.position()
.map_or(false, |pos| start_time < pos)
{
let pos = state
.out_segment
.position()
.expect("position checked above");
gst_debug!(
CAT,
obj: element,
"Adjusting item timing({} < {})",
start_time,
pos,
);
start_time = pos;
if end_time < start_time {
end_time = start_time;
}
}
buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
}
state.buffers.push_back(buf);
} else {
/* Doesn't need to be sent yet */
if !item.stable {
break;
}
/* Should be sent now */
gst_debug!(CAT, obj: element, "Item is ready: {}", item.content);
let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes());
{
let buf = buf.get_mut().unwrap();
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
}
if state
.out_segment
.position()
.map_or(false, |pos| start_time < pos)
{
let pos = state
.out_segment
.position()
.expect("position checked above");
gst_debug!(
CAT,
obj: element,
"Adjusting item timing({} < {})",
start_time,
pos,
);
start_time = pos;
if end_time < start_time {
end_time = start_time;
}
}
buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
}
state.partial_index += 1;
state.buffers.push_back(buf);
}
if !partial {
state.partial_index = 0;
}
}
@ -448,94 +440,28 @@ impl Transcriber {
));
}
let mut transcript: Transcript =
serde_json::from_str(&payload).map_err(|err| {
error_msg!(
gst::StreamError::Failed,
["Unexpected binary message: {} ({})", payload, err]
)
})?;
let transcript: Transcript = serde_json::from_str(&payload).map_err(|err| {
error_msg!(
gst::StreamError::Failed,
["Unexpected binary message: {} ({})", payload, err]
)
})?;
if !transcript.transcript.results.is_empty() {
let mut result = transcript.transcript.results.remove(0);
let use_partial_results = self.settings.lock().unwrap().use_partial_results;
if !result.is_partial && !result.alternatives.is_empty() {
let alternative = result.alternatives.remove(0);
if !use_partial_results {
gst_info!(
CAT,
obj: element,
"Transcript: {}",
alternative.transcript
);
if let Some(result) = transcript.transcript.results.get(0) {
gst_trace!(
CAT,
obj: element,
"result: {}",
serde_json::to_string_pretty(&result).unwrap(),
);
let mut start_time = gst::ClockTime::from_nseconds(
(result.start_time as f64 * 1_000_000_000.0) as u64,
);
let end_time = gst::ClockTime::from_nseconds(
(result.end_time as f64 * 1_000_000_000.0) as u64,
);
let mut state = self.state.lock().unwrap();
let position = state.out_segment.position();
if position.map_or(false, |position| end_time < position) {
let pos = position.expect("position checked above");
gst_warning!(CAT, obj: element,
"Received transcript is too late by {}, dropping, consider increasing the latency",
pos - start_time);
} else {
if let Some(delta) =
position.and_then(|pos| pos.checked_sub(start_time))
{
gst_warning!(CAT, obj: element,
"Received transcript is too late by {}, clipping, consider increasing the latency",
delta);
start_time = position.expect("position checked above");
}
let mut buf = gst::Buffer::from_mut_slice(
alternative.transcript.into_bytes(),
);
{
let buf = buf.get_mut().unwrap();
if state.discont {
buf.set_flags(gst::BufferFlags::DISCONT);
state.discont = false;
}
buf.set_pts(start_time);
buf.set_duration(end_time - start_time);
}
gst_debug!(
CAT,
obj: element,
"Adding pending buffer: {:?}",
buf
);
state.buffers.push_back(buf);
}
} else {
let mut state = self.state.lock().unwrap();
self.enqueue(
element,
&mut state,
&alternative,
false,
gst::ClockTime::ZERO,
gst::ClockTime::ZERO,
);
state.partial_alternative = None;
}
} else if !result.alternatives.is_empty() && use_partial_results {
if let Some(alternative) = result.alternatives.get(0) {
let mut state = self.state.lock().unwrap();
state.partial_alternative = Some(result.alternatives.remove(0));
self.enqueue(element, &mut state, alternative, result.is_partial)
}
}
Ok(())
}
@ -673,7 +599,7 @@ impl Transcriber {
fn sink_event(&self, pad: &gst::Pad, element: &super::Transcriber, event: gst::Event) -> bool {
use gst::EventView;
gst_debug!(CAT, obj: pad, "Handling event {:?}", event);
gst_log!(CAT, obj: pad, "Handling event {:?}", event);
match event.view() {
EventView::Eos(_) => match self.handle_buffer(pad, element, None) {
@ -811,7 +737,7 @@ impl Transcriber {
element: &super::Transcriber,
buffer: Option<gst::Buffer>,
) -> Result<gst::FlowSuccess, gst::FlowError> {
gst_debug!(CAT, obj: element, "Handling {:?}", buffer);
gst_log!(CAT, obj: element, "Handling {:?}", buffer);
self.ensure_connection(element).map_err(|err| {
element_error!(
@ -902,6 +828,16 @@ impl Transcriber {
signed.add_param("session-id", session_id);
}
signed.add_param("enable-partial-results-stabilization", "true");
signed.add_param(
"partial-results-stability",
match settings.results_stability {
AwsTranscriberResultStability::High => "high",
AwsTranscriberResultStability::Medium => "medium",
AwsTranscriberResultStability::Low => "low",
},
);
let url = signed.generate_presigned_url(&creds, &std::time::Duration::from_secs(60), true);
let (ws, _) = {
@ -1060,13 +996,6 @@ impl ObjectImpl for Transcriber {
Some("en-US"),
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
),
glib::ParamSpec::new_boolean(
"use-partial-results",
"Latency",
"Whether partial results from AWS should be used",
DEFAULT_USE_PARTIAL_RESULTS,
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_PLAYING,
),
glib::ParamSpec::new_uint(
"latency",
"Latency",
@ -1092,6 +1021,14 @@ impl ObjectImpl for Transcriber {
None,
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
),
glib::ParamSpec::new_enum(
"results-stability",
"Results stability",
"Defines how fast results should stabilize",
AwsTranscriberResultStability::static_type(),
DEFAULT_STABILITY as i32,
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
),
]
});
@ -1124,10 +1061,6 @@ impl ObjectImpl for Transcriber {
value.get::<u32>().expect("type checked upstream").into(),
);
}
"use-partial-results" => {
let mut settings = self.settings.lock().unwrap();
settings.use_partial_results = value.get().expect("type checked upstream");
}
"vocabulary-name" => {
let mut settings = self.settings.lock().unwrap();
settings.vocabulary = value.get().expect("type checked upstream");
@ -1136,6 +1069,12 @@ impl ObjectImpl for Transcriber {
let mut settings = self.settings.lock().unwrap();
settings.session_id = value.get().expect("type checked upstream");
}
"results-stability" => {
let mut settings = self.settings.lock().unwrap();
settings.results_stability = value
.get::<AwsTranscriberResultStability>()
.expect("type checked upstream");
}
_ => unimplemented!(),
}
}
@ -1150,10 +1089,6 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
(settings.latency.mseconds() as u32).to_value()
}
"use-partial-results" => {
let settings = self.settings.lock().unwrap();
settings.use_partial_results.to_value()
}
"vocabulary-name" => {
let settings = self.settings.lock().unwrap();
settings.vocabulary.to_value()
@ -1162,6 +1097,10 @@ impl ObjectImpl for Transcriber {
let settings = self.settings.lock().unwrap();
settings.session_id.to_value()
}
"results-stability" => {
let settings = self.settings.lock().unwrap();
settings.results_stability.to_value()
}
_ => unimplemented!(),
}
}

View file

@ -21,6 +21,24 @@ use gst::prelude::*;
mod imp;
mod packet;
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::GEnum)]
#[repr(u32)]
#[genum(type_name = "GstAwsTranscriberResultStability")]
pub enum AwsTranscriberResultStability {
#[genum(name = "High: stabilize results as fast as possible", nick = "high")]
High = 0,
#[genum(
name = "Medium: balance between stability and accuracy",
nick = "medium"
)]
Medium = 1,
#[genum(
name = "Low: relatively less stable partial transcription results with higher accuracy",
nick = "low"
)]
Low = 2,
}
glib::wrapper! {
pub struct Transcriber(ObjectSubclass<imp::Transcriber>) @extends gst::Element, gst::Object;
}