mirror of
https://gitlab.freedesktop.org/gstreamer/gst-plugins-rs.git
synced 2024-11-25 13:01:07 +00:00
awstranscriber: make use of new result stability AWS API option
<https://aws.amazon.com/blogs/machine-learning/amazon-transcribe-now-supports-partial-results-stabilization-for-streaming-audio/> Amazon seem to have realized the previous iteration of their API made it difficult to identify items from one result to the next, which made the element much more complicated than it should have been. With that new "stability" option, we can enqueue items as soon as they stabilize, and simply rely on the current index in the transcript to output them exactly once. This also means the "use_partial_results" is now useless, as there will be no difference in accuracy between a non-partial result and and of its stable items that might have been pushed from previous partial versions of the result. The property is removed, instead a new option is exposed to let users control how fast results should stabilize. This greatly simplifies the code, and also improves the output as punctuation doesn't need to be randomly discarded anymore.
This commit is contained in:
parent
d6f6f1a777
commit
640ce43fee
2 changed files with 130 additions and 173 deletions
|
@ -19,7 +19,7 @@ use gst::glib;
|
|||
use gst::prelude::*;
|
||||
use gst::subclass::prelude::*;
|
||||
use gst::{
|
||||
element_error, error_msg, gst_debug, gst_error, gst_info, gst_log, gst_warning, loggable_error,
|
||||
element_error, error_msg, gst_debug, gst_error, gst_info, gst_log, gst_trace, loggable_error,
|
||||
};
|
||||
|
||||
use std::default::Default;
|
||||
|
@ -44,11 +44,13 @@ use atomic_refcell::AtomicRefCell;
|
|||
|
||||
use super::packet::*;
|
||||
|
||||
use serde_derive::Deserialize;
|
||||
use serde_derive::{Deserialize, Serialize};
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
use super::AwsTranscriberResultStability;
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct TranscriptItem {
|
||||
content: String,
|
||||
|
@ -56,16 +58,17 @@ struct TranscriptItem {
|
|||
start_time: f32,
|
||||
#[serde(rename = "Type")]
|
||||
type_: String,
|
||||
stable: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct TranscriptAlternative {
|
||||
items: Vec<TranscriptItem>,
|
||||
transcript: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct TranscriptResult {
|
||||
alternatives: Vec<TranscriptAlternative>,
|
||||
|
@ -110,16 +113,16 @@ static RUNTIME: Lazy<runtime::Runtime> = Lazy::new(|| {
|
|||
});
|
||||
|
||||
const DEFAULT_LATENCY: gst::ClockTime = gst::ClockTime::from_seconds(8);
|
||||
const DEFAULT_USE_PARTIAL_RESULTS: bool = true;
|
||||
const DEFAULT_STABILITY: AwsTranscriberResultStability = AwsTranscriberResultStability::Low;
|
||||
const GRANULARITY: gst::ClockTime = gst::ClockTime::from_mseconds(100);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Settings {
|
||||
latency: gst::ClockTime,
|
||||
language_code: Option<String>,
|
||||
use_partial_results: bool,
|
||||
vocabulary: Option<String>,
|
||||
session_id: Option<String>,
|
||||
results_stability: AwsTranscriberResultStability,
|
||||
}
|
||||
|
||||
impl Default for Settings {
|
||||
|
@ -127,9 +130,9 @@ impl Default for Settings {
|
|||
Self {
|
||||
latency: DEFAULT_LATENCY,
|
||||
language_code: Some("en-US".to_string()),
|
||||
use_partial_results: DEFAULT_USE_PARTIAL_RESULTS,
|
||||
vocabulary: None,
|
||||
session_id: None,
|
||||
results_stability: DEFAULT_STABILITY,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -145,8 +148,7 @@ struct State {
|
|||
buffers: VecDeque<gst::Buffer>,
|
||||
send_eos: bool,
|
||||
discont: bool,
|
||||
last_partial_end_time: Option<gst::ClockTime>,
|
||||
partial_alternative: Option<TranscriptAlternative>,
|
||||
partial_index: usize,
|
||||
}
|
||||
|
||||
impl Default for State {
|
||||
|
@ -162,8 +164,7 @@ impl Default for State {
|
|||
buffers: VecDeque::new(),
|
||||
send_eos: false,
|
||||
discont: true,
|
||||
last_partial_end_time: None,
|
||||
partial_alternative: None,
|
||||
partial_index: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -207,15 +208,12 @@ impl Transcriber {
|
|||
|
||||
let (latency, now, mut last_position, send_eos, seqnum) = {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
|
||||
// Multiply GRANULARITY by 2 in order to not send buffers that
|
||||
// are less than GRANULARITY away too late
|
||||
let latency = self.settings.lock().unwrap().latency - 2 * GRANULARITY;
|
||||
let now = element.current_running_time();
|
||||
|
||||
if let Some(alternative) = state.partial_alternative.take() {
|
||||
self.enqueue(element, &mut state, &alternative, true, latency, now);
|
||||
state.partial_alternative = Some(alternative);
|
||||
}
|
||||
let send_eos = state.send_eos && state.buffers.is_empty();
|
||||
|
||||
while let Some(buf) = state.buffers.front() {
|
||||
|
@ -261,7 +259,7 @@ impl Transcriber {
|
|||
.duration(delta)
|
||||
.seqnum(seqnum)
|
||||
.build();
|
||||
gst_debug!(
|
||||
gst_log!(
|
||||
CAT,
|
||||
"Pushing gap: {} -> {}",
|
||||
last_pos,
|
||||
|
@ -279,7 +277,7 @@ impl Transcriber {
|
|||
let buf = buf.get_mut().unwrap();
|
||||
buf.set_pts(buf.pts());
|
||||
}
|
||||
gst_debug!(
|
||||
gst_log!(
|
||||
CAT,
|
||||
"Pushing buffer: {} -> {}",
|
||||
buf.pts().display(),
|
||||
|
@ -306,7 +304,7 @@ impl Transcriber {
|
|||
.seqnum(seqnum)
|
||||
.build();
|
||||
let next_position = last_pos + duration;
|
||||
gst_debug!(CAT, "Pushing gap: {} -> {}", last_pos, next_position,);
|
||||
gst_log!(CAT, "Pushing gap: {} -> {}", last_pos, next_position,);
|
||||
last_position = Some(next_position);
|
||||
if !self.srcpad.push_event(gap_event) {
|
||||
return false;
|
||||
|
@ -328,68 +326,62 @@ impl Transcriber {
|
|||
state: &mut State,
|
||||
alternative: &TranscriptAlternative,
|
||||
partial: bool,
|
||||
latency: gst::ClockTime,
|
||||
now: impl Into<Option<gst::ClockTime>> + Copy,
|
||||
) {
|
||||
for item in &alternative.items {
|
||||
for item in &alternative.items[state.partial_index..] {
|
||||
let mut start_time =
|
||||
gst::ClockTime::from_nseconds((item.start_time as f64 * 1_000_000_000.0) as u64);
|
||||
let mut end_time =
|
||||
gst::ClockTime::from_nseconds((item.end_time as f64 * 1_000_000_000.0) as u64);
|
||||
|
||||
if state
|
||||
.last_partial_end_time
|
||||
.map_or(false, |last_partial_end_time| {
|
||||
start_time <= last_partial_end_time
|
||||
})
|
||||
{
|
||||
/* Already sent (hopefully) */
|
||||
continue;
|
||||
} else if !partial || now.into().map_or(false, |now| start_time + latency < now) {
|
||||
/* Should be sent now */
|
||||
gst_debug!(CAT, obj: element, "Item is ready: {}", item.content);
|
||||
let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes());
|
||||
state.last_partial_end_time = Some(end_time);
|
||||
|
||||
{
|
||||
let buf = buf.get_mut().unwrap();
|
||||
|
||||
if state.discont {
|
||||
buf.set_flags(gst::BufferFlags::DISCONT);
|
||||
state.discont = false;
|
||||
}
|
||||
|
||||
if state
|
||||
.out_segment
|
||||
.position()
|
||||
.map_or(false, |pos| start_time < pos)
|
||||
{
|
||||
let pos = state
|
||||
.out_segment
|
||||
.position()
|
||||
.expect("position checked above");
|
||||
gst_debug!(
|
||||
CAT,
|
||||
obj: element,
|
||||
"Adjusting item timing({} < {})",
|
||||
start_time,
|
||||
pos,
|
||||
);
|
||||
start_time = pos;
|
||||
if end_time < start_time {
|
||||
end_time = start_time;
|
||||
}
|
||||
}
|
||||
|
||||
buf.set_pts(start_time);
|
||||
buf.set_duration(end_time - start_time);
|
||||
}
|
||||
|
||||
state.buffers.push_back(buf);
|
||||
} else {
|
||||
/* Doesn't need to be sent yet */
|
||||
if !item.stable {
|
||||
break;
|
||||
}
|
||||
|
||||
/* Should be sent now */
|
||||
gst_debug!(CAT, obj: element, "Item is ready: {}", item.content);
|
||||
let mut buf = gst::Buffer::from_mut_slice(item.content.clone().into_bytes());
|
||||
|
||||
{
|
||||
let buf = buf.get_mut().unwrap();
|
||||
|
||||
if state.discont {
|
||||
buf.set_flags(gst::BufferFlags::DISCONT);
|
||||
state.discont = false;
|
||||
}
|
||||
|
||||
if state
|
||||
.out_segment
|
||||
.position()
|
||||
.map_or(false, |pos| start_time < pos)
|
||||
{
|
||||
let pos = state
|
||||
.out_segment
|
||||
.position()
|
||||
.expect("position checked above");
|
||||
gst_debug!(
|
||||
CAT,
|
||||
obj: element,
|
||||
"Adjusting item timing({} < {})",
|
||||
start_time,
|
||||
pos,
|
||||
);
|
||||
start_time = pos;
|
||||
if end_time < start_time {
|
||||
end_time = start_time;
|
||||
}
|
||||
}
|
||||
|
||||
buf.set_pts(start_time);
|
||||
buf.set_duration(end_time - start_time);
|
||||
}
|
||||
|
||||
state.partial_index += 1;
|
||||
|
||||
state.buffers.push_back(buf);
|
||||
}
|
||||
|
||||
if !partial {
|
||||
state.partial_index = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -448,94 +440,28 @@ impl Transcriber {
|
|||
));
|
||||
}
|
||||
|
||||
let mut transcript: Transcript =
|
||||
serde_json::from_str(&payload).map_err(|err| {
|
||||
error_msg!(
|
||||
gst::StreamError::Failed,
|
||||
["Unexpected binary message: {} ({})", payload, err]
|
||||
)
|
||||
})?;
|
||||
let transcript: Transcript = serde_json::from_str(&payload).map_err(|err| {
|
||||
error_msg!(
|
||||
gst::StreamError::Failed,
|
||||
["Unexpected binary message: {} ({})", payload, err]
|
||||
)
|
||||
})?;
|
||||
|
||||
if !transcript.transcript.results.is_empty() {
|
||||
let mut result = transcript.transcript.results.remove(0);
|
||||
let use_partial_results = self.settings.lock().unwrap().use_partial_results;
|
||||
if !result.is_partial && !result.alternatives.is_empty() {
|
||||
let alternative = result.alternatives.remove(0);
|
||||
if !use_partial_results {
|
||||
gst_info!(
|
||||
CAT,
|
||||
obj: element,
|
||||
"Transcript: {}",
|
||||
alternative.transcript
|
||||
);
|
||||
if let Some(result) = transcript.transcript.results.get(0) {
|
||||
gst_trace!(
|
||||
CAT,
|
||||
obj: element,
|
||||
"result: {}",
|
||||
serde_json::to_string_pretty(&result).unwrap(),
|
||||
);
|
||||
|
||||
let mut start_time = gst::ClockTime::from_nseconds(
|
||||
(result.start_time as f64 * 1_000_000_000.0) as u64,
|
||||
);
|
||||
let end_time = gst::ClockTime::from_nseconds(
|
||||
(result.end_time as f64 * 1_000_000_000.0) as u64,
|
||||
);
|
||||
|
||||
let mut state = self.state.lock().unwrap();
|
||||
let position = state.out_segment.position();
|
||||
|
||||
if position.map_or(false, |position| end_time < position) {
|
||||
let pos = position.expect("position checked above");
|
||||
gst_warning!(CAT, obj: element,
|
||||
"Received transcript is too late by {}, dropping, consider increasing the latency",
|
||||
pos - start_time);
|
||||
} else {
|
||||
if let Some(delta) =
|
||||
position.and_then(|pos| pos.checked_sub(start_time))
|
||||
{
|
||||
gst_warning!(CAT, obj: element,
|
||||
"Received transcript is too late by {}, clipping, consider increasing the latency",
|
||||
delta);
|
||||
start_time = position.expect("position checked above");
|
||||
}
|
||||
|
||||
let mut buf = gst::Buffer::from_mut_slice(
|
||||
alternative.transcript.into_bytes(),
|
||||
);
|
||||
|
||||
{
|
||||
let buf = buf.get_mut().unwrap();
|
||||
|
||||
if state.discont {
|
||||
buf.set_flags(gst::BufferFlags::DISCONT);
|
||||
state.discont = false;
|
||||
}
|
||||
|
||||
buf.set_pts(start_time);
|
||||
buf.set_duration(end_time - start_time);
|
||||
}
|
||||
|
||||
gst_debug!(
|
||||
CAT,
|
||||
obj: element,
|
||||
"Adding pending buffer: {:?}",
|
||||
buf
|
||||
);
|
||||
|
||||
state.buffers.push_back(buf);
|
||||
}
|
||||
} else {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
self.enqueue(
|
||||
element,
|
||||
&mut state,
|
||||
&alternative,
|
||||
false,
|
||||
gst::ClockTime::ZERO,
|
||||
gst::ClockTime::ZERO,
|
||||
);
|
||||
state.partial_alternative = None;
|
||||
}
|
||||
} else if !result.alternatives.is_empty() && use_partial_results {
|
||||
if let Some(alternative) = result.alternatives.get(0) {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
state.partial_alternative = Some(result.alternatives.remove(0));
|
||||
|
||||
self.enqueue(element, &mut state, alternative, result.is_partial)
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -673,7 +599,7 @@ impl Transcriber {
|
|||
fn sink_event(&self, pad: &gst::Pad, element: &super::Transcriber, event: gst::Event) -> bool {
|
||||
use gst::EventView;
|
||||
|
||||
gst_debug!(CAT, obj: pad, "Handling event {:?}", event);
|
||||
gst_log!(CAT, obj: pad, "Handling event {:?}", event);
|
||||
|
||||
match event.view() {
|
||||
EventView::Eos(_) => match self.handle_buffer(pad, element, None) {
|
||||
|
@ -811,7 +737,7 @@ impl Transcriber {
|
|||
element: &super::Transcriber,
|
||||
buffer: Option<gst::Buffer>,
|
||||
) -> Result<gst::FlowSuccess, gst::FlowError> {
|
||||
gst_debug!(CAT, obj: element, "Handling {:?}", buffer);
|
||||
gst_log!(CAT, obj: element, "Handling {:?}", buffer);
|
||||
|
||||
self.ensure_connection(element).map_err(|err| {
|
||||
element_error!(
|
||||
|
@ -902,6 +828,16 @@ impl Transcriber {
|
|||
signed.add_param("session-id", session_id);
|
||||
}
|
||||
|
||||
signed.add_param("enable-partial-results-stabilization", "true");
|
||||
signed.add_param(
|
||||
"partial-results-stability",
|
||||
match settings.results_stability {
|
||||
AwsTranscriberResultStability::High => "high",
|
||||
AwsTranscriberResultStability::Medium => "medium",
|
||||
AwsTranscriberResultStability::Low => "low",
|
||||
},
|
||||
);
|
||||
|
||||
let url = signed.generate_presigned_url(&creds, &std::time::Duration::from_secs(60), true);
|
||||
|
||||
let (ws, _) = {
|
||||
|
@ -1060,13 +996,6 @@ impl ObjectImpl for Transcriber {
|
|||
Some("en-US"),
|
||||
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
|
||||
),
|
||||
glib::ParamSpec::new_boolean(
|
||||
"use-partial-results",
|
||||
"Latency",
|
||||
"Whether partial results from AWS should be used",
|
||||
DEFAULT_USE_PARTIAL_RESULTS,
|
||||
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_PLAYING,
|
||||
),
|
||||
glib::ParamSpec::new_uint(
|
||||
"latency",
|
||||
"Latency",
|
||||
|
@ -1092,6 +1021,14 @@ impl ObjectImpl for Transcriber {
|
|||
None,
|
||||
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
|
||||
),
|
||||
glib::ParamSpec::new_enum(
|
||||
"results-stability",
|
||||
"Results stability",
|
||||
"Defines how fast results should stabilize",
|
||||
AwsTranscriberResultStability::static_type(),
|
||||
DEFAULT_STABILITY as i32,
|
||||
glib::ParamFlags::READWRITE | gst::PARAM_FLAG_MUTABLE_READY,
|
||||
),
|
||||
]
|
||||
});
|
||||
|
||||
|
@ -1124,10 +1061,6 @@ impl ObjectImpl for Transcriber {
|
|||
value.get::<u32>().expect("type checked upstream").into(),
|
||||
);
|
||||
}
|
||||
"use-partial-results" => {
|
||||
let mut settings = self.settings.lock().unwrap();
|
||||
settings.use_partial_results = value.get().expect("type checked upstream");
|
||||
}
|
||||
"vocabulary-name" => {
|
||||
let mut settings = self.settings.lock().unwrap();
|
||||
settings.vocabulary = value.get().expect("type checked upstream");
|
||||
|
@ -1136,6 +1069,12 @@ impl ObjectImpl for Transcriber {
|
|||
let mut settings = self.settings.lock().unwrap();
|
||||
settings.session_id = value.get().expect("type checked upstream");
|
||||
}
|
||||
"results-stability" => {
|
||||
let mut settings = self.settings.lock().unwrap();
|
||||
settings.results_stability = value
|
||||
.get::<AwsTranscriberResultStability>()
|
||||
.expect("type checked upstream");
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
@ -1150,10 +1089,6 @@ impl ObjectImpl for Transcriber {
|
|||
let settings = self.settings.lock().unwrap();
|
||||
(settings.latency.mseconds() as u32).to_value()
|
||||
}
|
||||
"use-partial-results" => {
|
||||
let settings = self.settings.lock().unwrap();
|
||||
settings.use_partial_results.to_value()
|
||||
}
|
||||
"vocabulary-name" => {
|
||||
let settings = self.settings.lock().unwrap();
|
||||
settings.vocabulary.to_value()
|
||||
|
@ -1162,6 +1097,10 @@ impl ObjectImpl for Transcriber {
|
|||
let settings = self.settings.lock().unwrap();
|
||||
settings.session_id.to_value()
|
||||
}
|
||||
"results-stability" => {
|
||||
let settings = self.settings.lock().unwrap();
|
||||
settings.results_stability.to_value()
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,24 @@ use gst::prelude::*;
|
|||
mod imp;
|
||||
mod packet;
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Clone, Copy, glib::GEnum)]
|
||||
#[repr(u32)]
|
||||
#[genum(type_name = "GstAwsTranscriberResultStability")]
|
||||
pub enum AwsTranscriberResultStability {
|
||||
#[genum(name = "High: stabilize results as fast as possible", nick = "high")]
|
||||
High = 0,
|
||||
#[genum(
|
||||
name = "Medium: balance between stability and accuracy",
|
||||
nick = "medium"
|
||||
)]
|
||||
Medium = 1,
|
||||
#[genum(
|
||||
name = "Low: relatively less stable partial transcription results with higher accuracy",
|
||||
nick = "low"
|
||||
)]
|
||||
Low = 2,
|
||||
}
|
||||
|
||||
glib::wrapper! {
|
||||
pub struct Transcriber(ObjectSubclass<imp::Transcriber>) @extends gst::Element, gst::Object;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue